示例#1
0
def lstm_seq2seq_internal_bid_encoder(inputs, targets, hparams, train):
  """The basic LSTM seq2seq model with bidirectional encoder."""
  with tf.variable_scope("lstm_seq2seq_bid_encoder"):
    if inputs is not None:
      inputs_length = common_layers.length_from_embedding(inputs)
      # Flatten inputs.
      inputs = common_layers.flatten4d3d(inputs)
      # LSTM encoder.
      _, final_encoder_state = lstm_bid_encoder(
          inputs, inputs_length, hparams, train, "encoder")
    else:
      inputs_length = None
      final_encoder_state = None
    # LSTM decoder.
    shifted_targets = common_layers.shift_right(targets)
    # Add 1 to account for the padding added to the left from shift_right
    targets_length = common_layers.length_from_embedding(shifted_targets) + 1
    hparams_decoder = copy.copy(hparams)
    hparams_decoder.hidden_size = 2 * hparams.hidden_size
    decoder_outputs, _ = lstm(
        common_layers.flatten4d3d(shifted_targets),
        targets_length,
        hparams_decoder,
        train,
        "decoder",
        initial_state=final_encoder_state)
    return tf.expand_dims(decoder_outputs, axis=2)
示例#2
0
def lstm_seq2seq_internal(inputs, targets, hparams, train):
  """The basic LSTM seq2seq model, main step used for training."""
  with tf.variable_scope("lstm_seq2seq"):
    if inputs is not None:
      inputs_length = common_layers.length_from_embedding(inputs)
      # Flatten inputs.
      inputs = common_layers.flatten4d3d(inputs)

      # LSTM encoder.
      inputs = tf.reverse_sequence(inputs, inputs_length, seq_axis=1)
      _, final_encoder_state = lstm(inputs, inputs_length, hparams, train,
                                    "encoder")
    else:
      final_encoder_state = None

    # LSTM decoder.
    shifted_targets = common_layers.shift_right(targets)
    # Add 1 to account for the padding added to the left from shift_right
    targets_length = common_layers.length_from_embedding(shifted_targets) + 1
    decoder_outputs, _ = lstm(
        common_layers.flatten4d3d(shifted_targets),
        targets_length,
        hparams,
        train,
        "decoder",
        initial_state=final_encoder_state)
    return tf.expand_dims(decoder_outputs, axis=2)
  def body(self, features):
    hparams = self._hparams
    targets = features["targets"]
    inputs = features["inputs"]
    target_space = features["target_space_id"]

    inputs = common_layers.flatten4d3d(inputs)
    targets = common_layers.flatten4d3d(targets)

    (encoder_input, encoder_self_attention_bias,
     encoder_decoder_attention_bias) = (transformer.transformer_prepare_encoder(
         inputs, target_space, hparams))
    (decoder_input,
     decoder_self_attention_bias) = transformer.transformer_prepare_decoder(
         targets, hparams)

    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)
    decoder_input = tf.nn.dropout(decoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)
    encoder_output = transformer_revnet_encoder(
        encoder_input, encoder_self_attention_bias, hparams)

    decoder_output = transformer_revnet_decoder(
        decoder_input, encoder_output, decoder_self_attention_bias,
        encoder_decoder_attention_bias, hparams)
    decoder_output = tf.expand_dims(decoder_output, 2)

    return decoder_output
  def body(self, features):
    hp = self.hparams
    # pylint: disable=eval-used
    if hp.image_input_type == "image":
      image_feat = vqa_layers.image_embedding(
          features["inputs"],
          model_fn=eval(hp.image_model_fn),
          trainable=hp.train_resnet,
          is_training=hp.mode == tf.estimator.ModeKeys.TRAIN)
    else:
      image_feat = features["inputs"]

    image_feat = common_layers.flatten4d3d(image_feat)
    image_feat = common_layers.dense(image_feat, hp.hidden_size)
    utils.collect_named_outputs("norms", "image_feat_after_proj",
                                tf.norm(image_feat, axis=-1))

    question = common_layers.flatten4d3d(features["question"])
    utils.collect_named_outputs("norms", "question_embedding",
                                tf.norm(question, axis=-1))
    (encoder_input, encoder_self_attention_bias,
     encoder_decoder_attention_bias) = prepare_image_question_encoder(
         image_feat, question, hp)

    encoder_input = tf.nn.dropout(
        encoder_input, keep_prob=1.-hp.layer_prepostprocess_dropout)

    encoder_output, _ = recurrent_transformer_decoder(
        encoder_input, None, encoder_self_attention_bias, None,
        hp, name="encoder")
    utils.collect_named_outputs(
        "norms", "encoder_output", tf.norm(encoder_output, axis=-1))

    # scale query by sqrt(hidden_size)
    query = tf.get_variable("query", [hp.hidden_size]) * hp.hidden_size **0.5
    query = tf.expand_dims(tf.expand_dims(query, axis=0), axis=0)
    batch_size = common_layers.shape_list(encoder_input)[0]
    query = tf.tile(query, [batch_size, 1, 1])
    query = tf.nn.dropout(
        query, keep_prob=1.-hp.layer_prepostprocess_dropout)

    decoder_output, _ = recurrent_transformer_decoder(
        query, encoder_output, None, encoder_decoder_attention_bias,
        hp, name="decoder")
    utils.collect_named_outputs("norms", "decoder_output",
                                tf.norm(decoder_output, axis=-1))

    norm_tensors = utils.convert_collection_to_dict("norms")
    vqa_layers.summarize_tensors(norm_tensors, tag="norms/")

    # Expand dimension 1 and 2
    return tf.expand_dims(decoder_output, axis=1)
示例#5
0
def lstm_seq2seq_internal_attention(inputs, targets, hparams, train):
  """LSTM seq2seq model with attention, main step used for training."""
  with tf.variable_scope("lstm_seq2seq_attention"):
    # Flatten inputs.
    inputs = common_layers.flatten4d3d(inputs)
    # LSTM encoder.
    encoder_outputs, final_encoder_state = lstm(
        tf.reverse(inputs, axis=[1]), hparams, train, "encoder")
    # LSTM decoder with attention
    shifted_targets = common_layers.shift_right(targets)
    decoder_outputs, _ = lstm_attention_decoder(
        common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder",
        final_encoder_state, encoder_outputs)
    return tf.expand_dims(decoder_outputs, axis=2)
示例#6
0
def bytenet_internal(inputs, targets, hparams):
  """ByteNet, main step used for training."""
  with tf.variable_scope("bytenet"):
    # Flatten inputs and extend length by 50%.
    inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)
    extend_length = tf.to_int32(0.5 * tf.to_float(tf.shape(inputs)[1]))
    inputs_shape = inputs.shape.as_list()
    inputs = tf.pad(inputs, [[0, 0], [0, extend_length], [0, 0], [0, 0]])
    inputs_shape[1] = None
    inputs.set_shape(inputs_shape)  # Don't lose the other shapes when padding.
    # Pad inputs and targets to be the same length, divisible by 50.
    inputs, targets = common_layers.pad_to_same_length(
        inputs, targets, final_length_divisible_by=50)
    final_encoder = residual_dilated_conv(inputs, hparams.num_block_repeat,
                                          "SAME", "encoder", hparams)

    shifted_targets = common_layers.shift_right(targets)
    kernel = (hparams.kernel_height, hparams.kernel_width)
    decoder_start = common_layers.conv_block(
        tf.concat([final_encoder, shifted_targets], axis=3),
        hparams.hidden_size, [((1, 1), kernel)],
        padding="LEFT")

    return residual_dilated_conv(decoder_start, hparams.num_block_repeat,
                                 "LEFT", "decoder", hparams)
示例#7
0
def slicenet_internal(inputs, targets, target_space, hparams, run_decoder=True):
  """The slicenet model, main step used for training."""
  with tf.variable_scope("slicenet"):
    # Project to hidden size if necessary
    if inputs.get_shape().as_list()[-1] != hparams.hidden_size:
      inputs = common_layers.conv_block(
          inputs,
          hparams.hidden_size, [((1, 1), (3, 3))],
          first_relu=False,
          padding="SAME",
          force2d=True)

    # Flatten inputs and encode.
    inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)
    inputs_mask = 1.0 - embedding_to_padding(inputs)
    inputs = common_layers.add_timing_signal(inputs)  # Add position info.
    target_space_emb = embed_target_space(target_space, hparams.hidden_size)
    extra_layers = int(hparams.num_hidden_layers * 1.5)
    inputs_encoded = multi_conv_res(
        inputs, "SAME", "encoder", extra_layers, hparams, mask=inputs_mask)
    if not run_decoder:
      return inputs_encoded
    # Do the middle part.
    decoder_start, similarity_loss = slicenet_middle(
        inputs_encoded, targets, target_space_emb, inputs_mask, hparams)
    # Decode.
    decoder_final = multi_conv_res(
        decoder_start,
        "LEFT",
        "decoder",
        hparams.num_hidden_layers,
        hparams,
        mask=inputs_mask,
        source=inputs_encoded)
    return decoder_final, tf.reduce_mean(similarity_loss)
示例#8
0
  def encode(self, inputs, target_space, hparams, features=None):
    """Encode transformer inputs.

    Args:
      inputs: Transformer inputs [batch_size, input_length, input_height,
        hidden_dim] which will be flattened along the two spatial dimensions.
      target_space: scalar, target space ID.
      hparams: hyperparmeters for model.
      features: optionally pass the entire features dictionary as well.
        This is needed now for "packed" datasets.

    Returns:
      Tuple of:
          encoder_output: Encoder representation.
              [batch_size, input_length, hidden_dim]
          encoder_decoder_attention_bias: Bias and mask weights for
              encodre-decoder attention. [batch_size, input_length]
    """
    inputs = common_layers.flatten4d3d(inputs)

    encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
        transformer_prepare_encoder(
            inputs, target_space, hparams, features=features))

    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)

    encoder_output = transformer_encoder(
        encoder_input, self_attention_bias,
        hparams, nonpadding=features_to_nonpadding(features, "inputs"),
        save_weights_to=self.attention_weights)

    return encoder_output, encoder_decoder_attention_bias
示例#9
0
    def preprocess_targets(targets, i):
      """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
      # _shard_features called to ensure that the variable names match
      targets = self._shard_features({"targets": targets})["targets"]
      with tf.variable_scope(target_modality.name):
        targets = target_modality.targets_bottom_sharded(targets, dp)[0]
      targets = common_layers.flatten4d3d(targets)

      # TODO(llion): Explain! Is this even needed?
      targets = tf.cond(
          tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)

      if hparams.pos == "timing":
        targets += timing_signal[:, i:i + 1]
      return targets
示例#10
0
def transformer_text_encoder(inputs,
                             target_space,
                             hparams,
                             name=None):
  """Transformer text encoder over inputs with unmasked full attention.

  Args:
    inputs: Tensor of shape [batch, length, 1, hparams.hidden_size].
    target_space: int. Used for encoding inputs under a target space id.
    hparams: tf.contrib.training.HParams.
    name: string, variable scope.

  Returns:
    encoder_output: Tensor of shape [batch, length, hparams.hidden_size].
    ed: Tensor of shape [batch, 1, 1, length]. Encoder-decoder attention bias
      for any padded tokens.
  """
  with tf.variable_scope(name, default_name="transformer_text_encoder"):
    inputs = common_layers.flatten4d3d(inputs)
    [
        encoder_input,
        encoder_self_attention_bias,
        ed,
    ] = transformer_layers.transformer_prepare_encoder(
        inputs, target_space=target_space, hparams=hparams)
    encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout)
    encoder_output = transformer_layers.transformer_encoder(
        encoder_input, encoder_self_attention_bias, hparams)
    return encoder_output, ed
示例#11
0
  def model_fn_body(self, features):
    """Transformer main model_fn.

    Args:
      features: Map of features to the model. Should contain the following:
          "inputs": Transformer inputs [batch_size, input_length, hidden_dim]
          "tragets": Target decoder outputs.
              [batch_size, decoder_length, hidden_dim]
          "target_space_id"

    Returns:
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
    """
    hparams = self._hparams

    inputs = features.get("inputs")
    encoder_output, encoder_decoder_attention_bias = (None, None)
    if inputs is not None:
      target_space = features["target_space_id"]
      encoder_output, encoder_decoder_attention_bias = self.encode(
          inputs, target_space, hparams, features=features)

    targets = features["targets"]
    targets = common_layers.flatten4d3d(targets)

    decoder_input, decoder_self_attention_bias = transformer_prepare_decoder(
        targets, hparams, features=features)

    return self.decode(decoder_input, encoder_output,
                       encoder_decoder_attention_bias,
                       decoder_self_attention_bias, hparams,
                       nonpadding=_features_to_nonpadding(features, "targets"))
示例#12
0
def transformer_text_encoder(x,
                             space_id,
                             hparams,
                             name="transformer_text_encoder"):
  """Transformer text encoder over inputs with unmasked full attention.

  Args:
    x: Tensor of shape [batch, length, 1, hparams.hidden_size].
    space_id: int, id.
    hparams: tf.contrib.training.HParams.
    name: string, variable scope.

  Returns:
    encoder_output: Tensor of shape [batch, length, hparams.hidden_size].
    ed: Tensor of shape [batch, 1, 1, length]. Encoder-decoder attention bias
      for any padded tokens.
  """
  with tf.variable_scope(name):
    x = common_layers.flatten4d3d(x)
    (encoder_input, encoder_self_attention_bias,
     ed) = transformer.transformer_prepare_encoder(x, space_id, hparams)
    encoder_input = tf.nn.dropout(encoder_input, 1.0 - hparams.dropout)
    encoder_output = transformer.transformer_encoder(
        encoder_input, encoder_self_attention_bias, hparams)
    return encoder_output, ed
示例#13
0
  def body(self, features):
    hp = self.hparams
    # pylint: disable=eval-used
    if hp.image_input_type == "image":
      image_feat = vqa_layers.image_embedding(
          features["inputs"],
          model_fn=eval(hp.image_model_fn),
          trainable=hp.train_resnet,
          is_training=hp.mode == tf.estimator.ModeKeys.TRAIN)
    else:
      image_feat = features["inputs"]

    image_feat = common_layers.flatten4d3d(image_feat)
    # image feature self attention
    # image_feat = tf.nn.dropout(
    #     image_feat, keep_prob=1.-hp.layer_prepostprocess_dropout)

    # image_feat = image_feat - tf.reduce_mean(
    #     image_feat, axis=-1, keepdims=True)
    # image_feat = tf.nn.l2_normalize(image_feat, -1)
    # utils.collect_named_outputs("norms", "image_feat_after_l2",
    #                             tf.norm(image_feat, axis=-1))

    image_feat = tf.nn.dropout(image_feat, keep_prob=1.-hp.dropout)

    image_feat = image_encoder(image_feat, hp)
    utils.collect_named_outputs("norms", "image_feat_encoded",
                                tf.norm(image_feat, axis=-1))
    image_feat = common_layers.l2_norm(image_feat)
    utils.collect_named_outputs("norms", "image_feat_encoded_l2",
                                tf.norm(image_feat, axis=-1))

    query = question_encoder(features["question"], hp)
    utils.collect_named_outputs("norms", "query",
                                tf.norm(query, axis=-1))

    image_ave = attn(image_feat, query, hp)
    utils.collect_named_outputs("norms", "image_ave",
                                tf.norm(image_ave, axis=-1))

    image_question = tf.concat([image_ave, query], axis=1)
    utils.collect_named_outputs("norms", "image_question",
                                tf.norm(image_question, axis=-1))

    image_question = tf.nn.dropout(image_question, 1. - hp.dropout)

    output = mlp(image_question, hp)
    utils.collect_named_outputs("norms", "output",
                                tf.norm(output, axis=-1))

    norm_tensors = utils.convert_collection_to_dict("norms")
    vqa_layers.summarize_tensors(norm_tensors, tag="norms/")

    # Expand dimension 1 and 2
    return tf.expand_dims(tf.expand_dims(output, axis=1), axis=2)
示例#14
0
def decode_transformer(encoder_output,
                       encoder_decoder_attention_bias,
                       targets,
                       hparams,
                       name,
                       task=None):
  """Original Transformer decoder."""
  with tf.variable_scope(name):
    if task is None:
      task = hparams.task
    if task == "translate":
      targets = common_layers.flatten4d3d(targets)

      decoder_input, decoder_self_bias = (
          transformer.transformer_prepare_decoder(targets, hparams))

      decoder_input = tf.nn.dropout(decoder_input,
                                    1.0 - hparams.layer_prepostprocess_dropout)

      decoder_output = transformer.transformer_decoder(
          decoder_input,
          encoder_output,
          decoder_self_bias,
          encoder_decoder_attention_bias,
          hparams)
      decoder_output = tf.expand_dims(decoder_output, axis=2)
    else:
      assert task == "image"
      inputs = None
      # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise
      # prepare_image will choke
      targets = tf.reshape(targets, [tf.shape(targets)[0], hparams.img_len,
                                     hparams.img_len,
                                     hparams.num_channels*hparams.hidden_size])

      # Prepare decoder inputs and bias.
      decoder_input, _, _, bias = cia.prepare_decoder(targets, hparams)
      # Add class label to decoder input.
      if not hparams.drop_inputs:
        decoder_input += tf.reshape(
            inputs,
            [common_layers.shape_list(targets)[0], 1, 1, hparams.hidden_size])
      decoder_output = cia.transformer_decoder_layers(
          decoder_input,
          None,
          bias,
          hparams.num_decoder_layers or hparams.num_hidden_layers,
          hparams,
          attention_type=hparams.dec_attention_type,
          name="decoder")
    decoder_output_shape = common_layers.shape_list(decoder_output)
    decoder_output = tf.reshape(decoder_output, [decoder_output_shape[0], -1, 1,
                                                 hparams.hidden_size])
    # Expand since t2t expects 4d tensors.
    return decoder_output
示例#15
0
 def body(self, features):
   if self._hparams.initializer == "orthogonal":
     raise ValueError("LSTM models fail with orthogonal initializer.")
   train = self._hparams.mode == tf.estimator.ModeKeys.TRAIN
   inputs = features.get("inputs")
   # Flatten inputs.
   inputs = common_layers.flatten4d3d(inputs)
   # LSTM encoder.
   encoder_output, _ = lstm(
       tf.reverse(inputs, axis=[1]), self._hparams, train, "encoder")
   return tf.expand_dims(encoder_output, axis=2)
示例#16
0
def lstm_seq2seq_internal(inputs, targets, hparams, train):
  """The basic LSTM seq2seq model, main step used for training."""
  with tf.variable_scope("lstm_seq2seq"):
    if inputs is not None:
      # Flatten inputs.
      inputs = common_layers.flatten4d3d(inputs)
      # LSTM encoder.
      _, final_encoder_state = lstm(
          tf.reverse(inputs, axis=[1]), hparams, train, "encoder")
    else:
      final_encoder_state = None
    # LSTM decoder.
    shifted_targets = common_layers.shift_right(targets)
    decoder_outputs, _ = lstm(
        common_layers.flatten4d3d(shifted_targets),
        hparams,
        train,
        "decoder",
        initial_state=final_encoder_state)
    return tf.expand_dims(decoder_outputs, axis=2)
示例#17
0
def lstm_seq2seq_internal_attention(inputs, targets, hparams, train,
                                    inputs_length, targets_length):
  """LSTM seq2seq model with attention, main step used for training."""
  with tf.variable_scope("lstm_seq2seq_attention"):
    # Flatten inputs.
    inputs = common_layers.flatten4d3d(inputs)

    # LSTM encoder.
    inputs = tf.reverse_sequence(inputs, inputs_length, seq_axis=1)
    encoder_outputs, final_encoder_state = lstm(
        inputs, inputs_length, hparams, train, "encoder")

    # LSTM decoder with attention.
    shifted_targets = common_layers.shift_right(targets)
    # Add 1 to account for the padding added to the left from shift_right
    targets_length = targets_length + 1
    decoder_outputs = lstm_attention_decoder(
        common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder",
        final_encoder_state, encoder_outputs, inputs_length, targets_length)
    return tf.expand_dims(decoder_outputs, axis=2)
示例#18
0
def lstm_seq2seq_internal_attention_bid_encoder(inputs, targets, hparams,
                                                train):
  """LSTM seq2seq model with attention, main step used for training."""
  with tf.variable_scope("lstm_seq2seq_attention_bid_encoder"):
    inputs_length = common_layers.length_from_embedding(inputs)
    # Flatten inputs.
    inputs = common_layers.flatten4d3d(inputs)
    # LSTM encoder.
    encoder_outputs, final_encoder_state = lstm_bid_encoder(
        inputs, inputs_length, hparams, train, "encoder")
    # LSTM decoder with attention
    shifted_targets = common_layers.shift_right(targets)
    # Add 1 to account for the padding added to the left from shift_right
    targets_length = common_layers.length_from_embedding(shifted_targets) + 1
    hparams_decoder = copy.copy(hparams)
    hparams_decoder.hidden_size = 2 * hparams.hidden_size
    decoder_outputs = lstm_attention_decoder(
        common_layers.flatten4d3d(shifted_targets), hparams_decoder, train,
        "decoder", final_encoder_state, encoder_outputs,
        inputs_length, targets_length)
    return tf.expand_dims(decoder_outputs, axis=2)
示例#19
0
  def _prepare_decoder(self, targets):
    """Process the transformer decoder input."""
    targets = common_layers.flatten4d3d(targets)

    output = transformer.transformer_prepare_decoder(
        targets, self._hparams, features=None,
    )
    deco_input, deco_self_attention_bias = output

    deco_input = tf.nn.dropout(
        deco_input, 1.0 - self._hparams.layer_prepostprocess_dropout
    )
    return deco_input, deco_self_attention_bias
示例#20
0
def lstm_seq2seq_internal_bid_encoder(inputs, targets, hparams, train):
  """The basic LSTM seq2seq model with bidirectional encoder."""
  with tf.variable_scope("lstm_seq2seq_bid_encoder"):
    if inputs is not None:
      # Flatten inputs.
      inputs = common_layers.flatten4d3d(inputs)
      # LSTM encoder.
      _, final_encoder_state = lstm_bid_encoder(
          tf.reverse(inputs, axis=[1]), hparams, train, "encoder")
    else:
      final_encoder_state = None
    # LSTM decoder.
    shifted_targets = common_layers.shift_right(targets)
    hparams_decoder = copy.copy(hparams)
    hparams_decoder.hidden_size = 2 * hparams.hidden_size
    decoder_outputs, _ = lstm(
        common_layers.flatten4d3d(shifted_targets),
        hparams_decoder,
        train,
        "decoder",
        initial_state=final_encoder_state)
    return tf.expand_dims(decoder_outputs, axis=2)
示例#21
0
 def body(self, features):
   if self._hparams.initializer == "orthogonal":
     raise ValueError("LSTM models fail with orthogonal initializer.")
   train = self._hparams.mode == tf.estimator.ModeKeys.TRAIN
   inputs = features.get("inputs")
   inputs_length = common_layers.length_from_embedding(inputs)
   # Flatten inputs.
   inputs = common_layers.flatten4d3d(inputs)
   # LSTM encoder.
   inputs = tf.reverse_sequence(inputs, inputs_length, seq_axis=1)
   encoder_output, _ = lstm(inputs, inputs_length, self._hparams, train,
                            "encoder")
   return tf.expand_dims(encoder_output, axis=2)
示例#22
0
def question_encoder(question, hparams, name="encoder"):
  """Question encoder, run LSTM encoder and get the last output as encoding."""
  with tf.variable_scope(name, "encoder", values=[question]):
    question = common_layers.flatten4d3d(question)
    padding = common_attention.embedding_to_padding(question)
    length = common_attention.padding_to_length(padding)

    max_question_length = hparams.max_question_length
    question = question[:, :max_question_length, :]
    actual_question_length = common_layers.shape_list(question)[1]
    length = tf.minimum(length, max_question_length)
    padding = [[0, 0],
               [0, max_question_length-actual_question_length],
               [0, 0]]
    question = tf.pad(question, padding)
    question_shape = question.get_shape().as_list()
    question_shape[1] = max_question_length
    question.set_shape(question_shape)

    # apply tanh dropout on question embedding
    question = tf.tanh(question)
    question = tf.nn.dropout(question, keep_prob=1.-hparams.dropout)

    question = [question[:, i, :] for i in range(max_question_length)]

    # rnn_layers = [_get_rnn_cell(hparams)
    #               for _ in range(hparams.num_rnn_layers)]
    # rnn_multi_cell = tf.contrib.rnn.MultiRNNCell(rnn_layers)
    rnn_cell = _get_rnn_cell(hparams)
    # outputs, _ = tf.nn.dynamic_rnn(
    #     rnn_cell, question, length, dtype=tf.float32)
    _, state = tf.nn.static_rnn(rnn_cell, question, sequence_length=length,
                                dtype=tf.float32)
    # outputs = [tf.expand_dims(output, axis=1) for output in outputs]
    # outputs = tf.concat(outputs, axis=1)

    # utils.collect_named_outputs("vqa_attention_debug", "question_output",
    #                             outputs)
    # utils.collect_named_outputs("vqa_attention_debug", "question_state",
    #                             state.h)

    # batch_size = common_layers.shape_list(outputs)[0]
    # row_indices = tf.range(batch_size)
    # # length - 1 as index
    # indices = tf.transpose([row_indices, tf.maximum(length-1, 0)])
    # last_output = tf.gather_nd(outputs, indices)

    # utils.collect_named_outputs("vqa_attention_debug",
    #                             "question_final_output", last_output)

  return state.h
示例#23
0
def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams):
  """Middle part of slicenet, connecting encoder and decoder."""

  def norm_fn(x, name):
    with tf.variable_scope(name, default_name="norm"):
      return common_layers.apply_norm(x, hparams.norm_type, hparams.hidden_size,
                                      hparams.norm_epsilon)

  # Flatten targets and embed target_space_id.
  targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2)
  target_space_emb = tf.tile(target_space_emb,
                             [tf.shape(targets_flat)[0], 1, 1, 1])

  # Calculate similarity loss (but don't run if not needed).
  if len(hparams.problems) > 1 and hparams.sim_loss_mult > 0.00001:
    targets_timed = common_layers.add_timing_signal(targets_flat)
    extra_layers = int(hparams.num_hidden_layers * 1.5)
    with tf.variable_scope(tf.get_variable_scope(), reuse=True):
      targets_encoded = multi_conv_res(targets_timed, "SAME", "encoder",
                                       extra_layers, hparams)
    with tf.variable_scope("similarity_loss"):
      similarity_loss = similarity_cost(inputs_encoded, targets_encoded)
      similarity_loss *= hparams.sim_loss_mult
  else:
    similarity_loss = 0.0

  # Use attention from each target to look at input and retrieve.
  targets_shifted = common_layers.shift_right(
      targets_flat, pad_value=target_space_emb)
  if hparams.attention_type == "none":
    targets_with_attention = tf.zeros_like(targets_shifted)
  else:
    inputs_padding_bias = (1.0 - mask) * -1e9  # Bias to not attend to padding.
    targets_with_attention = attention(
        targets_shifted,
        inputs_encoded,
        norm_fn,
        hparams,
        bias=inputs_padding_bias)

  # Positional targets: merge attention and raw.
  kernel = (hparams.kernel_height, hparams.kernel_width)
  targets_merged = common_layers.subseparable_conv_block(
      tf.concat([targets_with_attention, targets_shifted], axis=3),
      hparams.hidden_size, [((1, 1), kernel)],
      normalizer_fn=norm_fn,
      padding="LEFT",
      separability=4,
      name="targets_merge")

  return targets_merged, similarity_loss
示例#24
0
def lstm_seq2seq_internal_attention(inputs, targets, hparams, train):
  """LSTM seq2seq model with attention, main step used for training."""
  with tf.variable_scope("lstm_seq2seq_attention"):
    # This is a temporary fix for varying-length sequences within in a batch.
    # A more complete fix should pass a length tensor from outside so that
    # all the lstm variants can use it.
    inputs_length = common_layers.length_from_embedding(inputs)
    # Flatten inputs.
    inputs = common_layers.flatten4d3d(inputs)

    # LSTM encoder.
    inputs = tf.reverse_sequence(inputs, inputs_length, seq_axis=1)
    encoder_outputs, final_encoder_state = lstm(
        inputs, inputs_length, hparams, train, "encoder")

    # LSTM decoder with attention.
    shifted_targets = common_layers.shift_right(targets)
    # Add 1 to account for the padding added to the left from shift_right
    targets_length = common_layers.length_from_embedding(shifted_targets) + 1
    decoder_outputs = lstm_attention_decoder(
        common_layers.flatten4d3d(shifted_targets), hparams, train, "decoder",
        final_encoder_state, encoder_outputs, inputs_length, targets_length)
    return tf.expand_dims(decoder_outputs, axis=2)
示例#25
0
  def targets_bottom(self, inputs):
    with tf.variable_scope(self.name):
      # Reshape inputs to 2-d tensor and embed the RGB pixel values.
      ret = common_layers.embedding(
          tf.to_int32(common_layers.flatten4d3d(inputs)),
          self.top_dimensionality,
          self._body_input_depth,
          name="input_rgb_embedding")
      if self._model_hparams.multiply_embedding_mode == "sqrt_depth":
        ret *= self._body_input_depth**0.5

      reshape_shape = common_layers.shape_list(inputs)[:3]
      reshape_shape.append(self._body_input_depth * 3)
      ret = tf.reshape(ret, reshape_shape)
      return tf.layers.dense(ret, self._body_input_depth)
示例#26
0
  def _prepare_encoder(self, inputs, target_space):
    """Process the transformer encoder inputs."""
    inputs = common_layers.flatten4d3d(inputs)

    output = transformer.transformer_prepare_encoder(
        inputs,
        target_space,
        self._hparams,
        features=None,
    )
    enco_input, enco_self_att_bias, enco_deco_att_bias = output

    enco_input = tf.nn.dropout(
        enco_input, 1.0 - self._hparams.layer_prepostprocess_dropout)

    return enco_input, enco_self_att_bias, enco_deco_att_bias
  def encode(self, inputs, target_space, hparams, features=None, losses=None):
    """Encode Universal Transformer inputs.

    It is similar to "transformer.encode", but it uses
    "universal_transformer_util.universal_transformer_encoder" instead of
    "transformer.transformer_encoder".

    Args:
      inputs: Transformer inputs [batch_size, input_length, input_height,
        hidden_dim] which will be flattened along the two spatial dimensions.
      target_space: scalar, target space ID.
      hparams: hyperparmeters for model.
      features: optionally pass the entire features dictionary as well.
        This is needed now for "packed" datasets.
      losses: Unused.

    Returns:
      Tuple of:
          encoder_output: Encoder representation.
              [batch_size, input_length, hidden_dim]
          encoder_decoder_attention_bias: Bias and mask weights for
              encoder-decoder attention. [batch_size, input_length]
          encoder_extra_output: which is extra encoder output used in some
            variants of the model (e.g. in ACT, to pass the ponder-time to body)
    """
    del losses

    inputs = common_layers.flatten4d3d(inputs)

    encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
        transformer.transformer_prepare_encoder(
            inputs, target_space, hparams, features=features))

    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)

    (encoder_output, encoder_extra_output) = (
        universal_transformer_util.universal_transformer_encoder(
            encoder_input,
            self_attention_bias,
            hparams,
            nonpadding=transformer.features_to_nonpadding(features, "inputs"),
            save_weights_to=self.attention_weights))

    return encoder_output, encoder_decoder_attention_bias, encoder_extra_output
示例#28
0
  def body(self, features):
    hparams = self._hparams
    inputs = features["inputs"]
    target_space = features["target_space_id"]

    inputs = common_layers.flatten4d3d(inputs)

    (encoder_input, encoder_self_attention_bias, _) = (
        transformer_prepare_encoder(inputs, target_space, hparams))

    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)
    encoder_output = transformer_encoder(
        encoder_input, encoder_self_attention_bias, hparams,
        nonpadding=features_to_nonpadding(features, "inputs"))
    encoder_output = tf.expand_dims(encoder_output, 2)

    return encoder_output
示例#29
0
  def body(self, features):
    """Transformer main model_fn.

    Args:
      features: Map of features to the model. Should contain the following:
          "inputs": Transformer inputs [batch_size, input_length, hidden_dim]
          "tragets": Target decoder outputs.
              [batch_size, decoder_length, hidden_dim]
          "target_space_id"

    Returns:
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
    """
    hparams = self._hparams

    if self.has_input:
      inputs = features["inputs"]
      target_space = features["target_space_id"]
      encoder_output, encoder_decoder_attention_bias = self.encode(
          inputs, target_space, hparams, features=features)
    else:
      encoder_output, encoder_decoder_attention_bias = (None, None)

    targets = features["targets"]
    targets = common_layers.flatten4d3d(targets)

    decoder_input, decoder_self_attention_bias = transformer_prepare_decoder(
        targets, hparams, features=features)

    decoder_output = self.decode(
        decoder_input,
        encoder_output,
        encoder_decoder_attention_bias,
        decoder_self_attention_bias,
        hparams,
        nonpadding=features_to_nonpadding(features, "targets"))

    expected_attentions = features.get("expected_attentions")
    if expected_attentions is not None:
      attention_loss = common_attention.encoder_decoder_attention_loss(
          expected_attentions, self.attention_weights)
      return decoder_output, {"attention_loss": attention_loss}

    return decoder_output
  def encode(self, features, input_key):
    hparams = self._hparams
    inputs = common_layers.flatten4d3d(features[input_key])

    (encoder_input, encoder_self_attention_bias, _) = (
        transformer.transformer_prepare_encoder(inputs, problem.SpaceID.EN_TOK,
                                                hparams))

    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)
    encoder_output = transformer.transformer_encoder(
        encoder_input,
        encoder_self_attention_bias,
        hparams,
        nonpadding=transformer.features_to_nonpadding(features, input_key))

    encoder_output = tf.reduce_mean(encoder_output, axis=1)

    return encoder_output
    def internal(self, features, real_features):
        """Main procedure for both training and inference."""
        inputs = common_layers.flatten4d3d(features["inputs"])
        targets = common_layers.flatten4d3d(features["targets"])
        target_space = features["target_space_id"]
        hparams = self._hparams
        inputs_mask = ops.embedding_to_non_padding(inputs)
        inputs_length = tf.reduce_sum(inputs_mask, axis=-1)

        encoder_output, encoder_decoder_attention_bias = (ops.encoder(
            "encoder", hparams, inputs, target_space))
        kwargs = {
            "encoder_output": encoder_output,
            "encoder_decoder_attention_bias": encoder_decoder_attention_bias
        }
        losses, monitor = {}, {}
        log_abs_det = tf.constant(0.0)

        if not self.is_predicting:
            # Training
            targets_mask = ops.embedding_to_non_padding(targets)
            targets_length = tf.reduce_sum(targets_mask, axis=-1)
            length_diff = targets_length - inputs_length
            decoder_self_attention_bias = (
                common_attention.attention_bias_ignore_padding(1.0 -
                                                               targets_mask))
            z_q, log_q_z, q_dist = self.sample_q(targets,
                                                 targets_mask,
                                                 decoder_self_attention_bias,
                                                 n_samples=1,
                                                 temp=1.0,
                                                 **kwargs)

            body_output = ops.decoder("decoder", z_q, hparams,
                                      decoder_self_attention_bias, **kwargs)
            logits = self.top(body_output, real_features)
            numerator, denominator = self.loss(logits, real_features)

            if not (self.is_evaluating and (hparams.compute_kl_refinement
                                            or hparams.compute_iw_marginal)):
                targets_length_pred, lenpred_loss = ops.predict_target_lengths(
                    encoder_output, inputs_mask, hparams, length_diff)
                log_p_z_base, log_abs_det = self.compute_prior_log_prob(
                    z_q,
                    targets_mask,
                    decoder_self_attention_bias,
                    check_invertibility=False,
                    **kwargs)
                losses, monitor = ops.save_log_loss(
                    hparams, targets_mask, numerator, denominator, log_q_z,
                    log_abs_det, log_p_z_base, z_q, lenpred_loss,
                    targets_length_pred, targets_length)

            if self.is_evaluating:
                if hparams.compute_kl_refinement:
                    z_p, _ = self.sample_p(targets_length,
                                           temp=self._decode_hparams.temp,
                                           check_invertibility=False,
                                           targets_mask=targets_mask,
                                           **kwargs)
                    z_dq = self.delta_posterior(
                        z_p, targets_mask, decoder_self_attention_bias,
                        self._decode_hparams.n_gibbs_steps, **kwargs)
                    log_q_z_ = q_dist.log_prob(z_dq)
                    log_q_z_ = gops.reduce_mean_over_bl_sum_over_c(
                        log_q_z_, targets_mask)
                    losses = {"training": log_q_z_}

                if hparams.compute_iw_marginal:
                    # if True:
                    log_p_y_x = self.compute_iw_marginal(
                        targets, targets_mask, decoder_self_attention_bias,
                        real_features, self._decode_hparams.n_samples,
                        **kwargs)
                    # real_features, 1, **kwargs)
                    losses = {"training": log_p_y_x}

            return logits, losses, monitor, targets_mask

        else:
            # Inference
            targets_length, _ = ops.predict_target_lengths(
                encoder_output, inputs_mask, hparams)
            targets_mask = ops.sequence_mask(targets_length, hparams)
            decoder_self_attention_bias = (
                common_attention.attention_bias_ignore_padding(1.0 -
                                                               targets_mask))
            z_p, _ = self.sample_p(targets_length,
                                   temp=self._decode_hparams.temp,
                                   check_invertibility=False,
                                   **kwargs)
            z_q = self.delta_posterior(z_p, targets_mask,
                                       decoder_self_attention_bias,
                                       self._decode_hparams.n_gibbs_steps,
                                       **kwargs)
            # 0, **kwargs)

            body_output = ops.decoder("decoder", z_q, hparams,
                                      decoder_self_attention_bias, **kwargs)
            return body_output, losses, monitor, targets_mask
示例#32
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """AE Transformer, main step used for training."""
    # Summaries break with the do_refine cond, turn them off in that case.
    global _DO_SUMMARIES
    if hparams.do_refine:
        _DO_SUMMARIES = False

    # Prepare.
    if inputs is not None:
        batch_size = common_layers.shape_list(inputs)[0]
    else:
        batch_size = common_layers.shape_list(targets)[0]
    targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

    # Encoder.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed = encode(inputs, target_space, hparams, "input_enc")
        inputs_ex, ed_ex = inputs, ed
    else:
        ed, inputs_ex, ed_ex = None, None, None

    # Autoencoding.
    losses = {
        "extra": tf.constant(0.0),
        "latent_pred": tf.constant(0.0),
        "neg_q_entropy": tf.constant(0.0)
    }
    if hparams.do_ae:
        # flatten here
        original_targets_shape = tf.shape(targets)
        if hparams.task == "image":
            cia.maybe_reshape_4d_to_3d(targets)
        if hparams.task == "translate":
            if inputs is not None:
                max_targets_len_from_inputs = tf.concat([inputs, inputs],
                                                        axis=1)
            else:
                max_targets_len_from_inputs = targets
        else:
            assert hparams.task == "image"
            max_targets_len_from_inputs = targets
        if hparams.word_shuffle:
            tf.logging.info("Using word shuffle with rate = {}".format(
                hparams.word_shuffle))
            targets_idx = tf.range(start=0,
                                   limit=common_layers.shape_list(targets)[1],
                                   delta=1)
            targets_idx = tf.to_float(targets_idx)
            noise = tf.random_uniform(
                shape=common_layers.shape_list(targets_idx),
                minval=0,
                maxval=1 + hparams.word_shuffle)
            targets_idx += noise
            permutation = tf.contrib.framework.argsort(targets_idx)
            targets_permuted = tf.gather(targets, indices=permutation, axis=1)
            targets = targets_permuted
        targets, _ = common_layers.pad_to_same_length(
            targets,
            max_targets_len_from_inputs,
            final_length_divisible_by=2**hparams.num_compress_steps)
        if hparams.word_dropout:
            mask = tf.random_uniform(shape=common_layers.shape_list(targets),
                                     minval=0.0,
                                     maxval=1.0)
            targets_noisy = tf.where(mask > hparams.word_dropout, targets,
                                     tf.zeros_like(targets))
        else:
            targets_noisy = targets
        targets_c = compress(targets_noisy, inputs, False, hparams, "compress")
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            # Compress and bottleneck.
            latents_dense, latents_discrete, extra_loss, embed, neg_q_entropy = (
                hparams.bottleneck(inputs=targets_c,
                                   filter_size=hparams.compress_filter_size,
                                   mode=hparams.mode,
                                   name="vc"))
            if _DO_SUMMARIES:
                tf.summary.histogram(
                    "b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
            pc = common_layers.inverse_exp_decay(hparams.startup_steps)
            pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
            cond = tf.less(tf.random_uniform([batch_size]), pc)
            latents_dense = tf.where(cond, latents_dense, targets_c)
            # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
            losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
            # Extra loss predicting latent code from input. Discrete only.
            if hparams.bottleneck_kind not in ["dense", "vae"]:
                latents_pred = decode_transformer(inputs_ex,
                                                  ed_ex,
                                                  embed(latents_discrete),
                                                  hparams,
                                                  "extra",
                                                  task="translate")
                _, latent_pred_loss = ae_latent_softmax(
                    latents_pred, tf.stop_gradient(latents_discrete), hparams)

                # Scale by latent dimension for summary so we can compare across
                # batches.
                if _DO_SUMMARIES:
                    tf.summary.scalar("latent_pred_loss_mean",
                                      tf.reduce_mean(latent_pred_loss))
                if hparams.sum_over_latents:
                    latent_pred_loss = tf.reduce_sum(latent_pred_loss, [1, 2])

                losses["latent_pred"] = tf.reduce_mean(
                    latent_pred_loss * tf.to_float(cond)) * hparams.prior_scale
                losses["neg_q_entropy"] = neg_q_entropy * hparams.entropy_scale
            else:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                losses["latent_pred"] = tf.reduce_mean(
                    (inputs_c - targets_c)**2) * 20

                def bn_inputs():
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=True):
                        bn, _, _, _, _ = hparams.bottleneck(
                            inputs=inputs_c,
                            filter_size=hparams.compress_filter_size,
                            mode=hparams.mode,
                            name="vc")
                    return bn

                inputs_c = bn_inputs()
                ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
                ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
                latents_dense = tf.where(
                    tf.less(tf.random_uniform([batch_size]), ptc),
                    latents_dense, inputs_c)
        else:
            if hparams.bottleneck_kind in ["dense", "vae"]:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                latents_dense, _, _, _, _ = hparams.bottleneck(
                    inputs=inputs_c,
                    filter_size=hparams.compress_filter_size,
                    mode=hparams.mode,
                    name="vc")
            else:
                latent_len = common_layers.shape_list(targets_c)[1]
                _, _, _, embed, _ = hparams.bottleneck(
                    inputs=targets_c,
                    filter_size=hparams.compress_filter_size,
                    name="vc")
                latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
                if cache is None:
                    cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex,
                                             embed, 16, hparams)
                latents_dense = embed(cache)
        # Postprocess.
        d = latents_dense
        latent_len = common_layers.shape_list(latents_dense)[1]
        if isinstance(latent_len, tf.Tensor):
            # TODO(trandustin): Fix this in a better manner.
            latent_len = max(1000, hparams.max_length)
        pos = tf.get_variable("pos",
                              [1, latent_len + 1, 1, hparams.hidden_size])
        pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
        latents_dense = tf.pad(latents_dense,
                               [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

        # decompressing the dense latents
        for i in range(hparams.num_compress_steps):
            j = hparams.num_compress_steps - i - 1
            d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
            if inputs is not None and hparams.do_attend_decompress:
                d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
            d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)

        # Masking.
        if hparams.do_mask:
            masking = common_layers.inverse_lin_decay(
                hparams.mask_startup_steps)
            masking *= common_layers.inverse_exp_decay(
                hparams.mask_startup_steps // 4)  # Not much at start.
            if not hparams.do_refine:
                masking -= tf.random_uniform([]) * hparams.unmasked_percentage
            masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
            if hparams.use_predict_mask:
                masking = predict_mask
            if hparams.mode == tf.estimator.ModeKeys.PREDICT:
                masking = predict_mask
            mask = tf.less(
                masking,
                tf.random_uniform(common_layers.shape_list(targets)[:-1]))
            mask = tf.expand_dims(tf.to_float(mask), 3)

            # targets is always [batch, length, 1, depth]
            targets = mask * targets + (1.0 - mask) * d
            # reshape back to 4d here
            if hparams.task == "image":
                targets = tf.reshape(targets, original_targets_shape)

    res = decode_transformer(inputs,
                             ed,
                             targets,
                             hparams,
                             "decoder",
                             causal=hparams.causal)
    if hparams.do_ae:
        if hparams.do_mask and hparams.do_refine:

            def refine_res():
                # return residual_conv(res, 1, (5, 1), hparams, "refine")
                r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams,
                              "refine_enc")
                return tf.expand_dims(r, axis=2)

            masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
            all_masked = tf.less(masked_batches, 0.1)
            res = tf.where(all_masked, refine_res(), res)
        # We'll start training the extra model of latents after mask_startup_steps.
        nonlatent_steps = hparams.mask_startup_steps
        latent_time = tf.less(nonlatent_steps,
                              tf.to_int32(tf.train.get_global_step()))
        losses["latent_pred"] *= tf.to_float(latent_time)

    # res was generated from padded targets, which means it has some extra
    # elements. These can cause shape problems when computing loss with respect to
    # the original (unpadded) targets. So we remove their extra elements here.
    res = res[:, :original_targets_shape[1], :, :]
    return res, losses, cache
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0,
                            means=None,
                            ema_count=None,
                            ema_means=None):
  """AE Transformer, main step used for training."""
  # Summaries break with the do_refine cond, turn them off in that case.
  global _DO_SUMMARIES
  if hparams.do_refine:
    _DO_SUMMARIES = False

  # Prepare.
  batch_size = common_layers.shape_list(inputs)[0]
  targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

  # Encoder.
  if inputs is not None:
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed = encode(inputs, target_space, hparams, "input_enc")
  else:
    ed = None

  # Autoencoding.
  losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
  if hparams.do_ae:
    max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
    targets, _ = common_layers.pad_to_same_length(
        targets, max_targets_len_from_inputs,
        final_length_divisible_by=2**hparams.num_compress_steps)
    targets_c = compress(targets, False, hparams, "compress")
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
      # Compress and bottleneck.
      latents_dense, latents_discrete, extra_loss, _ = bottleneck(
          targets_c, hparams, 2 * 2048, "vc", means, ema_count, ema_means)
      if _DO_SUMMARIES:
        tf.summary.histogram("b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
      pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95
      pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
      cond = tf.less(tf.random_uniform([batch_size]), pc)
      latents_dense = tf.where(cond, latents_dense, targets_c)
      # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
      losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
      # Extra loss predicting latent code from input. Discrete only.
      if hparams.bottleneck_kind not in ["dense", "vae"]:
        latents_pred = decode_transformer(
            tf.stop_gradient(inputs), tf.stop_gradient(ed),
            tf.stop_gradient(latents_dense), hparams, "extra")
        latents_pred = tf.layers.dense(latents_pred, 2**16, name="extra_logits")
        losses["latent_pred"] = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=latents_discrete, logits=latents_pred)
        losses["latent_pred"] = tf.reduce_mean(
            losses["latent_pred"] * 0.5 * tf.to_float(cond))
      else:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        losses["latent_pred"] = tf.reduce_mean((inputs_c - targets_c)**2) * 20
        def bn_inputs():
          with tf.variable_scope(tf.get_variable_scope(), reuse=True):
            bn, _, _, _ = bottleneck(inputs_c, hparams, 2 * 2048, "vc", means,
                                     ema_count, ema_means)
          return bn
        pbn = 0.8 if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
        inputs_c = tf.cond(tf.less(tf.random_uniform([]), pbn),
                           bn_inputs, lambda: inputs_c)
        ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
        ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
        latents_dense = tf.where(tf.less(tf.random_uniform([batch_size]), ptc),
                                 latents_dense, inputs_c)
    else:
      if hparams.bottleneck_kind in ["dense", "vae"]:
        inputs_c = decode_transformer(inputs, ed, targets_c, hparams, "dec_c")
        latents_dense, _, _, _ = bottleneck(inputs_c, hparams, 2 * 2048, "vc",
                                            means, ema_count, ema_means)
      else:
        latent_len = common_layers.shape_list(targets_c)[1]
        _, _, _, embed = bottleneck(targets_c, hparams, 2 * 2048, "vc", means,
                                    ema_count, ema_means)
        latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
        if cache is None:
          cache = ae_latent_sample(latents_dense, inputs, ed, embed, 8, hparams)
        latents_dense = embed(cache)
    # Postprocess.
    d = latents_dense
    pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
    pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
    latents_dense = tf.pad(latents_dense,
                           [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

    # Masking.
    if hparams.do_mask:
      masking = common_layers.inverse_lin_decay(100000)
      masking *= common_layers.inverse_exp_decay(25000)  # Not much at start.
      if not hparams.do_refine:
        masking -= tf.random_uniform([]) * 0.3
      masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
      if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        masking = predict_mask
      mask = tf.less(masking, tf.random_uniform(
          common_layers.shape_list(targets)[:-1]))
      mask = tf.expand_dims(tf.to_float(mask), 3)
      for i in xrange(hparams.num_compress_steps):
        j = hparams.num_compress_steps - i - 1
        d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
        d = decompress_step(d, hparams, i > 0, False, "decompress_%d" % j)
      targets = mask * targets + (1.0 - mask) * d
    targets = tf.concat([tf.reverse(latents_dense, [1]), targets], axis=1)

  res = decode_transformer(inputs, ed, targets, hparams, "decoder")
  if hparams.do_ae:
    res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :]
    if hparams.do_mask and hparams.do_refine:
      def refine_res():
        return residual_conv(res, 1, (5, 1), hparams, "refine")
      masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
      all_masked = tf.less(masked_batches, 0.1)
      res = tf.where(all_masked, refine_res(), res)
    # We'll start training only the extra model of latents after 400K steps.
    # Before we train only this, we decrease lr for other weights.
    latent_time = tf.less(300000, tf.to_int32(tf.train.get_global_step()))
    decreased_lr = common_layers.inverse_lin_decay(400000)
    losses["latent_pred"] *= tf.to_float(latent_time)
    losses["extra"] *= 1.0 - tf.to_float(latent_time)
    decreased_lr_res = tf.stop_gradient(decreased_lr * res)
    decreased_lr_res += (1.0 - decreased_lr) * res
    res = tf.cond(latent_time, lambda: decreased_lr_res, lambda: res)
  return res, losses, cache
示例#34
0
    def body(self, features):
        """Seq2Edits main model_fn.

    Args:
      features: Feature dictionary. Should contain the following fields:
          "inputs": [batch_size, input_length, 1, hidden_dim] float tensor with
            input token embeddings.
          "targets": [batch_size, target_length, 1, hidden_dim] float tensor
            with target token embeddings.
          "targets_error_tag": [batch_size, target_length, 1, hidden_dim] float
            tensor with target error tag embeddings.
          "target_space_id": A scalar int from data_generators.problem.SpaceID.

    Returns:
      Final decoder representation. Dictionary containing the following fields:
        "targets": [batch_size, target_length, hidden_dim] float tensor with
          decoder outputs
        "targets_error_tag": [batch_size, target_length, hidden_dim] float
          tensor with decoder outputs
    """
        hparams = self._hparams

        losses = []

        if self.has_input:
            target_space = features['target_space_id']
            encoder_output, encoder_decoder_attention_bias = self.encode(
                features['inputs'],
                target_space,
                hparams,
                features=features,
                losses=losses,
            )
        else:
            encoder_output, encoder_decoder_attention_bias = (None, None)

        targets = features['targets']
        targets_shape = common_layers.shape_list(targets)
        targets = common_layers.flatten4d3d(targets)
        decoder_input, decoder_self_attention_bias = self._prepare_decoder_fn(
            targets, hparams, features=features)

        nonpadding = features_to_nonpadding(features, 'targets')

        # Add edit ops layer to condition on start_token, end_token, and error_tag
        decoder_input = transformer_edit_ops_layer(
            decoder_input,
            hparams,
            encoder_output,
            features,
            nonpadding=nonpadding,
            losses=losses,
        )
        if hparams.middle_prediction:
            num_decoder_layers = (hparams.num_decoder_layers
                                  or hparams.num_hidden_layers)
            hparams.num_decoder_layers = int(
                num_decoder_layers / hparams.middle_prediction_layer_factor)

        decode_kwargs = {}
        decoder_output = self.decode(decoder_input,
                                     encoder_output,
                                     encoder_decoder_attention_bias,
                                     decoder_self_attention_bias,
                                     hparams,
                                     nonpadding=nonpadding,
                                     losses=losses,
                                     **decode_kwargs)

        loss_mask = common_layers.weights_nonzero(
            maybe_flatten4d2d(features['targets_raw']))
        self.loss_den = tf.reduce_sum(loss_mask)
        decoder_output = self._prediction_cascade(
            hparams=hparams,
            features=features,
            losses=losses,
            loss_mask=loss_mask,
            nonpadding=nonpadding,
            encoder_decoder_attention_bias=encoder_decoder_attention_bias,
            encoder_output=encoder_output,
            decoder_output=decoder_output,
        )

        if hparams.middle_prediction:
            with tf.variable_scope('after_prediction'):
                decoder_output = self.decode(decoder_input + decoder_output,
                                             encoder_output,
                                             encoder_decoder_attention_bias,
                                             decoder_self_attention_bias,
                                             hparams,
                                             nonpadding=nonpadding,
                                             losses=losses,
                                             **decode_kwargs)

        ret = {'targets': tf.reshape(decoder_output, targets_shape)}
        ret.update(self.logits)
        if losses:
            return ret, {'extra_loss': tf.add_n(losses)}
        else:
            return ret
示例#35
0
  def render2cmd_v3_internal(self, features, hparams, train):
    # inputs and targets are both sequences with
    # shape = [batch, seq_len, 1, hparams.problem.feature_dim]
    targets = features['targets']
    losses = {}

    sampled_bottleneck = self.pretrained_visual_encoder(features, hparams)
    if hparams.sg_bottleneck:
      sampled_bottleneck = tf.stop_gradient(sampled_bottleneck)

    with tf.variable_scope('render2cmd_v3_internal'):
      # override bottleneck, or return it, if requested
      if 'bottleneck' in features:
        if common_layers.shape_list(features['bottleneck'])[0] == 0:
          # return sampled_bottleneck,
          # set losses['training'] = 0 so self.top() doesn't get called on it
          return sampled_bottleneck, {'training': 0.0}
        else:
          # we want to use the given bottleneck
          sampled_bottleneck = features['bottleneck']

      # finalize bottleneck
      unbottleneck_dim = hparams.hidden_size * 2  # twice because using LSTM
      if hparams.twice_decoder:
        unbottleneck_dim = unbottleneck_dim * 2

      # unbottleneck back to LSTMStateTuple
      dec_initial_state = []
      for hi in range(hparams.num_hidden_layers):
        unbottleneck = self.unbottleneck(sampled_bottleneck, unbottleneck_dim,
                                         name_append='_{}'.format(hi))
        dec_initial_state.append(
            tf.nn.rnn_cell.LSTMStateTuple(
                c=unbottleneck[:, :unbottleneck_dim // 2],
                h=unbottleneck[:, unbottleneck_dim // 2:]))

      dec_initial_state = tuple(dec_initial_state)

      shifted_targets = common_layers.shift_right(targets)
      # Add 1 to account for the padding added to the left from shift_right
      targets_length = common_layers.length_from_embedding(shifted_targets) + 1

      # LSTM decoder
      hparams_decoder = copy.copy(hparams)
      if hparams.twice_decoder:
        hparams_decoder.hidden_size = 2 * hparams.hidden_size

      if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        decoder_outputs, _ = self.lstm_decoder_infer(
            common_layers.flatten4d3d(shifted_targets),
            targets_length, hparams_decoder, features['targets_cls'],
            train, initial_state=dec_initial_state,
            bottleneck=sampled_bottleneck)
      else:
        decoder_outputs, _ = self.lstm_decoder(
            common_layers.flatten4d3d(shifted_targets),
            targets_length, hparams_decoder, features['targets_cls'],
            train, initial_state=dec_initial_state,
            bottleneck=sampled_bottleneck)

      ret = tf.expand_dims(decoder_outputs, axis=2)

    return ret, losses
示例#36
0
def adv_transformer_internal(inputs, targets, target_space, hparams):
    """Adversarial Transformer, main step used for training."""
    with tf.variable_scope("adv_transformer"):
        batch_size = tf.shape(targets)[0]
        targets = tf.reshape(targets, [batch_size, -1, 1])
        intermediate = tf.constant(34 * 1024 - 1)
        intermediate += tf.zeros_like(targets)
        targets = tf.concat([targets, intermediate], axis=2)
        targets = tf.reshape(targets, [batch_size, -1, 1])
        embedding = tf.get_variable("embedding",
                                    [34 * 1024, hparams.hidden_size])
        targets_emb = tf.gather(embedding, targets)

        # Noisy embedded targets.
        targets_noisy = tf.one_hot(targets, 34 * 1024)
        noise_val = hparams.noise_val
        targets_noisy += tf.random_uniform(tf.shape(targets_noisy),
                                           minval=-noise_val,
                                           maxval=noise_val)
        targets_emb_noisy = softmax_embed(targets_noisy, embedding, batch_size,
                                          hparams)

        # Encoder.
        if inputs is not None:
            inputs_emb = common_layers.flatten4d3d(inputs)
            inputs, ed = encode(inputs_emb, target_space, hparams, "input_enc")
        else:
            ed = None

        # Masking.
        masking = common_layers.inverse_lin_decay(200000)
        masking *= common_layers.inverse_exp_decay(50000)  # Not much at start.
        masking -= tf.random_uniform([]) * 0.4
        masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
        mask = tf.less(masking, tf.random_uniform(tf.shape(targets)))
        mask = tf.expand_dims(tf.to_float(mask), 3)
        noise = tf.random_uniform(tf.shape(targets_emb))
        targets_emb = mask * targets_emb + (1.0 - mask) * noise

        # Decoder.
        res_dec = decode(inputs, ed, targets_emb, hparams, "decoder")
        res = tf.layers.dense(res_dec, 34 * 1024, name="res_sm")
        res_emb = softmax_embed(res, embedding, batch_size, hparams)

        # Extra steps.
        extra_step_prob = masking * 0.6 + 0.3
        if hparams.mode != tf.estimator.ModeKeys.TRAIN:
            extra_step_prob = 1.0
        for _ in xrange(hparams.extra_steps):

            def another_step(emb):
                res_dec = decode(inputs,
                                 ed,
                                 emb,
                                 hparams,
                                 "decoder",
                                 reuse=True)
                res = tf.layers.dense(res_dec,
                                      34 * 1024,
                                      name="res_sm",
                                      reuse=True)
                return softmax_embed(res, embedding, batch_size, hparams), res

            res_emb, res = tf.cond(tf.less(tf.random_uniform([]),
                                           extra_step_prob),
                                   lambda e=res_emb: another_step(e),
                                   lambda: (res_emb, res))

        # Adversary.
        delta = masking * hparams.delta_max
        true_logit = adversary(tf.stop_gradient(targets_emb_noisy),
                               tf.stop_gradient(inputs + inputs_emb), hparams,
                               "adversary")
        gen_logit = adversary(reverse_gradient(res_emb, delta),
                              tf.stop_gradient(inputs + inputs_emb),
                              hparams,
                              "adversary",
                              reuse=True)
        losses = {"adv": gen_logit - true_logit}
        res = tf.stop_gradient(masking * res) + (1.0 - masking) * res
        return res, losses
    def render2cmd_v3_internal(self, features, hparams, train):
        # inputs and targets are both sequences with
        # shape = [batch, seq_len, 1, hparams.problem.feature_dim]
        print(
            "render2cmd_v3_internal render2cmd_v3_internalrender2cmd_v3_internalrender2cmd_v3_internalrender2cmd_v3_internal"
        )
        all_targets = features['targets']
        all_targets_cls = features['targets_cls']
        all_targets_font_cls = features['targets_fnt']
        all_targets_psr = features['targets_psr']
        all_batch_size = common_layers.shape_list(all_targets)[0]
        batch_size = all_batch_size // 2
        sources = all_targets[:batch_size, ...]
        sources_cls = all_targets_cls[:batch_size, ...]
        sources_fnt = all_targets_font_cls[:batch_size, ...]
        sources_psr = all_targets_psr[:batch_size, ...]
        targets = all_targets[batch_size:, ...]
        targets_cls = all_targets_cls[batch_size:, ...]
        targets_fnt = all_targets_font_cls[batch_size:, ...]
        targets_psr = all_targets_psr[batch_size:, ...]

        losses = {}
        # sampled_bottleneck = self.pretrained_visual_encoder(features, hparams)

        # if hparams.sg_bottleneck:
        #     sampled_bottleneck = tf.stop_gradient(sampled_bottleneck)
        # embd = self.cls_embedding(sources_cls, sources_fnt, targets_cls, targets_fnt)
        vis_embd = self.vis_encoder(sources_psr, targets_psr, targets_cls)
        # print("embd embd embd embd embd embd embd ", embd.shape)
        print("vis embd vis embd vis embd vis embd vis", vis_embd.shape)
        sampled_bottleneck = vis_embd

        with tf.variable_scope('render2cmd_v3_internal'):
            # override bottleneck, or return it, if requested
            # if 'bottleneck' in features:
            #     if common_layers.shape_list(features['bottleneck'])[0] == 0:
            #         # return sampled_bottleneck,
            #         # set losses['training'] = 0 so self.top() doesn't get called on it
            #         print("RETURNRETURNRETURNRETURNRETURNRETURNRETURNRETURNRETURNRETURNRETURN")
            #         return sampled_bottleneck, {'training': 0.0}
            #     else:
            #         # we want to use the given bottleneck
            #         sampled_bottleneck = features['bottleneck']

            # finalize bottleneck
            unbottleneck_dim = hparams.hidden_size * 2  # twice because using LSTM
            if hparams.twice_decoder:
                unbottleneck_dim = unbottleneck_dim * 2

            dec_initial_state = []

            # LSTM encoder
            _, encoder_output_states = self.lstm_encoder(
                common_layers.flatten4d3d(sources), hparams)

            print(
                "targets shape targets shape targets shape targets shape targets shape ",
                targets.shape)
            print('run stacking...')
            print(
                "sample bottleneck shape sample bottleneck shape sample bottleneck shape ",
                sampled_bottleneck.shape)
            print(
                "sources shape sources shape sources shape sources shape sources shape",
                sources.shape)
            # input()
            for hi in range(hparams.num_hidden_layers):
                unbottleneck = self.unbottleneck(sampled_bottleneck,
                                                 unbottleneck_dim,
                                                 name_append='_{}'.format(hi))
                c, h = encoder_output_states[hi]
                # print(unbottleneck.shape)
                # print(c.shape, h.shape)
                # first_dim = common_layers.shape_list(unbottleneck)[0]
                # print(first_dim)
                # c = tf.tile(c,[first_dim,1])
                # h = tf.tile(h,[first_dim,1])
                # input()
                dec_initial_state.append(
                    tf.nn.rnn_cell.LSTMStateTuple(
                        c=tf.concat(
                            [unbottleneck[:, :unbottleneck_dim // 2], c], 1),
                        h=tf.concat(
                            [unbottleneck[:, unbottleneck_dim // 2:], h], 1)))

            dec_initial_state = tuple(dec_initial_state)
            # print('checkshape dec_initial_state')
            # print(dec_initial_state)
            # input()
            shifted_targets = common_layers.shift_right(targets)
            # Add 1 to account for the padding added to the left from shift_right
            targets_length = common_layers.length_from_embedding(
                shifted_targets) + 1

            # LSTM decoder
            hparams_decoder = copy.copy(hparams)
            if hparams.twice_decoder:
                hparams_decoder.hidden_size = 2 * hparams.hidden_size

            if hparams.mode == tf.estimator.ModeKeys.PREDICT:
                decoder_outputs, _ = self.lstm_decoder_infer(
                    common_layers.flatten4d3d(shifted_targets),
                    targets_length,
                    hparams_decoder,
                    targets_cls,
                    train,
                    initial_state=dec_initial_state,
                    bottleneck=sampled_bottleneck)
            else:
                decoder_outputs, _ = self.lstm_decoder(
                    common_layers.flatten4d3d(shifted_targets),
                    targets_length,
                    hparams_decoder,
                    targets_cls,
                    train,
                    initial_state=dec_initial_state,
                    bottleneck=sampled_bottleneck)

            ret = tf.expand_dims(decoder_outputs, axis=2)
        return ret, losses
示例#38
0
def transformer_autoencoder(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """AE Transformer, main step used for training."""
    # Define losses
    losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}

    # Reshape image targets as 4d tensor.
    original_targets_shape = common_layers.shape_list(targets)
    if len(original_targets_shape) == 4:
        compress_fn = compress_encoder_2d
        decompress_fn = decompress_decoder_2d
    else:
        compress_fn = compress_encoder_1d
        decompress_fn = decompress_decoder_1d

    # Encoder decoder attention bias.
    ed_attention_bias = None

    # Input Encoder if present.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed_attention_bias = transformer_text_encoder(
            inputs, target_space, hparams, "input_enc")

    # Encode targets to compute targets compressed.
    targets_c = compress_fn(targets, hparams, "compress")
    targets, _, _ = cia.maybe_reshape_4d_to_3d(targets)

    # Following code creates an exponentially decaying variable based on which
    # we rescale the los values.
    batch_size = common_layers.shape_list(targets_c)[0]
    pc = common_layers.inverse_exp_decay(hparams.startup_steps)
    pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
    cond = tf.less(tf.random_uniform([batch_size]), pc)

    # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
    # Call bottleneck layer to get the latents.
    # Returns embedded latents, discrete latents, loss and the embedding function.
    latents_dense, latents_discrete, extra_loss, embed = (bottleneck_layer(
        targets_c, hparams))
    extra_loss = tf.reduce_mean(extra_loss) * tf.to_float(cond)

    # Call the autoregressive latent prediction model.
    _, latents_pred_loss = latent_prediction_model(targets_c,
                                                   ed_attention_bias,
                                                   latents_discrete,
                                                   embed,
                                                   hparams,
                                                   name="latent_pred")
    latents_pred_loss = tf.reduce_mean(latents_pred_loss) * tf.to_float(cond)

    # Assign latent loss
    losses["latent_pred"] = latents_pred_loss
    losses["extra_loss"] = extra_loss

    latents_decoder = latents_dense
    if len(original_targets_shape) == 4:
        cmp_img_len = hparams.img_len / (2**(hparams.num_compress_steps // 2))
        latents_decoder = tf.reshape(latents_decoder, [
            batch_size, cmp_img_len, cmp_img_len,
            hparams.num_latents * hparams.hidden_size
        ])

    # Decompress either using 1D or 2D upconvs.
    latents_decoder = decompress_fn(latents_decoder,
                                    hparams,
                                    name="decompress")
    # if we're operating in 2d space on images, then we're assuming that the
    # last dimension will not be a multiple of channels
    latents_decoder = tf.reshape(
        latents_decoder,
        shape=[-1, hparams.img_len, hparams.img_len, hparams.hidden_size])

    if hparams.use_gold_targets:
        latents_decoder, _, _ = cia.maybe_reshape_4d_to_3d(latents_decoder)
        masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps)
        if hparams.mode == tf.estimator.ModeKeys.PREDICT:
            masking = predict_mask
        mask = tf.less(
            masking, tf.random_uniform(common_layers.shape_list(targets)[:-1]))
        mask = tf.expand_dims(tf.to_float(mask), 2)
        targets = mask * targets + (1.0 - mask) * latents_decoder
    else:
        targets = latents_decoder
    # reshape back to 4d here
    targets = tf.reshape(targets, original_targets_shape)
    if hparams.decode_autoregressive:
        # Transformer decoder, that goes from inputs->targets
        res = transformer_image_decoder(inputs, ed_attention_bias, targets,
                                        hparams, "decoder")
    else:
        res = targets

    # We'll start training the extra model of latents after mask_startup_steps.
    latent_time = tf.less(hparams.mask_startup_steps,
                          tf.to_int32(tf.train.get_global_step()))
    losses["latent_pred"] *= tf.to_float(latent_time)
    return res, losses, cache
示例#39
0
    def body(self, features):
        hp = self.hparams
        # pylint: disable=eval-used
        if hp.image_input_type == "image":
            image_feat = vqa_layers.image_embedding(
                features["inputs"],
                model_fn=eval(hp.image_model_fn),
                trainable=hp.train_resnet,
                is_training=hp.mode == tf.estimator.ModeKeys.TRAIN)
        else:
            image_feat = features["inputs"]

        image_feat = common_layers.flatten4d3d(image_feat)
        image_hidden_size = hp.image_hidden_size or hp.hidden_size
        if hp.image_feat_preprocess_proj:
            image_feat = common_layers.dense(image_feat, image_hidden_size)
            utils.collect_named_outputs("norms", "image_feat_after_proj",
                                        tf.norm(image_feat, axis=-1))
        else:
            assert image_hidden_size == 2048

        image_feat = tf.nn.dropout(image_feat,
                                   keep_prob=1. -
                                   hp.layer_prepostprocess_dropout)

        if hp.image_feat_encode:
            image_feat = image_encoder(image_feat, hp)
            utils.collect_named_outputs("norms", "image_feat_encoded",
                                        tf.norm(image_feat, axis=-1))
        else:
            image_feat = common_layers.layer_norm(image_feat)
            utils.collect_named_outputs("norms", "image_feat_after_layer",
                                        tf.norm(image_feat, axis=-1))

        question = common_layers.flatten4d3d(features["question"])
        utils.collect_named_outputs("norms", "question_embedding",
                                    tf.norm(question, axis=-1))
        question, question_self_attention_bias = prepare_question_encoder(
            question, hp)
        question = tf.nn.dropout(question,
                                 keep_prob=1. -
                                 hp.layer_prepostprocess_dropout)
        query = question_encoder(question, question_self_attention_bias, hp)
        utils.collect_named_outputs("norms", "query_encode",
                                    tf.norm(query, axis=-1))
        query = (query + tf.expand_dims(
            tf.squeeze(question_self_attention_bias, [1, 2]), axis=2))
        query = tf.reduce_max(query, axis=1)
        utils.collect_named_outputs("norms", "query_maxpool",
                                    tf.norm(query, axis=-1))

        # query = common_layers.l2_norm(query)
        # utils.collect_named_outputs("norms", "query_after_l2",
        #                             tf.norm(query, axis=-1))

        image_ave = attn(image_feat, query, hp)
        utils.collect_named_outputs("norms", "image_ave",
                                    tf.norm(image_ave, axis=-1))

        if hp.multimodal_combine == "concat":
            image_question = tf.concat([image_ave, query], axis=1)
        elif hp.multimodal_combine == "sum":
            image_question = image_ave + query
        elif hp.multimodal_combine == "product":
            image_question = image_ave * query

        utils.collect_named_outputs("norms", "image_question",
                                    tf.norm(image_question, axis=-1))

        image_question = tf.nn.dropout(image_question, 1. - hp.dropout)

        output = mlp(image_question, hp)
        utils.collect_named_outputs("norms", "output", tf.norm(output,
                                                               axis=-1))

        norm_tensors = utils.convert_collection_to_dict("norms")
        vqa_layers.summarize_tensors(norm_tensors, tag="norms/")

        # Expand dimension 1 and 2
        return tf.expand_dims(tf.expand_dims(output, axis=1), axis=2)
示例#40
0
def baseline_perf_transformer_encode(encoder_function,
                                     inputs,
                                     target_space,
                                     hparams,
                                     attention_weights=None,
                                     features=None,
                                     losses=None,
                                     prepare_encoder_fn=None,
                                     **kwargs):
    """Encoding for baseline performance transformer, no mean-aggregation.

  Args:
    encoder_function: the encoder function
    inputs: Transformer inputs [batch_size, input_length, 1, hidden_dim] which
      will be flattened along the two spatial dimensions.
    target_space: scalar, target space ID.
    hparams: hyperparameters for model.
    attention_weights: weight to store attention to.
    features: optionally pass the entire features dictionary as well. This is
      needed now for "packed" datasets.
    losses: optional list onto which to append extra training losses
    prepare_encoder_fn: optional, alternative to transformer_prepare_encoder.
    **kwargs: additional arguments to pass to encoder_function

  Returns:
    Tuple of:
        encoder_output: Encoder representation.
            [batch_size, input_length, hidden_dim]
        encoder_decoder_attention_bias: Bias and mask weights for
            encoder-decoder attention. [batch_size, input_length]
  """
    inputs = common_layers.flatten4d3d(inputs)

    if not prepare_encoder_fn:
        prepare_encoder_fn = transformer_prepare_encoder
    encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
        prepare_encoder_fn(inputs,
                           target_space,
                           hparams,
                           features=features,
                           reuse_target_embedding=tf.AUTO_REUSE))

    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT,
        value=hparams.layer_prepostprocess_dropout,
        hparams=hparams)

    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)

    attn_bias_for_padding = None
    # Otherwise the encoder will just use encoder_self_attention_bias.
    if hparams.unidirectional_encoder:
        attn_bias_for_padding = encoder_decoder_attention_bias

    encoder_output = encoder_function(
        encoder_input,
        self_attention_bias,
        hparams,
        name="encoder",
        nonpadding=features_to_nonpadding(features, "inputs"),
        save_weights_to=attention_weights,
        make_image_summary=not common_layers.is_xla_compiled(),
        losses=losses,
        attn_bias_for_padding=attn_bias_for_padding,
        **kwargs)

    # no aggregation --> just return everything normally
    return encoder_output, encoder_decoder_attention_bias
示例#41
0
def transformer_autoencoder(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """Auto-encoder using transformer decoder and prior over latents."""
    # Define losses
    losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}

    # Reshape image targets as 4d tensor.
    original_targets_shape = common_layers.shape_list(targets)
    batch_size = original_targets_shape[0]
    if len(original_targets_shape) == 4:
        compress_fn = compress_encoder_2d
        decompress_fn = decompress_decoder_2d
    else:
        compress_fn = compress_encoder_1d
        decompress_fn = decompress_decoder_1d

    # Encoder decoder attention bias.
    ed_attention_bias = None

    # Input Encoder if present.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed_attention_bias = transformer_text_encoder(
            inputs, target_space, hparams, "input_enc")

    # Encode targets to compute targets compressed.
    targets_c = compress_fn(targets, hparams, "compress")
    targets, _, _ = cia.maybe_reshape_4d_to_3d(targets)

    # Following code creates an exponentially decaying variable based on which
    # we rescale the loss values.
    pc = common_layers.inverse_exp_decay(hparams.startup_steps)
    pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
    cond = tf.less(tf.random_uniform([batch_size]), pc)

    # Call bottleneck layer, that takes encoder output and outputs the latents.
    # Returns embedded latents, discrete latent codes, loss.
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
        latents_dense, latents_discrete, extra_loss = (bottleneck_layer(
            targets_c, hparams))
        extra_loss = tf.reduce_mean(extra_loss) * tf.to_float(cond)

        # Call the autoregressive latent prediction model.
        _, latents_pred_loss = latent_prediction_model(inputs,
                                                       ed_attention_bias,
                                                       latents_discrete,
                                                       latents_dense,
                                                       hparams,
                                                       name="latent_pred")
        latents_pred_loss = tf.reduce_mean(latents_pred_loss) * tf.to_float(
            cond)
        # Latent dropout.
        latents_shape = common_layers.shape_list(latents_dense)
        latents_dense = tf.nn.dropout(
            latents_dense,
            1 - hparams.latent_dropout,
            noise_shape=[latents_shape[0], latents_shape[1], 1])
        # Assign latent loss.
        losses["latent_pred"] = latents_pred_loss
        losses["extra_loss"] = extra_loss
    else:
        latent_len = (hparams.img_len * hparams.img_len *
                      hparams.num_latents) / 2**(hparams.num_compress_steps)
        embed = functools.partial(discretization.parametrized_unbottleneck,
                                  hparams=hparams)
        latents_dense = tf.zeros(
            [batch_size, latent_len, 1, hparams.hidden_size])
        if cache is None:
            cache = ae_latent_sample_beam(latents_dense, inputs,
                                          ed_attention_bias, embed, hparams)
        latents_dense = embed(
            tf.one_hot(cache, depth=2**hparams.bottleneck_bits),
            hparams.hidden_size)

    latents_decoder = latents_dense
    if len(original_targets_shape) == 4:
        cmp_img_len = hparams.img_len / (2**(hparams.num_compress_steps // 2))
        latents_decoder = tf.reshape(latents_decoder, [
            batch_size, cmp_img_len, cmp_img_len,
            hparams.num_latents * hparams.hidden_size
        ])

    # Decompress either using 1D or 2D upconvs.
    latents_decoder = decompress_fn(latents_decoder,
                                    hparams,
                                    name="decompress")
    # if we're operating in 2d space on images, then we're assuming that the
    # last dimension will not be a multiple of channels
    output = tf.reshape(
        latents_decoder,
        shape=[-1, hparams.img_len, hparams.img_len, hparams.hidden_size])

    if hparams.use_gold_targets:
        latents_decoder, _, _ = cia.maybe_reshape_4d_to_3d(latents_decoder)
        masking = common_layers.inverse_exp_decay(hparams.mask_startup_steps)
        if hparams.mode == tf.estimator.ModeKeys.PREDICT:
            masking = predict_mask
        mask = tf.less(
            masking, tf.random_uniform(common_layers.shape_list(targets)[:-1]))
        mask = tf.expand_dims(tf.to_float(mask), 2)
        output = mask * targets + (1.0 - mask) * output

    # reshape back to 4d here
    output = tf.reshape(output, original_targets_shape)
    if hparams.decode_autoregressive:
        # Transformer decoder, that goes from inputs->targets
        decoder_output = transformer_image_decoder(output, inputs,
                                                   ed_attention_bias, hparams,
                                                   "decoder")
    else:
        decoder_output = output

    # We'll start training the extra model of latents after mask_startup_steps.
    latent_time = tf.less(hparams.mask_startup_steps,
                          tf.to_int32(tf.train.get_global_step()))
    losses["latent_pred"] *= tf.to_float(latent_time)
    return decoder_output, losses, cache
示例#42
0
    def body(self, features):
        """Transformer main model_fn.

        Args:
          features: Map of features to the model. Should contain the following:
              "inputs": Transformer inputs. [batch_size, input_length, 1,
                hidden_dim].
              "targets": Target decoder outputs. [batch_size, decoder_length, 1,
                hidden_dim]
              "target_space_id": A scalar int from data_generators.problem.SpaceID.

        Returns:
          Final decoder representation. [batch_size, decoder_length, hidden_dim]
        """
        hparams = self._hparams

        losses = []

        if self.has_input:
            # use melody-only as input features
            inputs = features["melody"]
            target_space = features["target_space_id"]
            encoder_output, encoder_decoder_attention_bias = self.encode(
                inputs,
                target_space,
                hparams,
                features=features,
                losses=losses)
        else:
            encoder_output, encoder_decoder_attention_bias = (None, None)

        targets = features["targets"]
        targets_shape = common_layers.shape_list(targets)
        targets = common_layers.flatten4d3d(targets)
        decoder_input, decoder_self_attention_bias = self._prepare_decoder_fn(
            targets, hparams, features=features)

        # Not all subclasses of Transformer support keyword arguments related to
        # recurrent memory, so only pass these arguments if memory is enabled.
        decode_kwargs = {}
        if self.recurrent_memory_by_layer is not None:
            # TODO(kitaev): The chunk_number feature currently has the same shape as
            # "targets", but this is only for the purposes of sharing sharding code.
            # In fact every token within an example must have the same chunk number.
            chunk_number_each_token = tf.squeeze(features["chunk_number"],
                                                 (-1, -2))
            chunk_number_each_example = chunk_number_each_token[:, 0]
            # Uncomment the code below to verify that tokens within a batch share the
            # same chunk number:
            # with tf.control_dependencies([
            #     tf.assert_equal(chunk_number_each_token,
            #                     chunk_number_each_example[:, None])
            # ]):
            #   chunk_number_each_example = tf.identity(chunk_number_each_example)
            decode_kwargs = dict(
                recurrent_memory_by_layer=self.recurrent_memory_by_layer,
                chunk_number=chunk_number_each_example,
            )
        decoder_output = six.ensure_text(self,
                                         decoder_input,
                                         encoder_output,
                                         encoder_decoder_attention_bias,
                                         decoder_self_attention_bias,
                                         hparams,
                                         nonpadding=features_to_nonpadding(
                                             features, "targets"),
                                         losses=losses,
                                         **decode_kwargs)
        expected_attentions = features.get("expected_attentions")
        if expected_attentions is not None:
            attention_loss = common_attention.encoder_decoder_attention_loss(
                expected_attentions, self.attention_weights,
                hparams.expected_attention_loss_type,
                hparams.expected_attention_loss_multiplier)
            return decoder_output, {"attention_loss": attention_loss}

        ret = tf.reshape(decoder_output, targets_shape)
        if losses:
            return ret, {"extra_loss": tf.add_n(losses)}
        else:
            return ret
def main():

  FLAGS = Args()

  # Enable TF Eager execution
  tfe = tf.contrib.eager
  tfe.enable_eager_execution()

  # sample sentence
  input_str = 'Twas brillig, and the slithy toves Did gyre and gimble in the wade; All mimsy were the borogoves, And the mome raths outgrabe.'

  # convert sentence into index in vocab
  wmt_problem = problems.problem(FLAGS.problem)
  encoders = wmt_problem.feature_encoders(FLAGS.data_dir)
  inputs = encoders["inputs"].encode(input_str) + [1]  # add EOS id
  batch_inputs = tf.reshape(inputs, [1, -1, 1])  # Make it 3D.
  features = {"inputs": batch_inputs}

  # initialize translation model
  hparams_set = FLAGS.hparams_set
  Modes = tf.estimator.ModeKeys
  hparams = trainer_lib.create_hparams(hparams_set, data_dir=FLAGS.data_dir, problem_name=FLAGS.problem)
  translate_model = registry.model(FLAGS.model)(hparams, Modes.EVAL)

  # recover parameters and conduct recurrent conduction
  ckpt_dir = tf.train.latest_checkpoint(FLAGS.model_dir)

  with tfe.restore_variables_on_create(ckpt_dir):
    with variable_scope.EagerVariableStore().as_default():
      with tf.variable_scope('universal_transformer'):
        # Convert word index to word embedding
        features = translate_model.bottom(features)

      with tf.variable_scope('universal_transformer/body'):
        input_tensor = tf.convert_to_tensor(features['inputs'])
        input_tensor = common_layers.flatten4d3d(input_tensor)
        encoder_input, self_attention_bias, _ = (
          transformer.transformer_prepare_encoder(
            input_tensor, tf.convert_to_tensor([0]), translate_model.hparams, features=None))

      with tf.variable_scope('universal_transformer/body/encoder'):

        ffn_unit = functools.partial(
          universal_transformer_util.transformer_encoder_ffn_unit,
          hparams=translate_model.hparams)

        attention_unit = functools.partial(
          universal_transformer_util.transformer_encoder_attention_unit,
          hparams=translate_model.hparams,
          encoder_self_attention_bias=None,
          attention_dropout_broadcast_dims=[],
          save_weights_to={},
          make_image_summary=True)

      storing_list = []
      transformed_state = encoder_input
      for step_index in range(1024):
        storing_list.append(transformed_state.numpy())

        with tf.variable_scope('universal_transformer/body/encoder/universal_transformer_{}'.format(FLAGS.ut_type)):
          transformed_state = universal_transformer_util.step_preprocess(
            transformed_state,
            tf.convert_to_tensor(step_index % FLAGS.step_num),
            translate_model.hparams
          )
        with tf.variable_scope('universal_transformer/body/encoder/universal_transformer_{}/rec_layer_0'.format(FLAGS.ut_type)):
          transformed_new_state = ffn_unit(attention_unit(transformed_state))
        with tf.variable_scope('universal_transformer/body/encoder'):
          if (step_index + 1) % FLAGS.step_num == 0:
            transformed_new_state = common_layers.layer_preprocess(transformed_new_state, translate_model.hparams)

            if step_index == 5:
              print(transformed_new_state)

        transformed_state = transformed_new_state
      storing_list = np.asarray(storing_list)
      np.save(FLAGS.save_dir, storing_list)
    def body(self, features):
        hp = self.hparams
        # pylint: disable=eval-used
        if hp.image_input_type == "image":
            image_feat = vqa_layers.image_embedding(
                features["inputs"],
                model_fn=eval(hp.image_model_fn),
                trainable=hp.train_resnet,
                is_training=hp.mode == tf.estimator.ModeKeys.TRAIN)
        else:
            image_feat = features["inputs"]

        image_feat = common_layers.flatten4d3d(image_feat)
        image_feat = common_layers.dense(image_feat, hp.hidden_size)
        utils.collect_named_outputs("norms", "image_feat_after_proj",
                                    tf.norm(image_feat, axis=-1))

        question = common_layers.flatten4d3d(features["question"])
        utils.collect_named_outputs("norms", "question_embedding",
                                    tf.norm(question, axis=-1))
        (encoder_input, encoder_self_attention_bias,
         encoder_decoder_attention_bias) = prepare_image_question_encoder(
             image_feat, question, hp)

        encoder_input = tf.nn.dropout(encoder_input,
                                      keep_prob=1. -
                                      hp.layer_prepostprocess_dropout)

        encoder_output, _ = recurrent_transformer_decoder(
            encoder_input,
            None,
            encoder_self_attention_bias,
            None,
            hp,
            name="encoder")
        utils.collect_named_outputs("norms", "encoder_output",
                                    tf.norm(encoder_output, axis=-1))

        # scale query by sqrt(hidden_size)
        query = tf.get_variable("query",
                                [hp.hidden_size]) * hp.hidden_size**0.5
        query = tf.expand_dims(tf.expand_dims(query, axis=0), axis=0)
        batch_size = common_layers.shape_list(encoder_input)[0]
        query = tf.tile(query, [batch_size, 1, 1])
        query = tf.nn.dropout(query,
                              keep_prob=1. - hp.layer_prepostprocess_dropout)

        decoder_output, _ = recurrent_transformer_decoder(
            query,
            encoder_output,
            None,
            encoder_decoder_attention_bias,
            hp,
            name="decoder")
        utils.collect_named_outputs("norms", "decoder_output",
                                    tf.norm(decoder_output, axis=-1))

        norm_tensors = utils.convert_collection_to_dict("norms")
        vqa_layers.summarize_tensors(norm_tensors, tag="norms/")

        # Expand dimension 1 and 2
        return tf.expand_dims(decoder_output, axis=1)
  def encode(self, inputs_context, inputs, target_space, hparams, features=None, losses=None):
    """Encode transformer inputs.

    Args:
      inputs_context: contextual input [batch_size, input_length, hidden_dim]
      inputs: Transformer inputs [batch_size, input_length, input_height,
        hidden_dim] which will be flattened along the two spatial dimensions.
      target_space: scalar, target space ID.
      hparams: hyperparameters for model.
      features: optionally pass the entire features dictionary as well.
        This is needed now for "packed" datasets.
      losses: optional list onto which to append extra training losses

    Returns:
      Tuple of:
          encoder_output: Encoder representation.
              [batch_size, input_length, hidden_dim]
          encoder_output_context: Contextual encoder representation
              [batch_size, input_length, hidden_dim]
          encoder_decoder_attention_bias: Bias and mask weights for
              encoder-decoder attention. [batch_size, input_length]
    """
    inputs = common_layers.flatten4d3d(inputs)

    encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
        transformer_prepare_encoder(
            inputs, target_space, hparams, features=features))

    encoder_input = tf.nn.dropout(encoder_input,
                                  1.0 - hparams.layer_prepostprocess_dropout)

    encoder_output = transformer_encoder(
        encoder_input,
        self_attention_bias,
        hparams,
        nonpadding=features_to_nonpadding(features, "inputs"),
        save_weights_to=self.attention_weights,
        losses=losses)

    if inputs_context is None:
        return None, encoder_output, encoder_decoder_attention_bias

    inputs_context = common_layers.flatten4d3d(inputs_context)

    encoder_input_context, self_attention_bias_context, encoder_decoder_attention_bias_context = (
        transformer_prepare_encoder(
            inputs_context, target_space, hparams, features=features))

    encoder_input_context = tf.nn.dropout(encoder_input_context,
                                  1.0 - hparams.layer_prepostprocess_dropout)

    encoder_output_context_0 = transformer_encoder(
        encoder_input_context,
        self_attention_bias_context,
        hparams,
        name="encoder_context",
        nonpadding=features_to_nonpadding(features, "inputs"),
        save_weights_to=self.attention_weights,
        losses=losses)

    encoder_output_context = transformer_decoder(
        encoder_input,
        encoder_output_context_0,
        encoder_decoder_attention_bias,
        encoder_decoder_attention_bias_context,
        hparams,
        name="decoder_input_context")

    return encoder_output_context, encoder_output, encoder_decoder_attention_bias
示例#46
0
        def infer_step(logits_so_far, current_hidden):
            """Inference step of LSTM while loop."""
            # unflatten hidden:
            current_hidden = tuple(
                tf.nn.rnn_cell.LSTMStateTuple(c=s[0], h=s[1])
                for s in current_hidden)

            # put logits_so_far through top
            tm = self._problem_hparams.modality['targets']
            # need to reuse top params
            reset_scope = tf.variable_scope(tf.VariableScope(
                tf.AUTO_REUSE, ''),
                                            reuse=tf.AUTO_REUSE,
                                            auxiliary_name_scope=False)
            top_scope = tf.variable_scope('svg_decoder/{}_modality'.format(tm),
                                          reuse=tf.AUTO_REUSE)
            with reset_scope, top_scope:
                samples_so_far = self.hparams.top['targets'](
                    logits_so_far, None, self.hparams,
                    self.problem_hparams.vocab_size)
            # append a zero pad to the samples. this effectively shifts the samples
            # right, but, unlike shift_right, by not removing the last element, we
            # allow an empty samples_so_far to not be empty after padding
            samples_so_far = tf.concat([zero_pad, samples_so_far], axis=1)
            shifted_targets = common_layers.flatten4d3d(samples_so_far)
            # now take the very last one here, will be the actual input to the rnn
            shifted_targets = shifted_targets[:, -1:, :]

            # tile and append the bottleneck to inputs
            sln_offset = 0
            if hparams.condition_on_sln:
                sln_offset = 51
            pre_tile_y = tf.reshape(bottleneck, [
                common_layers.shape_list(bottleneck)[0], 1,
                hparams.bottleneck_bits + hparams.num_categories + sln_offset
            ])
            overlay_x = tf.tile(
                pre_tile_y,
                [1, common_layers.shape_list(shifted_targets)[1], 1])
            inputs = tf.concat([shifted_targets, overlay_x], -1)

            seq_len_batch = tf.ones([common_layers.shape_list(inputs)[0]])

            # RUN PRE-LSTM LAYER
            with tf.variable_scope('pre_decoder', reuse=tf.AUTO_REUSE):
                inputs = tf.layers.dense(inputs,
                                         hparams.hidden_size,
                                         name='bottom')
                inputs = tf.nn.tanh(inputs)

            # RUN LSTM
            with tf.variable_scope('lstm_decoder', reuse=tf.AUTO_REUSE):
                next_step, next_state = tf.nn.dynamic_rnn(
                    layers,
                    inputs,
                    seq_len_batch,
                    initial_state=current_hidden,
                    dtype=tf.float32,
                    time_major=False)

            next_step = tf.expand_dims(next_step, [1])
            logits_so_far = tf.concat([logits_so_far, next_step], 1)

            # flatten state
            next_state = tuple((s.c, s.h) for s in next_state)

            return logits_so_far, next_state
  def body(self, features):
    """Transformer main model_fn.

    Args:
      features: Map of features to the model. Should contain the following:
          "inputs": Transformer inputs [batch_size, input_length, hidden_dim]
          "targets": Target decoder outputs.
              [batch_size, decoder_length, hidden_dim]
          "target_space_id": A scalar int from data_generators.problem.SpaceID.

    Returns:
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
    """
    hparams = self._hparams

    losses = []

    if self.has_input:
      inputs_context = features.get("inputs_context")
      inputs = features["inputs"]
      target_space = features["target_space_id"]
      encoder_output_context, encoder_output, encoder_decoder_attention_bias = self.encode(
          inputs_context, inputs, target_space, hparams, features=features, losses=losses)
    else:
        encoder_output_context, encoder_output, encoder_decoder_attention_bias = (None, None, None)

    targets = features["targets"]
    targets_shape = common_layers.shape_list(targets)
    targets = common_layers.flatten4d3d(targets)

    decoder_input, decoder_self_attention_bias = transformer_prepare_decoder(
        targets, hparams, features=features)

    decoder_output = self.decode(
        decoder_input,
        encoder_output,
        encoder_decoder_attention_bias,
        decoder_self_attention_bias,
        hparams,
        name="decoder_output_input",
        nonpadding=features_to_nonpadding(features, "targets"),
        losses=losses)

    if encoder_output_context is not None:
        decoder_output_context = self.decode(
            decoder_input,
            encoder_output_context,
            encoder_decoder_attention_bias,
            decoder_self_attention_bias,
            hparams,
            name="decoder_output_input_context",
            nonpadding=features_to_nonpadding(features, "targets"),
            losses=losses)
        decoder_output = self.cat_and_compress(decoder_output_context, decoder_output, hparams)

    expected_attentions = features.get("expected_attentions")
    if expected_attentions is not None:
      attention_loss = common_attention.encoder_decoder_attention_loss(
          expected_attentions, self.attention_weights,
          hparams.expected_attention_loss_type,
          hparams.expected_attention_loss_multiplier)
      return decoder_output, {"attention_loss": attention_loss}

    ret = tf.reshape(decoder_output, targets_shape)
    if losses:
      return ret, {"extra_loss": tf.add_n(losses)}
    else:
      return ret
示例#48
0
 def flatten(inputs):
   return tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2)
示例#49
0
    def encode(self,
               inputs,
               target_space,
               hparams,
               features=None,
               losses=None,
               **kwargs):
        """Encode Universal Transformer inputs.
    It is similar to "transformer.encode", but it uses
    "universal_transformer_util.universal_transformer_encoder" instead of
    "transformer.transformer_encoder".
    Args:
      inputs: Transformer inputs [batch_size, input_length, input_height,
        hidden_dim] which will be flattened along the two spatial dimensions.
      target_space: scalar, target space ID.
      hparams: hyperparmeters for model.
      features: optionally pass the entire features dictionary as well.
        This is needed now for "packed" datasets.
      losses: Unused.
      **kwargs: additional arguments to pass to encoder_function
    Returns:
      Tuple of:
          encoder_output: Encoder representation.
              [batch_size, input_length, hidden_dim]
          encoder_decoder_attention_bias: Bias and mask weights for
              encoder-decoder attention. [batch_size, input_length]
          encoder_extra_output: which is extra encoder output used in some
            variants of the model (e.g. in ACT, to pass the ponder-time to body)
    """

        ####
        ## DEBUG
        ####
        # with open("invertible_UT_params.json", "w") as f:
        #   json.dump(dict(hparams.__dict__), f, default=lambda o: '<not serializable>', sort_keys=True,
        #             indent=4, separators=(',', ': '))
        # sys.exit()

        del losses

        inputs = common_layers.flatten4d3d(inputs)

        encoder_input, self_attention_bias, encoder_decoder_attention_bias = (
            transformer.transformer_prepare_encoder(inputs,
                                                    target_space,
                                                    hparams,
                                                    features=features))

        encoder_input = tf.nn.dropout(
            encoder_input, 1.0 - hparams.layer_prepostprocess_dropout)

        (encoder_output, encoder_extra_output) = (invertible_UT_encoder(
            encoder_input,
            self_attention_bias,
            hparams,
            nonpadding=transformer.features_to_nonpadding(features, "inputs"),
            save_weights_to=self.attention_weights))

        for var in tf.trainable_variables():
            print(var)

        return encoder_output, encoder_decoder_attention_bias, encoder_extra_output
示例#50
0
 def testFlatten4D3D(self):
     x = np.random.random_integers(1, high=8, size=(3, 5, 2))
     y = common_layers.flatten4d3d(common_layers.embedding(x, 10, 7))
     self.evaluate(tf.global_variables_initializer())
     res = self.evaluate(y)
     self.assertEqual(res.shape, (3, 5 * 2, 7))
示例#51
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None,
                            predict_mask=1.0):
    """AE Transformer, main step used for training."""
    # Summaries break with the do_refine cond, turn them off in that case.
    global _DO_SUMMARIES
    if hparams.do_refine:
        _DO_SUMMARIES = False

    # Change hyperparameters for the latent prediction model.
    hparams_ex = copy.copy(hparams)
    hparams_ex.filter_size *= 2
    hparams_ex.hidden_size *= 2
    hparams_ex.dropout = 0.0
    hparams_ex.relu_dropout = 0.0
    hparams_ex.z_dropout = 0.0
    hparams_ex.layer_prepostprocess_dropout = 0.0
    hparams_ex.symbol_dropout = 0.0
    hparams.ex = hparams_ex

    # Prepare.
    if inputs is not None:
        batch_size = common_layers.shape_list(inputs)[0]
    else:
        batch_size = common_layers.shape_list(targets)[0]
    targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

    # Encoder.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs_ex = tf.layers.dense(tf.stop_gradient(inputs),
                                    hparams_ex.hidden_size,
                                    name="extra_embed")
        inputs, ed = encode(inputs, target_space, hparams, "input_enc")
        inputs_ex, ed_ex = encode(inputs_ex, target_space, hparams_ex,
                                  "extra_ienc")
    else:
        ed, inputs_ex, ed_ex = None, None, None

    # Autoencoding.
    losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
    if hparams.do_ae:
        # flatten here
        original_targets_shape = tf.shape(targets)
        if hparams.task == "image":
            cia.maybe_reshape_4d_to_3d(targets)
        if hparams.task == "translate":
            max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
        else:
            assert hparams.task == "image"
            max_targets_len_from_inputs = targets
        targets, _ = common_layers.pad_to_same_length(
            targets,
            max_targets_len_from_inputs,
            final_length_divisible_by=2**hparams.num_compress_steps)
        targets_c = compress(targets, inputs, False, hparams, "compress")
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            # Compress and bottleneck.
            latents_dense, latents_discrete, extra_loss, embed = hparams.bottleneck(
                x=targets_c,
                filter_size=hparams.compress_filter_size,
                name="vc",
                mode=hparams.mode)
            if _DO_SUMMARIES:
                tf.summary.histogram(
                    "b0", tf.reshape(latents_discrete[:, 0, :], [-1]))
            pc = common_layers.inverse_exp_decay(hparams.startup_steps)
            pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
            cond = tf.less(tf.random_uniform([batch_size]), pc)
            latents_dense = tf.where(cond, latents_dense, targets_c)
            # TODO(lukaszkaiser): return extra losses batchwise, multiply before mean.
            losses["extra"] = extra_loss * tf.reduce_mean(tf.to_float(cond))
            # Extra loss predicting latent code from input. Discrete only.
            if hparams.bottleneck_kind not in ["dense", "vae"]:
                latents_pred = decode_transformer(inputs_ex,
                                                  ed_ex,
                                                  tf.stop_gradient(
                                                      embed(latents_discrete)),
                                                  hparams,
                                                  "extra",
                                                  task="translate")
                _, latent_pred_loss = ae_latent_softmax(
                    latents_pred, tf.stop_gradient(latents_discrete), hparams)
                losses["latent_pred"] = tf.reduce_mean(latent_pred_loss *
                                                       tf.to_float(cond))
            else:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                losses["latent_pred"] = tf.reduce_mean(
                    (inputs_c - targets_c)**2) * 20

                def bn_inputs():
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=True):
                        bn, _, _, _ = hparams.bottleneck(
                            x=inputs_c,
                            filter_size=hparams.compress_filter_size,
                            name="vc",
                            mode=hparams.mode)
                    return bn

                inputs_c = bn_inputs
                ptc = 1.0 - common_layers.inverse_lin_decay(200000) * 0.5
                ptc = ptc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
                latents_dense = tf.where(
                    tf.less(tf.random_uniform([batch_size]), ptc),
                    latents_dense, inputs_c)
        else:
            if hparams.bottleneck_kind in ["dense", "vae"]:
                inputs_c = decode_transformer(inputs, ed, targets_c, hparams,
                                              "dec_c")
                latents_dense, _, _, _ = hparams.bottleneck(
                    x=inputs_c,
                    filter_size=hparams.compress_filter_size,
                    name="vc",
                    mode=hparams.mode)
            else:
                latent_len = common_layers.shape_list(targets_c)[1]
                _, _, _, embed = hparams.bottleneck(
                    x=targets_c,
                    filter_size=hparams.compress_filter_size,
                    name="vc")
                latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
                if cache is None:
                    cache = ae_latent_sample(latents_dense, inputs_ex, ed_ex,
                                             embed, 16, hparams)
                latents_dense = embed(cache)
        # Postprocess.
        d = latents_dense
        pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
        pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
        latents_dense = tf.pad(latents_dense,
                               [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

        # Masking.
        if hparams.do_mask:
            masking = common_layers.inverse_lin_decay(
                hparams.mask_startup_steps)
            masking *= common_layers.inverse_exp_decay(
                hparams.mask_startup_steps // 4)  # Not much at start.
            if not hparams.do_refine:
                masking -= tf.random_uniform([]) * hparams.unmasked_percentage
            masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
            if hparams.use_predict_mask:
                masking = predict_mask
            if hparams.mode == tf.estimator.ModeKeys.PREDICT:
                masking = predict_mask
            mask = tf.less(
                masking,
                tf.random_uniform(common_layers.shape_list(targets)[:-1]))
            mask = tf.expand_dims(tf.to_float(mask), 3)
            for i in xrange(hparams.num_compress_steps):
                j = hparams.num_compress_steps - i - 1
                d = residual_conv(d, 1, (3, 1), hparams,
                                  "decompress_rc_%d" % j)
                if hparams.do_attend_decompress:
                    d = attend(d, inputs, hparams, "decompress_attend_%d" % j)
                d = decompress_step(d, hparams, i > 0, False,
                                    "decompress_%d" % j)
            # targets is always [batch, length, 1, depth]
            targets = mask * targets + (1.0 - mask) * d
            # reshape back to 4d here
            if hparams.task == "image":
                targets = tf.reshape(targets, original_targets_shape)
        if hparams.task == "translate":
            targets = tf.concat([tf.reverse(latents_dense, [1]), targets],
                                axis=1)

    res = decode_transformer(inputs,
                             ed,
                             targets,
                             hparams,
                             "decoder",
                             causal=hparams.causal)
    if hparams.do_ae:
        if hparams.task == "translate":
            res = res[:, common_layers.shape_list(latents_dense)[1]:, :, :]
        if hparams.do_mask and hparams.do_refine:

            def refine_res():
                # return residual_conv(res, 1, (5, 1), hparams, "refine")
                r, _ = encode(tf.squeeze(res, axis=[2]), target_space, hparams,
                              "refine_enc")
                return tf.expand_dims(r, axis=2)

            masked_batches = tf.reduce_sum(mask, axis=[1, 2, 3])
            all_masked = tf.less(masked_batches, 0.1)
            res = tf.where(all_masked, refine_res(), res)
        # We'll start training the extra model of latents after mask_startup_steps.
        nonlatent_steps = hparams.mask_startup_steps
        latent_time = tf.less(nonlatent_steps,
                              tf.to_int32(tf.train.get_global_step()))
        # Learning rate warmup for the latent model for 20K steps.
        latent_warmup = tf.to_float(
            tf.train.get_global_step()) - nonlatent_steps
        latent_warmup = tf.maximum(0.0, tf.minimum(1.0,
                                                   latent_warmup / 20000.0))
        losses["latent_pred"] *= tf.to_float(latent_time) * latent_warmup
    return res, losses, cache
示例#52
0
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            cache=None):
    """Main step used for training."""
    # Encoder.
    inputs = common_layers.flatten4d3d(inputs)
    inputs, ed = encode(inputs, target_space, hparams, "input_enc")

    # Autoencoding.
    losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}

    max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
    targets, _ = common_layers.pad_to_same_length(
        targets,
        max_targets_len_from_inputs,
        final_length_divisible_by=2**hparams.num_compress_steps)
    targets_c = compress(targets, hparams, "compress")
    if hparams.mode != tf.estimator.ModeKeys.PREDICT:
        # Compress and bottleneck.
        latents_discrete_hot, extra_loss = vq_discrete_bottleneck(
            x=targets_c, hparams=hparams)
        latents_dense = vq_discrete_unbottleneck(latents_discrete_hot,
                                                 hparams=hparams)
        latents_dense = targets_c + tf.stop_gradient(latents_dense - targets_c)
        latents_discrete = tf.argmax(latents_discrete_hot, axis=-1)
        tf.summary.histogram("codes",
                             tf.reshape(latents_discrete[:, 0, :], [-1]))
        losses["extra"] = extra_loss

        # Extra loss predicting latent code from input.
        latents_pred = decode_transformer(inputs, ed, latents_dense, hparams,
                                          "extra")
        latent_pred_loss = get_latent_pred_loss(latents_pred,
                                                latents_discrete_hot, hparams)
        losses["latent_pred"] = tf.reduce_mean(latent_pred_loss)
    else:
        latent_len = common_layers.shape_list(targets_c)[1]
        embed = functools.partial(vq_discrete_unbottleneck, hparams=hparams)
        latents_dense = tf.zeros_like(targets_c[:, :latent_len, :, :])
        if cache is None:
            cache = ae_latent_sample_beam(latents_dense, inputs, ed, embed,
                                          hparams)
        cache_hot = tf.one_hot(cache, depth=2**hparams.bottleneck_bits)
        latents_dense = embed(cache_hot)

    # Postprocess.
    d = latents_dense
    pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
    pos = pos[:, :common_layers.shape_list(latents_dense)[1] + 1, :, :]
    latents_dense = tf.pad(latents_dense,
                           [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

    # Decompressing the dense latents
    for i in range(hparams.num_compress_steps):
        j = hparams.num_compress_steps - i - 1
        d = residual_conv(d, 1, (3, 1), hparams, "decompress_rc_%d" % j)
        d = decompress_step(d, hparams, i > 0, "decompress_%d" % j)

    masking = common_layers.inverse_lin_decay(hparams.mask_startup_steps)
    masking *= common_layers.inverse_exp_decay(hparams.mask_startup_steps //
                                               4)  # Not much at start.
    masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
    if hparams.mode == tf.estimator.ModeKeys.PREDICT:
        masking = 1.0
    mask = tf.less(masking,
                   tf.random_uniform(common_layers.shape_list(targets)[:-1]))
    mask = tf.expand_dims(tf.to_float(mask), 3)

    # targets is always [batch, length, 1, depth]
    targets = mask * targets + (1.0 - mask) * d

    res = decode_transformer(inputs, ed, targets, hparams, "decoder")
    latent_time = tf.less(hparams.mask_startup_steps,
                          tf.to_int32(tf.train.get_global_step()))
    losses["latent_pred"] *= tf.to_float(latent_time)
    return res, losses, cache
示例#53
0
  def lstm_decoder_infer(self, inputs, sequence_length, hparams, clss, train,
                         initial_state=None, bottleneck=None):
    # IN PREDICT MODE, RUN tf.while RNN
    max_decode_length = 51
    batch_size = common_layers.shape_list(inputs)[0]
    zero_pad, logits_so_far = self.create_initial_input_for_decode(batch_size)

    layers = contrib_rnn.MultiRNNCell([
        self.lstm_cell(hparams, train) for _ in range(hparams.num_hidden_layers)
    ])

    if initial_state is None:
      raise Exception('initial state should be init from bottleneck!')

    # append one-hot class to bottleneck, which will be given per step
    clss = tf.reshape(clss, [-1])
    if not hparams.use_cls:
      clss = tf.zeros_like(clss)
    if hparams.condition_on_sln:
      sln = tf.reshape(sequence_length, [-1])
      bottleneck = tf.concat((bottleneck,
                              tf.one_hot(clss, hparams.num_categories),
                              tf.one_hot(sln, max_decode_length)), -1)
    else:
      bottleneck = tf.concat((bottleneck,
                              tf.one_hot(clss, hparams.num_categories)), -1)

    def infer_step(logits_so_far, current_hidden):
      """Inference step of LSTM while loop."""
      # unflatten hidden:
      current_hidden = tuple(tf.nn.rnn_cell.LSTMStateTuple(c=s[0], h=s[1])
                             for s in current_hidden)

      # put logits_so_far through top
      tm = self._problem_hparams.modality['targets']
      # need to reuse top params
      reset_scope = tf.variable_scope(tf.VariableScope(tf.AUTO_REUSE, ''),
                                      reuse=tf.AUTO_REUSE,
                                      auxiliary_name_scope=False)
      top_scope = tf.variable_scope('svg_decoder/{}_modality'.format(tm),
                                    reuse=tf.AUTO_REUSE)
      with reset_scope, top_scope:
        samples_so_far = self.hparams.top['targets'](
            logits_so_far, None, self.hparams, self.problem_hparams.vocab_size)
      # append a zero pad to the samples. this effectively shifts the samples
      # right, but, unlike shift_right, by not removing the last element, we
      # allow an empty samples_so_far to not be empty after padding
      samples_so_far = tf.concat([zero_pad, samples_so_far], axis=1)
      shifted_targets = common_layers.flatten4d3d(samples_so_far)
      # now take the very last one here, will be the actual input to the rnn
      shifted_targets = shifted_targets[:, -1:, :]

      # tile and append the bottleneck to inputs
      sln_offset = 0
      if hparams.condition_on_sln:
        sln_offset = 51
      pre_tile_y = tf.reshape(
          bottleneck,
          [common_layers.shape_list(bottleneck)[0], 1,
           hparams.bottleneck_bits + hparams.num_categories + sln_offset])
      overlay_x = tf.tile(pre_tile_y,
                          [1, common_layers.shape_list(shifted_targets)[1], 1])
      inputs = tf.concat([shifted_targets, overlay_x], -1)

      seq_len_batch = tf.ones([common_layers.shape_list(inputs)[0]])

      # RUN PRE-LSTM LAYER
      with tf.variable_scope('pre_decoder', reuse=tf.AUTO_REUSE):
        inputs = tf.layers.dense(inputs, hparams.hidden_size, name='bottom')
        inputs = tf.nn.tanh(inputs)

      # RUN LSTM
      with tf.variable_scope('lstm_decoder', reuse=tf.AUTO_REUSE):
        next_step, next_state = tf.nn.dynamic_rnn(
            layers, inputs, seq_len_batch, initial_state=current_hidden,
            dtype=tf.float32, time_major=False)

      next_step = tf.expand_dims(next_step, [1])
      logits_so_far = tf.concat([logits_so_far, next_step], 1)

      # flatten state
      next_state = tuple((s.c, s.h) for s in next_state)

      return logits_so_far, next_state

    def while_exit_cond(logits_so_far, unused_current_hidden):
      length = common_layers.shape_list(logits_so_far)[1]
      return length < max_decode_length

    # passing state must be flattened:
    initial_state = tuple([(s.c, s.h) for s in initial_state])

    # actually run tf.while:
    logits, final_state = tf.while_loop(
        while_exit_cond, infer_step,
        [logits_so_far, initial_state],
        shape_invariants=[
            tf.TensorShape([None, None, 1, hparams.hidden_size]),
            tuple([(s[0].get_shape(), s[1].get_shape())
                   for s in initial_state]),
        ],
        back_prop=False,
        parallel_iterations=1
    )

    # logits should be returned in 3d mode:
    logits = common_layers.flatten4d3d(logits)

    return logits, final_state
示例#54
0
    def body(self, features):
        """R-Transformer main model_fn.


    Args:
      features: Map of features to the model. Should contain the following:
          "inputs": Transformer inputs [batch_size, input_length, hidden_dim]
          "targets": Target decoder outputs.
              [batch_size, decoder_length, hidden_dim]
          "target_space_id"

    Returns:
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
    """
        hparams = self._hparams
        if hparams.add_position_timing_signal:
            # Turning off addition of positional embedding in the encoder/decoder
            # preparation as we do it in the beginning of each step.
            hparams.pos = None

        if self.has_input:
            inputs = features["inputs"]
            target_space = features["target_space_id"]
            (encoder_output, encoder_decoder_attention_bias,
             enc_extra_output) = self.encode(inputs,
                                             target_space,
                                             hparams,
                                             features=features)
        else:
            (encoder_output, encoder_decoder_attention_bias,
             enc_extra_output) = (None, None, (None, None))

        targets = features["targets"]
        targets = common_layers.flatten4d3d(targets)

        (decoder_input, decoder_self_attention_bias
         ) = transformer.transformer_prepare_decoder(targets,
                                                     hparams,
                                                     features=features)

        decoder_output, dec_extra_output = self.decode(
            decoder_input,
            encoder_output,
            encoder_decoder_attention_bias,
            decoder_self_attention_bias,
            hparams,
            nonpadding=transformer.features_to_nonpadding(features, "targets"))

        expected_attentions = features.get("expected_attentions")
        if expected_attentions is not None:
            attention_loss = common_attention.encoder_decoder_attention_loss(
                expected_attentions, self.attention_weights,
                hparams.expected_attention_loss_type,
                hparams.expected_attention_loss_multiplier)
            return decoder_output, {"attention_loss": attention_loss}

        if hparams.recurrence_type == "act" and hparams.act_loss_weight != 0:
            if self.has_input:
                enc_ponder_times, enc_remainders = enc_extra_output
                enc_act_loss = (
                    hparams.act_loss_weight *
                    tf.reduce_mean(enc_ponder_times + enc_remainders))
            else:
                enc_act_loss = 0.0

            (dec_ponder_times, dec_remainders) = dec_extra_output
            dec_act_loss = (hparams.act_loss_weight *
                            tf.reduce_mean(dec_ponder_times + dec_remainders))
            act_loss = enc_act_loss + dec_act_loss
            tf.contrib.summary.scalar("act_loss", act_loss)
            return decoder_output, {"act_loss": act_loss}

        return decoder_output
示例#55
0
def decode_transformer(encoder_output,
                       encoder_decoder_attention_bias,
                       targets,
                       hparams,
                       name,
                       task=None,
                       causal=True):
    """Original Transformer decoder."""
    orig_hparams = hparams
    with tf.variable_scope(name):
        if task is None:
            task = hparams.task
        if task == "translate":
            targets = common_layers.flatten4d3d(targets)

            decoder_input, decoder_self_bias = (
                transformer.transformer_prepare_decoder(targets, hparams))

            decoder_input = tf.nn.dropout(
                decoder_input, 1.0 - hparams.layer_prepostprocess_dropout)

            if not causal:
                decoder_self_bias *= 0.

            decoder_output = transformer.transformer_decoder(
                decoder_input, encoder_output, decoder_self_bias,
                encoder_decoder_attention_bias, hparams)
            decoder_output = tf.expand_dims(decoder_output, axis=2)
        else:
            assert task == "image"
            inputs = None
            # have to reshape targets as b, 32, 32, 3 * hidden size] beacuse otherwise
            # prepare_image will choke
            targets = tf.reshape(targets, [
                tf.shape(targets)[0], hparams.img_len, hparams.img_len,
                hparams.num_channels * hparams.hidden_size
            ])

            # Prepare decoder inputs and bias.
            # TODO(nikip): Make prepare_decoder return bias
            decoder_input, _, _ = cia.prepare_decoder(targets, hparams)
            bias = None

            # Add class label to decoder input.
            if not hparams.drop_inputs:
                decoder_input += tf.reshape(inputs, [
                    common_layers.shape_list(targets)[0], 1, 1,
                    hparams.hidden_size
                ])
            decoder_output = cia.transformer_decoder_layers(
                decoder_input,
                encoder_output=None,
                num_layers=hparams.num_decoder_layers
                or hparams.num_hidden_layers,
                hparams=hparams,
                self_attention_bias=bias,
                attention_type=hparams.dec_attention_type,
                name="decoder")
        decoder_output_shape = common_layers.shape_list(decoder_output)
        decoder_output = tf.reshape(
            decoder_output,
            [decoder_output_shape[0], -1, 1, hparams.hidden_size])
        # Expand since t2t expects 4d tensors.
        hparams = orig_hparams
        return decoder_output
示例#56
0
    def body(self, features, original_features):
        """Transformer main model_fn.
    Args:
      features: Map of features to the model. Should contain the following:
          "inputs": Transformer inputs [batch_size, input_length, hidden_dim]
          "targets": Target decoder outputs.
              [batch_size, decoder_length, hidden_dim]
          "target_space_id"
    Returns:
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
    """
        hparams = self._hparams

        snippets = features.get(searchqa_problem.FeatureNames.SNIPPETS)
        questions = features.get(searchqa_problem.FeatureNames.QUESTION)
        target_space = features["target_space_id"]

        with tf.variable_scope('input'):
            # [batch_size, search_results_len, embed_sz]
            encoded_snippets = self.inputs_encoding(
                input=snippets,
                original_input=original_features.get(
                    searchqa_problem.FeatureNames.SNIPPETS),
                initializer=tf.constant_initializer(1.0),
                scope='snippets_encoding')

            # [batch_size, 1, embed_sz]
            encoded_question = self.inputs_encoding(
                input=questions,
                original_input=original_features.get(
                    searchqa_problem.FeatureNames.QUESTION),
                initializer=tf.constant_initializer(1.0),
                scope='question_encoding')

        # Concat snippets and questions to creat the inputs
        inputs = tf.concat([encoded_snippets, encoded_question], axis=1)
        # the input is 4D by default and it gets squeezed from 4D to 3D in the
        # encode function, so we need to make it 4D by inserting channel dim.
        inputs = tf.expand_dims(inputs, axis=2)

        losses = []
        encoder_output, encoder_decoder_attention_bias = self.encode(
            inputs, target_space, hparams, features=features, losses=losses)

        targets = features["targets"]
        targets_shape = common_layers.shape_list(targets)
        targets = common_layers.flatten4d3d(targets)

        decoder_input, decoder_self_attention_bias = transformer.transformer_prepare_decoder(
            targets, hparams, features=features)

        decoder_output = self.decode(decoder_input,
                                     encoder_output,
                                     encoder_decoder_attention_bias,
                                     decoder_self_attention_bias,
                                     hparams,
                                     nonpadding=features_to_nonpadding(
                                         features, "targets"),
                                     losses=losses)

        ret = tf.reshape(decoder_output, targets_shape)
        if losses:
            return ret, {"extra_loss": tf.add_n(losses)}
        else:
            return ret
    def body(self, features):
        """Transformer main model_fn.

    Args:
      features: Map of features to the model. Should contain the following:
          "inputs": Transformer inputs [batch_size, input_length, hidden_dim]
          "targets": Target decoder outputs. [batch_size, decoder_length,
            hidden_dim]
          "target_space_id": A scalar int from data_generators.problem.SpaceID.

    Returns:
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
    """
        tf.logging.info("Using PgScratch BODY function.")
        hparams = self._hparams

        losses = {}
        inputs = features["inputs"]
        target_space = features["target_space_id"]
        # encoder_output: <tf.float32>[batch_size, input_length, hidden_dim]
        # encoder_decoder_attention_bias: <tf.float32>[batch_size, input_length]
        encoder_output, encoder_decoder_attention_bias = self.encode(
            inputs, target_space, hparams, features=features, losses=losses)

        with tf.variable_scope("knowledge"):
            with tf.name_scope("knowledge_encoding"):
                # Encode knowledge.
                # <tf.float32>[batch_size, triple_num, emb_dim]
                fact_embedding, fact_lengths = self.encode_knowledge_bottom(
                    features)
                tf.logging.info("Encoded knowledge")

            with tf.name_scope("knowledge_selection_and_loss"):
                # Compute knowledge selection and loss.
                triple_logits, avg_triple_selection_loss, knowledge_encoder_output, transe_loss = self.compute_knowledge_selection_and_loss(
                    features, encoder_output, fact_embedding, fact_lengths,
                    hparams.margin, hparams.num_negative_samples)
                losses["kb_loss"] = avg_triple_selection_loss
                losses["transe_loss"] = transe_loss

        if hparams.attend_kb:
            tf.logging.info("ATTEND_KB is ACTIVE")
            with tf.name_scope("knowledge_attention"):

                knowledge_padding = tf.zeros_like(triple_logits,
                                                  dtype=tf.float32)
                knowledge_attention_bias = common_attention.attention_bias_ignore_padding(
                    knowledge_padding)
                encoder_output = tf.concat(
                    [knowledge_encoder_output, encoder_output], 1)
                encoder_decoder_attention_bias = tf.concat(
                    [knowledge_attention_bias, encoder_decoder_attention_bias],
                    -1)

        else:
            tf.logging.info("ATTEND_KB is INACTIVE")

        targets = features["targets"]
        targets_shape = common_layers.shape_list(targets)
        targets = common_layers.flatten4d3d(targets)

        (decoder_input, decoder_self_attention_bias
         ) = transformer.transformer_prepare_decoder(targets,
                                                     hparams,
                                                     features=features)

        decode_kwargs = {}
        decoder_output = self.decode(
            decoder_input,
            encoder_output,
            encoder_decoder_attention_bias,
            decoder_self_attention_bias,
            hparams,
            nonpadding=transformer.features_to_nonpadding(features, "targets"),
            losses=losses,
            **decode_kwargs)

        expected_attentions = features.get("expected_attentions")
        if expected_attentions is not None:
            attention_loss = common_attention.encoder_decoder_attention_loss(
                expected_attentions, self.attention_weights,
                hparams.expected_attention_loss_type,
                hparams.expected_attention_loss_multiplier)
            return decoder_output, {"attention_loss": attention_loss}

        ret = tf.reshape(decoder_output, targets_shape)
        if losses:
            return ret, losses
        else:
            return ret
def ae_transformer_internal(inputs,
                            targets,
                            target_space,
                            hparams,
                            beam_size,
                            cache=None,
                            predict_mask=1.0):
    """AE Transformer, main step used for training."""
    # Summaries break with the do_refine cond, turn them off in that case.
    global _DO_SUMMARIES
    if hparams.do_refine:
        _DO_SUMMARIES = False

    # Prepare.
    orig_targets = targets
    batch_size = tf.shape(orig_targets)[0]
    targets = tf.reshape(targets, [batch_size, -1, 1, hparams.hidden_size])

    # Encoder.
    if inputs is not None:
        inputs = common_layers.flatten4d3d(inputs)
        inputs, ed = encode(inputs, target_space, hparams, "input_enc")
    else:
        ed = None

    # Autoencoding.
    losses = {"extra": tf.constant(0.0), "latent_pred": tf.constant(0.0)}
    if hparams.do_ae:
        max_targets_len_from_inputs = tf.concat([inputs, inputs], axis=1)
        targets, _ = common_layers.pad_to_same_length(
            targets,
            max_targets_len_from_inputs,
            final_length_divisible_by=2**hparams.num_compress_steps)
        targets_c = compress(targets, False, hparams, "compress")
        if hparams.mode != tf.estimator.ModeKeys.PREDICT:
            # Compress and bottleneck.
            t_c, t_bit, vc_loss, _ = bottleneck(targets_c, hparams, 2 * 2048,
                                                "vc")
            if _DO_SUMMARIES:
                tf.summary.histogram("bit0", tf.reshape(t_bit[:, 0, :], [-1]))
            pc = common_layers.inverse_exp_decay(hparams.startup_steps) * 0.95
            pc = pc if hparams.mode == tf.estimator.ModeKeys.TRAIN else 1.0
            cond = tf.less(tf.random_uniform([]), pc)
            t_c = tf.cond(cond, lambda: t_c, lambda: targets_c)
            losses["extra"] = vc_loss * tf.to_float(cond)
            # Extra loss predicting latent code from input. Discrete only.
            if hparams.bottleneck_kind not in ["dense", "vae"]:
                t_pred = decode_transformer(inputs, ed, tf.stop_gradient(t_c),
                                            hparams, "extra")
                t_pred = tf.layers.dense(t_pred, 2**16, name="extra_logits")
                losses[
                    "latent_pred"] = tf.nn.sparse_softmax_cross_entropy_with_logits(
                        labels=t_bit, logits=t_pred)
                losses["latent_pred"] = tf.reduce_mean(
                    losses["latent_pred"]) * 0.5 * tf.to_float(cond)
        else:
            if hparams.bottleneck_kind in ["dense", "vae"]:
                targets_rand = tf.random_uniform(tf.shape(targets_c))
                t_c, _, _, _ = bottleneck(targets_rand, hparams, 2 * 2048,
                                          "vc")
            else:
                latent_len = tf.shape(targets_c)[1]
                _, _, _, embed = bottleneck(targets_c, hparams, 2 * 2048, "vc")
                t_c = tf.zeros_like(targets_c[:, :latent_len, :, :])
                if cache is None:
                    cache = ae_latent_sample(t_c, inputs, ed, embed, 8,
                                             hparams)
                    cache = cache[0, :, :]
                    cache = tf.reshape(cache, [1, latent_len, 1])
                    cache = tf.tile(cache, [beam_size, 1, 1])
                t_c = embed(cache)
        # Postprocess.
        d = t_c
        pos = tf.get_variable("pos", [1, 1000, 1, hparams.hidden_size])
        pos = pos[:, :tf.shape(t_c)[1] + 1, :, :]
        t_c = tf.pad(t_c, [[0, 0], [1, 0], [0, 0], [0, 0]]) + pos

        # Masking.
        if hparams.do_mask:
            masking = common_layers.inverse_lin_decay(100000)
            masking *= common_layers.inverse_exp_decay(
                25000)  # Not much at start.
            if not hparams.do_refine:
                masking -= tf.random_uniform([]) * 0.3
            masking = tf.minimum(tf.maximum(masking, 0.0), 1.0)
            if hparams.mode == tf.estimator.ModeKeys.PREDICT:
                masking = predict_mask
            mask = tf.less(masking, tf.random_uniform(tf.shape(targets)[:-1]))
            mask = tf.expand_dims(tf.to_float(mask), 3)
            for i in xrange(hparams.num_compress_steps):
                j = hparams.num_compress_steps - i - 1
                d = residual_conv(d, 1, (3, 1), hparams,
                                  "decompress_rc_%d" % j)
                d = decompress_step(d, hparams, i > 0, False,
                                    "decompress_%d" % j)
            targets = mask * targets + (1.0 - mask) * d
        targets = tf.concat([tf.reverse(t_c, [1]), targets], axis=1)

    res = decode_transformer(inputs, ed, targets, hparams, "decoder")
    if hparams.do_ae:
        res = res[:, tf.shape(t_c)[1]:, :, :]
        if hparams.do_mask and hparams.do_refine:

            def refine_res():
                return residual_conv(res, 1, (5, 1), hparams, "refine")

            all_masked = tf.less(tf.reduce_sum(mask), 0.1)
            res = tf.cond(all_masked, refine_res, lambda: res)
    return res, losses, cache
示例#59
0
def mel_perf_transformer_encode(encoder_function,
                                perf_inputs,
                                mel_inputs,
                                target_space,
                                hparams,
                                attention_weights=None,
                                features=None,
                                losses=None,
                                prepare_encoder_fn=None,
                                **kwargs):
    """Encode transformer inputs. Used for melody & performance autoencoder.

    Performance is mean-aggregated across time and combined with melody in a
    variety of different ways.

    Args:
      encoder_function: the encoder function
      perf_inputs: Transformer inputs [batch_size, input_length, 1, hidden_dim]
      which will be flattened along the two spatial dimensions.
      mel_inputs: Transformer inputs [batch_size, input_length, 1, hidden_dim]
      which will be flattened along the two spatial dimensions.
      target_space: scalar, target space ID.
      hparams: hyperparameters for model.
      attention_weights: weight to store attention to.
      features: optionally pass the entire features dictionary as well. This is
        needed now for "packed" datasets.
      losses: optional list onto which to append extra training losses
      prepare_encoder_fn: optional, alternative to transformer_prepare_encoder.
      **kwargs: additional arguments to pass to encoder_function

    Returns:
      Tuple of:
          encoder_output: Encoder representation.
              [batch_size, input_length, hidden_dim]
          encoder_decoder_attention_bias: Bias and mask weights for
              encoder-decoder attention. [batch_size, input_length]
    """
    perf_inputs = common_layers.flatten4d3d(perf_inputs)
    mel_inputs = common_layers.flatten4d3d(mel_inputs)

    if not prepare_encoder_fn:
        prepare_encoder_fn = transformer_prepare_encoder
    perf_encoder_input, perf_self_attention_bias, perf_encdec_attention_bias = (
        prepare_encoder_fn(perf_inputs,
                           target_space,
                           hparams,
                           features=features,
                           reuse_target_embedding=tf.AUTO_REUSE))

    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT,
        value=hparams.layer_prepostprocess_dropout,
        hparams=hparams)

    perf_encoder_input = tf.nn.dropout(
        perf_encoder_input, 1.0 - hparams.layer_prepostprocess_dropout)

    perf_attn_bias_for_padding = None
    # Otherwise the encoder will just use encoder_self_attention_bias.
    if hparams.unidirectional_encoder:
        perf_attn_bias_for_padding = perf_encdec_attention_bias

    # do the same thing for melody
    mel_encoder_input, mel_self_attention_bias, mel_encdec_attention_bias = (
        prepare_encoder_fn(mel_inputs,
                           target_space,
                           hparams,
                           features=features,
                           reuse_target_embedding=tf.AUTO_REUSE))

    mlperf_log.transformer_print(
        key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT,
        value=hparams.layer_prepostprocess_dropout,
        hparams=hparams)

    mel_encoder_input = tf.nn.dropout(
        mel_encoder_input, 1.0 - hparams.layer_prepostprocess_dropout)

    mel_attn_bias_for_padding = None
    # Otherwise the encoder will just use encoder_self_attention_bias.
    if hparams.unidirectional_encoder:
        mel_attn_bias_for_padding = mel_encdec_attention_bias

    # use the proper encoder function for perf/melody
    perf_encoder_output = encoder_function(
        perf_encoder_input,
        perf_self_attention_bias,
        hparams,
        name="perf_encoder",
        nonpadding=features_to_nonpadding(features, "inputs"),
        save_weights_to=attention_weights,
        make_image_summary=not common_layers.is_xla_compiled(),
        losses=losses,
        attn_bias_for_padding=perf_attn_bias_for_padding,
        **kwargs)
    # same thing for melody
    mel_encoder_output = encoder_function(
        mel_encoder_input,
        mel_self_attention_bias,
        hparams,
        name="mel_encoder",
        nonpadding=features_to_nonpadding(features, "inputs"),
        save_weights_to=attention_weights,
        make_image_summary=not common_layers.is_xla_compiled(),
        losses=losses,
        attn_bias_for_padding=mel_attn_bias_for_padding,
        **kwargs)

    # concatenate the global mean vector/bias term with the full melody encoding
    perf_mean_vector = tf.math.reduce_mean(perf_encoder_output,
                                           axis=1,
                                           keep_dims=True)

    # different methods of aggregating over the performance + melody vectors!
    if hparams.aggregation == "sum":
        # add both mean performance and melody vectors together
        perf_mean_bias = tf.math.reduce_mean(perf_encdec_attention_bias,
                                             axis=-1,
                                             keep_dims=True)
        encoder_output = mel_encoder_output + perf_mean_vector
        encoder_decoder_attention_bias = mel_encdec_attention_bias + perf_mean_bias
    elif hparams.aggregation == "concat":
        # concatenate melody with mean-aggregated performance embedding
        stop_token = tf.zeros((1, 1, 384))
        encoder_output = tf.concat(
            [mel_encoder_output, stop_token, perf_mean_vector], axis=1)
        perf_mean_bias = tf.math.reduce_mean(perf_encdec_attention_bias,
                                             axis=-1,
                                             keep_dims=True)
        stop_bias = tf.zeros((1, 1, 1, 1))
        encoder_decoder_attention_bias = tf.concat(
            [mel_encdec_attention_bias, stop_bias, perf_mean_bias], axis=-1)
    elif hparams.aggregation == "tile":
        # tile performance embedding across each dimension of melody embedding!
        dynamic_val = tf.shape(mel_encoder_output)[1]
        shp = tf.convert_to_tensor([1, dynamic_val, 1], dtype=tf.int32)
        tiled_mean = tf.tile(perf_mean_vector, shp)

        encoder_output = tf.concat([mel_encoder_output, tiled_mean], axis=-1)
        encoder_decoder_attention_bias = mel_encdec_attention_bias
    else:
        NotImplementedError(
            "aggregation method must be in [sum, concat, tile].")

    return encoder_output, encoder_decoder_attention_bias
示例#60
0
def maybe_flatten4d3d(x):
    xshape = common_layers.shape_list(x)
    return common_layers.flatten4d3d(x) if len(xshape) == 4 else x