{-# LANGUAGE DataKinds            #-}
{-# LANGUAGE FlexibleContexts     #-}
{-# LANGUAGE GADTs                #-}
{-# LANGUAGE InstanceSigs         #-}
{-# LANGUAGE KindSignatures       #-}
{-# LANGUAGE LambdaCase           #-}
{-# LANGUAGE RankNTypes           #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE TypeApplications     #-}
{-# LANGUAGE TypeFamilies         #-}
{-# LANGUAGE UndecidableInstances #-}

module TensorOps.BLAS.HMat
  ( HMat
  , HMatD
  ) where

import           Control.DeepSeq
import           Data.Kind
import           Data.Singletons
import           Data.Singletons.TypeLits
import           Data.Type.Combinator
import           Data.Type.Vector            (Vec, VecT(..))
import           Data.Type.Vector.Util       (curryV2', curryV3')
import           Numeric.LinearAlgebra
import           Numeric.LinearAlgebra.Data  as LA
import           Numeric.LinearAlgebra.Devel
import           TensorOps.BLAS
import           Type.Class.Higher
import           Type.Class.Higher.Util
import qualified Data.Finite                 as DF
import qualified Data.Finite.Internal        as DF
import qualified Data.Vector.Storable        as VS

type HMatD = HMat Double

data HMat :: Type -> BShape Nat -> Type where
    HMV :: { unHMV :: !(Vector a) } -> HMat a ('BV n)
    HMM :: { unHMM :: !(Matrix a) } -> HMat a ('BM n m)

instance (VS.Storable a, Show a, Element a) => Show (HMat a s) where
    showsPrec p = \case
      HMV x -> showParen (p > 10) $ showString "HMV "
                                  . showsPrec 11 x
      HMM x -> showParen (p > 10) $ showString "HMM "
                                  . showsPrec 11 x

instance (VS.Storable a, Show a, Element a) => Show1 (HMat a)

instance (VS.Storable a, NFData a) => NFData (HMat a s) where
    rnf = \case
      HMV xs -> rnf xs
      HMM xs -> rnf xs
    {-# INLINE rnf #-}

instance (VS.Storable a, NFData a) => NFData1 (HMat a)

instance (SingI s, Container Vector a, Container Matrix a, Num a) => Num (HMat a s) where
    (+) = unsafeZipH add add
    (*) = unsafeZipH (VS.zipWith (*)) (liftMatrix2 (VS.zipWith (*)))
    (-) = unsafeZipH (VS.zipWith (-)) (liftMatrix2 (VS.zipWith (-)))
    negate = unsafeMapH (scale (-1)) (scale (-1))
    abs    = unsafeMapH (cmap abs) (cmap abs)
    signum = unsafeMapH (cmap signum) (cmap signum)
    fromInteger = case (sing :: Sing s) of
        SBV n   -> HMV . flip konst (fromIntegral (fromSing n)) . fromInteger
        SBM n m -> HMM . flip konst (fromIntegral (fromSing n)
                                    ,fromIntegral (fromSing m)
                                    ) . fromInteger


-- | WARNING!! Functions should assume equal sized inputs and return
-- outputs of the same size!  This is not checked!!!
unsafeZipH
    :: (Vector a -> Vector a -> Vector a)
    -> (Matrix a -> Matrix a -> Matrix a)
    -> HMat a s -> HMat a s -> HMat a s
unsafeZipH f g = \case
    HMV x -> \case
      HMV y -> HMV $ f x y
    HMM x -> \case
      HMM y -> HMM $ g x y

-- | WARNING!! Functions should return outputs of the same size!  This is
-- not checked!!!
unsafeMapH
    :: (Vector a -> Vector a)
    -> (Matrix a -> Matrix a)
    -> HMat a s -> HMat a s
unsafeMapH f g = \case
    HMV x -> HMV $ f x
    HMM x -> HMM $ g x

liftB'
    :: (Numeric a)
    => Sing s
    -> (Vec n a -> a)
    -> Vec n (HMat a s)
    -> HMat a s
liftB' s f xs = bgen s $ \i -> f (indexB i <$> xs)
{-# INLINE liftB' #-}

instance (Container Vector a, Numeric a) => BLAS (HMat a) where
    type ElemB (HMat a) = a

    -- TODO: rewrite rules
    -- write in parallel?
    liftB
        :: forall n s. ()
        => Sing s
        -> (Vec n a -> a)
        -> Vec n (HMat a s)
        -> HMat a s
    liftB s f = \case
        ØV -> case s of
          SBV sN    -> HMV $ konst (f ØV) ( fromIntegral (fromSing sN) )
          SBM sN sM -> HMM $ konst (f ØV) ( fromIntegral (fromSing sN)
                                          , fromIntegral (fromSing sM)
                                          )
        I x :* ØV -> case x of
          HMV x' -> HMV (cmap (f . (:* ØV) . I) x')
          HMM x' -> HMM (cmap (f . (:* ØV) . I) x')
        I x :* I y :* ØV -> case x of
          HMV x' -> case y of
            HMV y' -> HMV $ VS.zipWith (curryV2' f) x' y'
          HMM x' -> case y of
            HMM y' -> HMM $ liftMatrix2 (VS.zipWith (curryV2' f)) x' y'
        xs@(I x :* I y :* I z :* ØV) -> case x of
          HMV x' -> case y of
            HMV y' -> case z of
              HMV z' -> HMV $ VS.zipWith3 (curryV3' f) x' y' z'
          _ -> liftB' s f xs
        xs@(_ :* _ :* _ :* _ :* _) -> liftB' s f xs

    axpy α (HMV x) my
        = HMV
        . maybe id (add . unHMV) my
        . scale α
        $ x
    {-# INLINE axpy #-}
    dot (HMV x) (HMV y)
        = x <.> y
    {-# INLINE dot #-}
    ger (HMV x) (HMV y)
        = HMM $ x `outer` y
    {-# INLINE ger #-}
    gemv α (HMM a) (HMV x) mβy
        = HMV
        . maybe id (\(β, HMV y) -> add (scale β y)) mβy
        . (a #>)
        . scale α
        $ x
    {-# INLINE gemv #-}
    gemm α (HMM a) (HMM b) mβc
        = HMM
        . maybe id (\(β, HMM c) -> add (scale β c)) mβc
        . (a <>)
        . scale α
        $ b
    {-# INLINE gemm #-}
    scaleB α = unsafeMapH (scale α) (scale α)
    {-# INLINE scaleB #-}
    addB = unsafeZipH add add
    {-# INLINE addB #-}
    indexB = \case
        PBV i -> \case
          HMV x -> x `atIndex` fromInteger (DF.getFinite i)
        PBM i j -> \case
          HMM x -> x `atIndex` ( fromInteger (DF.getFinite i)
                               , fromInteger (DF.getFinite j)
                               )
    {-# INLINE indexB #-}
    indexRowB i (HMM x) = HMV (x ! fromInteger (DF.getFinite i))
    {-# INLINE indexRowB #-}
    transpB (HMM x) = HMM (tr x)
    {-# INLINE transpB #-}
    iRowsB f (HMM x) = fmap (HMM . fromRows)
                     . traverse (\(i,r) -> unHMV <$> f (DF.Finite i) (HMV r))
                     . zip [0..]
                     . toRows
                     $ x
    {-# INLINE iRowsB #-}
    iElemsB f = \case
        HMV x -> fmap (HMV . fromList)
               . traverse (\(i,e) -> f (PBV (DF.Finite i)) e)
               . zip [0..]
               . LA.toList
               $ x
        HMM x -> fmap (HMM . fromLists)
               . traverse (\(i,rs) ->
                     traverse (\(j, e) -> f (PBM (DF.Finite i) (DF.Finite j)) e)
                   . zip [0..]
                   $ rs
                 )
               . zip [0..]
               . toLists
               $ x
    {-# INLINE iElemsB #-}
    -- TODO: can be implemented in parallel maybe?
    bgenA = \case
      SBV sN -> \f -> fmap (HMV . fromList)
                    . traverse (\i -> f (PBV (DF.Finite i)))
                    $ [0 .. fromSing sN - 1]
      SBM sN sM -> \f -> fmap (HMM . fromLists)
                       . traverse (\(i, js) ->
                           traverse (\j -> f (PBM (DF.Finite i) (DF.Finite j))) js
                         )
                       . zip [0 .. fromSing sN - 1]
                       $ repeat [0 .. fromSing sM - 1]
    {-# INLINE bgenA #-}
    bgenRowsA
        :: forall f n m. (Applicative f, SingI n)
        => (DF.Finite n -> f (HMat a ('BV m)))
        -> f (HMat a ('BM n m))
    bgenRowsA f = fmap (HMM . fromRows)
                . traverse (fmap unHMV . f . DF.Finite)
                $ [0 .. fromSing (sing @Nat @n) - 1]
    {-# INLINE bgenRowsA #-}

    eye = HMM . ident . fromIntegral . fromSing
    {-# INLINE eye #-}
    diagB = HMM . diag . unHMV
    {-# INLINE diagB #-}
    getDiagB = HMV . takeDiag . unHMM
    {-# INLINE getDiagB #-}
    traceB = sumElements . takeDiag . unHMM
    {-# INLINE traceB #-}
    sumB = \case
      HMV xs -> sumElements xs
      HMM xs -> sumElements xs
    {-# INLINE sumB #-}