def testSymbolModalityTargets(self):
     batch_size = 10
     num_datashards = 5
     length = 6
     height = 7
     hidden_size = 9
     vocab_size = 11
     model_hparams = common_hparams.basic_params1()
     model_hparams.hidden_size = hidden_size
     model_hparams.mode = tf.estimator.ModeKeys.TRAIN
     body_output = np.random.randint(100,
                                     size=(batch_size, length, height,
                                           hidden_size))
     targets = np.random.randint(vocab_size,
                                 size=(batch_size, length, height, 1))
     data_parallelism = expert_utils.Parallelism(["/device:CPU:0"] *
                                                 num_datashards)
     sharded_body_output = tf.split(tf.to_float(body_output),
                                    num_datashards)
     sharded_targets = tf.split(targets, num_datashards)
     sharded_logits = data_parallelism(
         modalities.get_top(modalities.ModalityType.SYMBOL),
         sharded_body_output, sharded_targets, model_hparams, vocab_size)
     sharded_loss_num, sharded_loss_den = data_parallelism(
         modalities.get_loss(modalities.ModalityType.SYMBOL),
         sharded_logits, sharded_targets, model_hparams, vocab_size,
         modalities.get_weights_fn(modalities.ModalityType.SYMBOL))
     train_loss = (tf.add_n(sharded_loss_num) /
                   tf.maximum(1.0, tf.add_n(sharded_loss_den)))
     logits = tf.concat(sharded_logits, 0)
     self.evaluate(tf.global_variables_initializer())
     res1, res2 = self.evaluate((logits, train_loss))
     self.assertEqual(res1.shape,
                      (batch_size, length, height, 1, vocab_size))
     self.assertEqual(res2.shape, ())
示例#2
0
    def symbols_to_logits_fn(ids, i, cache):
      """Go from ids to logits for next symbol."""
      ids = ids[:, -1:]
      targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      targets = preprocess_targets(targets, i)

      bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

      with tf.variable_scope("body"):
        body_outputs = dp(
            self.decode,
            targets,
            cache.get("encoder_output"),
            cache.get("encoder_decoder_attention_bias"),
            bias,
            hparams,
            cache,
            nonpadding=features_to_nonpadding(features, "targets"))

      update_decoder_attention_history(cache)
      cache["body_outputs"] = tf.concat([cache["body_outputs"], body_outputs[0]], axis=2)

      modality_name = hparams.name.get(
          "targets",
          modalities.get_name(target_modality))(hparams, target_vocab_size)
      with tf.variable_scope(modality_name):
        top = hparams.top.get("targets", modalities.get_top(target_modality))
        logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0]

      ret = tf.squeeze(logits, axis=[1, 2, 3])
      if partial_targets is not None:
        # If the position is within the given partial targets, we alter the
        # logits to always return those values.
        # A faster approach would be to process the partial targets in one
        # iteration in order to fill the corresponding parts of the cache.
        # This would require broader changes, though.
        vocab_size = tf.shape(ret)[1]

        def forced_logits():
          return tf.one_hot(
              tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0,
              -1e9)

        ret = tf.cond(
            tf.less(i, partial_targets_length), forced_logits, lambda: ret)
      return ret, cache
 def testGetForAllModalities(self):
   for modality in modalities.ModalityType.get_choices():
     bottom = modalities.get_bottom(modality)
     loss = modalities.get_loss(modality)
     name = modalities.get_name(modality)
     targets_bottom = modalities.get_targets_bottom(modality)
     top = modalities.get_top(modality)
     weights_fn = modalities.get_weights_fn(modality)
     self.assertIsNotNone(bottom,
                          msg="{} has no default bottom".format(modality))
     self.assertIsNotNone(loss, msg="{} has no default loss".format(modality))
     self.assertIsNotNone(name, msg="{} has no default name".format(modality))
     self.assertIsNotNone(
         targets_bottom,
         msg="{} has no default targets_bottom".format(modality))
     self.assertIsNotNone(top, msg="{} has no default top".format(modality))
     self.assertIsNotNone(weights_fn,
                          msg="{} has no default weights_fn".format(modality))
        def infer_step(recent_output, recent_logits, unused_loss):
            """Inference step."""
            if not tf.executing_eagerly():
                if self._target_modality_is_real:
                    dim = self._problem_hparams.vocab_size["targets"]
                    if dim is not None and hasattr(self._hparams,
                                                   "vocab_divisor"):
                        dim += (-dim) % self._hparams.vocab_divisor
                    recent_output.set_shape([None, None, None, dim])
                else:
                    recent_output.set_shape([None, None, None, 1])
            padded = tf.pad(recent_output, [[0, 0], [0, 1], [0, 0], [0, 0]])
            features["targets"] = padded
            # This is inefficient in that it generates samples at all timesteps,
            # not just the last one, except if target_modality is pointwise.
            samples, logits, losses = self.sample(features)
            # Concatenate the already-generated recent_output with last timestep
            # of the newly-generated samples.
            top = self._hparams.top.get("targets",
                                        modalities.get_top(target_modality))
            if getattr(top, "pointwise", False):
                cur_sample = samples[:, -1, :, :]
            else:
                cur_sample = samples[:,
                                     common_layers.shape_list(recent_output
                                                              )[1], :, :]
            if self._target_modality_is_real:
                cur_sample = tf.expand_dims(cur_sample, axis=1)
                samples = tf.concat([recent_output, cur_sample], axis=1)
            else:
                cur_sample = tf.to_int64(tf.expand_dims(cur_sample, axis=1))
                samples = tf.concat([recent_output, cur_sample], axis=1)
                if not tf.executing_eagerly():
                    samples.set_shape([None, None, None, 1])

            # Assuming we have one shard for logits.
            logits = tf.concat([recent_logits, logits[:, -1:]], 1)
            loss = sum([l for l in losses.values() if l is not None])
            return samples, logits, loss
示例#5
0
        def symbols_to_logits_fn(ids, ids_tag, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets_method(targets, i)

            ids_tag = ids_tag[:, -1:]
            targets_tag = tf.expand_dims(tf.expand_dims(ids_tag, axis=2),
                                         axis=3)
            targets_tag = preprocess_targets_tag_method(targets_tag, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope('body'):
                with tf.variable_scope('edit_ops_layer'):
                    with tf.variable_scope('ffn'):
                        x = targets
                        preproc = lambda z: common_layers.layer_preprocess(
                            z, hparams, layer_collection=None)
                        layer_inputs = [
                            tf.concat(preproc(x), axis=0),
                            tf.concat(preproc(targets_tag), axis=0),
                        ]
                        y = transformer_layers.transformer_ffn_layer(
                            tf.concat(layer_inputs, axis=2),
                            hparams,
                            conv_padding='LEFT',
                            nonpadding_mask=features_to_nonpadding(
                                features, 'targets'),
                            losses=None,
                            cache=cache,
                            decode_loop_step=None,
                            layer_collection=None,
                        )
                        targets = common_layers.layer_postprocess(
                            x, y, hparams)

                if hparams.middle_prediction:
                    num_decoder_layers = (hparams.num_decoder_layers
                                          or hparams.num_hidden_layers)
                    hparams.num_decoder_layers = int(
                        num_decoder_layers /
                        hparams.middle_prediction_layer_factor)

                body_outputs = dp(
                    self.decode,
                    targets,
                    cache.get('encoder_output'),
                    cache.get('encoder_decoder_attention_bias'),
                    bias,
                    hparams,
                    cache,
                    nonpadding=features_to_nonpadding(features, 'targets'),
                )[0]

                body_outputs, logits_tag = dp(
                    self._prediction_cascade_predict,
                    hparams,
                    features_to_nonpadding(features, 'targets'),
                    cache.get('encoder_decoder_attention_bias'),
                    cache.get('encoder_output'),
                    body_outputs,
                )
                logits_tag = logits_tag[0]['targets_error_tag']
                if hparams.middle_prediction:
                    with tf.variable_scope('after_prediction'):
                        body_outputs = dp(
                            self.decode,
                            targets + body_outputs[0],
                            cache.get('encoder_output'),
                            cache.get('encoder_decoder_attention_bias'),
                            bias,
                            hparams,
                            cache,
                            nonpadding=features_to_nonpadding(
                                features, 'targets'),
                        )

            update_decoder_attention_history(cache)

            modality_name = hparams.name.get(
                'targets',
                modalities.get_name(target_modality))(hparams,
                                                      target_vocab_size)
            with tf.variable_scope('targets/' + modality_name):
                top = hparams.top.get('targets',
                                      modalities.get_top(target_modality))
                logits = dp(top, body_outputs, None, hparams,
                            target_vocab_size)[0]

            ret = tf.squeeze(logits, axis=[1, 2])
            if partial_targets is not None:
                vocab_size = tf.shape(ret)[1]

                def forced_logits():
                    return tf.one_hot(
                        tf.tile(partial_targets[:, i], [beam_size]),
                        vocab_size,
                        0.0,
                        -1e9,
                    )

                ret = tf.cond(
                    tf.less(i, partial_targets_length),
                    forced_logits,
                    lambda: ret,
                )
            logits_tag = tf.squeeze(logits_tag, axis=[1])
            return ret, logits_tag, cache