动态RNN的小试验

In [36]:
import tensorflow as tf
import numpy as np

1)tensorflow rnn的输入至少为三维变量

In [37]:
X1=np.random.randn(2,10,1)
X8=np.random.randn(2,10,8)

X_lengths=[10,10]
In [38]:
print("X1 shape:",X1.shape)
print("X1=")
print(X1)

print("===============================================================")
print("X8 shape:",X8.shape)
print("X8=")
print(X8)
X1 shape: (2, 10, 1)
X1=
[[[ 0.21863186]
  [-1.65295463]
  [ 1.56655155]
  [ 0.5222816 ]
  [ 0.40690083]
  [ 0.26172606]
  [ 0.3209287 ]
  [ 1.13719127]
  [-0.44942696]
  [-1.79406094]]

 [[-0.4786805 ]
  [-0.31242633]
  [-0.80590554]
  [-1.18652789]
  [-0.56742282]
  [-0.1177505 ]
  [-0.3252891 ]
  [ 0.6774859 ]
  [-0.08408594]
  [ 0.1428919 ]]]
===============================================================
X8 shape: (2, 10, 8)
X8=
[[[ 0.6993111  -1.03759525  0.96471712  0.4055671   0.06724576 -0.73234351
   -1.65259831 -0.33445168]
  [-0.76294107  0.27792604  1.80677533 -0.54898853 -0.15724918 -0.99214043
   -0.28294209 -0.1492098 ]
  [-1.08052042  0.7962418   0.79272978 -0.12839983  0.52012404 -0.56452167
    0.25166734  0.38963217]
  [ 0.1729335   0.15945848 -1.65587935  0.32703981  0.14747156  0.35771881
   -0.11315079 -1.87676416]
  [-0.89611976 -0.31080159 -0.2799201  -1.10999613 -1.91726575  0.28003694
    0.02116266 -0.8215498 ]
  [ 1.16148319 -0.19632122  0.93964647 -1.46916089  0.7555181   0.54017628
   -0.8490647  -0.33395773]
  [ 0.37462454  1.0941274   1.95651376  0.5086241  -0.57443902  0.02805234
    0.2186624   0.0224864 ]
  [-2.2021759   0.95334633 -0.98244747  1.09360796  1.47853456  0.59936369
    1.79625866  0.41004887]
  [ 1.11163804  0.67475061  0.94647743  0.41051424 -0.16879005 -0.45637061
   -3.44572894 -1.12267569]
  [-0.34956491 -1.29691414  0.54098089 -1.1670317   0.39357543  0.68294119
    0.4732811   0.30467678]]

 [[-0.227536    1.42238399  0.55889734  0.95329206  0.53208999  1.05176877
    0.16059461  0.99112068]
  [-1.13586647  1.77323792  0.98764734  0.00603953  0.25806347  0.95925656
    0.20635804 -0.4331286 ]
  [ 0.46437875  0.4359869   0.27884278 -1.41788015 -0.24709798 -0.23222798
    0.12371822  0.41427335]
  [-1.10779625 -0.22371667  0.34149574 -0.90742223  1.93596188 -0.68002773
   -0.95494143  0.19367702]
  [-0.08433352 -1.42185208  0.25863714 -0.26926121 -0.12735508 -1.33432228
   -0.84560071 -0.84693066]
  [-0.70293446  0.23442055 -1.42384767 -1.52857024 -2.10605749  0.11991908
    1.61639467 -0.11030921]
  [ 0.77732251  0.02204307 -0.15399274  0.19140287 -2.04614395  0.09435536
   -1.72564799 -0.22644221]
  [-1.7679457   0.30533081  0.43704707  0.97786031 -0.66457351  1.56123384
    0.0213741  -0.7309844 ]
  [ 0.95032232 -1.43640063  0.21010823  0.25951039 -0.65137272 -1.80788741
   -1.93689171  0.69941073]
  [-0.96610284  0.19763549 -1.98497871  0.80398599 -0.05450722  0.99892764
    0.88444381 -0.43868934]]]

2)定义动态RNN模型

In [39]:
def models(inputs,seq_len):
    tf.reset_default_graph()
    #定义LSTM基本单元
    cell=tf.contrib.rnn.BasicLSTMCell(num_units=64,state_is_tuple=True)
    #定义动态RNN
    outputs,last_states=tf.nn.dynamic_rnn(cell=cell,dtype=tf.float64,sequence_length=seq_len,inputs=inputs)
    return outputs,last_states

2.1)分析模型的输出

outputs的输出第一维和第二维跟输入一样,第三维是隐藏单元的维数 last_states输出LSTMStateTuple,LSTMStateTuple包含RNN的状态单元c和最后的输出单元h,h和outputs的最后的输出一样

In [40]:
outputs,last_states=models(X1,X_lengths)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    o,l_s=sess.run([outputs,last_states])
    
    print("outputs shape:",o.shape)
    print("last state shape is LSTMStateTuple")
    
    print(o)
    print(l_s)
outputs shape: (2, 10, 64)
last state shape is LSTMStateTuple
[[[ 0.00626649  0.0009611   0.00534746 ...,  0.00208865 -0.00649064
    0.00357219]
  [-0.04391281 -0.0073861  -0.0437953  ..., -0.01259916  0.03978827
   -0.02345949]
  [ 0.01417123  0.00119387  0.0075347  ...,  0.01026246 -0.0066327
    0.00558009]
  ..., 
  [ 0.05204537  0.00746746  0.04042266 ...,  0.01179987 -0.06793874
    0.03744727]
  [ 0.02566657  0.00302749  0.02012766 ..., -0.00151092 -0.03868935
    0.02285022]
  [-0.03452437 -0.00499901 -0.03565577 ..., -0.01916861  0.02248779
   -0.01072146]]

 [[-0.01392682 -0.00217219 -0.01255407 ..., -0.00432957  0.01379682
   -0.00775427]
  [-0.01909943 -0.00270576 -0.01685329 ..., -0.00491802  0.02104834
   -0.01147927]
  [-0.03757103 -0.00548016 -0.03429469 ..., -0.00885125  0.03959536
   -0.02239637]
  ..., 
  [-0.01144903 -0.00637778 -0.01224874 ...,  0.01448941  0.02699594
   -0.01665722]
  [-0.01015557 -0.00796074 -0.01272366 ...,  0.01262847  0.02023091
   -0.01360148]
  [-0.00242926 -0.00788853 -0.00628072 ...,  0.01278608  0.01046499
   -0.00716096]]]
LSTMStateTuple(c=array([[-0.0652685 , -0.01030961, -0.06694325, -0.04646335, -0.06420023,
         0.01402679,  0.00767705, -0.07669409, -0.02672263, -0.07835473,
        -0.04611041, -0.00898107,  0.05946486, -0.02540907, -0.02328935,
         0.00620828,  0.02873059,  0.01682825,  0.02212377,  0.02181499,
        -0.0392472 ,  0.02476714,  0.07351189,  0.02137753, -0.04988622,
         0.01139808,  0.05024492, -0.04926311,  0.06373495, -0.06927173,
        -0.01070558,  0.07093576,  0.0391095 ,  0.08635727,  0.0195561 ,
         0.00593428,  0.064782  ,  0.04987822, -0.03134515, -0.07507392,
         0.01878588, -0.06646294,  0.07636915, -0.00520318, -0.02878028,
         0.03327519,  0.05448371,  0.02251081,  0.01642887, -0.06545384,
        -0.00447417, -0.03576494, -0.07626746,  0.02912251,  0.0634632 ,
         0.0163746 ,  0.10320128,  0.01024154, -0.05213877, -0.03146701,
         0.02218021, -0.04005124,  0.04889002, -0.02135456],
       [-0.00491765, -0.01570205, -0.0125662 , -0.02518407,  0.01415194,
         0.01469218, -0.01008905, -0.00752436, -0.03140909,  0.00477025,
         0.02931672,  0.02952346,  0.00639036,  0.02130706, -0.07305194,
         0.0305481 ,  0.01031087, -0.01953439, -0.04294066,  0.01615684,
        -0.00436108,  0.02384418, -0.03033282, -0.00848928,  0.0180987 ,
         0.02033322, -0.00631905, -0.02345694, -0.0199582 , -0.01320774,
        -0.00347512, -0.01415094,  0.03584482, -0.00589683, -0.01736072,
        -0.00710504, -0.02238686,  0.01384733,  0.00138563,  0.03839992,
         0.02976193,  0.04380168, -0.01957645, -0.04878634, -0.01660236,
         0.02167042,  0.03429299, -0.02857568,  0.03398872,  0.02104907,
        -0.02021614, -0.01701037,  0.01056921, -0.00673481, -0.02695598,
        -0.06012242, -0.01838759,  0.00632867,  0.00585587, -0.01305196,
         0.02342279,  0.02583127,  0.02080014, -0.01439244]]), h=array([[-0.03452437, -0.00499901, -0.03565577, -0.02120707, -0.03008585,
         0.00714468,  0.00391893, -0.04070923, -0.01483002, -0.03875778,
        -0.02450672, -0.0047048 ,  0.03246713, -0.01138599, -0.01256368,
         0.00282007,  0.01513757,  0.0087677 ,  0.01230584,  0.0113465 ,
        -0.01776578,  0.01243137,  0.03420777,  0.01177373, -0.02599306,
         0.00542281,  0.02352464, -0.02353532,  0.0327239 , -0.03403247,
        -0.00512684,  0.03802466,  0.01828889,  0.04549279,  0.00904329,
         0.00286779,  0.03548155,  0.02792373, -0.01695042, -0.04235431,
         0.00843343, -0.031168  ,  0.03949962, -0.00278506, -0.01446888,
         0.0159831 ,  0.02754722,  0.01230363,  0.00923642, -0.02948304,
        -0.00237604, -0.01684213, -0.03930688,  0.01568881,  0.03492925,
         0.00765215,  0.05713425,  0.00561899, -0.02499399, -0.01659339,
         0.01182197, -0.01916861,  0.02248779, -0.01072146],
       [-0.00242926, -0.00788853, -0.00628072, -0.01272494,  0.00713226,
         0.00737854, -0.00505453, -0.00375309, -0.01560966,  0.00236942,
         0.01456126,  0.01471894,  0.00318282,  0.01074783, -0.03632483,
         0.01550408,  0.00513769, -0.00972087, -0.02115354,  0.00810704,
        -0.00220623,  0.0120104 , -0.01513617, -0.00423673,  0.00900583,
         0.0102075 , -0.00317651, -0.01170487, -0.00999592, -0.00667419,
        -0.00174175, -0.00706933,  0.01787125, -0.00294105, -0.00868059,
        -0.00357683, -0.01108803,  0.00683414,  0.00069027,  0.01889004,
         0.01484512,  0.02202374, -0.00974516, -0.02429408, -0.00825032,
         0.01085691,  0.01707769, -0.01425047,  0.01676392,  0.01057519,
        -0.01003287, -0.00858159,  0.00531016, -0.00334643, -0.01342008,
        -0.03029744, -0.00904571,  0.00312976,  0.00291355, -0.00646035,
         0.0116739 ,  0.01278608,  0.01046499, -0.00716096]]))
In [41]:
print(l_s[1].shape)
print(l_s[1])
(2, 64)
[[-0.03452437 -0.00499901 -0.03565577 -0.02120707 -0.03008585  0.00714468
   0.00391893 -0.04070923 -0.01483002 -0.03875778 -0.02450672 -0.0047048
   0.03246713 -0.01138599 -0.01256368  0.00282007  0.01513757  0.0087677
   0.01230584  0.0113465  -0.01776578  0.01243137  0.03420777  0.01177373
  -0.02599306  0.00542281  0.02352464 -0.02353532  0.0327239  -0.03403247
  -0.00512684  0.03802466  0.01828889  0.04549279  0.00904329  0.00286779
   0.03548155  0.02792373 -0.01695042 -0.04235431  0.00843343 -0.031168
   0.03949962 -0.00278506 -0.01446888  0.0159831   0.02754722  0.01230363
   0.00923642 -0.02948304 -0.00237604 -0.01684213 -0.03930688  0.01568881
   0.03492925  0.00765215  0.05713425  0.00561899 -0.02499399 -0.01659339
   0.01182197 -0.01916861  0.02248779 -0.01072146]
 [-0.00242926 -0.00788853 -0.00628072 -0.01272494  0.00713226  0.00737854
  -0.00505453 -0.00375309 -0.01560966  0.00236942  0.01456126  0.01471894
   0.00318282  0.01074783 -0.03632483  0.01550408  0.00513769 -0.00972087
  -0.02115354  0.00810704 -0.00220623  0.0120104  -0.01513617 -0.00423673
   0.00900583  0.0102075  -0.00317651 -0.01170487 -0.00999592 -0.00667419
  -0.00174175 -0.00706933  0.01787125 -0.00294105 -0.00868059 -0.00357683
  -0.01108803  0.00683414  0.00069027  0.01889004  0.01484512  0.02202374
  -0.00974516 -0.02429408 -0.00825032  0.01085691  0.01707769 -0.01425047
   0.01676392  0.01057519 -0.01003287 -0.00858159  0.00531016 -0.00334643
  -0.01342008 -0.03029744 -0.00904571  0.00312976  0.00291355 -0.00646035
   0.0116739   0.01278608  0.01046499 -0.00716096]]
In [42]:
print(o[0,9].shape)
print(o[0,9])
(64,)
[-0.03452437 -0.00499901 -0.03565577 -0.02120707 -0.03008585  0.00714468
  0.00391893 -0.04070923 -0.01483002 -0.03875778 -0.02450672 -0.0047048
  0.03246713 -0.01138599 -0.01256368  0.00282007  0.01513757  0.0087677
  0.01230584  0.0113465  -0.01776578  0.01243137  0.03420777  0.01177373
 -0.02599306  0.00542281  0.02352464 -0.02353532  0.0327239  -0.03403247
 -0.00512684  0.03802466  0.01828889  0.04549279  0.00904329  0.00286779
  0.03548155  0.02792373 -0.01695042 -0.04235431  0.00843343 -0.031168
  0.03949962 -0.00278506 -0.01446888  0.0159831   0.02754722  0.01230363
  0.00923642 -0.02948304 -0.00237604 -0.01684213 -0.03930688  0.01568881
  0.03492925  0.00765215  0.05713425  0.00561899 -0.02499399 -0.01659339
  0.01182197 -0.01916861  0.02248779 -0.01072146]
In [43]:
print(o[1,9].shape)
print(o[0,9])
(64,)
[-0.03452437 -0.00499901 -0.03565577 -0.02120707 -0.03008585  0.00714468
  0.00391893 -0.04070923 -0.01483002 -0.03875778 -0.02450672 -0.0047048
  0.03246713 -0.01138599 -0.01256368  0.00282007  0.01513757  0.0087677
  0.01230584  0.0113465  -0.01776578  0.01243137  0.03420777  0.01177373
 -0.02599306  0.00542281  0.02352464 -0.02353532  0.0327239  -0.03403247
 -0.00512684  0.03802466  0.01828889  0.04549279  0.00904329  0.00286779
  0.03548155  0.02792373 -0.01695042 -0.04235431  0.00843343 -0.031168
  0.03949962 -0.00278506 -0.01446888  0.0159831   0.02754722  0.01230363
  0.00923642 -0.02948304 -0.00237604 -0.01684213 -0.03930688  0.01568881
  0.03492925  0.00765215  0.05713425  0.00561899 -0.02499399 -0.01659339
  0.01182197 -0.01916861  0.02248779 -0.01072146]

3)对输出进行flatten,会变成[batch_size,维数]二维矩阵

In [44]:
#outputs,last_states=models(X1,X_lengths)
#logits=tf.contrib.layers.fully_connected(outputs,1)
logits=tf.contrib.layers.flatten(outputs)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    l=sess.run(logits)
    print(l.shape)
    print(l)
(2, 640)
[[-0.00351074  0.00475258  0.00540732 ..., -0.02896367 -0.01072794
  -0.03872284]
 [ 0.00766588 -0.00999981 -0.01216001 ...,  0.03220017 -0.01098931
   0.00163596]]

4) 对输入进行full_connected连接,连接后最后一维的维数由full_connected的num_outputs参数指定

In [45]:
#outputs,last_states=models(X1,X_lengths)
logits=tf.contrib.layers.fully_connected(outputs,num_outputs=2)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    l=sess.run(logits)
    print(l.shape)
    print(l)
(2, 10, 2)
[[[ 0.00668481  0.        ]
  [ 0.          0.        ]
  [ 0.          0.        ]
  [ 0.01440622  0.        ]
  [ 0.02320003  0.        ]
  [ 0.02492307  0.        ]
  [ 0.02750982  0.        ]
  [ 0.05138999  0.        ]
  [ 0.01907409  0.        ]
  [ 0.          0.        ]]

 [[ 0.          0.00012225]
  [ 0.          0.00208919]
  [ 0.          0.00504782]
  [ 0.          0.01052497]
  [ 0.          0.01591998]
  [ 0.          0.01979976]
  [ 0.          0.02253833]
  [ 0.          0.0207775 ]
  [ 0.          0.01808157]
  [ 0.00128833  0.01482664]]]

5) 对logits进行softmax处理

In [46]:
#outputs,last_states=models(X1,X_lengths)
logits=tf.contrib.layers.fully_connected(outputs,num_outputs=2)
probs=tf.nn.softmax(logits)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    p=sess.run(probs)
    print(p.shape)
    print(p)
(2, 10, 2)
[[[ 0.50183033  0.49816967]
  [ 0.49730752  0.50269248]
  [ 0.50426361  0.49573639]
  [ 0.50654158  0.49345842]
  [ 0.50721053  0.49278947]
  [ 0.50661469  0.49338531]
  [ 0.50698204  0.49301796]
  [ 0.5141027   0.4858973 ]
  [ 0.50572122  0.49427878]
  [ 0.5         0.5       ]]

 [[ 0.49905257  0.50094743]
  [ 0.49875216  0.50124784]
  [ 0.49743443  0.50256557]
  [ 0.49560732  0.50439268]
  [ 0.49534884  0.50465116]
  [ 0.49568637  0.50431363]
  [ 0.49489893  0.50510107]
  [ 0.49653682  0.50346318]
  [ 0.49545958  0.50454042]
  [ 0.49541822  0.50458178]]]