{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Type.Tuple (
T0(..)
, (:#)(..), pattern (:##)
, t2Tup, tupT2
, uncurryT2, curryT2
, t2_1, t2_2
, T, TF(..), pattern (:&&)
, indexT
, tOnly, onlyT
, tHead
, tTail
, TMaybe
, PMaybe(..,TJust), fromPJust
, type (:#?), pattern (:#?), splitTupMaybe, tupMaybe
, pattern MaybB
) where
import Control.DeepSeq
import Data.Bifunctor
import Data.Coerce
import Data.Data
import Data.Kind
import Data.Profunctor hiding (rmap)
import Data.Semigroup
import Data.Type.Functor.Product
import Data.Vinyl.Core
import GHC.Generics ((:*:)(..))
import GHC.Generics (Generic)
import Lens.Micro
import Lens.Micro.Internal hiding (Index)
import Numeric.Backprop hiding (T2)
import Numeric.Opto.Ref
import Numeric.Opto.Update
import qualified Data.Binary as Bi
import qualified Data.Functor.Contravariant as Co
import qualified Data.Vinyl.Functor as V
data T0 = T0
deriving (Show, Read, Eq, Ord, Generic, Data)
data a :# b = !a :# !b
deriving (Show, Read, Eq, Ord, Generic, Functor, Data, Typeable)
infixr 1 :#
type T = Rec TF
newtype TF a = TF { getTF :: a }
deriving (Show, Eq, Ord, Num, Fractional, Floating, Backprop, Generic, Typeable, NFData)
type TMaybe = PMaybe TF
_TF :: (Profunctor p, Functor f) => p a (f b) -> p (TF a) (f (TF b))
_TF = dimap getTF (fmap TF)
pattern (:&&) :: a -> T as -> T (a ': as)
pattern x :&& xs = TF x :& xs
{-# COMPLETE (:&&) #-}
infixr 5 :&&
deriving instance Typeable T0
deriving instance Typeable (a :# b)
instance NFData T0
instance (NFData a, NFData b) => NFData (a :# b)
instance Bi.Binary T0
instance (Bi.Binary a, Bi.Binary b) => Bi.Binary (a :# b)
instance Bi.Binary a => Bi.Binary (TF a)
instance Bifunctor (:#) where
bimap f g (x :# y) = f x :# g y
t2Tup :: (a :# b) -> (a, b)
t2Tup (x :# y) = (x, y)
tupT2 :: (a, b) -> a :# b
tupT2 (x, y) = x :# y
onlyT :: a -> T '[a]
onlyT = (:& RNil) . TF
tOnly :: T '[a] -> a
tOnly (TF x :& _) = x
uncurryT2 :: (a -> b -> c) -> a :# b -> c
uncurryT2 f (x :# y) = f x y
curryT2 :: (a :# b -> c) -> a -> b -> c
curryT2 f x y = f (x :# y)
instance Field1 (a :# b) (a' :# b) a a' where
_1 = t2_1
instance Field2 (a :# b) (a :# b') b b' where
_2 = t2_2
instance Field1 (T (a ': as)) (T (a ': as)) a a where
_1 = tIx IZ
instance Field2 (T (a ': b ': as)) (T (a ': b ': as)) b b where
_2 = tIx (IS IZ)
instance Field3 (T (a ': b ': c ': as)) (T (a ': b ': c ': as)) c c where
_3 = tIx (IS (IS IZ))
t2_1 :: Lens (a :# b) (a' :# b) a a'
t2_1 f (x :# y) = (:# y) <$> f x
t2_2 :: Lens (a :# b) (a :# b') b b'
t2_2 f (x :# y) = (x :#) <$> f y
indexT :: Index as a -> T as -> a
indexT = flip (^.) . tIx
recIx :: forall f as a. Index as a -> Lens' (Rec f as) (f a)
recIx = \case
IZ -> \f -> \case
x :& xs -> (:& xs) <$> f x
IS i -> \f -> \case
x :& xs -> (x :&) <$> recIx i f xs
tIx :: Index as a -> Lens' (T as) a
tIx i = recIx i . _TF
tHead :: Lens (T (a ': as)) (T (b ': as)) a b
tHead f (TF x :& xs) = (:& xs) . TF <$> f x
tTail :: Lens (T (a ': as)) (T (a ': bs)) (T as) (T bs)
tTail f (x :& xs) = (x :&) <$> f xs
instance Num T0 where
_ + _ = T0
_ - _ = T0
_ * _ = T0
negate _ = T0
abs _ = T0
signum _ = T0
fromInteger _ = T0
instance Fractional T0 where
_ / _ = T0
recip _ = T0
fromRational _ = T0
instance Floating T0 where
pi = T0
_ ** _ = T0
logBase _ _ = T0
exp _ = T0
log _ = T0
sqrt _ = T0
sin _ = T0
cos _ = T0
asin _ = T0
acos _ = T0
atan _ = T0
sinh _ = T0
cosh _ = T0
asinh _ = T0
acosh _ = T0
atanh _ = T0
instance Semigroup T0 where
_ <> _ = T0
instance Monoid T0 where
mempty = T0
mappend = (<>)
instance (Num a, Num b) => Num (a :# b) where
(x1 :# y1) + (x2 :# y2) = x1 + x2 :# y1 + y2
(x1 :# y1) - (x2 :# y2) = x1 - x2 :# y1 - y2
(x1 :# y1) * (x2 :# y2) = x1 * x2 :# y1 * y2
negate (x :# y) = negate x :# negate y
abs (x :# y) = abs x :# abs y
signum (x :# y) = signum x :# signum y
fromInteger x = fromInteger x :# fromInteger x
instance (Fractional a, Fractional b) => Fractional (a :# b) where
(x1 :# y1) / (x2 :# y2) = x1 / x2 :# y1 / y2
recip (x :# y) = recip x :# recip y
fromRational x = fromRational x :# fromRational x
instance (Floating a, Floating b) => Floating (a :# b) where
pi = pi :# pi
(x1 :# y1) ** (x2 :# y2) = x1 ** x2 :# y1 ** y2
logBase (x1 :# y1) (x2 :# y2) = logBase x1 x2 :# logBase y1 y2
exp (x :# y) = exp x :# exp y
log (x :# y) = log x :# log y
sqrt (x :# y) = sqrt x :# sqrt y
sin (x :# y) = sin x :# sin y
cos (x :# y) = cos x :# cos y
asin (x :# y) = asin x :# asin y
acos (x :# y) = acos x :# acos y
atan (x :# y) = atan x :# atan y
sinh (x :# y) = sinh x :# sinh y
cosh (x :# y) = cosh x :# cosh y
asinh (x :# y) = asinh x :# asinh y
acosh (x :# y) = acosh x :# acosh y
atanh (x :# y) = atanh x :# atanh y
instance (Semigroup a, Semigroup b) => Semigroup (a :# b) where
(x1 :# y1) <> (x2 :# y2) = x1 <> x2 :# y1 <> y2
instance (Semigroup a, Semigroup b, Monoid a, Monoid b) => Monoid (a :# b) where
mappend = (<>)
mempty = mempty :# mempty
constT
:: forall c as. RPureConstrained c as
=> (forall a. c a => a)
-> T as
constT x = rpureConstrained @c (TF x)
mapT
:: forall c as. (RPureConstrained c as, RMap as, RApply as)
=> (forall a. c a => a -> a)
-> T as
-> T as
mapT f = rzipWith coerce (rpureConstrained @c @as @Endo (Endo f))
zipT
:: forall c as. (RPureConstrained c as, RMap as, RApply as)
=> (forall a. c a => a -> a -> a)
-> T as
-> T as
-> T as
zipT f = rapply . rzipWith coerce (rpureConstrained @c @as @TripleEndo (TE f))
newtype TripleEndo a = TE (a -> a -> a)
instance (RPureConstrained Num as, RMap as, RApply as) => Num (Rec TF as) where
(+) = zipT @Num (+)
(-) = zipT @Num (-)
(*) = zipT @Num (*)
negate = mapT @Num negate
abs = mapT @Num abs
signum = mapT @Num signum
fromInteger x = constT @Num (fromInteger x)
instance (RPureConstrained Num as, RPureConstrained Fractional as, RMap as, RApply as) => Fractional (T as) where
(/) = zipT @Fractional (/)
recip = mapT @Fractional recip
fromRational x = constT @Fractional (fromRational x)
instance (RPureConstrained Num as, RPureConstrained Fractional as, RPureConstrained Floating as, RMap as, RApply as) => Floating (T as) where
pi = constT @Floating pi
(**) = zipT @Floating (**)
logBase = zipT @Floating logBase
exp = mapT @Floating exp
log = mapT @Floating log
sqrt = mapT @Floating sqrt
sin = mapT @Floating sin
cos = mapT @Floating cos
asin = mapT @Floating asin
acos = mapT @Floating acos
atan = mapT @Floating atan
sinh = mapT @Floating sinh
cosh = mapT @Floating cosh
asinh = mapT @Floating asinh
acosh = mapT @Floating acosh
atanh = mapT @Floating atanh
instance (Backprop a, Backprop b) => Backprop (a :# b) where
zero (x :# y) = zero x :# zero y
add (x1 :# y1) (x2 :# y2) = add x1 x2 :# add y1 y2
one (x :# y) = one x :# one y
pattern (:##)
:: (Backprop a, Backprop b, Reifies s W)
=> BVar s a
-> BVar s b
-> BVar s (a :# b)
pattern x :## y <- (\xy -> (xy ^^. _1, xy ^^. _2) -> (x, y))
where
(:##) = isoVar2 (:#) t2Tup
{-# COMPLETE (:##) #-}
instance (Mutable m a, Mutable m b) => Mutable m (a :# b) where
type Ref m (a :# b) = GRef m (a :# b)
thawRef = gThawRef
freezeRef = gFreezeRef
copyRef = gCopyRef
instance (Linear c a, Linear c b) => Linear c (a :# b)
instance (Floating c, Ord c, Metric c a, Metric c b) => Metric c (a :# b)
instance (LinearInPlace m c a, LinearInPlace m c b) => LinearInPlace m c (a :# b)
instance RPureConstrained NFData as => NFData (Rec TF as) where
rnf = go (rpureConstrained @NFData (Co.Op rnf))
where
go :: Rec (Co.Op ()) bs -> Rec TF bs -> ()
go = \case
RNil -> \case
RNil -> ()
Co.Op r :& rs -> \case
TF x :& xs -> r x `seq` go rs xs
instance (RPureConstrained Bi.Binary as, RMap as, RApply as, RecordToList as) => Bi.Binary (Rec TF as) where
put = sequence_ @[]
. recordToList
. rzipWith (\(Co.Op p) (TF x) -> V.Const (p x))
(rpureConstrained @Bi.Binary (Co.Op Bi.put))
get = rtraverse coerce $ rpureConstrained @Bi.Binary Bi.get
instance Mutable m a => Mutable m (TF a) where
type Ref m (TF a) = TF (Ref m a)
thawRef = fmap coerce . thawRef @m @a . coerce
freezeRef = fmap coerce . freezeRef @m @a . coerce
copyRef v = copyRef @m @a (coerce v) . coerce
instance Linear c a => Linear c (TF a)
instance (Ord c, Floating c, Metric c a) => Metric c (TF a)
instance (Mutable m a, LinearInPlace m c a) => LinearInPlace m c (TF a) where
TF v .+.= TF x = v .+.= x
TF v .*= c = v .*= c
TF v .*+= (c, TF x) = v .*+= (c, x)
fromPJust :: PMaybe f ('Just a) -> f a
fromPJust (PJust x) = x
pattern TJust :: a -> TMaybe ('Just a)
pattern TJust x = PJust (TF x)
type family (a :: Maybe Type) :#? (b :: Maybe Type) :: Maybe Type where
'Nothing :#? 'Nothing = 'Nothing
'Nothing :#? 'Just b = 'Just b
'Just a :#? 'Nothing = 'Just a
'Just a :#? 'Just b = 'Just (a :# b)
infixr 1 :#?
pattern (:#?)
:: (AllConstrainedProd Backprop a, AllConstrainedProd Backprop b, Reifies s W, PureProd Maybe a, PureProd Maybe b)
=> PMaybe (BVar s) a
-> PMaybe (BVar s) b
-> PMaybe (BVar s) (a :#? b)
pattern x :#? y <- (splitTupMaybe (\(v :## u) -> (v, u))->(x, y))
where
(:#?) = tupMaybe (:##)
tupMaybe
:: forall f a b. ()
=> (forall a' b'. (a ~ 'Just a', b ~ 'Just b') => f a' -> f b' -> f (a' :# b'))
-> PMaybe f a
-> PMaybe f b
-> PMaybe f (a :#? b)
tupMaybe f = \case
PNothing -> \case
PNothing -> PNothing
PJust y -> PJust y
PJust x -> \case
PNothing -> PJust x
PJust y -> PJust (f x y)
splitTupMaybe
:: forall f a b. (PureProd Maybe a, PureProd Maybe b)
=> (forall a' b'. (a ~ 'Just a', b ~ 'Just b') => f (a' :# b') -> (f a', f b'))
-> PMaybe f (a :#? b)
-> (PMaybe f a, PMaybe f b)
splitTupMaybe f = case pureShape @_ @a of
PNothing -> case pureShape @_ @b of
PNothing -> \case
PNothing -> (PNothing, PNothing)
PJust _ -> \case
PJust y -> (PNothing, PJust y)
PJust _ -> case pureShape @_ @b of
PNothing -> \case
PJust x -> (PJust x, PNothing)
PJust _ -> \case
PJust xy -> bimap PJust PJust . f $ xy
instance PureProdC Maybe Backprop as => Backprop (TMaybe as) where
zero = zipWithProd coerce (pureProdC @_ @Backprop (Endo zero))
{-# INLINE zero #-}
add xs = zipWithProd (\a (x :*: y) -> coerce a x y) (pureProdC @_ @Backprop (TE add))
. zipProd xs
{-# INLINE add #-}
one = zipWithProd coerce (pureProdC @_ @Backprop (Endo one))
{-# INLINE one #-}
pattern MaybB
:: (Reifies s W, AllConstrainedProd Backprop a, PureProd Maybe a)
=> PMaybe (BVar s) a
-> BVar s (TMaybe a)
pattern MaybB v <- (_mb->v)
where
MaybB = \case
PNothing -> auto PNothing
PJust x -> isoVar (PJust . TF) (getTF . fromPJust) x
{-# COMPLETE MaybB #-}
_mb :: forall a s. (AllConstrainedProd Backprop a, PureProd Maybe a, Reifies s W)
=> BVar s (TMaybe a)
-> PMaybe (BVar s) a
_mb v = case pureShape @_ @a of
PJust _ -> PJust $ isoVar (getTF . fromPJust) (PJust . TF) v
PNothing -> PNothing
{-# INLINE _mb #-}