Beispiel #1
0
def transformer_prepare_decoder(targets, hparams, features=None):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
  decoder_self_attention_bias = (
      common_attention.attention_bias_lower_triangle(
          common_layers.shape_list(targets)[1]))
  if features and "targets_segmentation" in features:
    # "Packed" dataset - keep the examples from seeing each other.
    targets_segmentation = features["targets_segmentation"]
    targets_position = features["targets_position"]
    decoder_self_attention_bias += common_attention.attention_bias_same_segment(
        targets_segmentation, targets_segmentation)
  else:
    targets_position = None
  if hparams.proximity_bias:
    decoder_self_attention_bias += common_attention.attention_bias_proximal(
        common_layers.shape_list(targets)[1])
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    if targets_position is not None:
      decoder_input = common_attention.add_timing_signal_1d_given_position(
          decoder_input, targets_position)
    else:
      decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias)
def attention_lm_moe_prepare_decoder(targets, hparams):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly baises for diagonal alignments
    pad_remover (expert_utils.PadRemover): an util object to remove padding
  """
  targets_pad_mask = common_attention.embedding_to_padding(targets)
  with tf.name_scope("pad_remover"):
    # Because of the shift_right, the <eos> token will be considered as
    # padding. In practice, it doesn't really matter, due to the triangular
    # mask, this token should never be attended.
    pad_remover = expert_utils.PadRemover(targets_pad_mask)

  if hparams.prepend_mode == "prepend_inputs_full_attention":
    decoder_self_attention_bias = (
        common_attention.attention_bias_prepend_inputs_full_attention(
            targets_pad_mask))
  else:
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias, pad_remover)
    def decode_inputs_to_outputs(self, decoder_embed_inputs, encoder_outputs,
                                 encoder_attn_bias, rule_id_input_placeholder,
                                 mem_contexts, mem_outputs, global_step):
        if self.hparams.pos == 'timing':
            decoder_embed_inputs = common_attention.add_timing_signal_1d(
                decoder_embed_inputs)
            print('Use positional encoding in decoder text.')

        decoder_attn_bias = common_attention.attention_bias_lower_triangle(
            tf.shape(decoder_embed_inputs)[1])
        decoder_embed_inputs = tf.nn.dropout(
            decoder_embed_inputs,
            1.0 - self.hparams.layer_prepostprocess_dropout)

        if 'rule' in self.model_config.memory:
            decoder_output, contexts = transformer.transformer_decoder2(
                decoder_embed_inputs, encoder_outputs, decoder_attn_bias,
                encoder_attn_bias, self.hparams)

            # encoder_gate_w = tf.get_variable('encoder_gate_w', shape=(
            #     1, self.model_config.dimension, 1))
            # encoder_gate_b = tf.get_variable('encoder_gate_b', shape=(1, 1, 1))
            # encoder_gate = tf.tanh(encoder_gate_b + tf.nn.conv1d(encoder_outputs, encoder_gate_w, 1, 'SAME'))
            # encoder_context_outputs = tf.expand_dims(tf.reduce_mean(encoder_outputs * encoder_gate, axis=1), axis=1)
            cur_context = contexts[0]  #tf.concat(contexts, axis=-1)
            cur_mem_contexts = tf.stack(self.embedding_fn(
                rule_id_input_placeholder, mem_contexts),
                                        axis=1)
            cur_mem_outputs = tf.stack(self.embedding_fn(
                rule_id_input_placeholder, mem_outputs),
                                       axis=1)

            bias = tf.expand_dims(-1e9 * tf.to_float(
                tf.equal(tf.stack(rule_id_input_placeholder, axis=1), 0)),
                                  axis=1)
            weights = tf.nn.softmax(
                bias +
                tf.matmul(cur_context, cur_mem_contexts, transpose_b=True))
            mem_output = tf.matmul(weights, cur_mem_outputs)

            temp_output = tf.concat((decoder_output, mem_output), axis=-1)
            w = tf.get_variable('w_ffn',
                                shape=(1, self.model_config.dimension * 2,
                                       self.model_config.dimension))
            # b = tf.get_variable('b_ffn', shape=(
            #     1, 1, self.model_config.dimension))
            mem_output = tf.nn.conv1d(temp_output, w, 1, 'SAME')
            g = tf.greater(
                global_step,
                tf.constant(2 * self.model_config.memory_prepare_step,
                            dtype=tf.int64))
            final_output = tf.cond(g, lambda: mem_output,
                                   lambda: decoder_output)
            return final_output, decoder_output, cur_context
        else:
            decoder_output = transformer.transformer_decoder(
                decoder_embed_inputs, encoder_outputs, decoder_attn_bias,
                encoder_attn_bias, self.hparams)
            final_output = decoder_output
            return final_output, decoder_output, None
Beispiel #4
0
def _apply_decoder_layer(translation_layer, input_tensor, output_depth,
                         encoder_depth):
    """Applies an decoder layer with basic arguments."""

    residual_tensor_values = np.random.rand(
        *[_BATCH_SIZE, _TOTAL_SEQUENCE_LENGTH, output_depth]) - .5
    residual_tensor = tf.constant(residual_tensor_values, dtype=tf.float32)
    encoder_output_values = np.random.rand(
        *[_BATCH_SIZE, _TOTAL_SEQUENCE_LENGTH, encoder_depth]) - .5
    encoder_output = tf.constant(encoder_output_values, dtype=tf.float32)
    encoder_block_outputs = [encoder_output] * _NUM_BLOCKS
    hparams = transformer.transformer_base()
    hparams.attention_dropout = 0
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(_TOTAL_SEQUENCE_LENGTH))

    output_tensor = translation_layer.apply_layer(
        input_tensor,
        residual_tensor,
        output_depth,
        None,
        hparams,
        "",
        nonpadding=None,
        mask_future=True,
        layer_preprocess_fn=None,
        postprocess_dropout=False,
        decoder_self_attention_bias=decoder_self_attention_bias,
        encoder_decoder_attention_bias=None,
        encoder_block_outputs=encoder_block_outputs,
        block_number=_BLOCK_NUMBER)

    return output_tensor
def attention_lm_moe_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly biases for diagonal alignments
    pad_remover (expert_utils.PadRemover): an util object to remove padding
  """
    targets_pad_mask = common_attention.embedding_to_padding(targets)
    with tf.name_scope("pad_remover"):
        # Because of the shift_right, the <eos> token will be considered as
        # padding. In practice, it doesn't really matter, due to the triangular
        # mask, this token should never be attended.
        pad_remover = expert_utils.PadRemover(targets_pad_mask)

    if hparams.prepend_mode == "prepend_inputs_full_attention":
        decoder_self_attention_bias = (
            common_attention.attention_bias_prepend_inputs_full_attention(
                targets_pad_mask))
    else:
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(
                tf.shape(targets)[1]))
    decoder_input = common_layers.shift_right_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias, pad_remover)
Beispiel #6
0
    def test_nas_decoder_resizing_output(self):
        hparams, wrong_size = self._get_wrong_output_dim_decoder_hparams()
        hparams.enforce_output_size = False
        input_tensor = tf.zeros([_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH])
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(_INPUT_LENGTH))
        with tf.variable_scope("wrong"):
            wrong_size_decoder_output = translation_nas_net.nas_decoder(
                decoder_input=input_tensor,
                encoder_cell_outputs=[input_tensor] *
                hparams.encoder_num_cells,
                decoder_self_attention_bias=decoder_self_attention_bias,
                encoder_decoder_attention_bias=None,
                hparams=hparams)

        # Now add the correction.
        hparams.enforce_output_size = True
        with tf.variable_scope("correct"):
            correct_size_decoder_output = translation_nas_net.nas_decoder(
                decoder_input=input_tensor,
                encoder_cell_outputs=[input_tensor] *
                hparams.encoder_num_cells,
                decoder_self_attention_bias=decoder_self_attention_bias,
                encoder_decoder_attention_bias=None,
                hparams=hparams)

        with self.test_session() as session:
            session.run(tf.global_variables_initializer())
            wrong_output, correct_output = session.run(
                [wrong_size_decoder_output, correct_size_decoder_output])
        self.assertEqual(wrong_output.shape,
                         (_BATCH_SIZE, _INPUT_LENGTH, wrong_size))
        self.assertEqual(correct_output.shape,
                         (_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH))
Beispiel #7
0
def transformer_prepare_decoder(targets, hparams, features=None):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
  decoder_self_attention_bias = (
      common_attention.attention_bias_lower_triangle(
          common_layers.shape_list(targets)[1]))
  if features and "targets_segmentation" in features:
    # "Packed" dataset - keep the examples from seeing each other.
    targets_segmentation = features["targets_segmentation"]
    targets_position = features["targets_position"]
    decoder_self_attention_bias += common_attention.attention_bias_same_segment(
        targets_segmentation, targets_segmentation)
  else:
    targets_position = None
  if hparams.proximity_bias:
    decoder_self_attention_bias += common_attention.attention_bias_proximal(
        common_layers.shape_list(targets)[1])
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    if targets_position is not None:
      decoder_input = common_attention.add_timing_signal_1d_given_position(
          decoder_input, targets_position)
    else:
      decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias)
def attention_lm_moe_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly baises for diagonal alignments
    pad_remover (expert_utils.PadRemover): an util object to remove padding
  """
    targets_pad_mask = common_attention.embedding_to_padding(targets)
    with tf.name_scope("pad_remover"):
        pad_remover = expert_utils.PadRemover(targets_pad_mask)

    if hparams.prepend_mode == "prepend_inputs_full_attention":
        decoder_self_attention_bias = (
            common_attention.attention_bias_prepended(targets_pad_mask))
    else:
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(
                tf.shape(targets)[1]))
    decoder_input = common_layers.shift_left_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias, pad_remover)
def prepare_decoder(targets, hparams):
  """Prepare decoder for images."""
  targets_shape = common_layers.shape_list(targets)
  channels = hparams.num_channels
  curr_infer_length = None

  # during training, images are [batch, IMG_LEN, IMG_LEN, 3].
  # At inference, they are [batch, curr_infer_length, 1, 1]
  if hparams.mode == tf.contrib.learn.ModeKeys.INFER:
    curr_infer_length = targets_shape[1]
    if hparams.block_rastor_scan:
      assert hparams.img_len*channels % hparams.query_shape[1] == 0
      assert hparams.img_len % hparams.query_shape[0] == 0
      total_block_width = hparams.img_len*channels
      # Decoding is in block rastor scan order. We divide the image into
      # hparams.query_shape blocks and then decode each block in rastor scan.
      # To make that compatible with our inference pipeline, pad the target so
      # that rows is a multiple of query_shape and columns is a multiple of
      # hparams.img_len*channels
      curr_infer_length = targets_shape[1]
      block_padding_factor = total_block_width * hparams.query_shape[0]
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % block_padding_factor],
          [0, 0], [0, 0]])

      num_blocks = total_block_width // hparams.query_shape[1]
      # Reshape the image to represent blocks
      target_blocks = tf.reshape(
          targets, [targets_shape[0], -1, num_blocks, hparams.query_shape[0],
                    hparams.query_shape[1]])
      # Transpose to read the image in 2D fashion.
      targets = tf.transpose(target_blocks, [0, 1, 3, 2, 4])
    else:
      # add padding to make sure the size of targets is a multiple of img_height
      # times number of channels. This is  needed for positional encodings and
      # for doing the RGB lookup.
      padding_factor = channels * hparams.img_len
      targets = tf.pad(targets, [
          [0, 0], [0, -curr_infer_length % padding_factor], [0, 0], [0, 0]])
    targets = tf.reshape(targets,
                         [targets_shape[0], -1, hparams.img_len, channels])
  # Preprocess image
  x = prepare_image(targets, hparams, name="dec_channels")
  x_shape = common_layers.shape_list(x)
  # mask out upper triangle to avoid looking into the future.
  bias = common_attention.attention_bias_lower_triangle(x_shape[1]*x_shape[2])
  if hparams.dec_attention_type == AttentionType.LOCAL_2D:
    x = common_attention.right_shift_blockwise(x, hparams.query_shape)
    x = add_pos_signals(x, hparams, "dec_pos")
  else:
    # Add position signals
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1]*x_shape[2], hparams.hidden_size])
    x = common_layers.shift_right_3d(x)
    x = tf.reshape(x, [targets_shape[0],
                       x_shape[1], x_shape[2], hparams.hidden_size])
    x = add_pos_signals(x, hparams, "dec_pos")
  return x, x_shape[1], x_shape[2], bias
def transformer_prepare_decoder(targets, hparams, features=None):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in decoder self-attention
  """
  if hparams.causal_decoder_self_attention:
    # Causal attention.
    if hparams.prepend_mode == "prepend_inputs_full_attention":
      decoder_self_attention_bias = (
          common_attention.attention_bias_prepend_inputs_full_attention(
              common_attention.embedding_to_padding(targets)))
    else:
      decoder_self_attention_bias = (
          common_attention.attention_bias_lower_triangle(
              common_layers.shape_list(targets)[1]))
  else:
    # Full attention.
    decoder_padding = common_attention.embedding_to_padding(targets)
    decoder_self_attention_bias = (
        common_attention.attention_bias_ignore_padding(decoder_padding))

  if features and "targets_segmentation" in features:
    # "Packed" dataset - keep the examples from seeing each other.
    targets_segmentation = features["targets_segmentation"]
    targets_position = features["targets_position"]
    decoder_self_attention_bias += common_attention.attention_bias_same_segment(
        targets_segmentation, targets_segmentation)
  else:
    targets_position = None
  if hparams.proximity_bias:
    decoder_self_attention_bias += common_attention.attention_bias_proximal(
        common_layers.shape_list(targets)[1])
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    if targets_position is not None:
      decoder_input = common_attention.add_timing_signal_1d_given_position(
          decoder_input, targets_position)
    else:
      decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  elif hparams.pos == "emb":
    decoder_input = common_attention.add_positional_embedding(
        decoder_input, hparams.max_length, "targets_positional_embedding",
        targets_position)

  if hparams.activation_dtype == "bfloat16":
    decoder_self_attention_bias = tf.cast(decoder_self_attention_bias,
                                          tf.bfloat16)
  return (decoder_input, decoder_self_attention_bias)
def prepare_decoder(targets, target_space_emb):
    """Prepare decoder."""
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
    target_space_emb = tf.reshape(target_space_emb, [1, 1, -1])
    target_space_emb = tf.tile(target_space_emb, [tf.shape(targets)[0], 1, 1])
    decoder_input = common_layers.shift_right_3d(targets,
                                                 pad_value=target_space_emb)
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)
Beispiel #12
0
def prepare_decoder(targets, target_space_emb):
  """Prepare decoder."""
  decoder_self_attention_bias = (
      common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
  target_space_emb = tf.reshape(target_space_emb, [1, 1, -1])
  target_space_emb = tf.tile(target_space_emb, [tf.shape(targets)[0], 1, 1])
  decoder_input = common_layers.shift_right_3d(
      targets, pad_value=target_space_emb)
  decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias)
Beispiel #13
0
def transformer_prepare_decoder(targets, hparams, features=None):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(
            common_layers.shape_list(targets)[1]))
    if features and "targets_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        targets_segmentation = features["targets_segmentation"]
        targets_position = features["targets_position"]
        decoder_self_attention_bias += common_attention.attention_bias_same_segment(
            targets_segmentation, targets_segmentation)
    else:
        targets_position = None
    if hparams.proximity_bias:
        decoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(targets)[1])
    decoder_input = common_layers.shift_right_3d(targets)
    #if hparams.pos == "timing":
    #  if targets_position is not None:
    #    decoder_input = common_attention.add_timing_signal_1d_given_position(
    #        decoder_input, targets_position)
    #  else:
    #    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    raw_decoder_input = common_layers.shift_right(features['targets_raw'])
    terminal_decoder_bias, nonterminal_decoder_bias = _get_t_nt_bias(
        raw_decoder_input, hparams, decoder_self_attention_bias)
    pop_decoder_bias = _get_pop_bias(raw_decoder_input, hparams)
    raw_decoder_input = tf.squeeze(raw_decoder_input, axis=[-2, -1])
    pos_signals = generate_positional_signals(raw_decoder_input, hparams,
                                              terminal_decoder_bias,
                                              nonterminal_decoder_bias)
    pos_embeddings = generate_positional_embeddings(pos_signals,
                                                    hparams.decoder_pos,
                                                    hparams)
    if "sum" in hparams.decoder_pos_integration:
        decoder_input = decoder_input + pos_embeddings
    elif "ffn" in hparams.decoder_pos_integration:
        with tf.variable_scope("decoder_pos_ffn"):
            decoder_input = tf.concat([decoder_input, pos_embeddings], axis=2)
            decoder_input = transformer_ffn_layer(decoder_input,
                                                  hparams,
                                                  conv_padding="LEFT")
    return (decoder_input, decoder_self_attention_bias, terminal_decoder_bias,
            nonterminal_decoder_bias, pop_decoder_bias, pos_signals)
Beispiel #14
0
def transformer_prepare_decoder(targets, hparams):
    """Copied from tensor2tensor.models.transformer."""
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
    if hparams.proximity_bias:
        decoder_self_attention_bias += common_attention.attention_bias_proximal(
            tf.shape(targets)[1])
    decoder_input = common_layers.shift_right_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)
Beispiel #15
0
    def testMultiheadSelfAttentionMemoryEfficient(self):
        if tf.executing_eagerly():
            return  # don't run test in Eager mode

        num_heads = 4
        io_size = 16
        batch = 2
        length = 7
        head_size = 5
        x = np.random.rand(batch, length, io_size)
        dy = np.random.rand(batch, length, io_size)
        with self.session() as session:
            x = tf.to_float(x)
            dy = tf.to_float(dy)
            bias = common_attention.attention_bias_lower_triangle(length)
            wqkv = tf.get_variable(
                "wqkv", [num_heads, 1, io_size, 3 * head_size],
                initializer=tf.random_normal_initializer(stddev=io_size**-0.5))
            wo = tf.get_variable("wo", [num_heads, 1, head_size, io_size],
                                 initializer=tf.random_normal_initializer(
                                     stddev=(head_size * num_heads)**-0.5))
            norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
            y = common_attention.multihead_self_attention_memory_efficient(
                x,
                bias,
                num_heads,
                head_size=head_size,
                forget=False,
                test_vars=(wqkv, wo, norm_scale, norm_bias))
            y_forget = common_attention.multihead_self_attention_memory_efficient(
                x,
                bias,
                num_heads,
                head_size=head_size,
                forget=True,
                test_vars=(wqkv, wo, norm_scale, norm_bias))
            dx, dwqkv, dwo, dnorm_scale, dnorm_bias = tf.gradients(
                ys=[y], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy])
            dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f = tf.gradients(
                ys=[y_forget],
                xs=[x, wqkv, wo, norm_scale, norm_bias],
                grad_ys=[dy])
            session.run(tf.global_variables_initializer())
            (y, y_forget, dx, dwqkv, dwo, dnorm_scale, dnorm_bias, dx_f,
             dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f) = session.run([
                 y, y_forget, dx, dwqkv, dwo, dnorm_scale, dnorm_bias, dx_f,
                 dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f
             ])
        self.assertAllClose(y, y_forget)
        self.assertAllClose(dwo, dwo_f)
        self.assertAllClose(dwqkv, dwqkv_f)
        self.assertAllClose(dnorm_scale, dnorm_scale_f)
        self.assertAllClose(dnorm_bias, dnorm_bias_f)
        self.assertAllClose(dx, dx_f)
 def decode_inputs_to_outputs(self, kword_input, abstr_outputs, abstr_bias, hist_vector=None):
     if self.hparams.pos == 'timing':
         kword_input = common_attention.add_timing_signal_1d(kword_input)
     kword_tribias = common_attention.attention_bias_lower_triangle(tf.shape(kword_input)[1])
     kword_input = tf.nn.dropout(
         kword_input, 1.0 - self.hparams.layer_prepostprocess_dropout)
     kword_output = transformer.transformer_decoder(
         kword_input, abstr_outputs, kword_tribias,
         abstr_bias, self.hparams,
         hist_vector=hist_vector)
     return kword_output
def decode(cond_vec, cond_add, gold, c, ed, hparams):
    """Transformer decoder."""
    drop_gold = tf.nn.dropout(gold, 1.0 - hparams.layer_prepostprocess_dropout)
    decoder_input = common_layers.shift_right(drop_gold, pad_value=cond_vec)
    if cond_add is not None:
        decoder_input += cond_add
    decoder_input = tf.squeeze(decoder_input, axis=2)
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    bias = common_attention.attention_bias_lower_triangle(tf.shape(gold)[1])
    if c is not None and len(c.get_shape()) > 3:
        c = tf.squeeze(c, axis=2)
    return transformer.transformer_decoder(decoder_input, c, bias, ed, hparams)
def get_self_attention_bias(x):
  """Creates masked self attention bias.

  Args:
    x: A tensor of shape [batch, length, depth]

  Returns:
    self_attention_bias: A tensor of shape [length, length, 1]
  """

  x_shape = common_layers.shape_list(x)
  self_attention_bias = common_attention.attention_bias_lower_triangle(
      x_shape[1])
  return self_attention_bias
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None):
  """Complete attention layer with preprocessing."""
  separabilities = [hparams.separability, hparams.separability]
  if hparams.separability < 0:
    separabilities = [hparams.separability - 1, hparams.separability]
  targets_timed = common_layers.subseparable_conv_block(
      common_layers.add_timing_signal(targets_shifted),
      hparams.model_d, [((1, 1), (5, 1)), ((4, 1), (5, 1))],
      normalizer_fn=norm_fn,
      padding="LEFT",
      separabilities=separabilities,
      name="targets_time")
  if hparams.attention_type == "transformer":
    targets_timed = tf.squeeze(targets_timed, 2)
    target_shape = tf.shape(targets_timed)
    targets_segment = tf.zeros([target_shape[0], target_shape[1]])
    target_attention_bias = common_attention.attention_bias_lower_triangle(
        target_shape[1])
    inputs_encoded = common_layers.flatten4d3d(inputs_encoded)
    # TODO(jbaccash): use input bias parameter. This code seems to assume fixed
    # size inputs.
    inputs_attention_bias = tf.zeros([
        tf.shape(inputs_encoded)[0], hparams.num_heads,
        tf.shape(targets_segment)[1],
        tf.shape(inputs_encoded)[1]
    ])

    qv = common_attention.multihead_attention(
        targets_timed,
        None,
        target_attention_bias,
        hparams.model_d,
        hparams.model_d,
        hparams.model_d,
        hparams.num_heads,
        hparams.attention_dropout,
        name="self_attention")
    qv = common_attention.multihead_attention(
        qv,
        inputs_encoded,
        inputs_attention_bias,
        hparams.model_d,
        hparams.model_d,
        hparams.model_d,
        hparams.num_heads,
        hparams.attention_dropout,
        name="encdec_attention")
    return tf.expand_dims(qv, 2)
  else:
    raise ValueError("Unsupported attention_type: %s" % hparams.attention_type)
def get_self_attention_bias(x):
    """Creates masked self attention bias.

  Args:
    x: A tensor of shape [batch, length, depth]

  Returns:
    self_attention_bias: A tensor of shape [length, length, 1]
  """

    x_shape = common_layers.shape_list(x)
    self_attention_bias = common_attention.attention_bias_lower_triangle(
        x_shape[1])
    return self_attention_bias
Beispiel #21
0
def transformer_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
    if hparams.proximity_bias:
        decoder_self_attention_bias += common_attention.attention_bias_proximal(
            tf.shape(targets)[1])
    decoder_input = common_layers.shift_right_3d(targets)
    return (decoder_input, decoder_self_attention_bias)
 def decode_syntax_template(self, trg_syntax_emb):
     with tf.variable_scope('syntax_decoder', reuse=tf.AUTO_REUSE):
         trg_syntax_emb = common_attention.add_timing_signal_1d(
             trg_syntax_emb)
         trg_syntax_emb = self.update_embedding(trg_syntax_emb)
         trg_syntax_length = tf.shape(trg_syntax_emb)[1]
         trg_self_attention_bias = common_attention.attention_bias_lower_triangle(
             trg_syntax_length)
         trg_syntax_outputs = transformer.transformer_decoder(
             decoder_input=trg_syntax_emb,
             decoder_self_attention_bias=trg_self_attention_bias,
             encoder_output=self.shared_tensors['src_outputs'],
             encoder_decoder_attention_bias=self.shared_tensors['src_bias'],
             hparams=self.hparams,
             external_output=self.
             shared_tensors['template_prev_simp_outputs'],
             external_bias=self.shared_tensors['template_simp_bias'])
     return trg_syntax_outputs
Beispiel #23
0
def attention_lm_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly baises for diagonal alignments
  """
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
    decoder_input = common_layers.shift_left_3d(targets)
    if hparams.pos == "timing":
        decoder_input = common_attention.add_timing_signal_1d(decoder_input)
    return (decoder_input, decoder_self_attention_bias)
Beispiel #24
0
    def test_calculate_branching_model_parameters_decoder_resize(
            self, enforce_output_size):
        tf.reset_default_graph()

        hparams, _ = self._get_wrong_output_dim_decoder_hparams()
        hparams.enforce_output_size = enforce_output_size
        hparams.decoder_left_norms = [translation_nas_net.NO_NORM_KEY] * 5
        hparams.decoder_right_norms = [translation_nas_net.NO_NORM_KEY] * 5

        # Get predicted number of parameters.
        (predicted_num_params, _, _,
         _) = translation_nas_net.calculate_branching_model_parameters(
             encoding_depth=_EMBEDDING_DEPTH,
             left_inputs=hparams.decoder_left_inputs,
             left_layers=hparams.decoder_left_layers,
             left_output_dims=hparams.decoder_left_output_dims,
             right_inputs=hparams.decoder_right_inputs,
             right_layers=hparams.decoder_right_layers,
             right_output_dims=hparams.decoder_right_output_dims,
             combiner_functions=hparams.decoder_combiner_functions,
             final_combiner_function=hparams.decoder_final_combiner_function,
             layer_registry=layers.DECODER_LAYERS,
             num_cells=hparams.decoder_num_cells,
             encoder_depth=_EMBEDDING_DEPTH,
             enforce_output_size=enforce_output_size)

        # Count graph variables.
        input_tensor = tf.zeros([_BATCH_SIZE, _INPUT_LENGTH, _EMBEDDING_DEPTH])
        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(_INPUT_LENGTH))
        _ = translation_nas_net.nas_decoder(
            decoder_input=input_tensor,
            encoder_cell_outputs=[input_tensor] * hparams.encoder_num_cells,
            decoder_self_attention_bias=decoder_self_attention_bias,
            encoder_decoder_attention_bias=None,
            hparams=hparams,
            final_layer_norm=False)
        trainable_variables_list = tf.trainable_variables()
        empirical_num_params = 0
        for variable_tensor in trainable_variables_list:
            empirical_num_params += _list_product(
                variable_tensor.shape.as_list())

        self.assertEqual(empirical_num_params, predicted_num_params)
 def testMultiheadSelfAttentionMemoryEfficient(self):
   num_heads = 4
   io_size = 16
   batch = 2
   length = 7
   head_size = 5
   x = np.random.rand(batch, length, io_size)
   dy = np.random.rand(batch, length, io_size)
   with self.test_session() as session:
     x = tf.to_float(x)
     dy = tf.to_float(dy)
     bias = common_attention.attention_bias_lower_triangle(length)
     wqkv = tf.get_variable(
         "wqkv", [num_heads, 1, io_size, 3 * head_size],
         initializer=tf.random_normal_initializer(stddev=io_size**-0.5))
     wo = tf.get_variable(
         "wo", [num_heads, 1, head_size, io_size],
         initializer=tf.random_normal_initializer(
             stddev=(head_size * num_heads)**-0.5))
     norm_scale, norm_bias = common_layers.layer_norm_vars(io_size)
     y = common_attention.multihead_self_attention_memory_efficient(
         x, bias, num_heads, head_size=head_size, forget=False,
         test_vars=(wqkv, wo, norm_scale, norm_bias))
     y_forget = common_attention.multihead_self_attention_memory_efficient(
         x, bias, num_heads, head_size=head_size, forget=True,
         test_vars=(wqkv, wo, norm_scale, norm_bias))
     dx, dwqkv, dwo, dnorm_scale, dnorm_bias = tf.gradients(
         ys=[y], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy])
     dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f = tf.gradients(
         ys=[y_forget], xs=[x, wqkv, wo, norm_scale, norm_bias], grad_ys=[dy])
     session.run(tf.global_variables_initializer())
     (y, y_forget,
      dx, dwqkv, dwo, dnorm_scale, dnorm_bias,
      dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f) = session.run(
          [y, y_forget,
           dx, dwqkv, dwo, dnorm_scale, dnorm_bias,
           dx_f, dwqkv_f, dwo_f, dnorm_scale_f, dnorm_bias_f])
   self.assertAllClose(y, y_forget)
   self.assertAllClose(dwo, dwo_f)
   self.assertAllClose(dwqkv, dwqkv_f)
   self.assertAllClose(dnorm_scale, dnorm_scale_f)
   self.assertAllClose(dnorm_bias, dnorm_bias_f)
   self.assertAllClose(dx, dx_f)
def transformer_prepare_decoder(targets_emb_var,
                                targets,
                                hparams,
                                features=None):
    """Prepare one shard of the model for the decoder.

  Args:
    targets_emb_var: a Tensor
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in decoder self-attention
  """
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(
            common_layers.shape_list(targets)[1]))

    if features and "targets_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        targets_segmentation = features["targets_segmentation"]
        targets_position = features["targets_position"]
        decoder_self_attention_bias += common_attention.attention_bias_same_segment(
            targets_segmentation, targets_segmentation)
    else:
        targets_position = None
    decoder_input = tf.gather(targets_emb_var,
                              common_layers.shift_right_2d(targets))
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        decoder_input = common_attention.add_positional_embedding(
            decoder_input, hparams.max_length, "positional_embedding",
            targets_position)

    if hparams.activation_dtype == "bfloat16":
        decoder_self_attention_bias = tf.cast(decoder_self_attention_bias,
                                              tf.bfloat16)
    return (decoder_input, decoder_self_attention_bias)
Beispiel #27
0
def transformer_prepare_decoder(targets, hparams):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
  decoder_self_attention_bias = (
      common_attention.attention_bias_lower_triangle(tf.shape(targets)[1]))
  if hparams.proximity_bias:
    decoder_self_attention_bias += common_attention.attention_bias_proximal(
        tf.shape(targets)[1])
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  # decoder_input = tf.Print(decoder_input, [tf.shape(decoder_input)], 
  #     summarize=1000, message="decoder_input")
  # decoder_self_attention_bias = tf.Print(decoder_self_attention_bias, [tf.shape(decoder_self_attention_bias)], 
  #     summarize=1000, message="decoder_self_attention_bias")
  return (decoder_input, decoder_self_attention_bias)
Beispiel #28
0
def transformer_prepare_decoder(targets, hparams):
    """Prepare one shard of the model for the decoder.
  
    Args:
      targets: a Tensor.
      hparams: run hyperparameters
  
    Returns:
      decoder_input: a Tensor, bottom of decoder stack
      decoder_self_attention_bias: a bias tensor for use in encoder self-attention
    """
    decoder_self_attention_bias = (comm_attn.attention_bias_lower_triangle(
        tf.shape(targets)[1]))
    if hparams.proximity_bias:
        decoder_self_attention_bias += comm_attn.attention_bias_proximal(
            tf.shape(targets)[1])
    decoder_input = common_layers.shift_left_3d(targets)
    if hparams.pos == 'timing':
        decoder_input = comm_attn.add_timing_signal_1d(decoder_input)
    # Putting this here since always called immediately after...
    decoder_input = with_dropout(decoder_input, hparams)

    return DecoderState(input=decoder_input,
                        self_attn_bias=decoder_self_attention_bias)
Beispiel #29
0
def attention_lm_prepare_decoder(targets, hparams):
  """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a Tensor, containing large negative values
    to implement masked attention and possibly baises for diagonal alignments
  """
  if hparams.prepend_mode == "prepend_inputs_full_attention":
    decoder_self_attention_bias = (
        common_attention.attention_bias_prepended(
            common_attention.embedding_to_padding(targets)))
  else:
    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(
            common_layers.shape_list(targets)[1]))
  decoder_input = common_layers.shift_right_3d(targets)
  if hparams.pos == "timing":
    decoder_input = common_attention.add_timing_signal_1d(decoder_input)
  return (decoder_input, decoder_self_attention_bias)
    def _fast_decode(self,
                     features,
                     decode_length,
                     beam_size=1,
                     top_beams=1,
                     alpha=1.0):
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams
        target_modality = self._problem_hparams.modality["targets"]
        if "targets_segmentation" in features:
            raise NotImplementedError(
                "Decoding not supported on packed datasets "
                " If you want to decode from a dataset, use the non-packed version"
                " of the dataset when decoding.")
        if self.has_input:
            inputs = features["inputs"]
            if target_modality.is_class_modality:
                decode_length = 1
            else:
                decode_length = (common_layers.shape_list(inputs)[1] +
                                 features.get("decode_length", decode_length))

            contexts = {}
            for feature_name in features:
                if 'context' in feature_name and 'raw' not in feature_name:
                    contexts[feature_name] = features[feature_name]

            inputs = tf.expand_dims(inputs, axis=1)
            if len(inputs.shape) < 5:
                inputs = tf.expand_dims(inputs, axis=4)
            s = common_layers.shape_list(inputs)
            batch_size = s[0]
            inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
            # _shard_features called to ensure that the variable names match
            inputs = self._shard_features({"inputs": inputs})["inputs"]
            input_modality = self._problem_hparams.modality["inputs"]

            context_modality = {}
            for context_name in contexts:
                if context_name in self._problem_hparams.modality:
                    context_modality[
                        context_name] = self._problem_hparams.modality[
                            context_name]
                else:
                    context_modality[context_name] = input_modality

            with tf.variable_scope(input_modality.name, reuse=tf.AUTO_REUSE):
                inputs = input_modality.bottom_sharded(inputs, dp)

            for feature_name in contexts:
                with tf.variable_scope(context_modality[feature_name].name,
                                       reuse=tf.AUTO_REUSE):
                    contexts[feature_name] = context_modality[
                        feature_name].bottom_sharded(contexts[feature_name],
                                                     dp)

            contexts_list = [
                contexts[feature_name][0] for feature_name in contexts
            ]
            contexts = tf.concat(contexts_list, axis=1)
            inputs = [tf.concat([contexts, inputs[0]], axis=1)]

            with tf.variable_scope("body"):
                encoder_output, encoder_decoder_attention_bias = dp(
                    self.encode,
                    inputs,
                    features["target_space_id"],
                    hparams,
                    features=features)
            encoder_output = encoder_output[0]
            encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
            partial_targets = None
        else:
            # The problem has no inputs.
            encoder_output = None
            encoder_decoder_attention_bias = None

            # Prepare partial targets.
            # In either features["inputs"] or features["targets"].
            # We force the outputs to begin with these sequences.
            partial_targets = features.get("inputs")
            if partial_targets is None:
                partial_targets = features["targets"]
            assert partial_targets is not None
            partial_targets = common_layers.expand_squeeze_to_nd(
                partial_targets, 2)
            partial_targets = tf.to_int64(partial_targets)
            partial_targets_shape = common_layers.shape_list(partial_targets)
            partial_targets_length = partial_targets_shape[1]
            decode_length = (partial_targets_length +
                             features.get("decode_length", decode_length))
            batch_size = partial_targets_shape[0]

        if hparams.pos == "timing":
            positional_encoding = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)
        elif hparams.pos == "emb":
            positional_encoding = common_attention.add_positional_embedding(
                tf.zeros([1, decode_length, hparams.hidden_size]),
                hparams.max_length, "body/targets_positional_embedding", None)
        else:
            positional_encoding = None

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.
            This includes:
              - Embedding the ids.
              - Flattening to 3D tensor.
              - Optionally adding timing signals.
            Args:
              targets: inputs ids to the decoder. [batch_size, 1]
              i: scalar, Step number of the decoding loop.
            Returns:
              Processed targets [batch_size, 1, hidden_dim]
            """
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if positional_encoding is not None:
                targets += positional_encoding[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = dp(self.decode,
                                  targets,
                                  cache.get("encoder_output"),
                                  cache.get("encoder_decoder_attention_bias"),
                                  bias,
                                  hparams,
                                  cache,
                                  nonpadding=features_to_nonpadding(
                                      features, "targets"))

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            ret = tf.squeeze(logits, axis=[1, 2, 3])
            if partial_targets is not None:
                # If the position is within the given partial targets, we alter the
                # logits to always return those values.
                # A faster approach would be to process the partial targets in one
                # iteration in order to fill the corresponding parts of the cache.
                # This would require broader changes, though.
                vocab_size = tf.shape(ret)[1]

                def forced_logits():
                    return tf.one_hot(
                        tf.tile(partial_targets[:, i], [beam_size]),
                        vocab_size, 0.0, -1e9)

                ret = tf.cond(tf.less(i, partial_targets_length),
                              forced_logits, lambda: ret)
            return ret, cache

        ret = fast_decode(
            encoder_output=encoder_output,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            symbols_to_logits_fn=symbols_to_logits_fn,
            hparams=hparams,
            decode_length=decode_length,
            vocab_size=target_modality.top_dimensionality,
            beam_size=beam_size,
            top_beams=top_beams,
            alpha=alpha,
            batch_size=batch_size,
            force_decode_length=self._decode_hparams.force_decode_length)
        if partial_targets is not None:
            if beam_size <= 1 or top_beams <= 1:
                ret["outputs"] = ret["outputs"][:, partial_targets_length:]
            else:
                ret["outputs"] = ret["outputs"][:, :, partial_targets_length:]
        return ret
Beispiel #31
0
  def _fast_decode(self,
                   features,
                   decode_length,
                   beam_size=1,
                   top_beams=1,
                   alpha=1.0):
    """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
       samples: an integer `Tensor`. Top samples from the beam search

    Raises:
      NotImplementedError: If there are multiple data shards.
    """
    if self._num_datashards != 1:
      raise NotImplementedError("Fast decoding only supports a single shard.")
    dp = self._data_parallelism
    hparams = self._hparams

    inputs = features["inputs"]
    batch_size = common_layers.shape_list(inputs)[0]
    target_modality = self._problem_hparams.target_modality
    if target_modality.is_class_modality:
      decode_length = 1
    else:
      decode_length = common_layers.shape_list(inputs)[1] + decode_length

    # TODO(llion): Clean up this reshaping logic.
    inputs = tf.expand_dims(inputs, axis=1)
    if len(inputs.shape) < 5:
      inputs = tf.expand_dims(inputs, axis=4)
    s = common_layers.shape_list(inputs)
    inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
    # _shard_features called to ensure that the variable names match
    inputs = self._shard_features({"inputs": inputs})["inputs"]
    input_modality = self._problem_hparams.input_modality["inputs"]
    with tf.variable_scope(input_modality.name):
      inputs = input_modality.bottom_sharded(inputs, dp)
    with tf.variable_scope("body"):
      encoder_output, encoder_decoder_attention_bias = dp(
          self.encode, inputs, features["target_space_id"], hparams,
          features=features)
    encoder_output = encoder_output[0]
    encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]

    if hparams.pos == "timing":
      timing_signal = common_attention.get_timing_signal_1d(
          decode_length + 1, hparams.hidden_size)

    def preprocess_targets(targets, i):
      """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
      # _shard_features called to ensure that the variable names match
      targets = self._shard_features({"targets": targets})["targets"]
      with tf.variable_scope(target_modality.name):
        targets = target_modality.targets_bottom_sharded(targets, dp)[0]
      targets = common_layers.flatten4d3d(targets)

      # TODO(llion): Explain! Is this even needed?
      targets = tf.cond(
          tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)

      if hparams.pos == "timing":
        targets += timing_signal[:, i:i + 1]
      return targets

    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(decode_length))
    if hparams.proximity_bias:
      decoder_self_attention_bias += common_attention.attention_bias_proximal(
          decode_length)

    def symbols_to_logits_fn(ids, i, cache):
      """Go from ids to logits for next symbol."""
      ids = ids[:, -1:]
      targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      targets = preprocess_targets(targets, i)

      bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

      with tf.variable_scope("body"):
        body_outputs = dp(
            self.decode, targets, cache["encoder_output"],
            cache["encoder_decoder_attention_bias"], bias, hparams, cache,
            nonpadding=_features_to_nonpadding(features, "targets"))

      with tf.variable_scope(target_modality.name):
        logits = target_modality.top_sharded(body_outputs, None, dp)[0]

      return tf.squeeze(logits, axis=[1, 2, 3]), cache

    key_channels = hparams.attention_key_channels or hparams.hidden_size
    value_channels = hparams.attention_value_channels or hparams.hidden_size
    num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

    cache = {
        "layer_%d" % layer: {
            "k": tf.zeros([batch_size, 0, key_channels]),
            "v": tf.zeros([batch_size, 0, value_channels]),
        }
        for layer in range(num_layers)
    }

    # Set 2nd dim to None since it's not invariant in the tf.while_loop
    # Note: Tensor.set_shape() does not work here since it merges shape info.
    # TODO(llion); Find a more robust solution.
    # pylint: disable=protected-access
    if not context.in_eager_mode():
      for layer in cache:
        cache[layer]["k"]._shape = tf.TensorShape([None, None, key_channels])
        cache[layer]["v"]._shape = tf.TensorShape([None, None, value_channels])
    # pylint: enable=protected-access
    cache["encoder_output"] = encoder_output
    cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

    if beam_size > 1:  # Beam Search
      target_modality = (
          self._hparams.problems[self._problem_idx].target_modality)
      vocab_size = target_modality.top_dimensionality
      initial_ids = tf.zeros([batch_size], dtype=tf.int32)
      decoded_ids, scores = beam_search.beam_search(
          symbols_to_logits_fn,
          initial_ids,
          beam_size,
          decode_length,
          vocab_size,
          alpha,
          states=cache,
          stop_early=(top_beams == 1))

      if top_beams == 1:
        decoded_ids = decoded_ids[:, 0, 1:]
      else:
        decoded_ids = decoded_ids[:, :top_beams, 1:]
    else:  # Greedy

      def inner_loop(i, next_id, decoded_ids, cache):
        logits, cache = symbols_to_logits_fn(next_id, i, cache)
        temperature = (0.0 if hparams.sampling_method == "argmax" else
                       hparams.sampling_temp)
        next_id = tf.expand_dims(
            common_layers.sample_with_temperature(logits, temperature), axis=1)
        decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
        return i + 1, next_id, decoded_ids, cache

      decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
      scores = None
      next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
      _, _, decoded_ids, _ = tf.while_loop(
          # TODO(llion): Early stopping.
          lambda i, *_: tf.less(i, decode_length),
          inner_loop,
          [tf.constant(0), next_id, decoded_ids, cache],
          shape_invariants=[
              tf.TensorShape([]),
              tf.TensorShape([None, None]),
              tf.TensorShape([None, None]),
              nest.map_structure(lambda t: tf.TensorShape(t.shape), cache),
          ])

    return decoded_ids, scores
Beispiel #32
0
 def _bias(x):
     return common_attention.attention_bias_lower_triangle(
         tf.shape(x)[1])
Beispiel #33
0
    def _fast_decode(self,
                     features,
                     decode_length,
                     beam_size=1,
                     top_beams=1,
                     alpha=1.0):
        """
		Fast decoding.
		Implements both greedy and beam search decoding, uses beam search iff
		beam_size > 1, otherwise beam search related arguments are ignored.
		Args:
			features: a map of string to model  features.
			decode_length: an integer.  How many additional timesteps to decode.
			beam_size: number of beams.
			top_beams: an integer. How many of the beams to return.
			alpha: Float that controls the length penalty. larger the alpha, stronger
			the preference for slonger translations.
		Returns:
			 samples: an integer `Tensor`. Top samples from the beam search
		Raises:
			NotImplementedError: If there are multiple data shards.
		"""
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams

        inputs = features["inputs"]
        batch_size = common_layers.shape_list(inputs)[0]
        target_modality = self._problem_hparams.target_modality
        if target_modality.is_class_modality:
            decode_length = 1
        else:
            decode_length = common_layers.shape_list(inputs)[1] + decode_length

        # TODO(llion): Clean up this reshaping logic.
        inputs = tf.expand_dims(inputs, axis=1)
        if len(inputs.shape) < 5:
            inputs = tf.expand_dims(inputs, axis=4)
        s = common_layers.shape_list(inputs)
        inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
        # _shard_features called to ensure that the variable names match
        inputs = self._shard_features({"inputs": inputs})["inputs"]
        input_modality = self._problem_hparams.input_modality["inputs"]
        with tf.variable_scope(input_modality.name):
            inputs = input_modality.bottom_sharded(inputs, dp)
        with tf.variable_scope("body"):
            encoder_output, encoder_decoder_attention_bias = dp(
                self.encode,
                inputs,
                features["target_space_id"],
                hparams,
                features=features)
        encoder_output = encoder_output[0]
        encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]

        if hparams.pos == "timing":
            timing_signal = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.
			This includes:
			- Embedding the ids.
			- Flattening to 3D tensor.
			- Optionally adding timing signals.
			Args:
			targets: inputs ids to the decoder. [batch_size, 1]
			i: scalar, Step number of the decoding loop.
			Returns:
			Processed targets [batch_size, 1, hidden_dim]
			"""
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # TODO(llion): Explain! Is this even needed?
            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if hparams.pos == "timing":
                targets += timing_signal[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = dp(
                    self.decode,
                    targets,
                    cache["encoder_output"],
                    cache["encoder_decoder_attention_bias"],
                    bias,
                    hparams,
                    cache,
                    nonpadding=transformer._features_to_nonpadding(
                        features, "targets"))

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            return tf.squeeze(logits, axis=[1, 2, 3]), cache

        key_channels = hparams.attention_key_channels or hparams.hidden_size
        value_channels = hparams.attention_value_channels or hparams.hidden_size
        num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

        cache = {
            "layer_%d" % layer: {
                "k": tf.zeros([batch_size, 0, key_channels]),
                "v": tf.zeros([batch_size, 0, value_channels]),
            }
            for layer in range(num_layers)
        }

        # Set 2nd dim to None since it's not invariant in the tf.while_loop
        # Note: Tensor.set_shape() does not work here since it merges shape info.
        # TODO(llion); Find a more robust solution.
        # pylint: disable=protected-access
        if not context.in_eager_mode():
            for layer in cache:
                cache[layer]["k"]._shape = tf.TensorShape(
                    [None, None, key_channels])
                cache[layer]["v"]._shape = tf.TensorShape(
                    [None, None, value_channels])
        # pylint: enable=protected-access
        cache["encoder_output"] = encoder_output
        cache[
            "encoder_decoder_attention_bias"] = encoder_decoder_attention_bias

        if beam_size > 1:  # Beam Search
            target_modality = (
                self._hparams.problems[self._problem_idx].target_modality)
            vocab_size = target_modality.top_dimensionality
            initial_ids = tf.zeros([batch_size], dtype=tf.int32)
            decoded_ids, scores = beam_search.beam_search(
                symbols_to_logits_fn,
                initial_ids,
                beam_size,
                decode_length,
                vocab_size,
                alpha,
                states=cache,
                stop_early=(top_beams == 1))

            decoded_ids = decoded_ids[:, :, 1:]

            # do roulette wheel selection or inverse roulette wheel selection
            if self._hparams.roulette == "Normal" or self._hparams.roulette == "Inverse":
                if self._hparams.roulette == "Normal":
                    probabilities = tf.pow(tf.constant(2.0), scores)
                    start = 0
                else:
                    probabilities = tf.subtract(
                        tf.constant(1.0), tf.pow(tf.constant(2.0), scores))
                    start = beam_size - self._hparams.roulette_beam_size

                summ = tf.reduce_sum(probabilities)
                ex_probs = tf.divide(probabilities, summ)
                #ex_probs=tf.nn.softmax(probabilities)

                # sample a number between 0 and 1
                wheel = tf.random_uniform([1])
                upper_bound = tf.constant(0.0)

                # change this as well if using inverse
                for i in range(start, self._hparams.roulette_beam_size):
                    upper_bound = tf.add(ex_probs[:, i], upper_bound)
                    truthValue = tf.squeeze(
                        tf.logical_and(wheel >= upper_bound - ex_probs[:, i],
                                       wheel <= upper_bound))
                    decoded_ids, scores, i = tf.cond(
                        truthValue, lambda:
                        (decoded_ids[:, i, :], scores[:, i], beam_size),
                        lambda: (decoded_ids, scores, i))

        else:  # Greedy

            def inner_loop(i, next_id, decoded_ids, cache):
                logits, cache = symbols_to_logits_fn(next_id, i, cache)
                temperature = (0.0 if hparams.sampling_method == "argmax" else
                               hparams.sampling_temp)
                next_id = tf.expand_dims(common_layers.sample_with_temperature(
                    logits, temperature),
                                         axis=1)
                decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
                return i + 1, next_id, decoded_ids, cache

            decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
            scores = None
            next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
            _, _, decoded_ids, _ = tf.while_loop(
                # TODO(llion): Early stopping.
                lambda i, *_: tf.less(i, decode_length),
                inner_loop,
                [tf.constant(0), next_id, decoded_ids, cache],
                shape_invariants=[
                    tf.TensorShape([]),
                    tf.TensorShape([None, None]),
                    tf.TensorShape([None, None]),
                    nest.map_structure(lambda t: tf.TensorShape(t.shape),
                                       cache),
                ])

        return decoded_ids, scores
    def decode_inputs_to_outputs(self, decoder_embed_inputs, encoder_outputs, encoder_attn_bias,
                                 rule_id_input_placeholder, mem_contexts, mem_outputs, global_step,
                                 score, obj_tensors=None):
        if self.hparams.pos == 'timing':
            decoder_embed_inputs = common_attention.add_timing_signal_1d(decoder_embed_inputs)
            print('Use positional encoding in decoder text.')
        decoder_embed_inputs = self.update_decoder_embedding(decoder_embed_inputs, score, self.model_config.beam_search_size)

        decoder_attn_bias = common_attention.attention_bias_lower_triangle(tf.shape(decoder_embed_inputs)[1])
        decoder_embed_inputs = tf.nn.dropout(decoder_embed_inputs,
                                             1.0 - self.hparams.layer_prepostprocess_dropout)
        if 'direct' in self.model_config.memory:
            assert 'direct_bert_output' in obj_tensors
            decoder_output = transformer.transformer_multi_decoder(
                decoder_embed_inputs, encoder_outputs, decoder_attn_bias,
                encoder_attn_bias, obj_tensors['direct_bert_output'], obj_tensors['direct_bert_bias'],
                self.hparams, save_weights_to=obj_tensors,
                direct_mode=self.model_config.direct_mode)

            if self.model_config.npad_mode == 'static_seq':
                decoder_output = tf.nn.conv1d(decoder_output, obj_tensors['npad_w'], 1, 'SAME')

            return decoder_output, decoder_output, None
        elif 'rule' in self.model_config.memory:
            decoder_output, contexts = transformer.transformer_decoder_contexts(
                decoder_embed_inputs, encoder_outputs, decoder_attn_bias,
                encoder_attn_bias, self.hparams)

            # encoder_gate_w = tf.get_variable('encoder_gate_w', shape=(
            #     1, self.model_config.dimension, 1))
            # encoder_gate_b = tf.get_variable('encoder_gate_b', shape=(1, 1, 1))
            # encoder_gate = tf.tanh(encoder_gate_b + tf.nn.conv1d(encoder_outputs, encoder_gate_w, 1, 'SAME'))
            # encoder_context_outputs = tf.expand_dims(tf.reduce_mean(encoder_outputs * encoder_gate, axis=1), axis=1)
            cur_context = contexts[0] #tf.concat(contexts, axis=-1)
            cur_mem_contexts = tf.stack(self.embedding_fn(rule_id_input_placeholder, mem_contexts), axis=1)
            cur_mem_outputs = tf.stack(self.embedding_fn(rule_id_input_placeholder, mem_outputs), axis=1)
            cur_mem_contexts = tf.reshape(cur_mem_contexts,
                                          [self.model_config.batch_size,
                                           self.model_config.max_target_rule_sublen*self.model_config.max_cand_rules,
                                           self.model_config.dimension])
            cur_mem_outputs = tf.reshape(cur_mem_outputs,
                                         [self.model_config.batch_size,
                                          self.model_config.max_target_rule_sublen*self.model_config.max_cand_rules,
                                          self.model_config.dimension])

            # bias = tf.expand_dims(
            #     -1e9 * tf.to_float(tf.equal(tf.stack(rule_id_input_placeholder, axis=1), 0)),
            #     axis=1)
            # weights = tf.nn.softmax(bias + tf.matmul(cur_context, cur_mem_contexts, transpose_b=True))
            weights = tf.nn.softmax(tf.matmul(cur_context, cur_mem_contexts, transpose_b=True))
            mem_output = tf.matmul(weights, cur_mem_outputs)

            # trainable_mem = 'stopgrad' not in self.model_config.rl_configs
            temp_output = tf.concat((decoder_output, mem_output), axis=-1)
            # w_u = tf.get_variable('w_ffn', shape=(
            #     1, self.model_config.dimension*2, self.model_config.dimension), trainable=trainable_mem)
            # b_u = tf.get_variable('b_ffn', shape=(
            #     1, 1, self.model_config.dimension), trainable=trainable_mem)
            # w_u.reuse_variables()
            # b_u.reuse_variables()
            # tf.get_variable_scope().reuse_variables()
            w_t = tf.get_variable('w_ffn', shape=(
                1, self.model_config.dimension*2, self.model_config.dimension), trainable=True)
            b_t = tf.get_variable('b_ffn', shape=(
                1, 1, self.model_config.dimension), trainable=True)
            # w = tf.cond(tf.equal(tf.mod(self.global_step, 2), 0), lambda: w_t, lambda: w_u)
            # b = tf.cond(tf.equal(tf.mod(self.global_step, 2), 0), lambda: b_t, lambda: b_u)

            mem_output = tf.nn.conv1d(temp_output, w_t, 1, 'SAME') + b_t
            g = tf.greater(global_step, tf.constant(self.model_config.memory_prepare_step, dtype=tf.int64))
            final_output = tf.cond(g, lambda: mem_output, lambda: decoder_output)
            return final_output, decoder_output, cur_context
        else:
            if self.model_config.architecture == 'ut2t':
                (decoder_output, decoder_extra_output) = universal_transformer_util.universal_transformer_decoder(
                    decoder_embed_inputs, encoder_outputs,
                    decoder_attn_bias, encoder_attn_bias, self.hparams,
                    save_weights_to=obj_tensors)
                dec_ponder_times, dec_remainders = decoder_extra_output
                extra_dec_loss = (
                        self.hparams.act_loss_weight *
                        tf.reduce_mean(dec_ponder_times + dec_remainders))
                if self.is_train:
                    obj_tensors['extra_decoder_loss'] = extra_dec_loss
            else:
                decoder_output = transformer.transformer_decoder(
                    decoder_embed_inputs, encoder_outputs, decoder_attn_bias,
                    encoder_attn_bias, self.hparams, save_weights_to=obj_tensors,
                    npad_mode=self.model_config.npad_mode)
            final_output = decoder_output
            return final_output, decoder_output, None
Beispiel #35
0
  def _fast_decode(self,
                   features,
                   decode_length,
                   beam_size=1,
                   top_beams=1,
                   alpha=1.0):
    """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if beam_size == 1 or
              [batch_size, top_beams, <= decode_length]
          "scores": decoding log probs from the beam search,
              None if using greedy decoding (beam_size=1)
      }

    Raises:
      NotImplementedError: If there are multiple data shards.
    """
    if self._num_datashards != 1:
      raise NotImplementedError("Fast decoding only supports a single shard.")
    dp = self._data_parallelism
    hparams = self._hparams
    target_modality = self._problem_hparams.target_modality

    if self.has_input:
      inputs = features["inputs"]
      if target_modality.is_class_modality:
        decode_length = 1
      else:
        decode_length = common_layers.shape_list(inputs)[1] + decode_length

      # TODO(llion): Clean up this reshaping logic.
      inputs = tf.expand_dims(inputs, axis=1)
      if len(inputs.shape) < 5:
        inputs = tf.expand_dims(inputs, axis=4)
      s = common_layers.shape_list(inputs)
      batch_size = s[0]
      inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
      # _shard_features called to ensure that the variable names match
      inputs = self._shard_features({"inputs": inputs})["inputs"]
      input_modality = self._problem_hparams.input_modality["inputs"]
      with tf.variable_scope(input_modality.name):
        inputs = input_modality.bottom_sharded(inputs, dp)
      with tf.variable_scope("body"):
        encoder_output, encoder_decoder_attention_bias = dp(
            self.encode, inputs, features["target_space_id"], hparams,
            features=features)
      encoder_output = encoder_output[0]
      encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
      partial_targets = None
    else:
      # The problem has no inputs.
      # In this case, features["inputs"] contains partial targets.
      # We force the outputs to begin with these sequences.
      encoder_output = None
      encoder_decoder_attention_bias = None
      partial_targets = tf.squeeze(tf.to_int64(features["inputs"]), [2, 3])
      partial_targets_length = common_layers.shape_list(partial_targets)[1]
      decode_length += partial_targets_length
      batch_size = tf.shape(partial_targets)[0]

    if hparams.pos == "timing":
      timing_signal = common_attention.get_timing_signal_1d(
          decode_length + 1, hparams.hidden_size)

    def preprocess_targets(targets, i):
      """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
      # _shard_features called to ensure that the variable names match
      targets = self._shard_features({"targets": targets})["targets"]
      with tf.variable_scope(target_modality.name):
        targets = target_modality.targets_bottom_sharded(targets, dp)[0]
      targets = common_layers.flatten4d3d(targets)

      # TODO(llion): Explain! Is this even needed?
      targets = tf.cond(
          tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)

      if hparams.pos == "timing":
        targets += timing_signal[:, i:i + 1]
      return targets

    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(decode_length))
    if hparams.proximity_bias:
      decoder_self_attention_bias += common_attention.attention_bias_proximal(
          decode_length)

    def symbols_to_logits_fn(ids, i, cache):
      """Go from ids to logits for next symbol."""
      ids = ids[:, -1:]
      targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      targets = preprocess_targets(targets, i)

      bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

      with tf.variable_scope("body"):
        body_outputs = dp(
            self.decode, targets, cache.get("encoder_output"),
            cache.get("encoder_decoder_attention_bias"),
            bias, hparams, cache,
            nonpadding=features_to_nonpadding(features, "targets"))

      with tf.variable_scope(target_modality.name):
        logits = target_modality.top_sharded(body_outputs, None, dp)[0]

      ret = tf.squeeze(logits, axis=[1, 2, 3])
      if partial_targets is not None:
        # If the position is within the given partial targets, we alter the
        # logits to always return those values.
        # A faster approach would be to process the partial targets in one
        # iteration in order to fill the corresponding parts of the cache.
        # This would require broader changes, though.
        vocab_size = tf.shape(ret)[1]
        def forced_logits():
          return tf.one_hot(tf.tile(partial_targets[:, i], [beam_size]),
                            vocab_size, 0.0, -1e9)
        ret = tf.cond(
            tf.less(i, partial_targets_length), forced_logits, lambda: ret)
      return ret, cache

    ret = fast_decode(
        encoder_output=encoder_output,
        encoder_decoder_attention_bias=encoder_decoder_attention_bias,
        symbols_to_logits_fn=symbols_to_logits_fn,
        hparams=hparams,
        decode_length=decode_length,
        vocab_size=target_modality.top_dimensionality,
        beam_size=beam_size,
        top_beams=top_beams,
        alpha=alpha,
        batch_size=batch_size)
    if partial_targets is not None:
      ret["outputs"] = ret["outputs"][:, partial_targets_length:]
    return ret
Beispiel #36
0
    def _fast_decode(self,
                     features,
                     decode_length,
                     beam_size=1,
                     top_beams=1,
                     alpha=1.0,
                     sentence_cache=None):
        """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
      A dict of decoding results {
          "outputs": integer `Tensor` of decoded ids of shape
              [batch_size, <= decode_length] if beam_size == 1 or
              [batch_size, top_beams, <= decode_length]
          "scores": decoding log probs from the beam search,
              None if using greedy decoding (beam_size=1)
      }

    Raises:
      NotImplementedError: If there are multiple data shards.
    """
        if self._num_datashards != 1:
            raise NotImplementedError(
                "Fast decoding only supports a single shard.")
        dp = self._data_parallelism
        hparams = self._hparams
        target_modality = self._problem_hparams.target_modality

        if self.has_input:
            inputs = features["inputs"]
            if target_modality.is_class_modality:
                decode_length = 1
            else:
                decode_length = common_layers.shape_list(
                    inputs)[1] + decode_length

            # TODO(llion): Clean up this reshaping logic.
            inputs = tf.expand_dims(inputs, axis=1)
            if len(inputs.shape) < 5:
                inputs = tf.expand_dims(inputs, axis=4)
            s = common_layers.shape_list(inputs)
            batch_size = s[0]
            inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
            # _shard_features called to ensure that the variable names match
            inputs = self._shard_features({"inputs": inputs})["inputs"]
            input_modality = self._problem_hparams.input_modality["inputs"]
            with tf.variable_scope(input_modality.name):
                inputs = input_modality.bottom_sharded(inputs, dp)
            with tf.variable_scope("body"):
                encoder_output, encoder_decoder_attention_bias = dp(
                    self.encode,
                    inputs,
                    features["target_space_id"],
                    hparams,
                    features=features)
            encoder_output = encoder_output[0]
            encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
            partial_targets = None
        else:
            # The problem has no inputs.
            # In this case, features["inputs"] contains partial targets.
            # We force the outputs to begin with these sequences.
            encoder_output = None
            encoder_decoder_attention_bias = None
            partial_targets = tf.squeeze(tf.to_int64(features["inputs"]),
                                         [2, 3])
            partial_targets_length = common_layers.shape_list(
                partial_targets)[1]
            decode_length += partial_targets_length
            batch_size = tf.shape(partial_targets)[0]

        if hparams.pos == "timing":
            timing_signal = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
            # _shard_features called to ensure that the variable names match
            targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom_sharded(targets,
                                                                 dp)[0]
            targets = common_layers.flatten4d3d(targets)

            # TODO(llion): Explain! Is this even needed?
            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if hparams.pos == "timing":
                targets += timing_signal[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = dp(self.decode,
                                  targets,
                                  cache.get("encoder_output"),
                                  cache.get("encoder_decoder_attention_bias"),
                                  bias,
                                  hparams,
                                  cache,
                                  nonpadding=features_to_nonpadding(
                                      features, "targets"))

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top_sharded(body_outputs, None, dp)[0]

            ret = tf.squeeze(logits, axis=[1, 2, 3])
            if partial_targets is not None:
                # If the position is within the given partial targets, we alter the
                # logits to always return those values.
                # A faster approach would be to process the partial targets in one
                # iteration in order to fill the corresponding parts of the cache.
                # This would require broader changes, though.
                vocab_size = tf.shape(ret)[1]

                def forced_logits():
                    return tf.one_hot(
                        tf.tile(partial_targets[:, i], [beam_size]),
                        vocab_size, 0.0, -1e9)

                ret = tf.cond(tf.less(i, partial_targets_length),
                              forced_logits, lambda: ret)
            return ret, cache, body_outputs

        ret = fast_decode(
            encoder_output=encoder_output,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            symbols_to_logits_fn=symbols_to_logits_fn,
            hparams=hparams,
            decode_length=decode_length,
            vocab_size=target_modality.top_dimensionality,
            beam_size=beam_size,
            top_beams=top_beams,
            alpha=alpha,
            batch_size=batch_size,
            sentence_cache=self.sentence_cache,
            cache_flag=self.cache_flag)
        if partial_targets is not None:
            ret["outputs"] = ret["outputs"][:, partial_targets_length:]
        return ret
Beispiel #37
0
    def test_calculate_branching_model_parameters_transformer(
            self, get_config, expected_hidden_depths):
        tf.reset_default_graph()

        (num_cells, left_inputs, left_layers, left_output_dims, right_inputs,
         right_layers, right_output_dims, combiner_functions,
         final_combiner_function, dummy_activations, dummy_norms,
         layer_registry, is_decoder) = get_config()

        # Get predicted number of parameters.
        (predicted_num_params, output_size, hidden_depths,
         _) = translation_nas_net.calculate_branching_model_parameters(
             encoding_depth=_EMBEDDING_DEPTH,
             left_inputs=left_inputs,
             left_layers=left_layers,
             left_output_dims=left_output_dims,
             right_inputs=right_inputs,
             right_layers=right_layers,
             right_output_dims=right_output_dims,
             combiner_functions=combiner_functions,
             final_combiner_function=final_combiner_function,
             layer_registry=layer_registry,
             num_cells=num_cells,
             encoder_depth=_EMBEDDING_DEPTH)

        # Create model graph.
        input_tensor = tf.zeros([32, _INPUT_LENGTH, _EMBEDDING_DEPTH])
        hparams = transformer.transformer_small()

        if is_decoder:
            nonpadding = None
            mask_future = True
            decoder_self_attention_bias = (
                common_attention.attention_bias_lower_triangle(_INPUT_LENGTH))
            encoder_cell_outputs = [input_tensor] * 6
        else:
            nonpadding = tf.ones([32, _INPUT_LENGTH])
            mask_future = False
            decoder_self_attention_bias = None
            encoder_cell_outputs = None

        translation_nas_net.apply_nas_layers(
            input_tensor=input_tensor,
            left_inputs=left_inputs,
            left_layers=left_layers,
            left_activations=dummy_activations,
            left_output_dims=left_output_dims,
            left_norms=dummy_norms,
            right_inputs=right_inputs,
            right_layers=right_layers,
            right_activations=dummy_activations,
            right_output_dims=right_output_dims,
            right_norms=dummy_norms,
            combiner_functions=combiner_functions,
            final_combiner_function=final_combiner_function,
            num_cells=num_cells,
            nonpadding=nonpadding,
            layer_registry=layer_registry,
            mask_future=mask_future,
            hparams=hparams,
            var_scope="test",
            encoder_decoder_attention_bias=None,
            encoder_cell_outputs=encoder_cell_outputs,
            decoder_self_attention_bias=decoder_self_attention_bias,
            final_layer_norm=False)

        # Count graph variables.
        trainable_variables_list = tf.trainable_variables()
        empirical_num_params = 0
        for variable_tensor in trainable_variables_list:
            empirical_num_params += _list_product(
                variable_tensor.shape.as_list())

        # Compare.
        self.assertEqual(empirical_num_params, predicted_num_params)
        self.assertEqual(output_size, _EMBEDDING_DEPTH)
        self.assertEqual(hidden_depths, expected_hidden_depths)
    def _fast_decode(self,
                     features,
                     decode_length,
                     beam_size=1,
                     top_beams=1,
                     alpha=1.0):
        #dp = self._data_parallelism
        hparams = self._hparams
        target_modality = self._problem_hparams.modality["targets"]

        inputs = features["inputs"]

        decode_length = (common_layers.shape_list(inputs)[1] +
                         features.get("decode_length", decode_length))

        #inputs = tf.expand_dims(inputs, axis=1)
        #if len(inputs.shape) < 5:
        #    inputs = tf.expand_dims(inputs, axis=4)

        s = common_layers.shape_list(inputs)
        batch_size = s[0]
        #inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
        # _shard_features called to ensure that the variable names match
        #inputs = self._shard_features({"inputs": inputs})["inputs"]
        input_modality = self._problem_hparams.modality["inputs"]
        context_modality = {}

        contexts = {}
        for feature_name in features:
            if 'context' in feature_name and 'raw' not in feature_name:
                contexts[feature_name] = features[feature_name]

        for context_name in contexts:
            if context_name in self._problem_hparams.modality:
                context_modality[
                    context_name] = self._problem_hparams.modality[
                        context_name]
            else:
                context_modality[context_name] = input_modality

        with tf.variable_scope(input_modality.name, reuse=tf.AUTO_REUSE):
            inputs = input_modality.bottom(inputs)
            for context_name in contexts:
                contexts[context_name] = context_modality[context_name].bottom(
                    contexts[context_name])

        with tf.variable_scope("body", reuse=tf.AUTO_REUSE):
            encoder_output, encoder_decoder_attention_bias = self.encode(
                inputs,
                contexts,
                features["target_space_id"],
                hparams,
                features=features)
        #encoder_output = encoder_output[0]
        #encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]
        partial_targets = None

        if hparams.pos == "timing":
            positional_encoding = common_attention.get_timing_signal_1d(
                decode_length + 1, hparams.hidden_size)
        elif hparams.pos == "emb":
            positional_encoding = common_attention.add_positional_embedding(
                tf.zeros([1, decode_length + 1, hparams.hidden_size]),
                hparams.max_length, "targets_positional_embedding", None)
        else:
            positional_encoding = None

        def preprocess_targets(targets, i):
            """Performs preprocessing steps on the targets to prepare for the decoder.
            This includes:
              - Embedding the ids.
              - Flattening to 3D tensor.
              - Optionally adding timing signals.
            Args:
              targets: inputs ids to the decoder. [batch_size, 1]
              i: scalar, Step number of the decoding loop.
            Returns:
              Processed targets [batch_size, 1, hidden_dim]
            """
            # _shard_features called to ensure that the variable names match
            #targets = self._shard_features({"targets": targets})["targets"]
            with tf.variable_scope(target_modality.name):
                targets = target_modality.targets_bottom(targets)
            targets = common_layers.flatten4d3d(targets)

            targets = tf.cond(tf.equal(i, 0), lambda: tf.zeros_like(targets),
                              lambda: targets)

            if positional_encoding is not None:
                targets += positional_encoding[:, i:i + 1]
            return targets

        decoder_self_attention_bias = (
            common_attention.attention_bias_lower_triangle(decode_length))
        if hparams.proximity_bias:
            decoder_self_attention_bias += common_attention.attention_bias_proximal(
                decode_length)

        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]
            targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = preprocess_targets(targets, i)

            bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

            with tf.variable_scope("body"):
                body_outputs = self.decode(
                    targets,
                    cache.get("encoder_output"),
                    cache.get("encoder_decoder_attention_bias"),
                    bias,
                    hparams,
                    cache,
                    nonpadding=features_to_nonpadding(features, "targets"))

            with tf.variable_scope(target_modality.name):
                logits = target_modality.top(body_outputs, None)

            ret = tf.squeeze(logits, axis=[1, 2, 3])
            return ret, cache

        ret = fast_decode(
            encoder_output=encoder_output,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            symbols_to_logits_fn=symbols_to_logits_fn,
            hparams=hparams,
            decode_length=decode_length,
            vocab_size=target_modality.top_dimensionality,
            beam_size=beam_size,
            top_beams=top_beams,
            alpha=alpha,
            batch_size=batch_size,
            force_decode_length=self._decode_hparams.force_decode_length)

        return ret
Beispiel #39
0
 def _bias(x):
   return common_attention.attention_bias_lower_triangle(
       common_layers.shape_list(x)[1])