{-|
Module : CPS
Description : Delimited continuation monad transformer
Maintainer : gatlin@niltag.net
Stability : experimental
-}

{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE BangPatterns #-}

module CPS
  ( -- * Delimited Continuation Monad Transformer
    CPS(..)
  , shift
  , reset
    -- * Utilities
  , lift
  )
  where

import Data.Kind (Type)

import Control.Applicative (Alternative(..))
import Control.Monad (MonadPlus(..))
import Control.Concurrent.MonadIO

newtype CPS (result :: k) (m :: k -> Type) (answer :: Type) = CPS
  { forall k (result :: k) (m :: k -> *) answer.
CPS result m answer -> (answer -> m result) -> m result
(#) :: (answer -> m result) -> m result }

reset :: Monad m => CPS r m r -> m r
reset :: forall (m :: * -> *) r. Monad m => CPS r m r -> m r
reset (CPS (r -> m r) -> m r
cc) = (r -> m r) -> m r
cc forall (m :: * -> *) a. Monad m => a -> m a
return

shift :: Monad m => ((a -> m r) -> CPS r m r) -> CPS r m a
shift :: forall (m :: * -> *) a r.
Monad m =>
((a -> m r) -> CPS r m r) -> CPS r m a
shift (a -> m r) -> CPS r m r
e = forall k (result :: k) (m :: k -> *) answer.
((answer -> m result) -> m result) -> CPS result m answer
CPS forall a b. (a -> b) -> a -> b
$ \a -> m r
k -> forall (m :: * -> *) r. Monad m => CPS r m r -> m r
reset ((a -> m r) -> CPS r m r
e a -> m r
k)

instance Functor (CPS r m) where
  fmap :: forall a b. (a -> b) -> CPS r m a -> CPS r m b
fmap a -> b
f !CPS r m a
c = forall k (result :: k) (m :: k -> *) answer.
((answer -> m result) -> m result) -> CPS result m answer
CPS forall a b. (a -> b) -> a -> b
$ \b -> m r
k -> CPS r m a
c forall k (result :: k) (m :: k -> *) answer.
CPS result m answer -> (answer -> m result) -> m result
# (b -> m r
k forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f)

instance Applicative (CPS r m) where
  pure :: forall a. a -> CPS r m a
pure a
x = forall k (result :: k) (m :: k -> *) answer.
((answer -> m result) -> m result) -> CPS result m answer
CPS (forall a b. (a -> b) -> a -> b
$ a
x)
  CPS r m (a -> b)
f <*> :: forall a b. CPS r m (a -> b) -> CPS r m a -> CPS r m b
<*> !CPS r m a
v = forall k (result :: k) (m :: k -> *) answer.
((answer -> m result) -> m result) -> CPS result m answer
CPS forall a b. (a -> b) -> a -> b
$ \b -> m r
c -> CPS r m (a -> b)
f forall k (result :: k) (m :: k -> *) answer.
CPS result m answer -> (answer -> m result) -> m result
# ( \a -> b
g -> CPS r m a
v forall k (result :: k) (m :: k -> *) answer.
CPS result m answer -> (answer -> m result) -> m result
# (b -> m r
c forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
g) )
  CPS r m a
m *> :: forall a b. CPS r m a -> CPS r m b -> CPS r m b
*> CPS r m b
k = CPS r m a
m forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
_ -> CPS r m b
k

instance Monad (CPS r m) where
  !CPS r m a
m >>= :: forall a b. CPS r m a -> (a -> CPS r m b) -> CPS r m b
>>= a -> CPS r m b
k = forall {k} (r :: k) (m :: k -> *) a.
CPS r m (CPS r m a) -> CPS r m a
_join (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> CPS r m b
k CPS r m a
m) where
    _join :: CPS r m (CPS r m a) -> CPS r m a
    _join :: forall {k} (r :: k) (m :: k -> *) a.
CPS r m (CPS r m a) -> CPS r m a
_join (CPS (CPS r m a -> m r) -> m r
cc) = forall k (result :: k) (m :: k -> *) answer.
((answer -> m result) -> m result) -> CPS result m answer
CPS (\(!a -> m r
k) -> (CPS r m a -> m r) -> m r
cc (\(CPS (a -> m r) -> m r
c) -> (a -> m r) -> m r
c a -> m r
k))

instance (HasFork m) => Alternative (CPS () m) where
  empty :: forall a. CPS () m a
empty = forall k (result :: k) (m :: k -> *) answer.
((answer -> m result) -> m result) -> CPS result m answer
CPS forall a b. (a -> b) -> a -> b
$ \a -> m ()
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
  CPS () m a
p <|> :: forall a. CPS () m a -> CPS () m a -> CPS () m a
<|> CPS () m a
q = forall k (result :: k) (m :: k -> *) answer.
((answer -> m result) -> m result) -> CPS result m answer
CPS forall a b. (a -> b) -> a -> b
$ \a -> m ()
k -> do
    forall (io :: * -> *). HasFork io => io () -> io ThreadId
fork (CPS () m a
p forall k (result :: k) (m :: k -> *) answer.
CPS result m answer -> (answer -> m result) -> m result
# a -> m ()
k)
    CPS () m a
q forall k (result :: k) (m :: k -> *) answer.
CPS result m answer -> (answer -> m result) -> m result
# a -> m ()
k
    forall (m :: * -> *) a. Monad m => a -> m a
return ()

instance (HasFork m) => MonadPlus (CPS () m) where
  mzero :: forall a. CPS () m a
mzero = forall (f :: * -> *) a. Alternative f => f a
empty
  mplus :: forall a. CPS () m a -> CPS () m a -> CPS () m a
mplus = forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
(<|>)

instance (MonadIO m) => MonadIO (CPS () m) where
  liftIO :: forall a. IO a -> CPS () m a
liftIO = forall (m :: * -> *) a r. Monad m => m a -> CPS r m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO

lift :: Monad m => m a -> CPS r m a
lift :: forall (m :: * -> *) a r. Monad m => m a -> CPS r m a
lift m a
v = forall k (result :: k) (m :: k -> *) answer.
((answer -> m result) -> m result) -> CPS result m answer
CPS (m a
v forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=)