{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE EmptyCase #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE InstanceSigs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilyDependencies #-} {-# LANGUAGE TypeInType #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module TensorOps.BLAS ( BShape(..) , BShapeDims, bShapeDims , BShapeP(..) , bShapeProd, pbvProd, pbmProd , BLAS(..) , Sing(..) , SBShape , elemsB , zipB , bgen , bgenRows ) where import Data.Kind import Data.Singletons import Data.Singletons.Prelude hiding (Reverse, Head, sReverse, (:-)) import Data.Singletons.TH import Data.Type.Combinator import Data.Type.Product as TCP import Data.Type.Sing import Data.Type.Vector (VecT(..), Vec) import TensorOps.NatKind $(singletons [d| data BShape a = BV !a | BM !a !a deriving (Show, Eq, Ord, Functor) |]) type family BShapeDims (s :: BShape k) = (ks :: [k]) | ks -> s where BShapeDims ('BV x ) = '[x] BShapeDims ('BM x y) = '[x,y] genDefunSymbols [''BShapeDims] data BShapeP :: (k -> Type) -> BShape k -> Type where PBV :: { unPBV :: !(f a) } -> BShapeP f ('BV a) PBM :: { unPBMn :: !(f a) , unPBMm :: !(f b) } -> BShapeP f ('BM a b) bShapeProd :: BShapeP f s -> Prod f (BShapeDims s) bShapeProd = \case PBV x -> x :< Ø PBM x y -> x :< y :< Ø {-# INLINE[0] bShapeProd #-} {-# RULES "bShapeProd/BV" bShapeProd = pbvProd "bShapeProd/BM" bShapeProd = pbmProd #-} pbvProd :: BShapeP f ('BV a) -> Prod f '[a] pbvProd (PBV x) = x :< Ø {-# INLINE pbvProd #-} pbmProd :: BShapeP f ('BM a b) -> Prod f '[a,b] pbmProd (PBM x y) = x :< y :< Ø {-# INLINE pbmProd #-} bShapeDims :: BShape a -> [a] bShapeDims (BV x ) = [x] bShapeDims (BM x y) = [x,y] class NatKind k => BLAS (b :: BShape k -> Type) where type ElemB b :: Type liftB :: Sing s -> (Vec n (ElemB b) -> ElemB b) -> Vec n (b s) -> b s axpy :: ElemB b -- ^ α -> b ('BV n) -- ^ x -> Maybe (b ('BV n)) -- ^ y -> b ('BV n) -- ^ α x + y dot :: b ('BV n) -- ^ x -> b ('BV n) -- ^ y -> ElemB b -- ^ x' y -- norm -- :: b ('BV n) -- ^ x -- -> ElemB b -- ^ ||x|| ger :: b ('BV n) -- ^ x -> b ('BV m) -- ^ y -> b ('BM n m) -- ^ x y' gemv :: ElemB b -- ^ α -> b ('BM n m) -- ^ A -> b ('BV m) -- ^ x -> Maybe (ElemB b, b ('BV n)) -- ^ β, y -> b ('BV n) -- ^ α A x + β y -- TODO: better way to scale matrices gemm :: ElemB b -- ^ α -> b ('BM n o) -- ^ A -> b ('BM o m) -- ^ B -> Maybe (ElemB b, b ('BM n m)) -- ^ β, C -> b ('BM n m) -- ^ α A B + β C scaleB :: ElemB b -> b s -> b s addB :: b s -> b s -> b s indexB :: BShapeP (IndexN k) s -> b s -> ElemB b indexRowB :: IndexN k n -> b ('BM n m) -> b ('BV m) transpB :: b ('BM n m) -> b ('BM m n) iRowsB :: Applicative f => (IndexN k n -> b ('BV m) -> f (b ('BV o))) -> b ('BM n m) -> f (b ('BM n o)) iElemsB :: Applicative f => (BShapeP (IndexN k) s -> ElemB b -> f (ElemB b)) -> b s -> f (b s) -- TODO: can we merge bgen and bgenRowsA? bgenA :: Applicative f => Sing s -> (BShapeP (IndexN k) s -> f (ElemB b)) -> f (b s) bgenRowsA :: (Applicative f, SingI n) => (IndexN k n -> f (b ('BV m))) -> f (b ('BM n m)) eye :: Sing n -> b ('BM n n) traceB :: b ('BM n n) -> ElemB b diagB :: b ('BV n) -> b ('BM n n) getDiagB :: b ('BM n n) -> b ('BV n) sumB :: b s -> ElemB b -- zero :: b s elemsB :: (Applicative f, BLAS b) => (ElemB b -> f (ElemB b)) -> b s -> f (b s) elemsB f = iElemsB (\_ x -> f x) {-# INLINE elemsB #-} bgen :: forall k (b :: BShape k -> Type) (s :: BShape k). BLAS b => Sing s -> (BShapeP (IndexN k) s -> ElemB b) -> b s bgen s f = getI $ bgenA s (I . f) {-# INLINE bgen #-} bgenRows :: (BLAS b, SingI n) => (IndexN k n -> b ('BV m)) -> b ('BM n m) bgenRows f = getI $ bgenRowsA (I . f) {-# INLINE bgenRows #-} zipB :: BLAS b => Sing s -> (ElemB b -> ElemB b -> ElemB b) -> b s -> b s -> b s zipB s f x y = liftB s (\(I x' :* I y' :* ØV) -> f x' y') (I x :* I y :* ØV) {-# INLINE zipB #-}