{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
module Numeric.Opto.Run (
RunOpts(..)
, hoistRunOpts
, ParallelOpts(..)
, hoistParallelOpts
, opto
, opto'
, optoNonSampling'
, optoNonSampling
, optoConduit, optoConduit'
, optoFold, optoFold'
, optoPar
, optoParChunk
, optoParNonSampling
, optoConduitPar
, optoConduitParChunk
, mean
) where
import Control.Applicative
import Control.Concurrent.STM.TBMQueue
import Control.Monad
import Control.Monad.State
import Control.Monad.Trans.Maybe
import Data.Bifunctor
import Data.Conduit
import Data.Conduit.TQueue
import Data.Default
import Data.Functor
import Data.Functor.Contravariant
import Data.Functor.Invariant
import Data.List
import Data.List.NonEmpty (NonEmpty(..))
import Data.Maybe
import Data.MonoTraversable
import Data.Semigroup.Foldable
import GHC.Natural
import Numeric.Opto.Core
import Numeric.Opto.Ref
import Numeric.Opto.Update
import UnliftIO
import UnliftIO.Concurrent
import qualified Data.Conduit as C
import qualified Data.List.NonEmpty as NE
data RunOpts m a = RO
{
roStopCond :: Diff a -> a -> m Bool
, roReport :: a -> m ()
, roLimit :: Maybe Int
, roBatch :: Int
, roFreq :: Maybe Int
}
data ParallelOpts m a = PO
{
poThreads :: Maybe Int
, poSplit :: Int
, poCombine :: NonEmpty a -> m a
, poPull :: Bool
}
instance Applicative m => Default (RunOpts m a) where
def = RO
{ roStopCond = \_ _ -> pure False
, roReport = \_ -> pure ()
, roLimit = Nothing
, roBatch = 1
, roFreq = Just 1
}
instance (Applicative m, Fractional a) => Default (ParallelOpts m a) where
def = PO
{ poThreads = Nothing
, poSplit = 1000
, poCombine = pure . mean
, poPull = True
}
instance Contravariant (RunOpts m) where
contramap f ro = ro
{ roStopCond = \d -> roStopCond ro (f d) . f
, roReport = roReport ro . f
}
instance Invariant (RunOpts m) where
invmap _ g = contramap g
instance Functor m => Invariant (ParallelOpts m) where
invmap f g po = po
{ poCombine = fmap f . poCombine po . fmap g
}
hoistRunOpts
:: (forall x. m x -> n x)
-> RunOpts m a
-> RunOpts n a
hoistRunOpts f ro = ro
{ roStopCond = \d -> f . roStopCond ro d
, roReport = f . roReport ro
}
hoistParallelOpts
:: (forall x. m x -> n x)
-> ParallelOpts m a
-> ParallelOpts n a
hoistParallelOpts f po = po
{ poCombine = f . poCombine po
}
opto
:: Monad m
=> RunOpts m a
-> m (Maybe r)
-> a
-> Opto m r a
-> m a
opto ro sampler x0 o = opto_ ro sampler x0 o const
{-# INLINE opto #-}
optoNonSampling
:: Monad m
=> RunOpts m a
-> a
-> Opto m () a
-> m a
optoNonSampling ro = opto ro (pure (Just ()))
{-# INLINE optoNonSampling #-}
opto'
:: Monad m
=> RunOpts m a
-> m (Maybe r)
-> a
-> Opto m r a
-> m (a, Opto m r a)
opto' ro sampler x0 o = opto_ ro sampler x0 o (liftA2 (,))
{-# INLINE opto' #-}
optoNonSampling'
:: Monad m
=> RunOpts m a
-> a
-> Opto m () a
-> m (a, Opto m () a)
optoNonSampling' ro = opto' ro (pure (Just ()))
{-# INLINE optoNonSampling' #-}
opto_
:: forall m r a q. Monad m
=> RunOpts m a
-> m (Maybe r)
-> a
-> Opto m r a
-> (m a -> m (Opto m r a) -> m q)
-> m q
opto_ RO{..} sampler x0 MkOpto{..} f = do
rS <- thawRef oInit
rX <- thawRef @m @a x0
optoLoop OL
{ olLimit = roLimit
, olBatch = roBatch
, olReportFreq = roFreq
, olInitialize = thawRef @m @a
, olUpdate = (.*+=)
, olRead = freezeRef
, olVar = rX
, olSample = sampler
, olUpdateState = oUpdate rS
, olStopCond = roStopCond
, olReportAct = roReport
}
f (freezeRef rX) (flip MkOpto oUpdate <$> freezeRef rS)
{-# INLINE opto_ #-}
data OptoLoop m r a c = OL
{ olLimit :: Maybe Int
, olBatch :: Int
, olReportFreq :: Maybe Int
, olInitialize :: a -> m (Ref m a)
, olUpdate :: Ref m a -> (c, a) -> m ()
, olRead :: Ref m a -> m a
, olVar :: Ref m a
, olSample :: m (Maybe r)
, olUpdateState :: r -> a -> m (c, a)
, olStopCond :: Diff a -> a -> m Bool
, olReportAct :: a -> m ()
}
optoLoop
:: forall m r a c. (Monad m, Linear c a)
=> OptoLoop m r a c
-> m ()
optoLoop OL{..} = go 0
where
go !i = when (limCheck i) $ do
!x <- olRead olVar
(exhausted, cg) <- batcher x
forM_ cg $ \(c, g) -> do
olUpdate olVar (c, g)
x' <- olRead olVar
when (reportCheck i) $
olReportAct x'
stopper <- olStopCond (c .* g) x'
when (not exhausted && not stopper) $
go (i + 1)
limCheck = case olLimit of
Nothing -> const True
Just l -> (< l)
reportCheck = case olReportFreq of
Nothing -> const False
Just r -> \i -> (i + 1) `mod` r == 0
batcher
| olBatch <= 1 = fmap (\y -> (isNothing y, y)) . runMaybeT . batchSingle
| otherwise = batchLoop
batchSingle !x = lift . (`olUpdateState` x) =<< MaybeT olSample
batchLoop !x = do
v <- olInitialize zeroL
k <- fmap isNothing . runMaybeT . replicateM olBatch $
lift . olUpdate v =<< batchSingle x
(k,) . Just . (1 :: c,) <$> olRead v
{-# INLINE optoLoop #-}
optoConduit
:: Monad m
=> RunOpts m a
-> a
-> Opto (ConduitT r a m) r a
-> ConduitT r a m ()
optoConduit ro x0 = void . optoConduit' ro x0
{-# INLINE optoConduit #-}
optoConduit'
:: Monad m
=> RunOpts m a
-> a
-> Opto (ConduitT r a m) r a
-> ConduitT r a m (Opto (ConduitT r a m) r a)
optoConduit' ro x0 o = opto_ ro' C.await x0 o (const id)
where
ro' = (hoistRunOpts lift ro)
{ roStopCond = \d x -> C.yield x *> lift (roStopCond ro d x) }
{-# INLINE optoConduit' #-}
optoFold
:: Monad m
=> RunOpts m a
-> a
-> Opto (StateT [r] m) r a
-> [r]
-> m (a, [r])
optoFold ro x0 o = runStateT (opto (hoistRunOpts lift ro) sampleState x0 o)
{-# INLINE optoFold #-}
optoFold'
:: Monad m
=> RunOpts m a
-> a
-> Opto (StateT [r] m) r a
-> [r]
-> m (a, [r], Opto (StateT [r] m) r a)
optoFold' ro x0 o = fmap shuffle
. runStateT (opto' (hoistRunOpts lift ro) sampleState x0 o)
where
shuffle ((x', o'), rs) = (x', rs, o')
{-# INLINE shuffle #-}
{-# INLINE optoFold' #-}
sampleState :: Monad m => StateT [r] m (Maybe r)
sampleState = state $ maybe (Nothing, []) (first Just) . uncons
{-# INLINE sampleState #-}
optoPar
:: forall m r a. MonadUnliftIO m
=> RunOpts m a
-> ParallelOpts m a
-> m (Maybe r)
-> a
-> Opto m r a
-> m a
optoPar ro po sampler x0 o = optoPar_ ro po x0 $ \hitStop lim x -> do
if lim > 0
then Just <$> do
let ro' = ro
{ roLimit = Just lim
, roReport = \_ -> pure ()
, roStopCond = \d x' -> do
sc <- roStopCond ro d x'
sc <$ when sc (writeIORef hitStop (Just x'))
, roFreq = Nothing
}
opto ro' sampler x o
else pure Nothing
{-# INLINE optoPar #-}
optoParNonSampling
:: MonadUnliftIO m
=> RunOpts m a
-> ParallelOpts m a
-> a
-> Opto m () a
-> m a
optoParNonSampling ro po = optoPar ro po (pure (Just ()))
{-# INLINE optoParNonSampling #-}
optoParChunk
:: forall m r a. MonadUnliftIO m
=> RunOpts m a
-> ParallelOpts m a
-> (Int -> m [r])
-> a
-> Opto (StateT [r] m) r a
-> m a
optoParChunk ro po sampler x0 o = optoPar_ ro po x0 $ \hitStop lim x -> do
items <- sampler lim
if onull items
then pure Nothing
else Just . fst <$> do
let ro' = ro
{ roLimit = Nothing
, roReport = \_ -> pure ()
, roStopCond = \d x' -> do
sc <- roStopCond ro d x'
sc <$ when sc (writeIORef hitStop (Just x'))
, roFreq = Nothing
}
optoFold ro' x o items
{-# INLINE optoParChunk #-}
optoPar_
:: forall m a. MonadUnliftIO m
=> RunOpts m a
-> ParallelOpts m a
-> a
-> (IORef (Maybe a) -> Int -> a -> m (Maybe a))
-> m a
optoPar_ RO{..} PO{..} x0 runner = do
n <- maybe getNumCapabilities pure poThreads
hitStop <- newIORef Nothing
gas <- mapM newMVar (fromIntegral <$> roLimit)
optoParLoop OPL
{ oplThreads = n
, oplFreq = roFreq
, oplSplit = poSplit
, oplHitStop = hitStop
, oplGas = gas
, oplReport = roReport
, oplRunner = runner hitStop
, oplCombine = poCombine
, oplInitial = x0
}
{-# INLINE optoPar_ #-}
data OptoParLoop m a = OPL
{ oplThreads :: Int
, oplFreq :: Maybe Int
, oplSplit :: Int
, oplHitStop :: IORef (Maybe a)
, oplGas :: Maybe (MVar Natural)
, oplReport :: a -> m ()
, oplRunner :: Int -> a -> m (Maybe a)
, oplCombine :: NonEmpty a -> m a
, oplInitial :: a
}
optoParLoop
:: MonadUnliftIO m
=> OptoParLoop m a
-> m a
optoParLoop OPL{..} = go 0 oplInitial
where
go !i !x = do
xs <- fmap catMaybes . replicateConcurrently oplThreads $
flip oplRunner x =<< maybe (pure oplSplit) getGas oplGas
readIORef oplHitStop >>= \case
Nothing -> case NE.nonEmpty xs of
Just xs' -> do
!x' <- oplCombine xs'
when (reportCheck i) $
oplReport x'
go (i + 1) x'
Nothing -> pure x
Just found -> pure found
reSplit = oplFreq <&> \r -> max 1 (r `div` (oplThreads * oplSplit))
reportCheck = case reSplit of
Nothing -> const False
Just r -> \i -> (i + 1) `mod` r == 0
getGas = flip modifyMVar $ \n -> case n `minusNaturalMaybe` fromIntegral oplSplit of
Nothing -> pure (0, fromIntegral n)
Just g -> pure (g, oplSplit )
{-# INLINE optoParLoop #-}
optoConduitPar
:: forall m r a. MonadUnliftIO m
=> RunOpts m a
-> ParallelOpts m a
-> a
-> Opto m r a
-> ConduitT () r m ()
-> ConduitT () a m ()
optoConduitPar ro po x0 o = optoConduitPar_ ro po $ \sem inQueue outVar -> do
let ro' = ro
{ roReport = \x -> do
putMVar outVar (False, x)
roReport ro x
}
readQueue = do
sem
atomically $ readTBMQueue inQueue
optoPar ro' po readQueue x0 o
{-# INLINE optoConduitPar #-}
optoConduitParChunk
:: forall m r a. MonadUnliftIO m
=> RunOpts m a
-> ParallelOpts m a
-> a
-> Opto (StateT [r] m) r a
-> ConduitT () r m ()
-> ConduitT () a m ()
optoConduitParChunk ro po x0 o = optoConduitPar_ ro po $ \sem inQueue outVar -> do
let ro' = ro
{ roReport = \x -> do
putMVar outVar (False, x)
roReport ro x
}
readChunk i = fmap catMaybes . replicateM i $ do
sem
atomically $ readTBMQueue inQueue
optoParChunk ro' po readChunk x0 o
{-# INLINE optoConduitParChunk #-}
optoConduitPar_
:: forall m r a. MonadUnliftIO m
=> RunOpts m a
-> ParallelOpts m a
-> (m () -> TBMQueue r -> MVar (Bool, a) -> m a)
-> ConduitT () r m ()
-> ConduitT () a m ()
optoConduitPar_ ro po runner src = do
n <- lift . maybe getNumCapabilities pure . poThreads $ po
let buff0 = n * poSplit po
buff = fromIntegral . maybe buff0 (min buff0) $ roLimit ro
inQueue <- atomically $ newTBMQueue buff
outVar <- newEmptyMVar
sem <- forM (guard @Maybe (poPull po)) $ \_ -> newEmptyMVar @_ @()
lift $ do
void . forkIO $ runConduit (src .| sinkTBMQueue inQueue)
void . forkIO $ do
x <- runner (mapM_ readMVar sem) inQueue outVar
putMVar outVar (True, x)
let loop = do
mapM_ (`putMVar` ()) sem
(done, r) <- takeMVar outVar
mapM_ takeMVar sem
C.yield r
unless done loop
loop
{-# INLINE optoConduitPar_ #-}
mean :: (Foldable1 t, Fractional a) => t a -> a
mean = go . foldMap1 (`Sum2` 1)
where
go (Sum2 x n) = x / fromInteger n
{-# INLINE go #-}
{-# INLINE mean #-}
data Sum2 a b = Sum2 !a !b
instance (Num a, Num b) => Semigroup (Sum2 a b) where
Sum2 x1 y1 <> Sum2 x2 y2 = Sum2 (x1 + x2) (y1 + y2)
{-# INLINE (<>) #-}
instance (Num a, Num b) => Monoid (Sum2 a b) where
mappend = (<>)
{-# INLINE mappend #-}
mempty = Sum2 0 0
{-# INLINE mempty #-}