{-# LANGUAGE ApplicativeDo #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE PatternSynonyms #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module Backprop.Learn.Model.State ( -- * To and from statelessness trainState, deState, deStateD, zeroState, dummyState -- * Manipulate model states , unroll, unrollFinal, recurrent ) where import Backprop.Learn.Model.Types import Control.Monad.Primitive import Control.Monad.Trans.State import Data.Bifunctor import Data.Foldable import Data.Type.Functor.Product import Data.Type.Tuple import Numeric.Backprop import qualified System.Random.MWC as MWC -- | Make a model stateless by converting the state to a trained parameter, -- and dropping the modified state from the result. -- -- One of the ways to make a model stateless for training purposes. Useful -- when used after 'Unroll'. See 'DeState', as well. -- -- Its parameters are: -- -- * If the input has no parameters, just the initial state. -- * If the input has a parameter, a ':#' of that parameter and initial state. trainState :: forall p s a b. ( PureProd Maybe p , PureProd Maybe s , AllConstrainedProd Backprop p , AllConstrainedProd Backprop s ) => Model p s a b -> Model (p :#? s) 'Nothing a b trainState = withModelFunc $ \f (p :#? s) x n_ -> (second . const) n_ <$> f p x s -- | Make a model stateless by pre-applying a fixed state (or a stochastic -- one with fixed stribution) and dropping the modified state from the -- result. -- -- One of the ways to make a model stateless for training purposes. Useful -- when used after 'Unroll'. See 'TrainState', as well. deState :: s -> (forall m. PrimMonad m => MWC.Gen (PrimState m) -> m s) -> Model p ('Just s) a b -> Model p 'Nothing a b deState s sStoch f = Model { runLearn = \p x n_ -> (second . const) n_ $ runLearn f p x (PJust (auto s)) , runLearnStoch = \g p x n_ -> do s' <- sStoch g (second . const) n_ <$> runLearnStoch f g p x (PJust (auto s')) } -- | 'deState', except the state is always the same even in stochastic -- mode. deStateD :: s -> Model p ('Just s) a b -> Model p 'Nothing a b deStateD s = deState s (const (pure s)) -- | 'deState' with a constant state of 0. zeroState :: Num s => Model p ('Just s) a b -> Model p 'Nothing a b zeroState = deStateD 0 -- | Unroll a (usually) stateful model into one taking a vector of -- sequential inputs. -- -- Basically applies the model to every item of input and returns all of -- the results, but propagating the state between every step. -- -- Useful when used before 'trainState' or 'deState'. See -- 'unrollTrainState' and 'unrollDeState'. -- -- Compare to 'feedbackTrace', which, instead of receiving a vector of -- sequential inputs, receives a single input and uses its output as the -- next input. unroll :: (Traversable t, Backprop a, Backprop b) => Model p s a b -> Model p s (t a) (t b) unroll = withModelFunc $ \f p xs s -> (fmap . first) collectVar . flip runStateT s . traverse (StateT . f p) . sequenceVar $ xs -- | Version of 'unroll' that only keeps the "final" result, dropping all -- of the intermediate results. -- -- Turns a stateful model into one that runs the model repeatedly on -- multiple inputs sequentially and outputs the final result after seeing -- all items. -- -- Note will be partial if given an empty sequence. unrollFinal :: (Traversable t, Backprop a) => Model p s a b -> Model p s (t a) b unrollFinal = withModelFunc $ \f p xs s0 -> foldlM (\(_, s) x -> f p x s) (undefined, s0) (sequenceVar xs) -- | Fix a part of a parameter of a model to be (a function of) the -- /previous/ ouput of the model itself. -- -- Essentially, takes a \( X \times Y \rightarrow Z \) into a /stateful/ -- \( X \rightarrow Z \), where the Y is given by a function of the -- /previous output/ of the model. -- -- Essentially makes a model "recurrent": it receives its previous output -- as input. -- -- See 'fcr' for an application. recurrent :: forall p s ab a b c. -- ( KnownMayb s ( AllConstrainedProd Backprop s , PureProd Maybe s , Backprop a , Backprop b ) => (ab -> (a, b)) -- ^ split -> (a -> b -> ab) -- ^ join -> BFunc c b -- ^ store state -> Model p s ab c -> Model p (s :#? 'Just b) a c recurrent spl joi sto = withModelFunc $ \f p x (s :#? y) -> do (z, s') <- f p (isoVar2 joi spl x (fromPJust y)) s pure (z, s' :#? PJust (sto z)) -- | Give a stateless model a "dummy" state. For now, useful for using -- with combinators like 'deState' that require state. However, 'deState' -- could also be made more lenient (to accept non stateful models) in the -- future. -- -- Also useful for usage with combinators like 'Control.Category..' from -- "Control.Category" that requires all input models to share common state. dummyState :: forall s p a b. () => Model p 'Nothing a b -> Model p s a b dummyState = withModelFunc $ \f p x s -> (second . const) s <$> f p x PNothing