{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# OPTIONS_GHC -Wno-redundant-constraints #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
module Numeric.Opto.Update (
Linear(..), sumLinear, gAdd, gZeroL, gScale
, Metric(..), gDot, gNorm_inf, gNorm_0, gNorm_1, gNorm_2, gQuadrance
, LinearInPlace(..), sumLinearInPlace
, linearWit
) where
import Control.DeepSeq
import Control.Monad.Primitive
import Data.Coerce
import Data.Complex
import Data.Data
import Data.Finite
import Data.Foldable
import Data.Function
import Data.Maybe
import Data.Semigroup
import Data.Vinyl hiding ((:~:))
import GHC.Generics (Generic)
import GHC.TypeLits
import Generics.OneLiner
import Numeric.Opto.Ref
import Unsafe.Coerce
import qualified Data.Vector.Generic as VG
import qualified Data.Vector.Generic.Mutable.Sized as SVGM
import qualified Data.Vector.Generic.Sized as SVG
import qualified Numeric.LinearAlgebra as UH
import qualified Numeric.LinearAlgebra.Static as H
import qualified Numeric.LinearAlgebra.Static.Vector as H
class Num c => Linear c a | a -> c where
(.+.) :: a -> a -> a
zeroL :: a
(.*) :: c -> a -> a
infixl 6 .+.
infixl 7 .*
default (.+.) :: (ADTRecord a, Constraints a (Linear c)) => a -> a -> a
(.+.) = gAdd @c
default zeroL :: (ADTRecord a, Constraints a (Linear c)) => a
zeroL = gZeroL @c
default (.*) :: (ADTRecord a, Constraints a (Linear c)) => c -> a -> a
(.*) = gScale
sumLinear :: (Linear c a, Foldable t) => t a -> a
sumLinear = foldl' (.+.) zeroL
gAdd :: forall c a. (ADTRecord a, Constraints a (Linear c)) => a -> a -> a
gAdd = binaryOp @(Linear c) (.+.)
gZeroL :: forall c a. (ADTRecord a, Constraints a (Linear c)) => a
gZeroL = nullaryOp @(Linear c) zeroL
gScale :: forall c a. (ADTRecord a, Constraints a (Linear c)) => c -> a -> a
gScale c = unaryOp @(Linear c) (c .*)
class Linear c a => Metric c a where
infixl 7 <.>
(<.>) :: a -> a -> c
norm_inf :: a -> c
norm_0 :: a -> c
norm_1 :: a -> c
norm_2 :: a -> c
quadrance :: a -> c
default (<.>) :: (ADT a, Constraints a (Metric c)) => a -> a -> c
(<.>) = gDot
default norm_inf :: (ADT a, Constraints a (Metric c), Ord c) => a -> c
norm_inf = gNorm_inf
default norm_0 :: (ADT a, Constraints a (Metric c)) => a -> c
norm_0 = gNorm_0
default norm_1 :: (ADT a, Constraints a (Metric c)) => a -> c
norm_1 = gNorm_1
default norm_2 :: Floating c => a -> c
norm_2 = sqrt . quadrance
default quadrance :: (ADT a, Constraints a (Metric c)) => a -> c
quadrance = gQuadrance
gDot :: forall c a. (ADT a, Constraints a (Metric c), Num c) => a -> a -> c
gDot x = getSum . mzipWith @(Metric c) (\x' -> Sum . (x' <.>)) x
gNorm_inf :: forall c a. (ADT a, Constraints a (Metric c), Ord c) => a -> c
gNorm_inf = getMax
. fromMaybe (error "norm_inf: Divergent infinity norm")
. getOption
. gfoldMap @(Metric c) (Option . Just . Max . abs . norm_inf)
gNorm_0 :: forall c a. (ADT a, Constraints a (Metric c), Num c) => a -> c
gNorm_0 = getSum . gfoldMap @(Metric c) (Sum . norm_0)
gNorm_1 :: forall c a. (ADT a, Constraints a (Metric c), Num c) => a -> c
gNorm_1 = getSum . gfoldMap @(Metric c) (Sum . norm_1)
gNorm_2 :: forall c a. (ADT a, Constraints a (Metric c), Floating c) => a -> c
gNorm_2 = sqrt . gQuadrance
gQuadrance :: forall c a. (ADT a, Constraints a (Metric c), Num c) => a -> c
gQuadrance = getSum . gfoldMap @(Metric c) (Sum . quadrance)
class (Mutable m a, Linear c a) => LinearInPlace m c a where
(.+.=) :: Ref m a -> a -> m ()
(.*=) :: Ref m a -> c -> m ()
(.*+=) :: Ref m a -> (c, a) -> m ()
r .+.= x = modifyRef' r (.+. x)
r .*= c = modifyRef' r (c .*)
r .*+= (c, x) = modifyRef' r ((c .* x) .+.)
infix 4 .+.=
infix 4 .*=
infix 4 .*+=
sumLinearInPlace :: (LinearInPlace m c a, Foldable t) => Ref m a -> t a -> m ()
sumLinearInPlace v = mapM_ (v .+.=)
newtype LinearNum a = LinearNum { getLinearNum :: a }
deriving ( Show, Eq, Ord
, Functor, Foldable, Traversable
, Enum, Bounded
, Num, Fractional, Floating, Real, Integral, RealFrac, RealFloat
, Generic, Typeable, Data
)
instance NFData a => NFData (LinearNum a)
instance Num a => Linear a (LinearNum a) where
(.+.) = (+)
zeroL = 0
(.*) = coerce ((*) :: a -> a -> a)
instance Num a => Metric a (LinearNum a) where
(<.>) = coerce ((*) :: a -> a -> a)
norm_inf = coerce (abs :: a -> a)
norm_0 = coerce (abs . signum :: a -> a)
norm_1 = coerce (abs :: a -> a)
norm_2 = coerce (abs :: a -> a)
quadrance = coerce ((^ (2 :: Int)) :: a -> a)
deriving via (LinearNum Int) instance Linear Int Int
deriving via (LinearNum Integer) instance Linear Integer Integer
deriving via (LinearNum Rational) instance Linear Rational Rational
deriving via (LinearNum Float) instance Linear Float Float
deriving via (LinearNum Double) instance Linear Double Double
deriving via (LinearNum (Complex a)) instance RealFloat a => Linear (Complex a) (Complex a)
deriving via (LinearNum Int) instance Metric Int Int
deriving via (LinearNum Integer) instance Metric Integer Integer
deriving via (LinearNum Rational) instance Metric Rational Rational
deriving via (LinearNum Float) instance Metric Float Float
deriving via (LinearNum Double) instance Metric Double Double
deriving via (LinearNum (Complex a)) instance RealFloat a => Metric (Complex a) (Complex a)
instance Mutable m Int => LinearInPlace m Int Int
instance Mutable m Integer => LinearInPlace m Integer Integer
instance Mutable m Rational => LinearInPlace m Rational Rational
instance Mutable m Float => LinearInPlace m Float Float
instance Mutable m Double => LinearInPlace m Double Double
instance (Mutable m (Complex a), RealFloat a) => LinearInPlace m (Complex a) (Complex a)
instance (Num a, VG.Vector v a, KnownNat n) => Linear a (SVG.Vector v n a) where
(.+.) = (+)
zeroL = 0
c .* xs = SVG.map (c *) xs
instance (Floating a, Ord a, VG.Vector v a, KnownNat n) => Metric a (SVG.Vector v n a) where
xs <.> ys = SVG.sum (xs * ys)
norm_inf = SVG.foldl' (\x y -> max (abs x) y) 0
norm_0 = fromIntegral . SVG.length
norm_1 = SVG.sum . abs
quadrance = SVG.sum . (^ (2 :: Int))
instance (PrimMonad m, PrimState m ~ s, Num a, mv ~ VG.Mutable v, VG.Vector v a, KnownNat n)
=> LinearInPlace m a (SVG.Vector v n a) where
r .+.= xs = flip SVG.imapM_ xs $ \i x ->
SVGM.modify r (+ x) i
r .*= c = forM_ finites $ \i ->
SVGM.modify r (c *) i
r .*+= (c, xs) = flip SVG.imapM_ xs $ \i x ->
SVGM.modify r (+ (c * x)) i
instance KnownNat n => Linear Double (H.R n) where
(.+.) = (+)
zeroL = 0
c .* xs = H.konst c * xs
instance KnownNat n => Metric Double (H.R n) where
(<.>) = (H.<.>)
norm_inf = H.norm_Inf
norm_0 = H.norm_0
norm_1 = H.norm_1
norm_2 = H.norm_2
quadrance = (**2) . H.norm_2
instance (PrimMonad m, KnownNat n) => LinearInPlace m Double (H.R n) where
MR v .+.= x = v .+.= H.rVec x
MR v .*= c = v .*= c
MR v .*+= (c, x) = v .*+= (c, H.rVec x)
instance (KnownNat n, KnownNat m) => Linear Double (H.L n m) where
(.+.) = (+)
zeroL = 0
c .* xs = H.konst c * xs
instance (KnownNat n, KnownNat m) => Metric Double (H.L n m) where
(<.>) = (UH.<.>) `on` UH.flatten . H.extract
norm_inf = UH.maxElement . H.extract . abs
norm_0 = sum . map norm_0 . H.toRows
norm_1 = UH.sumElements . H.extract
norm_2 = UH.norm_2 . UH.flatten . H.extract
quadrance = (**2) . norm_2
instance (PrimMonad m, KnownNat n, KnownNat k) => LinearInPlace m Double (H.L n k) where
ML v .+.= x = v .+.= H.lVec x
ML v .*= c = v .*= c
ML v .*+= (c, x) = v .*+= (c, H.lVec x)
instance (Linear c a, Linear c b) => Linear c (a, b) where
instance (Linear c a, Linear c b, Linear c d) => Linear c (a, b, d) where
instance (Linear c a, Linear c b, Linear c d, Linear c e) => Linear c (a, b, d, e) where
instance (Linear c a, Linear c b, Linear c d, Linear c e, Linear c f) => Linear c (a, b, d, e, f) where
instance (Metric c a, Metric c b, Ord c, Floating c) => Metric c (a, b)
instance (Metric c a, Metric c b, Metric c d, Ord c, Floating c) => Metric c (a, b, d)
instance (Metric c a, Metric c b, Metric c d, Metric c e, Ord c, Floating c) => Metric c (a, b, d, e)
instance (Metric c a, Metric c b, Metric c d, Metric c e, Metric c f, Ord c, Floating c) => Metric c (a, b, d, e, f)
instance (Mutable m (a, b), Linear c a, Linear c b) => LinearInPlace m c (a, b)
instance (Mutable m (a, b, d), Linear c a, Linear c b, Linear c d) => LinearInPlace m c (a, b, d)
instance (Mutable m (a, b, d, e), Linear c a, Linear c b, Linear c d, Linear c e) => LinearInPlace m c (a, b, d, e)
instance (Mutable m (a, b, d, e, f), Linear c a, Linear c b, Linear c d, Linear c e, Linear c f) => LinearInPlace m c (a, b, d, e, f)
instance Linear c (f a) => Linear c (Rec f '[a]) where
x :& RNil .+. y :& RNil = (x .+. y) :& RNil
zeroL = zeroL :& RNil
c .* (x :& RNil) = (c .* x) :& RNil
instance (Linear c (f a), Linear c (Rec f (b ': bs))) => Linear c (Rec f (a ': b ': bs)) where
x :& xs .+. y :& ys = (x .+. y) :& (xs .+. ys)
zeroL = zeroL :& zeroL
c .* (x :& xs) = (c .* x) :& (c .* xs)
instance Metric c (f a) => Metric c (Rec f '[a]) where
(x :& RNil) <.> (y :& RNil) = x <.> y
norm_inf (x :& RNil) = norm_inf x
norm_0 (x :& RNil) = norm_0 x
norm_1 (x :& RNil) = norm_1 x
norm_2 (x :& RNil) = norm_2 x
quadrance (x :& RNil) = quadrance x
instance (Floating c, Ord c, Metric c (f a), Metric c (Rec f (b ': bs))) => Metric c (Rec f (a ': b ': bs)) where
(<.>) = undefined
norm_inf (x :& xs) = norm_0 x `max` norm_0 xs
norm_0 (x :& xs) = norm_0 x + norm_0 xs
norm_1 (x :& xs) = norm_1 x + norm_1 xs
norm_2 = sqrt . quadrance
quadrance (x :& xs) = quadrance x + quadrance xs
instance LinearInPlace m c (f a) => LinearInPlace m c (Rec f '[a]) where
(RecRef v :& RNil) .+.= (x :& RNil) = v .+.= x
(RecRef v :& RNil) .*= c = v .*= c
(RecRef v :& RNil) .*+= (c, x :& RNil) = v .*+= (c, x)
instance ( LinearInPlace m c (f a)
, LinearInPlace m c (Rec f (b ': bs))
)
=> LinearInPlace m c (Rec f (a ': b ': bs)) where
(RecRef v :& vs) .+.= (x :& xs) = do v .+.= x; vs .+.= xs
(RecRef v :& vs) .*= c = do v .*= c; vs .*= c
(RecRef v :& vs) .*+= (c, x :& xs) = do v .*+= (c, x); vs .*+= (c, xs)
linearWit
:: forall a c d. (Linear c a, Linear d a)
=> (c :~: d)
linearWit = unsafeCoerce Refl