{-# LANGUAGE DeriveAnyClass    #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}
{-# LANGUAGE TypeApplications  #-}

module Hyperion.Slurm.Sbatch where

import           Control.Monad.Catch   (Exception)
import           Data.Attoparsec.Text  (Parser, parseOnly, takeWhile1)
import           Data.Char             (isSpace)
import           Data.Maybe            (catMaybes)
import           Data.Text             (Text)
import qualified Data.Text             as T
import           Data.Time.Clock       (NominalDiffTime)
import qualified Hyperion.Log          as Log
import           Hyperion.Slurm.JobId  (JobId (..))
import           Hyperion.Util         (hour)
import           System.Directory      (createDirectoryIfMissing)
import           System.Exit           (ExitCode (..))
import           System.FilePath.Posix (takeDirectory)
import           System.Process        (readCreateProcessWithExitCode, shell)

-- | Error from running @sbatch@. The 'String's are the contents of 'stdout'
-- and 'stderr' from @sbatch@.
data SbatchError = SbatchError
  { SbatchError -> (ExitCode, String, String)
exitCodeStdinStderr :: (ExitCode, String, String)
  , SbatchError -> String
input :: String
  } deriving (Int -> SbatchError -> ShowS
[SbatchError] -> ShowS
SbatchError -> String
(Int -> SbatchError -> ShowS)
-> (SbatchError -> String)
-> ([SbatchError] -> ShowS)
-> Show SbatchError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SbatchError] -> ShowS
$cshowList :: [SbatchError] -> ShowS
show :: SbatchError -> String
$cshow :: SbatchError -> String
showsPrec :: Int -> SbatchError -> ShowS
$cshowsPrec :: Int -> SbatchError -> ShowS
Show, Show SbatchError
Typeable SbatchError
Typeable SbatchError
-> Show SbatchError
-> (SbatchError -> SomeException)
-> (SomeException -> Maybe SbatchError)
-> (SbatchError -> String)
-> Exception SbatchError
SomeException -> Maybe SbatchError
SbatchError -> String
SbatchError -> SomeException
forall e.
Typeable e
-> Show e
-> (e -> SomeException)
-> (SomeException -> Maybe e)
-> (e -> String)
-> Exception e
displayException :: SbatchError -> String
$cdisplayException :: SbatchError -> String
fromException :: SomeException -> Maybe SbatchError
$cfromException :: SomeException -> Maybe SbatchError
toException :: SbatchError -> SomeException
$ctoException :: SbatchError -> SomeException
$cp2Exception :: Show SbatchError
$cp1Exception :: Typeable SbatchError
Exception)

-- | Type representing possible options for @sbatch@. Map 1-to-1 to @sbatch@
-- options, so see @man sbatch@ for details.
data SbatchOptions = SbatchOptions
  {
  -- | Job name (\"--job-name\")
    SbatchOptions -> Maybe Text
jobName       :: Maybe Text
  -- | Working directory for the job (\"--D\")
  , SbatchOptions -> Maybe String
chdir         :: Maybe FilePath
  -- | Where to direct 'stdout' of the job (\"--output\")
  , SbatchOptions -> Maybe String
output        :: Maybe FilePath
  -- | Number of nodes (\"--nodes\")
  , SbatchOptions -> Int
nodes         :: Int
  -- | Number of tasks per node (\"--ntasks-per-node\")
  , SbatchOptions -> Int
nTasksPerNode :: Int
  -- | Job time limit (\"--time\")
  , SbatchOptions -> NominalDiffTime
time          :: NominalDiffTime
  -- | Memory per node, use suffix K,M,G, or T to define the units. (\"--mem\")
  , SbatchOptions -> Maybe Text
mem           :: Maybe Text
  -- | (\"--mail-type\")
  , SbatchOptions -> Maybe Text
mailType      :: Maybe Text
  -- | (\"--mail-user\")
  , SbatchOptions -> Maybe Text
mailUser      :: Maybe Text
  -- | @SLURM@ partition (\"--partition\")
  , SbatchOptions -> Maybe Text
partition     :: Maybe Text
  -- | (\"--constraint")
  , SbatchOptions -> Maybe Text
constraint    :: Maybe Text
  -- | (\"--account")
  , SbatchOptions -> Maybe Text
account       :: Maybe Text
  -- | (\"--qos")
  , SbatchOptions -> Maybe Text
qos           :: Maybe Text
  } deriving (Int -> SbatchOptions -> ShowS
[SbatchOptions] -> ShowS
SbatchOptions -> String
(Int -> SbatchOptions -> ShowS)
-> (SbatchOptions -> String)
-> ([SbatchOptions] -> ShowS)
-> Show SbatchOptions
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SbatchOptions] -> ShowS
$cshowList :: [SbatchOptions] -> ShowS
show :: SbatchOptions -> String
$cshow :: SbatchOptions -> String
showsPrec :: Int -> SbatchOptions -> ShowS
$cshowsPrec :: Int -> SbatchOptions -> ShowS
Show)

-- | Default 'SbatchOptions'. Request 1 task on 1 node for 24 hrs, everything else
-- unspecified.
defaultSbatchOptions :: SbatchOptions
defaultSbatchOptions :: SbatchOptions
defaultSbatchOptions = SbatchOptions :: Maybe Text
-> Maybe String
-> Maybe String
-> Int
-> Int
-> NominalDiffTime
-> Maybe Text
-> Maybe Text
-> Maybe Text
-> Maybe Text
-> Maybe Text
-> Maybe Text
-> Maybe Text
-> SbatchOptions
SbatchOptions
  { jobName :: Maybe Text
jobName         = Maybe Text
forall a. Maybe a
Nothing
  , chdir :: Maybe String
chdir           = Maybe String
forall a. Maybe a
Nothing
  , output :: Maybe String
output          = Maybe String
forall a. Maybe a
Nothing
  , nodes :: Int
nodes           = Int
1
  , nTasksPerNode :: Int
nTasksPerNode   = Int
1
  , time :: NominalDiffTime
time            = NominalDiffTime
24NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
*NominalDiffTime
hour
  , mem :: Maybe Text
mem             = Maybe Text
forall a. Maybe a
Nothing
  , mailType :: Maybe Text
mailType        = Maybe Text
forall a. Maybe a
Nothing
  , mailUser :: Maybe Text
mailUser        = Maybe Text
forall a. Maybe a
Nothing
  , partition :: Maybe Text
partition       = Maybe Text
forall a. Maybe a
Nothing
  , constraint :: Maybe Text
constraint      = Maybe Text
forall a. Maybe a
Nothing
  , account :: Maybe Text
account         = Maybe Text
forall a. Maybe a
Nothing
  , qos :: Maybe Text
qos             = Maybe Text
forall a. Maybe a
Nothing
  }

-- | Convert 'SbatchOptions' to a string of options for @sbatch@
sBatchOptionString :: SbatchOptions -> String
sBatchOptionString :: SbatchOptions -> String
sBatchOptionString SbatchOptions{Int
Maybe String
Maybe Text
NominalDiffTime
qos :: Maybe Text
account :: Maybe Text
constraint :: Maybe Text
partition :: Maybe Text
mailUser :: Maybe Text
mailType :: Maybe Text
mem :: Maybe Text
time :: NominalDiffTime
nTasksPerNode :: Int
nodes :: Int
output :: Maybe String
chdir :: Maybe String
jobName :: Maybe Text
qos :: SbatchOptions -> Maybe Text
account :: SbatchOptions -> Maybe Text
constraint :: SbatchOptions -> Maybe Text
partition :: SbatchOptions -> Maybe Text
mailUser :: SbatchOptions -> Maybe Text
mailType :: SbatchOptions -> Maybe Text
mem :: SbatchOptions -> Maybe Text
time :: SbatchOptions -> NominalDiffTime
nTasksPerNode :: SbatchOptions -> Int
nodes :: SbatchOptions -> Int
output :: SbatchOptions -> Maybe String
chdir :: SbatchOptions -> Maybe String
jobName :: SbatchOptions -> Maybe Text
..} =
  [String] -> String
unwords [ String
opt String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
val | (String
opt, Just String
val) <- [(String, Maybe String)]
optPairs]
  where
    optPairs :: [(String, Maybe String)]
optPairs =
      [ (String
"--job-name",        (Text -> String) -> Maybe Text -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> String
T.unpack Maybe Text
jobName)
      -- sbatch changed this option from workdir to chdir
      -- at some point, so we need to use the short name
      , (String
"-D",                Maybe String
chdir)
      , (String
"--output",          Maybe String
output)
      , (String
"--nodes",           String -> Maybe String
forall a. a -> Maybe a
Just (Int -> String
forall a. Show a => a -> String
show Int
nodes))
      , (String
"--ntasks-per-node", String -> Maybe String
forall a. a -> Maybe a
Just (Int -> String
forall a. Show a => a -> String
show Int
nTasksPerNode))
      , (String
"--time",            String -> Maybe String
forall a. a -> Maybe a
Just (NominalDiffTime -> String
formatRuntime NominalDiffTime
time))
      , (String
"--mem",             (Text -> String) -> Maybe Text -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> String
T.unpack Maybe Text
mem)
      , (String
"--mail-type",       (Text -> String) -> Maybe Text -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> String
T.unpack Maybe Text
mailType)
      , (String
"--mail-user",       (Text -> String) -> Maybe Text -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> String
T.unpack Maybe Text
mailUser)
      , (String
"--partition",       (Text -> String) -> Maybe Text -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> String
T.unpack Maybe Text
partition)
      , (String
"--constraint",      (Text -> String) -> Maybe Text -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> String
T.unpack Maybe Text
constraint)
      , (String
"--account",         (Text -> String) -> Maybe Text -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> String
T.unpack Maybe Text
account)
      , (String
"--qos",             (Text -> String) -> Maybe Text -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Text -> String
T.unpack Maybe Text
qos)
      ]

sbatchOutputParser :: Parser JobId
sbatchOutputParser :: Parser JobId
sbatchOutputParser = Text -> JobId
JobId (Text -> JobId) -> Parser Text Text -> Parser JobId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Parser Text Text
"Submitted batch job " Parser Text Text -> Parser Text Text -> Parser Text Text
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (Char -> Bool) -> Parser Text Text
takeWhile1 (Bool -> Bool
not (Bool -> Bool) -> (Char -> Bool) -> Char -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> Bool
isSpace) Parser Text Text -> Parser Text Text -> Parser Text Text
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Parser Text Text
"\n")

-- | Runs @sbatch@ on a batch file with options pulled from 'SbatchOptions' and
-- script given as the 'String' input parameter. If 'sbatch' exists with failure
-- then throws 'SbatchError'.
sbatchScript :: SbatchOptions -> String -> IO JobId
sbatchScript :: SbatchOptions -> String -> IO JobId
sbatchScript SbatchOptions
opts String
script = do
  (String -> IO ()) -> [String] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Bool -> String -> IO ()
createDirectoryIfMissing Bool
True) ([String] -> IO ()) -> [String] -> IO ()
forall a b. (a -> b) -> a -> b
$
    [Maybe String] -> [String]
forall a. [Maybe a] -> [a]
catMaybes [ SbatchOptions -> Maybe String
chdir SbatchOptions
opts
              , ShowS -> Maybe String -> Maybe String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ShowS
takeDirectory (SbatchOptions -> Maybe String
output SbatchOptions
opts)
              ]
  result :: (ExitCode, String, String)
result@(ExitCode
exit, String
out, String
_) <- CreateProcess -> String -> IO (ExitCode, String, String)
readCreateProcessWithExitCode (String -> CreateProcess
shell String
pipeToSbatch) String
""
  case (ExitCode
exit, Parser JobId -> Text -> Either String JobId
forall a. Parser a -> Text -> Either String a
parseOnly Parser JobId
sbatchOutputParser (String -> Text
T.pack String
out)) of
    (ExitCode
ExitSuccess, Right JobId
j) -> JobId -> IO JobId
forall (m :: * -> *) a. Monad m => a -> m a
return JobId
j
    (ExitCode, Either String JobId)
_                      -> SbatchError -> IO JobId
forall (m :: * -> *) e a.
(MonadThrow m, MonadIO m, Exception e) =>
e -> m a
Log.throw ((ExitCode, String, String) -> String -> SbatchError
SbatchError (ExitCode, String, String)
result String
pipeToSbatch)
  where
    pipeToSbatch :: String
pipeToSbatch = String
"printf '" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
wrappedScript String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"' | sbatch " String -> ShowS
forall a. [a] -> [a] -> [a]
++ SbatchOptions -> String
sBatchOptionString SbatchOptions
opts
    wrappedScript :: String
wrappedScript = String
"#!/bin/sh\n " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
script

-- | Formats 'NominalDiffTime' into @hh:mm:ss@.
formatRuntime :: NominalDiffTime -> String
formatRuntime :: NominalDiffTime -> String
formatRuntime NominalDiffTime
t = Integer -> String
forall a. Show a => a -> String
padNum Integer
h String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
":" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
padNum Integer
m String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
":" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
padNum Integer
s
  where
    h :: Integer
h = NominalDiffTime -> NominalDiffTime -> Integer
forall t. Real t => t -> t -> Integer
quotBy NominalDiffTime
3600 NominalDiffTime
t
    m :: Integer
m = Integer -> Integer -> Integer
forall t. Real t => t -> t -> t
remBy Integer
60 (NominalDiffTime -> NominalDiffTime -> Integer
forall t. Real t => t -> t -> Integer
quotBy NominalDiffTime
60 NominalDiffTime
t)
    s :: Integer
s = Integer -> Integer -> Integer
forall t. Real t => t -> t -> t
remBy Integer
60 (NominalDiffTime -> NominalDiffTime -> Integer
forall t. Real t => t -> t -> Integer
quotBy NominalDiffTime
1 NominalDiffTime
t)

    padNum :: a -> String
padNum a
x = case String -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (a -> String
forall a. Show a => a -> String
show a
x) of
      Int
1 -> Char
'0' Char -> ShowS
forall a. a -> [a] -> [a]
: a -> String
forall a. Show a => a -> String
show a
x
      Int
_ -> a -> String
forall a. Show a => a -> String
show a
x

    quotBy :: Real t => t -> t -> Integer
    quotBy :: t -> t -> Integer
quotBy t
d t
n = Rational -> Integer
forall a b. (RealFrac a, Integral b) => a -> b
truncate (t -> Rational
forall a. Real a => a -> Rational
toRational t
n Rational -> Rational -> Rational
forall a. Fractional a => a -> a -> a
/ t -> Rational
forall a. Real a => a -> Rational
toRational t
d)

    remBy :: Real t => t -> t -> t
    remBy :: t -> t -> t
remBy t
d t
n = t
n t -> t -> t
forall a. Num a => a -> a -> a
- (Integer -> t
forall a. Num a => Integer -> a
fromInteger Integer
f) t -> t -> t
forall a. Num a => a -> a -> a
* t
d where
      f :: Integer
f = t -> t -> Integer
forall t. Real t => t -> t -> Integer
quotBy t
d t
n

-- | Runs the command given by 'FilePath' with arguments @['Text']@ in
-- @sbatch@ script via 'sbatchScript'. If 'sbatch' fails then throws
-- 'SbatchError'.
sbatchCommand :: SbatchOptions -> FilePath -> [Text] -> IO JobId
sbatchCommand :: SbatchOptions -> String -> [Text] -> IO JobId
sbatchCommand SbatchOptions
opts String
cmd [Text]
args = SbatchOptions -> String -> IO JobId
sbatchScript SbatchOptions
opts String
script
  where
    script :: String
script = String
cmd String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ [String] -> String
unwords ((Text -> String) -> [Text] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map Text -> String
quote [Text]
args)
    quote :: Text -> String
quote Text
a = String
"\"" String -> ShowS
forall a. [a] -> [a] -> [a]
++ Text -> String
T.unpack Text
a String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"\""