{-# LANGUAGE ApplicativeDo                   #-}
{-# LANGUAGE DataKinds                       #-}
{-# LANGUAGE FlexibleContexts                #-}
{-# LANGUAGE PartialTypeSignatures           #-}
{-# LANGUAGE RankNTypes                      #-}
{-# LANGUAGE ScopedTypeVariables             #-}
{-# LANGUAGE TypeApplications                #-}
{-# LANGUAGE TypeFamilies                    #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}
module Backprop.Learn.Test (
  
    Test
  , maxIxTest, rmseTest
  , squaredErrorTest, absErrorTest, totalSquaredErrorTest, squaredErrorTestV
  , crossEntropyTest, crossEntropyTest1
  , boolTest
  
  , lossTest, lmapTest
  
  , testModel, testModelStoch, testModelAll, testModelStochAll
  
  , testModelCov, testModelCorr
  , testModelStochCov, testModelStochCorr
  ) where
import           Backprop.Learn.Loss
import           Backprop.Learn.Model
import           Control.Monad.Primitive
import           Data.Bifunctor
import           Data.Bitraversable
import           Data.Function
import           Data.Profunctor
import           Data.Proxy
import           GHC.TypeNats
import           Numeric.Backprop
import qualified Control.Foldl                as L
import qualified Numeric.LinearAlgebra        as HU
import qualified Numeric.LinearAlgebra.Static as H
import qualified System.Random.MWC            as MWC
type Test o = o -> o -> Double
lossTest :: Loss a -> Test a
lossTest l x = evalBP (l x)
boolTest :: forall a. RealFrac a => Test a
boolTest x y
    | ri x == ri y = 1
    | otherwise    = 0
  where
    ri :: a -> Int
    ri = round
maxIxTest :: KnownNat n => Test (H.R n)
maxIxTest x y
    | match x y = 1
    | otherwise = 0
  where
    match = (==) `on` (HU.maxIndex . H.extract)
rmseTest :: forall n. KnownNat n => Test (H.R n)
rmseTest x y = H.norm_2 (x - y) / sqrt (fromIntegral (natVal (Proxy @n)))
squaredErrorTest :: Real a => Test a
squaredErrorTest x y = e * e
  where
    e = realToFrac (x - y)
absErrorTest :: Real a => Test a
absErrorTest x y = realToFrac . abs $ x - y
totalSquaredErrorTest :: (Applicative t, Foldable t, Real a) => Test (t a)
totalSquaredErrorTest x y = realToFrac (sum e)
  where
    e = do
      x' <- x
      y' <- y
      pure ((x' - y') ^ (2 :: Int))
squaredErrorTestV :: KnownNat n => Test (H.R n)
squaredErrorTestV x y = e `H.dot` e
  where
    e = x - y
crossEntropyTest :: KnownNat n => Test (H.R n)
crossEntropyTest targ res = -(log res H.<.> targ)
crossEntropyTest1 :: Test Double
crossEntropyTest1 targ res = -(log res * targ + log (1 - res) * (1 - targ))
lmapTest
    :: (a -> b)
    -> Test b
    -> Test a
lmapTest f t x y = t (f x) (f y)
testModel
    :: Test b
    -> Model p 'Nothing a b
    -> TMaybe p
    -> a
    -> b
    -> Double
testModel t f mp x y = t y $ runModelStateless f mp x
testModelStoch
    :: PrimMonad m
    => Test b
    -> Model p 'Nothing a b
    -> MWC.Gen (PrimState m)
    -> TMaybe p
    -> a
    -> b
    -> m Double
testModelStoch t f g mp x y = t y <$> runModelStochStateless f g mp x
cov :: Fractional a => L.Fold (a, a) a
cov = do
    x  <- lmap fst           L.sum
    y  <- lmap snd           L.sum
    xy <- lmap (uncurry (*)) L.sum
    n  <- fromIntegral <$> L.length
    pure (xy / n - (x * y) / n / n)
corr :: Floating a => L.Fold (a, a) a
corr = do
    x  <- lmap fst           L.sum
    x2 <- lmap ((**2) . fst) L.sum
    y  <- lmap snd           L.sum
    y2 <- lmap ((**2) . snd) L.sum
    xy <- lmap (uncurry (*)) L.sum
    n  <- fromIntegral <$> L.length
    pure $ (xy / n - (x * y) / n / n)
         / sqrt ( x2 / n - (x / n)**2 )
         / sqrt ( y2 / n - (y / n)**2 )
testModelCov
    :: (Foldable t, Fractional b)
    => Model p 'Nothing a b
    -> TMaybe p
    -> t (a, b)
    -> b
testModelCov f p = L.fold $ (lmap . first) (runModelStateless f p) cov
testModelCorr
    :: (Foldable t, Floating b)
    => Model p 'Nothing a b
    -> TMaybe p
    -> t (a, b)
    -> b
testModelCorr f p = L.fold $ (lmap . first) (runModelStateless f p) corr
testModelAll
    :: Foldable t
    => Test b
    -> Model p 'Nothing a b
    -> TMaybe p
    -> t (a, b)
    -> Double
testModelAll t f p = L.fold $ lmap (uncurry (testModel t f p)) L.mean
testModelStochAll
    :: (Foldable t, PrimMonad m)
    => Test b
    -> Model p 'Nothing a b
    -> MWC.Gen (PrimState m)
    -> TMaybe p
    -> t (a, b)
    -> m Double
testModelStochAll t f g p = L.foldM $ L.premapM (uncurry (testModelStoch t f g p))
                                      (L.generalize L.mean)
testModelStochCov
    :: (Foldable t, PrimMonad m, Fractional b)
    => Model p 'Nothing a b
    -> MWC.Gen (PrimState m)
    -> TMaybe p
    -> t (a, b)
    -> m b
testModelStochCov f g p = L.foldM $ (L.premapM . flip bitraverse pure)
                                      (runModelStochStateless f g p)
                                      (L.generalize cov)
testModelStochCorr
    :: (Foldable t, PrimMonad m, Floating b)
    => Model p 'Nothing a b
    -> MWC.Gen (PrimState m)
    -> TMaybe p
    -> t (a, b)
    -> m b
testModelStochCorr f g p = L.foldM $ (L.premapM . flip bitraverse pure)
                                       (runModelStochStateless f g p)
                                       (L.generalize corr)