{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Backprop.Learn.Run (
consecutives
, consecutivesN
, leadings
, conduitModel, conduitModelStoch
, oneHot', oneHot, oneHotR
, SVG.maxIndex, maxIndexR
) where
import Backprop.Learn.Model
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.Trans.Class
import Control.Monad.Trans.Maybe
import Data.Bool
import Data.Conduit
import Data.Finite
import Data.Foldable
import Data.Proxy
import Data.Type.Functor.Product
import GHC.TypeNats
import Numeric.LinearAlgebra.Static
import Numeric.LinearAlgebra.Static.Vector
import qualified Data.Conduit.Combinators as C
import qualified Data.Sequence as Seq
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Sized as SVG
import qualified System.Random.MWC as MWC
consecutives :: Monad m => ConduitT i (i, i) m ()
consecutives = void . runMaybeT $ do
x <- MaybeT await
go x
where
go x = do
y <- MaybeT await
lift $ yield (x, y)
go y
consecutivesN
:: forall v n i m. (KnownNat n, VG.Vector v i, Monad m)
=> ConduitT i (SVG.Vector v n i, SVG.Vector v n i) m ()
consecutivesN = conseq (fromIntegral n) .| C.concatMap process
where
n = natVal (Proxy @n)
process (xs, ys, _) = (,) <$> SVG.fromList (toList xs)
<*> SVG.fromList (toList ys)
leadings
:: forall v n i m. (KnownNat n, VG.Vector v i, Monad m)
=> ConduitT i (SVG.Vector v n i, i) m ()
leadings = conseq (fromIntegral n) .| C.concatMap process
where
n = natVal (Proxy @n)
process (xs, _, y) = (, y) <$> SVG.fromList (toList xs)
conseq
:: forall i m. Monad m
=> Int
-> ConduitT i (Seq.Seq i, Seq.Seq i, i) m ()
conseq n = void . runMaybeT $ do
xs <- Seq.replicateM n $ MaybeT await
go xs
where
go xs = do
_ Seq.:<| xs' <- pure xs
y <- MaybeT await
let ys = xs' Seq.:|> y
lift $ yield (xs, ys, y)
go ys
conduitModel
:: (Backprop b, AllConstrainedProd Backprop s, Monad m)
=> Model p s a b
-> TMaybe p
-> TMaybe s
-> ConduitT a b m (TMaybe s)
conduitModel f p = go
where
go s = do
mx <- await
case mx of
Nothing -> return s
Just x -> do
let (y, s') = runModel f p x s
yield y
go s'
conduitModelStoch
:: (Backprop b, AllConstrainedProd Backprop s, PrimMonad m)
=> Model p s a b
-> MWC.Gen (PrimState m)
-> TMaybe p
-> TMaybe s
-> ConduitT a b m (TMaybe s)
conduitModelStoch f g p = go
where
go s = do
mx <- await
case mx of
Nothing -> return s
Just x -> do
(y, s') <- lift $ runModelStoch f g p x s
yield y
go s'
oneHot'
:: (VG.Vector v a, KnownNat n)
=> a
-> a
-> Finite n
-> SVG.Vector v n a
oneHot' nothot hot i = SVG.generate (bool nothot hot . (== i))
oneHot
:: (VG.Vector v a, KnownNat n, Num a)
=> Finite n
-> SVG.Vector v n a
oneHot = oneHot' 0 1
oneHotR :: KnownNat n => Finite n -> R n
oneHotR = vecR . oneHot
maxIndexR :: KnownNat n => R (n + 1) -> Finite (n + 1)
maxIndexR = SVG.maxIndex . rVec