{-# LANGUAGE BangPatterns           #-}
{-# LANGUAGE DataKinds              #-}
{-# LANGUAGE FlexibleContexts       #-}
{-# LANGUAGE GADTs                  #-}
{-# LANGUAGE LambdaCase             #-}
{-# LANGUAGE PolyKinds              #-}
{-# LANGUAGE ScopedTypeVariables    #-}
{-# LANGUAGE TupleSections          #-}
{-# LANGUAGE TypeApplications       #-}
{-# LANGUAGE TypeFamilyDependencies #-}
module Backprop.Learn.Model (
    module M, Backprop(..)
  
  , runModel, runModelStoch, runModelStateless, runModelStochStateless
  , gradModel, gradModelStoch
  
  , initParam, initParamNormal
  , encodeParam, decodeParam, decodeParamOrFail, saveParam, loadParam, loadParamOrFail
  
  , iterateModel, iterateModelM, iterateModelStoch
  , scanModel, scanModelStoch
  
  , iterateModel_, iterateModelM_, iterateModelStoch_
  , scanModel_, scanModelStoch_
  
  , primeModel, primeModelStoch, selfPrime, selfPrimeM
  ) where
import           Backprop.Learn.Initialize
import           Backprop.Learn.Model.Combinator  as M
import           Backprop.Learn.Model.Function    as M
import           Backprop.Learn.Model.Neural      as M
import           Backprop.Learn.Model.Neural.LSTM as M
import           Backprop.Learn.Model.Parameter   as M
import           Backprop.Learn.Model.Regression  as M
import           Backprop.Learn.Model.State       as M
import           Backprop.Learn.Model.Stochastic  as M
import           Backprop.Learn.Model.Types       as M
import           Control.Monad.Primitive
import           Control.Monad.ST
import           Control.Monad.Trans.State
import           Data.Bifunctor
import           Data.Foldable
import           Data.Functor.Identity
import           Data.Type.Functor.Product
import           Data.Type.Tuple
import           Data.Word
import           Numeric.Backprop
import           Statistics.Distribution
import qualified Data.Binary                      as Bi
import qualified Data.ByteString.Lazy             as BSL
import qualified Data.Vector.Unboxed              as VU
import qualified System.Random.MWC                as MWC
runModel
    :: forall p s a b. (AllConstrainedProd Backprop s, Backprop b)
    => Model p s a b
    -> TMaybe p
    -> a
    -> TMaybe s
    -> (b, TMaybe s)
runModel f mp x ms = evalBP0 go
  where
    go :: forall z. Reifies z W => BVar z (b, TMaybe s)
    go = case ms' of
        PNothing    -> T2 y $ auto PNothing
        PJust s' -> T2 y $ isoVar (PJust . TF) (getTF . fromPJust) s'
      where
        y :: BVar z b
        ms' :: PMaybe (BVar z) s
        (y, ms') = runLearn f (mapProd (auto . getTF) mp)
                              (auto x)
                              (mapProd (auto . getTF) ms)
runModelStoch
    :: forall p s a b m. (AllConstrainedProd Backprop s, Backprop b, PrimMonad m)
    => Model p s a b
    -> MWC.Gen (PrimState m)
    -> TMaybe p
    -> a
    -> TMaybe s
    -> m (b, TMaybe s)
runModelStoch f g mp x ms = do
    
    seed <- MWC.uniformVector @_ @Word32 @VU.Vector g 2
    pure $ evalBP0 (runST (go seed))
  where
    go  :: forall q z. Reifies z W
        => VU.Vector Word32
        -> ST q (BVar z (b, TMaybe s))
    go seed = do
      g' <- MWC.initialize seed
      (y :: BVar z b, ms') <- runLearnStoch f g'
        (mapProd (auto . getTF) mp)
        (auto x)
        (mapProd (auto . getTF) ms)
      pure $ case ms' of
        PNothing    -> T2 y $ auto PNothing
        PJust s' -> T2 y $ isoVar (PJust . TF) (getTF . fromPJust) s'
runModelStateless
    :: Model p 'Nothing a b
    -> TMaybe p
    -> a
    -> b
runModelStateless f = \case
    PNothing       -> evalBP  (runLearnStateless f PNothing  )
    PJust (TF p) -> evalBP2 (runLearnStateless f . PJust) p
runModelStochStateless
    :: PrimMonad m
    => Model p 'Nothing a b
    -> MWC.Gen (PrimState m)
    -> TMaybe p
    -> a
    -> m b
runModelStochStateless f g mp x = do
    seed <- MWC.uniformVector @_ @Word32 @VU.Vector g 2
    pure $ case mp of
      PNothing       -> evalBP  (\x' -> runST $ do
          g' <- MWC.initialize seed
          runLearnStochStateless f g' PNothing x'
        ) x
      PJust (TF p) -> evalBP2 (\p' x' -> runST $ do
          g' <- MWC.initialize seed
          runLearnStochStateless f g' (PJust p') x'
        ) p x
gradModel
    :: (Backprop a, Backprop b, AllConstrainedProd Backprop p)
    => Model p 'Nothing a b
    -> TMaybe p
    -> a
    -> (TMaybe p, a)
gradModel f = \case
    PNothing       ->       (PNothing,)    . gradBP  (runLearnStateless f   PNothing)
    PJust (TF p) -> first (PJust . TF) . gradBP2 (runLearnStateless f . PJust) p
gradModelStoch
    :: (Backprop a, Backprop b, AllConstrainedProd Backprop p, PrimMonad m)
    => Model p 'Nothing a b
    -> MWC.Gen (PrimState m)
    -> TMaybe p
    -> a
    -> m (TMaybe p, a)
gradModelStoch f g mp x = do
    seed <- MWC.uniformVector @_ @Word32 @VU.Vector g 2
    pure $ case mp of
      PNothing       ->       (PNothing,)    $ gradBP  (\x' -> runST $ do
          g' <- MWC.initialize seed
          runLearnStochStateless f g' PNothing x'
        ) x
      PJust (TF p) -> first (PJust . TF) $ gradBP2 (\p' x' -> runST $ do
          g' <- MWC.initialize seed
          runLearnStochStateless f g' (PJust p') x'
        ) p x
iterateModel
    :: (Backprop b, AllConstrainedProd Backprop s)
    => (b -> a)         
    -> Int              
    -> Model p s a b
    -> TMaybe p
    -> a
    -> TMaybe s
    -> ([b], TMaybe s)
iterateModel l n f p x = runIdentity . iterateModelM (Identity . l) n f p x
iterateModel_
    :: (Backprop b, AllConstrainedProd Backprop s)
    => (b -> a)         
    -> Model p s a b
    -> TMaybe p
    -> a
    -> TMaybe s
    -> [b]
iterateModel_ l f p = go
  where
    go !x !s = y : go (l y) s'
      where
        (y, s') = runModel f p x s
selfPrime
    :: (Backprop b, AllConstrainedProd Backprop s)
    => (b -> a)         
    -> Model p s a b
    -> TMaybe p
    -> a
    -> TMaybe s
    -> [TMaybe s]
selfPrime l f p = go
  where
    go !x !s = s' : go (l y) s'
      where
        (y, s') = runModel f p x s
iterateModelM
    :: (Backprop b, AllConstrainedProd Backprop s, Monad m)
    => (b -> m a)           
    -> Int                  
    -> Model p s a b
    -> TMaybe p
    -> a
    -> TMaybe s
    -> m ([b], TMaybe s)
iterateModelM l n f p = go 0
  where
    go !i !x !s
      | i <= n    = do
          let (y, s') = runModel f p x s
          (ys, s'') <- flip (go (i + 1)) s' =<< l y
          pure (y : ys, s'')
      | otherwise = pure ([], s)
iterateModelM_
    :: (Backprop b, AllConstrainedProd Backprop s, Monad m)
    => (b -> m a)           
    -> Int                  
    -> Model p s a b
    -> TMaybe p
    -> a
    -> TMaybe s
    -> m [b]
iterateModelM_ l n f p x = fmap fst . iterateModelM l n f p x
selfPrimeM
    :: (Backprop b, AllConstrainedProd Backprop s, Monad m)
    => (b -> m a)           
    -> Int                  
    -> Model p s a b
    -> TMaybe p
    -> a
    -> TMaybe s
    -> m (TMaybe s)
selfPrimeM l n f p x = fmap snd . iterateModelM l n f p x
iterateModelStoch
    :: (Backprop b, AllConstrainedProd Backprop s, PrimMonad m)
    => (b -> m a)           
    -> Int                  
    -> Model p s a b
    -> MWC.Gen (PrimState m)
    -> TMaybe p
    -> a
    -> TMaybe s
    -> m ([b], TMaybe s)
iterateModelStoch l n f g p = go 0
  where
    go !i !x !s
      | i <= n    = do
          (y , s' ) <- runModelStoch f g p x s
          (ys, s'') <- flip (go (i + 1)) s' =<< l y
          pure (y : ys, s'')
      | otherwise = pure ([], s)
iterateModelStoch_
    :: (Backprop b, AllConstrainedProd Backprop s, PrimMonad m)
    => (b -> m a)           
    -> Int                  
    -> Model p s a b
    -> MWC.Gen (PrimState m)
    -> TMaybe p
    -> a
    -> TMaybe s
    -> m [b]
iterateModelStoch_ l n f g p x = fmap fst . iterateModelStoch l n f g p x
scanModel
    :: (Traversable t, Backprop b, AllConstrainedProd Backprop s)
    => Model p s a b
    -> TMaybe p
    -> t a
    -> TMaybe s
    -> (t b, TMaybe s)
scanModel f p = runState . traverse (state . runModel f p)
scanModel_
    :: (Traversable t, Backprop b, AllConstrainedProd Backprop s)
    => Model p s a b
    -> TMaybe p
    -> t a
    -> TMaybe s
    -> t b
scanModel_ f p xs = fst . scanModel f p xs
primeModel
    :: (Foldable t, Backprop b, AllConstrainedProd Backprop s)
    => Model p s a b
    -> TMaybe p
    -> t a
    -> TMaybe s
    -> TMaybe s
primeModel f p = execState . traverse_ (state . runModel f p)
scanModelStoch
    :: (Traversable t, Backprop b, AllConstrainedProd Backprop s, PrimMonad m)
    => Model p s a b
    -> MWC.Gen (PrimState m)
    -> TMaybe p
    -> t a
    -> TMaybe s
    -> m (t b, TMaybe s)
scanModelStoch f g p = runStateT . traverse (StateT . runModelStoch f g p)
scanModelStoch_
    :: (Traversable t, Backprop b, AllConstrainedProd Backprop s, PrimMonad m)
    => Model p s a b
    -> MWC.Gen (PrimState m)
    -> TMaybe p
    -> t a
    -> TMaybe s
    -> m (t b)
scanModelStoch_ f g p xs = fmap fst . scanModelStoch f g p xs
primeModelStoch
    :: (Foldable t, Backprop b, AllConstrainedProd Backprop s, PrimMonad m)
    => Model p s a b
    -> MWC.Gen (PrimState m)
    -> TMaybe p
    -> t a
    -> TMaybe s
    -> m (TMaybe s)
primeModelStoch f g p = execStateT . traverse_ (StateT . runModelStoch f g p)
initParam
    :: (Initialize p, ContGen d, PrimMonad m)
    => model ('Just p) s a b                    
    -> d
    -> MWC.Gen (PrimState m)
    -> m p
initParam _ = initialize
initParamNormal
    :: (Initialize p, PrimMonad m)
    => model ('Just p) s a b                    
    -> Double
    -> MWC.Gen (PrimState m)
    -> m p
initParamNormal _ = initializeNormal
encodeParam
    :: Bi.Binary p
    => model ('Just p) s a b                    
    -> p
    -> BSL.ByteString
encodeParam _ = Bi.encode
decodeParam
    :: Bi.Binary p
    => model ('Just p) s a b                    
    -> BSL.ByteString
    -> p
decodeParam _ = Bi.decode
decodeParamOrFail
    :: Bi.Binary p
    => model ('Just p) s a b                    
    -> BSL.ByteString
    -> Either String p
decodeParamOrFail _ = bimap thrd thrd . Bi.decodeOrFail
saveParam
    :: Bi.Binary p
    => model ('Just p) s a b                    
    -> FilePath
    -> p
    -> IO ()
saveParam p fp = BSL.writeFile fp . encodeParam p
loadParam
    :: Bi.Binary p
    => model ('Just p) s a b                    
    -> FilePath
    -> IO p
loadParam p fp = decodeParam p <$> BSL.readFile fp
loadParamOrFail
    :: Bi.Binary p
    => model ('Just p) s a b                    
    -> FilePath
    -> IO (Either String p)
loadParamOrFail p fp = decodeParamOrFail p <$> BSL.readFile fp
thrd :: (a,b,c) -> c
thrd (_,_,z) = z