def generic_loss(top_out, targets, model_hparams, weights_fn=common_layers.weights_nonzero, gaussian=False): logits = top_out labels = targets # padded_cross_entropy confidence = 1.0 - model_hparams.label_smoothing logits_shape = common_layers.shape_list(logits) vocab_size = logits_shape[-1] logits, labels = common_layers.pad_with_zeros(logits, labels) logits = tf.reshape(logits, common_layers.shape_list(labels) + [vocab_size], name="padded_cross_entropy_size_check") logits = tf.cast(logits, tf.float32) xent = common_layers.smoothing_cross_entropy(logits, labels, vocab_size, confidence, gaussian=gaussian) weights = weights_fn(labels) return tf.reduce_sum(xent * weights), tf.reduce_sum(weights)
def while_exit_cond(result, logits, loss): # pylint: disable=unused-argument """Exit the loop either if reach decode_length or EOS.""" length = common_layers.shape_list(result)[1] not_overflow = length < decode_length if self._problem_hparams.stop_at_eos: def fn_not_eos(): return tf.not_equal( # Check if the last predicted element is a EOS tf.squeeze(result[:, -1, :, :]), text_encoder.EOS_ID) not_eos = tf.cond( # We only check for early stopping if there is at least 1 element ( # otherwise not_eos will crash). tf.not_equal(length, 0), fn_not_eos, lambda: True, ) return tf.cond( tf.equal(batch_size, 1), # If batch_size == 1, we check EOS for early stopping. lambda: tf.logical_and(not_overflow, not_eos), # Else, just wait for max length lambda: not_overflow) return not_overflow
def combine_last_two_dimensions(x): """Reshape x so that the last two dimension become one. Args: x: a Tensor with shape [..., a, b] Returns: a Tensor with shape [..., ab] """ x_shape = common_layers.shape_list(x) a, b = x_shape[-2:] return tf.reshape(x, x_shape[:-2] + [a * b])
def add_position_timing_signal(x, hparams): """Add n-dimensional embedding as the position (horizontal) timing signal. Args: x: a tensor with shape [batch, length, depth] step: step hparams: model hyper parameters Returns: a Tensor with the same shape as x. """ length = common_layers.shape_list(x)[1] channels = common_layers.shape_list(x)[2] signal = common_attention.get_timing_signal_1d(length, channels) if hparams.add_or_concat_timing_signal == "add": x_with_timing = x + common_layers.cast_like(signal, x) return x_with_timing else: ValueError("Unknown timing signal add or concat type: %s" % hparams.add_or_concat_timing_signal)
def split_last_dimension(x, n): """Reshape x so that the last dimension becomes two dimensions. The first of these two dimensions is n. Args: x: a Tensor with shape [..., m] n: an integer. Returns: a Tensor with shape [..., n, m/n] """ x_shape = common_layers.shape_list(x) m = x_shape[-1] if isinstance(m, int) and isinstance(n, int): assert m % n == 0 return tf.reshape(x, x_shape[:-1] + [n, m // n])
def transformer_prepare_decoder(targets): """Prepare one shard of the model for the decoder. Args: targets: a Tensor. hparams: run hyperparameters features: optionally pass the entire features dictionary as well. This is needed now for "packed" datasets. Returns: decoder_input: a Tensor, bottom of decoder stack decoder_self_attention_bias: a bias tensor for use in decoder self-attention """ decoder_self_attention_bias = ( common_attention.attention_bias_lower_triangle( common_layers.shape_list(targets)[1])) decoder_input = common_layers.shift_right_3d(targets) return (decoder_input, decoder_self_attention_bias)
def symbol_top(body_output, target_embedding_space): body_output_shape = common_layers.shape_list(body_output) body_output = tf.reshape(body_output, [-1, body_output_shape[-1]]) logits = tf.matmul(body_output, target_embedding_space, transpose_b=True) return tf.reshape( logits, body_output_shape[:-1] + [1, target_embedding_space.shape[0]])
def _greedy_infer(self, features): features["inputs"] = tf.expand_dims(features["inputs"], 2) batch_size = common_layers.shape_list(features["inputs"])[0] initial_output = tf.zeros((batch_size, 0, 1, 1), dtype=tf.int64) prefix_length = common_layers.shape_list(features["inputs"])[1] decode_length = prefix_length + self._hparams.decode_length result = initial_output vocab_size = self._problem_hparams.vocab_size["targets"] logits = tf.zeros((batch_size, 0, 1, 1, vocab_size)) logits_shape_inv = [None, None, None, None, None] loss = 0.0 def infer_step(recent_output, recent_logits, unuserd_loss): padded = tf.pad(recent_output, [[0, 0], [0, 1], [0, 0], [0, 0]]) features["targets"] = padded samples, logits, loss = self.sample(features) cur_sample = samples[:, -1, :, :] cur_sample = tf.to_int64(tf.expand_dims(cur_sample, axis=1)) samples = tf.concat([recent_output, cur_sample], axis=1) logits = tf.concat([recent_logits, logits[:, -1:]], 1) return samples, logits, loss def while_exit_cond(result, logits, loss): # pylint: disable=unused-argument """Exit the loop either if reach decode_length or EOS.""" length = common_layers.shape_list(result)[1] not_overflow = length < decode_length if self._problem_hparams.stop_at_eos: def fn_not_eos(): return tf.not_equal( # Check if the last predicted element is a EOS tf.squeeze(result[:, -1, :, :]), text_encoder.EOS_ID) not_eos = tf.cond( # We only check for early stopping if there is at least 1 element ( # otherwise not_eos will crash). tf.not_equal(length, 0), fn_not_eos, lambda: True, ) return tf.cond( tf.equal(batch_size, 1), # If batch_size == 1, we check EOS for early stopping. lambda: tf.logical_and(not_overflow, not_eos), # Else, just wait for max length lambda: not_overflow) return not_overflow result, logits, loss = tf.while_loop( while_exit_cond, infer_step, [result, logits, loss], shape_invariants=[ tf.TensorShape([None, None, None, None]), tf.TensorShape(logits_shape_inv), tf.TensorShape([]), ], back_prop=False, parallel_iterations=1, name='', ) return result