{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeInType #-}
module Numeric.Opto.Run.Simple (
sampleCollect, trainReport
, dispEpoch, dispBatch
, simpleReport
, simpleRunner
, SimpleOpts(..)
, SOOptimizer(..)
) where
import Control.Concurrent hiding (yield)
import Control.Concurrent.STM
import Control.DeepSeq
import Control.Exception
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.IO.Unlift
import Control.Monad.Primitive
import Control.Monad.Trans.Class
import Control.Monad.Trans.State
import Data.Conduit
import Data.Default
import Data.Functor
import Data.Kind
import Data.MonoTraversable
import Data.Time
import Numeric.Opto.Core
import Numeric.Opto.Run
import Numeric.Opto.Run.Conduit
import Text.Printf
import qualified Data.Conduit.Combinators as C
import qualified System.Random.MWC as MWC
dispEpoch :: MonadIO m => Int -> m ()
dispEpoch = liftIO . printf "[Epoch %d]\n"
dispBatch :: MonadIO m => Int -> m ()
dispBatch = liftIO . printf "(Batch %d)\n"
sampleCollect
:: (PrimMonad m, MonadIO m, MonoFoldable t)
=> TBQueue (Element t)
-> Maybe Int
-> (Int -> m ())
-> t
-> MWC.Gen (PrimState m)
-> ConduitT i (Element t) m ()
sampleCollect sampleQueue n report train g =
forM_ iterator (\e -> lift (report e) >> C.yieldMany train .| shuffling g)
.| C.iterM (liftIO . atomically . writeTBQueue sampleQueue)
where
iterator = case n of
Nothing -> [1..]
Just i -> [1..i]
simpleReport
:: MonadIO m
=> Maybe [i]
-> ([i] -> a -> String)
-> NominalDiffTime
-> [i]
-> a
-> m ()
simpleReport testSet testNet t chnk net = liftIO $ do
printf "Trained on %d points in %s.\n"
(length chnk)
(show t)
printf "Training>\t%s\n" trainScore
forM_ testScore $ \ts ->
printf "Validation>\t%s\n" ts
where
trainScore = testNet chnk net
testScore = (`testNet` net) <$> testSet
trainReport
:: (MonadIO m, NFData a)
=> TBQueue i
-> (Int -> m ())
-> Int
-> (NominalDiffTime -> [i] -> a -> m ())
-> ConduitT a a m ()
trainReport sampleQueue db n reportFunc =
forever (C.drop (n - 1) *> (mapM_ yield =<< await))
.| mapM_ report [1..]
where
report b = do
lift $ db b
t0 <- liftIO getCurrentTime
net' <- mapM (liftIO . evaluate . force) =<< await
chnk <- liftIO . atomically $ flushTBQueue sampleQueue
t1 <- liftIO getCurrentTime
forM_ net' $ \net -> do
lift $ reportFunc (t1 `diffUTCTime` t0) chnk net
yield net
data SimpleOpts m i a b = SO
{ soEpochs :: Maybe Int
, soDispEpoch :: Int -> m ()
, soDispBatch :: Int -> m ()
, soTestSet :: Maybe [i]
, soSkipSamps :: Int
, soEvaluate :: [i] -> a -> String
, soSink :: ConduitT a Void m b
}
data SOOptimizer :: (Type -> Type) -> (Type -> Type) -> Type -> Type -> Type where
SOSingle :: SOOptimizer m (ConduitT i a m) i a
SOParallel :: MonadUnliftIO m => ParallelOpts m a -> SOOptimizer m m i a
SOParChunked :: MonadUnliftIO m => ParallelOpts m a -> SOOptimizer m (StateT [i] m) i a
runSOOptimizer
:: MonadIO m
=> SOOptimizer m n i a
-> RunOpts m a
-> a
-> Opto n i a
-> ConduitT () i m ()
-> ConduitT () a m ()
runSOOptimizer = \case
SOSingle -> \ro x0 o -> (.| optoConduit ro x0 o)
SOParallel po -> \ro x0 o -> optoConduitPar ro po x0 o
SOParChunked po -> \ro x0 o -> optoConduitParChunk ro po x0 o
instance (MonadIO m, Default b) => Default (SimpleOpts m i a b) where
def = SO
{ soEpochs = Nothing
, soDispEpoch = dispEpoch
, soDispBatch = dispBatch
, soTestSet = Nothing
, soSkipSamps = 1000
, soEvaluate = \_ _ -> "<unevaluated>"
, soSink = def <$ C.sinkNull
}
simpleRunner
:: forall m t a b n. (MonadIO m, PrimMonad m, NFData a, MonoFoldable t)
=> SimpleOpts m (Element t) a b
-> t
-> SOOptimizer m n (Element t) a
-> RunOpts m a
-> a
-> Opto n (Element t) a
-> MWC.Gen (PrimState m)
-> m b
simpleRunner SO{..} samps soo ro x0 o g = do
sampleQueue <- liftIO . atomically $ newTBQueue (fromIntegral soSkipSamps)
skipAmt <- liftIO $ case soo of
SOSingle -> pure soSkipSamps
SOParallel po -> getSkip (poSplit po)
SOParChunked po -> getSkip (poSplit po)
let source = sampleCollect sampleQueue soEpochs soDispEpoch samps g
optimizer = runSOOptimizer soo ro x0 o source
reporter = simpleReport soTestSet soEvaluate
runConduit $ optimizer
.| trainReport sampleQueue soDispBatch skipAmt reporter
.| soSink
where
getSkip s = getNumCapabilities <&> \n ->
max 0 $ (soSkipSamps `div` (n * s)) - 1