{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE DeriveDataTypeable    #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE MultiWayIf            #-}
{-# LANGUAGE PatternSynonyms       #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeInType            #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE UndecidableInstances  #-}
{-# LANGUAGE ViewPatterns          #-}
module Backprop.Learn.Model.Function (
  
    meanModel
  , varModel
  , stdevModel
  , rangeModel
  
  
  
  
  
  , step
  , logistic
  , softsign
  , reLU
  , softPlus
  , bentIdentity
  , siLU
  , softExponential
  , sinc
  , gaussian
  , tanh
  , atan
  , sin
  , vmap
  , vmap'
  
  , liftUniform
  , isru
  , preLU
  , sreLU
  , sreLUPFP
  , eLU
  , isrLU
  , apl
  , aplPFP
  
  , softMax
  , maxout
  , kSparse
  ) where
import           Control.Category
import           Data.Foldable
import           Data.Profunctor
import           Data.Proxy
import           Data.Type.Tuple
import           GHC.TypeNats
import           Numeric.Backprop
import           Numeric.LinearAlgebra.Static.Backprop hiding (tr)
import           Prelude hiding                               ((.), id)
import qualified Control.Foldl                                as F
import qualified Data.Vector.Sized                            as SV
import qualified Data.Vector.Storable.Sized                   as SVS
import qualified Numeric.LinearAlgebra                        as HU
import qualified Numeric.LinearAlgebra.Static                 as H
import qualified Numeric.LinearAlgebra.Static.Vector          as H
meanModel
    :: (Backprop (t a), Foldable t, Functor t, Fractional a, Reifies s W)
    => BVar s (t a)
    -> BVar s a
meanModel = liftOp1 . op1 $ \xs ->
    let x :# n = F.fold ((:#) <$> F.sum <*> F.length) xs
    in  (x / fromIntegral n, \d -> (d / fromIntegral n) <$ xs)
varModel
    :: (Backprop (t a), Foldable t, Functor t, Fractional a, Reifies s W)
    => BVar s (t a)
    -> BVar s a
varModel = liftOp1 . op1 $ \xs ->
    let x2 :# x1 :# x0 = F.fold ((\x2' x1' x0' -> x2' :# x1' :# x0') <$> lmap (^(2::Int)) F.sum <*> F.sum <*> F.length) xs
        meanx  = x1 / fromIntegral x0
        subAll = 2 * x1 / (fromIntegral x0 ^ (2 :: Int))
    in  ( (x2 / fromIntegral x0) - meanx * meanx
        , \d -> let subAllD = d * subAll
                in  (\x -> d * 2 * x / fromIntegral x0 - subAllD) <$> xs
        )
stdevModel
    :: (Backprop (t a), Foldable t, Functor t, Floating a, Reifies s W)
    => BVar s (t a)
    -> BVar s a
stdevModel = sqrt . varModel
rangeModel
    :: (Backprop (t a), Foldable t, Functor t, Ord a, Num a, Reifies s W)
    => BVar s (t a)
    -> BVar s a
rangeModel = liftOp1 . op1 $ \xs ->
    let mn :# mx = F.fold ((:#) <$> F.minimum <*> F.maximum) xs
    in  case (:#) <$> mn <*> mx of
          Nothing           -> errorWithoutStackTrace "Backprop.Learn.Model.Function.range: empty range"
          Just (mn' :# mx') ->
            ( mx' - mn'
            , \d -> (\x -> if | x == mx'  -> d
                              | x == mn'  -> -d
                              | otherwise -> 0
                    ) <$> xs
            )
softMax
    :: (KnownNat i, Reifies s W)
    => BVar s (R i)
    -> BVar s (R i)
softMax x = expx / konst (norm_1V expx)
  where
    expx = exp x
logistic :: Floating a => a -> a
logistic x = 1 / (1 + exp (-x))
step :: (Ord a, Num a) => a -> a
step x | x < 0     = 0
       | otherwise = 1
softsign :: Fractional a => a -> a
softsign x = x / (1 + abs x)
isru
    :: Floating a
    => a        
    -> a        
    -> a
isru α x = x / sqrt (1 + α * x * x)
reLU :: (Num a, Ord a) => a -> a
reLU x | x < 0     = 0
       | otherwise = x
preLU
    :: (Num a, Ord a)
    => a        
    -> a        
    -> a
preLU α x | x < 0     = α * x
          | otherwise = x
eLU :: (Floating a, Ord a)
    => a        
    -> a        
    -> a
eLU α x | x < 0     = α * (exp x - 1)
        | otherwise = x
sreLU
    :: (Num a, Ord a)
    => a        
    -> a        
    -> a        
    -> a        
    -> a        
    -> a
sreLU tl al tr ar x
    | x < tl    = tl + al * (x - tl)
    | x > tr    = tr + ar * (x - tr)
    | otherwise = x
sreLUPFP
    :: (KnownNat n, Reifies s W)
    => BVar s ((Double :# Double) :# (Double :# Double))
    -> BVar s (R n)
    -> BVar s (R n)
sreLUPFP ((tl :## al) :## (tr :## ar)) = vmap (sreLU tl al tr ar)
isrLU
    :: (Floating a, Ord a)
    => a        
    -> a        
    -> a
isrLU α x
    | x < 0     = x / sqrt (1 + α * x * x)
    | otherwise = x
apl :: (KnownNat n, KnownNat m, Reifies s W)
    => BVar s (L n m)     
    -> BVar s (L n m)     
    -> BVar s (R m)       
    -> BVar s (R m)
apl as bs x = vmap' (max 0) x
            + sum (toRows (as * (bs - fromRows (SV.replicate x))))
aplPFP
    :: (KnownNat n, KnownNat m, Reifies s W)
    => BVar s (L n m :# L n m)
    -> BVar s (R m)
    -> BVar s (R m)
aplPFP (a :## b) = apl a b
softPlus :: Floating a => a -> a
softPlus x = log (1 + exp x)
bentIdentity :: Floating a => a -> a
bentIdentity x = (sqrt (x * x + 1) - 1) / 2 + x
siLU :: Floating a => a -> a
siLU x = x * logistic x
softExponential
    :: (Floating a, Ord a)
    => a            
    -> a            
    -> a
softExponential α x = case compare α x of
    LT -> - log (1 - α * (x + α)) / α
    EQ -> x
    GT -> (exp (α * x) - 1) / α + α
sinc :: (Floating a, Eq a) => a -> a
sinc 0 = 1
sinc x = sin x / x
gaussian :: Floating a => a -> a
gaussian x = exp (- (x * x))
maxout :: (KnownNat n, Reifies s W) => BVar s (R n) -> BVar s Double
maxout = liftOp1 . op1 $ \x ->
    let n = HU.maxElement . H.extract $ x
    in  ( n
        , \d -> H.vecR . SVS.map (\e -> if e == n then d else 0) . H.rVec $ x
        )
liftUniform
    :: (Reifies s W, KnownNat n)
    => (BVar s (R n) -> r)
    -> BVar s Double
    -> r
liftUniform f = f . konst
kSparse
    :: forall n s. (Reifies s W, KnownNat n)
    => Int                      
    -> BVar s (R n)
    -> BVar s (R n)
kSparse k = liftOp1 . op1 $ \xs ->
    let xsSort = HU.sortVector (H.extract xs)
        thresh = xsSort `HU.atIndex` (n - k)
        mask   = H.dvmap (\x -> if x >= thresh then 1 else 0) xs
    in  ( H.dvmap (\x -> if x >= thresh then x else 0) xs
        , (* mask)
        )
  where
    n = fromIntegral $ natVal (Proxy @n)