Example #1
0
        def symbols_to_logits_fn(dec_BxT, context, i):
            """Decode loop."""
            dec_Bx1 = tf.slice(
                dec_BxT,
                [0, tf.maximum(tf.cast(0, i.dtype), i - 1)],
                [tf.shape(dec_BxT)[0], 1],
            )

            bias_1x1xT = tf.slice(bias_1xTxT, [0, i, 0], [1, 1, T])
            dec_Bx1xD = self._embedding_layer(dec_Bx1, True)
            dec_Bx1xD *= tf.cast(tf.greater(i, 0), self._dtype)
            dec_Bx1xD = timing.add_time_signal(dec_Bx1xD, start_index = i)
            with tf.variable_scope(
                self._decoder_scope_name, reuse = tf.AUTO_REUSE
            ):
                dec_Bx1xD = transformer_block.stack(
                    self._decoder_layers,
                    False,
                    dec_Bx1xD,
                    bias_1x1xT,
                    context['memory'],
                    context['memory_bias'],
                    context,
                    i,
                )
                dec_Bx1xD = contrib_layers.layer_norm(
                    dec_Bx1xD, begin_norm_axis = 2
                )
            logits_Bx1xV = self._embedding_layer(dec_Bx1xD, False)
            logits_BxV = tf.squeeze(logits_Bx1xV, axis = 1)
            return logits_BxV
Example #2
0
    def __call__(self, features, training):
        """Create model.

    Args:
      features: dictionary of tensors including "inputs" [batch, input_len] and
        "targets" [batch, output_len]
      training: bool of whether the mode is training.

    Returns:
     Tuple of (loss, outputs): Loss is a scalar. Output is a dictionary of
       tensors, containing model's output logits.
    """
        if 'inputs' not in features or 'targets' not in features:
            raise ValueError('Require inputs and targets keys in features.')

        context = self._encode(features, training)
        self._context = context
        targets_BxT = features['targets']
        decoder_BxT = features['decoder']
        bias_1xTxT = attention.upper_triangle_bias(
            tf.shape(decoder_BxT)[1], self._dtype
        )
        states_BxTxD = self._embedding_layer(decoder_BxT, True)
        states_BxTxD = tf.pad(states_BxTxD, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
        states_BxTxD = timing.add_time_signal(states_BxTxD)
        states_BxTxD = self._dropout_fn(states_BxTxD, training)
        with tf.variable_scope(self._decoder_scope_name, reuse = tf.AUTO_REUSE):
            states_BxTxD = transformer_block.stack(
                self._decoder_layers,
                training,
                states_BxTxD,
                bias_1xTxT,
                context['memory'],
                context['memory_bias'],
            )
            states_BxTxD = contrib_layers.layer_norm(
                states_BxTxD, begin_norm_axis = 2
            )
        logits_BxTxV = self._embedding_layer(states_BxTxD, False)
        targets_mask_BxT = tf.cast(tf.greater(targets_BxT, 0), self._dtype)
        loss = tf.losses.softmax_cross_entropy(
            tf.one_hot(targets_BxT, self._vocab_size),
            logits_BxTxV,
            label_smoothing = self._label_smoothing,
            weights = targets_mask_BxT,
        )
        return loss, {'logits': logits_BxTxV}
Example #3
0
 def _encode(self, features, training):
     inputs_BxI = features['inputs']
     inputs_bias_Bx1xI = attention.ids_to_bias(inputs_BxI, self._dtype)
     states_BxIxD = self._embedding_layer(inputs_BxI, True)
     states_BxIxD = self._dropout_fn(
         timing.add_time_signal(states_BxIxD), training
     )
     with tf.variable_scope('encoder', reuse = tf.AUTO_REUSE):
         states_BxIxD = transformer_block.stack(
             self._encoder_layers,
             training,
             states_BxIxD,
             inputs_bias_Bx1xI,
             None,
             None,
         )
         states_BxIxD = contrib_layers.layer_norm(
             states_BxIxD, begin_norm_axis = 2
         )
     return {'memory': states_BxIxD, 'memory_bias': inputs_bias_Bx1xI}