Пример #1
0
    def testWeightsNonzero(self):
        inputs = tf.constant([[3, 1, 0], [1, 0, 0]])

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh")
        batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0])
        channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1])

        mtf_inputs = mtf.import_tf_tensor(mesh,
                                          inputs,
                                          shape=mtf.Shape(
                                              [batch_dim, channels_dim]))
        mtf_outputs = mtf.layers.weights_nonzero(mtf_inputs)
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[],
                                                              layout={},
                                                              devices=[""])
        lowering = mtf.Lowering(graph, {mesh: mesh_impl})
        actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)

        expected_outputs = common_layers.weights_nonzero(inputs)
        tf_group = lowering.copy_masters_to_slices()
        self.evaluate(tf_group)
        actual, expected = self.evaluate([actual_outputs, expected_outputs])

        self.assertAllEqual(actual, expected)
Пример #2
0
    def body(self, features):
        """Seq2Edits main model_fn.

    Args:
      features: Feature dictionary. Should contain the following fields:
          "inputs": [batch_size, input_length, 1, hidden_dim] float tensor with
            input token embeddings.
          "targets": [batch_size, target_length, 1, hidden_dim] float tensor
            with target token embeddings.
          "targets_error_tag": [batch_size, target_length, 1, hidden_dim] float
            tensor with target error tag embeddings.
          "target_space_id": A scalar int from data_generators.problem.SpaceID.

    Returns:
      Final decoder representation. Dictionary containing the following fields:
        "targets": [batch_size, target_length, hidden_dim] float tensor with
          decoder outputs
        "targets_error_tag": [batch_size, target_length, hidden_dim] float
          tensor with decoder outputs
    """
        hparams = self._hparams

        losses = []

        if self.has_input:
            target_space = features['target_space_id']
            encoder_output, encoder_decoder_attention_bias = self.encode(
                features['inputs'],
                target_space,
                hparams,
                features=features,
                losses=losses,
            )
        else:
            encoder_output, encoder_decoder_attention_bias = (None, None)

        targets = features['targets']
        targets_shape = common_layers.shape_list(targets)
        targets = common_layers.flatten4d3d(targets)
        decoder_input, decoder_self_attention_bias = self._prepare_decoder_fn(
            targets, hparams, features=features)

        nonpadding = features_to_nonpadding(features, 'targets')

        # Add edit ops layer to condition on start_token, end_token, and error_tag
        decoder_input = transformer_edit_ops_layer(
            decoder_input,
            hparams,
            encoder_output,
            features,
            nonpadding=nonpadding,
            losses=losses,
        )
        if hparams.middle_prediction:
            num_decoder_layers = (hparams.num_decoder_layers
                                  or hparams.num_hidden_layers)
            hparams.num_decoder_layers = int(
                num_decoder_layers / hparams.middle_prediction_layer_factor)

        decode_kwargs = {}
        decoder_output = self.decode(decoder_input,
                                     encoder_output,
                                     encoder_decoder_attention_bias,
                                     decoder_self_attention_bias,
                                     hparams,
                                     nonpadding=nonpadding,
                                     losses=losses,
                                     **decode_kwargs)

        loss_mask = common_layers.weights_nonzero(
            maybe_flatten4d2d(features['targets_raw']))
        self.loss_den = tf.reduce_sum(loss_mask)
        decoder_output = self._prediction_cascade(
            hparams=hparams,
            features=features,
            losses=losses,
            loss_mask=loss_mask,
            nonpadding=nonpadding,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            encoder_output=encoder_output,
            decoder_output=decoder_output,
        )

        if hparams.middle_prediction:
            with tf.variable_scope('after_prediction'):
                decoder_output = self.decode(decoder_input + decoder_output,
                                             encoder_output,
                                             encoder_decoder_attention_bias,
                                             decoder_self_attention_bias,
                                             hparams,
                                             nonpadding=nonpadding,
                                             losses=losses,
                                             **decode_kwargs)

        ret = {'targets': tf.reshape(decoder_output, targets_shape)}
        ret.update(self.logits)
        if losses:
            return ret, {'extra_loss': tf.add_n(losses)}
        else:
            return ret