{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
module Backprop.Learn.Model.Function (
meanModel
, varModel
, stdevModel
, rangeModel
, step
, logistic
, softsign
, reLU
, softPlus
, bentIdentity
, siLU
, softExponential
, sinc
, gaussian
, tanh
, atan
, sin
, vmap
, vmap'
, liftUniform
, isru
, preLU
, sreLU
, sreLUPFP
, eLU
, isrLU
, apl
, aplPFP
, softMax
, maxout
, kSparse
) where
import Control.Category
import Data.Foldable
import Data.Profunctor
import Data.Proxy
import Data.Type.Tuple
import GHC.TypeNats
import Numeric.Backprop
import Numeric.LinearAlgebra.Static.Backprop hiding (tr)
import Prelude hiding ((.), id)
import qualified Control.Foldl as F
import qualified Data.Vector.Sized as SV
import qualified Data.Vector.Storable.Sized as SVS
import qualified Numeric.LinearAlgebra as HU
import qualified Numeric.LinearAlgebra.Static as H
import qualified Numeric.LinearAlgebra.Static.Vector as H
meanModel
:: (Backprop (t a), Foldable t, Functor t, Fractional a, Reifies s W)
=> BVar s (t a)
-> BVar s a
meanModel = liftOp1 . op1 $ \xs ->
let x :# n = F.fold ((:#) <$> F.sum <*> F.length) xs
in (x / fromIntegral n, \d -> (d / fromIntegral n) <$ xs)
varModel
:: (Backprop (t a), Foldable t, Functor t, Fractional a, Reifies s W)
=> BVar s (t a)
-> BVar s a
varModel = liftOp1 . op1 $ \xs ->
let x2 :# x1 :# x0 = F.fold ((\x2' x1' x0' -> x2' :# x1' :# x0') <$> lmap (^(2::Int)) F.sum <*> F.sum <*> F.length) xs
meanx = x1 / fromIntegral x0
subAll = 2 * x1 / (fromIntegral x0 ^ (2 :: Int))
in ( (x2 / fromIntegral x0) - meanx * meanx
, \d -> let subAllD = d * subAll
in (\x -> d * 2 * x / fromIntegral x0 - subAllD) <$> xs
)
stdevModel
:: (Backprop (t a), Foldable t, Functor t, Floating a, Reifies s W)
=> BVar s (t a)
-> BVar s a
stdevModel = sqrt . varModel
rangeModel
:: (Backprop (t a), Foldable t, Functor t, Ord a, Num a, Reifies s W)
=> BVar s (t a)
-> BVar s a
rangeModel = liftOp1 . op1 $ \xs ->
let mn :# mx = F.fold ((:#) <$> F.minimum <*> F.maximum) xs
in case (:#) <$> mn <*> mx of
Nothing -> errorWithoutStackTrace "Backprop.Learn.Model.Function.range: empty range"
Just (mn' :# mx') ->
( mx' - mn'
, \d -> (\x -> if | x == mx' -> d
| x == mn' -> -d
| otherwise -> 0
) <$> xs
)
softMax
:: (KnownNat i, Reifies s W)
=> BVar s (R i)
-> BVar s (R i)
softMax x = expx / konst (norm_1V expx)
where
expx = exp x
logistic :: Floating a => a -> a
logistic x = 1 / (1 + exp (-x))
step :: (Ord a, Num a) => a -> a
step x | x < 0 = 0
| otherwise = 1
softsign :: Fractional a => a -> a
softsign x = x / (1 + abs x)
isru
:: Floating a
=> a
-> a
-> a
isru α x = x / sqrt (1 + α * x * x)
reLU :: (Num a, Ord a) => a -> a
reLU x | x < 0 = 0
| otherwise = x
preLU
:: (Num a, Ord a)
=> a
-> a
-> a
preLU α x | x < 0 = α * x
| otherwise = x
eLU :: (Floating a, Ord a)
=> a
-> a
-> a
eLU α x | x < 0 = α * (exp x - 1)
| otherwise = x
sreLU
:: (Num a, Ord a)
=> a
-> a
-> a
-> a
-> a
-> a
sreLU tl al tr ar x
| x < tl = tl + al * (x - tl)
| x > tr = tr + ar * (x - tr)
| otherwise = x
sreLUPFP
:: (KnownNat n, Reifies s W)
=> BVar s ((Double :# Double) :# (Double :# Double))
-> BVar s (R n)
-> BVar s (R n)
sreLUPFP ((tl :## al) :## (tr :## ar)) = vmap (sreLU tl al tr ar)
isrLU
:: (Floating a, Ord a)
=> a
-> a
-> a
isrLU α x
| x < 0 = x / sqrt (1 + α * x * x)
| otherwise = x
apl :: (KnownNat n, KnownNat m, Reifies s W)
=> BVar s (L n m)
-> BVar s (L n m)
-> BVar s (R m)
-> BVar s (R m)
apl as bs x = vmap' (max 0) x
+ sum (toRows (as * (bs - fromRows (SV.replicate x))))
aplPFP
:: (KnownNat n, KnownNat m, Reifies s W)
=> BVar s (L n m :# L n m)
-> BVar s (R m)
-> BVar s (R m)
aplPFP (a :## b) = apl a b
softPlus :: Floating a => a -> a
softPlus x = log (1 + exp x)
bentIdentity :: Floating a => a -> a
bentIdentity x = (sqrt (x * x + 1) - 1) / 2 + x
siLU :: Floating a => a -> a
siLU x = x * logistic x
softExponential
:: (Floating a, Ord a)
=> a
-> a
-> a
softExponential α x = case compare α x of
LT -> - log (1 - α * (x + α)) / α
EQ -> x
GT -> (exp (α * x) - 1) / α + α
sinc :: (Floating a, Eq a) => a -> a
sinc 0 = 1
sinc x = sin x / x
gaussian :: Floating a => a -> a
gaussian x = exp (- (x * x))
maxout :: (KnownNat n, Reifies s W) => BVar s (R n) -> BVar s Double
maxout = liftOp1 . op1 $ \x ->
let n = HU.maxElement . H.extract $ x
in ( n
, \d -> H.vecR . SVS.map (\e -> if e == n then d else 0) . H.rVec $ x
)
liftUniform
:: (Reifies s W, KnownNat n)
=> (BVar s (R n) -> r)
-> BVar s Double
-> r
liftUniform f = f . konst
kSparse
:: forall n s. (Reifies s W, KnownNat n)
=> Int
-> BVar s (R n)
-> BVar s (R n)
kSparse k = liftOp1 . op1 $ \xs ->
let xsSort = HU.sortVector (H.extract xs)
thresh = xsSort `HU.atIndex` (n - k)
mask = H.dvmap (\x -> if x >= thresh then 1 else 0) xs
in ( H.dvmap (\x -> if x >= thresh then x else 0) xs
, (* mask)
)
where
n = fromIntegral $ natVal (Proxy @n)