{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Backprop.Learn.Regularize (
Regularizer
, Regularize(..)
, l1Reg
, l2Reg
, noReg
, l2RegMetric
, l1RegMetric
, RegularizeMetric(..), NoRegularize(..)
, lassoLinear, ridgeLinear
, grnorm_1, grnorm_2
, glasso, gridge
, addReg
, scaleReg
) where
import Control.Applicative
import Control.Monad.Trans.State
import Data.Ratio
import Data.Semigroup hiding (Any(..))
import Data.Type.Functor.Product
import Data.Type.Tuple
import Data.Vinyl
import GHC.Exts
import GHC.Generics
import GHC.TypeNats
import Generics.OneLiner
import Numeric.Backprop as B
import Numeric.LinearAlgebra.Static.Backprop ()
import Numeric.Opto.Update hiding ((<.>))
import qualified Data.Functor.Contravariant as Co
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Sized as SVG
import qualified Numeric.LinearAlgebra.Static as H
type Regularizer p = forall s. Reifies s W => BVar s p -> BVar s Double
class Backprop p => Regularize p where
rnorm_1 :: p -> Double
default rnorm_1 :: (ADT p, Constraints p Regularize) => p -> Double
rnorm_1 = grnorm_1
rnorm_2 :: p -> Double
default rnorm_2 :: (ADT p, Constraints p Regularize) => p -> Double
rnorm_2 = grnorm_2
lasso :: Double -> p -> p
default lasso :: (ADT p, Constraints p Regularize) => Double -> p -> p
lasso = glasso
ridge :: Double -> p -> p
default ridge :: (ADT p, Constraints p Regularize) => Double -> p -> p
ridge = gridge
grnorm_1 :: (ADT p, Constraints p Regularize) => p -> Double
grnorm_1 = getSum . gfoldMap @Regularize (Sum . rnorm_1)
grnorm_2 :: (ADT p, Constraints p Regularize) => p -> Double
grnorm_2 = getSum . gfoldMap @Regularize (Sum . rnorm_2)
glasso :: (ADT p, Constraints p Regularize) => Double -> p -> p
glasso r = gmap @Regularize (lasso r)
gridge :: (ADT p, Constraints p Regularize) => Double -> p -> p
gridge r = gmap @Regularize (ridge r)
l1Reg :: Regularize p => Double -> Regularizer p
l1Reg λ = liftOp1 . op1 $ \x ->
( λ * rnorm_1 x
, (`lasso` x) . (* λ)
)
l2Reg :: Regularize p => Double -> Regularizer p
l2Reg λ = liftOp1 . op1 $ \x ->
( λ * rnorm_2 x
, (`ridge` x) . (* λ)
)
noReg :: Regularizer p
noReg _ = auto 0
lassoLinear :: (Linear Double p, Num p) => Double -> p -> p
lassoLinear r = (r .*) . signum
ridgeLinear :: Linear Double p => Double -> p -> p
ridgeLinear = (.*)
l2RegMetric
:: (Metric Double p, Backprop p)
=> Double
-> Regularizer p
l2RegMetric λ = liftOp1 . op1 $ \x ->
( λ * quadrance x, (.* x) . (* λ))
l1RegMetric
:: (Num p, Metric Double p, Backprop p)
=> Double
-> Regularizer p
l1RegMetric λ = liftOp1 . op1 $ \x ->
( λ * norm_1 x, (.* signum x) . (* λ)
)
addReg :: Regularizer p -> Regularizer p -> Regularizer p
addReg f g x = f x + g x
scaleReg :: Double -> Regularizer p -> Regularizer p
scaleReg λ reg = (* auto λ) . reg
newtype RegularizeMetric a = RegularizeMetric a
deriving (Show, Eq, Ord, Read, Generic, Functor, Backprop)
instance (Metric Double p, Num p, Backprop p) => Regularize (RegularizeMetric p) where
rnorm_1 = coerce $ norm_1 @_ @p
rnorm_2 = coerce $ norm_2 @_ @p
lasso = coerce $ lassoLinear @p
ridge = coerce $ ridgeLinear @p
newtype NoRegularize a = NoRegularize a
deriving (Show, Eq, Ord, Read, Generic, Functor, Backprop)
instance Backprop a => Regularize (NoRegularize a) where
rnorm_1 _ = 0
rnorm_2 _ = 0
lasso _ = B.zero
ridge _ = B.zero
instance Regularize Double where
rnorm_1 = id
rnorm_2 = (** 2)
lasso r = (r *) . signum
ridge = (*)
instance Regularize Float where
rnorm_1 = realToFrac
rnorm_2 = (** 2) . realToFrac
lasso r = (realToFrac r *) . signum
ridge = (*) . realToFrac
instance Integral a => Regularize (Ratio a) where
rnorm_1 = realToFrac
rnorm_2 = (** 2) . realToFrac
lasso r = (realToFrac r *) . signum
ridge = (*) . realToFrac
instance Regularize () where
rnorm_1 _ = 0
rnorm_2 _ = 0
lasso _ _ = ()
ridge _ _ = ()
instance (Regularize a, Regularize b) => Regularize (a, b)
instance (Regularize a, Regularize b, Regularize c) => Regularize (a, b, c)
instance (Regularize a, Regularize b, Regularize c, Regularize d) => Regularize (a, b, c, d)
instance (Regularize a, Regularize b, Regularize c, Regularize d, Regularize e) => Regularize (a, b, c, d, e)
instance (Regularize a, Regularize b) => Regularize (a :# b)
instance Regularize a => Regularize (TF a)
instance (RPureConstrained Regularize as, ReifyConstraint Backprop TF as, RMap as, RApply as, RFoldMap as) => Regularize (Rec TF as) where
rnorm_1 = getSum
. rfoldMap getConst
. rzipWith coerce (rpureConstrained @Regularize (Co.Op rnorm_1))
rnorm_2 = getSum
. rfoldMap getConst
. rzipWith coerce (rpureConstrained @Regularize (Co.Op rnorm_2))
lasso r = rzipWith coerce (rpureConstrained @Regularize (Endo (lasso r)))
ridge r = rzipWith coerce (rpureConstrained @Regularize (Endo (ridge r)))
instance (PureProdC Maybe Backprop as, PureProdC Maybe Regularize as) => Regularize (PMaybe TF as) where
rnorm_1 = getSum
. foldMapProd getConst
. zipWithProd coerce (pureProdC @_ @Regularize (Co.Op rnorm_1))
rnorm_2 = getSum
. foldMapProd getConst
. zipWithProd coerce (pureProdC @_ @Regularize (Co.Op rnorm_2))
lasso r = zipWithProd coerce (pureProdC @_ @Regularize (Endo (lasso r)))
ridge r = zipWithProd coerce (pureProdC @_ @Regularize (Endo (ridge r)))
instance (VG.Vector v a, Regularize a, Backprop (SVG.Vector v n a)) => Regularize (SVG.Vector v n a) where
rnorm_1 = (`execState` 0) . SVG.mapM_ (modify . (+) . rnorm_1)
rnorm_2 = (`execState` 0) . SVG.mapM_ (modify . (+) . rnorm_2)
lasso r = SVG.map (lasso r)
ridge r = SVG.map (ridge r)
deriving via (RegularizeMetric (H.R n)) instance KnownNat n => Regularize (H.R n)
deriving via (RegularizeMetric (H.L n m)) instance (KnownNat n, KnownNat m) => Regularize (H.L n m)