如何用TensorFlow构建RNN?这里有一份极简的教程(2)
2017-04-30 编辑:
在这段代码中,我们通过计算current_input Wa + current_state Wbin,得到两个仿射变换的总和input_and_state_concatenated。在连接这两个张量后,只用了一个矩阵乘法即可在每个批次中添加所有样本的偏置b。
图5:第8行代码的矩阵计算示意图,省略了非线性变换arctan。
你可能会想知道变量truncated_backprop_lengthis的作用。在训练时,RNN被看做是一种在每一层都有冗余权重的深层神经网络。在训练开始时,这些层由于展开后占据了太多的计算资源,因此要在有限的时间步内截断。在每个批次训练时,网络误差反向传播了三次。
计算Loss
这是计算图的最后一部分,我们建立了一个从状态到输出的全连接层,用于softmax分类,标签采用One-hot编码,用于计算每个批次的Loss。
logits_series = [tf.matmul(state, W2) + b2 forstate instates_series] #Broadcasted additionpredictions_series = [tf.nn.softmax(logits) forlogits inlogits_series]losses = [tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels) forlogits, labels inzip(logits_series,labels_series)]total_loss = tf.reduce_mean(losses)train_step = tf.train.AdagradOptimizer( 0.3).minimize(total_loss)
最后一行是添加训练函数,TensorFlow将自动执行反向传播函数:对每批数据执行一次计算图,并逐步更新网络权重。
这里调用的tosparse_softmax_cross_entropy_with_logits函数,能在内部算得softmax函数值后,继续计算交叉熵。在示例中,各类是互斥的,非0即1,这也是将要采用稀疏自编码的原因。标签的格式为[batch_size,num_classes]。
可视化结果
我们利用可视化功能tensorboard,在训练过程中观察网络训练情况。它将会在时间维度上绘制Loss值,显示在训练批次中数据输入、数据输出和网络结构对不同样本的实时预测效果。
defplot(loss_list, predictions_series, batchX, batchY):plt.subplot( 2, 3, 1) plt.cla() plt.plot(loss_list) forbatch_series_idx inrange( 5): one_hot_output_series = np.array(predictions_series)[:, batch_series_idx, :] single_output_series = np.array([( 1ifout[ 0] < 0.5else0) forout inone_hot_output_series]) plt.subplot( 2, 3, batch_series_idx + 2) plt.cla() plt.axis([ 0, truncated_backprop_length, 0, 2]) left_offset = range(truncated_backprop_length) plt.bar(left_offset, batchX[batch_series_idx, :], width= 1, color= "blue") plt.bar(left_offset, batchY[batch_series_idx, :] * 0.5, width= 1, color= "red") plt.bar(left_offset, single_output_series * 0.3, width= 1, color= "green") plt.draw() plt.pause( 0.0001) 建立训练会话
已经完成构建网络的工作,开始训练网络。在TensorFlow中,该计算图会在一个会话中执行。在每一步开始时,都会随机生成新的数据。
withtf.Session() assess: sess.run(tf.initialize_all_variables()) plt.ion() plt.figure() plt.show() loss_list = [] forepoch_idx inrange(num_epochs): x,y = generateData() _current_state = np.zeros((batch_size, state_size)) print( "New data, epoch", epoch_idx) forbatch_idx inrange(num_batches): start_idx = batch_idx * truncated_backprop_length end_idx = start_idx + truncated_backprop_length batchX = x[:,start_idx:end_idx] batchY = y[:,start_idx:end_idx] _total_loss, _train_step, _current_state, _predictions_series = sess.run( [total_loss, train_step, current_state, predictions_series], feed_dict={ batchX_placeholder:batchX, batchY_placeholder:batchY, init_state:_current_state }) loss_list.append(_total_loss) ifbatch_idx% 100== 0: print( "Step",batch_idx, "Loss", _total_loss) plot(loss_list, _predictions_series, batchX, batchY)plt.ioff()plt.show()
从第15-19行可以看出,在每次迭代中往前移动truncated_backprop_length步,但可能有不同的stride值。这样做的缺点是,为了封装相关的训练数据,truncated_backprop_length的值要显著大于时间依赖值(本文中为3步),否则可能会丢失很多有效信息,如图6所示。
图6:数据示意图
我们用多个正方形来代表时间序列,上升的黑色方块表示回波输出,由输入回波(黑色方块)经过三次激活后得到。滑动批处理窗口在每次运行时也滑动了三次,在示例中之前没有任何批数据,用来封装依赖关系,因此它不能进行训练。
请注意,本文只是用一个简单示例解释了RNN如何工作,可以轻松地用几行代码中来实现此网络。此网络将能够准确地了解回声行为,因此不需要任何测试数据。
在训练过程中,该程序实时更新图表,如图7所示。蓝色条表示用于训练的输入信号,红色条表示训练得到的输出回波,绿色条是RNN网络产生的预测回波。不同的条形图显示了在当前批次中多个批数据的预测回波。
我们的算法能很快地完成训练任务。左上角的图表输出了损失函数,但为什么曲线上有尖峰?答案就在下面。
图7:各图分别为Loss,训练的输入和输出数据(蓝色和红色)以及预测回波(绿色)。
尖峰的产生原因是在新的迭代开始时,会产生新的数据。由于矩阵重构,每行上的第一个元素与上一行中的最后一个元素会相邻。但是所有行中的前几个元素(第一个除外)都具有不包含在该状态中的依赖关系,因此在最开始的批处理中,网络的预测功能不良。
整个程序
这是完整实现RNN网络的程序,只需复制粘贴即可运行。如果对文章有什么疑问,欢迎加量子位小助手qbitbot,注明“加入门群”并做个自我介绍,小助手将带你和更多小伙伴交流讨论。