{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module TensorOps.Backend.NTensor
( NTensor
, NTensorL
, NTensorV
)
where
import Control.DeepSeq
import Control.Monad.Primitive
import Data.Distributive
import Data.Kind
import Data.List.Util
import Data.Nested
import Data.Proxy
import Data.Singletons
import Data.Singletons.Prelude.List hiding (Length, Reverse)
import Data.Type.Combinator
import Data.Type.Combinator.Util
import Data.Type.Length as TCL
import Data.Type.Length.Util as TCL
import Data.Type.Product as TCP
import Data.Type.Product.Util as TCP
import Data.Type.Sing
import Data.Type.Uniform
import GHC.Generics
import Statistics.Distribution
import System.Random.MWC
import TensorOps.NatKind
import TensorOps.Types
import Type.Class.Higher
import Type.Class.Higher.Util
import Type.Class.Witness
import Type.Family.List
import qualified Data.Type.Vector as TCV
import qualified Data.Type.Vector.Util as TCV
import qualified Data.Vector.Sized as VS
import qualified Type.Family.Nat as TCN
newtype NTensor :: (k -> Type -> Type) -> Type -> [k] -> Type where
NTensor
:: { getNVec :: Nested v js a }
-> NTensor v a js
deriving instance (Num a, SingI ns, Nesting1 Proxy Functor v, Nesting1 Sing Applicative v)
=> Num (NTensor v a ns)
deriving instance Generic (NTensor v a ns)
deriving instance (NFData a, Nesting Proxy NFData v) => NFData (NTensor v a ns)
instance (NFData a, Nesting Proxy NFData v) => Nesting1 w NFData (NTensor v a) where
nesting1 _ = Wit
instance (NFData a, Nesting Proxy NFData v) => NFData1 (NTensor v a)
genNTensor
:: forall k (v :: k -> Type -> Type) (ns :: [k]) (a :: Type). Vec v
=> Sing ns
-> (Prod (IndexN k) ns -> a)
-> NTensor v a ns
genNTensor s f = NTensor $ genNested s f
{-# INLINE genNTensor #-}
genNTensorA
:: forall k (v :: k -> Type -> Type) (ns :: [k]) (a :: Type) f. (Applicative f, Vec v)
=> Sing ns
-> (Prod (IndexN k) ns -> f a)
-> f (NTensor v a ns)
genNTensorA s f = NTensor <$> genNestedA s f
{-# INLINE genNTensorA #-}
indexNTensor
:: forall k (v :: k -> Type -> Type) (ns :: [k]) (a :: Type). Vec v
=> Prod (IndexN k) ns
-> NTensor v a ns
-> a
indexNTensor i = indexNested i . getNVec
{-# INLINE indexNTensor #-}
overNVec2
:: (Nested v ns a -> Nested w ms b -> Nested u os c)
-> NTensor v a ns
-> NTensor w b ms
-> NTensor u c os
overNVec2 f x y = NTensor $ f (getNVec x) (getNVec y)
{-# INLINE overNVec2 #-}
ntNVec
:: Functor f
=> (Nested v ns a -> f (Nested w ms b))
-> NTensor v a ns
-> f (NTensor w b ms)
ntNVec f = fmap NTensor . f . getNVec
{-# INLINE ntNVec #-}
nvecNT
:: Functor f
=> (NTensor v a ns -> f (NTensor w b ms))
-> Nested v ns a
-> f (Nested w ms b)
nvecNT f = fmap getNVec . f . NTensor
{-# INLINE nvecNT #-}
overNVec
:: (Nested v ns a -> Nested v ms a)
-> NTensor v a ns
-> NTensor v a ms
overNVec f = getI . ntNVec (I . f)
{-# INLINE overNVec #-}
underNVec
:: (NTensor v a ns -> NTensor v a ms)
-> Nested v ns a
-> Nested v ms a
underNVec f = getNVec . f . NTensor
instance
( Vec (v :: k -> Type -> Type)
, RealFloat a
, Nesting1 Proxy Functor v
, Nesting1 Sing Applicative v
, Nesting1 Proxy Foldable v
, Nesting1 Proxy Traversable v
, Nesting1 Sing Distributive v
, Eq1 (IndexN k)
) => Tensor (NTensor v a) where
type ElemT (NTensor v a) = a
liftT
:: forall (n :: TCN.N) (o :: [k]). SingI o
=> (TCV.Vec n a -> a)
-> TCV.Vec n (NTensor v a o)
-> NTensor v a o
liftT f = NTensor . liftNested f . fmap getNVec
{-# INLINE liftT #-}
sumT = sum'
{-# INLINE sumT #-}
scaleT α = overNVec (fmap (α *))
{-# INLINE scaleT #-}
transp
:: forall ns. (SingI ns, SingI (Reverse ns))
=> NTensor v a ns
-> NTensor v a (Reverse ns)
transp = overNVec (transpose' (singLength sing) sing)
{-# INLINE transp #-}
gmul
:: forall ms os ns. SingI (Reverse os ++ ns)
=> Length ms
-> Length os
-> Length ns
-> NTensor v a (ms ++ os)
-> NTensor v a (Reverse os ++ ns)
-> NTensor v a (ms ++ ns)
gmul lM lO lN = overNVec2 (gmul' lM lO lN)
\\ sO'
\\ sN
where
sO' :: Sing (Reverse os)
sN :: Sing ns
(sO', sN) = splitSing (TCL.reverse' lO)
(sing :: Sing (Reverse os ++ ns))
{-# INLINE gmul #-}
diag
:: forall n ns. SingI (n ': ns)
=> Uniform n ns
-> NTensor v a '[n]
-> NTensor v a (n ': ns)
diag u d
= genNTensor sing (\i -> case TCV.uniformVec (prodToVec I (US u) i) of
Nothing -> 0
Just (I i') -> indexNTensor (i' :< Ø) d
)
\\ (produceEq1 :: Eq1 (IndexN k) :- Eq (IndexN k n))
{-# INLINE diag #-}
-- TODO: this can be just done using genNTensor
getDiag
:: forall n ns. SingI '[n]
=> Uniform n ns
-> NTensor v a (n ': n ': ns)
-> NTensor v a '[n]
getDiag u = overNVec (diagNV sing u)
\\ sHead (sing :: Sing '[n])
{-# INLINE getDiag #-}
genRand
:: forall m d (ns :: [k]). (ContGen d, PrimMonad m, SingI ns)
=> d
-> Gen (PrimState m)
-> m (NTensor v a ns)
genRand d g = generateA (\_ -> realToFrac <$> genContVar d g)
{-# INLINE genRand #-}
generateA
:: forall f ns. (Applicative f, SingI ns)
=> (Prod (IndexN k) ns -> f a)
-> f (NTensor v a ns)
generateA = genNTensorA sing
{-# INLINE generateA #-}
(!) = flip indexNTensor
{-# INLINE (!) #-}
ixRows
:: Applicative f
=> Length ms
-> Length os
-> (Prod (IndexN k) ms -> NTensor v a ns -> f (NTensor v a os))
-> NTensor v a (ms ++ ns)
-> f (NTensor v a (ms ++ os))
ixRows l _ f = ntNVec $ fmap joinNested . nIxRows l (\i -> nvecNT (f i))
{-# INLINE ixRows #-}
sumRows
:: forall n ns. SingI ns
=> NTensor v a (n ': ns)
-> NTensor v a ns
sumRows = overNVec sumRowsNested
\\ (nesting1 Proxy :: Wit (Foldable (v n)))
{-# INLINE sumRows #-}
mapRows
:: Length ns
-> (NTensor v a ms -> NTensor v a ms)
-> NTensor v a (ns ++ ms)
-> NTensor v a (ns ++ ms)
mapRows l f = overNVec $ joinNested . mapNVecSlices (underNVec f) l
{-# INLINE mapRows #-}
type NTensorL = NTensor (Flip2 TCV.VecT I) Double
type NTensorV = NTensor (Flip2 VS.VectorT I) Double