{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE PolyKinds           #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeInType          #-}
{-# LANGUAGE TypeOperators       #-}

module TensorOps.Learn.NeuralNet.AutoEncoder
  ( Encoder(..)
  , encode, decode
  , encoderNet, encodeDecode, testEncoder
  , trainEncoder
  ) where

import           Control.Category
import           Data.Kind
import           Data.Singletons
import           Data.Type.Conjunction
import           Data.Type.Length
import           Data.Type.Product                     as TCP
import           Data.Type.Product.Util                as TCP
import           Data.Type.Sing
import           Prelude hiding                        ((.), id)
import           TensorOps.Learn.NeuralNet.FeedForward
import           TensorOps.Types
import           Type.Class.Higher
import           Type.Class.Witness
import           Type.Family.List
import qualified TensorOps.TOp                         as TO
import qualified TensorOps.Tensor                      as TT

data Encoder :: ([k] -> Type) -> k -> k -> Type where
    E :: { eEncoder :: !(Network t i o)
         , eDecoder :: !(Network t o i)
         } -> Encoder t i o

encode
    :: (RealFloat (ElemT t), Tensor t)
    => Encoder t i o
    -> t '[i]
    -> t '[o]
encode e = runNetwork (eEncoder e)
{-# INLINE encode #-}

decode
    :: (RealFloat (ElemT t), Tensor t)
    => Encoder t i o
    -> t '[o]
    -> t '[i]
decode e = runNetwork (eDecoder e)
{-# INLINE decode #-}

encodeDecode
    :: (RealFloat (ElemT t), Tensor t)
    => Encoder t i o
    -> t '[i]
    -> t '[i]
encodeDecode e = runNetwork (encoderNet e)

testEncoder
    :: forall t i o. (RealFloat (ElemT t), Tensor t, SingI i)
    => TOp '[ '[i], '[i] ] '[ '[] ]
    -> Encoder t i o
    -> t '[i]
    -> ElemT t
testEncoder loss e = case encoderNet e of
    N _ o p -> TT.unScalar
             . TCP.head'
             . runTOp ( firstOp TO.duplicate
                    >>> secondOp @'[ '[i] ] o
                    >>> TO.swap
                    >>> loss
                      )
             . (:< p)

encoderNet
    :: Encoder t i o
    -> Network t i i
encoderNet (E e d) = e >>> d
{-# INLINE encoderNet #-}

trainEncoder
    :: forall t i o. (Tensor t, RealFloat (ElemT t), SingI i)
    => TOp '[ '[i], '[i] ] '[ '[] ]
    -> ElemT t
    -> t '[i]
    -> Encoder t i o
    -> Encoder t i o
trainEncoder loss r x = \case
    E e d -> case e of
      N sE oE pE -> case d of
        N sD oD pD ->
          let (grE, grD) = encGrad loss x sE oE pE oD pD
              pE' = map1 stepFunc $ zipProd3 (singProd sE) pE grE
              pD' = map1 stepFunc $ zipProd3 (singProd sD) pD grD
          in  E (N sE oE pE') (N sD oD pD')
  where
    stepFunc
        :: forall ns. ()
        => (Sing :&: t :&: t) ns
        -> t ns
    stepFunc !(s :&: p :&: g) =
      TT.zip (\(!p') (!g') -> p' - r * g') p g \\ s
    {-# INLINE stepFunc #-}
{-# INLINE trainEncoder #-}

encGrad
    :: forall k (t :: [k] -> Type) (i :: k) (o :: k) (psE :: [[k]]) (psD :: [[k]]).
     ( Tensor t
     , RealFloat (ElemT t)
     , SingI i
     )
    => TOp '[ '[i], '[i] ] '[ '[] ]
    -> t '[i]
    -> Sing psE
    -> TOp ( '[i] ': psE ) '[ '[o] ]
    -> Prod t psE
    -> TOp ( '[o] ': psD ) '[ '[i] ]
    -> Prod t psD
    -> (Prod t psE, Prod t psD)
encGrad loss x sE oE pE oD pD = splitProd @psD @t @psE lE gr
  where
    lE :: Length psE
    lE = singLength sE
    o :: TOp ('[i] ': psE ++ psD) '[ '[] ]
    o = (\\ lE) $
          firstOp @(psE ++ psD) TO.duplicate
      >>> secondOp @'[ '[i] ] (
            firstOp @psD oE >>> oD
          )
      >>> TO.swap
      >>> loss
    x' :: Prod t ('[i] ': psE ++ psD)
    x' = x :< pE `TCP.append'` pD
    gr :: Prod t (psE ++ psD)
    gr = TCP.tail' $ gradTOp o x'
{-# INLINE encGrad #-}