def _symbols_to_logits_fn(self, embedding, vocab_size, mode, output_layer=None, dtype=None): embedding_fn = get_embedding_fn(embedding) if output_layer is None: output_layer = build_output_layer(self.num_units, vocab_size, dtype=dtype) def _impl(ids, step, cache): inputs = embedding_fn(ids[:, -1:]) inputs *= self.num_units**0.5 inputs = self.position_encoder.apply_one(inputs, step + 1) outputs = self._self_attention_stack(inputs, mode=mode, cache=cache, memory=cache["memory"], memory_sequence_length=None) outputs = outputs[:, -1:, :] logits = output_layer(outputs) return logits, cache return _impl
def _symbols_to_logits_fn(self, embedding, vocab_size, mode, output_layer=None, dtype=None): embedding_fn = get_embedding_fn(embedding) if self.share_embedding: w_embs = reuse_variable("w_embs") output_layer = build_linear_shared_weights( vocab_size, w_embs, scope="proj_to_vocab_size") elif output_layer is None: output_layer = build_linear_weight_norm(self.out_embedding_dim, vocab_size, dropout=self.dropout, dtype=dtype, scope="proj_to_vocab_size") def _impl(ids, step, cache): inputs = embedding_fn(ids[:, -1:]) if self.position_encoder is not None: inputs = self.position_encoder.apply_one(inputs, step + 1) outputs = self._cnn_stack( inputs, memory=cache["memory"], mode=mode, cache=cache) outputs = outputs[:, -1:, :] logits = output_layer(outputs) return logits, cache return _impl
def dynamic_decode_and_search(self, embedding, start_tokens, end_token, vocab_size, initial_state=None, beam_width=5, length_penalty=0.0, maximum_iterations=250, mode=tf.estimator.ModeKeys.PREDICT, memory=None, memory_sequence_length=None): if initial_state is not None: initial_state = tf.contrib.seq2seq.tile_batch( initial_state, multiplier=beam_width) if memory is not None: memory = tf.contrib.seq2seq.tile_batch(memory, multiplier=beam_width) if memory_sequence_length is not None: memory_sequence_length = tf.contrib.seq2seq.tile_batch( memory_sequence_length, multiplier=beam_width) embedding_fn = get_embedding_fn(embedding) def _symbols_to_logits_fn(symbols): batch_size = tf.shape(symbols)[0] step = tf.shape(symbols)[1] sequence_length = tf.fill([batch_size], step) outputs = self._self_attention_stack( embedding_fn(symbols), sequence_length, mode=mode, memory=memory, memory_sequence_length=memory_sequence_length) # Only sample the last timestep. last_output = tf.slice(outputs, [0, step - 1, 0], [-1, 1, -1]) logits = tf.layers.dense(last_output, vocab_size) return logits outputs, log_probs = beam_search(_symbols_to_logits_fn, start_tokens, beam_width, maximum_iterations, vocab_size, length_penalty, eos_id=end_token) outputs = tf.slice(outputs, [0, 0, 1], [-1, -1, -1]) # Ignore <s>. lengths = tf.not_equal(outputs, 0) lengths = tf.cast(lengths, tf.int32) lengths = tf.reduce_sum(lengths, axis=-1) return (outputs, None, lengths, log_probs)
def _symbols_to_logits_fn(self, embedding, vocab_size, mode): embedding_fn = get_embedding_fn(embedding) def _impl(ids, step, cache): inputs = embedding_fn(ids[:, -1:]) inputs = self.position_encoder.apply_one(inputs, step + 1) outputs = self._self_attention_stack( inputs, mode=mode, cache=cache, memory=cache["memory"], memory_sequence_length=cache["memory_sequence_length"]) outputs = outputs[:, -1:, :] logits = tf.layers.dense(outputs, vocab_size) return logits, cache return _impl
def _symbols_to_logits_fn(self, embedding, vocab_size, mode): embedding_fn = get_embedding_fn(embedding) def _impl(ids, step, cache): inputs = embedding_fn(ids[:, -1:]) inputs *= self.num_units**0.5 inputs = self.position_encoder.apply_one(inputs, step + 1) outputs = self._self_attention_stack( inputs, mode=mode, cache=cache, memory=cache["memory"], memory_sequence_length=None) outputs = outputs[:, -1:, :] logits = tf.layers.dense(outputs, vocab_size) return logits, cache return _impl
def dynamic_decode(self, embedding, start_tokens, end_token, vocab_size, initial_state=None, maximum_iterations=250, mode=tf.estimator.ModeKeys.PREDICT, memory=None, memory_sequence_length=None): batch_size = tf.shape(start_tokens)[0] finished = tf.tile([False], [batch_size]) step = tf.constant(0) inputs = tf.expand_dims(start_tokens, 1) lengths = tf.zeros([batch_size], dtype=tf.int32) log_probs = tf.zeros([batch_size]) embedding_fn = get_embedding_fn(embedding) def _condition(unused_step, finished, unused_inputs, unused_lengths, unused_log_probs): return tf.logical_not(tf.reduce_all(finished)) def _body(step, finished, inputs, lengths, log_probs): inputs_lengths = tf.add(lengths, 1 - tf.cast(finished, tf.int32)) # Decode inputs. outputs = self._self_attention_stack( embedding_fn(inputs), inputs_lengths, mode=mode, memory=memory, memory_sequence_length=memory_sequence_length) # Only sample the last timestep. last_output = tf.slice(outputs, [0, step, 0], [-1, 1, -1]) logits = tf.layers.dense(last_output, vocab_size) probs = tf.nn.log_softmax(logits) sample_ids = tf.argmax(probs, axis=-1) # Accumulate log probabilities. sample_probs = tf.reduce_max(probs, axis=-1) masked_probs = tf.squeeze( sample_probs, -1) * (1.0 - tf.cast(finished, tf.float32)) log_probs = tf.add(log_probs, masked_probs) next_inputs = tf.concat( [inputs, tf.cast(sample_ids, tf.int32)], -1) next_lengths = inputs_lengths next_finished = tf.logical_or( finished, tf.equal(tf.squeeze(sample_ids, axis=[1]), end_token)) step = step + 1 if maximum_iterations is not None: next_finished = tf.logical_or(next_finished, step >= maximum_iterations) return step, next_finished, next_inputs, next_lengths, log_probs step, _, outputs, lengths, log_probs = tf.while_loop( _condition, _body, loop_vars=(step, finished, inputs, lengths, log_probs), shape_invariants=(tf.TensorShape([]), finished.get_shape(), tf.TensorShape([None, None]), lengths.get_shape(), log_probs.get_shape()), parallel_iterations=1) outputs = tf.slice(outputs, [0, 1], [-1, -1]) # Ignore <s>. # Make shape consistent with beam search. outputs = tf.expand_dims(outputs, 1) lengths = tf.expand_dims(lengths, 1) log_probs = tf.expand_dims(log_probs, 1) return (outputs, None, lengths, log_probs)