{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} module TensorOps.Learn.NeuralNet.Recurrent ( Network , buildNet , netParams , runNetwork , runNetworkSt , genNet , fullyConnected , ffLayer , stateless , (~*~) , (*~) , (~*) , nmap , trainNetwork , trainNetwork' , networkGradient ) where import Control.Category import Control.DeepSeq import Control.Monad.Primitive import Control.Monad.State import Data.Kind import Data.Singletons import Data.Singletons.Prelude (Sing(..)) import Data.Type.Combinator import Data.Type.Conjunction import Data.Type.Length as TCL import Data.Type.Length.Util as TCL import Data.Type.Nat import Data.Type.Product as TCP import Data.Type.Product.Util as TCP import Data.Type.Sing import Data.Type.Uniform import Data.Type.Vector as TCV import Data.Type.Vector.Util as TCV import Prelude hiding ((.), id) import Statistics.Distribution.Normal import System.Random.MWC import TensorOps.Learn.NeuralNet import TensorOps.NatKind import TensorOps.TOp as TO import TensorOps.Tensor as TT import TensorOps.Types import Type.Class.Higher import Type.Class.Higher.Util import Type.Class.Known import Type.Class.Witness import Type.Family.List import Type.Family.List.Util import Type.Family.Nat import Unsafe.Coerce import qualified TensorOps.Learn.NeuralNet.FeedForward as FF data Network :: ([k] -> Type) -> k -> k -> Type where N :: { _nsSs :: !(Sing ss) , _nsPs :: !(Sing ps) , _nOp :: !(TOp ('[i] ': ss ++ ps) ('[o] ': ss)) , _nState :: !(Prod t ss) , _nParams :: !(Prod t ps) } -> Network t i o instance NFData1 t => NFData (Network t i o) where rnf = \case N _ _ o p s -> o `seq` p `deepseq1` s `deepseq1` () {-# INLINE rnf #-} netParams :: Network t i o -> (forall ss ps. (SingI ss, SingI ps) => Prod t ss -> Prod t ps -> r) -> r netParams = \case N sS sP _ s p -> \f -> f s p \\ sS \\ sP buildNet :: forall ss ps i o t. (SingI ss, SingI ps) => TOp ('[i] ': ss ++ ps) ('[o] ': ss) -- ^ Tensor operation for network -> Prod t ss -- ^ Initial states -> Prod t ps -- ^ Network parameters -> Network t i o buildNet = N sing sing fullyConnected :: forall k (i :: k) (o :: k) (m :: Type -> Type) (t :: [k] -> Type). ( SingI i , SingI o , PrimMonad m , Tensor t ) => Activation k -- ^ Activation function for internal state -> Gen (PrimState m) -> m (Network t i o) fullyConnected a g = (\s w w' b -> buildNet @'[ '[o] ] @'[ '[o,o], '[o,i], '[o] ] fc (s :< Ø) (w' :< w :< b :< Ø) ) <$> genRand (normalDistr 0 0.5) g <*> genRand (normalDistr 0 0.5) g <*> genRand (normalDistr 0 0.5) g <*> genRand (normalDistr 0 0.5) g where fc :: TOp '[ '[i], '[o], [o,o], '[o,i], '[o]] '[ '[o], '[o] ] fc = secondOp @'[ '[i] ] ( firstOp @'[ '[o,i], '[o] ] (TO.swap >>> TO.matVec) >>> firstOp @'[ '[o] ] TO.swap ) >>> firstOp @'[ '[o], '[o] ] (TO.swap >>> TO.matVec) >>> TO.add3 >>> TO.duplicate >>> secondOp @'[ '[o] ] (getAct a) {-# INLINE fc #-} {-# INLINE fullyConnected #-} -- | Convert a neural network from "TensorOps.Learn.NeuralNet.FeedForward" -- to a stateless 'Network'. -- -- Can be thought a functor from the category of stateless neural networks -- to the category of stateful (recurrent) ones. stateless :: FF.Network t i o -> Network t i o stateless = \case FF.N sP o p -> N SNil sP o Ø p {-# INLINE stateless #-} ffLayer :: forall i o m t. (SingI i, SingI o, PrimMonad m, Tensor t) => Gen (PrimState m) -> m (Network t i o) ffLayer g = stateless <$> FF.ffLayer g {-# INLINE ffLayer #-} genNet :: forall k o i m (t :: [k] -> Type). (SingI o, SingI i, PrimMonad m, Tensor t) => [(Integer, (Activation k, Maybe (Activation k)))] -> Activation k -> Maybe (Activation k) -> Gen (PrimState m) -> m (Network t i o) genNet xs0 f fS g = go sing xs0 where go :: forall (j :: k). () => Sing j -> [(Integer, (Activation k, Maybe (Activation k)))] -> m (Network t j o) go sj = (\\ sj) $ \case [] -> fmap (*~ getAct f) $ case fS of Just fS' -> fullyConnected fS' g Nothing -> ffLayer g (x,(f', fS')):xs -> withNatKind x $ \sl -> (\\ sl) $ do n <- go sl xs l <- case fS' of Nothing -> ffLayer g Just fS'' -> fullyConnected fS'' g return $ l *~ getAct f' ~*~ n {-# INLINE go #-} {-# INLINE genNet #-} instance Category (Network t) where id = N SNil SNil idOp Ø Ø (.) = flip (~*~) (~*~) :: forall k (t :: [k] -> Type) a b c. () => Network t a b -> Network t b c -> Network t a c (~*~) = \case N (sSs1 :: Sing ss1) (sPs1 :: Sing ps1) (o1 :: TOp ('[a] ': ss1 ++ ps1) ('[b] ': ss1)) (s1 :: Prod t ss1) (p1 :: Prod t ps1) -> \case N (sSs2 :: Sing ss2) (sPs2 :: Sing ps2) (o2 :: TOp ('[b] ': ss2 ++ ps2) ('[c] ': ss2)) (s2 :: Prod t ss2) (p2 :: Prod t ps2) -> let lSs1 :: Length ss1 lSs1 = singLength sSs1 lSs2 :: Length ss2 lSs2 = singLength sSs2 lPs1 :: Length ps1 lPs1 = singLength sPs1 lPs2 :: Length ps2 lPs2 = singLength sPs2 o :: TOp ('[a] ': (ss2 ++ ss1) ++ (ps1 ++ ps2)) ('[c] ': ss2 ++ ss1) -- all these proofs lol o = (\\ (unsafeCoerce Refl :: ((ss2 ++ ss1) ++ ps1) ++ ps2 :~: (ss2 ++ ss1) ++ (ps1 ++ ps2))) $ (\\ (unsafeCoerce Refl :: ((ss1 ++ ps1) ++ ss2) ++ ps2 :~: (ss1 ++ ps1) ++ (ss2 ++ ps2))) $ (\\ appendAssoc lSs1 lSs2 lPs2 ) $ (\\ appendAssoc lSs2 lSs1 lPs1 ) $ (\\ (lSs2 `TCL.append'` lSs1) `TCL.append'` lPs1) $ (\\ (lSs1 `TCL.append'` lPs1) `TCL.append'` lSs2) $ (\\ lSs1 ) $ (\\ lSs2 ) $ (\\ lSs1 `TCL.append'` lPs1 ) $ (\\ lSs2 `TCL.append'` lPs2 ) $ secondOp @'[ '[a] ] (firstOp @ps2 (TO.swap' lSs2 (lSs1 `TCL.append'` lPs1)) ) >>> firstOp @(ss2 ++ ps2) o1 >>> secondOp @'[ '[b] ] (TO.swap' lSs1 (lSs2 `TCL.append'` lPs2)) >>> firstOp @ss1 o2 in N (sSs2 %:++ sSs1) (sPs1 %:++ sPs2) o (s2 `TCP.append'` s1) (p1 `TCP.append'` p2) infixr 4 ~*~ {-# INLINE (~*~) #-} runNetwork :: (RealFloat (ElemT t), Tensor t) => Network t i o -> t '[i] -> (t '[o], Network t i o) runNetwork (N sS sO o s p) x = (\case y :< s' -> (y, N sS sO o s' p)) . runTOp o $ x :< (s `TCP.append'` p) {-# INLINE runNetwork #-} runNetworkSt :: (RealFloat (ElemT t), Tensor t, MonadState (Network t i o) m) => t '[i] -> m (t '[o]) runNetworkSt x = state $ flip runNetwork x (~*) :: TOp '[ '[a] ] '[ '[b] ] -> Network t b c -> Network t a c f ~* N sS sO o s p = N sS sO (f *>> o) s p infixr 4 ~* {-# INLINE (~*) #-} (*~) :: Network t a b -> TOp '[ '[b] ] '[ '[c] ] -> Network t a c N sS sO o s p *~ f = N sS sO (o >>> firstOp f) s p infixl 5 *~ {-# INLINE (*~) #-} nmap :: SingI o => (forall a. RealFloat a => a -> a) -> Network t i o -> Network t i o nmap f n = n *~ TO.map f {-# INLINE nmap #-} netGrad :: forall k n (t :: [k] -> Type) (i :: k) (o :: k) ss ps. ( Tensor t , RealFloat (ElemT t) , SingI i , SingI o ) => TOp '[ '[o], '[o] ] '[ '[] ] -> Vec n (t '[i]) -> Vec n (t '[o]) -> Sing ss -> Sing ps -> TOp ('[i] ': ss ++ ps) ('[o] ': ss) -> Prod t ss -> Prod t ps -> (Vec n (t '[i]), (Prod t ss, Prod t ps)) netGrad loss xs ys sS sP o s p = (prodToVec' I n grI, splitProd @ps lS grSP) where n :: Nat n n = known \\ xs lS :: Length ss lS = singLength sS lP :: Length ps lP = singLength sP sO :: Sing (Replicate n '[o]) sO = replicateSing (sing :: Sing '[o]) n lI :: Length (Replicate n '[i]) lI = replicateLength @'[i] n lO :: Length (Replicate n '[o]) lO = replicateLength @'[o] n unrolled :: TOp (Replicate n '[i] ++ ss ++ ps) (Replicate n '[o]) unrolled = (\\ sS %:++ sO) $ (\\ sS %:++ sP) $ unroll sS sP o n >>> TO.drop @(Replicate n '[o]) lS o' :: TOp (Replicate n '[i] ++ ss ++ ps ++ Replicate n '[o]) '[ '[] ] o' = (\\ appendAssoc lI (lS `TCL.append'` lP) lO) $ (\\ appendAssoc lS lP lO) $ (\\ lI `TCL.append'` lS `TCL.append'` lP) $ (\\ lO) $ firstOp @(Replicate n '[o]) unrolled >>> rollup @k lS lP loss n xs' :: Prod t (Replicate n '[i] ++ ss ++ ps ++ Replicate n '[o]) xs' = vecToProd' getI (TCV.reverse' xs) `TCP.append'` (s `TCP.append'` ( p `TCP.append'` vecToProd' getI ys ) ) grad :: Prod t (Replicate n '[i] ++ ss ++ ps) grad = (\\ (unsafeCoerce Refl :: Replicate n '[i] ++ ss ++ ps ++ Replicate n '[o] :~: (Replicate n '[i] ++ ss ++ ps) ++ Replicate n '[o] ) ) $ takeProd @(Replicate n '[o]) (lI `TCL.append'` lS `TCL.append'` lP) $ gradTOp o' xs' grI :: Prod t (Replicate n '[i]) grSP :: Prod t (ss ++ ps) (grI, grSP) = splitProd @(ss ++ ps) lI grad {-# INLINE netGrad #-} trainNetwork' :: forall k n (t :: [k] -> Type) (i :: k) (o :: k). ( Tensor t , RealFloat (ElemT t) , SingI i , SingI o ) => TOp '[ '[o], '[o] ] '[ '[] ] -- ^ loss function (input, target) -> ElemT t -- ^ train rate for initial state -> ElemT t -- ^ train rate for network parameters -> Vec n (t '[i]) -- ^ inputs -> Vec n (t '[o]) -- ^ targets -> Network t i o -> Network t i o trainNetwork' loss rS rP xs ys = \case N sS sP o s p -> let (gS, gP) = snd $ netGrad loss xs ys sS sP o s p s' = map1 (f rS) $ zipProd3 (singProd sS) s gS p' = map1 (f rP) $ zipProd3 (singProd sP) p gP in N sS sP o s' p' where f :: forall ns. () => ElemT t -> (Sing :&: t :&: t) ns -> t ns f r !(s1 :&: p1 :&: g1) = TT.zip (\(!p2) (!g2) -> p2 - r * g2) p1 g1 \\ s1 {-# INLINE f #-} {-# INLINE trainNetwork #-} trainNetwork :: forall k (t :: [k] -> Type) (i :: k) (o :: k). ( Tensor t , RealFloat (ElemT t) , SingI i , SingI o ) => TOp '[ '[o], '[o] ] '[ '[] ] -- ^ loss function (input, target) -> ElemT t -- ^ train rate for initial state -> ElemT t -- ^ train rate for network parameters -> [(t '[i], t '[o])] -- ^ inputs and targets -> Network t i o -> Network t i o trainNetwork loss rS rP xsys n = withV xsys $ \v -> let (xs, ys) = TCV.unzip' v in trainNetwork' loss rS rP xs ys n networkGradient :: forall k n (t :: [k] -> Type) (i :: k) (o :: k) r. ( Tensor t , RealFloat (ElemT t) , SingI i , SingI o ) => TOp '[ '[o], '[o] ] '[ '[] ] -- ^ loss function (output, target) -> Vec n (t '[i]) -- ^ inputs -> Vec n (t '[o]) -- ^ targets -> Network t i o -> (forall ss ps. (SingI ss, SingI ps) => Vec n (t '[i]) -> Prod t ss -> Prod t ps -> r) -> r networkGradient loss xs ys = \case N sS sP o s p -> \f -> case netGrad loss xs ys sS sP o s p of (gI, (gS, gP)) -> f gI gS gP \\ sS \\ sP {-# INLINE networkGradient #-} unroll :: forall ss ps i o n. (SingI (ss ++ ps), SingI i) => Sing ss -> Sing ps -> TOp ('[i] ': ss ++ ps) ('[o] ': ss) -> Nat n -> TOp (Replicate n '[i] ++ ss ++ ps) (ss ++ Replicate n '[o]) unroll sS sP o = \case Z_ -> TO.take lS lP \\ appendNil lS S_ (m :: Nat m) -> (\\ (unsafeCoerce Refl :: Replicate m '[i] ++ '[i] ': ss ++ ps :~: '[i] ': Replicate m '[i] ++ ss ++ ps ) ) $ (\\ (unsafeCoerce Refl :: (Replicate m '[i] ++ ss ++ ps) ++ '[ '[o] ] :~: Replicate m '[i] ++ ss ++ (ps >: '[o]) ) ) $ (\\ (unsafeCoerce Refl :: (ss ++ Replicate m '[o]) ++ '[ '[o] ] :~: ss ++ '[o] ': Replicate m '[o] ) ) $ (\\ appendAssoc lS lP (LS LZ :: Length '[ '[o] ])) $ (\\ appendSnoc lP (Proxy @'[o])) $ (\\ replicateLength @'[i] m `TCL.append'` lS `TCL.append'` lP) $ (\\ lS `TCL.append'` replicateLength @'[o] m) $ (\\ replicateLength @'[i] m) $ (\\ lS) $ secondOp @(Replicate m '[i]) @('[i] ': ss ++ ps) @(ss ++ ps >: '[o]) ( (o &&& TO.drop @ps @('[i] ': ss) (LS lS)) >>> TO.swap' (LS LZ) (lS `TCL.append'` lP) ) >>> firstOp @'[ '[o] ] @(Replicate m '[i] ++ ss ++ ps) @(ss ++ Replicate m '[o]) ( unroll sS sP o m ) where lS :: Length ss lS = singLength sS lP :: Length ps lP = singLength sP {-# INLINE unroll #-} rollup :: forall k (ss :: [[k]]) (ps :: [[k]]) (o :: k) (n :: N). () => Length ss -> Length ps -> TOp '[ '[o], '[o] ] '[ '[] ] -> Nat n -> TOp (Replicate n '[o] ++ Replicate n '[o]) '[ '[] ] rollup lS lP loss = \case Z_ -> TO.konst (US UØ) 0 S_ Z_ -> loss S_ (m@(S_ _) :: Nat m) -> let lO :: Length (Replicate m '[o]) lO = replicateLength @'[o] m in (\\ (unsafeCoerce Refl :: Replicate m '[o] ++ '[o] ': '[o] ': Replicate m '[o] :~: '[o] ': Replicate m '[o] ++ '[o] ': Replicate m '[o] ) ) $ (\\ appendSnoc lO (Proxy @'[])) $ (\\ lO) $ (\\ lO `TCL.append'` lO) $ (\\ appendAssoc lO lO (LS LZ :: Length '[ '[] ])) $ secondOp @(Replicate m '[o]) @('[o] ': '[o] ': Replicate m '[o]) @(Replicate m '[o] >: '[]) ( firstOp @(Replicate m '[o]) @('[ '[o], '[o] ]) @('[ '[] ]) loss >>> TO.swap' @'[ '[] ] @(Replicate m '[o]) (LS LZ) lO ) >>> firstOp @'[ '[] ] @(Replicate m '[o] ++ Replicate m '[o]) @'[ '[] ] ( rollup lS lP loss m ) >>> TO.add {-# INLINE rollup #-}