{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module TensorOps.Learn.NeuralNet.FeedForward ( Network(..) , buildNet , runNetwork , trainNetwork , induceNetwork , nmap , (~*) , (*~) , liftNet , netParams , networkGradient , genNet , ffLayer ) where import Control.Category import Control.DeepSeq import Control.Monad.Primitive import Data.Kind import Data.Singletons import Data.Singletons.Prelude (Sing(..)) import Data.Type.Conjunction import Data.Type.Length import Data.Type.Product as TCP import Data.Type.Product.Util as TCP import Data.Type.Sing import Prelude hiding ((.), id) import Statistics.Distribution.Normal import System.Random.MWC import TensorOps.Learn.NeuralNet import TensorOps.NatKind import TensorOps.TOp as TO import TensorOps.Tensor as TT import TensorOps.Types import Type.Class.Higher import Type.Class.Higher.Util import Type.Class.Known import Type.Class.Witness import Type.Family.List import Type.Family.List.Util data Network :: ([k] -> Type) -> k -> k -> Type where N :: { _nsPs :: !(Sing ps) , _nOp :: !(TOp ('[i] ': ps) '[ '[o] ]) , _nParams :: !(Prod t ps) } -> Network t i o instance NFData1 t => NFData (Network t i o) where rnf = \case N _ o p -> o `seq` p `deepseq1` () {-# INLINE rnf #-} buildNet :: SingI ps => TOp ('[i] ': ps) '[ '[o] ] -> Prod t ps -> Network t i o buildNet = N sing netParams :: Network t i o -> (forall ps. SingI ps => Prod t ps -> r) -> r netParams n f = case n of N o _ p -> f p \\ o (~*~) :: Network t a b -> Network t b c -> Network t a c N sPs1 o1 p1 ~*~ N sPs2 o2 p2 = N (sPs1 %:++ sPs2) (o1 *>> o2) (p1 `TCP.append'` p2) \\ singLength sPs1 infixr 4 ~*~ {-# INLINE (~*~) #-} instance Category (Network t) where id = N SNil idOp Ø (.) = flip (~*~) (~*) :: TOp '[ '[a] ] '[ '[b] ] -> Network t b c -> Network t a c f ~* N sO o p = N sO (f *>> o) p infixr 4 ~* {-# INLINE (~*) #-} (*~) :: Network t a b -> TOp '[ '[b] ] '[ '[c] ] -> Network t a c N sO o p *~ f = N sO (o >>> f) p infixl 5 *~ {-# INLINE (*~) #-} liftNet :: TOp '[ '[i] ] '[ '[o] ] -> Network t i o liftNet o = buildNet o Ø nmap :: SingI o => (forall a. RealFloat a => a -> a) -> Network t i o -> Network t i o nmap f n = n *~ TO.map f {-# INLINE nmap #-} runNetwork :: (RealFloat (ElemT t), Tensor t) => Network t i o -> t '[i] -> t '[o] runNetwork (N _ o p) = head' . runTOp o . (:< p) {-# INLINE runNetwork #-} trainNetwork :: forall i o t. (Tensor t, RealFloat (ElemT t)) => TOp '[ '[o], '[o] ] '[ '[] ] -> ElemT t -> t '[i] -> t '[o] -> Network t i o -> Network t i o trainNetwork loss r x y = \case N s o p -> let p' = map1 (\(!(s1 :&: o1 :&: g1)) -> TT.zip stepFunc o1 g1 \\ s1) $ zipProd3 (singProd s) p (tail' $ netGrad loss x y s o p) in N s o p' where stepFunc :: ElemT t -> ElemT t -> ElemT t stepFunc !o' !g' = o' - r * g' {-# INLINE stepFunc #-} {-# INLINE trainNetwork #-} induceNetwork :: forall i o t. (Tensor t, RealFloat (ElemT t), SingI i) => TOp '[ '[o], '[o] ] '[ '[] ] -> ElemT t -> t '[o] -> Network t i o -> t '[i] -> t '[i] induceNetwork loss r y = \case N s o p -> \x -> TT.zip stepFunc x (head' $ netGrad loss x y s o p) where stepFunc :: ElemT t -> ElemT t -> ElemT t stepFunc o' g' = o' - r * g' {-# INLINE stepFunc #-} {-# INLINE induceNetwork #-} networkGradient :: forall i o t r. (Tensor t, RealFloat (ElemT t)) => TOp '[ '[o], '[o] ] '[ '[] ] -> t '[i] -> t '[o] -> Network t i o -> (forall ps. SingI ps => Prod t ps -> r) -> r networkGradient loss x y = \case N s o p -> \f -> f (tail' $ netGrad loss x y s o p) \\ s {-# INLINE networkGradient #-} netGrad :: forall i o ps t. (Tensor t, RealFloat (ElemT t)) => TOp '[ '[o], '[o] ] '[ '[] ] -> t '[i] -> t '[o] -> Sing ps -> TOp ('[i] ': ps) '[ '[o] ] -> Prod t ps -> Prod t ('[i] ': ps) netGrad loss x y s o p = (\\ appendSnoc lO (Proxy @'[o])) $ (\\ lO ) $ takeProd @'[ '[o] ] (LS lO) $ gradTOp o' inp where lO :: Length ps lO = singLength s o' :: ((ps ++ '[ '[o] ]) ~ (ps >: '[o]), Known Length ps) => TOp ('[i] ': ps >: '[o]) '[ '[]] o' = o *>> loss inp :: Prod t ('[i] ': ps >: '[o]) inp = x :< p >: y {-# INLINE netGrad #-} ffLayer :: forall i o m t. (SingI i, SingI o, PrimMonad m, Tensor t) => Gen (PrimState m) -> m (Network t i o) ffLayer g = (\w b -> buildNet ffLayer' (w :< b :< Ø)) <$> genRand (normalDistr 0 0.5) g <*> genRand (normalDistr 0 0.5) g where ffLayer' :: TOp '[ '[i], '[o,i], '[o]] '[ '[o] ] ffLayer' = firstOp (TO.swap >>> TO.matVec) >>> TO.add {-# INLINE ffLayer' #-} {-# INLINE ffLayer #-} genNet :: forall k o i m (t :: [k] -> Type). (SingI o, SingI i, PrimMonad m, Tensor t) => [(Integer, Activation k)] -> Activation k -> Gen (PrimState m) -> m (Network t i o) genNet xs0 f g = go sing xs0 where go :: forall (j :: k). () => Sing j -> [(Integer, Activation k)] -> m (Network t j o) go sj = (\\ sj) $ \case [] -> (*~ getAct f) <$> ffLayer g (x,f'):xs -> withNatKind x $ \sl -> (\\ sl) $ do n <- go sl xs l <- ffLayer g return $ l *~ getAct f' ~*~ n {-# INLINE go #-} {-# INLINE genNet #-}