はわわーっ

はわわわわっ

Haskell で確率計算

すごいH本の最後の方にあるやつ。

まず型を定義する。

import Data.Ratio

newtype Prob a = Prob [(a, Rational)] deriving Show

確率は Data.Ratio を使って表すことにする。使い方はこんな感じ。

data Coin = Heads | Tails deriving (Show, Eq)
coin :: Prob Coin
coin = Prob [(Heads, 1%2), (Tails, 1%2)]

dice :: Prob Int
dice = Prob $ map (\x -> (x, 1%6)) [1..6]
> coin
Prob [(Heads,1 % 2),(Tails,1 % 2)]
> dice
Prob [(1,1 % 6),(2,1 % 6),(3,1 % 6),(4,1 % 6),(5,1 % 6),(6,1 % 6)]

で、これを Functor にしてみる。Functor にするには fmap を実装すればいい。この場合は確率はそのままで各要素に関数を適用するようにする。

import Control.Arrow (first)

instance Functor Prob where
  fmap f (Prob xs) = Prob $ map (first f) xs
> fmap (== Heads) coin
Prob [(True,1 % 2),(False,1 % 2)]
> fmap (* 10) dice
Prob [(10,1 % 6),(20,1 % 6),(30,1 % 6),(40,1 % 6),(50,1 % 6),(60,1 % 6)]

次に Monad にする。return は要素が x で確率が 1 のものを返せばいい。
>>= は join (fmap f m) と同じなので、まず join に相当する関数を作る。名前は joinProb にしておく。
やることは

Prob [(Prob [('a', 1%3), ('b', 2%3)], 1%2),
      (Prob [('c', 1%4), ('d', 3%4)], 1%2)]

みたいなやつを

Prob [('a', 1%6), ('b', 1%3), ('c', 1%8), ('d', 3%8)]

にすればいい。

import Control.Arrow (second)

joinProb :: Prob (Prob a) -> Prob a
joinProb (Prob ps) = Prob $ concat $ map multiplyProb ps
  where
    multiplyProb (Prob xs, p) = map (second (*p)) xs

これを使って Monadインスタンスにする。

instance Monad Prob where
  return x = Prob [(x, 1)]
  m >>= f  = joinProb (fmap f m)

モナドを使ってみる。

coin3 :: Prob [Coin]
coin3 = do
  a <- coin
  b <- coin
  c <- coin
  return [a,b,c]

dice2 :: Prob Int
dice2 = do
  a <- dice
  b <- dice
  return (a + b)
> coin3
Prob [([Heads,Heads,Heads],1 % 8),
      ([Heads,Heads,Tails],1 % 8),
      ([Heads,Tails,Heads],1 % 8),
      ([Heads,Tails,Tails],1 % 8),
      ([Tails,Heads,Heads],1 % 8),
      ([Tails,Heads,Tails],1 % 8),
      ([Tails,Tails,Heads],1 % 8),
      ([Tails,Tails,Tails],1 % 8)]
> dice2
Prob [(2,1 % 36),(3,1 % 36),(4,1 % 36),(5,1 % 36),(6,1 % 36),(7,1 % 36),
      (3,1 % 36),(4,1 % 36),(5,1 % 36),(6,1 % 36),(7,1 % 36),(8,1 % 36),
      (4,1 % 36),(5,1 % 36),(6,1 % 36),(7,1 % 36),(8,1 % 36),(9,1 % 36),
      (5,1 % 36),(6,1 % 36),(7,1 % 36),(8,1 % 36),(9,1 % 36),(10,1 % 36),
      (6,1 % 36),(7,1 % 36),(8,1 % 36),(9,1 % 36),(10,1 % 36),(11,1 % 36),
      (7,1 % 36),(8,1 % 36),(9,1 % 36),(10,1 % 36),(11,1 % 36),(12,1 % 36)]

このままだと dice2 に同じ要素のものが複数でてくるので、それをまとめる関数を作る。

integrateProb :: Eq a => Prob a -> Prob a
integrateProb (Prob ps) = Prob $ integrate ps
  where
    integrate []       = []
    integrate ps@(p:_) = (fst p, sum $ map snd pr) : integrate pr'
      where
        pr  = filter ((== fst p) . fst) ps
        pr' = filter ((/= fst p) . fst) ps


dice2 :: Prob Int
dice2 = integrateProb $ do
  a <- dice
  b <- dice
  return (a + b)
> dice2
Prob [(2,1 % 36),(3,1 % 18),(4,1 % 12),(5,1 % 9),(6,5 % 36),(7,1 % 6),
      (8,5 % 36),(9,1 % 9),(10,1 % 12),(11,1 % 18),(12,1 % 36)]

コインを3回投げて表が n 枚でる確率

> integrateProb $ liftM (length . filter (== Heads)) coin3
Prob [(3,1 % 8),(2,3 % 8),(1,3 % 8),(0,1 % 8)]

サイコロを2回投げて和が3の倍数になる確率

> integrateProb $ liftM ((== 0) . (`mod` 3)) dice2
Prob [(False,2 % 3),(True,1 % 3)]

最後に Prob をアプリカティブファンクターにする。もう Monadインスタンスになっているので pure が return で、<*> が ap になるようにすればいい。

import Control.Applicative
import Control.Monad

instance Applicative Prob where
  pure  = return
  (<*>) = ap

アプリカティブファンクターを使ってみる。

> integrateProb $ (+) <$> dice <*> dice
Prob [(2,1 % 36),(3,1 % 18),(4,1 % 12),(5,1 % 9),(6,5 % 36),(7,1 % 6),
      (8,5 % 36),(9,1 % 9),(10,1 % 12),(11,1 % 18),(12,1 % 36)]

ソース全体。

import Control.Applicative
import Control.Arrow (first, second)
import Control.Monad
import Data.Ratio

newtype Prob a = Prob [(a, Rational)] deriving Show

instance Functor Prob where
  fmap f (Prob xs) = Prob $ map (first f) xs

instance Monad Prob where
  return x = Prob [(x, 1)]
  m >>= f  = joinProb (fmap f m)

instance Applicative Prob where
  pure  = return
  (<*>) = ap


joinProb :: Prob (Prob a) -> Prob a
joinProb (Prob ps) = Prob $ concat $ map multiplyProb ps
  where
    multiplyProb (Prob xs, p) = map (second (*p)) xs

integrateProb :: Eq a => Prob a -> Prob a
integrateProb (Prob ps) = Prob $ integrate ps
  where
    integrate []       = []
    integrate ps@(p:_) = (fst p, sum $ map snd pr) : integrate pr'
      where
        pr  = filter ((== fst p) . fst) ps
        pr' = filter ((/= fst p) . fst) ps


data Coin = Heads | Tails deriving (Show, Eq)
coin :: Prob Coin
coin = Prob [(Heads, 1%2), (Tails, 1%2)]

dice :: Prob Int
dice = Prob $ map (\x -> (x, 1%6)) [1..6]


coin3 :: Prob [Coin]
coin3 = do
  a <- coin
  b <- coin
  c <- coin
  return [a,b,c]

dice2 :: Prob Int
dice2 = integrateProb $ do
  a <- dice
  b <- dice
  return (a + b)