{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
module Backprop.Learn.Model.Stochastic (
dropout
, rreLU
, injectNoise, applyNoise
, injectNoiseR, applyNoiseR
) where
import Backprop.Learn.Model.Function
import Backprop.Learn.Model.Types
import Control.Monad.Primitive
import Data.Bool
import GHC.TypeNats
import Numeric.Backprop
import Numeric.LinearAlgebra.Static.Backprop
import Numeric.LinearAlgebra.Static.Vector
import qualified Data.Vector.Storable.Sized as SVS
import qualified Statistics.Distribution as Stat
import qualified System.Random.MWC as MWC
import qualified System.Random.MWC.Distributions as MWC
dropout
:: KnownNat n
=> Double
-> Model 'Nothing 'Nothing (R n) (R n)
dropout r = Func
{ runFunc = (auto (realToFrac (1 - r)) *)
, runFuncStoch = \g x -> do
(x *) . auto . vecR <$> SVS.replicateM (mask g)
}
where
mask :: PrimMonad m => MWC.Gen (PrimState m) -> m Double
mask = fmap (bool 1 0) . MWC.bernoulli r
rreLU
:: (Stat.ContGen d, Stat.Mean d, KnownNat n)
=> d
-> Model 'Nothing 'Nothing (R n) (R n)
rreLU d = Func
{ runFunc = vmap' (preLU v)
, runFuncStoch = \g x -> do
α <- vecR <$> SVS.replicateM (Stat.genContVar d g)
pure (zipWithVector preLU (constVar α) x)
}
where
v :: BVar s Double
v = auto (Stat.mean d)
injectNoise
:: (Stat.ContGen d, Stat.Mean d, Fractional a)
=> d
-> Model 'Nothing 'Nothing a a
injectNoise d = Func
{ runFunc = (realToFrac (Stat.mean d) +)
, runFuncStoch = \g x -> do
e <- Stat.genContVar d g
pure (realToFrac e + x)
}
injectNoiseR
:: (Stat.ContGen d, Stat.Mean d, KnownNat n)
=> d
-> Model 'Nothing 'Nothing (R n) (R n)
injectNoiseR d = Func
{ runFunc = (realToFrac (Stat.mean d) +)
, runFuncStoch = \g x -> do
e <- vecR <$> SVS.replicateM (Stat.genContVar d g)
pure (constVar e + x)
}
applyNoise
:: (Stat.ContGen d, Stat.Mean d, Fractional a)
=> d
-> Model 'Nothing 'Nothing a a
applyNoise d = Func
{ runFunc = (realToFrac (Stat.mean d) *)
, runFuncStoch = \g x -> do
e <- Stat.genContVar d g
pure (realToFrac e * x)
}
applyNoiseR
:: (Stat.ContGen d, Stat.Mean d, KnownNat n)
=> d
-> Model 'Nothing 'Nothing (R n) (R n)
applyNoiseR d = Func
{ runFunc = (realToFrac (Stat.mean d) *)
, runFuncStoch = \g x -> do
e <- vecR <$> SVS.replicateM (Stat.genContVar d g)
pure (constVar e * x)
}