隐马尔可夫模型(Hidden Markov Model,HMM)

HMM模型定义

隐马尔可夫模型是关于时序的概率模型,描述由一个隐藏的马尔可夫链生成不可观测的状态随机序列,再由各个状态生成一个观测而产生观测随机序列的过程。隐藏的马尔可夫链随机生成的状态的序列,称为状态序列(state sequence);每个状态生成的一个观测,而由此产生的观测的随机序列,称为观测序列(observation sequence)。序列的每一个位置可以看作是一个时刻。

隐马尔可夫模型由初始状态概率分布向量( $\pi$),状态转移概率矩阵($A$)和状态产生的观测概率矩阵($B$)决定,$\pi$和$A$决定状态序列,$B$决定状态产生的观测序列,因此假设马尔可夫模型为$\lambda$,则: $$\lambda=(A,B,\pi)$$

假设

$Q$是所有可能隐藏状态的集合 ,其中N为状态个数,则 $$Q=\{q_1,q_2,...,q_N\}$$ $V$是所有可能观测的集合,其中M为可能观测数,则 $$V=\{v_1,v_2,...,v_M\}$$ $I$是长度为T的状态序列,$O$是对应的观测序列,则 $$I=\{i_1,i_2,...,i_T\},O=\{o_1,o_2,...,o_T\}$$

那么隐马尔可夫的参数$A$、$B$、$\pi$的定义如下: $$A=[a_{ij}]_{NxN}$$ 其中$a_{ij}=P(i_{t+1}=q_j|i_t=q_i)$,i=1,2,...,N;j=1,2,...,N,表示时刻t处于状态$q_i$的条件下在时刻t+1转移到$q_j$的概率 $$B=[b_j(k)]_{NxM}$$ 其中$b_j(k)=P(o_t=v_k|i_t=q_j)$,k=1,2,...,M;j=1,2,...,N,表示时刻t处于状态$q_j$的条件下生成观测$v_k$的概率 $$\pi=(\pi_i)$$ 其中$\pi_i=P(i_1=q_i)$,i=1,2,...,N,是时刻t=1处于状态$q_i$的概率

隐马尔可夫模型的3个基本问题

  • (1)概率计算问题:给定模型 $\lambda=(A,B,\pi)$ 和观测序列$O=(o_1,o_2,...,o_t)$,计算在模型$\lambda$下观测序列O出现的概率$p(O|\lambda)$,求解算法有前向算法和后向算法
  • (2)学习问题:已知观测序列$O=(o_1,o_2,...,o_t)$,估计模型$\lambda=(A,B,\pi)$的参数,使得在该模型下观测序列$P(O|\lambda)$最大,即用极大似然估计的方法估计参数,求解算法有监督学习算法和EM算法
  • (3)预测问题:也称为解码(decoding)问题,已知模型$\lambda=(A,B,\pi)$和观测序列$O=(o_1,o_2,...,o_t)$,求对给定观测序列条件概率$P(I|O)$最大的状态序列$I=(i_1,i_2,...,i_t)$,求解算法有维特比算法

具体请参见斯坦福的材料

3)预测问题:维特比算法实现

预测问题就是寻找最有可能产生可观擦序列的隐藏状态序列,维特比算法和forward算法比较相似,维特比算法在每一步取最大的转移概率,而foward在每一步对概率求和。 假设: $$\delta_t(i)=\underset{q_1,q_2,...,q_{t-1}}{\max}P(q_1q_2...q_t=s_i,o_1,o_2,...o_t|\lambda)$$

维特比(Viterbi)算法:

1.初始化 $$\delta_1(i)=\pi_ib_i(o_1),1\leq i \leq N,\psi_1(i)=0$$

2.递归 $$\delta_t(j)=\underset{1 \leq i \leq N}{\max}[\delta_{t-1}(i)a_{ij}]b_j(o_t),2 \leq t \leq T,1 \leq j \leq N$$

$$\psi_t(j)=\arg\underset{1 \leq i \leq N}{\max}[\delta_{t-1}(i)a_{ij}],2 \leq t \leq T,1 \leq j \leq N$$

3.中止 $$P=\underset{1 \leq i \leq N}{\max}[\delta_T(i)]$$ $$q_T=\arg \underset{1 \leq i \leq N}{\max}[\delta_T(i)]$$

4.回溯(backtracking) $$q_t=\psi_{t+1}(q_{t+1}),t=T-1,T-2,...,1$$

In [38]:
import numpy as np
def Viterbi(pi,a,b,obs):
    """
    pi:初始概率
    a:状态转移矩阵
    b:发射概率矩阵
    obs:可观察系列
    """
    #隐藏状态
    nStates = np.shape(b)[0]
    #时间步长
    T = np.shape(obs)[0]
    #隐藏状态序列
    path = np.zeros(T)
    delta = np.zeros((nStates,T))
    phi = np.zeros((nStates,T))

    #obs[0]表示第一个可观擦的状态
    #b[:,obs[0]]表示每种隐藏状态转为第一可观察状态的概率
    delta[:,0] = pi * b[:,obs[0]]
    phi[:,0] = 0

    for t in range(1,T):
        for s in range(nStates):
            #取每步状态的概率值,然后转到下一步的可观察状态
            delta[s,t] = np.max(delta[:,t-1]*a[:,s])*b[s,obs[t]]
            #保留每步最大概率的状态
            phi[s,t] = np.argmax(delta[:,t-1]*a[:,s])

    path[T-1] = np.argmax(delta[:,T-1])
    for t in range(T-2,-1,-1):
        path[t] = phi[int(path[t+1]),t+1]

    return path,delta, phi
In [39]:
def test():
    np.random.seed(4)
    pi = np.array([0.25,0.25,0.25,0.25])
    aLast = np.array([0.25,0.25,0.25,0.25])
    #a = np.array([[.7,.3],[.4,.6]] )
    a = np.array([[.4,.3,.1,.2],[.6,.05,.1,.25],[.7,.05,.05,.2],[.3,.4,.25,.05]])
    #b = np.array([[.2,.4,.4],[.5,.4,.1]] )
    b = np.array([[.2,.1,.2,.5],[.4,.2,.1,.3],[.3,.4,.2,.1],[.3,.05,.3,.35]])
    obs = np.array([0,0,3,1,1,2,1,3])
    #obs = np.array([2,0,2])
    path,delta,phi=Viterbi(pi,a,b,obs)
    print("path=",path)
    print("delta=",delta)
    print("phi=",phi)
    
test()
path= [3. 1. 0. 1. 0. 3. 2. 0.]
delta= [[5.00000e-02 1.20000e-02 3.60000e-03 1.44000e-04 1.29600e-05 1.20960e-06
  4.83840e-08 2.72160e-08]
 [1.00000e-01 1.20000e-02 1.08000e-03 2.16000e-04 8.64000e-06 3.88800e-07
  7.25760e-08 4.35456e-09]
 [7.50000e-02 5.62500e-03 1.87500e-04 1.44000e-04 8.64000e-06 2.59200e-07
  7.77600e-08 7.25760e-10]
 [7.50000e-02 7.50000e-03 1.05000e-03 3.60000e-05 2.70000e-06 7.77600e-07
  1.20960e-08 6.35040e-09]]
phi= [[0. 1. 1. 0. 1. 2. 0. 2.]
 [0. 3. 0. 0. 0. 0. 0. 0.]
 [0. 3. 3. 0. 1. 0. 3. 1.]
 [0. 1. 1. 0. 1. 0. 0. 1.]]

1)概率计算问题

forward算法计算过程:

1.初始化 $$\alpha_1(i)=\pi_ib_i(o_1),1 \leq i \leq N$$

2.归纳 $$\alpha_{t+1}(j)=[\sum_{i=1}^N\alpha_t(i)a_{ij}]b_j(o_{t+1}),1 \leq t \leq T-1,1 \leq j \leq N$$

3.中止 $$P(O|\lambda)=\sum_{i=1}^N\alpha_T(i)$$

In [40]:
scaling = False

def HMMfwd(pi,a,b,obs):
    
    nStates = np.shape(b)[0]
    T = np.shape(obs)[0]

    alpha = np.zeros((nStates,T))

    alpha[:,0] = pi*b[:,obs[0]]

    for t in range(1,T):
        for s in range(nStates):
            alpha[s,t] = b[s,obs[t]] * np.sum(alpha[:,t-1] * a[:,s])

    c = np.ones((T))
    if scaling:
        for t in range(T):
            c[t] = np.sum(alpha[:,t])
            alpha[:,t] /= c[t]
    return alpha,c
In [41]:
def HMMbwd(a,b,obs,c):

    nStates = np.shape(b)[0]
    T = np.shape(obs)[0]

    beta = np.zeros((nStates,T))

    beta[:,T-1] = 1.0 #aLast

    for t in range(T-2,-1,-1):
        for s in range(nStates):
            beta[s,t] = np.sum(b[:,obs[t+1]] * beta[:,t+1] * a[s,:])

    for t in range(T):
        beta[:,t] /= c[t]
    #beta[:,0] = b[:,obs[0]] * np.sum(beta[:,1] * pi)
    return beta

2)学习问题,也就是参数估计问题

分监督学习方法和非监督学习方法,如果样本包含隐藏状态序列和对应的可观察序列,可以通过监督方法估计;如果只有可观察序列,可以通过非监督方法估计,就是EM算法。对于监督方法很简单,具体请参见斯坦福的材料,这里忽略,下面实现的是EM算法。

In [42]:
def BaumWelch(obs,nStates):
    """
    obs:可观察序列
    nStates:状态序列
    """

    T = np.shape(obs)[0]
    xi = np.zeros((nStates,nStates,T))

    # Initialise pi, a, b randomly
    pi = 1./nStates*np.ones((nStates))
    a = np.random.rand(nStates,nStates)
    b = np.random.rand(nStates,np.max(obs)+1)

    tol = 1e-5
    error = tol+1
    maxits = 100
    nits = 0
    while ((error > tol) & (nits < maxits)):
        nits += 1
        oldpi = pi.copy()
        olda = a.copy()
        oldb = b.copy()

        # E step
        alpha,c = HMMfwd(pi,a,b,obs)
        beta = HMMbwd(a,b,obs,c) 

        for t in range(T-1):
            for i in range(nStates):
                for j in range(nStates):
                    xi[i,j,t] = alpha[i,t]*a[i,j]*b[j,obs[t+1]]*beta[j,t+1]
            xi[:,:,t] /= np.sum(xi[:,:,t])

        # The last step has no b, beta in
        for i in range(nStates):
            for j in range(nStates):
                xi[i,j,T-1] = alpha[i,T-1]*a[i,j]
        xi[:,:,T-1] /= np.sum(xi[:,:,T-1])
        
        # M step
        for i in range(nStates):
            pi[i] = np.sum(xi[i,:,0])
            for j in range(nStates):
                a[i,j] = np.sum(xi[i,j,:T-1])/np.sum(xi[i,:,:T-1])
            for k in range(max(obs)):
                found = (obs==k).nonzero()
                b[i,k] = np.sum(xi[i,:,found])/np.sum(xi[i,:,:])

        error = (np.abs(a-olda)).max() + (np.abs(b-oldb)).max() 
        print(nits, error, 1./np.sum(1./c), np.sum(alpha[:,T-1]))

    return pi, a, b
In [43]:
def biased_coins():
    a = np.array([[0.4,0.6],[0.9,0.1]])
    b = np.array([[0.49,0.51],[0.85,0.15]])
    pi = np.array([0.5,0.5])

    obs = np.array([0,1,1,0,1,1,0,0,1,1,0,1,1,1,0,0,1,0,0,1,1,0,1,1,1,1,0,1,0,0,1,0,1,0,0,1,1,1,0])
    print( Viterbi(pi,a,b,obs)[0])

    print (BaumWelch(obs,2))
    
    
biased_coins()
[1. 0. 0. 1. 0. 0. 1. 0. 0. 0. 1. 0. 0. 0. 0. 1. 0. 0. 1. 0. 0. 1. 0. 0.
 0. 0. 1. 0. 1. 0. 0. 1. 0. 0. 1. 0. 0. 0. 1.]
1 0.9386213433467061 0.02564102564102564 4.008756231134437e-14
2 0.07852934255224686 0.02564102564102564 1.6314992099443197e-23
3 0.030341608093822803 0.02564102564102564 2.413081098560643e-23
4 0.019048428459106814 0.02564102564102564 2.4521774405500433e-23
5 0.01264642368071034 0.02564102564102564 2.2875609671881943e-23
6 0.010957837136204046 0.02564102564102564 2.084224285270683e-23
7 0.009454681037457946 0.02564102564102564 1.895277414353039e-23
8 0.0081981851965241 0.02564102564102564 1.7324381626240722e-23
9 0.007192945738328571 0.02564102564102564 1.5945489008984826e-23
10 0.006417882805379638 0.02564102564102564 1.477249441224474e-23
11 0.0058412241444092405 0.02564102564102564 1.3760212439518782e-23
12 0.005428555660587198 0.02564102564102564 1.287049494360456e-23
13 0.005147163770812591 0.02564102564102564 1.2073667064836976e-23
14 0.004968154958024523 0.02564102564102564 1.1347667440450871e-23
15 0.004867222993423903 0.02564102564102564 1.0676595812292863e-23
16 0.00495493100074304 0.02564102564102564 1.004927005805197e-23
17 0.005103773023962754 0.02564102564102564 9.45798962147311e-24
18 0.005235488897563173 0.02564102564102564 8.897547429924369e-24
19 0.005353950289643097 0.02564102564102564 8.36447375404165e-24
20 0.005462152850170737 0.02564102564102564 7.85647583484045e-24
21 0.0055624027666997256 0.02564102564102564 7.372033724054916e-24
22 0.005656453619530696 0.02564102564102564 6.91011605890383e-24
23 0.0057456150360361835 0.02564102564102564 6.4699851571260684e-24
24 0.005830842775322517 0.02564102564102564 6.051066906288771e-24
25 0.005912814323112442 0.02564102564102564 5.652866536206514e-24
26 0.005991991826585982 0.02564102564102564 5.274916141178338e-24
27 0.006068673528574542 0.02564102564102564 4.916743673080954e-24
28 0.006143034833729133 0.02564102564102564 4.577856112143838e-24
29 0.0062151602852589805 0.02564102564102564 4.257731759029469e-24
30 0.006285067843288639 0.02564102564102564 3.9558182192968276e-24
31 0.006352726869753494 0.02564102564102564 3.671533804506986e-24
32 0.006418071142474252 0.02564102564102564 3.404270871479124e-24
33 0.006481008071172301 0.02564102564102564 3.1534001596730247e-24
34 0.006541425103885223 0.02564102564102564 2.9182755423210383e-24
35 0.006599194121277574 0.02564102564102564 2.6982388368434634e-24
36 0.006654174438082722 0.02564102564102564 2.492624465681289e-24
37 0.0067062148763138985 0.02564102564102564 2.300763849134354e-24
38 0.006755155248373806 0.02564102564102564 2.1219894670183887e-24
39 0.006800827489449046 0.02564102564102564 1.955638559188448e-24
40 0.006843056604670106 0.02564102564102564 1.8010564547862244e-24
41 0.0068816615431608885 0.02564102564102564 1.657599531841905e-24
42 0.00691645607383784 0.02564102564102564 1.5246378159063494e-24
43 0.006947249712445536 0.02564102564102564 1.4015572306832856e-24
44 0.006973848732427568 0.02564102564102564 1.2877615163228911e-24
45 0.00699605728101696 0.02564102564102564 1.1826738327823372e-24
46 0.007013678614396307 0.02564102564102564 1.0857380668160683e-24
47 0.007026516460447246 0.02564102564102564 9.964198619265353e-25
48 0.007034376513475021 0.02564102564102564 9.14207391088066e-25
49 0.00703706806174062 0.02564102564102564 8.386118923122033e-25
50 0.00703440574522024 0.02564102564102564 7.6916798718090555e-25
51 0.007026211437498597 0.02564102564102564 7.054338023577147e-25
52 0.007012316242007266 0.02564102564102564 6.469909138096301e-25
53 0.006992562588834711 0.02564102564102564 5.934441330480132e-25
54 0.00696680641412642 0.02564102564102564 5.4442115413859585e-25
55 0.006934919399688794 0.02564102564102564 4.995720795521771e-25
56 0.006896791245923142 0.02564102564102564 4.585688421424515e-25
57 0.006852331946754858 0.02564102564102564 4.211045396618352e-25
58 0.006801474030960972 0.02564102564102564 3.8689269727376134e-25
59 0.006744174730399025 0.02564102564102564 3.556664725076551e-25
60 0.006680418032323374 0.02564102564102564 3.271778160451249e-25
61 0.006610216570412275 0.02564102564102564 3.0119660063835562e-25
62 0.006533613307578637 0.02564102564102564 2.7750972935903374e-25
63 0.006450682963267743 0.02564102564102564 2.559202332720151e-25
64 0.006361533138931438 0.02564102564102564 2.36246367535366e-25
65 0.006272296823248466 0.02564102564102564 2.1832071385947603e-25
66 0.006179147218761005 0.02564102564102564 2.0198929622281034e-25
67 0.006079809311762524 0.02564102564102564 1.8711071575036666e-25
68 0.005974517059374199 0.02564102564102564 1.7355530972023638e-25
69 0.005863539519322461 0.02564102564102564 1.6120433878066015e-25
70 0.0057471795747110616 0.02564102564102564 1.4994920563950716e-25
71 0.005625772156587647 0.02564102564102564 1.396907077332769e-25
72 0.00549968197646003 0.02564102564102564 1.3033832569625354e-25
73 0.005369300793738874 0.02564102564102564 1.218095488327963e-25
74 0.00523504425573923 0.02564102564102564 1.1402923824662667e-25
75 0.005097348359875381 0.02564102564102564 1.0692902779934817e-25
76 0.004956665598478682 0.02564102564102564 1.0044676265372616e-25
77 0.004813460855825381 0.02564102564102564 9.452597480258866e-26
78 0.004668207134017263 0.02564102564102564 8.911539468810399e-26
79 0.004521381189028409 0.02564102564102564 8.416849777427182e-26
80 0.00437345916029671 0.02564102564102564 7.964308474363015e-26
81 0.00422491227658886 0.02564102564102564 7.550089384256886e-26
82 0.004076202717613407 0.02564102564102564 7.170724379366129e-26
83 0.003927779705098333 0.02564102564102564 6.823070562338373e-26
84 0.003780075889144724 0.02564102564102564 6.50428017148511e-26
85 0.0036335040859663617 0.02564102564102564 6.211773038342186e-26
86 0.0034884544120881433 0.02564102564102564 5.943211428399657e-26
87 0.0033452918482076294 0.02564102564102564 5.696477098869846e-26
88 0.003204354253731688 0.02564102564102564 5.469650411885383e-26
89 0.003065950840937137 0.02564102564102564 5.260991347247951e-26
90 0.0029303611062775937 0.02564102564102564 5.0689222655033413e-26
91 0.002797834205860234 0.02564102564102564 4.892012279441813e-26
92 0.002668588752933583 0.02564102564102564 4.7289630998894946e-26
93 0.0025428130074957693 0.02564102564102564 4.5785962296876255e-26
94 0.00242066542202864 0.02564102564102564 4.439841387879782e-26
95 0.00230227550290385 0.02564102564102564 4.311726054223744e-26
96 0.002187744944162795 0.02564102564102564 4.193366032093164e-26
97 0.002077148989047529 0.02564102564102564 4.083956935562395e-26
98 0.00197053797469926 0.02564102564102564 3.982766513904873e-26
99 0.0018679390166790877 0.02564102564102564 3.8891277338273446e-26
100 0.0017693577921808826 0.02564102564102564 3.8024325464894023e-26
(array([3.14614908e-105, 1.00000000e+000]), array([[0.49573803, 0.50426197],
       [0.68708636, 0.31291364]]), array([[0.01620177, 0.2160895 ],
       [0.97444184, 0.00623026]]))

参考资料

Machine Learning: An Algorithmic Perspective

斯坦福的材料

李航《统计机器学习》