{-# LANGUAGE RankNTypes #-}

module Hyperion.CallClosure where

import           Control.Distributed.Process
import           Control.Distributed.Process.Closure                  (SerializableDict,
                                                                       bindCP,
                                                                       returnCP,
                                                                       sdictUnit,
                                                                       seqCP)
import           Control.Distributed.Process.Internal.Closure.BuiltIn (cpDelayed)
import           Control.Distributed.Static                           (closureApply,
                                                                       staticClosure,
                                                                       staticLabel)
import           Data.Binary                                          (Binary,
                                                                       encode)
import           Data.ByteString.Lazy                                 (ByteString)
import           Data.Typeable                                        (Typeable)

-- | The purpose of this module is to generalize 'call' from
-- 'Control.Distributed.Process' so that it takes a 'Closure
-- (SerializableDict a)' instead of a 'Static (SerializableDict
-- a)'. Note that this is a strict generalization because any 'Static
-- a' can be turned into 'Closure a' via 'staticClosure', while a
-- 'Closure a' cannot be turned into a 'Static a' in general.
--
-- Note: The extra flexibility afforded by call' is needed in
-- conjunction with the 'Hyperion.Static (KnownNat j)'
-- instance. In that case, we cannot construct a
-- 'Control.Distributed.Static.Static (Dict (KnownNat j))', but we can
-- construct a 'Closure (Dict (KnownNat j))'. NB: The name 'Static' is
-- used in two places: 'Control.Distributed.Static.Static' and
-- 'Hyperion.Static'. The former is a datatype and the latter
-- is a typeclass.
--
-- Most of the code here has been copied from
-- 'Control.Distributed.Process' and
-- 'Control.Distributed.Process.Closure', with small modifications.

-- | 'CP' version of 'send' that uses a 'Closure (SerializableDict a)'
-- instead of 'Static (SerializableDict a)'
cpSend' :: forall a . Closure (SerializableDict a) -> ProcessId -> Closure (a -> Process ())
cpSend' :: Closure (SerializableDict a)
-> ProcessId -> Closure (a -> Process ())
cpSend' Closure (SerializableDict a)
dict ProcessId
pid =
  Static (SerializableDict a -> ProcessId -> a -> Process ())
-> Closure (SerializableDict a -> ProcessId -> a -> Process ())
forall a. Static a -> Closure a
staticClosure Static (SerializableDict a -> ProcessId -> a -> Process ())
forall a.
Static (SerializableDict a -> ProcessId -> a -> Process ())
sendDictStatic Closure (SerializableDict a -> ProcessId -> a -> Process ())
-> Closure (SerializableDict a)
-> Closure (ProcessId -> a -> Process ())
forall a b. Closure (a -> b) -> Closure a -> Closure b
`closureApply`
  Closure (SerializableDict a)
dict Closure (ProcessId -> a -> Process ())
-> Closure ProcessId -> Closure (a -> Process ())
forall a b. Closure (a -> b) -> Closure a -> Closure b
`closureApply`
  Static (ByteString -> ProcessId) -> ByteString -> Closure ProcessId
forall a. Static (ByteString -> a) -> ByteString -> Closure a
closure Static (ByteString -> ProcessId)
decodeProcessIdStatic (ProcessId -> ByteString
forall a. Binary a => a -> ByteString
encode ProcessId
pid)
  where
    sendDictStatic :: Static (SerializableDict a -> ProcessId -> a -> Process ())
    sendDictStatic :: Static (SerializableDict a -> ProcessId -> a -> Process ())
sendDictStatic = String
-> Static (SerializableDict a -> ProcessId -> a -> Process ())
forall a. String -> Static a
staticLabel String
"$sendDict"

    decodeProcessIdStatic :: Static (ByteString -> ProcessId)
    decodeProcessIdStatic :: Static (ByteString -> ProcessId)
decodeProcessIdStatic = String -> Static (ByteString -> ProcessId)
forall a. String -> Static a
staticLabel String
"$decodeProcessId"

-- | 'call' that uses a 'Closure (SerializableDict a)' instead of a 'Static (SerializableDict a)'.
call'
  :: (Binary a, Typeable a)
  => Closure (SerializableDict a)
  -> NodeId
  -> Closure (Process a)
  -> Process a
call' :: Closure (SerializableDict a)
-> NodeId -> Closure (Process a) -> Process a
call' Closure (SerializableDict a)
dict NodeId
nid Closure (Process a)
proc = do
  ProcessId
us <- Process ProcessId
getSelfPid
  (ProcessId
pid, MonitorRef
mRef) <- NodeId -> Closure (Process ()) -> Process (ProcessId, MonitorRef)
spawnMonitor NodeId
nid (Closure (Process a)
proc Closure (Process a) -> CP a () -> Closure (Process ())
forall a b.
(Typeable a, Typeable b) =>
Closure (Process a) -> CP a b -> Closure (Process b)
`bindCP`
                                   Closure (SerializableDict a) -> ProcessId -> CP a ()
forall a.
Closure (SerializableDict a)
-> ProcessId -> Closure (a -> Process ())
cpSend' Closure (SerializableDict a)
dict ProcessId
us Closure (Process ())
-> Closure (Process ()) -> Closure (Process ())
forall a b.
(Typeable a, Typeable b) =>
Closure (Process a) -> Closure (Process b) -> Closure (Process b)
`seqCP`
                                   -- Delay so the process does not terminate
                                   -- before the response arrives.
                                   ProcessId -> Closure (Process ()) -> Closure (Process ())
cpDelayed ProcessId
us (Static (SerializableDict ()) -> () -> Closure (Process ())
forall a.
Serializable a =>
Static (SerializableDict a) -> a -> Closure (Process a)
returnCP Static (SerializableDict ())
sdictUnit ())
                                  )
  Either DiedReason a
mResult <- [Match (Either DiedReason a)] -> Process (Either DiedReason a)
forall b. [Match b] -> Process b
receiveWait
    [ (a -> Process (Either DiedReason a)) -> Match (Either DiedReason a)
forall a b. Serializable a => (a -> Process b) -> Match b
match ((a -> Process (Either DiedReason a))
 -> Match (Either DiedReason a))
-> (a -> Process (Either DiedReason a))
-> Match (Either DiedReason a)
forall a b. (a -> b) -> a -> b
$ \a
a -> ProcessId -> () -> Process ()
forall a. Serializable a => ProcessId -> a -> Process ()
usend ProcessId
pid () Process ()
-> Process (Either DiedReason a) -> Process (Either DiedReason a)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Either DiedReason a -> Process (Either DiedReason a)
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Either DiedReason a
forall a b. b -> Either a b
Right a
a)
    , (ProcessMonitorNotification -> Bool)
-> (ProcessMonitorNotification -> Process (Either DiedReason a))
-> Match (Either DiedReason a)
forall a b.
Serializable a =>
(a -> Bool) -> (a -> Process b) -> Match b
matchIf (\(ProcessMonitorNotification MonitorRef
ref ProcessId
_ DiedReason
_) -> MonitorRef
ref MonitorRef -> MonitorRef -> Bool
forall a. Eq a => a -> a -> Bool
== MonitorRef
mRef)
              (\(ProcessMonitorNotification MonitorRef
_ ProcessId
_ DiedReason
reason) -> Either DiedReason a -> Process (Either DiedReason a)
forall (m :: * -> *) a. Monad m => a -> m a
return (DiedReason -> Either DiedReason a
forall a b. a -> Either a b
Left DiedReason
reason))
    ]
  case Either DiedReason a
mResult of
    Right a
a  -> do
      -- Wait for the monitor message so that we the mailbox doesn't grow
      [Match ()] -> Process ()
forall b. [Match b] -> Process b
receiveWait
        [ (ProcessMonitorNotification -> Bool)
-> (ProcessMonitorNotification -> Process ()) -> Match ()
forall a b.
Serializable a =>
(a -> Bool) -> (a -> Process b) -> Match b
matchIf (\(ProcessMonitorNotification MonitorRef
ref ProcessId
_ DiedReason
_) -> MonitorRef
ref MonitorRef -> MonitorRef -> Bool
forall a. Eq a => a -> a -> Bool
== MonitorRef
mRef)
                  (\(ProcessMonitorNotification {}) -> () -> Process ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
        ]
      -- Clean up connection to pid
      ProcessId -> Process ()
reconnect ProcessId
pid
      a -> Process a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
    Left DiedReason
err ->
      String -> Process a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Process a) -> String -> Process a
forall a b. (a -> b) -> a -> b
$ String
"call: remote process died: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ DiedReason -> String
forall a. Show a => a -> String
show DiedReason
err