Ejemplo n.º 1
0
 def _prepare_inputs_for_decode(self, features):
     """Prepare inputs for decoding.
     Args:
     features: A map of string to model features.
     Returns:
     Inputs after fixing shape and applying modality.
     """
     dp = self._data_parallelism
     hparams = self._hparams
     inputs = features['inputs']
     inputs = tf.expand_dims(inputs, axis=1)
     if len(inputs.shape) < 5:
         inputs = tf.expand_dims(inputs, axis=4)
     s = common_layers.shape_list(inputs)
     inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
     inputs = self._shard_features({'inputs': inputs})['inputs']
     input_modality = self._problem_hparams.modality['inputs']
     input_vocab_size = self._problem_hparams.vocab_size['inputs']
     if input_vocab_size is not None and hasattr(hparams, 'vocab_divisor'):
         input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor
     modality_name = hparams.name.get('inputs',
                                      modalities.get_name(input_modality))(
                                          hparams, input_vocab_size)
     with tf.variable_scope(modality_name):
         bottom = hparams.bottom.get('inputs',
                                     modalities.get_bottom(input_modality))
         inputs = dp(bottom, inputs, hparams, input_vocab_size)
     return inputs
Ejemplo n.º 2
0
    def symbols_to_logits_fn(ids, i, cache):
      """Go from ids to logits for next symbol."""
      ids = ids[:, -1:]
      targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      targets = preprocess_targets(targets, i)

      bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

      with tf.variable_scope("body"):
        body_outputs = dp(
            self.decode,
            targets,
            cache.get("encoder_output"),
            cache.get("encoder_decoder_attention_bias"),
            bias,
            hparams,
            cache,
            nonpadding=features_to_nonpadding(features, "targets"))

      update_decoder_attention_history(cache)
      cache["body_outputs"] = tf.concat([cache["body_outputs"], body_outputs[0]], axis=2)

      modality_name = hparams.name.get(
          "targets",
          modalities.get_name(target_modality))(hparams, target_vocab_size)
      with tf.variable_scope(modality_name):
        top = hparams.top.get("targets", modalities.get_top(target_modality))
        logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0]

      ret = tf.squeeze(logits, axis=[1, 2, 3])
      if partial_targets is not None:
        # If the position is within the given partial targets, we alter the
        # logits to always return those values.
        # A faster approach would be to process the partial targets in one
        # iteration in order to fill the corresponding parts of the cache.
        # This would require broader changes, though.
        vocab_size = tf.shape(ret)[1]

        def forced_logits():
          return tf.one_hot(
              tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0,
              -1e9)

        ret = tf.cond(
            tf.less(i, partial_targets_length), forced_logits, lambda: ret)
      return ret, cache
 def testGetForAllModalities(self):
   for modality in modalities.ModalityType.get_choices():
     bottom = modalities.get_bottom(modality)
     loss = modalities.get_loss(modality)
     name = modalities.get_name(modality)
     targets_bottom = modalities.get_targets_bottom(modality)
     top = modalities.get_top(modality)
     weights_fn = modalities.get_weights_fn(modality)
     self.assertIsNotNone(bottom,
                          msg="{} has no default bottom".format(modality))
     self.assertIsNotNone(loss, msg="{} has no default loss".format(modality))
     self.assertIsNotNone(name, msg="{} has no default name".format(modality))
     self.assertIsNotNone(
         targets_bottom,
         msg="{} has no default targets_bottom".format(modality))
     self.assertIsNotNone(top, msg="{} has no default top".format(modality))
     self.assertIsNotNone(weights_fn,
                          msg="{} has no default weights_fn".format(modality))
Ejemplo n.º 4
0
        def preprocess_targets(targets, i):
            targets = self._shard_features({'targets': targets})['targets']
            modality_name = hparams.name.get(
                'targets',
                modalities.get_name(target_modality))(hparams,
                                                      target_vocab_size)
            with tf.variable_scope(modality_name + '/targets'):
                bottom = hparams.bottom.get(
                    'targets', modalities.get_targets_bottom(target_modality))
                targets = dp(bottom, targets, hparams, target_vocab_size)[0]
            targets = common_layers.flatten4d3d(targets)

            if not self.get_decode_start_id():
                targets = tf.cond(
                    tf.equal(i, 0),
                    lambda: tf.zeros_like(targets),
                    lambda: targets,
                )

            if positional_encoding is not None:
                targets += positional_encoding[:, i:i + 1]
            return targets
Ejemplo n.º 5
0
    def preprocess_targets(targets, i):
      """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
      # _shard_features called to ensure that the variable names match
      targets = self._shard_features({"targets": targets})["targets"]
      modality_name = hparams.name.get(
          "targets",
          modalities.get_name(target_modality))(hparams, target_vocab_size)
      with tf.variable_scope(modality_name):
        bottom = hparams.bottom.get(
            "targets", modalities.get_targets_bottom(target_modality))
        targets = dp(bottom, targets, hparams, target_vocab_size)[0]
      targets = common_layers.flatten4d3d(targets)

      # GO embeddings are all zero, this is because transformer_prepare_decoder
      # Shifts the targets along by one for the input which pads with zeros.
      # If the modality already maps GO to the zero embeddings this is not
      # needed.
      targets = tf.cond(
          tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)

      if positional_encoding is not None:
        targets += positional_encoding[:, i:i + 1]
      return targets
Ejemplo n.º 6
0
  def _fast_decode(self,
                   features,
                   decode_length,
                   beam_size=1,
                   top_beams=1,
                   alpha=1.0):
    """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for longer translations.

    Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if beam_size == 1 or
              [batch_size, top_beams, <= decode_length]
          "scores": decoding log probs from the beam search,
              None if using greedy decoding (beam_size=1)
      }

    Raises:
      NotImplementedError: If there are multiple data shards.
    """
    if self._num_datashards != 1:
      raise NotImplementedError("Fast decoding only supports a single shard.")
    dp = self._data_parallelism
    hparams = self._hparams
    target_modality = self._problem_hparams.modality["targets"]
    target_vocab_size = self._problem_hparams.vocab_size["targets"]
    if target_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
      target_vocab_size += (-target_vocab_size) % hparams.vocab_divisor
    if "targets_segmentation" in features:
      raise NotImplementedError(
          "Decoding not supported on packed datasets "
          " If you want to decode from a dataset, use the non-packed version"
          " of the dataset when decoding.")
    if self.has_input:
      inputs = features["inputs"]
      if target_modality == modalities.ModalityType.CLASS_LABEL:
        decode_length = 1
      else:
        decode_length = (
            common_layers.shape_list(inputs)[1] + features.get(
                "decode_length", decode_length))

      # TODO(llion): Clean up this reshaping logic.
      inputs = tf.expand_dims(inputs, axis=1)
      if len(inputs.shape) < 5:
        inputs = tf.expand_dims(inputs, axis=4)
      s = common_layers.shape_list(inputs)
      batch_size = s[0]
      inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
      # _shard_features called to ensure that the variable names match
      inputs = self._shard_features({"inputs": inputs})["inputs"]
      input_modality = self._problem_hparams.modality["inputs"]
      input_vocab_size = self._problem_hparams.vocab_size["inputs"]
      if input_vocab_size is not None and hasattr(hparams, "vocab_divisor"):
        input_vocab_size += (-input_vocab_size) % hparams.vocab_divisor
      modality_name = hparams.name.get(
          "inputs",
          modalities.get_name(input_modality))(hparams, input_vocab_size)
      with tf.variable_scope(modality_name):
        bottom = hparams.bottom.get("inputs",
                                    modalities.get_bottom(input_modality))
        inputs = dp(bottom, inputs, hparams, input_vocab_size)
      with tf.variable_scope("body"):
        encoder_output, encoder_decoder_attention_bias = dp(
            self.encode,
            inputs,
            features["target_space_id"],
            hparams,
            features=features)
      encoder_output = encoder_output[0]
      encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
      if 'partial_targets' in features:
        partial_targets = features['partial_targets']
      else:
        partial_targets = None
    else:
      # The problem has no inputs.
      encoder_output = None
      encoder_decoder_attention_bias = None

      # Prepare partial targets.
      # In either features["inputs"] or features["targets"].
      # We force the outputs to begin with these sequences.
      partial_targets = features.get("inputs")
      if partial_targets is None:
        partial_targets = features["targets"]
      assert partial_targets is not None

    if partial_targets is not None:
      partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
      partial_targets = tf.to_int64(partial_targets)
      partial_targets_shape = common_layers.shape_list(partial_targets)
      partial_targets_length = partial_targets_shape[1]
      decode_length = (
          partial_targets_length + features.get("decode_length", decode_length))
      batch_size = partial_targets_shape[0]

    if hparams.pos == "timing":
      positional_encoding = common_attention.get_timing_signal_1d(
          decode_length + 1, hparams.hidden_size)
    elif hparams.pos == "emb":
      positional_encoding = common_attention.add_positional_embedding(
          tf.zeros([1, decode_length, hparams.hidden_size]), hparams.max_length,
          "body/targets_positional_embedding", None)
    else:
      positional_encoding = None

    def preprocess_targets(targets, i):
      """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
      # _shard_features called to ensure that the variable names match
      targets = self._shard_features({"targets": targets})["targets"]
      modality_name = hparams.name.get(
          "targets",
          modalities.get_name(target_modality))(hparams, target_vocab_size)
      with tf.variable_scope(modality_name):
        bottom = hparams.bottom.get(
            "targets", modalities.get_targets_bottom(target_modality))
        targets = dp(bottom, targets, hparams, target_vocab_size)[0]
      targets = common_layers.flatten4d3d(targets)

      # GO embeddings are all zero, this is because transformer_prepare_decoder
      # Shifts the targets along by one for the input which pads with zeros.
      # If the modality already maps GO to the zero embeddings this is not
      # needed.
      targets = tf.cond(
          tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)

      if positional_encoding is not None:
        targets += positional_encoding[:, i:i + 1]
      return targets

    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(decode_length))
    if hparams.proximity_bias:
      decoder_self_attention_bias += common_attention.attention_bias_proximal(
          decode_length)

    # Create tensors for encoder-decoder attention history
    att_cache = {"attention_history": {}}
    num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
    att_batch_size, enc_seq_length = common_layers.shape_list(encoder_output)[0:2]
    for layer in range(num_layers):
      att_cache["attention_history"]["layer_%d" % layer] = tf.zeros(
        [att_batch_size, hparams.num_heads, 0, enc_seq_length])
    att_cache["body_outputs"] = tf.zeros([att_batch_size, 1, 0, hparams.hidden_size])

    def update_decoder_attention_history(cache):
      for k in filter(lambda x: "decoder" in x and not "self" in x and not "logits" in x,
        self.attention_weights.keys()):
        m = re.search(r"(layer_\d+)", k)
        if m is None:
          continue
        cache["attention_history"][m[0]] = tf.concat(
            [cache["attention_history"][m[0]], self.attention_weights[k]], axis=2)

    def symbols_to_logits_fn(ids, i, cache):
      """Go from ids to logits for next symbol."""
      ids = ids[:, -1:]
      targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      targets = preprocess_targets(targets, i)

      bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

      with tf.variable_scope("body"):
        body_outputs = dp(
            self.decode,
            targets,
            cache.get("encoder_output"),
            cache.get("encoder_decoder_attention_bias"),
            bias,
            hparams,
            cache,
            nonpadding=features_to_nonpadding(features, "targets"))

      update_decoder_attention_history(cache)
      cache["body_outputs"] = tf.concat([cache["body_outputs"], body_outputs[0]], axis=2)

      modality_name = hparams.name.get(
          "targets",
          modalities.get_name(target_modality))(hparams, target_vocab_size)
      with tf.variable_scope(modality_name):
        top = hparams.top.get("targets", modalities.get_top(target_modality))
        logits = dp(top, body_outputs, None, hparams, target_vocab_size)[0]

      ret = tf.squeeze(logits, axis=[1, 2, 3])
      if partial_targets is not None:
        # If the position is within the given partial targets, we alter the
        # logits to always return those values.
        # A faster approach would be to process the partial targets in one
        # iteration in order to fill the corresponding parts of the cache.
        # This would require broader changes, though.
        vocab_size = tf.shape(ret)[1]

        def forced_logits():
          return tf.one_hot(
              tf.tile(partial_targets[:, i], [beam_size]), vocab_size, 0.0,
              -1e9)

        ret = tf.cond(
            tf.less(i, partial_targets_length), forced_logits, lambda: ret)
      return ret, cache

    ret = fast_decode(
        encoder_output=encoder_output,
        encoder_decoder_attention_bias=encoder_decoder_attention_bias,
        symbols_to_logits_fn=symbols_to_logits_fn,
        hparams=hparams,
        decode_length=decode_length,
        vocab_size=target_vocab_size,
        beam_size=beam_size,
        top_beams=top_beams,
        alpha=alpha,
        batch_size=batch_size,
        force_decode_length=self._decode_hparams.force_decode_length,
        cache=att_cache)

    if partial_targets is not None:
      if beam_size <= 1 or top_beams <= 1:
        ret["outputs"] = ret["outputs"][:, partial_targets_length:]
      else:
        ret["outputs"] = ret["outputs"][:, :, partial_targets_length:]
    return ret
Ejemplo n.º 7
0
        def symbols_to_logits_fn(ids, ids_tag, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets_method(targets, i)

            ids_tag = ids_tag[:, -1:]
            targets_tag = tf.expand_dims(tf.expand_dims(ids_tag, axis=2),
                                         axis=3)
            targets_tag = preprocess_targets_tag_method(targets_tag, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope('body'):
                with tf.variable_scope('edit_ops_layer'):
                    with tf.variable_scope('ffn'):
                        x = targets
                        preproc = lambda z: common_layers.layer_preprocess(
                            z, hparams, layer_collection=None)
                        layer_inputs = [
                            tf.concat(preproc(x), axis=0),
                            tf.concat(preproc(targets_tag), axis=0),
                        ]
                        y = transformer_layers.transformer_ffn_layer(
                            tf.concat(layer_inputs, axis=2),
                            hparams,
                            conv_padding='LEFT',
                            nonpadding_mask=features_to_nonpadding(
                                features, 'targets'),
                            losses=None,
                            cache=cache,
                            decode_loop_step=None,
                            layer_collection=None,
                        )
                        targets = common_layers.layer_postprocess(
                            x, y, hparams)

                if hparams.middle_prediction:
                    num_decoder_layers = (hparams.num_decoder_layers
                                          or hparams.num_hidden_layers)
                    hparams.num_decoder_layers = int(
                        num_decoder_layers /
                        hparams.middle_prediction_layer_factor)

                body_outputs = dp(
                    self.decode,
                    targets,
                    cache.get('encoder_output'),
                    cache.get('encoder_decoder_attention_bias'),
                    bias,
                    hparams,
                    cache,
                    nonpadding=features_to_nonpadding(features, 'targets'),
                )[0]

                body_outputs, logits_tag = dp(
                    self._prediction_cascade_predict,
                    hparams,
                    features_to_nonpadding(features, 'targets'),
                    cache.get('encoder_decoder_attention_bias'),
                    cache.get('encoder_output'),
                    body_outputs,
                )
                logits_tag = logits_tag[0]['targets_error_tag']
                if hparams.middle_prediction:
                    with tf.variable_scope('after_prediction'):
                        body_outputs = dp(
                            self.decode,
                            targets + body_outputs[0],
                            cache.get('encoder_output'),
                            cache.get('encoder_decoder_attention_bias'),
                            bias,
                            hparams,
                            cache,
                            nonpadding=features_to_nonpadding(
                                features, 'targets'),
                        )

            update_decoder_attention_history(cache)

            modality_name = hparams.name.get(
                'targets',
                modalities.get_name(target_modality))(hparams,
                                                      target_vocab_size)
            with tf.variable_scope('targets/' + modality_name):
                top = hparams.top.get('targets',
                                      modalities.get_top(target_modality))
                logits = dp(top, body_outputs, None, hparams,
                            target_vocab_size)[0]

            ret = tf.squeeze(logits, axis=[1, 2])
            if partial_targets is not None:
                vocab_size = tf.shape(ret)[1]

                def forced_logits():
                    return tf.one_hot(
                        tf.tile(partial_targets[:, i], [beam_size]),
                        vocab_size,
                        0.0,
                        -1e9,
                    )

                ret = tf.cond(
                    tf.less(i, partial_targets_length),
                    forced_logits,
                    lambda: ret,
                )
            logits_tag = tf.squeeze(logits_tag, axis=[1])
            return ret, logits_tag, cache