module Hyperion.TokenPool where

import           Control.Concurrent.STM      (atomically, check)
import           Control.Concurrent.STM.TVar (TVar, modifyTVar, newTVarIO,
                                              readTVar)
import           Control.Monad.Catch         (MonadMask, bracket)
import           Control.Monad.IO.Class      (MonadIO, liftIO)

-- | A 'TokenPool' keeps track of the number of resources of some
-- kind, represented by "tokens". 'TokenPool (Just var)' indicates a
-- limited number of tokens, and 'var' contains the number of
-- available tokens. When 'var' contains 0, processes wishing to use a
-- token must block until one becomes available (see
-- 'withToken'). 'TokenPool Nothing' represents an unlimited number of
-- tokens.
newtype TokenPool = TokenPool (Maybe (TVar Int))

-- | Create a new 'TokenPool' containing the given number of
-- tokens. 'Nothing' indicates an unlimited pool.
newTokenPool :: Maybe Int -> IO TokenPool
newTokenPool :: Maybe Int -> IO TokenPool
newTokenPool (Just Int
n) = Maybe (TVar Int) -> TokenPool
TokenPool (Maybe (TVar Int) -> TokenPool)
-> (TVar Int -> Maybe (TVar Int)) -> TVar Int -> TokenPool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TVar Int -> Maybe (TVar Int)
forall a. a -> Maybe a
Just (TVar Int -> TokenPool) -> IO (TVar Int) -> IO TokenPool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO (TVar Int)
forall a. a -> IO (TVar a)
newTVarIO Int
n
newTokenPool Maybe Int
Nothing  = TokenPool -> IO TokenPool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TokenPool -> IO TokenPool) -> TokenPool -> IO TokenPool
forall a b. (a -> b) -> a -> b
$ Maybe (TVar Int) -> TokenPool
TokenPool Maybe (TVar Int)
forall a. Maybe a
Nothing

-- | Remove a token from the pool, run the given process, and then
-- replace the token. If no token is initially available, block until
-- one becomes available.
withToken :: (MonadIO m, MonadMask m) => TokenPool -> m a -> m a
withToken :: TokenPool -> m a -> m a
withToken (TokenPool Maybe (TVar Int)
Nothing) m a
go = m a
go
withToken (TokenPool (Just TVar Int
tokenVar)) m a
go =
  m () -> (() -> m ()) -> (() -> m a) -> m a
forall (m :: * -> *) a c b.
MonadMask m =>
m a -> (a -> m c) -> (a -> m b) -> m b
bracket (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ()
getToken) (IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (() -> IO ()) -> () -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. () -> IO ()
forall p. p -> IO ()
replaceToken) (\()
_ -> m a
go)
  where
    getToken :: IO ()
getToken = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Int
tokens <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
tokenVar
      Bool -> STM ()
check (Int
tokens Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0)
      TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar Int
tokenVar (Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1)
      () -> STM ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    replaceToken :: p -> IO ()
replaceToken p
_ =
      STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar TVar Int
tokenVar (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)