module AOC.Common.Intcode.Memory (
MonadMem(..)
, Memory(..)
, mRegLens
, MemRef(..)
, initMemRef
, freezeMemRef
) where
import Control.DeepSeq
import Control.Lens
import Control.Monad.Except
import Control.Monad.Primitive
import Control.Monad.Reader
import Control.Monad.State
import Data.Conduino
import Data.Generics.Labels ()
import Data.Map (Map)
import Data.Primitive.MutVar
import GHC.Generics
import Numeric.Natural (Natural)
import qualified Data.Map as M
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Storable.Mutable as MVS
class Monad m => MonadMem m where
mRead :: m Int
mCurr :: m Natural
mPeek :: Natural -> m Int
mSeek :: Natural -> m ()
mWrite :: Natural -> Int -> m ()
mShiftBase :: Int -> m ()
mWithBase :: Int -> m Int
data Memory = Mem
{ mPos :: Natural
, mBase :: Int
, mRegs :: Map Natural Int
}
deriving (Eq, Ord, Show, Generic)
instance NFData Memory
instance Monad m => MonadMem (StateT Memory m) where
mRead = do
m@Mem{..} <- get
M.findWithDefault 0 mPos mRegs <$ put (m { mPos = mPos + 1 })
mCurr = gets mPos
mPeek i = gets $ M.findWithDefault 0 i . mRegs
mSeek z = modify $ \m -> m { mPos = z }
mWrite i x = modify $ \m -> m { mRegs = M.insert i x (mRegs m) }
mShiftBase b = modify $ \m -> m { mBase = mBase m + b }
mWithBase i = gets $ (+ i) . mBase
instance MonadMem m => MonadMem (Pipe i o u m) where
mRead = lift mRead
mCurr = lift mCurr
mPeek = lift . mPeek
mSeek = lift . mSeek
mWrite i = lift . mWrite i
mShiftBase = lift . mShiftBase
mWithBase = lift . mWithBase
instance MonadMem m => MonadMem (ExceptT e m) where
mRead = lift mRead
mCurr = lift mCurr
mPeek = lift . mPeek
mSeek = lift . mSeek
mWrite i = lift . mWrite i
mShiftBase = lift . mShiftBase
mWithBase = lift . mWithBase
mRegLens :: Natural -> Lens' Memory Int
mRegLens i = #mRegs . at i . non 0
data MemRef s = MemRef
{ mrPos :: MutVar s Natural
, mrBase :: MutVar s Int
, mrRegs :: MutVar s (VS.MVector s Int)
}
initMemRef :: (PrimMonad m, s ~ PrimState m) => Memory -> m (MemRef s)
initMemRef Mem{..} = do
mrPos <- newMutVar mPos
mrBase <- newMutVar mBase
mrRegs <- case M.lookupMax mRegs of
Nothing -> newMutVar =<< MVS.new 0
Just (n, _) -> do
let r = VS.generate (fromIntegral n * 10 + 1) $ \i -> M.findWithDefault 0 (fromIntegral i) mRegs
newMutVar =<< VS.thaw r
pure MemRef{..}
freezeMemRef :: (PrimMonad m, s ~ PrimState m) => MemRef s -> m Memory
freezeMemRef MemRef{..} = do
mPos <- readMutVar mrPos
mBase <- readMutVar mrBase
mRegs <- fmap toRegs . VS.freeze =<< readMutVar mrRegs
pure Mem{..}
where
toRegs = M.filter (/= 0) . M.fromList . zip [0..] . VS.toList
instance (PrimMonad m, s ~ PrimState m) => MonadMem (ReaderT (MemRef s) m) where
mRead = ask >>= \MemRef{..} -> do
i <- fromIntegral <$> atomicModifyMutVar' mrPos (\i -> (i+1, i))
mPeek i
mCurr = readMutVar =<< asks mrPos
mPeek i = do
r <- readMutVar =<< asks mrRegs
if i' < MVS.length r
then MVS.unsafeRead r i'
else pure 0
where
i' = fromIntegral i
mSeek i = (`writeMutVar` i) =<< asks mrPos
mWrite i x = ask >>= \MemRef{..} -> do
r <- readMutVar mrRegs
let l0 = MVS.length r
if i' < MVS.length r
then MVS.unsafeWrite r i' x
else do
let l1 = (i' + 1) * 2
regs' <- MVS.unsafeGrow r (l1 - l0)
forM_ [l0 .. l1 - 1] $ \j ->
MVS.unsafeWrite regs' j 0
MVS.unsafeWrite regs' i' x
writeMutVar mrRegs regs'
where
i' = fromIntegral i
mShiftBase b = (`modifyMutVar'` (+ b)) =<< asks mrBase
mWithBase i = fmap (+ i) . readMutVar =<< asks mrBase