{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Diff.Internal.Generics (
SumDiff(..)
, CtrDiff(..)
, diffSOP
, diffSOP'
, patchSOP
, undiffSOP
, showSOP
, showProd
) where
import Control.Monad
import Data.Bifunctor
import Data.Diff.Pretty
import Data.Kind
import Data.Maybe
import Data.Type.Combinator
import Data.Type.Combinator.Util
import Data.Type.Conjunction
import Data.Type.Equality
import Data.Type.Index
import Data.Type.Length
import Data.Type.Product
import Data.Type.Sum
import Type.Class.Higher
import Type.Class.Known
import Type.Class.Witness
import Type.Reflection
import qualified Data.Type.Product as TCP
import qualified Data.Type.Sum as TCS
import qualified Generics.SOP as SOP
import qualified Text.PrettyPrint.ANSI.Leijen as PP
data SumDiff :: (k -> Type) -> (k -> Type) -> [k] -> Type where
SD :: (Index as :&: CtrDiff f g as) a -> SumDiff f g as
data CtrDiff :: (k -> Type) -> (k -> Type) -> [k] -> k -> Type where
CDEdit :: g a -> CtrDiff f g as a
CDName :: (Index as :&: g) a -> CtrDiff f g as a
CDDiff :: f a -> (Index as :&: f) b -> CtrDiff f g as a
sumDiff'
:: forall f g as. ()
=> (forall a. (Index as :&: f :&: f) a -> g a)
-> Sum f as
-> Sum f as
-> SumDiff f g as
sumDiff' f (sumIx -> Some (i :&: x)) (sumIx -> Some (j :&: y)) =
case testEquality i j of
Just Refl -> SD ( i :&: CDEdit (f (i :&: x :&: y)) )
Nothing -> SD ( i :&: CDDiff x (j :&: y) )
sumDiff
:: forall f g as. Every Typeable as
=> (forall a. Typeable a => ((Index as :&: f) :&: (Index as :&: f)) a -> g a)
-> Sum f as
-> Sum f as
-> SumDiff f g as
sumDiff f (sumIx -> Some (i :&: x)) (sumIx -> Some (j :&: y)) =
every @_ @Typeable i //
every @_ @Typeable j //
case testEquality (tr i) (tr j) of
Just Refl
| i == j -> SD ( i
:&: CDEdit (f ((i :&: x) :&: (j :&: y)))
)
| otherwise -> SD ( i
:&: CDName (j :&: f ((i :&: x) :&: (j :&: y)))
)
Nothing -> SD ( i :&: CDDiff x (j :&: y) )
where
tr :: Typeable a => p a -> TypeRep a
tr _ = typeRep
diffSOP'
:: forall f ass. ()
=> (forall as a. Index ass as -> Index as a -> a -> a -> f a)
-> Sum Tuple ass
-> Sum Tuple ass
-> SumDiff Tuple (Prod f) ass
diffSOP' f = sumDiff' combine
where
combine
:: forall as. ()
=> (Index ass :&: Tuple :&: Tuple) as
-> Prod f as
combine (i :&: xs :&: ys) = izipProdWith go xs ys
where
go :: Index as a -> I a -> I a -> f a
go j (I x) (I y) = f i j x y
diffSOP
:: forall f ass. (Every Typeable ass)
=> (forall as a. Index ass as -> Index as a -> a -> a -> f a)
-> Sum Tuple ass
-> Sum Tuple ass
-> SumDiff Tuple (Prod f) ass
diffSOP f = sumDiff combine
where
combine
:: forall as. ()
=> ((Index ass :&: Tuple) :&: (Index ass :&: Tuple)) as
-> Prod f as
combine ((i :&: xs) :&: (_ :&: ys)) = izipProdWith go xs ys
where
go :: Index as a -> I a -> I a -> f a
go j (I x) (I y) = f i j x y
patchSOP
:: forall f ass. ()
=> (forall as a. Index ass as -> Index as a -> f a -> a -> Maybe a)
-> (forall as a. Index ass as -> Index as a -> a -> a -> Bool)
-> SumDiff Tuple (Prod f) ass
-> Sum Tuple ass
-> Maybe (Sum Tuple ass)
patchSOP f g = \case
SD (i :&: CDEdit es) -> \xss -> do
xs <- TCS.index i xss
ys <- itraverse1 (\k -> fmap I . go i k) (zipProd es xs)
return (injectSum i ys)
SD (i :&: CDName (j :&: es)) -> \xss -> do
xs <- TCS.index i xss
ys <- itraverse1 (\k -> fmap I . go i k) (zipProd es xs)
return (injectSum j ys)
SD (i :&: CDDiff xs (j :&: ys)) -> \xss -> do
xs' <- TCS.index i xss
izipProdWithA_ (\k (I x') (I x) -> guard $ g i k x' x) xs' xs
return (injectSum j ys)
where
go :: Index ass as -> Index as a -> (f :&: I) a -> Maybe a
go i k (e :&: I x) = f i k e x
undiffSOP
:: (forall as a. Index ass as -> Index as a -> f a -> (a, a))
-> SumDiff Tuple (Prod f) ass
-> (Sum Tuple :&: Sum Tuple) ass
undiffSOP f (SD (i :&: cd)) = case cd of
CDEdit es -> (injectSum i .&. injectSum i)
. unzipProd
. imap1 (\j e -> let (x, y) = f i j e
in I x :&: I y
)
$ es
CDName (j :&: es) -> (injectSum i .&. injectSum j)
. unzipProd
. imap1 (\k e -> let (x, y) = f i k e
in I x :&: I y
)
$ es
CDDiff xs (j :&: ys) -> injectSum i xs :&: injectSum j ys
showSOP
:: forall f ass. (Every (Every Show) ass)
=> (forall as a. Index ass as -> Index as a -> f a -> Maybe PP.Doc)
-> SOP.DatatypeInfo ass
-> SumDiff Tuple (Prod f) ass
-> PP.Doc
showSOP f di (SD ((i :: Index ass as) :&: cd)) = case cd of
CDEdit es ->
let ds = ifoldMap1 (\j -> (:[]) . f i j) es
in if null (catMaybes ds)
then ppNoChange
else (iNm PP.<+>) . PP.align $ case iCtr of
SOP.Constructor{} -> docList ds
SOP.Infix{} -> docList ds
SOP.Record{} ->
recDoc $ catMaybes (zipWith (fmap . recField) iFs ds)
CDName ((j :: Index ass bs) :&: es) ->
let ds = ifoldMap1 (\k -> (:[]) . f i k) es
jCtr :: SOP.ConstructorInfo bs
jCtr = ctrInfo j
(jNm, jFs) = bimap PP.text (TCP.toList SOP.fieldName) . ctrNames $ jCtr
in (ppDel iNm PP.<$>) . (ppAdd jNm PP.<+>) . PP.align $ if null ds
then mempty
else case (iCtr, jCtr) of
(_ , SOP.Record{}) ->
recDoc $ catMaybes (zipWith3 editsCR iFs jFs ds)
(SOP.Record{}, _ ) ->
recDoc $ catMaybes (zipWith3 editsCR iFs jFs ds)
(_ , _ ) -> docList ds
CDDiff _ ((j :: Index ass bs) :&: ys) ->
let jCtr :: SOP.ConstructorInfo bs
jCtr = ctrInfo j
(jNm, jFs) = bimap PP.text (TCP.toList SOP.fieldName) . ctrNames $ jCtr
ds = ifoldMap1 (\k (I x) -> [PP.text (show x)]
\\ every @_ @Show k
\\ every @_ @(Every Show) j
)
ys
in (ppDel iNm PP.<$>) . (ppAdd jNm PP.<+>) . PP.align $ case jCtr of
SOP.Constructor{} -> PP.vcat ds
SOP.Infix{} -> PP.vcat ds
SOP.Record{} -> recDoc $ zipWith recField jFs ds
where
ctrInfo :: Index ass bs -> SOP.ConstructorInfo bs
ctrInfo j = TCP.index j . sopProd . SOP.constructorInfo $ di
iCtr :: SOP.ConstructorInfo as
iCtr = TCP.index i . sopProd $ SOP.constructorInfo di
(iNm, iFs) = bimap PP.text (TCP.toList SOP.fieldName) . ctrNames $ iCtr
editsCR :: SOP.FieldName -> SOP.FieldName -> Maybe PP.Doc -> Maybe PP.Doc
editsCR f1 f2 = fmap $ \e -> PP.text f1
PP.<+> PP.yellow (PP.text "~>")
PP.<+> PP.text f2
PP.<+> e
showProd
:: (forall a. Index as a -> f a -> Maybe PP.Doc)
-> SOP.ConstructorInfo as
-> Prod f as
-> PP.Doc
showProd f ci es
| null (catMaybes ds) = ppNoChange
| otherwise = (iNm PP.<+>) . PP.align $ case ci of
SOP.Constructor{} -> docList ds
SOP.Infix{} -> docList ds
SOP.Record{} ->
recDoc $ catMaybes (zipWith (fmap . recField) iFs ds)
where
ds = ifoldMap1 (\i -> (:[]) . f i) es
(iNm, iFs) = bimap PP.text (TCP.toList SOP.fieldName) . ctrNames $ ci
recDoc :: [PP.Doc] -> PP.Doc
recDoc = PP.encloseSep PP.lbrace PP.rbrace PP.comma
recField :: SOP.FieldName -> PP.Doc -> PP.Doc
recField fn d = PP.text fn PP.<+> PP.char '=' PP.<+> d
docList :: [Maybe PP.Doc] -> PP.Doc
docList = PP.vcat . map (fromMaybe ppNoChange)
ctrNames
:: forall as. ()
=> SOP.ConstructorInfo as
-> (String, Prod SOP.FieldInfo as)
ctrNames = \case
SOP.Constructor n -> (n, numeric) \\ sListLength (SOP.sList @_ @as)
SOP.Infix n _ _ -> ("(_ " ++ n ++ " _)", numeric)
SOP.Record n fs -> (n, sopProd fs)
where
numeric :: forall bs. Known Length bs => Prod SOP.FieldInfo bs
numeric = map1 (SOP.FieldInfo . show . ixNum) indices