{-# LANGUAGE DeriveFoldable      #-}
{-# LANGUAGE DeriveFunctor       #-}
{-# LANGUAGE DeriveGeneric       #-}
{-# LANGUAGE DeriveTraversable   #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE KindSignatures      #-}
{-# LANGUAGE LambdaCase          #-}
{-# LANGUAGE PatternSynonyms     #-}
{-# LANGUAGE PolyKinds           #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving  #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeInType          #-}
{-# LANGUAGE TypeOperators       #-}

module Data.Vector.Sized where

import           Control.DeepSeq
import           Data.Finite
import           Data.Finite.Internal
import           Data.Kind
import           Data.Proxy
import           Data.Singletons
import           Data.Singletons.TypeLits
import           Data.Distributive
import           Data.Type.Combinator
import           GHC.Generics
import           GHC.TypeLits
import           GHC.TypeLits.Util
import           Prelude hiding           ((!!))
import           Text.Printf
import qualified Data.Vector              as V


type Vector n = VectorT n I

newtype VectorT :: Nat -> (Type -> Type) -> Type -> Type where
    UnsafeV :: { getV :: V.Vector (f a) }
            -> VectorT n f a

deriving instance                  Generic (VectorT n f a)
deriving instance Show (f a)    => Show (VectorT n f a)
deriving instance Functor f     => Functor (VectorT n f)
deriving instance Traversable f => Traversable (VectorT n f)
deriving instance Foldable f    => Foldable (VectorT n f)

instance (KnownNat n, Distributive f) => Distributive (VectorT n f) where
    distribute xs = generate $ \i -> distribute $ fmap (! i) xs
    {-# INLINE distribute #-}

instance NFData (f a) => NFData (VectorT n f a)

instance (KnownNat n, Applicative f) => Applicative (VectorT n f) where
    pure x = UnsafeV $ V.replicate n (pure x)
      where
        n = fromIntegral $ natVal (Proxy @n)
    {-# INLINE pure #-}
    UnsafeV fs <*> UnsafeV xs = UnsafeV (V.zipWith (<*>) fs xs)
    {-# INLINE (<*>) #-}

mkVectorT
    :: forall n f a. KnownNat n
    => V.Vector (f a)
    -> Maybe (VectorT n f a)
mkVectorT v | V.length v == n = Just (UnsafeV v)
            | otherwise       = Nothing
  where
    n = fromIntegral $ natVal (Proxy @n)
{-# INLINE mkVectorT #-}

mkVector
    :: forall n a. KnownNat n
    => V.Vector a
    -> Maybe (Vector n a)
mkVector = mkVectorT . fmap I
{-# INLINE mkVector #-}

mkVectorT'
    :: forall n f a. KnownNat n
    => V.Vector (f a)
    -> VectorT n f a
mkVectorT' v | V.length v == n = UnsafeV v
             | otherwise       = error
                               $ printf "mkVectorT: Mismatched vector length. %d, expected %d" (V.length v) n
  where
    n = fromIntegral $ natVal (Proxy @n)
{-# INLINE mkVectorT' #-}

mkVector'
    :: forall n a. KnownNat n
    => V.Vector a
    -> Vector n a
mkVector' = mkVectorT' . fmap I
{-# INLINE mkVector' #-}

generate
    :: forall n f a. KnownNat n
    => (Finite n -> f a)
    -> VectorT n f a
generate f = UnsafeV $ V.generate n (f . Finite . fromIntegral)
  where
    n = fromIntegral $ natVal (Proxy @n)
{-# INLINE generate #-}

generateA
    :: forall n f a g. (KnownNat n, Applicative g)
    => (Finite n -> g (f a))
    -> g (VectorT n f a)
generateA f = UnsafeV <$> sequenceA (V.generate n (f . Finite . fromIntegral))
  where
    n = fromIntegral $ natVal (Proxy @n)
{-# INLINE generateA #-}

replicate
    :: KnownNat n
    => f a
    -> VectorT n f a
replicate x = generate (\_ -> x)
{-# INLINE replicate #-}

(!) :: VectorT n f a
    -> Finite n
    -> f a
UnsafeV v ! i = v `V.unsafeIndex` fromIntegral (getFinite i)
{-# INLINE (!) #-}

(!!)
    :: Vector n a
    -> Finite n
    -> a
v !! i = getI $ v ! i
{-# INLINE (!!) #-}

withVectorT
    :: forall f a r. ()
    => V.Vector (f a)
    -> (forall n. KnownNat n => VectorT n f a -> r)
    -> r
withVectorT v f =
    withSomeSing n $ \(SNat :: Sing (n :: Nat)) ->
      f (UnsafeV v :: VectorT n f a)
  where
    n :: Integer
    n = fromIntegral (V.length v)
{-# INLINE withVectorT #-}

withVector
    :: forall a r. ()
    => V.Vector a
    -> (forall n. KnownNat n => Vector n a -> r)
    -> r
withVector v f = withVectorT (I <$> v) f
{-# INLINE withVector #-}

liftVec
    :: (Applicative f, Traversable g)
    => (g a -> b)
    -> g (f a)
    -> f b
liftVec f xs = f <$> sequenceA xs
{-# INLINE liftVec #-}

vecFunc
    :: KnownNat n
    => (a -> Vector n b)
    -> Vector n (a -> b)
vecFunc f = generate (\i -> I $ (!! i) . f)
{-# INLINE vecFunc #-}

vap :: (f a -> g b -> h c)
    -> VectorT n f a
    -> VectorT n g b
    -> VectorT n h c
vap f (UnsafeV xs) (UnsafeV ys) = UnsafeV (V.zipWith f xs ys)
{-# INLINE vap #-}

vmap
    :: (f a -> g b)
    -> VectorT n f a
    -> VectorT n g b
vmap f (UnsafeV xs) = UnsafeV (f <$> xs)
{-# INLINE vmap #-}

data Uncons :: Nat -> (Type -> Type) -> Type -> Type where
    VNil  :: Uncons 0 f a
    VCons :: KnownNat n => !(f a) -> !(VectorT n f a) -> Uncons (n + 1) f a

uncons
    :: forall n f a. KnownNat n
    => VectorT n f a
    -> Uncons n f a
uncons (UnsafeV v) = case inductive (Proxy @n) of
    NatZ   -> VNil
    NatS _ -> VCons (V.unsafeHead v) (UnsafeV (V.unsafeTail v))
{-# INLINE uncons #-}

fromUncons
    :: Uncons n f a
    -> VectorT n f a
fromUncons = \case
    VNil                 -> UnsafeV V.empty
    VCons x (UnsafeV xs) -> UnsafeV (V.cons x xs)
{-# INLINE fromUncons #-}

cons
    :: f a
    -> VectorT n f a
    -> VectorT (n + 1) f a
x `cons` UnsafeV xs = UnsafeV (x `V.cons` xs)
{-# INLINE cons #-}

empty :: VectorT 0 f a
empty = UnsafeV (V.empty)
{-# INLINE empty #-}

head
    :: VectorT (m + 1) f a
    -> f a
head (UnsafeV v) = V.unsafeHead v
{-# INLINE head #-}

tail
    :: VectorT (m + 1) f a
    -> VectorT m f a
tail (UnsafeV v) = UnsafeV (V.unsafeTail v)
{-# INLINE tail #-}

itraverse
    :: Applicative h
    => (Finite n -> f a -> h (g b))
    -> VectorT n f a
    -> h (VectorT n g b)
itraverse f (UnsafeV v) = UnsafeV
    <$> sequenceA (V.imap (\i x -> f (Finite (fromIntegral i)) x) v)
{-# INLINE itraverse #-}