{-# LANGUAGE DeriveDataTypeable    #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE KindSignatures        #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms       #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE RecordWildCards       #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeInType            #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# LANGUAGE ViewPatterns          #-}
module Backprop.Learn.Model.Stochastic (
    dropout
  , rreLU
  , injectNoise, applyNoise
  , injectNoiseR, applyNoiseR
  ) where
import           Backprop.Learn.Model.Function
import           Backprop.Learn.Model.Types
import           Control.Monad.Primitive
import           Data.Bool
import           GHC.TypeNats
import           Numeric.Backprop
import           Numeric.LinearAlgebra.Static.Backprop
import           Numeric.LinearAlgebra.Static.Vector
import qualified Data.Vector.Storable.Sized            as SVS
import qualified Statistics.Distribution               as Stat
import qualified System.Random.MWC                     as MWC
import qualified System.Random.MWC.Distributions       as MWC
dropout
    :: KnownNat n
    => Double
    -> Model 'Nothing 'Nothing (R n) (R n)
dropout r = Func
    { runFunc      = (auto (realToFrac (1 - r)) *)
    , runFuncStoch = \g x -> do
        (x *) . auto . vecR <$> SVS.replicateM (mask g)
    }
  where
    mask :: PrimMonad m => MWC.Gen (PrimState m) -> m Double
    mask = fmap (bool 1 0) . MWC.bernoulli r
rreLU
    :: (Stat.ContGen d, Stat.Mean d, KnownNat n)
    => d
    -> Model 'Nothing 'Nothing (R n) (R n)
rreLU d = Func
    { runFunc      = vmap' (preLU v)
    , runFuncStoch = \g x -> do
        α <- vecR <$> SVS.replicateM (Stat.genContVar d g)
        pure (zipWithVector preLU (constVar α) x)
    }
  where
    v :: BVar s Double
    v = auto (Stat.mean d)
injectNoise
    :: (Stat.ContGen d, Stat.Mean d, Fractional a)
    => d
    -> Model 'Nothing 'Nothing a a
injectNoise d = Func
    { runFunc = (realToFrac (Stat.mean d) +)
    , runFuncStoch = \g x -> do
        e <- Stat.genContVar d g
        pure (realToFrac e + x)
    }
injectNoiseR
    :: (Stat.ContGen d, Stat.Mean d, KnownNat n)
    => d
    -> Model 'Nothing 'Nothing (R n) (R n)
injectNoiseR d = Func
    { runFunc = (realToFrac (Stat.mean d) +)
    , runFuncStoch  = \g x -> do
        e <- vecR <$> SVS.replicateM (Stat.genContVar d g)
        pure (constVar e + x)
    }
applyNoise
    :: (Stat.ContGen d, Stat.Mean d, Fractional a)
    => d
    -> Model 'Nothing 'Nothing a a
applyNoise d = Func
    { runFunc = (realToFrac (Stat.mean d) *)
    , runFuncStoch  = \g x -> do
        e <- Stat.genContVar d g
        pure (realToFrac e * x)
    }
applyNoiseR
    :: (Stat.ContGen d, Stat.Mean d, KnownNat n)
    => d
    -> Model 'Nothing 'Nothing (R n) (R n)
applyNoiseR d = Func
    { runFunc = (realToFrac (Stat.mean d) *)
    , runFuncStoch  = \g x -> do
        e <- vecR <$> SVS.replicateM (Stat.genContVar d g)
        pure (constVar e * x)
    }