{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
module Backprop.Learn.Model.Neural (
FCp, fc, fca
, fcWeights, fcBias
, fcr, fcra
, FCRp, fcrBias, fcrInputWeights, fcrStateWeights
) where
import Backprop.Learn.Model.Combinator
import Backprop.Learn.Model.Regression
import Backprop.Learn.Model.State
import Backprop.Learn.Model.Types
import Data.Tuple
import GHC.TypeNats
import Lens.Micro
import Numeric.Backprop
import Numeric.LinearAlgebra.Static.Backprop
import qualified Numeric.LinearAlgebra.Static as H
type FCp = LRp
fcWeights :: Lens (FCp i o) (FCp i' o) (L o i) (L o i')
fcWeights = lrBeta
fcBias :: forall i o. Lens' (FCp i o) (R o)
fcBias = lrAlpha
fc :: (KnownNat i, KnownNat o)
=> Model ('Just (FCp i o)) 'Nothing (R i) (R o)
fc = linReg
fca :: (KnownNat i, KnownNat o)
=> (forall z. Reifies z W => BVar z (R o) -> BVar z (R o))
-> Model ('Just (FCp i o)) 'Nothing (R i) (R o)
fca f = funcD f <~ linReg
fcr :: (KnownNat i, KnownNat o, KnownNat s)
=> (forall z. Reifies z W => BVar z (R o) -> BVar z (R s))
-> Model ('Just (FCRp s i o)) ('Just (R s)) (R i) (R o)
fcr s = recurrent H.split (H.#) s fc
fcra
:: (KnownNat i, KnownNat o, KnownNat s)
=> (forall z. Reifies z W => BVar z (R o) -> BVar z (R o))
-> (forall z. Reifies z W => BVar z (R o) -> BVar z (R s))
-> Model ('Just (FCRp s i o)) ('Just (R s)) (R i) (R o)
fcra f s = funcD f <~ recurrent H.split (H.#) s fc
type FCRp s i o = FCp (i + s) o
lensIso :: (s -> (a, x)) -> ((b, x) -> t) -> Lens s t a b
lensIso f g h x = g <$> _1 h (f x)
fcrInputWeights
:: (KnownNat s, KnownNat i, KnownNat i', KnownNat o)
=> Lens (FCRp s i o) (FCRp s i' o) (L o i) (L o i')
fcrInputWeights = fcWeights
. lensIso H.splitCols (uncurry (H.|||))
fcrStateWeights
:: (KnownNat s, KnownNat s', KnownNat i, KnownNat o)
=> Lens (FCRp s i o) (FCRp s' i o) (L o s) (L o s')
fcrStateWeights = fcWeights
. lensIso (swap . H.splitCols) (uncurry (H.|||) . swap)
fcrBias :: forall s i o. Lens' (FCRp s i o) (R o)
fcrBias = fcBias @(i + s) @o