{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeApplications #-}
module Hyperion.WorkerCpuPool where
import Control.Concurrent.STM (atomically, check)
import Control.Concurrent.STM.TVar (TVar, modifyTVar, newTVarIO,
readTVar, readTVarIO)
import Control.Exception (Exception)
import Control.Monad (when)
import Control.Monad.Catch (MonadMask, bracket, try)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.List.Extra (maximumOn)
import Data.Map.Strict (Map)
import Data.Maybe (fromMaybe)
import qualified Data.Map.Strict as Map
import qualified Hyperion.Log as Log
import qualified Hyperion.Slurm as Slurm
import Hyperion.Util (retryRepeated, shellEsc)
import System.Exit (ExitCode (..))
import System.Process (readCreateProcessWithExitCode
, proc)
newtype NumCPUs = NumCPUs Int
deriving newtype (NumCPUs -> NumCPUs -> Bool
(NumCPUs -> NumCPUs -> Bool)
-> (NumCPUs -> NumCPUs -> Bool) -> Eq NumCPUs
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: NumCPUs -> NumCPUs -> Bool
$c/= :: NumCPUs -> NumCPUs -> Bool
== :: NumCPUs -> NumCPUs -> Bool
$c== :: NumCPUs -> NumCPUs -> Bool
Eq, Eq NumCPUs
Eq NumCPUs
-> (NumCPUs -> NumCPUs -> Ordering)
-> (NumCPUs -> NumCPUs -> Bool)
-> (NumCPUs -> NumCPUs -> Bool)
-> (NumCPUs -> NumCPUs -> Bool)
-> (NumCPUs -> NumCPUs -> Bool)
-> (NumCPUs -> NumCPUs -> NumCPUs)
-> (NumCPUs -> NumCPUs -> NumCPUs)
-> Ord NumCPUs
NumCPUs -> NumCPUs -> Bool
NumCPUs -> NumCPUs -> Ordering
NumCPUs -> NumCPUs -> NumCPUs
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: NumCPUs -> NumCPUs -> NumCPUs
$cmin :: NumCPUs -> NumCPUs -> NumCPUs
max :: NumCPUs -> NumCPUs -> NumCPUs
$cmax :: NumCPUs -> NumCPUs -> NumCPUs
>= :: NumCPUs -> NumCPUs -> Bool
$c>= :: NumCPUs -> NumCPUs -> Bool
> :: NumCPUs -> NumCPUs -> Bool
$c> :: NumCPUs -> NumCPUs -> Bool
<= :: NumCPUs -> NumCPUs -> Bool
$c<= :: NumCPUs -> NumCPUs -> Bool
< :: NumCPUs -> NumCPUs -> Bool
$c< :: NumCPUs -> NumCPUs -> Bool
compare :: NumCPUs -> NumCPUs -> Ordering
$ccompare :: NumCPUs -> NumCPUs -> Ordering
$cp1Ord :: Eq NumCPUs
Ord, Integer -> NumCPUs
NumCPUs -> NumCPUs
NumCPUs -> NumCPUs -> NumCPUs
(NumCPUs -> NumCPUs -> NumCPUs)
-> (NumCPUs -> NumCPUs -> NumCPUs)
-> (NumCPUs -> NumCPUs -> NumCPUs)
-> (NumCPUs -> NumCPUs)
-> (NumCPUs -> NumCPUs)
-> (NumCPUs -> NumCPUs)
-> (Integer -> NumCPUs)
-> Num NumCPUs
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
fromInteger :: Integer -> NumCPUs
$cfromInteger :: Integer -> NumCPUs
signum :: NumCPUs -> NumCPUs
$csignum :: NumCPUs -> NumCPUs
abs :: NumCPUs -> NumCPUs
$cabs :: NumCPUs -> NumCPUs
negate :: NumCPUs -> NumCPUs
$cnegate :: NumCPUs -> NumCPUs
* :: NumCPUs -> NumCPUs -> NumCPUs
$c* :: NumCPUs -> NumCPUs -> NumCPUs
- :: NumCPUs -> NumCPUs -> NumCPUs
$c- :: NumCPUs -> NumCPUs -> NumCPUs
+ :: NumCPUs -> NumCPUs -> NumCPUs
$c+ :: NumCPUs -> NumCPUs -> NumCPUs
Num)
data WorkerCpuPool = WorkerCpuPool { WorkerCpuPool -> TVar (Map WorkerAddr NumCPUs)
cpuMap :: TVar (Map WorkerAddr NumCPUs) }
newWorkerCpuPool :: Map WorkerAddr NumCPUs -> IO WorkerCpuPool
newWorkerCpuPool :: Map WorkerAddr NumCPUs -> IO WorkerCpuPool
newWorkerCpuPool Map WorkerAddr NumCPUs
cpus = TVar (Map WorkerAddr NumCPUs) -> WorkerCpuPool
WorkerCpuPool (TVar (Map WorkerAddr NumCPUs) -> WorkerCpuPool)
-> IO (TVar (Map WorkerAddr NumCPUs)) -> IO WorkerCpuPool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map WorkerAddr NumCPUs -> IO (TVar (Map WorkerAddr NumCPUs))
forall a. a -> IO (TVar a)
newTVarIO Map WorkerAddr NumCPUs
cpus
getAddrs :: WorkerCpuPool -> IO [WorkerAddr]
getAddrs :: WorkerCpuPool -> IO [WorkerAddr]
getAddrs WorkerCpuPool{TVar (Map WorkerAddr NumCPUs)
cpuMap :: TVar (Map WorkerAddr NumCPUs)
cpuMap :: WorkerCpuPool -> TVar (Map WorkerAddr NumCPUs)
..} = (Map WorkerAddr NumCPUs -> [WorkerAddr])
-> IO (Map WorkerAddr NumCPUs) -> IO [WorkerAddr]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Map WorkerAddr NumCPUs -> [WorkerAddr]
forall k a. Map k a -> [k]
Map.keys (TVar (Map WorkerAddr NumCPUs) -> IO (Map WorkerAddr NumCPUs)
forall a. TVar a -> IO a
readTVarIO TVar (Map WorkerAddr NumCPUs)
cpuMap)
data WorkerAddr = LocalHost String | RemoteAddr String
deriving (WorkerAddr -> WorkerAddr -> Bool
(WorkerAddr -> WorkerAddr -> Bool)
-> (WorkerAddr -> WorkerAddr -> Bool) -> Eq WorkerAddr
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: WorkerAddr -> WorkerAddr -> Bool
$c/= :: WorkerAddr -> WorkerAddr -> Bool
== :: WorkerAddr -> WorkerAddr -> Bool
$c== :: WorkerAddr -> WorkerAddr -> Bool
Eq, Eq WorkerAddr
Eq WorkerAddr
-> (WorkerAddr -> WorkerAddr -> Ordering)
-> (WorkerAddr -> WorkerAddr -> Bool)
-> (WorkerAddr -> WorkerAddr -> Bool)
-> (WorkerAddr -> WorkerAddr -> Bool)
-> (WorkerAddr -> WorkerAddr -> Bool)
-> (WorkerAddr -> WorkerAddr -> WorkerAddr)
-> (WorkerAddr -> WorkerAddr -> WorkerAddr)
-> Ord WorkerAddr
WorkerAddr -> WorkerAddr -> Bool
WorkerAddr -> WorkerAddr -> Ordering
WorkerAddr -> WorkerAddr -> WorkerAddr
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: WorkerAddr -> WorkerAddr -> WorkerAddr
$cmin :: WorkerAddr -> WorkerAddr -> WorkerAddr
max :: WorkerAddr -> WorkerAddr -> WorkerAddr
$cmax :: WorkerAddr -> WorkerAddr -> WorkerAddr
>= :: WorkerAddr -> WorkerAddr -> Bool
$c>= :: WorkerAddr -> WorkerAddr -> Bool
> :: WorkerAddr -> WorkerAddr -> Bool
$c> :: WorkerAddr -> WorkerAddr -> Bool
<= :: WorkerAddr -> WorkerAddr -> Bool
$c<= :: WorkerAddr -> WorkerAddr -> Bool
< :: WorkerAddr -> WorkerAddr -> Bool
$c< :: WorkerAddr -> WorkerAddr -> Bool
compare :: WorkerAddr -> WorkerAddr -> Ordering
$ccompare :: WorkerAddr -> WorkerAddr -> Ordering
$cp1Ord :: Eq WorkerAddr
Ord, Int -> WorkerAddr -> ShowS
[WorkerAddr] -> ShowS
WorkerAddr -> String
(Int -> WorkerAddr -> ShowS)
-> (WorkerAddr -> String)
-> ([WorkerAddr] -> ShowS)
-> Show WorkerAddr
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [WorkerAddr] -> ShowS
$cshowList :: [WorkerAddr] -> ShowS
show :: WorkerAddr -> String
$cshow :: WorkerAddr -> String
showsPrec :: Int -> WorkerAddr -> ShowS
$cshowsPrec :: Int -> WorkerAddr -> ShowS
Show)
getSlurmAddrs :: IO [WorkerAddr]
getSlurmAddrs :: IO [WorkerAddr]
getSlurmAddrs = do
[String]
jobNodes <- IO [String]
Slurm.getJobNodes
Maybe String
mHeadNode <- IO (Maybe String)
Slurm.lookupHeadNode
[WorkerAddr] -> IO [WorkerAddr]
forall (m :: * -> *) a. Monad m => a -> m a
return ([WorkerAddr] -> IO [WorkerAddr])
-> [WorkerAddr] -> IO [WorkerAddr]
forall a b. (a -> b) -> a -> b
$ (String -> WorkerAddr) -> [String] -> [WorkerAddr]
forall a b. (a -> b) -> [a] -> [b]
map (Maybe String -> String -> WorkerAddr
toAddr Maybe String
mHeadNode) [String]
jobNodes
where
toAddr :: Maybe String -> String -> WorkerAddr
toAddr Maybe String
mh String
n =
if Maybe String
mh Maybe String -> Maybe String -> Bool
forall a. Eq a => a -> a -> Bool
== String -> Maybe String
forall a. a -> Maybe a
Just String
n
then String -> WorkerAddr
LocalHost String
n
else String -> WorkerAddr
RemoteAddr String
n
newJobPool :: [WorkerAddr] -> IO WorkerCpuPool
newJobPool :: [WorkerAddr] -> IO WorkerCpuPool
newJobPool [WorkerAddr]
nodes = do
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([WorkerAddr] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [WorkerAddr]
nodes) (String -> IO ()
forall (m :: * -> *) a. MonadIO m => String -> m a
Log.throwError String
"Empty node list")
NumCPUs
cpusPerNode <- (Int -> NumCPUs) -> IO Int -> IO NumCPUs
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Int -> NumCPUs
NumCPUs IO Int
Slurm.getNTasksPerNode
Map WorkerAddr NumCPUs -> IO WorkerCpuPool
newWorkerCpuPool (Map WorkerAddr NumCPUs -> IO WorkerCpuPool)
-> Map WorkerAddr NumCPUs -> IO WorkerCpuPool
forall a b. (a -> b) -> a -> b
$ [(WorkerAddr, NumCPUs)] -> Map WorkerAddr NumCPUs
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(WorkerAddr, NumCPUs)] -> Map WorkerAddr NumCPUs)
-> [(WorkerAddr, NumCPUs)] -> Map WorkerAddr NumCPUs
forall a b. (a -> b) -> a -> b
$ [WorkerAddr] -> [NumCPUs] -> [(WorkerAddr, NumCPUs)]
forall a b. [a] -> [b] -> [(a, b)]
zip [WorkerAddr]
nodes (NumCPUs -> [NumCPUs]
forall a. a -> [a]
repeat NumCPUs
cpusPerNode)
withWorkerAddr
:: (MonadIO m, MonadMask m)
=> WorkerCpuPool
-> NumCPUs
-> (WorkerAddr -> m a)
-> m a
withWorkerAddr :: WorkerCpuPool -> NumCPUs -> (WorkerAddr -> m a) -> m a
withWorkerAddr WorkerCpuPool{TVar (Map WorkerAddr NumCPUs)
cpuMap :: TVar (Map WorkerAddr NumCPUs)
cpuMap :: WorkerCpuPool -> TVar (Map WorkerAddr NumCPUs)
..} NumCPUs
cpus WorkerAddr -> m a
go =
m WorkerAddr -> (WorkerAddr -> m ()) -> (WorkerAddr -> m a) -> m a
forall (m :: * -> *) a c b.
MonadMask m =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracket (IO WorkerAddr -> m WorkerAddr
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO WorkerAddr
getWorkerAddr) (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (WorkerAddr -> IO ()) -> WorkerAddr -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WorkerAddr -> IO ()
replaceWorkerAddr) WorkerAddr -> m a
go
where
getWorkerAddr :: IO WorkerAddr
getWorkerAddr = STM WorkerAddr -> IO WorkerAddr
forall a. STM a -> IO a
atomically (STM WorkerAddr -> IO WorkerAddr)
-> STM WorkerAddr -> IO WorkerAddr
forall a b. (a -> b) -> a -> b
$ do
Map WorkerAddr NumCPUs
workers <- TVar (Map WorkerAddr NumCPUs) -> STM (Map WorkerAddr NumCPUs)
forall a. TVar a -> STM a
readTVar TVar (Map WorkerAddr NumCPUs)
cpuMap
let (WorkerAddr
addr, NumCPUs
availCpus) = ((WorkerAddr, NumCPUs) -> NumCPUs)
-> [(WorkerAddr, NumCPUs)] -> (WorkerAddr, NumCPUs)
forall b a. (Partial, Ord b) => (a -> b) -> [a] -> a
maximumOn (WorkerAddr, NumCPUs) -> NumCPUs
forall a b. (a, b) -> b
snd ([(WorkerAddr, NumCPUs)] -> (WorkerAddr, NumCPUs))
-> [(WorkerAddr, NumCPUs)] -> (WorkerAddr, NumCPUs)
forall a b. (a -> b) -> a -> b
$ Map WorkerAddr NumCPUs -> [(WorkerAddr, NumCPUs)]
forall k a. Map k a -> [(k, a)]
Map.toList Map WorkerAddr NumCPUs
workers
Bool -> STM ()
check (NumCPUs
availCpus NumCPUs -> NumCPUs -> Bool
forall a. Ord a => a -> a -> Bool
>= NumCPUs
cpus)
TVar (Map WorkerAddr NumCPUs)
-> (Map WorkerAddr NumCPUs -> Map WorkerAddr NumCPUs) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar (Map WorkerAddr NumCPUs)
cpuMap ((NumCPUs -> NumCPUs)
-> WorkerAddr -> Map WorkerAddr NumCPUs -> Map WorkerAddr NumCPUs
forall k a. Ord k => (a -> a) -> k -> Map k a -> Map k a
Map.adjust (NumCPUs -> NumCPUs -> NumCPUs
forall a. Num a => a -> a -> a
subtract NumCPUs
cpus) WorkerAddr
addr)
WorkerAddr -> STM WorkerAddr
forall (m :: * -> *) a. Monad m => a -> m a
return WorkerAddr
addr
replaceWorkerAddr :: WorkerAddr -> IO ()
replaceWorkerAddr WorkerAddr
addr = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$
TVar (Map WorkerAddr NumCPUs)
-> (Map WorkerAddr NumCPUs -> Map WorkerAddr NumCPUs) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar (Map WorkerAddr NumCPUs)
cpuMap ((NumCPUs -> NumCPUs)
-> WorkerAddr -> Map WorkerAddr NumCPUs -> Map WorkerAddr NumCPUs
forall k a. Ord k => (a -> a) -> k -> Map k a -> Map k a
Map.adjust (NumCPUs -> NumCPUs -> NumCPUs
forall a. Num a => a -> a -> a
+NumCPUs
cpus) WorkerAddr
addr)
data SSHError = SSHError String (ExitCode, String, String)
deriving (Int -> SSHError -> ShowS
[SSHError] -> ShowS
SSHError -> String
(Int -> SSHError -> ShowS)
-> (SSHError -> String) -> ([SSHError] -> ShowS) -> Show SSHError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SSHError] -> ShowS
$cshowList :: [SSHError] -> ShowS
show :: SSHError -> String
$cshow :: SSHError -> String
showsPrec :: Int -> SSHError -> ShowS
$cshowsPrec :: Int -> SSHError -> ShowS
Show, Show SSHError
Typeable SSHError
Typeable SSHError
-> Show SSHError
-> (SSHError -> SomeException)
-> (SomeException -> Maybe SSHError)
-> (SSHError -> String)
-> Exception SSHError
SomeException -> Maybe SSHError
SSHError -> String
SSHError -> SomeException
forall e.
Typeable e
-> Show e
-> (e -> SomeException)
-> (SomeException -> Maybe e)
-> (e -> String)
-> Exception e
displayException :: SSHError -> String
$cdisplayException :: SSHError -> String
fromException :: SomeException -> Maybe SSHError
$cfromException :: SomeException -> Maybe SSHError
toException :: SSHError -> SomeException
$ctoException :: SSHError -> SomeException
$cp2Exception :: Show SSHError
$cp1Exception :: Typeable SSHError
Exception)
type SSHCommand = Maybe (String, [String])
sshRunCmd :: String -> SSHCommand -> (String, [String]) -> IO ()
sshRunCmd :: String -> SSHCommand -> (String, [String]) -> IO ()
sshRunCmd String
addr SSHCommand
sshCmd (String
cmd, [String]
args) = Int -> (IO () -> IO (Either SSHError ())) -> IO () -> IO ()
forall e (m :: * -> *) a.
(Show e, MonadIO m) =>
Int -> (m a -> m (Either e a)) -> m a -> m a
retryRepeated Int
10 (forall a.
(MonadCatch IO, Exception SSHError) =>
IO a -> IO (Either SSHError a)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try @IO @SSHError) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
result :: (ExitCode, String, String)
result@(ExitCode
exit, String
_, String
_) <- CreateProcess -> String -> IO (ExitCode, String, String)
readCreateProcessWithExitCode (String -> [String] -> CreateProcess
proc String
ssh [String]
sshArgs) String
""
case ExitCode
exit of
ExitCode
ExitSuccess -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
ExitCode
_ -> SSHError -> IO ()
forall (m :: * -> *) e a.
(MonadThrow m, MonadIO m, Exception e) =>
e -> m a
Log.throw (String -> (ExitCode, String, String) -> SSHError
SSHError String
addr (ExitCode, String, String)
result)
where
(String
ssh, [String]
sshOpts) = (String, [String]) -> SSHCommand -> (String, [String])
forall a. a -> Maybe a -> a
fromMaybe (String, [String])
defaultCmd SSHCommand
sshCmd
sshArgs :: [String]
sshArgs = [String]
sshOpts [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ [ String
addr
, String -> [String] -> String
shellEsc String
"sh"
[ String
"-c"
, String -> [String] -> String
shellEsc String
"nohup" (String
cmd String -> [String] -> [String]
forall a. a -> [a] -> [a]
: [String]
args)
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" &"
]
]
defaultCmd :: (String, [String])
defaultCmd = (String
"ssh", [String
"-f", String
"-o", String
"UserKnownHostsFile /dev/null"])