{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE UndecidableInstances #-}
module Backprop.Learn.Model.Parameter (
deParam, deParamD
, reParam, reParamD
, dummyParam
) where
import Backprop.Learn.Model.Types
import Control.Monad.Primitive
import Numeric.Backprop
import qualified System.Random.MWC as MWC
deParam
:: forall p q pq s a b. (Backprop p, Backprop q)
=> (pq -> (p, q))
-> (p -> q -> pq)
-> q
-> (forall m. (PrimMonad m) => MWC.Gen (PrimState m) -> m q)
-> Model ('Just pq) s a b
-> Model ('Just p ) s a b
deParam spl joi q qStoch = reParam (PJust . r . fromPJust)
(\g -> fmap PJust . rStoch g . fromPJust)
where
r :: Reifies z W => BVar z p -> BVar z pq
r p = isoVar2 joi spl p (auto q)
rStoch
:: (PrimMonad m, Reifies z W)
=> MWC.Gen (PrimState m)
-> BVar z p
-> m (BVar z pq)
rStoch g p = isoVar2 joi spl p . auto <$> qStoch g
deParamD
:: (Backprop p, Backprop q)
=> (pq -> (p, q))
-> (p -> q -> pq)
-> q
-> Model ('Just pq) s a b
-> Model ('Just p ) s a b
deParamD spl joi q = deParam spl joi q (const (pure q))
reParam
:: (forall z. Reifies z W => PMaybe (BVar z) q -> PMaybe (BVar z) p)
-> (forall m z. (PrimMonad m, Reifies z W) => MWC.Gen (PrimState m) -> PMaybe (BVar z) q -> m (PMaybe (BVar z) p))
-> Model p s a b
-> Model q s a b
reParam r rStoch f = Model
{ runLearn = runLearn f . r
, runLearnStoch = \g p x s -> do
q <- rStoch g p
runLearnStoch f g q x s
}
reParamD
:: (forall z. Reifies z W => PMaybe (BVar z) q -> PMaybe (BVar z) p)
-> Model p s a b
-> Model q s a b
reParamD r = reParam r (\_ -> pure . r)
dummyParam
:: Model 'Nothing s a b
-> Model p s a b
dummyParam = reParamD (const PNothing)