debug tensorflow的seq2seq的attention_decoder方法(debug mode)

网友投稿 250 2022-08-24


debug tensorflow的seq2seq的attention_decoder方法(debug mode)

写这个attention_decoder的testcase来用debug的方式看看注意力机制的实现

import tensorflow as tffrom tensorflow.python.ops import rnnfrom tensorflow.python.ops import rnn_cellfrom tensorflow.contrib.legacy_seq2seq.python.ops import seq2seq as seq2seq_libwith tf.Session() as sess: batch_size = 16 step1 = 20 step2 = 10 input_size = 50 output_size = 40 gru_hidden = 30 cell_fn = lambda: rnn_cell.GRUCell(gru_hidden) cell = cell_fn() inp = [tf.constant(0.8, shape=[batch_size, input_size])] * step1 enc_outputs, enc_state = rnn.static_rnn(cell, inp, dtype=tf.float32) attn_states = tf.concat([ tf.reshape(e, [-1, 1, cell.output_size]) for e in enc_outputs ], 1) dec_inp = [tf.constant(0.3, shape=[batch_size, output_size])] * step2 dec, mem = seq2seq_lib.attention_decoder( dec_inp, enc_state, attn_states, cell_fn(), output_size=7) sess.run([tf.global_variables_initializer()]) res = sess.run(dec) print(len(res)) print(res[0].shape) res = sess.run([mem]) print(len(res)) print(res[0].shape)

改编自​​https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/seq2seq_test.py​​


版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:Scikit Learn CountVectorizer 入门实例(scikit什么意思)
下一篇:java字符串的替换replace、replaceAll、replaceFirst的区别说明
相关文章

 发表评论

暂时没有评论,来抢沙发吧~