{-|
Module : HIO
Description : Hierarchical IO

This is a re-implementation of a module from Galois. It is part of the @orc@
package, which I am interested in experimenting with. I wanted to re-implement 2
out of the 3 modules in the original @Orc@ package, so rather than import it
only to chuck 2/3 of it just for its @HIO@ module I have reproduced it here.
-}

{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DeriveFunctor #-}

{-# LANGUAGE CPP #-}

module HIO
  ( -- * Hierarchical IO
    HIO(..)
  , runHIO
  , unHIO
    -- * Thread groups
  , Group(..)
  , newGroup
  , local
  , close
  , finished
  , register
    -- * Auxiliary types
  , Entry(..)
  , Inhabitants(..)
    -- * Profiling HIO
  , countingThreads
  , threadCount
  , incrementThreadCount
  , printThreadReport
  )
where

import System.IO.Unsafe (unsafePerformIO)
import Control.Monad
  ( ap
  , when
  , join )
import Control.Exception (mask, finally)
import Control.Concurrent.STM.MonadIO
  ( TVar(..)
  , readTVar
  , modifyTVar
  , modifyTVar_
  , newTVar
  , writeTVarSTM
  , writeTVar
  , check
  , readTVarSTM
  , atomically )
import Control.Concurrent.MonadIO
  ( HasFork(..)
  , MonadIO(..)
  , ThreadId
  , myThreadId
  , killThread)

-- * Preliminary: HIO, Hierarchical I/O

-- | A thread 'Group' accounts for its inhabitants, which may be threads or
-- other 'Group's.
type Group = (TVar Int, TVar Inhabitants)

-- | A group can be 'Closed', in which case it is empty and cannot accept new
-- inhabitants; or 'Open', in which case it contains any number of constituents,
-- and new 'Thread's and 'Group's may be registered with it.
data Inhabitants = Closed | Open [Entry]
data Entry = Thread ThreadId | Group Group

-- | 'HIO' is simply 'IO' augmented with an environment that tracks the current
-- thread 'Group'. This permits tracking forked threads and culling them
-- en masse when an ancestor is killed.
-- Because of its 'MonadIO' instance arbitrary 'IO' actions may be embedded;
-- however it is advised that any action be summarily killed.
newtype HIO a = HIO { forall a. HIO a -> Group -> IO a
inGroup :: Group -> IO a }

instance Functor HIO where
    fmap :: forall a b. (a -> b) -> HIO a -> HIO b
fmap a -> b
f (HIO Group -> IO a
hio) = forall a. (Group -> IO a) -> HIO a
HIO (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) Group -> IO a
hio)

instance Applicative HIO where
    pure :: forall a. a -> HIO a
pure a
x = forall a. (Group -> IO a) -> HIO a
HIO forall a b. (a -> b) -> a -> b
$ \Group
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
    <*> :: forall a b. HIO (a -> b) -> HIO a -> HIO b
(<*>) = forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap

instance Monad HIO where
    HIO a
m >>= :: forall a b. HIO a -> (a -> HIO b) -> HIO b
>>= a -> HIO b
k = forall a. HIO (HIO a) -> HIO a
_join (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> HIO b
k HIO a
m)
        where _join :: HIO (HIO a) -> HIO a     -- if you stand for nothing,
              _join :: forall a. HIO (HIO a) -> HIO a
_join HIO (HIO a)
hhio = forall a. (Group -> IO a) -> HIO a
HIO forall a b. (a -> b) -> a -> b
$ \Group
w -> do       -- you'll fall for anything.
                HIO a
x <- HIO (HIO a)
hhio forall a. HIO a -> Group -> IO a
`inGroup` Group
w           --         -- a bathroom
                HIO a
x forall a. HIO a -> Group -> IO a
`inGroup` Group
w

instance MonadIO HIO where
    liftIO :: forall a. IO a -> HIO a
liftIO IO a
io = forall a. (Group -> IO a) -> HIO a
HIO forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const IO a
io

instance HasFork HIO where
#ifdef __GHC_BLOCK_DEPRECATED__
    fork :: HIO () -> HIO ThreadId
fork HIO ()
hio = forall a. (Group -> IO a) -> HIO a
HIO forall a b. (a -> b) -> a -> b
$ \Group
w -> forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
countingThreads IO ()
incrementThreadCount
        Group -> IO ()
increment Group
w
        forall (io :: * -> *). HasFork io => io () -> io ThreadId
fork (do ThreadId
tid <- forall (io :: * -> *). HasFork io => io ThreadId
myThreadId
                 Entry -> Group -> IO ()
register (ThreadId -> Entry
Thread ThreadId
tid) Group
w
                 forall a. IO a -> IO a
restore (HIO ()
hio forall a. HIO a -> Group -> IO a
`inGroup` Group
w)
              forall a b. IO a -> IO b -> IO a
`finally`
              Group -> IO ()
decrement Group
w)
#else
    fork hio = HIO $ \w -> block $ do
        fork (block (do tid <- myThreadId
                        register (Thread tid) w
                        unblock (hio `inGroup` w))
              `finally`
              decrement w)
#endif

-- | Creates a new thread group and registers the current environment's thread
-- group in it. If the current group is closed, immediately terminates
-- execution of the current thread.
newGroup :: HIO Group
newGroup :: HIO Group
newGroup = forall a. (Group -> IO a) -> HIO a
HIO forall a b. (a -> b) -> a -> b
$ \Group
w -> do
    Group
w' <- IO Group
newPrimGroup
    Entry -> Group -> IO ()
register (Group -> Entry
Group Group
w') Group
w
    forall (m :: * -> *) a. Monad m => a -> m a
return Group
w'

-- | Explicitly sets the current 'Group' environment for a 'HIO' monad.
local :: Group -> HIO a -> HIO a
local :: forall a. Group -> HIO a -> HIO a
local Group
w HIO a
p = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (HIO a
p forall a. HIO a -> Group -> IO a
`inGroup` Group
w)

-- | Kills all threads which are descendants of a 'Group' and closes the group,
-- disallowing new threads or groups to be added to the group.
-- Doesn't do anything if the group is already closed.
close :: Group -> IO ()
close :: Group -> IO ()
close (TVar Int
c, TVar Inhabitants
t) = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (io :: * -> *). HasFork io => io () -> io ThreadId
fork (Entry -> IO ()
kill (Group -> Entry
Group (TVar Int
c, TVar Inhabitants
t)) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (io :: * -> *) a. MonadIO io => TVar a -> a -> io ()
writeTVar TVar Int
c Int
0)
               forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Blocks until the 'Group' @w@ is finished executing.
finished :: Group -> HIO ()
finished :: Group -> HIO ()
finished Group
w = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ Group -> IO ()
isZero Group
w

-- | Runs a 'HIO' computation inside a new thread group that has no parent, and
-- blocks until all subthreads of the operation are done executing.
-- If @countingThreads@ is @True@, it then prints some debugging information
-- about the threads run.
runHIO :: HIO b -> IO ()
runHIO :: forall b. HIO b -> IO ()
runHIO HIO b
hio = do
    Group
w <- IO Group
newPrimGroup
    b
_r <- HIO b
hio forall a. HIO a -> Group -> IO a
`inGroup` Group
w
    Group -> IO ()
isZero Group
w
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
countingThreads IO ()
printThreadReport
    forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Unsafely extracts the underlying result value from the 'HIO' monad.
unHIO :: HIO a -> a
unHIO :: forall a. HIO a -> a
unHIO HIO a
hio = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ do
    Group
w <- IO Group
newPrimGroup
    a
_r <- HIO a
hio forall a. HIO a -> Group -> IO a
`inGroup` Group
w
    Group -> IO ()
isZero Group
w
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
countingThreads IO ()
printThreadReport
    forall (m :: * -> *) a. Monad m => a -> m a
return a
_r

-- | Creates a new empty thread group.
newPrimGroup :: IO Group
newPrimGroup :: IO Group
newPrimGroup = do
    TVar Int
count <- forall (io :: * -> *) a. MonadIO io => a -> io (TVar a)
newTVar Int
0
    TVar Inhabitants
threads <- forall (io :: * -> *) a. MonadIO io => a -> io (TVar a)
newTVar ([Entry] -> Inhabitants
Open [])
    forall (m :: * -> *) a. Monad m => a -> m a
return (TVar Int
count, TVar Inhabitants
threads)

-- | Registers a thread/group entry @tid@ in a 'Group', terminating the current
-- thread (suicide) if the group is closed.
register :: Entry -> Group -> IO ()
register :: Entry -> Group -> IO ()
register Entry
tid (TVar Int
_, TVar Inhabitants
t) = forall (m :: * -> *) a. Monad m => m (m a) -> m a
join forall a b. (a -> b) -> a -> b
$ forall (io :: * -> *) a. MonadIO io => STM a -> io a
atomically forall a b. (a -> b) -> a -> b
$ do
    Inhabitants
ts <- forall a. TVar a -> STM a
readTVarSTM TVar Inhabitants
t
    case Inhabitants
ts of
        Inhabitants
Closed      -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall (io :: * -> *). HasFork io => io ThreadId
myThreadId forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (io :: * -> *). HasFork io => ThreadId -> io ()
killThread) -- suicide
        Open [Entry]
tids   -> forall a. TVar a -> a -> STM ()
writeTVarSTM TVar Inhabitants
t ([Entry] -> Inhabitants
Open (Entry
tid forall a. a -> [a] -> [a]
: [Entry]
tids)) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> -- register
                       forall (m :: * -> *) a. Monad m => a -> m a
return (forall (m :: * -> *) a. Monad m => a -> m a
return ())

-- | Recursively kills a thread/group entry. Does not do anything if the entry
-- is a closed group.
kill :: Entry -> IO ()
kill :: Entry -> IO ()
kill (Thread ThreadId
tid) = forall (io :: * -> *). HasFork io => ThreadId -> io ()
killThread ThreadId
tid
kill (Group (TVar Int
_,TVar Inhabitants
t)) = do
    (Inhabitants
ts, Inhabitants
_) <- forall (io :: * -> *) a.
MonadIO io =>
TVar a -> (a -> a) -> io (a, a)
modifyTVar TVar Inhabitants
t (forall a b. a -> b -> a
const Inhabitants
Closed)
    case Inhabitants
ts of
        Inhabitants
Closed      -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Open [Entry]
tids   -> forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ (forall a b. (a -> b) -> [a] -> [b]
map Entry -> IO ()
kill [Entry]
tids)

increment, decrement, isZero :: Group -> IO ()
increment :: Group -> IO ()
increment (TVar Int
c, TVar Inhabitants
_) = forall (io :: * -> *) a. MonadIO io => TVar a -> (a -> a) -> io ()
modifyTVar_ TVar Int
c (forall a. Num a => a -> a -> a
+Int
1)
decrement :: Group -> IO ()
decrement (TVar Int
c, TVar Inhabitants
_) = forall (io :: * -> *) a. MonadIO io => TVar a -> (a -> a) -> io ()
modifyTVar_ TVar Int
c (\Int
x -> Int
x forall a. Num a => a -> a -> a
- Int
1)
isZero :: Group -> IO ()
isZero    (TVar Int
c, TVar Inhabitants
_) = forall (io :: * -> *) a. MonadIO io => STM a -> io a
atomically forall a b. (a -> b) -> a -> b
$ (forall a. TVar a -> STM a
readTVarSTM TVar Int
c forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Bool -> STM ()
check forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Eq a => a -> a -> Bool
== Int
0)))

-- * Profiling HIO

countingThreads :: Bool
countingThreads :: Bool
countingThreads = Bool
True

threadCount :: TVar Integer
threadCount :: TVar Integer
threadCount = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall (io :: * -> *) a. MonadIO io => a -> io (TVar a)
newTVar Integer
0

incrementThreadCount :: IO ()
incrementThreadCount :: IO ()
incrementThreadCount = forall (io :: * -> *) a. MonadIO io => TVar a -> (a -> a) -> io ()
modifyTVar_ TVar Integer
threadCount (forall a. Num a => a -> a -> a
+Integer
1)

printThreadReport :: IO ()
printThreadReport :: IO ()
printThreadReport = do
    Integer
n <- forall (io :: * -> *) a. MonadIO io => TVar a -> io a
readTVar TVar Integer
threadCount
    String -> IO ()
putStrLn String
"----------"
    String -> IO ()
putStrLn (forall a. Show a => a -> String
show Integer
n forall a. [a] -> [a] -> [a]
++ String
" HIO threads were forked.")