{-# LANGUAGE DataKinds                                #-}
{-# LANGUAGE DeriveAnyClass                           #-}
{-# LANGUAGE DeriveDataTypeable                       #-}
{-# LANGUAGE DeriveGeneric                            #-}
{-# LANGUAGE DerivingVia                              #-}
{-# LANGUAGE FlexibleInstances                        #-}
{-# LANGUAGE GADTs                                    #-}
{-# LANGUAGE KindSignatures                           #-}
{-# LANGUAGE MultiParamTypeClasses                    #-}
{-# LANGUAGE PatternSynonyms                          #-}
{-# LANGUAGE RankNTypes                               #-}
{-# LANGUAGE RecordWildCards                          #-}
{-# LANGUAGE ScopedTypeVariables                      #-}
{-# LANGUAGE StandaloneDeriving                       #-}
{-# LANGUAGE TemplateHaskell                          #-}
{-# LANGUAGE TypeApplications                         #-}
{-# LANGUAGE TypeFamilies                             #-}
{-# LANGUAGE TypeInType                               #-}
{-# LANGUAGE TypeOperators                            #-}
{-# LANGUAGE UndecidableInstances                     #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise       #-}
module Backprop.Learn.Model.Neural.LSTM (
    
    lstm
  , LSTMp(..), lstmForget, lstmInput, lstmUpdate, lstmOutput
  , reshapeLSTMpInput
  , reshapeLSTMpOutput
  , lstm'
    
  , gru
  , GRUp(..), gruMemory, gruUpdate, gruOutput
  , gru'
  ) where
import           Backprop.Learn.Initialize
import           Backprop.Learn.Model.Function
import           Backprop.Learn.Model.Neural
import           Backprop.Learn.Model.Regression
import           Backprop.Learn.Model.State
import           Backprop.Learn.Model.Types
import           Control.DeepSeq
import           Control.Monad
import           Control.Monad.Primitive
import           Data.Type.Tuple
import           Data.Typeable
import           GHC.Generics                          (Generic)
import           GHC.TypeNats
import           Lens.Micro
import           Lens.Micro.TH
import           Numeric.Backprop
import           Numeric.LinearAlgebra.Static.Backprop
import           Numeric.OneLiner
import           Numeric.Opto.Ref
import           Numeric.Opto.Update
import           Statistics.Distribution
import qualified Data.Binary                           as Bi
import qualified Numeric.LinearAlgebra.Static          as H
import qualified System.Random.MWC                     as MWC
data LSTMp (i :: Nat) (o :: Nat) =
    LSTMp { _lstmForget :: !(FCp (i + o) o)
          , _lstmInput  :: !(FCp (i + o) o)
          , _lstmUpdate :: !(FCp (i + o) o)
          , _lstmOutput :: !(FCp (i + o) o)
          }
  deriving stock     (Generic, Typeable, Show)
  deriving anyclass  (NFData, Linear Double, Metric Double, Bi.Binary, Regularize, Backprop)
deriving via (GNum (LSTMp i o)) instance (KnownNat i, KnownNat o) => Num (LSTMp i o)
deriving via (GNum (LSTMp i o)) instance (KnownNat i, KnownNat o) => Fractional (LSTMp i o)
deriving via (GNum (LSTMp i o)) instance (KnownNat i, KnownNat o) => Floating (LSTMp i o)
makeLenses ''LSTMp
instance (PrimMonad m, KnownNat i, KnownNat o) => Mutable m (LSTMp i o) where
    type Ref m (LSTMp i o) = GRef m (LSTMp i o)
    thawRef   = gThawRef
    freezeRef = gFreezeRef
    copyRef   = gCopyRef
instance (PrimMonad m, KnownNat i, KnownNat o) => LinearInPlace m Double (LSTMp i o)
instance (PrimMonad m, KnownNat i, KnownNat o) => Learnable m (LSTMp i o)
lstm'
    :: (KnownNat i, KnownNat o)
    => Model ('Just (LSTMp i o)) ('Just (R o)) (R (i + o)) (R o)
lstm' = modelD $ \(PJust p) x (PJust s) ->
    let forget = logistic $ runLRp (p ^^. lstmForget) x
        input  = logistic $ runLRp (p ^^. lstmInput ) x
        update = tanh     $ runLRp (p ^^. lstmUpdate) x
        s'     = forget * s + input * update
        o      = logistic $ runLRp (p ^^. lstmOutput) x
        h      = o * tanh s'
    in  (h, PJust s')
lstm
    :: (KnownNat i, KnownNat o)
    => Model ('Just (LSTMp i o)) ('Just (R o :# R o)) (R i) (R o)
lstm = recurrent H.split (H.#) id lstm'
reshapeLSTMpInput
    :: (ContGen d, PrimMonad m, KnownNat i, KnownNat i', KnownNat o)
    => d
    -> MWC.Gen (PrimState m)
    -> LSTMp i o
    -> m (LSTMp i' o)
reshapeLSTMpInput d g (LSTMp forget input update output) =
    LSTMp <$> reshaper forget
          <*> reshaper input
          <*> reshaper update
          <*> reshaper output
  where
    reshaper = reshapeLRpInput d g
reshapeLSTMpOutput
    :: (ContGen d, PrimMonad m, KnownNat i, KnownNat o, KnownNat o')
    => d
    -> MWC.Gen (PrimState m)
    -> LSTMp i o
    -> m (LSTMp i o')
reshapeLSTMpOutput d g (LSTMp forget input update output) =
    LSTMp <$> reshaper forget
          <*> reshaper input
          <*> reshaper update
          <*> reshaper output
  where
    reshaper = reshapeLRpInput  d g
           <=< reshapeLRpOutput d g
instance (KnownNat i, KnownNat o) => Initialize (LSTMp i o) where
    initialize d g = LSTMp <$> set (mapped . fcBias) 1 (initialize d g)
                           <*> initialize d g
                           <*> initialize d g
                           <*> initialize d g
data GRUp (i :: Nat) (o :: Nat) =
    GRUp { _gruMemory :: !(FCp (i + o) o)
         , _gruUpdate :: !(FCp (i + o) o)
         , _gruOutput :: !(FCp (i + o) o)
         }
  deriving stock     (Generic, Typeable, Show)
  deriving anyclass  (NFData, Linear Double, Metric Double, Bi.Binary, Initialize, Regularize, Backprop)
deriving via (GNum (GRUp i o)) instance (KnownNat i, KnownNat o) => Num (GRUp i o)
deriving via (GNum (GRUp i o)) instance (KnownNat i, KnownNat o) => Fractional (GRUp i o)
deriving via (GNum (GRUp i o)) instance (KnownNat i, KnownNat o) => Floating (GRUp i o)
makeLenses ''GRUp
instance (PrimMonad m, KnownNat i, KnownNat o) => Mutable m (GRUp i o) where
    type Ref m (GRUp i o) = GRef m (GRUp i o)
    thawRef   = gThawRef
    freezeRef = gFreezeRef
    copyRef   = gCopyRef
instance (KnownNat i, KnownNat o, Mutable m (GRUp i o)) => LinearInPlace m Double (GRUp i o)
instance (KnownNat i, KnownNat o, PrimMonad m) => Learnable m (GRUp i o)
gru'
    :: forall i o. (KnownNat i, KnownNat o)
    => Model ('Just (GRUp i o)) 'Nothing (R (i + o)) (R o)
gru' = modelStatelessD $ \(PJust p) x ->
    let z      = logistic $ runLRp (p ^^. gruMemory) x
        r      = logistic $ runLRp (p ^^. gruUpdate) x
        r'     = 1 # r
        h'     = tanh     $ runLRp (p ^^. gruOutput) (r' * x)
    in  (1 - z) * snd (split @i x) + z * h'
gru :: (KnownNat i, KnownNat o)
    => Model ('Just (GRUp i o)) ('Just (R o)) (R i) (R o)
gru = recurrent H.split (H.#) id gru'