{-# OPTIONS_GHC -fno-warn-orphans  #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE RecordWildCards       #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE StaticPointers        #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}

module Hyperion.HasWorkers where

import           Control.Distributed.Process (Closure, Process)
import           Control.Monad.Base          (MonadBase (..))
import           Control.Monad.IO.Class      (MonadIO)
import           Control.Monad.Reader        (ReaderT (..), asks, runReaderT)
import           Data.Binary                 (Binary)
import           Data.Constraint             (Dict (..))
import           Data.Typeable               (Typeable)
import           Hyperion.Remote             (RemoteProcessRunner,
                                              WorkerLauncher,
                                              mkSerializableClosureProcess,
                                              withRemoteRunProcess)
import           Hyperion.Slurm              (JobId)
import           Hyperion.Static             (Serializable, Static (..))

-- | A class for monads that can run things in the 'Process' monad,
-- and have access to a 'WorkerLauncher'. An instance of 'HasWorkers'
-- can use 'remoteBind' and 'remoteEval' to run computations in worker
-- processes at remote locations.
class (MonadBase Process m, MonadUnliftProcess m, MonadIO m) => HasWorkers m where
  getWorkerLauncher :: m (WorkerLauncher JobId)

-- | Trivial orphan instance of 'MonadBase' for 'Process'.
instance MonadBase Process Process where
  liftBase :: Process α -> Process α
liftBase = Process α -> Process α
forall a. a -> a
id

-- | A class for Monads that can run continuations in the Process
-- monad, modeled after MonadUnliftIO
-- (https://hackage.haskell.org/package/unliftio-core-0.2.0.1/docs/Control-Monad-IO-Unlift.html).
class MonadUnliftProcess m where
  withRunInProcess :: ((forall a. m a -> Process a) -> Process b) -> m b

instance MonadUnliftProcess Process where
  withRunInProcess :: ((forall α. Process α -> Process α) -> Process b) -> Process b
withRunInProcess (forall α. Process α -> Process α) -> Process b
go = (forall α. Process α -> Process α) -> Process b
go forall a. a -> a
forall α. Process α -> Process α
id

instance MonadUnliftProcess m => MonadUnliftProcess (ReaderT r m) where
  withRunInProcess :: ((forall a. ReaderT r m a -> Process a) -> Process b)
-> ReaderT r m b
withRunInProcess (forall a. ReaderT r m a -> Process a) -> Process b
inner =
    (r -> m b) -> ReaderT r m b
forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT ((r -> m b) -> ReaderT r m b) -> (r -> m b) -> ReaderT r m b
forall a b. (a -> b) -> a -> b
$ \r
r ->
    ((forall a. m a -> Process a) -> Process b) -> m b
forall (m :: * -> *) b.
MonadUnliftProcess m =>
((forall a. m a -> Process a) -> Process b) -> m b
withRunInProcess (((forall a. m a -> Process a) -> Process b) -> m b)
-> ((forall a. m a -> Process a) -> Process b) -> m b
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> Process a
run ->
    (forall a. ReaderT r m a -> Process a) -> Process b
inner (m a -> Process a
forall a. m a -> Process a
run (m a -> Process a)
-> (ReaderT r m a -> m a) -> ReaderT r m a -> Process a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ReaderT r m a -> r -> m a) -> r -> ReaderT r m a -> m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT r m a -> r -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT r
r)

-- | A class indicating that type 'env' contains a 'WorkerLauncher'.
class HasWorkerLauncher env where
  toWorkerLauncher :: env -> WorkerLauncher JobId

-- | This is our main instance for 'HasWorkers'. The 'Cluster' and
-- 'Job' monads are both cases of 'ReaderT env Process' with different
-- 'env's.
instance HasWorkerLauncher env => HasWorkers (ReaderT env Process) where
  getWorkerLauncher :: ReaderT env Process (WorkerLauncher JobId)
getWorkerLauncher = (env -> WorkerLauncher JobId)
-> ReaderT env Process (WorkerLauncher JobId)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks env -> WorkerLauncher JobId
forall env. HasWorkerLauncher env => env -> WorkerLauncher JobId
toWorkerLauncher

-- | Uses the 'WorkerLauncher' to get a 'RemoteProcessRunner' and pass it
-- to the given continuation.
--
-- This function is essentially a composition of 'getWorkerLauncher' with
-- 'withRemoteRunProcess', lifted from 'Process' to 'm' using 'MonadUnliftProcess'.
--
-- We use the machinery of 'MonadUnliftProcess' because
-- 'withRemoteRunProcess' expects something that runs in the 'Process'
-- monad, not in 'm'. Our main use case is when 'm ~ ReaderT env
-- Process', where 'env' is an instance of 'HasWorkerLauncher'.
--
withRemoteRun :: HasWorkers m => (RemoteProcessRunner -> m a) -> m a
withRemoteRun :: (RemoteProcessRunner -> m a) -> m a
withRemoteRun RemoteProcessRunner -> m a
go = do
  WorkerLauncher JobId
workerLauncher <- m (WorkerLauncher JobId)
forall (m :: * -> *). HasWorkers m => m (WorkerLauncher JobId)
getWorkerLauncher
  ((forall a. m a -> Process a) -> Process a) -> m a
forall (m :: * -> *) b.
MonadUnliftProcess m =>
((forall a. m a -> Process a) -> Process b) -> m b
withRunInProcess (((forall a. m a -> Process a) -> Process a) -> m a)
-> ((forall a. m a -> Process a) -> Process a) -> m a
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> Process a
runInProcess ->
    WorkerLauncher JobId
-> (RemoteProcessRunner -> Process a) -> Process a
forall j a.
Show j =>
WorkerLauncher j -> (RemoteProcessRunner -> Process a) -> Process a
withRemoteRunProcess WorkerLauncher JobId
workerLauncher ((RemoteProcessRunner -> Process a) -> Process a)
-> (RemoteProcessRunner -> Process a) -> Process a
forall a b. (a -> b) -> a -> b
$ \RemoteProcessRunner
remoteRunProcess ->
    m a -> Process a
forall a. m a -> Process a
runInProcess (RemoteProcessRunner -> m a
go RemoteProcessRunner
remoteRunProcess)

-- | Compute a closure at a remote location. The user supplies an 'm
-- (Closure (...))' which is only evaluated when a remote worker
-- becomes available (for example after the worker makes it out of the
-- Slurm queue).
remoteEvalWithDictM
  :: (HasWorkers m, Serializable b)
  => Closure (Dict (Serializable b))
  -> m (Closure (Process b))
  -> m b
remoteEvalWithDictM :: Closure (Dict (Serializable b)) -> m (Closure (Process b)) -> m b
remoteEvalWithDictM Closure (Dict (Serializable b))
bDict m (Closure (Process b))
mb = do
  SerializableClosureProcess b
scp <- ((forall a. m a -> Process a)
 -> Process (SerializableClosureProcess b))
-> m (SerializableClosureProcess b)
forall (m :: * -> *) b.
MonadUnliftProcess m =>
((forall a. m a -> Process a) -> Process b) -> m b
withRunInProcess (((forall a. m a -> Process a)
  -> Process (SerializableClosureProcess b))
 -> m (SerializableClosureProcess b))
-> ((forall a. m a -> Process a)
    -> Process (SerializableClosureProcess b))
-> m (SerializableClosureProcess b)
forall a b. (a -> b) -> a -> b
$ \forall a. m a -> Process a
runInProcess -> Closure (Dict (Serializable b))
-> Process (Closure (Process b))
-> Process (SerializableClosureProcess b)
forall b.
Typeable b =>
Closure (Dict (Serializable b))
-> Process (Closure (Process b))
-> Process (SerializableClosureProcess b)
mkSerializableClosureProcess Closure (Dict (Serializable b))
bDict (m (Closure (Process b)) -> Process (Closure (Process b))
forall a. m a -> Process a
runInProcess m (Closure (Process b))
mb)
  (RemoteProcessRunner -> m b) -> m b
forall (m :: * -> *) a.
HasWorkers m =>
(RemoteProcessRunner -> m a) -> m a
withRemoteRun (\RemoteProcessRunner
remoteRun -> Process b -> m b
forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
liftBase (SerializableClosureProcess b -> Process b
RemoteProcessRunner
remoteRun SerializableClosureProcess b
scp))

-- | Evaluate a 'Closure' at a remote location, assuming a 'Static
-- (Binary b)' instance. The Closure itself is ony computed when a
-- worker becomes available.
remoteEvalM
  :: (HasWorkers m, Static (Binary b), Typeable b)
  => m (Closure (Process b))
  -> m b
remoteEvalM :: m (Closure (Process b)) -> m b
remoteEvalM = Closure (Dict (Serializable b)) -> m (Closure (Process b)) -> m b
forall (m :: * -> *) b.
(HasWorkers m, Serializable b) =>
Closure (Dict (Serializable b)) -> m (Closure (Process b)) -> m b
remoteEvalWithDictM Closure (Dict (Serializable b))
forall (c :: Constraint). Static c => Closure (Dict c)
closureDict

-- | Evaluate a 'Closure' at a remote location.
remoteEval
  :: (HasWorkers m, Static (Binary b), Typeable b)
  => Closure (Process b)
  -> m b
remoteEval :: Closure (Process b) -> m b
remoteEval = m (Closure (Process b)) -> m b
forall (m :: * -> *) b.
(HasWorkers m, Static (Binary b), Typeable b) =>
m (Closure (Process b)) -> m b
remoteEvalM (m (Closure (Process b)) -> m b)
-> (Closure (Process b) -> m (Closure (Process b)))
-> Closure (Process b)
-> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Closure (Process b) -> m (Closure (Process b))
forall (f :: * -> *) a. Applicative f => a -> f a
pure