def testWeightsNonzero(self): inputs = tf.constant([[3, 1, 0], [1, 0, 0]]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0]) channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1]) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs = mtf.layers.weights_nonzero(mtf_inputs) mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = common_layers.weights_nonzero(inputs) tf_group = lowering.copy_masters_to_slices() self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertAllEqual(actual, expected)
def body(self, features): """Seq2Edits main model_fn. Args: features: Feature dictionary. Should contain the following fields: "inputs": [batch_size, input_length, 1, hidden_dim] float tensor with input token embeddings. "targets": [batch_size, target_length, 1, hidden_dim] float tensor with target token embeddings. "targets_error_tag": [batch_size, target_length, 1, hidden_dim] float tensor with target error tag embeddings. "target_space_id": A scalar int from data_generators.problem.SpaceID. Returns: Final decoder representation. Dictionary containing the following fields: "targets": [batch_size, target_length, hidden_dim] float tensor with decoder outputs "targets_error_tag": [batch_size, target_length, hidden_dim] float tensor with decoder outputs """ hparams = self._hparams losses = [] if self.has_input: target_space = features['target_space_id'] encoder_output, encoder_decoder_attention_bias = self.encode( features['inputs'], target_space, hparams, features=features, losses=losses, ) else: encoder_output, encoder_decoder_attention_bias = (None, None) targets = features['targets'] targets_shape = common_layers.shape_list(targets) targets = common_layers.flatten4d3d(targets) decoder_input, decoder_self_attention_bias = self._prepare_decoder_fn( targets, hparams, features=features) nonpadding = features_to_nonpadding(features, 'targets') # Add edit ops layer to condition on start_token, end_token, and error_tag decoder_input = transformer_edit_ops_layer( decoder_input, hparams, encoder_output, features, nonpadding=nonpadding, losses=losses, ) if hparams.middle_prediction: num_decoder_layers = (hparams.num_decoder_layers or hparams.num_hidden_layers) hparams.num_decoder_layers = int( num_decoder_layers / hparams.middle_prediction_layer_factor) decode_kwargs = {} decoder_output = self.decode(decoder_input, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=nonpadding, losses=losses, **decode_kwargs) loss_mask = common_layers.weights_nonzero( maybe_flatten4d2d(features['targets_raw'])) self.loss_den = tf.reduce_sum(loss_mask) decoder_output = self._prediction_cascade( hparams=hparams, features=features, losses=losses, loss_mask=loss_mask, nonpadding=nonpadding, encoder_decoder_attention_bias=encoder_decoder_attention_bias, encoder_output=encoder_output, decoder_output=decoder_output, ) if hparams.middle_prediction: with tf.variable_scope('after_prediction'): decoder_output = self.decode(decoder_input + decoder_output, encoder_output, encoder_decoder_attention_bias, decoder_self_attention_bias, hparams, nonpadding=nonpadding, losses=losses, **decode_kwargs) ret = {'targets': tf.reshape(decoder_output, targets_shape)} ret.update(self.logits) if losses: return ret, {'extra_loss': tf.add_n(losses)} else: return ret