维特比算法(viterbi)原理以及简单实现

参考文档: https://zhuanlan.zhihu.com/p/161436964

维基百科的解释,维特比算法(Viterbi algorithm)是一种动态规划算法。它用于寻找最有可能产生观测事件序列的维特比路径——隐含状态序列,特别是在马尔可夫信息源上下文和隐马尔可夫模型中。

如何通俗地讲解 viterbi 算法?

参考文档: https://www.zhihu.com/question/20136144

如下图,假如你从S和E之间找一条最短的路径,除了遍历完所有路径,还有什么更好的方法?

viterbi_introduction_1.webp

viterbi维特比算法解决的是篱笆型的图的最短路径问题,图的节点按列组织,每列的节点数量可以不一样,每一列的节点只能和相邻列的节点相连,不能跨列相连,节点之间有着不同的距离,距离的值就不在图上一一标注出来了,大家自行脑补

过程非常简单:

为了找出S到E之间的最短路径,我们先从S开始从左到右一列一列地来看。

首先起点是S,从S到A列的路径有三种可能:S-A1、S-A2、S-A3,如下图:

viterbi_introduction_2.webp

我们不能武断地说S-A1、S-A2、S-A3中的哪一段必定是全局最短路径中的一部分,目前为止任何一段都有可能是全局最短路径的备选项。

我们继续往右看,到了B列。按B列的B1、B2、B3逐个分析。

先看B1:

viterbi_introduction_3.webp

如上图,经过B1的所有路径只有3条:

S-A1-B1

S-A2-B1

S-A3-B1

以上这三条路径,各节点距离加起来对比一下,我们就可以知道其中哪一条是最短的。假设S-A3-B1是最短的,那么我们就知道了经过B1的所有路径当中S-A3-B1是最短的,其它两条路径路径S-A1-B1和S-A2-B1都比S-A3-B1长,绝对不是目标答案,可以大胆地删掉了。删掉了不可能是答案的路径,就是viterbi算法(维特比算法)的重点,因为后面我们再也不用考虑这些被删掉的路径了。现在经过B1的所有路径只剩一条路径了,如下图:

viterbi_introduction_4.webp

接下来,我们继续看B2:

viterbi_introduction_5.webp

同理,如上图,经过B2的路径有3条:

S-A1-B2

S-A2-B2

S-A3-B2

这三条路径中,各节点距离加起来对比一下,我们肯定也可以知道其中哪一条是最短的,假设S-A1-B2是最短的,那么我们就知道了经过B2的所有路径当中S-A1-B2是最短的,其它两条路径路径S-A2-B2和S-A3-B1也可以删掉了。经过B2所有路径只剩一条,如下图:

viterbi_introduction_6.webp

接下来我们继续看B3:

viterbi_introduction_7.webp

同理,如上图,经过B3的路径也有3条:

S-A1-B3

S-A2-B3

S-A3-B3

这三条路径中我们也肯定可以算出其中哪一条是最短的,假设S-A2-B3是最短的,那么我们就知道了经过B3的所有路径当中S-A2-B3是最短的,其它两条路径路径S-A1-B3和S-A3-B3也可以删掉了。经过B3的所有路径只剩一条,如下图:

viterbi_introduction_8.webp

现在对于B列的所有节点我们都过了一遍,B列的每个节点我们都删除了一些不可能是答案的路径,看看我们剩下哪些备选的最短路径,如下图:

viterbi_introduction_9.webp

上图是我们删掉了其它不可能是最短路径的情况,留下了三个有可能是最短的路径:S-A3-B1、S-A1-B2、S-A2-B3。现在我们将这三条备选的路径放在一起汇总到下图:

viterbi_introduction_10.webp

S-A3-B1、S-A1-B2、S-A2-B3都有可能是全局的最短路径的备选路径,我们还没有足够的信息判断哪一条一定是全局最短路径的子路径。

如果我们你认为没毛病就继续往下看C列,如果不理解,回头再看一遍,前面的步骤决定你是否能看懂viterbi算法(维特比算法)。

接下来讲到C列了,类似上面说的B列,我们从C1、C2、C3一个个节点分析。

经过C1节点的路径有:

S-A3-B1-C1、

S-A1-B2-C1、

S-A2-B3-C1

viterbi_introduction_11.webp

和B列的做法一样,从这三条路径中找到最短的那条(假定是S-A3-B1-C1),其它两条路径同样道理可以删掉了。那么经过C1的所有路径只剩一条,如下图:

viterbi_introduction_12.webp

同理,我们可以找到经过C2和C3节点的最短路径,汇总一下:

viterbi_introduction_13.webp

到达C列时最终也只剩3条备选的最短路径,我们仍然没有足够信息断定哪条才是全局最短。

最后,我们继续看E节点,才能得出最后的结论。

到E的路径也只有3种可能性:

viterbi_introduction_14.webp

E点已经是终点了,我们稍微对比一下这三条路径的总长度就能知道哪条是最短路径了。

viterbi_introduction_15.webp

在效率方面相对于粗暴地遍历所有路径,viterbi 维特比算法到达每一列的时候都会删除不符合最短路径要求的路径,大大降低时间复杂度。

算法原理

维特比算法就是求所有观测序列中的最优,如下图所示,我们要求从S到E的最优序列,中间有3个时刻,每个时刻都有对应的不同观察的概率,下图中每个时刻不同的观测标签有3个。

viterbi_introduction_16.webp

求所有路径中最优路径,最容易想到的就是暴力解法,直接把所有路径全部计算出来,然后找出最优的。这方法理论上是可行,但当序列很长时,时间复杂夫很高。而且进行了大量的重复计算,viterbi算法就是用动态规划的方法就减少这些重复计算。

viterbi算法是每次记录到当前时刻,每个观察标签的最优序列,如下图,假设在t时刻已经保存了从0到t时刻的最优路径,那么t+1时刻只需要计算从t到t+1的最优就可以了,图中红箭头表示从t时刻到t+1时刻,观测标签为1,2,3的最优。

viterbi_introduction_17.webp

每次只需要保存到当前位置最优路径,之后循环向后走。到结束时,从最后一个时刻的最优值回溯到开始位置,回溯完成后,这个从开始到结束的路径就是最优的。

viterbi_introduction_18.webp

代码实现

下面用python简单实现一下viterbi算法

import numpy as np

def viterbi_decode(nodes, trans):
    """
    Viterbi算法求最优路径
    其中 nodes.shape=[seq_len, num_labels],
        trans.shape=[num_labels, num_labels].
    """
    # 获得输入状态序列的长度,以及观察标签的个数
    seq_len, num_labels = len(nodes), len(trans)
    # 简单起见,先不考虑发射概率,直接用起始0时刻的分数
    scores = nodes[0].reshape((-1, 1))
  
    paths = []
    # 递推求解上一时刻t-1到当前时刻t的最优
    for t in range(1, seq_len):
        # scores 表示起始0到t-1时刻的每个标签的最优分数
        scores_repeat = np.repeat(scores, num_labels, axis=1)
        # observe当前时刻t的每个标签的观测分数
        observe = nodes[t].reshape((1, -1))
        observe_repeat = np.repeat(observe, num_labels, axis=0)
        # 从t-1时刻到t时刻最优分数的计算,这里需要考虑转移分数trans
        M = scores_repeat + trans + observe_repeat
        # 寻找到t时刻的最优路径
        scores = np.max(M, axis=0).reshape((-1, 1))
        idxs = np.argmax(M, axis=0)
        # 路径保存
        paths.append(idxs.tolist())
    
    best_path = [0] * seq_len
    best_path[-1] = np.argmax(scores)
    # 最优路径回溯
    for i in range(seq_len-2, -1, -1):
        idx = best_path[i+1]
        best_path[i] = paths[i][idx]
  
    return best_path

代码中针对scores和observe的repeat复制操作,是为了方便矩阵运算,减少循环的操作。

如果将M = scores_repeat + trans + observe_repeat,展开用for循环写,在t时刻M[i][j] = scores[i] + trans[i][j] + observe[j],M[i][j]表示从t-1时刻为i-1状态,t时刻为j-1状态的分数。

viterbi_introduction_19.webp

下面就是展开用for循环一步一步求解的伪码。

# 每个时刻scores更新的伪码
for t in range(1, seq_len):
	tmp_scores = scores
	for j in range(nums_labels):
		for i in range(nums_labels):
			M[i][j] = scores[i] + trans[i][j] + observe[t][j]
		tmp_scores[j] = max(M[i][j]) (0 <= i < nums_labels)
	scores = tmp_scores 

可以利用矩阵计算的原理,合并一些步骤。

for t in range(1, seq_len):
    scores_repeat = np.repeat(scores, num_labels, axis=1)
    observe = nodes[t].reshape((1, -1))
    observe_repeat = np.repeat(observe, num_labels, axis=0)
    M = scores_repeat + trans + observe_repeat
    scores = np.max(M, axis=0).reshape((-1, 1))

这里还有对scores和observe进行复制的操作,numpy运算中还可以在简化。

for t in range(1, seq_len):
     observe = nodes[t].reshape((1, -1))
     M = scores + trans + observe
     scores = np.max(M, axis=0).reshape((-1, 1))

numpy在相加时可以自动扩充维度,横向和纵向都可以。

viterbi_introduction_20.webp

经过简化的viterbi算法完整版。

def viterbi_decode_v2(nodes, trans):
    """
    Viterbi算法求最优路径v2
    其中 nodes.shape=[seq_len, num_labels],
        trans.shape=[num_labels, num_labels].
    """
    seq_len, num_labels = len(nodes), len(trans)
    scores = nodes[0].reshape((-1, 1))
    paths = []
    # # 递推求解上一时刻t-1到当前时刻t的最优
    for t in range(1, seq_len):
        observe = nodes[t].reshape((1, -1))
        M = scores + trans + observe
        scores = np.max(M, axis=0).reshape((-1, 1))
        idxs = np.argmax(M, axis=0)
        paths.append(idxs.tolist())

    best_path = [0] * seq_len
    best_path[-1] = np.argmax(scores)
    # 最优路径回溯
    for i in range(seq_len-2, -1, -1):
        idx = best_path[i+1]
        best_path[i] = paths[i][idx]
    
    return best_path

还有一种写法,最后不用回溯,每次把最优路径的索引都保存起来,并添加一个正常的路径,最后直接按索引找出最优路径,这个代码很少,但是不太好理解。

def viterbi_decode_v3(nodes, trans):
    """
    Viterbi算法求最优路径
    其中 nodes.shape=[seq_len, num_labels],
        trans.shape=[num_labels, num_labels].
    """
    seq_len, num_labels = len(nodes), len(trans)
    labels = np.arange(num_labels).reshape((1, -1))
    scores = nodes[0].reshape((-1, 1))
    paths = labels
    for t in range(1, seq_len):
        observe = nodes[t].reshape((1, -1))
        M = scores + trans + observe
        scores = np.max(M, axis=0).reshape((-1, 1))
        idxs = np.argmax(M, axis=0)
        paths = np.concatenate([paths[:, idxs], labels], 0)
    best_path = paths[:, scores.argmax()]
    return best_path