{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
module Numeric.Opto.Core (
Diff, Grad, Opto(..)
, mapSample, mapOpto
, fromCopying, fromStateless
, pureGrad
, nonSampling, pureNonSampling
) where
import Data.Kind
import Data.Type.Equality
import Numeric.Opto.Ref
import Numeric.Opto.Update
type Diff a = a
type Grad m r a = r -> a -> m (Diff a)
data Opto :: (Type -> Type) -> Type -> Type -> Type where
MkOpto :: forall s m r a c. (LinearInPlace m c a, Mutable m s)
=> { oInit :: !s
, oUpdate :: !( Ref m s
-> r
-> a
-> m (c, Diff a)
)
}
-> Opto m r a
mapOpto
:: forall m n r a c. (LinearInPlace n c a)
=> (forall x. m x -> n x)
-> (forall x. Ref n x -> Ref m x)
-> Opto m r a
-> Opto n r a
mapOpto f g (MkOpto o (u :: Ref m s -> r -> b -> m (d, b))) =
reMutable @m @n @s f $ case linearWit @a @c @d of
Refl -> MkOpto @s @n @r @a @d o $ \r i x -> f (u (g r) i x)
mapSample
:: (r -> s)
-> Opto m s a
-> Opto m r a
mapSample f MkOpto{..} = MkOpto
{ oInit = oInit
, oUpdate = \u r -> oUpdate u (f r)
}
fromCopying
:: (LinearInPlace m c a, Mutable m s)
=> s
-> (r -> a -> s -> m (c, Diff a, s))
-> Opto m r a
fromCopying s0 update = MkOpto
{ oInit = s0
, oUpdate = \rS r x -> do
(c, g, s) <- update r x =<< freezeRef rS
copyRef rS s
return (c, g)
}
fromStateless
:: (LinearInPlace m c a)
=> (r -> a -> m (c, Diff a))
-> Opto m r a
fromStateless update = MkOpto
{ oInit = ()
, oUpdate = \(~()) -> update
}
pureGrad
:: Applicative m
=> (r -> a -> Diff a)
-> Grad m r a
pureGrad f r = pure . f r
nonSampling
:: (a -> m (Diff a))
-> Grad m r a
nonSampling f _ = f
pureNonSampling
:: Applicative m
=> (a -> Diff a)
-> Grad m r a
pureNonSampling f _ = pure . f