def attention_lm_moe_prepare_decoder(targets, hparams):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly baises for diagonal alignments
    pad_remover (expert_utils.PadRemover): an util object to remove padding
  """
  targets_pad_mask = common_attention.embedding_to_padding(targets)
  with tf.name_scope("pad_remover"):
    # Because of the shift_right, the <eos> token will be considered as
    # padding. In practice, it doesn't really matter, due to the triangular
    # mask, this token should never be attended.
    pad_remover = expert_utils.PadRemover(targets_pad_mask)

  if hparams.prepend_mode == "prepend_inputs_full_attention":
    decoder_self_attention_bias = (
        common_attention.attention_bias_prepend_inputs_full_attention(
            targets_pad_mask))
  else:
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias, pad_remover)
Beispiel #2
0
def transformer_prepare_decoder(targets, hparams, features=None):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
  decoder_self_attention_bias = (
      common_attention.attention_bias_lower_triangle(
          common_layers.shape_list(targets)[1]))
  if features and "targets_segmentation" in features:
    # "Packed" dataset - keep the examples from seeing each other.
    targets_segmentation = features["targets_segmentation"]
    targets_position = features["targets_position"]
    decoder_self_attention_bias += common_attention.attention_bias_same_segment(
        targets_segmentation, targets_segmentation)
  else:
    targets_position = None
  if hparams.proximity_bias:
    decoder_self_attention_bias += common_attention.attention_bias_proximal(
        common_layers.shape_list(targets)[1])
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    if targets_position is not None:
      decoder_input = common_attention.add_timing_signal_1d_given_position(
          decoder_input, targets_position)
    else:
      decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias)
def prepare_decoder(targets, hparams):
  """Prepare decoder for images."""
  targets_shape = common_layers.shape_list(targets)
  channels = hparams.num_channels
  curr_infer_length = None

  # during training, images are [batch, IMG_LEN, IMG_LEN, 3].
  # At inference, they are [batch, curr_infer_length, 1, 1]
  if hparams.mode == tf.contrib.learn.ModeKeys.INFER:
    curr_infer_length = targets_shape[1]
    if hparams.block_raster_scan:
      assert hparams.img_len*channels % hparams.query_shape[1] == 0
      assert hparams.img_len % hparams.query_shape[0] == 0
      total_block_width = hparams.img_len*channels
      # Decoding is in block raster scan order. We divide the image into
      # hparams.query_shape blocks and then decode each block in raster scan.
      # To make that compatible with our inference pipeline, pad the target so
      # that rows is a multiple of query_shape and columns is a multiple of
      # hparams.img_len*channels
      curr_infer_length = targets_shape[1]
      block_padding_factor = total_block_width * hparams.query_shape[0]
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % block_padding_factor],
          [0, 0], [0, 0]])

      num_blocks = total_block_width // hparams.query_shape[1]
      # Reshape the image to represent blocks
      target_blocks = tf.reshape(
          targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0],
                    hparams.query_shape[1]])
      # Transpose to read the image in 2D fashion.
      targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4])
    else:
      # add padding to make sure the size of targets is a multiple of img_height
      # times number of channels. This is  needed for positional encodings and
      # for doing the RGB lookup.
      padding_factor = channels * hparams.img_len
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]])
    targets = tf.reshape(targets,
                         [targets_shape[0], -1, hparams.img_len, channels])
  # Preprocess image
  x = prepare_image(targets, hparams, name="dec_channels")
  x_shape = common_layers.shape_list(x)
  if (hparams.dec_attention_type == AttentionType.LOCAL_2D or
      hparams.dec_attention_type == AttentionType.LOCAL_BLOCK):
    x = common_attention.right_shift_blockwise(x, hparams.query_shape)
    x = add_pos_signals(x, hparams, "dec_pos")
  else:
    # Add position signals
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1]*x_shape[2], hparams.hidden_size])
    x = common_layers.shift_right_3d(x)
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1], x_shape[2], hparams.hidden_size])
    x = add_pos_signals(x, hparams, "dec_pos")
  x = common_layers.cast_like(x, targets)
  return x, x_shape[1], x_shape[2]
Beispiel #4
0
def prepare_decoder(targets, hparams):
  """Prepare decoder for images."""
  targets_shape = common_layers.shape_list(targets)
  channels = hparams.num_channels
  curr_infer_length = None

  # during training, images are [batch, IMG_LEN, IMG_LEN, 3].
  # At inference, they are [batch, curr_infer_length, 1, 1]
  if (hparams.mode == tf.contrib.learn.ModeKeys.INFER and
      hparams.block_raster_scan):
    curr_infer_length = targets_shape[1]
    if hparams.block_raster_scan:
      assert hparams.img_len*channels % hparams.query_shape[1] == 0
      assert hparams.img_len % hparams.query_shape[0] == 0
      total_block_width = hparams.img_len*channels
      # Decoding is in block raster scan order. We divide the image into
      # hparams.query_shape blocks and then decode each block in raster scan.
      # To make that compatible with our inference pipeline, pad the target so
      # that rows is a multiple of query_shape and columns is a multiple of
      # hparams.img_len*channels
      curr_infer_length = targets_shape[1]
      block_padding_factor = total_block_width * hparams.query_shape[0]
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % block_padding_factor],
          [0, 0], [0, 0]])

      num_blocks = total_block_width // hparams.query_shape[1]
      # Reshape the image to represent blocks
      target_blocks = tf.reshape(
          targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0],
                    hparams.query_shape[1]])
      # Transpose to read the image in 2D fashion.
      targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4])
    else:
      # add padding to make sure the size of targets is a multiple of img_height
      # times number of channels. This is  needed for positional encodings and
      # for doing the RGB lookup.
      padding_factor = channels * hparams.img_len
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]])
    targets = tf.reshape(targets,
                         [targets_shape[0], -1, hparams.img_len, channels])
  # Preprocess image
  x = prepare_image(targets, hparams, name="dec_channels")
  x_shape = common_layers.shape_list(x)
  if (hparams.dec_attention_type == AttentionType.LOCAL_2D or
      hparams.dec_attention_type == AttentionType.LOCAL_BLOCK):
    x = common_attention.right_shift_blockwise(x, hparams.query_shape)
    x = add_pos_signals(x, hparams, "dec_pos")
  else:
    # Add position signals
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1]*x_shape[2], hparams.hidden_size])
    x = common_layers.shift_right_3d(x)
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1], x_shape[2], hparams.hidden_size])
    x = add_pos_signals(x, hparams, "dec_pos")
  return x, x_shape[1], x_shape[2]
def transformer_prepare_decoder(targets, hparams, features=None):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in decoder self-attention
  """
  if hparams.causal_decoder_self_attention:
    # Causal attention.
    if hparams.prepend_mode == "prepend_inputs_full_attention":
      decoder_self_attention_bias = (
          common_attention.attention_bias_prepend_inputs_full_attention(
              common_attention.embedding_to_padding(targets)))
    else:
      decoder_self_attention_bias = (
          common_attention.attention_bias_lower_triangle(
              common_layers.shape_list(targets)[1]))
  else:
    # Full attention.
    decoder_padding = common_attention.embedding_to_padding(targets)
    decoder_self_attention_bias = (
        common_attention.attention_bias_ignore_padding(decoder_padding))

  if features and "targets_segmentation" in features:
    # "Packed" dataset - keep the examples from seeing each other.
    targets_segmentation = features["targets_segmentation"]
    targets_position = features["targets_position"]
    decoder_self_attention_bias += common_attention.attention_bias_same_segment(
        targets_segmentation, targets_segmentation)
  else:
    targets_position = None
  if hparams.proximity_bias:
    decoder_self_attention_bias += common_attention.attention_bias_proximal(
        common_layers.shape_list(targets)[1])
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    if targets_position is not None:
      decoder_input = common_attention.add_timing_signal_1d_given_position(
          decoder_input, targets_position)
    else:
      decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  elif hparams.pos == "emb":
    decoder_input = common_attention.add_positional_embedding(
        decoder_input, hparams.max_length, "targets_positional_embedding",
        targets_position)

  if hparams.activation_dtype == "bfloat16":
    decoder_self_attention_bias = tf.cast(decoder_self_attention_bias,
                                          tf.bfloat16)
  return (decoder_input, decoder_self_attention_bias)
def prepare_decoder(targets, target_space_emb):
    """Prepare decoder."""
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
    target_space_emb = tf.reshape(target_space_emb, [1, 1, -1])
    target_space_emb = tf.tile(target_space_emb, [tf.shape(targets)[0], 1, 1])
    decoder_input = common_layers.shift_right_3d(targets,
                                                 pad_value=target_space_emb)
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)
Beispiel #7
0
def prepare_decoder(targets, target_space_emb):
  """Prepare decoder."""
  decoder_self_attention_bias = (
      common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
  target_space_emb = tf.reshape(target_space_emb, [1, 1, -1])
  target_space_emb = tf.tile(target_space_emb, [tf.shape(targets)[0], 1, 1])
  decoder_input = common_layers.shift_right_3d(
      targets, pad_value=target_space_emb)
  decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias)
Beispiel #8
0
def transformer_prepare_decoder(targets, hparams, features=None):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(
            common_layers.shape_list(targets)[1]))
    if features and "targets_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        targets_segmentation = features["targets_segmentation"]
        targets_position = features["targets_position"]
        decoder_self_attention_bias += common_attention.attention_bias_same_segment(
            targets_segmentation, targets_segmentation)
    else:
        targets_position = None
    if hparams.proximity_bias:
        decoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(targets)[1])
    decoder_input = common_layers.shift_right_3d(targets)
    #if hparams.pos == "timing":
    #  if targets_position is not None:
    #    decoder_input = common_attention.add_timing_signal_1d_given_position(
    #        decoder_input, targets_position)
    #  else:
    #    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    raw_decoder_input = common_layers.shift_right(features['targets_raw'])
    terminal_decoder_bias, nonterminal_decoder_bias = _get_t_nt_bias(
        raw_decoder_input, hparams, decoder_self_attention_bias)
    pop_decoder_bias = _get_pop_bias(raw_decoder_input, hparams)
    raw_decoder_input = tf.squeeze(raw_decoder_input, axis=[-2, -1])
    pos_signals = generate_positional_signals(raw_decoder_input, hparams,
                                              terminal_decoder_bias,
                                              nonterminal_decoder_bias)
    pos_embeddings = generate_positional_embeddings(pos_signals,
                                                    hparams.decoder_pos,
                                                    hparams)
    if "sum" in hparams.decoder_pos_integration:
        decoder_input = decoder_input + pos_embeddings
    elif "ffn" in hparams.decoder_pos_integration:
        with tf.variable_scope("decoder_pos_ffn"):
            decoder_input = tf.concat([decoder_input, pos_embeddings], axis=2)
            decoder_input = transformer_ffn_layer(decoder_input,
                                                  hparams,
                                                  conv_padding="LEFT")
    return (decoder_input, decoder_self_attention_bias, terminal_decoder_bias,
            nonterminal_decoder_bias, pop_decoder_bias, pos_signals)
Beispiel #9
0
def transformer_prepare_decoder(targets, hparams):
    """Copied from tensor2tensor.models.transformer."""
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
    if hparams.proximity_bias:
        decoder_self_attention_bias += common_attention.attention_bias_proximal(
            tf.shape(targets)[1])
    decoder_input = common_layers.shift_right_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)
Beispiel #10
0
    def model_fn_body(self, features):
        hparams = self._hparams
        if not hparams.attention or hparams.attention_architecture != "standard":
            raise ValueError(
                "Layer-by-layer version of TF-NMT only available for "
                "the 'standard' attention model architecture.")
        inputs, inputs_length = usr_utils.get_feature_with_length(
            features, "inputs")
        target_roots, target_roots_length = usr_utils.get_feature_with_length(
            features, "target_roots")
        targets, targets_length = usr_utils.get_feature_with_length(
            features, "targets")
        # We need to do +1 for inference since get_feature_with_length()
        # may not have direct access to sequence lengths and returns
        # a length of 0 for the first inference step.
        if hparams.mode == tf.contrib.learn.ModeKeys.INFER:
            targets_length = targets_length + 1
        # input lengths of 0 breaks things
        inputs_length = tf.maximum(inputs_length, 1)
        target_roots_length = tf.maximum(target_roots_length, 1)

        # Shift targets right to use them as input
        targets = common_layers.shift_right_3d(targets)

        # Manage POP signals
        if hparams.target_root_attention == "pop":
            raw_targets = tf.squeeze(tf.squeeze(features["raw_targets"],
                                                axis=2),
                                     axis=2)
            targets_is_pop = tf.equal(raw_targets, hparams.pop_id)
        else:
            targets_is_pop = None

        iterator = TFNmtLayerbylayerInput(
            initializer=None,
            source=inputs,
            target_input=targets,
            target_input_is_pop=targets_is_pop,
            target_output=None,  # Loss is computed in T2T
            target_root=target_roots,
            source_sequence_length=inputs_length,
            target_sequence_length=targets_length,
            target_root_sequence_length=target_roots_length)
        tfnmt_model = TFNmtLayerbylayerModel(
            hparams_helper.convert_to_tfnmt_hparams(hparams),
            iterator=iterator,
            mode=tf.contrib.learn.ModeKeys.
            EVAL,  # We use eval graph for training
            source_vocab_table=FakeVocabTable(),
            target_vocab_table=FakeVocabTable())
        decoder_output = tfnmt_model.logits
        return tf.expand_dims(decoder_output, axis=2)
def transformer_edit_ops_layer(decoder_input,
                               hparams,
                               encoder_output,
                               features,
                               cache=None,
                               decode_loop_step=None,
                               nonpadding=None,
                               losses=None,
                               layer_collection=None):
  """Layer that conditions on the error tag and start and end token pointers."""
  if isinstance(encoder_output, list):  # Select forward encoder
    encoder_output = encoder_output[0]
  with tf.variable_scope("edit_ops_layer"):
    with tf.variable_scope("ffn"):
      x = decoder_input
      # Shorthand for layer preprocessing
      # pylint: disable=g-long-lambda
      preproc = lambda z: common_layers.layer_preprocess(
          z, hparams, layer_collection=layer_collection)
      # pylint: enable=g-long-lambda

      feedback_start_token = (hparams.use_start_token or
                              not hparams.feedback_end_token)
      if feedback_start_token:
        start_token = _pointer_feedback(
            features["targets_start_token"],
            encoder_output,
            shift=hparams.feedback_end_token)
      if hparams.feedback_end_token:
        end_token = _pointer_feedback(features["targets_end_token"],
                                      encoder_output)
      layer_inputs = [preproc(x)]
      if hparams.use_error_tags:
        error_tags = common_layers.shift_right_3d(
            common_layers.flatten4d3d(features["targets_error_tag"]))
        layer_inputs.append(preproc(error_tags))
      if feedback_start_token:
        layer_inputs.append(start_token)
      if hparams.feedback_end_token:
        layer_inputs.append(end_token)
      y = transformer_layers.transformer_ffn_layer(
          tf.concat(layer_inputs, axis=2),
          hparams,
          conv_padding="LEFT",
          nonpadding_mask=nonpadding,
          losses=losses,
          cache=cache,
          decode_loop_step=decode_loop_step,
          layer_collection=layer_collection)
      x = common_layers.layer_postprocess(x, y, hparams)
      return x
    def body(self, features):
        """
        Args:
            features["inputs"]:
            features["targets"]:
                tensors with shape [batch_size, ..., hidden_size]
        Return:
            decoder_outputs: pre-softmax activations of same size as inputs

        I assume that the input is a time series such that input size is
        [batch_size,sequence_length,hidden_size]
        """

        inputs = features["inputs"]
        targets = features["targets"]

        #tensor2tensor provides 4d tensors and axis=2 is useless
        #so I remove it for ease of handling
        original_shape = common_layers.shape_list(inputs)
        squeeze_shape_inputs = [x for x in \
                common_layers.shape_list(inputs) if x != 1]
        squeeze_shape_targets = [x for x in \
                common_layers.shape_list(targets) if x != 1]

        #squeeze unneeded dimensions
        inputs = tf.reshape(inputs, squeeze_shape_inputs)
        targets = tf.reshape(targets, squeeze_shape_targets)
        decoder_inputs = common_layers.shift_right_3d(targets)

        #encoder bias causes padding to be ignored
        inputs_embedding_mask = common_attention.\
                embedding_to_padding(inputs)
        self.encoder_attention_bias = common_attention.\
                attention_bias_ignore_padding(inputs_embedding_mask)
        #decoder bias causes targets to only attend to
        #previous positions (and itself)
        self.decoder_attention_bias = \
                common_attention.attention_bias_lower_triangle\
                (common_layers.shape_list(targets)[1])

        #process encoder and save the result for decoder to use
        #and process decoder
        self.encoder_outputs = self.adaptive_computation(inputs, self.encode)
        outputs = self.adaptive_computation(decoder_inputs, self.decode)
        #reshape output back to 4d
        outputs = tf.reshape(outputs, original_shape)
        return outputs
def transformer_fast_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.
  Args:
    targets: a Tensor.
    hparams: run hyperparameters
  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_position_forward_mask: mask Tensor for position-forward. [1, t, 1]
  """
    length = tf.shape(targets)[1]
    decoder_position_forward_mask = 1. / tf.expand_dims(
        tf.expand_dims(tf.to_float(tf.range(length)) + 1., 0), -1)  # [1, t, 1]

    decoder_input = common_layers.shift_right_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_position_forward_mask)
Beispiel #14
0
def transformer_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
    if hparams.proximity_bias:
        decoder_self_attention_bias += common_attention.attention_bias_proximal(
            tf.shape(targets)[1])
    decoder_input = common_layers.shift_right_3d(targets)
    return (decoder_input, decoder_self_attention_bias)
Beispiel #15
0
def transformer_edit_ops_layer(
    decoder_input,
    hparams,
    encoder_output,
    features,
    cache=None,
    decode_loop_step=None,
    nonpadding=None,
    losses=None,
    layer_collection=None,
):
    """Layer that conditions on the error tag and start and end token pointers."""
    if isinstance(encoder_output, list):  # Select forward encoder
        encoder_output = encoder_output[0]
    with tf.variable_scope('edit_ops_layer'):
        with tf.variable_scope('ffn'):
            x = decoder_input
            # Shorthand for layer preprocessing
            # pylint: disable=g-long-lambda
            preproc = lambda z: common_layers.layer_preprocess(
                z, hparams, layer_collection=layer_collection)
            # pylint: enable=g-long-lambda
            layer_inputs = [preproc(x)]
            error_tags = common_layers.shift_right_3d(
                common_layers.flatten4d3d(features['targets_error_tag']))
            layer_inputs.append(preproc(error_tags))
            y = transformer_layers.transformer_ffn_layer(
                tf.concat(layer_inputs, axis=2),
                hparams,
                conv_padding='LEFT',
                nonpadding_mask=nonpadding,
                losses=losses,
                cache=cache,
                decode_loop_step=decode_loop_step,
                layer_collection=layer_collection,
            )
            x = common_layers.layer_postprocess(x, y, hparams)
            return x
Beispiel #16
0
def attention_lm_prepare_decoder(targets, hparams):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly baises for diagonal alignments
  """
  if hparams.prepend_mode == "prepend_inputs_full_attention":
    decoder_self_attention_bias = (
        common_attention.attention_bias_prepended(
            common_attention.embedding_to_padding(targets)))
  else:
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias)
def transformer_prepare_decoder(targets, hparams, features=None):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(
            common_layers.shape_list(targets)[1]))
    if features and "targets_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        targets_segmentation = features["targets_segmentation"]
        targets_position = features["targets_position"]
        decoder_self_attention_bias += common_attention.attention_bias_same_segment(
            targets_segmentation, targets_segmentation)
    else:
        targets_position = None
    if hparams.proximity_bias:
        decoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(targets)[1])
    decoder_input = common_layers.shift_right_3d(targets)
    if hparams.pos == "timing":
        if targets_position is not None:
            decoder_input = common_attention.add_timing_signal_1d_given_position(
                decoder_input, targets_position)
        else:
            decoder_input = common_attention.add_timing_signal_1d(
                decoder_input)
    return (decoder_input, decoder_self_attention_bias)
Beispiel #18
0
def attention_lm_prepare_decoder(targets, hparams):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly baises for diagonal alignments
  """
  if hparams.prepend_mode == "prepend_inputs_full_attention":
    decoder_self_attention_bias = (
        common_attention.attention_bias_prepended(
            common_attention.embedding_to_padding(targets)))
  else:
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(
            common_layers.shape_list(targets)[1]))
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias)
Beispiel #19
0
def transformer_prepare_decoder(targets, hparams):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
  decoder_self_attention_bias = (
      common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
  if hparams.proximity_bias:
    decoder_self_attention_bias += common_attention.attention_bias_proximal(
        tf.shape(targets)[1])
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  # decoder_input = tf.Print(decoder_input, [tf.shape(decoder_input)], 
  #     summarize=1000, message="decoder_input")
  # decoder_self_attention_bias = tf.Print(decoder_self_attention_bias, [tf.shape(decoder_self_attention_bias)], 
  #     summarize=1000, message="decoder_self_attention_bias")
  return (decoder_input, decoder_self_attention_bias)
Beispiel #20
0
  def body(self, features):
    hparams = self.hparams
    is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN
    # Run the basic autoencoder part first.
    basic_result, losses = super(AutoencoderAutoregressive, self).body(features)
    if hparams.autoregressive_mode == "none":
      assert not hparams.autoregressive_forget_base
      return basic_result, losses
    shape = common_layers.shape_list(basic_result)
    basic1d = tf.reshape(basic_result, [shape[0], -1, shape[3]])
    # During autoregressive inference, don't resample.
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hasattr(hparams, "sampled_basic1d_tensor"):
        basic1d = hparams.sampled_basic1d_tensor
      else:
        hparams.sampled_basic1d_tensor = basic1d
    # Prepare inputs for autoregressive modes.
    if common_layers.shape_list(features["targets"])[1] == 1:
      # This happens on the first step of predicitions.
      assert hparams.mode == tf.estimator.ModeKeys.PREDICT
      features["targets"] = tf.zeros_like(basic_result)
    targets_dropout = common_layers.mix(
        features["targets"],
        tf.zeros_like(basic_result),
        hparams.bottleneck_warmup_steps,
        is_training,
        max_prob=1.0 - hparams.autoregressive_dropout,
        broadcast_last=True)
    # Sometimes it's useful to look at non-autoregressive evals.
    if (hparams.mode == tf.estimator.ModeKeys.EVAL and
        hparams.autoregressive_eval_pure_autoencoder):
      targets_dropout = tf.zeros_like(basic_result)
    # Now combine the basic reconstruction with shifted targets.
    targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[3]])
    targets_shifted = common_layers.shift_right_3d(targets1d)
    concat1d = tf.concat([basic1d, targets_shifted], axis=-1)
    # The forget_base hparam sets purely-autoregressive mode, no autoencoder.
    if hparams.autoregressive_forget_base:
      concat1d = tf.reshape(features["targets"], [shape[0], -1, shape[3]])
      concat1d = common_layers.shift_right_3d(concat1d)
    # The autoregressive part depends on the mode.
    if hparams.autoregressive_mode == "conv3":
      res = common_layers.conv1d(
          concat1d,
          shape[3],
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv3")
      return tf.reshape(res, shape), losses
    if hparams.autoregressive_mode == "conv5":
      res = common_layers.conv1d(
          concat1d,
          shape[3],
          5,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv5")
      return tf.reshape(res, shape), losses
    if hparams.autoregressive_mode == "sru":
      res = common_layers.conv1d(
          concat1d,
          shape[3],
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_sru_conv3")
      res = common_layers.sru(res)
      return tf.reshape(res, shape), losses

    raise ValueError(
        "Unsupported autoregressive mode: %s" % hparams.autoregressive_mode)
Beispiel #21
0
    def _build_decoder_agreement_loss(self, central_lang_tag="<en>"):
        """Builds an agreement loss that enforces consistency of the decodings.

    Args:
      central_lang_tag: A string with the tag of the central language.
        A ``central'' language (usually English) is the one that has parallel
        data with all other languages. It is used to protect supervised
        directions from gradients coming from auxiliary losses.

    Returns:
      loss: <float32> [] for the agreement losses.
    """
        # Get target embeddigns and vocab size.
        target_modality = self._problem_hparams.modality["targets"]
        target_modality_scope = self._variable_scopes[target_modality.name]
        target_embeddings = model_utils.get_embeddings(
            modality=target_modality,
            outer_scope=target_modality_scope,
            inner_scope="shared")
        target_vocab_size = target_modality._vocab_size  # pylint: disable=protected-access

        # Build auxiliary sequences (if necessary).
        aux_keys = self._build_aux_sequences(target_embeddings,
                                             target_vocab_size,
                                             central_lang_tag=central_lang_tag)

        # Build loss.
        aux_loss = 0.
        with tf.name_scope("dec_agreement_loss"):
            for key1, key2 in zip(aux_keys, aux_keys[::-1]):
                # Prepare for decoding.
                targets = self.dec_outputs[key2]["rnn_output"]
                targets_length = self.dec_outputs[key2]["length"]
                shifted_targets = common_layers.shift_right_3d(targets)
                hiddens = self.enc_outputs[key1].outputs
                hiddens_length = self.inputs[key1][1]
                enc_state = self.enc_outputs[key1].final_state
                # Decode.
                decode_func = self.get_decode_func(
                    target_embeddings,
                    shifted_targets,
                    targets_length,
                    hiddens,
                    hiddens_length,
                    enc_state,
                    mode=tf.estimator.ModeKeys.PREDICT,
                    decoder_iterations=self._hparams.aux_decode_length)
                aux_dec_outputs = decode_func()
                # Compute logits (protect central directions from the gradients).
                aux_logits_1 = model_utils.build_logits(
                    sequences=tf.expand_dims(aux_dec_outputs["rnn_output"],
                                             axis=2),
                    embeddings=target_embeddings,
                    vocab_size=target_vocab_size)
                aux_logits_1 = tf.where(self._is_central[key1],
                                        tf.stop_gradient(aux_logits_1),
                                        aux_logits_1)
                # Compute KL loss.
                logits = tf.squeeze(aux_logits_1, axis=2)
                if self._hparams.dec_agreement_loss_sparse:
                    target_ids = self.dec_outputs[key2]["sample_id"]
                    aux_loss = aux_loss + losses.CrossEntropyLoss(sparse=True)(
                        logits, target_ids, targets_length)
                else:
                    aux_logits_2 = tf.squeeze(self.dec_outputs[key2]["logits"],
                                              axis=2)
                    target_probs = tf.nn.softmax(aux_logits_2, axis=-1)
                    aux_loss = aux_loss + losses.CrossEntropyLoss(
                        sparse=False)(logits, target_probs, targets_length)

        aux_loss = self._hparams.dec_agreement_coeff * aux_loss

        return aux_loss
Beispiel #22
0
  def body(self, features):
    hparams = self.hparams
    # Run the basic autoencoder part first.
    basic_result, losses = super(AutoencoderAutoregressive, self).body(features)
    if hparams.autoregressive_mode == "none":
      assert not hparams.autoregressive_forget_base
      return basic_result, losses
    if "training" in losses:
      plain_training_loss = losses.pop("training")
      losses["plain"] = plain_training_loss
    res_shape = common_layers.shape_list(basic_result)
    vocab_size = self._problem_hparams.vocab_size["targets"]
    if hasattr(self._hparams, "vocab_divisor"):
      vocab_size += (-vocab_size) % self._hparams.vocab_divisor
    targets = tf.one_hot(features["targets_raw"], vocab_size)
    # Prepare inputs for autoregressive modes.
    if common_layers.shape_list(features["targets"])[1] == 1:
      # This happens on the first step of predicitions.
      assert hparams.mode == tf.estimator.ModeKeys.PREDICT
      targets = tf.zeros_like(basic_result)
    targets = self.embed(targets)
    if hparams.autoregressive_gumbel_sample:
      basic_hot = self.gumbel_sample(basic_result)
    else:
      basic_hot = basic_result
    basic_result = self.embed(basic_hot)
    shape = common_layers.shape_list(basic_result)
    basic1d = tf.reshape(basic_result, [shape[0], -1, shape[-1]])
    targets = tf.reshape(targets, common_layers.shape_list(basic_result))
    # During autoregressive inference, don't resample.
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hasattr(hparams, "sampled_basic1d_tensor"):
        basic1d = hparams.sampled_basic1d_tensor
      else:
        hparams.sampled_basic1d_tensor = basic1d
    # Sometimes it's useful to look at non-autoregressive evals.
    targets_dropout = targets
    if (hparams.mode == tf.estimator.ModeKeys.EVAL and
        hparams.autoregressive_eval_pure_autoencoder):
      targets_dropout = tf.zeros_like(basic_result)
    # Now combine the basic reconstruction with shifted targets.
    targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[-1]])
    targets_shifted = common_layers.shift_right_3d(targets1d)
    concat1d = tf.concat([basic1d, targets_shifted], axis=-1)
    # The forget_base hparam sets purely-autoregressive mode, no autoencoder.
    if hparams.autoregressive_forget_base:
      concat1d = tf.reshape(targets, [shape[0], -1, shape[-1]])
      concat1d = common_layers.shift_right_3d(concat1d)
    # The autoregressive part depends on the mode.
    if hparams.autoregressive_mode == "conv3":
      res = common_layers.conv1d(
          concat1d,
          hparams.hidden_size,
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv3")
      res = tf.layers.dense(res, vocab_size, name="autoregressive_final")
      return tf.reshape(res, res_shape), losses
    if hparams.autoregressive_mode == "conv5":
      res = common_layers.conv1d(
          concat1d,
          hparams.hidden_size,
          5,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv5")
      res = tf.layers.dense(res, vocab_size, name="autoregressive_final")
      return tf.reshape(res, res_shape), losses
    if hparams.autoregressive_mode == "sru":
      res = common_layers.conv1d(
          concat1d,
          hparams.hidden_size,
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_sru_conv3")
      res = common_layers.sru(res)
      res = tf.layers.dense(res, vocab_size, name="autoregressive_final")
      return tf.reshape(res, res_shape), losses

    raise ValueError(
        "Unsupported autoregressive mode: %s" % hparams.autoregressive_mode)
Beispiel #23
0
  def body(self, features):
    hparams = self.hparams
    # Run the basic autoencoder part first.
    basic_result, losses = super(AutoencoderAutoregressive, self).body(features)
    if hparams.autoregressive_mode == "none":
      assert not hparams.autoregressive_forget_base
      return basic_result, losses
    if "training" in losses:
      plain_training_loss = losses.pop("training")
      losses["plain"] = plain_training_loss
    res_shape = common_layers.shape_list(basic_result)
    vocab_size = self._problem_hparams.modality["targets"].top_dimensionality
    targets = tf.one_hot(features["targets_raw"], vocab_size)
    # Prepare inputs for autoregressive modes.
    if common_layers.shape_list(features["targets"])[1] == 1:
      # This happens on the first step of predicitions.
      assert hparams.mode == tf.estimator.ModeKeys.PREDICT
      targets = tf.zeros_like(basic_result)
    targets = self.embed(targets)
    if hparams.autoregressive_gumbel_sample:
      basic_hot = self.gumbel_sample(basic_result)
    else:
      basic_hot = basic_result
    basic_result = self.embed(basic_hot)
    shape = common_layers.shape_list(basic_result)
    basic1d = tf.reshape(basic_result, [shape[0], -1, shape[-1]])
    targets = tf.reshape(targets, common_layers.shape_list(basic_result))
    # During autoregressive inference, don't resample.
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
      if hasattr(hparams, "sampled_basic1d_tensor"):
        basic1d = hparams.sampled_basic1d_tensor
      else:
        hparams.sampled_basic1d_tensor = basic1d
    # Sometimes it's useful to look at non-autoregressive evals.
    targets_dropout = targets
    if (hparams.mode == tf.estimator.ModeKeys.EVAL and
        hparams.autoregressive_eval_pure_autoencoder):
      targets_dropout = tf.zeros_like(basic_result)
    # Now combine the basic reconstruction with shifted targets.
    targets1d = tf.reshape(targets_dropout, [shape[0], -1, shape[-1]])
    targets_shifted = common_layers.shift_right_3d(targets1d)
    concat1d = tf.concat([basic1d, targets_shifted], axis=-1)
    # The forget_base hparam sets purely-autoregressive mode, no autoencoder.
    if hparams.autoregressive_forget_base:
      concat1d = tf.reshape(targets, [shape[0], -1, shape[-1]])
      concat1d = common_layers.shift_right_3d(concat1d)
    # The autoregressive part depends on the mode.
    if hparams.autoregressive_mode == "conv3":
      res = common_layers.conv1d(
          concat1d,
          hparams.hidden_size,
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv3")
      res = tf.layers.dense(res, vocab_size, name="autoregressive_final")
      return tf.reshape(res, res_shape), losses
    if hparams.autoregressive_mode == "conv5":
      res = common_layers.conv1d(
          concat1d,
          hparams.hidden_size,
          5,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_conv5")
      res = tf.layers.dense(res, vocab_size, name="autoregressive_final")
      return tf.reshape(res, res_shape), losses
    if hparams.autoregressive_mode == "sru":
      res = common_layers.conv1d(
          concat1d,
          hparams.hidden_size,
          3,
          padding="LEFT",
          activation=common_layers.belu,
          name="autoregressive_sru_conv3")
      res = common_layers.sru(res)
      res = tf.layers.dense(res, vocab_size, name="autoregressive_final")
      return tf.reshape(res, res_shape), losses

    raise ValueError(
        "Unsupported autoregressive mode: %s" % hparams.autoregressive_mode)