{-# LANGUAGE ApplicativeDo #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module TensorOps.Backend.BTensor ( BTensor , BTensorL, BTensorV , HMat , HMatD ) where import Control.Applicative import Control.DeepSeq import Data.Distributive import Data.Kind import Data.List.Util import Data.Monoid import Data.Nested hiding (unScalar, unVector, gmul') import Data.Singletons import Data.Singletons.Prelude hiding (Reverse, Head, sReverse, (:-)) import Data.Type.Combinator import Data.Type.Combinator.Util import Data.Type.Length as TCL import Data.Type.Length.Util as TCL import Data.Type.Nat import Data.Type.Product as TCP import Data.Type.Product.Util as TCP import Data.Type.Sing import Data.Type.Uniform import Statistics.Distribution import TensorOps.BLAS import TensorOps.BLAS.HMat import TensorOps.NatKind import TensorOps.Types import Type.Class.Higher import Type.Class.Higher.Util import Type.Class.Witness import Type.Family.List import Type.Family.List.Util import Type.Family.Nat import qualified Data.Type.Vector as TCV import qualified Data.Type.Vector.Util as TCV import qualified Data.Vector.Sized as VS data BTensor :: (k -> Type -> Type) -> (BShape k -> Type) -> [k] -> Type where BTS :: { unScalar :: !(ElemB b) } -> BTensor v b '[] BTV :: { unVector :: !(b ('BV n)) } -> BTensor v b '[n] BTM :: { unMatrix :: !(b ('BM n m)) } -> BTensor v b '[n,m] BTN :: { unNested :: !(v n (BTensor v b (o ': m ': ns))) } -> BTensor v b (n ': o ': m ': ns) type BTensorL = BTensor (Flip2 TCV.VecT I) type BTensorV = BTensor (Flip2 VS.VectorT I) instance (Nesting Proxy Show v, Show1 b, Show (ElemB b)) => Show (BTensor v b s) where showsPrec p = \case BTS x -> showParen (p > 10) $ showString "BTS " . showsPrec 11 x BTV xs -> showParen (p > 10) $ showString "BTV " . showsPrec1 11 xs BTM xs -> showParen (p > 10) $ showString "BTM " . showsPrec1 11 xs BTN xs -> showParen (p > 10) $ showString "BTN " . showsPrec' 11 xs where showsPrec' :: forall n s'. Int -> v n (BTensor v b s') -> ShowS showsPrec' p' xs = showsPrec p' xs \\ (nesting Proxy :: Show (BTensor v b s') :- Show (v n (BTensor v b s')) ) instance (Nesting Proxy Show v, Show1 b, Show (ElemB b)) => Show1 (BTensor v b) instance (NFData (ElemB b), NFData1 b, Nesting Proxy NFData v) => Nesting1 Proxy NFData (BTensor v b) where nesting1 _ = Wit instance (NFData (ElemB b), NFData1 b, Nesting Proxy NFData v) => NFData (BTensor v b js) where rnf = \case BTS x -> rnf x BTV xs -> rnf1 xs BTM xs -> rnf1 xs BTN (xs :: v n (BTensor v b (o ': m ': ns))) -> rnf xs \\ (nesting Proxy :: NFData (BTensor v b (o ': m ': ns)) :- NFData (v n (BTensor v b (o ': m ': ns))) ) instance (NFData (ElemB b), NFData1 b, Nesting Proxy NFData v) => NFData1 (BTensor v b) instance ( BLAS b , Vec v , Nesting1 Proxy Functor v , Nesting1 Sing Applicative v , SingI ns , Num (ElemB b) ) => Num (BTensor v b ns) where (+) = zipBase sing (+) (\_ xs ys -> axpy 1 xs (Just ys)) (\(SBM _ sM) xs ys -> gemm 1 xs (eye sM) (Just (1, ys))) {-# INLINE (+) #-} (-) = zipBase sing (-) (\_ xs ys -> axpy (-1) ys (Just xs)) (\(SBM _ sM) xs ys -> gemm 1 xs (eye sM) (Just (-1, ys))) {-# INLINE (-) #-} (*) = zipBTensorElems sing (*) {-# INLINE (*) #-} negate = mapBase sing negate (\_ xs -> axpy (-1) xs Nothing) (\(SBM _ sM) xs -> gemm 1 xs (eye sM) Nothing) {-# INLINE negate #-} abs = mapBTensorElems abs {-# INLINE abs #-} signum = mapBTensorElems signum {-# INLINE signum #-} fromInteger i = genBTensor sing $ \_ -> fromInteger i {-# INLINE fromInteger #-} -- | TODO: add RULES pragmas so that this can be done without checking -- lengths at runtime in the common case that the lengths are known at -- compile-time. -- -- Also, totally forgot about matrix-scalar multiplication here, but there -- isn't really any way of making it work without a lot of empty cases. -- should probably handle one level up. dispatchBLAS :: forall b ms os ns v. (RealFloat (ElemB b), BLAS b) => MaxLength N1 ms -> MaxLength N1 os -> MaxLength N1 ns -> BTensor v b (ms ++ os) -> BTensor v b (Reverse os ++ ns) -> BTensor v b (ms ++ ns) dispatchBLAS lM lO lN v r = case (lM, lO, lN) of (MLZ , MLZ , MLZ ) -> case (v, r) of -- scalar-scalar (BTS x, BTS y) -> BTS $ x * y (MLZ , MLZ , MLS MLZ) -> case (v, r) of -- scalar-vector (BTS x, BTV y) -> BTV $ axpy x y Nothing (MLZ , MLS MLZ, MLZ ) -> case (v, r) of -- dot (BTV x, BTV y) -> BTS $ x `dot` y (MLZ , MLS MLZ, MLS MLZ) -> case (v, r) of -- vector-matrix -- TODO: transpose? (BTV x, BTM y) -> BTV $ gemv 1 (transpB y) x Nothing (MLS MLZ, MLZ , MLZ ) -> case (v, r) of -- vector-scalar (BTV x, BTS y) -> BTV $ axpy y x Nothing (MLS MLZ, MLZ , MLS MLZ) -> case (v, r) of -- vector-scalar (BTV x, BTV y) -> BTM $ ger x y (MLS MLZ, MLS MLZ, MLZ ) -> case (v, r) of -- matrx-vector (BTM x, BTV y) -> BTV $ gemv 1 x y Nothing (MLS MLZ, MLS MLZ, MLS MLZ) -> case (v, r) of -- matrix-matrix (BTM x, BTM y) -> BTM $ gemm 1 x y Nothing {-# INLINE dispatchBLAS #-} mapRowsBTensor :: forall k (v :: k -> Type -> Type) ns ms os b. (Vec v, BLAS b) => Sing ns -> Length os -> (BTensor v b ms -> BTensor v b os) -> BTensor v b (ns ++ ms) -> BTensor v b (ns ++ os) mapRowsBTensor sN lO f = getI . bRows sN lO (I . f) {-# INLINE mapRowsBTensor #-} bRows :: forall k (v :: k -> Type -> Type) ns ms os b f. (Applicative f, Vec v, BLAS b) => Sing ns -> Length os -> (BTensor v b ms -> f (BTensor v b os)) -> BTensor v b (ns ++ ms) -> f (BTensor v b (ns ++ os)) bRows sN lO f = bIxRows sN lO (\_ -> f) {-# INLINE bRows #-} mapIxRows :: forall k (v :: k -> Type -> Type) ns ms os b. (Vec v, BLAS b) => Sing ns -> Length os -> (Prod (IndexN k) ns -> BTensor v b ms -> BTensor v b os) -> BTensor v b (ns ++ ms) -> BTensor v b (ns ++ os) mapIxRows sN lO f = getI . bIxRows sN lO (\i -> I . f i) {-# INLINE mapIxRows #-} foldMapIxRows :: forall k (v :: k -> Type -> Type) ns ms m b. (Vec v, Monoid m, BLAS b) => Sing ns -> (Prod (IndexN k) ns -> BTensor v b ms -> m) -> BTensor v b (ns ++ ms) -> m foldMapIxRows s f = getConst . bIxRows s LZ (\i -> Const . f i) {-# INLINE foldMapIxRows #-} bIxRows :: forall k (v :: k -> Type -> Type) ns ms os b f. (Applicative f, Vec v, BLAS b) => Sing ns -> Length os -> (Prod (IndexN k) ns -> BTensor v b ms -> f (BTensor v b os)) -> BTensor v b (ns ++ ms) -> f (BTensor v b (ns ++ os)) bIxRows = \case SNil -> \_ f -> f Ø s `SCons` ss -> \lO f -> \case BTV xs -> case ss of -- ns ~ '[n] -- ms ~ '[] SNil -> case lO of -- ns ++ os ~ '[n] LZ -> BTV <$> iElemsB (\i -> fmap unScalar . f (pbvProd i) . BTS) xs -- ns ++ os ~ '[n,m] LS LZ -> BTM <$> bgenRowsA (\i -> unVector <$> f (i :< Ø) (BTS $ indexB (PBV i) xs)) \\ s LS (LS _) -> BTN <$> vGenA s (\i -> f (i :< Ø) (BTS $ indexB (PBV i) xs)) BTM xs -> case ss of -- ns ~ '[n] -- ms ~ '[m] SNil -> case lO of -- ns ++ os ~ '[n] LZ -> BTV <$> bgenA (SBV s) (\(PBV i) -> unScalar <$> f (i :< Ø) (BTV (indexRowB i xs))) -- ns ++ os ~ '[n,o] LS LZ -> BTM <$> iRowsB (\i -> fmap unVector . f (i :< Ø) . BTV) xs LS (LS _) -> BTN <$> vGenA s (\i -> f (i :< Ø) (BTV (indexRowB i xs))) -- ns ~ '[n,m] -- ms ~ '[] s' `SCons` ss' -> (\\ s') $ case ss' of SNil -> case lO of LZ -> BTM <$> iElemsB (\i -> fmap unScalar . f (pbmProd i) . BTS) xs LS _ -> BTN <$> vGenA s (\i -> btn lO <$> vGenA s' (\j -> f (i :< j :< Ø) (BTS (indexB (PBM i j) xs)) ) ) BTN xs -> (\\ s) $ fmap (btn (singLength ss `TCL.append'` lO)) . vITraverse (\i -> bIxRows ss lO (\is -> f (i :< is))) $ xs indexRowBTensor :: forall k (b :: BShape k -> Type) v ns ms. ( BLAS b , Vec v ) => Prod (IndexN k) ns -> BTensor v b (ns ++ ms) -> BTensor v b ms indexRowBTensor = \case Ø -> id i :< is -> \case BTV xs -> case is of Ø -> BTS $ indexB (PBV i) xs BTM xs -> case is of Ø -> BTV $ indexRowB i xs j :< Ø -> BTS $ indexB (PBM i j) xs BTN xs -> indexRowBTensor is (vIndex i xs) {-# INLINE indexRowBTensor #-} mapBTensorElems :: (Vec v, BLAS b) => (ElemB b -> ElemB b) -> BTensor v b ns -> BTensor v b ns mapBTensorElems f = getI . bTensorElems (I . f) {-# INLINE mapBTensorElems #-} bTensorElems :: forall k (v :: k -> Type -> Type) ns b f. (Applicative f, Vec v, BLAS b) => (ElemB b -> f (ElemB b)) -> BTensor v b ns -> f (BTensor v b ns) bTensorElems f = \case BTS x -> BTS <$> f x BTV xs -> BTV <$> elemsB f xs BTM xs -> BTM <$> elemsB f xs BTN xs -> BTN <$> vITraverse (\_ x -> bTensorElems f x) xs {-# INLINE bTensorElems #-} ifoldMapBTensor :: forall k (v :: k -> Type -> Type) ns m b. (Monoid m, Vec v, BLAS b) => (Prod (IndexN k) ns -> ElemB b -> m) -> BTensor v b ns -> m ifoldMapBTensor f = getConst . bTensorIxElems (\i -> Const . f i) {-# INLINE ifoldMapBTensor #-} bTensorIxElems :: forall k (v :: k -> Type -> Type) ns b f. (Applicative f, Vec v, BLAS b) => (Prod (IndexN k) ns -> ElemB b -> f (ElemB b)) -> BTensor v b ns -> f (BTensor v b ns) bTensorIxElems f = \case BTS x -> BTS <$> f Ø x BTV xs -> BTV <$> iElemsB (f . pbvProd) xs BTM xs -> BTM <$> iElemsB (f . pbmProd) xs BTN xs -> BTN <$> vITraverse (\i -> bTensorIxElems (\is -> f (i :< is))) xs {-# INLINE bTensorIxElems #-} zipBTensorElems :: forall v b ns. (BLAS b, Nesting1 Sing Applicative v) => Sing ns -> (ElemB b -> ElemB b -> ElemB b) -> BTensor v b ns -> BTensor v b ns -> BTensor v b ns zipBTensorElems = \case SNil -> \f -> \case BTS x -> \case BTS y -> BTS (f x y) sN `SCons` SNil -> \f -> \case BTV xs -> \case BTV ys -> BTV (zipB (SBV sN) f xs ys) sN `SCons` (sM `SCons` SNil) -> \f -> \case BTM xs -> \case BTM ys -> BTM (zipB (SBM sN sM) f xs ys) (s :: Sing k) `SCons` ss@(_ `SCons` (_ `SCons` _)) -> \f -> \case BTN xs -> \case BTN ys -> BTN (zipBTensorElems ss f <$> xs <*> ys) \\ (nesting1 s :: Wit (Applicative (v k))) {-# INLINE zipBTensorElems #-} liftBTensor :: forall v b ns n. ( BLAS b , Nesting1 Proxy Functor v , Nesting1 Sing Distributive v ) => Sing ns -> (TCV.Vec n (ElemB b) -> ElemB b) -> TCV.Vec n (BTensor v b ns) -> BTensor v b ns liftBTensor = \case SNil -> \f xs -> let xs' = unScalar <$> xs in BTS $ f xs' sN `SCons` SNil -> \f xs -> let xs' = unVector <$> xs in BTV $ liftB (SBV sN) f xs' sN `SCons` (sM `SCons` SNil) -> \f xs -> let xs' = unMatrix <$> xs in BTM $ liftB (SBM sN sM) f xs' (s :: Sing k) `SCons` ss@(_ `SCons` (_ `SCons` _)) -> \f xs -> let xs' = unNested <$> xs in BTN $ TCV.liftVecD (liftBTensor ss f) xs' \\ (nesting1 s :: Wit (Distributive (v k))) {-# INLINE liftBTensor #-} mapBTM :: forall k (v :: k -> Type -> Type) ns n m ms b. (Vec v, BLAS b) => Sing ns -> Length ms -> (b ('BM n m) -> BTensor v b ms) -> BTensor v b (ns ++ [n,m]) -> BTensor v b (ns ++ ms) mapBTM sN lM f = getI . traverseBTM sN lM (I . f) {-# INLINE mapBTM #-} foldMapBTM :: (Monoid a, Vec v, BLAS b) => Length ns -> (b ('BM n m) -> a) -> BTensor v b (ns ++ [n,m]) -> a foldMapBTM l f = ifoldMapBTM l (\_ -> f) {-# INLINE foldMapBTM #-} traverseBTM :: forall k (v :: k -> Type -> Type) ns n m ms b f. (Applicative f, Vec v, BLAS b) => Sing ns -> Length ms -> (b ('BM n m) -> f (BTensor v b ms)) -> BTensor v b (ns ++ [n,m]) -> f (BTensor v b (ns ++ ms)) traverseBTM = \case SNil -> \_ f -> \case BTM x -> f x s `SCons` ss -> \lM f -> \case BTV _ -> case ss of BTM _ -> case ss of BTN xs -> (\\ s) $ fmap (btn (singLength ss `TCL.append'` lM)) . vITraverse (\_ -> traverseBTM ss lM f) $ xs {-# INLINE traverseBTM #-} imapBTM :: forall k (v :: k -> Type -> Type) ns n m ms b. (Vec v, BLAS b) => Sing ns -> Length ms -> (Prod (IndexN k) ns -> b ('BM n m) -> BTensor v b ms) -> BTensor v b (ns ++ [n,m]) -> BTensor v b (ns ++ ms) imapBTM sN lM f = getI . itraverseBTM sN lM (\i -> I . f i) {-# INLINE imapBTM #-} ifoldMapBTM :: (Vec v, Monoid a, BLAS b) => Length ns -> (Prod (IndexN k) ns -> b ('BM n m) -> a) -> BTensor v b (ns ++ [n,m]) -> a ifoldMapBTM = \case LZ -> \f -> \case BTM xs -> f Ø xs LS l -> \f -> \case BTV _ -> case l of BTM _ -> case l of BTN xs -> vIFoldMap (\i -> ifoldMapBTM l (\is -> f (i :< is))) xs {-# INLINE ifoldMapBTM #-} itraverseBTM :: forall k (v :: k -> Type -> Type) ns n m ms b f. (Applicative f, Vec v, BLAS b) => Sing ns -> Length ms -> (Prod (IndexN k) ns -> b ('BM n m) -> f (BTensor v b ms)) -> BTensor v b (ns ++ [n,m]) -> f (BTensor v b (ns ++ ms)) itraverseBTM = \case SNil -> \_ f -> \case BTM x -> f Ø x s `SCons` ss -> \lM f -> \case BTV _ -> case ss of BTM _ -> case ss of BTN xs -> (\\ s) $ fmap (btn (singLength ss `TCL.append'` lM)) . vITraverse (\i -> itraverseBTM ss lM (\is ys -> f (i :< is) ys)) $ xs {-# INLINE itraverseBTM #-} mapBase :: forall v b ns. (Nesting1 Proxy Functor v) => Sing ns -> (ElemB b -> ElemB b) -> (forall n. Sing n -> b ('BV n) -> b ('BV n)) -> (forall n m. Sing ('BM n m) -> b ('BM n m) -> b ('BM n m)) -> BTensor v b ns -> BTensor v b ns mapBase = \case SNil -> \f _ _ -> \case BTS x -> BTS (f x) sN `SCons` SNil -> \_ g _ -> \case BTV xs -> BTV (g sN xs) sN `SCons` (sM `SCons` SNil) -> \_ _ h -> \case BTM xs -> BTM (h (SBM sN sM) xs) (_ :: Sing k) `SCons` ss@(_ `SCons` (_ `SCons` _)) -> \f g h -> \case BTN xs -> BTN (mapBase ss f g h <$> xs) \\ (nesting1 Proxy :: Wit (Functor (v k))) {-# INLINE mapBase #-} zipBase :: forall v b ns. (Nesting1 Sing Applicative v) => Sing ns -> (ElemB b -> ElemB b -> ElemB b) -> (forall n. Sing n -> b ('BV n) -> b ('BV n) -> b ('BV n)) -> (forall n m. Sing ('BM n m) -> b ('BM n m) -> b ('BM n m) -> b ('BM n m)) -> BTensor v b ns -> BTensor v b ns -> BTensor v b ns zipBase = \case SNil -> \f _ _ -> \case BTS x -> \case BTS y -> BTS (f x y) sN `SCons` SNil -> \_ g _ -> \case BTV xs -> \case BTV ys -> BTV (g sN xs ys) sN `SCons` (sM `SCons` SNil) -> \_ _ h -> \case BTM xs -> \case BTM ys -> BTM (h (SBM sN sM) xs ys) (s :: Sing k) `SCons` ss@(_ `SCons` (_ `SCons` _)) -> \f g h -> \case BTN xs -> \case BTN ys -> BTN $ zipBase ss f g h <$> xs <*> ys \\ (nesting1 s :: Wit (Applicative (v k))) {-# INLINE zipBase #-} genBTensorA :: forall k (b :: BShape k -> Type) v (ns :: [k]) f. (Applicative f, BLAS b, Vec v) => Sing ns -> (Prod (IndexN k) ns -> f (ElemB b)) -> f (BTensor v b ns) genBTensorA = \case SNil -> \f -> BTS <$> f Ø sN `SCons` SNil -> \f -> BTV <$> bgenA (SBV sN) (f . pbvProd) sN `SCons` (sM `SCons` SNil) -> \f -> BTM <$> bgenA (SBM sN sM) (f . pbmProd) s `SCons` ss@(_ `SCons` (_ `SCons` _)) -> \f -> BTN <$> vGenA s (\i -> genBTensorA ss (\is -> f (i :< is))) {-# INLINE genBTensorA #-} genBTensor :: forall k (b :: BShape k -> Type) v (ns :: [k]). (BLAS b, Vec v) => Sing ns -> (Prod (IndexN k) ns -> ElemB b) -> BTensor v b ns genBTensor s f = getI $ genBTensorA s (I . f) {-# INLINE genBTensor #-} indexBTensor :: forall k (b :: BShape k -> Type) v ns. (BLAS b, Vec v) => Prod (IndexN k) ns -> BTensor v b ns -> ElemB b indexBTensor = \case Ø -> \case BTS x -> x i :< Ø -> \case BTV xs -> indexB (PBV i) xs i :< j :< Ø -> \case BTM xs -> indexB (PBM i j) xs i :< js@(_ :< _ :< _) -> \case BTN xs -> indexBTensor js (vIndex i xs) {-# INLINE indexBTensor #-} btn :: (BLAS b, Vec v, SingI n) => Length ns -> v n (BTensor v b ns) -> BTensor v b (n ': ns) btn = \case LZ -> \xs -> BTV $ bgen sing (unScalar . (`vIndex` xs) . unPBV) LS LZ -> \xs -> BTM $ bgenRows (unVector . (`vIndex` xs)) LS (LS _) -> BTN {-# INLINE btn #-} gmul' :: forall v b ms os ns. ( SingI (ms ++ ns) , RealFloat (ElemB b) , Vec v , Nesting1 Proxy Functor v , Nesting1 Sing Applicative v , BLAS b ) => Length ms -> Length os -> Length ns -> BTensor v b (ms ++ os) -> BTensor v b (Reverse os ++ ns) -> BTensor v b (ms ++ ns) gmul' lM lO lN = gmulB sM lO lN \\ sN where sM :: Sing ms sN :: Sing ns (sM, sN) = splitSing lM sing {-# INLINE[0] gmul' #-} {-# RULES "gmul'/SS" gmul' = dispatchSS "gmul'/SV" gmul' = dispatchSV "gmul'/dot" gmul' = dispatchDot "gmul'/VM" gmul' = dispatchVM "gmul'/VS" gmul' = dispatchVS "gmul'/out" gmul' = dispatchOut "gmul'/MV" gmul' = dispatchMV "gmul'/MM" gmul' = dispatchMM #-} -- | General strategy: -- -- * We can only outsource to BLAS (using 'dispatchBLAS') in the case -- that @os@ and @ns@ have length 0 or 1. Anything else, fall back to -- the basic reverse-indexing method in "Data.Nested". -- * If @ms@ is length 2 or higher, "traverse down" to the length 0 or -- 1 tail...and then sum them up. gmulB :: forall k (b :: BShape k -> Type) v ms os ns. ( RealFloat (ElemB b) , SingI ns , BLAS b , Vec v , Nesting1 Proxy Functor v , Nesting1 Sing Applicative v ) => Sing ms -> Length os -> Length ns -> BTensor v b (ms ++ os) -> BTensor v b (Reverse os ++ ns) -> BTensor v b (ms ++ ns) gmulB sM lO lN v r = case splitting (S_ Z_) (lengthProd lN) of Fewer mlN _ -> case splittingEnd (S_ (S_ Z_)) (lengthProd lO) of FewerEnd MLZ _ -> gmulBLAS sM MLZ mlN v r FewerEnd (MLS MLZ) _ -> gmulBLAS sM (MLS MLZ) mlN v r FewerEnd (MLS (MLS MLZ)) _ -> case mlN of MLZ -> case r of BTM ys -> mapBTM sM LZ (\xs -> BTS $ traceB (gemm 1 xs ys Nothing)) v MLS MLZ -> naiveGMul sM lO lN v r SplitEnd _ _ _ -> naiveGMul sM lO lN v r Split _ _ _ -> naiveGMul sM lO lN v r {-# INLINE[0] gmulB #-} -- | Naive implementation of 'gmul' (based on the implementation for -- 'NTensor') that does not utilize any BLAS capabilities. naiveGMul :: forall k (b :: BShape k -> Type) v ms os ns. ( BLAS b , Vec v , Num (ElemB b) , Nesting1 Proxy Functor v , Nesting1 Sing Applicative v , SingI ns ) => Sing ms -> Length os -> Length ns -> BTensor v b (ms ++ os) -> BTensor v b (Reverse os ++ ns) -> BTensor v b (ms ++ ns) naiveGMul sM _ lN v r = mapRowsBTensor sM lN (getSum . ifoldMapBTensor (\i -> Sum . f i)) v where f :: Prod (IndexN k) os -> ElemB b -> BTensor v b ns f is x = mapBase sing (x *) (\_ ys -> scaleB x ys) (\_ ys -> scaleB x ys) (indexRowBTensor (TCP.reverse' is) r) -- | A 'gmul' that runs my dispatching BLAS commands when it can. -- Contains the type-level constraint that @os@ and @ns@ have to have -- either length 0 or 1. -- -- TODO: no longer needs Sing ms gmulBLAS :: forall b ms os ns v. ( RealFloat (ElemB b) , BLAS b , Vec v , SingI ns , Nesting1 Proxy Functor v , Nesting1 Sing Applicative v ) => Sing ms -> MaxLength N1 os -> MaxLength N1 ns -> BTensor v b (ms ++ os) -> BTensor v b (Reverse os ++ ns) -> BTensor v b (ms ++ ns) gmulBLAS sM mlO mlN v r = case mlO of MLZ -> case splittingEnd (S_ (S_ Z_)) spM of FewerEnd MLZ _ -> dispatchBLAS MLZ mlO mlN v r FewerEnd (MLS MLZ) _ -> dispatchBLAS (MLS MLZ) mlO mlN v r FewerEnd (MLS (MLS MLZ)) _ -> case v of BTM xs -> case mlN of MLZ -> case r of BTS y -> BTM $ scaleB y xs -- TODO: can this be made non-naive? -- ms ~ '[m1,m2] -- os ~ '[] -- ns ~ '[n] MLS MLZ -> naiveGMul sM LZ (fromMaxLength mlN) v r SplitEnd (ELS (ELS ELZ)) spM0 spM1 -> case mlN of MLZ -> case r of BTS y -> mapBTM (prodSing spM0) (prodLength spM1) (\xs -> BTM $ scaleB y xs) v \\ appendNil lM -- TODO: can this be made non-naive? -- ms ~ (ms0 ++ '[m1,m2]) -- os ~ '[] -- ns ~ '[n] MLS MLZ -> naiveGMul sM LZ (fromMaxLength mlN) v r MLS MLZ -> case splittingEnd (S_ Z_) spM of FewerEnd mlM _ -> dispatchBLAS mlM mlO mlN v r SplitEnd (ELS ELZ) spM0 spM1 -> let sM0 = prodSing spM0 lM0 = prodLength spM0 lM1 = prodLength spM1 in (\\ appendAssoc (TCL.tail' lM0) lM1 (LS LZ :: Length os) ) $ case mlN of MLZ -> case r of BTV ys -> mapBTM sM0 lM1 (\xs -> BTV $ gemv 1 xs ys Nothing) v \\ appendNil lM MLS MLZ -> case r of BTM ys -> mapBTM sM0 (lM1 `TCL.append'` (LS LZ :: Length ns)) (\xs -> BTM $ gemm 1 xs ys Nothing) v \\ appendAssoc (TCL.tail' lM0) lM1 (LS LZ :: Length ns) where spM = singProd sM lM = singLength sM diagBTensor :: forall k (b :: BShape k -> Type) v n ns. ( SingI (n ': ns) , BLAS b , Vec v , Num (ElemB b) , Eq (IndexN k n) ) => Uniform n ns -> BTensor v b '[n] -> BTensor v b (n ': ns) diagBTensor = \case UØ -> id US UØ -> \case BTV xs -> BTM $ diagB xs u@(US (US _)) -> \(BTV xs) -> genBTensor sing (\i -> case TCV.uniformVec (prodToVec I (US u) i) of Nothing -> 0 Just (I i') -> indexB (PBV i') xs ) {-# INLINE diagBTensor #-} transpBTensor :: (BLAS b, Vec v) => Sing ns -> BTensor v b ns -> BTensor v b (Reverse ns) transpBTensor s = \case BTS x -> BTS x BTV xs -> BTV xs BTM xs -> BTM $ transpB xs xs@(BTN _) -> (\\ reverseReverse (singLength s)) $ genBTensor (sReverse s) $ \i -> indexBTensor (TCP.reverse' i) xs {-# INLINE transpBTensor #-} sumBTensor :: forall v b n ns. ( BLAS b , Vec v , Num (ElemB b) , Foldable (v n) , SingI ns , SingI n , Nesting1 Proxy Functor v , Nesting1 Sing Applicative v ) => BTensor v b (n ': ns) -> BTensor v b ns sumBTensor = \case BTV xs -> BTS $ sumB xs BTM (xs :: b ('BM n m)) -> BTV $ gemv 1 (transpB xs) (bgen (SBV (sing :: Sing n)) (\_ -> 1)) Nothing BTN xs -> sum xs instance ( Vec (v :: k -> Type -> Type) , BLAS b , RealFloat (ElemB b) , Nesting1 Proxy Functor v , Nesting1 Proxy Foldable v , Nesting1 Sing Applicative v , Nesting1 Sing Distributive v , Eq1 (IndexN k) ) => Tensor (BTensor v b) where type ElemT (BTensor v b) = ElemB b liftT :: SingI ns => (TCV.Vec n (ElemB b) -> ElemB b) -> TCV.Vec n (BTensor v b ns) -> BTensor v b ns liftT = liftBTensor sing {-# INLINE liftT #-} sumT = sum' {-# INLINE sumT #-} scaleT α = mapBase sing (α*) (\_ -> scaleB α) (\_ -> scaleB α) {-# INLINE scaleT #-} gmul :: forall ms os ns. SingI (ms ++ ns) => Length ms -> Length os -> Length ns -> BTensor v b (ms ++ os) -> BTensor v b (Reverse os ++ ns) -> BTensor v b (ms ++ ns) gmul = gmul' {-# INLINE gmul #-} diag :: forall n ns. SingI (n ': ns) => Uniform n ns -> BTensor v b '[n] -> BTensor v b (n ': ns) diag = diagBTensor \\ (produceEq1 :: Eq1 (IndexN k) :- Eq (IndexN k n)) {-# INLINE diag #-} getDiag :: SingI n => Uniform n ns -> BTensor v b (n ': n ': ns) -> BTensor v b '[n] getDiag = \case UØ -> \case BTM xs -> BTV $ getDiagB xs u@(US _) -> \xs -> genBTensor sing $ \(i :< Ø) -> indexBTensor (TCP.replicate i (US (US u))) xs {-# INLINE getDiag #-} transp = transpBTensor sing {-# INLINE transp #-} generateA = genBTensorA sing {-# INLINE generateA #-} genRand d g = generateA (\_ -> realToFrac <$> genContVar d g) {-# INLINE genRand #-} ixRows :: forall f ms os ns. (Applicative f, SingI (ms ++ os)) => Length ms -> Length os -> (Prod (IndexN k) ms -> BTensor v b ns -> f (BTensor v b os)) -> BTensor v b (ms ++ ns) -> f (BTensor v b (ms ++ os)) ixRows lM lO = bIxRows sM lO where sM :: Sing ms sM = takeSing lM lO (sing :: Sing (ms ++ os)) {-# INLINE ixRows #-} (!) = flip indexBTensor {-# INLINE (!) #-} sumRows :: forall n ns. (SingI (n ': ns), SingI ns) => BTensor v b (n ': ns) -> BTensor v b ns sumRows = sumBTensor \\ (nesting1 Proxy :: Wit (Foldable (v n))) \\ sHead (sing :: Sing (n ': ns)) {-# INLINE sumRows #-} mapRows :: forall ns ms. SingI (ns ++ ms) => Length ns -> (BTensor v b ms -> BTensor v b ms) -> BTensor v b (ns ++ ms) -> BTensor v b (ns ++ ms) mapRows l f = mapRowsBTensor sN (singLength sM) f where sN :: Sing ns sM :: Sing ms (sN, sM) = splitSing l (sing :: Sing (ns ++ ms)) {-# INLINE mapRows #-} -- * Boring dispatches dispatchSS :: Num (ElemB b) => Length '[] -> Length '[] -> Length '[] -> BTensor v b '[] -> BTensor v b '[] -> BTensor v b '[] dispatchSS _ _ _ (BTS x) (BTS y) = BTS (x * y) {-# INLINE dispatchSS #-} dispatchSV :: BLAS b => Length '[] -> Length '[] -> Length '[n] -> BTensor v b '[] -> BTensor v b '[n] -> BTensor v b '[n] dispatchSV _ _ _ (BTS x) (BTV y) = BTV $ axpy x y Nothing {-# INLINE dispatchSV #-} dispatchDot :: BLAS b => Length '[] -> Length '[n] -> Length '[] -> BTensor v b '[n] -> BTensor v b '[n] -> BTensor v b '[] dispatchDot _ _ _ (BTV x) (BTV y) = BTS $ x `dot` y {-# INLINE dispatchDot #-} dispatchVM :: (Num (ElemB b), BLAS b) => Length '[] -> Length '[n] -> Length '[m] -> BTensor v b '[n] -> BTensor v b '[n,m] -> BTensor v b '[m] dispatchVM _ _ _ (BTV x) (BTM y) = BTV $ gemv 1 (transpB y) x Nothing {-# INLINE dispatchVM #-} dispatchVS :: BLAS b => Length '[n] -> Length '[] -> Length '[] -> BTensor v b '[n] -> BTensor v b '[] -> BTensor v b '[n] dispatchVS _ _ _ (BTV x) (BTS y) = BTV $ axpy y x Nothing {-# INLINE dispatchVS #-} dispatchOut :: BLAS b => Length '[n] -> Length '[] -> Length '[m] -> BTensor v b '[n] -> BTensor v b '[m] -> BTensor v b '[n,m] dispatchOut _ _ _ (BTV x) (BTV y) = BTM $ ger x y {-# INLINE dispatchOut #-} dispatchMV :: (Num (ElemB b), BLAS b) => Length '[n] -> Length '[m] -> Length '[] -> BTensor v b '[n,m] -> BTensor v b '[m] -> BTensor v b '[n] dispatchMV _ _ _ (BTM x) (BTV y) = BTV $ gemv 1 x y Nothing {-# INLINE dispatchMV #-} dispatchMM :: (Num (ElemB b), BLAS b) => Length '[m] -> Length '[o] -> Length '[n] -> BTensor v b '[m,o] -> BTensor v b '[o,n] -> BTensor v b '[m,n] dispatchMM _ _ _ (BTM x) (BTM y) = BTM $ gemm 1 x y Nothing {-# INLINE dispatchMM #-}