def output(self, rnn_output): return melt.dense(rnn_output, self.w, self.v)
def __init__(self, is_training=True, is_predict=False): self.scope = 'rnn' self.is_training = is_training self.is_predict = is_predict assert not (FLAGS.decode_copy and FLAGS.decode_use_alignment) vocabulary.init() vocab_size = vocabulary.get_vocab_size() self.vocab_size = vocab_size self.end_id = vocabulary.end_id() self.start_id = None self.get_start_id() assert self.end_id != vocabulary.vocab.unk_id( ), 'input vocab generated without end id' self.emb_dim = emb_dim = FLAGS.emb_dim #--- for perf problem here exchange w_t and w https://github.com/tensorflow/tensorflow/issues/4138 self.num_units = num_units = FLAGS.rnn_hidden_size with tf.variable_scope('output_projection'): self.w_t = melt.variable.get_weights_truncated( 'w_t', [vocab_size, num_units], stddev=FLAGS.weight_stddev) #weights self.w = tf.transpose(self.w_t) #biases self.v = melt.variable.get_weights_truncated( 'v', [vocab_size], stddev=FLAGS.weight_stddev) #TODO https://github.com/tensorflow/tensorflow/issues/6761 tf 1.0 will fail if not scope='rnn' the same as when using self.cell... self.cell = melt.create_rnn_cell(num_units=num_units, is_training=is_training, keep_prob=FLAGS.keep_prob, num_layers=FLAGS.num_layers, cell_type=FLAGS.cell) self.num_sampled = num_sampled = FLAGS.num_sampled if not ( is_predict and FLAGS.predict_no_sample) else 0 #self.softmax_loss_function is None means not need sample self.softmax_loss_function = None if FLAGS.gen_only: self.softmax_loss_function = melt.seq2seq.gen_sampled_softmax_loss_function( num_sampled, self.vocab_size, weights=self.w_t, biases=self.v, log_uniform_sample=FLAGS.log_uniform_sample, is_predict=self.is_predict, sample_seed=FLAGS.predict_sample_seed, vocabulary=vocabulary) if FLAGS.use_attention: print('----attention_option:', FLAGS.attention_option) if FLAGS.gen_copy_switch or FLAGS.gen_copy or FLAGS.copy_only: assert FLAGS.use_attention is True, 'must use attention if not gen_only mode seq2seq' FLAGS.gen_only = False if FLAGS.gen_copy_switch: print('-------gen copy switch mode!') FLAGS.gen_copy = False FLAGS.copy_only = False elif FLAGS.gen_copy: print('-------gen copy mode !') FLAGS.copy_only = False else: print('-------copy only mode !') else: print('--------gen only mode') #if use copy mode use score as alignment(no softmax) self.score_as_alignment = False if FLAGS.gen_only else True #gen only output_fn self.output_fn = lambda cell_output: melt.dense( cell_output, self.w, self.v) def copy_output(indices, batch_size, cell_output, cell_state): alignments = cell_state.alignments updates = alignments return tf.scatter_nd(indices, updates, shape=[batch_size, self.vocab_size]) self.copy_output_fn = copy_output #one problem is big memory for large vocabulary def gen_copy_output(indices, batch_size, cell_output, cell_state): gen_logits = self.output_fn(cell_output) copy_logits = copy_output(indices, batch_size, cell_output, cell_state) if FLAGS.gen_copy_switch: gen_probability = cell_state.gen_probability #[batch_size, 1] * [batch_size, vocab_size] return gen_probability * tf.nn.softmax(gen_logits) + ( 1 - gen_probability) * tf.nn.softmax(copy_logits) else: return gen_logits + copy_logits self.gen_copy_output_fn = gen_copy_output def gen_copy_output_train(time, indices, targets, sampled_values, batch_size, cell_output, cell_state): if self.softmax_loss_function is not None: labels = tf.slice(targets, [0, time], [-1, 1]) sampled, true_expected_count, sampled_expected_count = sampled_values sampled_values = \ sampled, tf.slice(tf.reshape(true_expected_count, [batch_size, -1]), [0, time], [-1, 1]), sampled_expected_count sampled_ids, sampled_logits = melt.nn.compute_sampled_ids_and_logits( weights=self.w_t, biases=self.v, labels=labels, inputs=cell_output, num_sampled=self.num_sampled, num_classes=self.vocab_size, sampled_values=sampled_values, remove_accidental_hits=False) gen_indices = melt.batch_values_to_indices( tf.to_int32(sampled_ids)) gen_logits = tf.scatter_nd(gen_indices, sampled_logits, shape=[batch_size, self.vocab_size]) else: gen_logits = self.output_fn(cell_output) copy_logits = copy_output(indices, batch_size, cell_output, cell_state) if FLAGS.gen_copy_switch: #gen_copy_switch == True. gen_probability = cell_state.gen_probability return gen_probability * tf.nn.softmax(gen_logits) + ( 1 - gen_probability) * tf.nn.softmax(copy_logits) else: return gen_logits + copy_logits self.gen_copy_output_train_fn = gen_copy_output_train