{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Numeric.Opto.Run.Conduit (
shuffling
, shufflingN
, sinkSampleReservoir
, samplingN
, skipSampling
) where
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.Trans.Class
import Control.Monad.Trans.Maybe
import Data.Conduit
import Data.Foldable
import Data.Maybe
import qualified Data.Conduit.Combinators as C
import qualified Data.Vector as V
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable as VG
import qualified System.Random.MWC as MWC
import qualified System.Random.MWC.Distributions as MWC
shuffling
:: PrimMonad m
=> MWC.Gen (PrimState m)
-> ConduitT a a m ()
shuffling g = do
v <- C.sinkVector @V.Vector
C.yieldMany =<< MWC.uniformShuffle v g
shufflingN
:: PrimMonad m
=> Int
-> MWC.Gen (PrimState m)
-> ConduitT a a m ()
shufflingN n g = do
v <- C.sinkVectorN @V.Vector n
C.yieldMany =<< MWC.uniformShuffle v g
sinkSampleReservoir
:: forall m v a o. (PrimMonad m, VG.Vector v a)
=> Int
-> MWC.Gen (PrimState m)
-> ConduitT a o m (v a)
sinkSampleReservoir k g = do
xs <- VG.thaw . VG.fromList . catMaybes =<< replicateM k await
void . runMaybeT . for_ [k+1 ..] $ \i -> do
x <- MaybeT await
lift . lift $ do
j <- MWC.uniformR (1, i) g
when (j <= k) $
VG.unsafeWrite xs (j - 1) x
lift $ VG.freeze xs
samplingN
:: PrimMonad m
=> Int
-> MWC.Gen (PrimState m)
-> ConduitT a a m ()
samplingN k = C.yieldMany <=< sinkSampleReservoir @_ @V.Vector k
skipSampling
:: PrimMonad m
=> Double
-> MWC.Gen (PrimState m)
-> ConduitT a a m ()
skipSampling λ g = go
where
go = do
n <- MWC.geometric0 λ g
C.drop n
mx <- await
case mx of
Just x -> yield x >> go
Nothing -> return ()