{-# LANGUAGE ApplicativeDo #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
module Numeric.Opto.Optimizer (
steepestDescent
, Momentum(..), momentum
, Nesterov(..), nesterov
, Adagrad(..), adagrad
, Adadelta(..), adadelta
, RMSProp(..), rmsProp
, Adam(..), adam
, AdaMax(..), adaMax
) where
import Data.Default
import Numeric.Opto.Core
import Numeric.Opto.Ref
import Numeric.Opto.Update
steepestDescent
:: LinearInPlace m c a
=> c
-> Grad m r a
-> Opto m r a
steepestDescent lr gr = fromStateless $ \r x -> do
!g <- gr r x
pure (-lr, g)
newtype Momentum c = Momentum
{ momentumDecay :: c
}
deriving (Show, Eq)
instance Fractional c => Default (Momentum c) where
def = Momentum { momentumDecay = 0.9 }
momentum
:: forall m r a c. LinearInPlace m c a
=> Momentum c
-> c
-> Grad m r a
-> Opto m r a
momentum Momentum{..} lr gr = MkOpto
{ oInit = zeroL :: a
, oUpdate = \rV r x -> do
!g <- gr r x
rV .*= momentumDecay
rV .*+= (lr, g)
v <- freezeRef rV
pure ( -1, v )
}
newtype Nesterov c = Nesterov
{ nesterovDecay :: c
}
deriving (Show, Eq)
instance Fractional c => Default (Nesterov c) where
def = Nesterov { nesterovDecay = 0.9 }
nesterov
:: forall m r a c. LinearInPlace m c a
=> Nesterov c
-> c
-> Grad m r a
-> Opto m r a
nesterov Nesterov{..} lr gr = MkOpto
{ oInit = zeroL :: a
, oUpdate = \rV r x -> do
rV .*= nesterovDecay
!v <- freezeRef rV
!g <- gr r (x .+. ((-1) .* v))
rV .*+= (lr, g)
!w <- freezeRef rV
pure ( -1, w )
}
data Adagrad c = Adagrad
{ adagradRate :: c
, adagradEps :: c
}
deriving (Show, Eq)
instance Fractional c => Default (Adagrad c) where
def = Adagrad
{ adagradRate = 0.01
, adagradEps = 1e-8
}
adagrad
:: forall m r a c.
( LinearInPlace m c a
, Floating a
, Real c
)
=> Adagrad c
-> Grad m r a
-> Opto m r a
adagrad Adagrad{..} gr = MkOpto
{ oInit = zeroL :: a
, oUpdate = \rBigG r x -> do
!g <- gr r x
rBigG .+.= (g ** 2)
!bigG <- freezeRef rBigG
pure ( - adagradRate
, g / sqrt (bigG + eps)
)
}
where
eps = realToFrac adagradEps
data Adadelta c = Adadelta
{ adadeltaDecay :: c
, adadeltaEps :: c
}
deriving (Show, Eq)
instance Fractional c => Default (Adadelta c) where
def = Adadelta
{ adadeltaDecay = 0.9
, adadeltaEps = 1e-8
}
adadelta
:: forall m r a c.
( LinearInPlace m c a
, Floating a
, Real c
)
=> Adadelta c
-> Grad m r a
-> Opto m r a
adadelta Adadelta{..} gr = MkOpto
{ oInit = (zeroL, zeroL) :: (a, a)
, oUpdate = \(rDeltHist, rGradHist) r x -> do
!g <- gr r x
!deltHist <- freezeRef rDeltHist
rGradHist .*= adadeltaDecay
rGradHist .*+= (complDecay, g ** 2)
!gradHist <- freezeRef rGradHist
let negaDelt = (sqrt (deltHist + eps) / sqrt (gradHist + eps)) * g
rDeltHist .*= adadeltaDecay
rDeltHist .*+= (complDecay, negaDelt ** 2)
pure ( -1, negaDelt )
}
where
eps = realToFrac adadeltaEps
complDecay = 1 - adadeltaDecay
data RMSProp c = RMSProp
{ rmsPropRate :: c
, rmsPropDecay :: c
, rmsPropEps :: c
}
deriving (Show, Eq)
instance Fractional c => Default (RMSProp c) where
def = RMSProp
{ rmsPropRate = 0.001
, rmsPropDecay = 0.9
, rmsPropEps = 1e-8
}
rmsProp
:: forall m r a c.
( LinearInPlace m c a
, Floating a
, Real c
)
=> RMSProp c
-> Grad m r a
-> Opto m r a
rmsProp RMSProp{..} gr = MkOpto
{ oInit = zeroL :: a
, oUpdate = \rGradHist r x -> do
!g <- gr r x
rGradHist .*= rmsPropDecay
rGradHist .*+= (complDecay, g ** 2)
!gradHist <- freezeRef rGradHist
pure ( - rmsPropRate
, g / sqrt (gradHist + eps)
)
}
where
eps = realToFrac rmsPropEps
complDecay = 1 - rmsPropDecay
data Adam c = Adam
{ adamStep :: !c
, adamDecay1 :: !c
, adamDecay2 :: !c
, adamEps :: !c
}
deriving (Show, Eq)
instance Fractional c => Default (Adam c) where
def = Adam { adamStep = 0.001
, adamDecay1 = 0.9
, adamDecay2 = 0.999
, adamEps = 1e-8
}
adam
:: forall m r a c.
( RealFloat c
, Floating a
, LinearInPlace m c a
, Mutable m c
)
=> Adam c
-> Grad m r a
-> Opto m r a
adam Adam{..} gr = MkOpto
{ oInit = (1, zeroL, zeroL) :: (c, a, a)
, oUpdate = \(rT, rM, rV) r x -> do
!g <- gr r x
rM .*= adamDecay1
rV .*= adamDecay2
rM .*+= (1 - adamDecay1, g)
rV .*+= (1 - adamDecay2, g * g)
(!m, !v) <- freezeRef (rM, rV)
!t <- updateRef' rT $ \t0 -> let !t1 = t0 + 1
in (t1, t1)
let !mHat = recip (1 - adamDecay1 ** t) .* m
!vHat = recip (1 - adamDecay2 ** t) .* v
return ( -adamStep
, mHat / (sqrt vHat + realToFrac adamEps)
)
}
data AdaMax c = AdaMax
{ adaMaxStep :: !c
, adaMaxDecay1 :: !c
, adaMaxDecay2 :: !c
, adaMaxEps :: !c
}
deriving (Show, Eq)
instance Fractional c => Default (AdaMax c) where
def = AdaMax { adaMaxStep = 0.002
, adaMaxDecay1 = 0.9
, adaMaxDecay2 = 0.999
, adaMaxEps = 1e-8
}
adaMax
:: forall m r a c.
( RealFloat c
, Metric c a
, LinearInPlace m c a
, Mutable m c
)
=> AdaMax c
-> Grad m r a
-> Opto m r a
adaMax AdaMax{..} gr = MkOpto
{ oInit = (1, zeroL, 0) :: (c, a, c)
, oUpdate = \(rT, rM, rU) r x -> do
!g <- gr r x
rM .*= adaMaxDecay1
rM .*+= (1 - adaMaxDecay1, g)
!t <- updateRef' rT $ \t0 ->
let !t1 = t0 + 1
in (t1, t1)
!m <- freezeRef rM
!u <- updateRef' rU $ \u0 ->
let !u1 = max (adaMaxDecay2 * u0) (norm_inf g)
in (u1, u1)
return ( -adaMaxStep / ((1 - adaMaxDecay1 ** t) * u)
, m
)
}