{-# LANGUAGE DeriveDataTypeable    #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE KindSignatures        #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms       #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE RecordWildCards       #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeInType            #-}
{-# LANGUAGE UndecidableInstances  #-}
module Backprop.Learn.Model.Parameter (
    deParam, deParamD
  , reParam, reParamD
  , dummyParam
  ) where
import           Backprop.Learn.Model.Types
import           Control.Monad.Primitive
import           Numeric.Backprop
import qualified System.Random.MWC          as MWC
deParam
    :: forall p q pq s a b. (Backprop p, Backprop q)
    => (pq -> (p, q))                   
    -> (p -> q -> pq)                   
    -> q                                
    -> (forall m. (PrimMonad m) => MWC.Gen (PrimState m) -> m q)    
    -> Model ('Just pq) s a b
    -> Model ('Just p )  s a b
deParam spl joi q qStoch = reParam (PJust . r . fromPJust)
                                   (\g -> fmap PJust . rStoch g . fromPJust)
  where
    r :: Reifies z W => BVar z p -> BVar z pq
    r p = isoVar2 joi spl p (auto q)
    rStoch
        :: (PrimMonad m, Reifies z W)
        => MWC.Gen (PrimState m)
        -> BVar z p
        -> m (BVar z pq)
    rStoch g p = isoVar2 joi spl p . auto <$> qStoch g
deParamD
    :: (Backprop p, Backprop q)
    => (pq -> (p, q))                   
    -> (p -> q -> pq)                   
    -> q                                
    -> Model ('Just pq) s a b
    -> Model ('Just p )  s a b
deParamD spl joi q = deParam spl joi q (const (pure q))
reParam
    :: (forall z. Reifies z W => PMaybe (BVar z) q -> PMaybe (BVar z) p)
    -> (forall m z. (PrimMonad m, Reifies z W) => MWC.Gen (PrimState m) -> PMaybe (BVar z) q -> m (PMaybe (BVar z) p))
    -> Model p s a b
    -> Model q s a b
reParam r rStoch f = Model
    { runLearn      = runLearn f . r
    , runLearnStoch = \g p x s -> do
        q <- rStoch g p
        runLearnStoch f g q x s
    }
reParamD
    :: (forall z. Reifies z W => PMaybe (BVar z) q -> PMaybe (BVar z) p)
    -> Model p s a b
    -> Model q s a b
reParamD r = reParam r (\_ -> pure . r)
dummyParam
    :: Model 'Nothing  s a b
    -> Model p         s a b
dummyParam = reParamD (const PNothing)