2011年01月03日

Wikipediaのビタビアルゴリズムを写経

今年の冬休みは写経をけっこうやってる。今日はWikipediaのビタビアルゴリズムのソースをScalaにして写経(?)した。

頭では分かっているつもりでも、自分で書いて使ってみて分かることも多い。


import scala.collection.JavaConversions._

object Viterbi {

def forward_viterbi(
observation: Array[String],
states: Array[String],
sp: Map[String, Double],
tp: Map[String, Map[String, Double]],
ep: Map[String, Map[String, Double]]) {

var t = Map[String, (Double, List[String], Double)]()

for (state <- states)
t += (state -> (sp(state), List(state), sp(state)))

for (output <- observation) {
var u = Map[String, (Double, List[String], Double)]()
var total = 0.0
var argmax: List[String] = null
var valmax = 0.0

for (next_state <- states) {
for (source_state <- states) {
val probs = t(source_state)
val p = ep(source_state)(output) * tp(source_state)(next_state)
val prob = probs._1 * p
val v_prob = probs._3 * p
total += prob
if (v_prob > valmax) {
argmax = probs._2 :+ next_state
valmax = v_prob
}
}
u += next_state -> (total, argmax, valmax)
}
t = u
}

var total = 0.0
var argmax: List[String] = null
var valmax = 0.0
for (state <- states) {
val probs = t(state)
total += probs._1
if (probs._3 > valmax) {
argmax = probs._2
valmax = probs._3
}
}

println(total, argmax, valmax)
}

def main(args: Array[String]) {
val states = Array("Rainy", "Sunny")
val observations = Array("walk", "shop", "clean")
val start_probability = Map("Rainy" -> 0.6, "Sunny" -> 0.4)

val transition_probability = Map(
"Rainy" -> Map("Rainy" -> 0.7, "Sunny" -> 0.3),
"Sunny" -> Map("Rainy" -> 0.4, "Sunny" -> 0.6))

val emission_probability = Map(
"Rainy" -> Map("walk" -> 0.1, "shop" -> 0.4, "clean" -> 0.5),
"Sunny" -> Map("walk" -> 0.6, "shop" -> 0.3, "clean" -> 0.1))

forward_viterbi(observations, states,
start_probability, transition_probability, emission_probability)
}

}