Exemple #1
0
def transformer_prepare_encoder(inputs, hparams):
    """Prepare one shard of the model for the encoder.

    Args:
        inputs: [batch_size, input_length, hidden_dim]
        hparams: hyperparameters

    Returns:
        encoder_input: a Tensor, bottom of encoder stack
            [batch_size, input_length, hidden_dim]
        encoder_self_attention_bias: a bias tensor for use in encoder
            self-attention [batch_size, input_length]
        top_layer_attention_bias: a bias tensor for use in top layer
            classification [batch_size, input_length]
    """
    ishape_static = inputs.shape.as_list()
    encoder_input = inputs
    encoder_padding = common_attention.embedding_to_padding(encoder_input)
    ignore_padding = common_attention.attention_bias_ignore_padding(
        encoder_padding)
    encoder_self_attention_bias = ignore_padding
    top_layer_attention_bias = ignore_padding
    if hparams.proximity_bias:
        encoder_self_attention_bias += common_attention.attention_bias_proximal(
            tf.shape(inputs)[1])
    if hparams.pos == "timing":
        encoder_input = common_attention.add_timing_signal_1d(encoder_input)
    return (encoder_input, encoder_self_attention_bias,
            top_layer_attention_bias)
Exemple #2
0
def transformer_prepare_encoder(inputs, target_space, hparams):
  """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
  ishape_static = inputs.shape.as_list()
  encoder_input = inputs
  encoder_padding = common_attention.embedding_to_padding(encoder_input)
  ignore_padding = common_attention.attention_bias_ignore_padding(
      encoder_padding)
  encoder_self_attention_bias = ignore_padding
  encoder_decoder_attention_bias = ignore_padding
  if hparams.proximity_bias:
    encoder_self_attention_bias += common_attention.attention_bias_proximal(
        tf.shape(inputs)[1])
  # Append target_space_id embedding to inputs.
  emb_target_space = common_layers.embedding(
      target_space, 32, ishape_static[-1], name="target_space_embedding")
  emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
  encoder_input += emb_target_space
  if hparams.pos == "timing":
    encoder_input = common_attention.add_timing_signal_1d(encoder_input)
  return (encoder_input, encoder_self_attention_bias,
          encoder_decoder_attention_bias)
def prepare_image_question_encoder(image_feat, question, hparams):
    """Prepare encoder.

  Args:
    image_feat: a Tensor.
    question: a Tensor.
    hparams: run hyperparameters

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """

    encoder_input = tf.concat([image_feat, question], axis=1)
    encoder_padding = common_attention.embedding_to_padding(encoder_input)
    ignore_padding = common_attention.attention_bias_ignore_padding(
        encoder_padding)
    encoder_self_attention_bias = ignore_padding
    encoder_decoder_attention_bias = ignore_padding
    # Usual case - not a packed dataset.
    if hparams.pos == "timing":
        question = common_attention.add_timing_signal_1d(question)
    elif hparams.pos == "emb":
        question = common_attention.add_positional_embedding(
            question, hparams.max_length, "inputs_positional_embedding", None)
    encoder_input = tf.concat([image_feat, question], axis=1)

    return (encoder_input, encoder_self_attention_bias,
            encoder_decoder_attention_bias)
def prepare_image_question_encoder(image_feat, question, hparams):
  """Prepare encoder.

  Args:
    image_feat: a Tensor.
    question: a Tensor.
    hparams: run hyperparameters

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """

  encoder_input = tf.concat([image_feat, question], axis=1)
  encoder_padding = common_attention.embedding_to_padding(encoder_input)
  ignore_padding = common_attention.attention_bias_ignore_padding(
      encoder_padding)
  encoder_self_attention_bias = ignore_padding
  encoder_decoder_attention_bias = ignore_padding
  # Usual case - not a packed dataset.
  if hparams.pos == "timing":
    question = common_attention.add_timing_signal_1d(question)
  elif hparams.pos == "emb":
    question = common_attention.add_positional_embedding(
        question, hparams.max_length, "inputs_positional_embedding",
        None)
  encoder_input = tf.concat([image_feat, question], axis=1)

  return (encoder_input, encoder_self_attention_bias,
          encoder_decoder_attention_bias)
Exemple #5
0
def transformer_prepare_encoder2(encoder_input, target_space, hparams,
                                 emb_name):
    '''the same as the existing module except for being able to name the embedding'''
    # compute bias
    ishape_static = encoder_input.shape.as_list()
    encoder_padding = common_attention.embedding_to_padding(encoder_input)
    ignore_padding = common_attention.attention_bias_ignore_padding(
        encoder_padding)
    encoder_self_attention_bias = ignore_padding
    encoder_decoder_attention_bias = ignore_padding
    if hparams.proximity_bias:
        encoder_self_attention_bias += common_attention.attention_bias_proximal(
            tf.shape(encoder_input)[1])

    # Append target_space_id embedding to encoder_input
    id_values = [
        value for attr, value in vars(problem.SpaceID).items()
        if not attr.startswith("__")
    ]
    id_cur = int(max(id_values) + 1)
    emb_target_space = common_layers.embedding(target_space,
                                               id_cur,
                                               ishape_static[-1],
                                               name=emb_name)
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
    encoder_input += emb_target_space

    # position embedding
    if hparams.pos == "timing":
        encoder_input = common_attention.add_timing_signal_1d(encoder_input)
    return encoder_input, encoder_self_attention_bias, encoder_decoder_attention_bias
    def sample_p(self,
                 targets_length,
                 temp,
                 check_invertibility=False,
                 targets_mask=None,
                 **kwargs):
        hparams = self._hparams
        if targets_mask is None:
            targets_mask = ops.sequence_mask(targets_length, hparams)
        decoder_self_attention_bias = (
            common_attention.attention_bias_ignore_padding(1.0 - targets_mask))
        batch_size, targets_max_length = (
            common_layers.shape_list(targets_mask)[:2])
        prior_shape = [batch_size, targets_max_length, hparams.latent_size]
        noise = tf.random.normal(prior_shape, stddev=temp)
        p_dist = None

        if hparams.prior_type == "standard_normal":
            z_p = noise
        elif hparams.prior_type == "diagonal_normal":
            diag_prior_params = ops.cond_prior("diag_prior", hparams,
                                               tf.zeros(prior_shape),
                                               targets_mask,
                                               hparams.latent_size * 2,
                                               decoder_self_attention_bias,
                                               **kwargs)
            p_dist = gops.diagonal_normal(diag_prior_params, "diag_prior")
            z_p = p_dist.loc + p_dist.scale * noise
        elif hparams.prior_type in ["affine", "additive", "rq"]:
            n_levels = len(hparams.depths.split("/"))
            divi = max(1, hparams.factor**(n_levels - 1))
            flow_prior_shape = [
                batch_size, targets_max_length // divi, hparams.latent_size
            ]
            noise = tf.random_normal(flow_prior_shape, stddev=temp)
            z_p, _, _, _ = glow.glow("glow",
                                     noise,
                                     targets_mask,
                                     decoder_self_attention_bias,
                                     inverse=True,
                                     init=False,
                                     hparams=self._fparams,
                                     disable_dropout=True,
                                     temp=temp,
                                     **kwargs)
            if self.is_evaluating and check_invertibility:
                noise_inv, _, _, _ = glow.glow("glow",
                                               z_p,
                                               targets_mask,
                                               decoder_self_attention_bias,
                                               inverse=False,
                                               init=False,
                                               hparams=self._fparams,
                                               disable_dropout=True,
                                               **kwargs)
                z_diff = noise - noise_inv
                tf.summary.scalar("flow_recon_inverse",
                                  tf.reduce_max(tf.abs(z_diff)))
        return z_p, p_dist
def get_attention_bias(sequence_length):
  """Create attention bias so attention is not applied at padding position."""
  # attention_bias: [batch, 1, 1, memory_length]
  invert_sequence_mask = tf.to_float(tf.logical_not(tf.sequence_mask(
      sequence_length)))
  attention_bias = common_attention.attention_bias_ignore_padding(
      invert_sequence_mask)
  return attention_bias
def transformer_prepare_decoder_right(targets, hparams, features=None):
    """Prepare one shard of the model for the decoder.

  Args:
    targets: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    decoder_input: a Tensor, bottom of decoder stack
    decoder_self_attention_bias: a bias tensor for use in decoder self-attention
  """
    if hparams.causal_decoder_self_attention:
        # Causal attention.
        if hparams.prepend_mode == "prepend_inputs_full_attention":
            decoder_self_attention_bias = (
                common_attention.attention_bias_prepend_inputs_full_attention(
                    common_attention.embedding_to_padding(targets)))
        else:
            decoder_self_attention_bias = (
                common_attention.attention_bias_local(
                    common_layers.shape_list(targets)[1], 0, -1))
    else:
        # Full attention.
        decoder_padding = common_attention.embedding_to_padding(targets)
        decoder_self_attention_bias = (
            common_attention.attention_bias_ignore_padding(decoder_padding))

    if features and "targets_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        targets_segmentation = features["targets_segmentation"]
        targets_position = features["targets_position"]
        decoder_self_attention_bias += common_attention.attention_bias_same_segment(
            targets_segmentation, targets_segmentation)
    else:
        targets_position = None
    if hparams.proximity_bias:
        decoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(targets)[1])
    decoder_input = shift_left_3d(targets)
    if hparams.pos == "timing":
        if targets_position is not None:
            decoder_input = common_attention.add_timing_signal_1d_given_position(
                decoder_input, targets_position)
        else:
            decoder_input = common_attention.add_timing_signal_1d(
                decoder_input)
    elif hparams.pos == "emb":
        decoder_input = common_attention.add_positional_embedding(
            decoder_input, hparams.max_length, "targets_positional_embedding",
            targets_position)

    if hparams.activation_dtype == "bfloat16":
        decoder_self_attention_bias = tf.cast(decoder_self_attention_bias,
                                              tf.bfloat16)
    return (decoder_input, decoder_self_attention_bias)
def transformer_prepare_encoder(inputs, target_space, hparams, features=None):
    """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
    ishape_static = inputs.shape.as_list()
    encoder_input = inputs
    if features and "inputs_segmentation" in features:
        # Packed dataset.  Keep the examples from seeing each other.
        inputs_segmentation = features["inputs_segmentation"]
        inputs_position = features["inputs_position"]
        targets_segmentation = features["targets_segmentation"]
        encoder_self_attention_bias = common_attention.attention_bias_same_segment(
            inputs_segmentation, inputs_segmentation)
        encoder_decoder_attention_bias = (
            common_attention.attention_bias_same_segment(
                targets_segmentation, inputs_segmentation))
    else:
        # Usual case - not a packed dataset.
        encoder_padding = common_attention.embedding_to_padding(encoder_input)
        ignore_padding = common_attention.attention_bias_ignore_padding(
            encoder_padding)
        encoder_self_attention_bias = ignore_padding
        encoder_decoder_attention_bias = ignore_padding
        inputs_position = None
    if hparams.proximity_bias:
        encoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(inputs)[1])
    # Append target_space_id embedding to inputs.
    emb_target_space = common_layers.embedding(target_space,
                                               32,
                                               ishape_static[-1],
                                               name="target_space_embedding")
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
    encoder_input += emb_target_space
    if hparams.pos == "timing":
        if inputs_position is not None:
            encoder_input = common_attention.add_timing_signal_1d_given_position(
                encoder_input, inputs_position)
        else:
            encoder_input = common_attention.add_timing_signal_1d(
                encoder_input)
    return (encoder_input, encoder_self_attention_bias,
            encoder_decoder_attention_bias)
Exemple #10
0
def get_ignore_padding(inputs):
    """
    Args:
        inputs: Tensor with shape [batch, memory_length, depth]
    """
    # Extract which individual embedding vectors are identically zero.
    # encoder_padding has shape [batch, memory_length].
    padding = comm_attn.embedding_to_padding(inputs)
    # ignore_padding has shape [batch, 1, 1, memory_length].
    # it also replaces all 1s in encoder_padding with -1e9 because idk.
    ignore_padding = comm_attn.attention_bias_ignore_padding(padding)
    return ignore_padding
Exemple #11
0
def transformer_prepare_encoder(inputs, target_space, hparams, features=None):
  """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
  ishape_static = inputs.shape.as_list()
  encoder_input = inputs
  if features and "inputs_segmentation" in features:
    # Packed dataset.  Keep the examples from seeing each other.
    inputs_segmentation = features["inputs_segmentation"]
    inputs_position = features["inputs_position"]
    targets_segmentation = features["targets_segmentation"]
    encoder_self_attention_bias = common_attention.attention_bias_same_segment(
        inputs_segmentation, inputs_segmentation)
    encoder_decoder_attention_bias = (
        common_attention.attention_bias_same_segment(
            targets_segmentation, inputs_segmentation))
  else:
    # Usual case - not a packed dataset.
    encoder_padding = common_attention.embedding_to_padding(encoder_input)
    ignore_padding = common_attention.attention_bias_ignore_padding(
        encoder_padding)
    encoder_self_attention_bias = ignore_padding
    encoder_decoder_attention_bias = ignore_padding
    inputs_position = None
  if hparams.proximity_bias:
    encoder_self_attention_bias += common_attention.attention_bias_proximal(
        common_layers.shape_list(inputs)[1])
  # Append target_space_id embedding to inputs.
  emb_target_space = common_layers.embedding(
      target_space, 32, ishape_static[-1], name="target_space_embedding")
  emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
  encoder_input += emb_target_space
  if hparams.pos == "timing":
    if inputs_position is not None:
      encoder_input = common_attention.add_timing_signal_1d_given_position(
          encoder_input, inputs_position)
    else:
      encoder_input = common_attention.add_timing_signal_1d(encoder_input)
  return (encoder_input, encoder_self_attention_bias,
          encoder_decoder_attention_bias)
def transformer_prepare_encoder(inputs_emb_var,
                                inputs,
                                hparams,
                                features=None):
    """Prepare one shard of the model for the encoder.

  Args:
    inputs_emb_var: a Tensor
    inputs: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
    encoder_input = tf.gather(inputs_emb_var, inputs)

    if features and "inputs_segmentation" in features:
        # Packed dataset.  Keep the examples from seeing each other.
        inputs_segmentation = features["inputs_segmentation"]
        inputs_position = features["inputs_position"]
        targets_segmentation = features["targets_segmentation"]
        encoder_self_attention_bias = common_attention.attention_bias_same_segment(
            inputs_segmentation, inputs_segmentation)
        encoder_decoder_attention_bias = (
            common_attention.attention_bias_same_segment(
                targets_segmentation, inputs_segmentation))
    else:
        # Usual case - not a packed dataset.
        encoder_padding = tf.to_float(tf.equal(inputs, 0))
        ignore_padding = common_attention.attention_bias_ignore_padding(
            encoder_padding)
        encoder_self_attention_bias = ignore_padding
        encoder_decoder_attention_bias = ignore_padding
        inputs_position = None
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        encoder_input = common_attention.add_positional_embedding(
            encoder_input, hparams.max_length, "positional_embedding",
            inputs_position)
    if hparams.activation_dtype == "bfloat16":
        encoder_self_attention_bias = tf.cast(encoder_self_attention_bias,
                                              tf.bfloat16)
        encoder_decoder_attention_bias = tf.cast(
            encoder_decoder_attention_bias, tf.bfloat16)
    return (encoder_input, encoder_self_attention_bias,
            encoder_decoder_attention_bias)
Exemple #13
0
  def model_fn_body(self, features):
   
    hparams = self._hparams
    inputs = features.get("inputs")
    firstP = features.get("firstP")
    firstP = common_layers.flatten4d3d(firstP)
    targets = features["targets"]
    targets = common_layers.flatten4d3d(targets)
    #JI: set image dimensions
    imageP = features.get("imageP")
    imageP.set_shape([None, 1, 19600])
    imageP=tf.reshape(imageP,[-1, img_dim, 100])  
    
    encoder_output, encoder_decoder_attention_bias = (None, None)
    if inputs is not None:
      target_space = features["target_space_id"]
      #JI: if needed pass images to encoder
      encoder_output, encoder_decoder_attention_bias = self.encode(inputs, target_space, hparams, imageP=None)

    # used to extract hidden states
    (decoder_input, decoder_self_attention_bias) = transformer_prepare_decoder(firstP, hparams)
    # the conventional `targets` used for the second-pass decoder, i.e., delib-decoder
    (delibdecoder_input, delibdecoder_self_attention_bias) = transformer_prepare_decoder(targets, hparams)
    # the `delibctx` used for the second-pass decoder
    firstP_input, firstP_self_attention_bias = self.transformer_prepare_delibdecoder(firstP, hparams)
    
    # add dropout to the two decoders
    decoder_input = tf.nn.dropout(decoder_input, 1.0 - hparams.layer_prepostprocess_dropout)
    delibdecoder_input = tf.nn.dropout(delibdecoder_input, 1.0 - hparams.layer_prepostprocess_dropout)

    decoder_output = transformer_decoder(decoder_input,
                                         encoder_output,
                                         decoder_self_attention_bias,
                                         encoder_decoder_attention_bias,
                                         hparams,
                                         cache=None)
    
    firstP_input = tf.concat(values=[firstP_input, decoder_output], axis=-1)
    
    #JI: get biases for image attention
    img_encoder_padding = common_attention.embedding_to_padding(imageP)
    imageP_self_attention_bias = common_attention.attention_bias_ignore_padding(img_encoder_padding)

    #JI: pass images to the decoder
    delibdecoder_output = transformer_delibdecoder(
        delibdecoder_input, encoder_output, firstP_input, imageP,
        delibdecoder_self_attention_bias, encoder_decoder_attention_bias, firstP_self_attention_bias, imageP_self_attention_bias,
        hparams, cache=None, name="delib_decoder")
    return delibdecoder_output
Exemple #14
0
  def transformer_prepare_delibdecoder(self, inputs, hparams):
    """Prepare one shard of the model for the encoder.
    Args:
    inputs: a Tensor.
    hparams: run hyperparameters
    Returns:
    """
    firstPdecoder_input = inputs
    firstPdecoder_padding = common_attention.embedding_to_padding(firstPdecoder_input)
    ignore_padding = common_attention.attention_bias_ignore_padding(firstPdecoder_padding)
    firstP_delib_attention_bias = ignore_padding
    if hparams.pos == "timing":
      firstPdecoder_input = common_attention.add_timing_signal_1d(firstPdecoder_input)

    return (firstPdecoder_input, firstP_delib_attention_bias)
Exemple #15
0
    def encode(self, inputs, hparams, features=None):
        train = hparams.mode == tf.estimator.ModeKeys.TRAIN
        inputs_length = common_layers.length_from_embedding(inputs)

        # Flatten inputs.
        inputs = common_layers.flatten4d3d(inputs)

        encoder_padding = common_attention.embedding_to_padding(inputs)
        encoder_decoder_attention_bias = common_attention.attention_bias_ignore_padding(
            encoder_padding)

        # LSTM encoder.
        encoder_outputs, final_encoder_state = lstm_bid_encoder(
            inputs, inputs_length, self._hparams, train, "encoder")

        return encoder_outputs, final_encoder_state, encoder_decoder_attention_bias, inputs_length
  def compute_iw_marginal(
      self, targets, targets_mask, decoder_self_attention_bias, features,
      n_samples, reduce_mean=True, **kwargs):
    hparams = self._hparams
    z_q, log_q_z, _ = self.sample_q(
        targets, targets_mask, decoder_self_attention_bias,
        n_samples=n_samples, temp=1.0, **kwargs)  # [K*B, L, C]
    iw_kwargs = {key: ops.prepare_for_iw(value, n_samples) for (
        key, value) in kwargs.items()}
    iw_targets_mask = ops.prepare_for_iw(targets_mask, n_samples)
    iw_decoder_self_attention_bias = (
        common_attention.attention_bias_ignore_padding(1.0 - iw_targets_mask))
    iw_features = copy.copy(features)
    iw_features["targets"] = ops.prepare_for_iw(
        features["targets"], n_samples)

    log_p_z_base, log_abs_det = self.compute_prior_log_prob(
        z_q, iw_targets_mask, iw_decoder_self_attention_bias,
        check_invertibility=False, **iw_kwargs)
    log_p_z = log_p_z_base + log_abs_det

    body_output = ops.decoder(
        "decoder", z_q, hparams, iw_decoder_self_attention_bias, **iw_kwargs)
    logits = self.top(body_output, iw_features)
    numerator, denominator = self.loss_iw(logits, iw_features)
    numerator = tf.reduce_sum(numerator[..., 0, 0], 1)  # [K*B]
    denominator = tf.reduce_sum(denominator[..., 0, 0], 1)  # [K*B]
    log_p_x = -1 * numerator / denominator
    log_q_z = gops.reduce_mean_over_l_sum_over_c(log_q_z, iw_targets_mask)
    log_p_z = log_p_z / tf.reduce_sum(iw_targets_mask, 1)

    log_p_x, log_q_z, log_p_z = [ops.unprepare_for_iw(ii, n_samples) for ii in [
        log_p_x, log_q_z, log_p_z]]

    log_w_n = log_p_z - log_q_z
    log_w_n = tf.nn.log_softmax(log_w_n, axis=0)  # [K, B]

    iw_marginal = log_p_x + log_w_n
    iw_marginal = tf.reduce_logsumexp(iw_marginal, 0)  # [B]

    if reduce_mean:
      iw_marginal = tf.cast(tf.reduce_mean(iw_marginal, 0), tf.float32)  # [1]
    else:
      iw_marginal = tf.cast(iw_marginal, tf.float32)  # [1]
    return iw_marginal
Exemple #17
0
  def model_fn_body(self, features):
    """Transformer main model_fn.

    Args:
      features: Map of features to the model. Should contain the following:
          "inputs": Transformer inputs [batch_size, input_length, hidden_dim]
          "tragets": Target decoder outputs.
              [batch_size, decoder_length, hidden_dim]
          "target_space_id"

    Returns:
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
    """
    #JI: set image shapes
    imageP = features.get("imageP")
    imageP.set_shape([None, 1, 19600])
    imageP=tf.reshape(imageP,[-1, img_dim, 100])

    hparams = self._hparams

    inputs = features.get("inputs")
    encoder_output, encoder_decoder_attention_bias = (None, None)
    if inputs is not None:
      target_space = features["target_space_id"]
      # JI: send images to encoder if needed
      encoder_output, encoder_decoder_attention_bias = self.encode(
          inputs, target_space, hparams,imageP=None)

    targets = features["targets"]
    targets = common_layers.flatten4d3d(targets)

    decoder_input, decoder_self_attention_bias = transformer_prepare_decoder(
        targets, hparams)
    
    #JI: compute attention bias for images for decoder
    img_encoder_padding = common_attention.embedding_to_padding(imageP)
    imageP_decoder_self_attention_bias = common_attention.attention_bias_ignore_padding(img_encoder_padding)
    
    #JI: send images for decoder if needed
    return self.decode(decoder_input, encoder_output,
                       encoder_decoder_attention_bias,
                       decoder_self_attention_bias, hparams, imageP=imageP, imageP_decoder_self_attention_bias=imageP_decoder_self_attention_bias)
Exemple #18
0
    def forward(self, contexts_emb, contexts, abbr_inp_emb, longform_emb=None):
        """
        :param contexts_emb: [batch_size, context_len, emb_dim]
        :param contexts: a list of tensors of words, [batch_size] * context_len
        :param abbr_inp_emb: [batch_size, 1, emb_dim]
        :param longform_emb: [batch_size, longform_len, emb_dim]
        :return:
               decoder_output: predicted abbr embedding, [batch_size, 1, emb_dim]
        """
        saved_weights = {}
        extra_loss = None

        contexts_bias = common_attention.attention_bias_ignore_padding(
            tf.to_float(
                tf.equal(tf.stack(contexts, axis=1),
                         self.voc.encode(constant.PAD))))

        contexts_emb = tf.nn.dropout(
            contexts_emb, 1.0 - self.hparams.layer_prepostprocess_dropout)
        abbr_inp_emb = tf.nn.dropout(
            abbr_inp_emb, 1.0 - self.hparams.layer_prepostprocess_dropout)

        # [batch_size, context_len, emb_dim]
        encoder_output = transformer.transformer_encoder(
            contexts_emb,
            contexts_bias,
            hparams=self.hparams,
            save_weights_to=saved_weights)

        # [batch_size, 1, emb_dim]
        decoder_output = transformer.transformer_decoder(
            abbr_inp_emb,
            encoder_output,
            decoder_self_attention_bias=tf.zeros(
                [self.model_config.batch_size, 1, 1, 1]),
            encoder_decoder_attention_bias=contexts_bias,
            hparams=self.hparams,
            save_weights_to=saved_weights)

        return decoder_output, saved_weights, extra_loss
Exemple #19
0
def transformer_prepare_encoder(inputs, target_space, hparams):
    """Copied from tensor2tensor.models.transformer."""
    ishape_static = inputs.shape.as_list()
    encoder_input = inputs
    encoder_padding = common_attention.embedding_to_padding(encoder_input)
    ignore_padding = common_attention.attention_bias_ignore_padding(
        encoder_padding)
    encoder_self_attention_bias = ignore_padding
    encoder_decoder_attention_bias = ignore_padding
    if hparams.proximity_bias:
        encoder_self_attention_bias += common_attention.attention_bias_proximal(
            tf.shape(inputs)[1])
    # Append target_space_id embedding to inputs.
    emb_target_space = common_layers.embedding(target_space,
                                               32,
                                               ishape_static[-1],
                                               name="target_space_embedding")
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
    encoder_input += emb_target_space
    if hparams.pos == "timing":
        encoder_input = common_attention.add_timing_signal_1d(encoder_input)
    return (encoder_input, encoder_self_attention_bias,
            encoder_decoder_attention_bias)
def prepare_question_encoder(inputs, hparams):
  """Prepare question encoder.

  Args:
    inputs: a Tensor.
    hparams: run hyperparameters

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
  encoder_input = inputs
  # Usual case - not a packed dataset.
  encoder_padding = common_attention.embedding_to_padding(encoder_input)
  ignore_padding = common_attention.attention_bias_ignore_padding(
      encoder_padding)
  encoder_self_attention_bias = ignore_padding
  if hparams.pos == "timing":
    encoder_input = common_attention.add_timing_signal_1d(encoder_input)
  elif hparams.pos == "emb":
    encoder_input = common_attention.add_positional_embedding(
        encoder_input, hparams.max_length, "inputs_positional_embedding",
        None)
  return (encoder_input, encoder_self_attention_bias)
def prepare_question_encoder(inputs, hparams):
    """Prepare question encoder.

  Args:
    inputs: a Tensor.
    hparams: run hyperparameters

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
  """
    encoder_input = inputs
    # Usual case - not a packed dataset.
    encoder_padding = common_attention.embedding_to_padding(encoder_input)
    ignore_padding = common_attention.attention_bias_ignore_padding(
        encoder_padding)
    encoder_self_attention_bias = ignore_padding
    if hparams.pos == "timing":
        encoder_input = common_attention.add_timing_signal_1d(encoder_input)
    elif hparams.pos == "emb":
        encoder_input = common_attention.add_positional_embedding(
            encoder_input, hparams.max_length, "inputs_positional_embedding",
            None)
    return (encoder_input, encoder_self_attention_bias)
  def body(self, features):
    hparams = self._hparams
    ps_devices = self._ps_devices
    single_device = (len(ps_devices) == 1)
    assert hparams.num_model_shards % len(ps_devices) == 0
    shards_per_device = hparams.num_model_shards // len(ps_devices)
    model_devices = [ps_devices[i // shards_per_device]
                     for i in range(hparams.num_model_shards)]
    print("model_devices = %s" % model_devices)
    mp = expert_utils.Parallelism(model_devices, reuse=False)
    targets_vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size
    # squeeze out channels, heights
    targets = tf.squeeze(features["targets_raw"], [2, 3])
    targets_embedding_var = mp(
        tf.get_variable, "embedding",
        [[targets_vocab_size, hparams.hidden_size]] * mp.n,
        initializer=tf.random_normal_initializer(
            0.0, hparams.hidden_size**-0.5))
    shifted_targets = common_layers.shift_right_2d(targets)
    # Bypass the symbol modality and use a different embedding on each shard.
    if single_device:
      targets_embedding_var_combined = tf.concat(targets_embedding_var, 1)
      decoder_input_combined = common_layers.embedding(
          shifted_targets, targets_vocab_size,
          hparams.hidden_size * mp.n,
          multiplier=hparams.hidden_size**0.5,
          embedding_var=targets_embedding_var_combined,
      )
      decoder_input = tf.split(decoder_input_combined, mp.n, axis=2)
    else:
      targets_embedding_var_combined = None
      decoder_input = mp(
          common_layers.embedding, shifted_targets, targets_vocab_size,
          hparams.hidden_size,
          multiplier=hparams.hidden_size**0.5,
          embedding_var=targets_embedding_var,
      )
    decoder_self_attention_bias = mp(
        common_attention.attention_bias_lower_triangle,
        tf.shape(targets)[1])
    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = features["targets_segmentation"]
      targets_position = features["targets_position"]
      decoder_self_attention_bias = mp(
          tf.add, decoder_self_attention_bias,
          mp(common_attention.attention_bias_same_segment,
             targets_segmentation, targets_segmentation))
      decoder_input = mp(
          common_attention.add_timing_signal_1d_given_position,
          decoder_input, targets_position)
    else:
      targets_position = None
      decoder_self_attention_bias = mp(
          common_attention.attention_bias_lower_triangle,
          tf.shape(targets)[1])
      decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input)

    if self.has_input:
      inputs = tf.squeeze(features["inputs_raw"], [2, 3])
      inputs_vocab_size = self._problem_hparams.vocabulary["inputs"].vocab_size
      # share everything for now
      share_inputs_and_targets_embedding = True
      if share_inputs_and_targets_embedding:
        assert inputs_vocab_size == targets_vocab_size
        inputs_embedding_var = targets_embedding_var
        inputs_embedding_var_combined = targets_embedding_var_combined
      if single_device:
        encoder_input_combined = common_layers.embedding(
            inputs, inputs_vocab_size,
            hparams.hidden_size * mp.n,
            multiplier=hparams.hidden_size**0.5,
            embedding_var=inputs_embedding_var_combined,
        )
        encoder_input = tf.split(encoder_input_combined, mp.n, axis=2)
      else:
        encoder_input = mp(
            common_layers.embedding, inputs, inputs_vocab_size,
            hparams.hidden_size,
            multiplier=hparams.hidden_size**0.5,
            embedding_var=inputs_embedding_var,
        )
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = features["inputs_segmentation"]
        inputs_position = features["inputs_position"]
        encoder_self_attention_bias = mp(
            common_attention.attention_bias_same_segment,
            inputs_segmentation, inputs_segmentation)
        encoder_decoder_attention_bias = mp(
            common_attention.attention_bias_same_segment,
            targets_segmentation, inputs_segmentation)
        encoder_input = mp(
            common_attention.add_timing_signal_1d_given_position,
            encoder_input, inputs_position)
      else:
        encoder_padding = tf.to_float(tf.equal(inputs, 0))
        ignore_padding = common_attention.attention_bias_ignore_padding(
            encoder_padding)
        encoder_self_attention_bias = ignore_padding
        encoder_decoder_attention_bias = ignore_padding
        inputs_position = None
        encoder_input = mp(common_attention.add_timing_signal_1d, encoder_input)

      # encoder stack here
      with tf.variable_scope("encoder"):
        encoder_input = mp(
            tf.nn.dropout, encoder_input,
            1.0 - hparams.layer_prepostprocess_dropout)
        encoder_output = _layer_stack(
            mp,
            encoder_input,
            encoder_self_attention_bias,
            hparams.encoder_layers,
            hparams)
    else:
      encoder_decoder_attention_bias = None
      encoder_output = None

    with tf.variable_scope("decoder"):
      decoder_input = mp(
          tf.nn.dropout, decoder_input,
          1.0 - hparams.layer_prepostprocess_dropout)
      decoder_output = _layer_stack(
          mp,
          decoder_input,
          decoder_self_attention_bias,
          layers=hparams.decoder_layers,
          hparams=hparams,
          encoder_output=encoder_output,
          encoder_decoder_attention_bias=encoder_decoder_attention_bias)

    # Bypass the symbol modality and compute logits directly.
    # We compute a different set of logits on each shard, and sum them.
    # Share the weights with the target embedding.
    output_var = targets_embedding_var
    output_var_combined = targets_embedding_var_combined
    if single_device:
      decoder_output = tf.concat(decoder_output, 2)
      logits = tf.tensordot(decoder_output, output_var_combined, [[2], [1]])
      num, denom = common_layers.padded_cross_entropy(
          logits, targets, hparams.label_smoothing)
      training_loss = num / denom
    else:
      logits = mp(
          tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n)
      logits = expert_utils.all_reduce_ring(logits, mp)
      # On each device, we compute the loss for a part of the batch.
      # This is faster than computing the whole loss on one shard.
      mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0])
      def _loss_for_shard(logits, targets, shard):
        logits = common_layers.approximate_split(logits, mp.n, 0)[shard]
        targets = common_layers.approximate_split(targets, mp.n, 0)[shard]
        return common_layers.padded_cross_entropy(
            logits, targets, hparams.label_smoothing)
      num, denom = mp(_loss_for_shard, logits, targets, range(mp.n))
      training_loss = tf.add_n(num) / tf.add_n(denom)
      logits = logits[0]
    logits = tf.expand_dims(tf.expand_dims(logits, 2), 3)
    # override training loss so that it is not computed externally.
    losses = {"training": training_loss}
    return logits, losses
    def transformer_fn(self,
                       sentence_complex_input_placeholder, emb_complex,
                       sentence_simple_input_placeholder, emb_simple,
                       w, b,
                       rule_id_input_placeholder, rule_target_input_placeholder,
                       mem_contexts, mem_outputs,
                       global_step, score, comp_features, obj):
        encoder_mask = tf.to_float(
            tf.equal(tf.stack(sentence_complex_input_placeholder, axis=1),
                     self.data.vocab_complex.encode(constant.SYMBOL_PAD)))
        encoder_attn_bias = common_attention.attention_bias_ignore_padding(encoder_mask)

        obj_tensors = {}

        train_mode = self.model_config.train_mode
        if self.model_config.bert_mode:
            # Leave space for decoder when static seq
            gpu_id = 0 if train_mode == 'static_seq' or train_mode == 'static_self-critical' or 'direct' in self.model_config.memory else 1
            with tf.device('/device:GPU:%s' % gpu_id):
                sentence_complex_input = tf.stack(sentence_complex_input_placeholder, axis=1)
                bert_model = BertModel(
                    BertConfig.from_json_file(self.model_config.bert_config),
                    self.is_train, sentence_complex_input,
                    input_mask=1.0-encoder_mask, token_type_ids=None, use_one_hot_embeddings=False)
                encoder_embed_inputs = bert_model.embedding_output
                encoder_outputs = bert_model.sequence_output
                emb_complex = bert_model.embedding_table # update emb complex
                if (self.model_config.tie_embedding == 'all' or
                        self.model_config.tie_embedding == 'enc_dec'):
                    emb_simple = bert_model.embedding_table
                if (self.model_config.tie_embedding == 'all' or
                        self.model_config.tie_embedding == 'dec_out'):
                    emb_w_proj = tf.get_variable(
                        'emb_w_proj', shape=[self.model_config.dimension, self.model_config.dimension],
                        initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32)
                    w = tf.matmul(bert_model.embedding_table, emb_w_proj)

                if 'direct' in self.model_config.memory:
                    with tf.device('/device:GPU:1'):
                        direct_mask = tf.to_float(
                            tf.equal(tf.stack(rule_target_input_placeholder, axis=1),
                                     self.data.vocab_complex.encode(constant.SYMBOL_PAD)))
                        direct_bert_model = BertModel(
                            BertConfig.from_json_file(self.model_config.bert_config),
                            self.is_train, tf.stack(rule_target_input_placeholder, axis=1),
                            input_mask=1.0 - direct_mask, token_type_ids=None, use_one_hot_embeddings=False,
                            embedding_table=emb_simple,
                            scope='direct')
                        direct_bert_output = direct_bert_model.sequence_output
                        obj_tensors['direct_bert_bias'] = common_attention.attention_bias_ignore_padding(direct_mask)
                        obj_tensors['direct_bert_output'] = direct_bert_output
        else:
            encoder_embed_inputs = tf.stack(
                self.embedding_fn(sentence_complex_input_placeholder, emb_complex), axis=1)
            if self.hparams.pos == 'timing':
                encoder_embed_inputs = common_attention.add_timing_signal_1d(encoder_embed_inputs)
                print('Use positional encoding in encoder text.')

            if self.model_config.subword_vocab_size and self.model_config.seg_mode:
                encoder_embed_inputs = common_attention.add_positional_embedding(
                    encoder_embed_inputs, 100, 'seg_embedding',
                    positions=obj['line_comp_segids'])
                print('Add segment embedding.')

            with tf.variable_scope('transformer_encoder'):
                encoder_embed_inputs = tf.nn.dropout(encoder_embed_inputs,
                                                     1.0 - self.hparams.layer_prepostprocess_dropout)

                if self.model_config.architecture == 'ut2t':
                    encoder_outputs, encoder_extra_output = universal_transformer_util.universal_transformer_encoder(
                        encoder_embed_inputs, encoder_attn_bias, self.hparams)
                    enc_ponder_times, enc_remainders = encoder_extra_output
                    extra_encoder_loss = (
                            self.hparams.act_loss_weight *
                            tf.reduce_mean(enc_ponder_times + enc_remainders))
                    if self.is_train:
                        obj_tensors['extra_encoder_loss'] = extra_encoder_loss
                else:
                    encoder_outputs = transformer.transformer_encoder(
                        encoder_embed_inputs, encoder_attn_bias, self.hparams)

                # Update score based on multiplier
                score, pred_score_tuple = self.update_score(
                    score, encoder_outputs=encoder_outputs, encoder_mask=tf.to_float(
                        tf.not_equal(tf.stack(sentence_complex_input_placeholder, axis=1),
                                     self.data.vocab_complex.encode(constant.SYMBOL_PAD))),
                    comp_features=comp_features)

                encoder_outputs = self.update_encoder_embedding(encoder_outputs, score)

        encoder_embed_inputs_list = tf.unstack(encoder_embed_inputs, axis=1)

        with tf.variable_scope('transformer_decoder', reuse=tf.AUTO_REUSE):
            if self.model_config.subword_vocab_size or 'bert_token' in self.model_config.bert_mode:
                go_id = self.data.vocab_simple.encode(constant.SYMBOL_GO)[0]
            else:
                go_id = self.data.vocab_simple.encode(constant.SYMBOL_GO)
            batch_go = tf.tile(
                tf.expand_dims(self.embedding_fn(go_id, emb_simple), axis=0),
                [self.model_config.batch_size, 1])

            # For static_seq train_mode
            if self.model_config.npad_mode == 'static_seq':
                with tf.variable_scope('npad'):
                    npad_w = tf.get_variable(
                        'npad_w', shape=[1, self.model_config.dimension, self.model_config.dimension],
                        initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32)
                    obj_tensors['npad_w'] = npad_w

            if self.is_train and (train_mode == 'teacher' or
                                  train_mode == 'teachercritical'or train_mode ==  'teachercriticalv2'):
                # General train
                print('Use Generally Process.')
                decoder_embed_inputs_list = self.embedding_fn(
                    sentence_simple_input_placeholder[:-1], emb_simple)
                final_output, decoder_output, cur_context = self.decode_step(
                    decoder_embed_inputs_list, encoder_outputs, encoder_attn_bias,
                    rule_id_input_placeholder, mem_contexts, mem_outputs, global_step, score, batch_go,
                    obj_tensors)

                decoder_logit = (
                        tf.nn.conv1d(final_output, tf.expand_dims(tf.transpose(w), axis=0), 1, 'SAME') +
                        tf.expand_dims(tf.expand_dims(b, axis=0), axis=0))
                decoder_target_list = []
                decoder_logit_list = tf.unstack(decoder_logit, axis=1)
                for logit in decoder_logit_list:
                    decoder_target_list.append(tf.argmax(logit, output_type=tf.int32, axis=-1))

                decoder_output_list = [
                    tf.squeeze(d, 1)
                    for d in tf.split(decoder_output, self.model_config.max_simple_sentence, axis=1)]
                final_output_list = [
                    tf.squeeze(d, 1)
                    for d in tf.split(final_output, self.model_config.max_simple_sentence, axis=1)]

                if self.model_config.pointer_mode:
                    segment_mask = None
                    if 'line_comp_segids' in obj:
                        segment_mask = obj['line_comp_segids']
                    decoder_logit_list = word_distribution(
                        decoder_logit_list, decoder_output_list, encoder_outputs, encoder_embed_inputs,
                        sentence_complex_input_placeholder, obj_tensors, self.model_config, self.data, segment_mask)
            elif self.is_train and (train_mode == 'static_seq' or train_mode == 'static_self-critical'):
                decoder_target_list = []
                decoder_logit_list = []
                decoder_embed_inputs_list = []
                # Will Override for following 3 lists
                final_output_list = []
                decoder_output_list = []
                contexts = []
                sample_target_list = []
                sample_logit_list = []

                gpu_assign_interval = int(self.model_config.max_simple_sentence / 3)
                for step in range(self.model_config.max_simple_sentence):
                    gpu_id = int(step / gpu_assign_interval)
                    if gpu_id > 3:
                        gpu_id = 3
                    gpu_id += 1
                    with tf.device('/device:GPU:%s' % gpu_id):
                        print('Step%s with GPU%s' % (step, gpu_id))
                        final_outputs, _, cur_context = self.decode_step(
                            decoder_embed_inputs_list, encoder_outputs, encoder_attn_bias,
                            rule_id_input_placeholder, mem_contexts, mem_outputs, global_step,
                            score, batch_go, obj_tensors)

                        final_output_list = [
                            tf.squeeze(d, 1)
                            for d in tf.split(final_outputs, step+1, axis=1)]
                        final_output = final_output_list[-1]

                        # if self.model_config.npad_mode == 'static_seq':
                        #     final_output = tf.matmul(final_output, npad_w)

                        last_logit_list = self.output_to_logit(final_output, w, b)
                        last_target_list = tf.argmax(last_logit_list, output_type=tf.int32, axis=-1)
                        decoder_logit_list.append(last_logit_list)
                        decoder_target_list.append(last_target_list)
                        decoder_embed_inputs_list.append(
                            tf.stop_gradient(self.embedding_fn(last_target_list, emb_simple)))
                        if train_mode == 'static_self-critical':
                            last_sample_list = tf.multinomial(last_logit_list, 1)
                            sample_target_list.append(last_sample_list)
                            indices = tf.stack(
                                [tf.range(0, self.model_config.batch_size, dtype=tf.int64),
                                 tf.squeeze(last_sample_list)],
                                axis=-1)
                            sample_logit_list.append(tf.gather_nd(tf.nn.softmax(last_logit_list), indices))
            else:
                # Beam Search
                print('Use Beam Search with Beam Search Size %d.' % self.model_config.beam_search_size)
                return self.transformer_beam_search(encoder_outputs, encoder_attn_bias, encoder_embed_inputs_list,
                                                    sentence_complex_input_placeholder, emb_simple, w, b,
                                                    rule_id_input_placeholder, mem_contexts, mem_outputs, global_step,
                                                    score, obj, obj_tensors)

        gt_target_list = sentence_simple_input_placeholder
        output = ModelOutput(
            contexts=cur_context if 'rule' in self.model_config.memory else None,
            encoder_outputs=encoder_outputs,
            decoder_outputs_list=final_output_list if train_mode != 'dynamic_self-critical' else None,
            final_outputs_list=final_output_list if train_mode != 'dynamic_self-critical' else None,
            decoder_logit_list=decoder_logit_list if train_mode != 'dynamic_self-critical' else None,
            gt_target_list=gt_target_list,
            encoder_embed_inputs_list=tf.unstack(encoder_embed_inputs, axis=1),
            decoder_target_list=decoder_target_list,
            sample_logit_list=sampled_logit_list if train_mode == 'dynamic_self-critical' else None,
            sample_target_list=sampled_target_list if train_mode == 'dynamic_self-critical' else None,
            pred_score_tuple=pred_score_tuple if 'pred' in self.model_config.tune_mode else None,
            obj_tensors=obj_tensors,
        )
        return output
def transformer_prepare_encoder(inputs,
                                target_space,
                                hparams,
                                features=None,
                                type_ids=None,
                                num_types=None,
                                reuse_target_embedding=tf.AUTO_REUSE):
    """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.
    type_ids: optional, an int64 Tensor of shape [batch, length] that allows
      for adding type embeddings, similar to positional embeddings.
    num_types: optional, an int that decides the number of types in type_ids.
    reuse_target_embedding: option to reuse variable name in the case that
      symbol modalities are reused between inputs/targets.

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
    ishape_static = inputs.shape.as_list()
    encoder_input = inputs
    if features and "inputs_segmentation" in features:
        # Packed dataset.  Keep the examples from seeing each other.
        inputs_segmentation = features["inputs_segmentation"]
        inputs_position = features["inputs_position"]
        targets_segmentation = features["targets_segmentation"]
        if (hasattr(hparams, "unidirectional_encoder")
                and hparams.unidirectional_encoder):
            tf.logging.info("Using unidirectional encoder")
            encoder_self_attention_bias = (
                common_attention.attention_bias_lower_triangle(
                    common_layers.shape_list(inputs)[1]))
        else:
            encoder_self_attention_bias = (
                common_attention.attention_bias_same_segment(
                    inputs_segmentation, inputs_segmentation))
        encoder_decoder_attention_bias = (
            common_attention.attention_bias_same_segment(
                targets_segmentation, inputs_segmentation))
    else:
        encoder_padding = common_attention.embedding_to_padding(encoder_input)
        ignore_padding = common_attention.attention_bias_ignore_padding(
            encoder_padding)
        if (hasattr(hparams, "unidirectional_encoder")
                and hparams.unidirectional_encoder):
            tf.logging.info("Using unidirectional encoder")
            encoder_self_attention_bias = (
                common_attention.attention_bias_lower_triangle(
                    common_layers.shape_list(inputs)[1]))
        else:
            # Usual case - not a packed dataset.
            encoder_self_attention_bias = ignore_padding
        encoder_decoder_attention_bias = ignore_padding
        inputs_position = None
    if hparams.proximity_bias:
        encoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(inputs)[1])
    if target_space is not None and hparams.get("use_target_space_embedding",
                                                True):
        # Append target_space_id embedding to inputs.
        emb_target_space = common_layers.embedding(
            target_space,
            32,
            ishape_static[-1],
            name="target_space_embedding",
            dtype=hparams.get("activation_dtype", "float32"),
            reuse=reuse_target_embedding)
        emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
        encoder_input += emb_target_space
    if hparams.pos == "timing":
        if inputs_position is not None:
            encoder_input = common_attention.add_timing_signal_1d_given_position(
                encoder_input, inputs_position)
        else:
            encoder_input = common_attention.add_timing_signal_1d(
                encoder_input)
    elif hparams.pos == "timing_from_features":
        encoder_input = common_attention.add_timing_signals_from_features(
            encoder_input, features, hparams.position_features)
    elif hparams.pos == "emb":
        encoder_input = common_attention.add_positional_embedding(
            encoder_input, hparams.max_length, "inputs_positional_embedding",
            inputs_position)

    # Add type embeddings
    if type_ids is not None:
        if not num_types:
            raise ValueError("Need to set num_types as well.")
        encoder_input = common_attention.add_positional_embedding(
            encoder_input, num_types, "inputs_type_embedding", type_ids)

    encoder_self_attention_bias = common_layers.cast_like(
        encoder_self_attention_bias, encoder_input)
    encoder_decoder_attention_bias = common_layers.cast_like(
        encoder_decoder_attention_bias, encoder_input)
    return (encoder_input, encoder_self_attention_bias,
            encoder_decoder_attention_bias)
  def body(self, features):
    hparams = self._hparams
    ps_devices = self._ps_devices
    single_device = (len(ps_devices) == 1)
    assert hparams.num_model_shards % len(ps_devices) == 0
    shards_per_device = hparams.num_model_shards // len(ps_devices)
    model_devices = [ps_devices[i // shards_per_device]
                     for i in range(hparams.num_model_shards)]
    print("model_devices = %s" % model_devices)
    mp = expert_utils.Parallelism(model_devices, reuse=False)
    targets_vocab_size = self._problem_hparams.vocabulary["targets"].vocab_size
    # squeeze out channels, heights
    targets = tf.squeeze(features["targets_raw"], [2, 3])
    targets_embedding_var = mp(
        tf.get_variable, "embedding",
        [[targets_vocab_size, hparams.hidden_size]] * mp.n,
        initializer=tf.random_normal_initializer(
            0.0, hparams.hidden_size**-0.5))
    shifted_targets = common_layers.shift_right_2d(targets)
    # Bypass the symbol modality and use a different embedding on each shard.
    if single_device:
      targets_embedding_var_combined = tf.concat(targets_embedding_var, 1)
      decoder_input_combined = common_layers.embedding(
          shifted_targets, targets_vocab_size,
          hparams.hidden_size * mp.n,
          multiplier=hparams.hidden_size**0.5,
          embedding_var=targets_embedding_var_combined,
      )
      decoder_input = tf.split(decoder_input_combined, mp.n, axis=2)
    else:
      targets_embedding_var_combined = None
      decoder_input = mp(
          common_layers.embedding, shifted_targets, targets_vocab_size,
          hparams.hidden_size,
          multiplier=hparams.hidden_size**0.5,
          embedding_var=targets_embedding_var,
      )
    decoder_self_attention_bias = mp(
        common_attention.attention_bias_lower_triangle,
        tf.shape(targets)[1])
    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = features["targets_segmentation"]
      targets_position = features["targets_position"]
      decoder_self_attention_bias = mp(
          tf.add, decoder_self_attention_bias,
          mp(common_attention.attention_bias_same_segment,
             targets_segmentation, targets_segmentation))
      decoder_input = mp(
          common_attention.add_timing_signal_1d_given_position,
          decoder_input, targets_position)
    else:
      targets_position = None
      decoder_self_attention_bias = mp(
          common_attention.attention_bias_lower_triangle,
          tf.shape(targets)[1])
      decoder_input = mp(common_attention.add_timing_signal_1d, decoder_input)

    if self.has_input:
      inputs = tf.squeeze(features["inputs_raw"], [2, 3])
      inputs_vocab_size = self._problem_hparams.vocabulary["inputs"].vocab_size
      # share everything for now
      share_inputs_and_targets_embedding = True
      if share_inputs_and_targets_embedding:
        assert inputs_vocab_size == targets_vocab_size
        inputs_embedding_var = targets_embedding_var
        inputs_embedding_var_combined = targets_embedding_var_combined
      if single_device:
        encoder_input_combined = common_layers.embedding(
            inputs, inputs_vocab_size,
            hparams.hidden_size * mp.n,
            multiplier=hparams.hidden_size**0.5,
            embedding_var=inputs_embedding_var_combined,
        )
        encoder_input = tf.split(encoder_input_combined, mp.n, axis=2)
      else:
        encoder_input = mp(
            common_layers.embedding, inputs, inputs_vocab_size,
            hparams.hidden_size,
            multiplier=hparams.hidden_size**0.5,
            embedding_var=inputs_embedding_var,
        )
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = features["inputs_segmentation"]
        inputs_position = features["inputs_position"]
        encoder_self_attention_bias = mp(
            common_attention.attention_bias_same_segment,
            inputs_segmentation, inputs_segmentation)
        encoder_decoder_attention_bias = mp(
            common_attention.attention_bias_same_segment,
            targets_segmentation, inputs_segmentation)
        encoder_input = mp(
            common_attention.add_timing_signal_1d_given_position,
            encoder_input, inputs_position)
      else:
        encoder_padding = tf.to_float(tf.equal(inputs, 0))
        ignore_padding = common_attention.attention_bias_ignore_padding(
            encoder_padding)
        encoder_self_attention_bias = ignore_padding
        encoder_decoder_attention_bias = ignore_padding
        inputs_position = None
        encoder_input = mp(common_attention.add_timing_signal_1d, encoder_input)

      # encoder stack here
      with tf.variable_scope("encoder"):
        encoder_input = mp(
            tf.nn.dropout, encoder_input,
            1.0 - hparams.layer_prepostprocess_dropout)
        encoder_output = _layer_stack(
            mp,
            encoder_input,
            encoder_self_attention_bias,
            hparams.encoder_layers,
            hparams)
    else:
      encoder_decoder_attention_bias = None
      encoder_output = None

    with tf.variable_scope("decoder"):
      decoder_input = mp(
          tf.nn.dropout, decoder_input,
          1.0 - hparams.layer_prepostprocess_dropout)
      decoder_output = _layer_stack(
          mp,
          decoder_input,
          decoder_self_attention_bias,
          layers=hparams.decoder_layers,
          hparams=hparams,
          encoder_output=encoder_output,
          encoder_decoder_attention_bias=encoder_decoder_attention_bias)

    # Bypass the symbol modality and compute logits directly.
    # We compute a different set of logits on each shard, and sum them.
    # Share the weights with the target embedding.
    output_var = targets_embedding_var
    output_var_combined = targets_embedding_var_combined
    if single_device:
      decoder_output = tf.concat(decoder_output, 2)
      logits = tf.tensordot(decoder_output, output_var_combined, [[2], [1]])
      num, denom = common_layers.padded_cross_entropy(
          logits, targets, hparams.label_smoothing)
      training_loss = num / denom
    else:
      logits = mp(
          tf.tensordot, decoder_output, output_var, [[[2], [1]]] * mp.n)
      logits = expert_utils.all_reduce_ring(logits, mp)
      # On each device, we compute the loss for a part of the batch.
      # This is faster than computing the whole loss on one shard.
      mp, logits = expert_utils.reduce_by_device(mp, logits, lambda l: l[0])
      def _loss_for_shard(logits, targets, shard):
        logits = common_layers.approximate_split(logits, mp.n, 0)[shard]
        targets = common_layers.approximate_split(targets, mp.n, 0)[shard]
        return common_layers.padded_cross_entropy(
            logits, targets, hparams.label_smoothing)
      num, denom = mp(_loss_for_shard, logits, targets, range(mp.n))
      training_loss = tf.add_n(num) / tf.add_n(denom)
      logits = logits[0]
    logits = tf.expand_dims(tf.expand_dims(logits, 2), 3)
    # override training loss so that it is not computed externally.
    losses = {"training": training_loss}
    return logits, losses
    def create_model(self):
        with tf.variable_scope('variables'):
            abstr_ph = []
            for _ in range(self.model_config.max_abstr_len):
                abstr_ph.append(tf.zeros(self.model_config.batch_size, tf.int32, name='abstract_input'))

            kwords_ph = []
            for _ in range(self.model_config.max_cnt_kword):
                kword = []
                for _ in range(self.model_config.max_kword_len):
                    kword.append(tf.zeros(self.model_config.batch_size, tf.int32, name='kword_input'))
                kwords_ph.append(kword)

            # Train for length control
            if self.is_train:
                kword_occupies_ph = []
                for _ in range(self.model_config.max_cnt_kword):
                    kword_occupies_ph.append(
                        tf.zeros(self.model_config.batch_size, tf.float32, name='kword_occupy_input'))

            emb_abstr, emb_kword, proj_w, proj_b = self.get_embedding()
            abstr = tf.stack(self.embedding_fn(abstr_ph, emb_abstr), axis=1)
            kwords = []
            for kword_idx in range(self.model_config.max_cnt_kword):
                kwords.append(self.embedding_fn(kwords_ph[kword_idx], emb_kword))

        with tf.variable_scope('model_encoder'):
            if self.hparams.pos == 'timing':
                abstr = common_attention.add_timing_signal_1d(abstr)
            encoder_embed_inputs = tf.nn.dropout(abstr,
                                                 1.0 - self.hparams.layer_prepostprocess_dropout)
            abstr_bias = common_attention.attention_bias_ignore_padding(
                tf.to_float(tf.equal(tf.stack(abstr_ph, axis=1),
                                     self.voc_kword.encode(constant.SYMBOL_PAD))))
            abstr_outputs = transformer.transformer_encoder(
                encoder_embed_inputs, abstr_bias, self.hparams)

        losses = []
        targets = []
        pred_occupies = []
        obj = {}

        hist_vector = None
        if 'kp_attn' in self.model_config.cov_mode:
            hist_vector = tf.zeros(
                [self.model_config.batch_size, 1, self.model_config.dimension,])

        with tf.variable_scope('model_decoder'):
            if self.model_config.subword_vocab_size:
                go_id = self.voc_kword.encode(constant.SYMBOL_GO)[0]
            else:
                go_id = self.voc_kword.encode(constant.SYMBOL_GO)
            batch_go = tf.tile(
                tf.expand_dims(self.embedding_fn(go_id, emb_kword), axis=0),
                [self.model_config.batch_size, 1])

            for kword_idx in range(self.model_config.max_cnt_kword):
                if self.is_train:
                    kword = kwords[kword_idx][:-1]
                    kword_ph = kwords_ph[kword_idx]
                    kword_output, kword_output_list = self.decode_step(
                        kword, abstr_outputs, abstr_bias, batch_go, hist_vector=hist_vector)
                    kword_logit_list = [self.output_to_logit(o, proj_w, proj_b) for o in kword_output_list]
                    kword_target_list = [tf.argmax(o, output_type=tf.int32, axis=-1)
                                         for o in kword_logit_list]

                    kword_lossbias = [
                        tf.to_float(tf.not_equal(d, self.voc_kword.encode(constant.SYMBOL_PAD)))
                        for d in kword_ph]
                    kword_lossbias = tf.stack(kword_lossbias, axis=1)
                    if self.model_config.number_samples > 0:
                        loss_fn = tf.nn.sampled_softmax_loss
                    else:
                        loss_fn = None
                    loss = sequence_loss(logits=tf.stack(kword_logit_list, axis=1),
                                         targets=tf.stack(kword_ph, axis=1),
                                         weights=kword_lossbias,
                                         softmax_loss_function=loss_fn,
                                         w=proj_w,
                                         b=proj_b,
                                         decoder_outputs=tf.stack(kword_output_list, axis=1),
                                         number_samples=self.model_config.number_samples
                                         )
                    kword_target = tf.stack(kword_target_list, axis=1)
                    targets.append(kword_target)

                    if 'kp_attn' in self.model_config.cov_mode:
                        kword_embed = self.embedding_fn(kword_ph, emb_kword)
                        hist_vector += tf.expand_dims(tf.reduce_mean(
                            tf.stack(kword_embed, axis=1), axis=1), axis=1)

                    # Train for length control
                    pred_occupy = self.get_pred_occupy_logit(hist_vector, abstr_outputs)
                    occupy_loss = tf.nn.sigmoid_cross_entropy_with_logits(
                        logits=pred_occupy, labels=kword_occupies_ph[kword_idx])
                    loss += tf.reduce_mean(occupy_loss)
                    pred_occupies.append(pred_occupy)

                    losses.append(loss)
                else:
                    loss, kword_target = self.transformer_beam_search(
                        abstr_outputs, abstr_bias, emb_kword, proj_w, proj_b, hist_vector=hist_vector)

                    targets.append(kword_target)
                    losses = loss

                    if 'kp_attn' in self.model_config.cov_mode:
                        kword_embed = self.embedding_fn(kword_target, emb_kword)
                        hist_vector += tf.expand_dims(tf.reduce_mean(kword_embed, axis=1), axis=1)

                    pred_occupy = tf.round(tf.sigmoid(self.get_pred_occupy_logit(hist_vector, abstr_outputs)))
                    pred_occupies.append(pred_occupy)

                tf.get_variable_scope().reuse_variables()
        if targets:
            obj['targets'] = tf.stack(targets, axis=1)
        obj['abstr_ph'] = abstr_ph
        obj['kwords_ph'] = kwords_ph
        if self.is_train:
            obj['kword_occupies_ph'] = kword_occupies_ph
        pred_occupies = tf.stack(pred_occupies, axis=1)
        obj['pred_occupies'] = pred_occupies

        if type(losses) is list:
            losses = tf.add_n(losses)
        return losses, obj
def hierarchical_attention_network_encoder(
        encoder_input,
        encoder_self_attention_bias,
        contexts,
        context_self_attention_biases,
        features,
        hparams,
        name="hierarchical_attention_network_encoder",
        save_weights_to=None,
        make_image_summary=True,
        losses=None):
    input_x = encoder_input
    context_xs = {}
    for context_name in contexts:
        context_xs[context_name] = contexts[context_name]
    context_paddings = {}
    context_nonpaddings = {}
    context_pad_removers = {}

    attention_dropout_broadcast_dims = (
        common_layers.comma_separated_string_to_integer_list(
            getattr(hparams, "attention_dropout_broadcast_dims", "")))

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        input_padding = common_attention.attention_bias_to_padding(
            encoder_self_attention_bias)
        input_nonpadding = 1.0 - input_padding
        for context_name in context_self_attention_biases:
            context_paddings[
                context_name] = common_attention.attention_bias_to_padding(
                    context_self_attention_biases[context_name])
            context_nonpaddings[
                context_name] = 1.0 - context_paddings[context_name]

        input_pad_remover = None
        for context_name in context_paddings:
            context_pad_removers[context_name] = None
        if hparams.use_pad_remover and not common_layers.is_xla_compiled():
            input_pad_remover = expert_utils.PadRemover(input_padding)
            for context_name in context_paddings:
                context_pad_removers[context_name] = expert_utils.PadRemover(
                    context_paddings[context_name])

        temp_hparam = tf.contrib.training.HParams(
        )  # copy hparams except num_hidden_layers -> num_hidden_layers - 1
        for key, val in hparams.values().items():
            temp_hparam.add_hparam(key, val)
        temp_hparam.set_hparam("num_hidden_layers",
                               hparams.num_hidden_layers - 1)
        encoder_output = transformer_with_contexts_layers.transformer_encoder(
            input_x,
            encoder_self_attention_bias,
            temp_hparam,
            nonpadding=features_to_nonpadding(features, "inputs"),
            save_weights_to=save_weights_to,
            make_image_summary=make_image_summary)

        context_encoded_outputs = {}
        for context_name in context_xs:
            context_encoded_outputs[
                context_name] = transformer_with_contexts_layers.transformer_encoder(
                    context_xs[context_name],
                    context_self_attention_biases[context_name],
                    hparams,
                    nonpadding=features_to_nonpadding(features, context_name),
                    save_weights_to=save_weights_to,
                    make_image_summary=make_image_summary)

        with tf.variable_scope('word_abstraction', reuse=tf.AUTO_REUSE):
            encoder_word_level_query = common_layers.dense(
                encoder_output, hparams.hidden_size)  # q_w = f_w(h_t)
            encoder_word_level_abstraction = {}
            for context_name in context_encoded_outputs:
                encoder_word_level_abstraction[
                    context_name] = transformer_with_contexts_layers.multihead_attention(
                        common_layers.layer_preprocess(
                            encoder_word_level_query, hparams),
                        context_encoded_outputs[context_name],
                        context_self_attention_biases[context_name],
                        hparams.attention_key_channels or hparams.hidden_size,
                        hparams.attention_value_channels
                        or hparams.hidden_size,
                        hparams.hidden_size,
                        hparams.num_heads,
                        hparams.attention_dropout,
                        attention_type=hparams.self_attention_type,
                        save_weights_to=save_weights_to,
                        make_image_summary=make_image_summary,
                        max_relative_position=hparams.max_relative_position,
                        dropout_broadcast_dims=attention_dropout_broadcast_dims,
                        max_length=hparams.get("max_length"),
                        vars_3d=hparams.get("attention_variables_3d"))  # s^j,

            sentence_information = tf.concat([
                encoder_word_level_abstraction[context_name]
                for context_name in encoder_word_level_abstraction
            ],
                                             axis=1)

        with tf.variable_scope('sentence_abstraction', reuse=tf.AUTO_REUSE):
            encoder_sentence_level_query = common_layers.dense(
                encoder_output, hparams.hidden_size)  # q_s = f_s(h_t)
            context_padding = common_attention.embedding_to_padding(
                sentence_information)
            ignore_padding = common_attention.attention_bias_ignore_padding(
                context_padding)
            contextual_information = transformer_with_contexts_layers.multihead_attention(
                common_layers.layer_preprocess(encoder_sentence_level_query,
                                               hparams),
                sentence_information,
                ignore_padding,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                attention_type=hparams.self_attention_type,
                save_weights_to=save_weights_to,
                make_image_summary=make_image_summary,
                max_relative_position=hparams.max_relative_position,
                dropout_broadcast_dims=attention_dropout_broadcast_dims,
                max_length=hparams.get("max_length"),
                vars_3d=hparams.get("attention_variables_3d")
            )  # MultiHead(q_s, s^j), [batch, encoder_length, hidden_dim]

            contextual_information = common_layers.dense_relu_dense(
                contextual_information, hparams.filter_size,
                hparams.hidden_size)

        with tf.variable_scope('context_gating', reuse=tf.AUTO_REUSE):
            gate_lambda = tf.nn.sigmoid(
                common_layers.dense(contextual_information,
                                    hparams.hidden_size) +
                common_layers.dense(encoder_output, hparams.hidden_size))
            encoder_output = gate_lambda * encoder_output + (
                1 - gate_lambda) * contextual_information

    return common_layers.layer_preprocess(encoder_output, hparams)
def hierarchical_context_encoder(encoder_input,
                                 encoder_self_attention_bias,
                                 contexts,
                                 context_self_attention_biases,
                                 features,
                                 hparams,
                                 name="discourse_aware_encoder",
                                 save_weights_to=None,
                                 make_image_summary=True,
                                 losses=None):
    input_x = encoder_input
    context_xs = {}
    for context_name in contexts:
        context_xs[context_name] = contexts[context_name]
    context_paddings = {}
    context_nonpaddings = {}
    context_pad_removers = {}

    attention_dropout_broadcast_dims = (
        common_layers.comma_separated_string_to_integer_list(
            getattr(hparams, "attention_dropout_broadcast_dims", "")))

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        input_padding = common_attention.attention_bias_to_padding(
            encoder_self_attention_bias)
        input_nonpadding = 1.0 - input_padding
        for context_name in context_self_attention_biases:
            context_paddings[
                context_name] = common_attention.attention_bias_to_padding(
                    context_self_attention_biases[context_name])
            context_nonpaddings[
                context_name] = 1.0 - context_paddings[context_name]

        input_pad_remover = None
        for context_name in context_paddings:
            context_pad_removers[context_name] = None
        if hparams.use_pad_remover and not common_layers.is_xla_compiled():
            input_pad_remover = expert_utils.PadRemover(input_padding)
            for context_name in context_paddings:
                context_pad_removers[context_name] = expert_utils.PadRemover(
                    context_paddings[context_name])

        temp_hparam = tf.contrib.training.HParams(
        )  # copy hparams except num_hidden_layers -> num_hidden_layers - 1
        for key, val in hparams.values().items():
            temp_hparam.add_hparam(key, val)
        temp_hparam.set_hparam("num_hidden_layers",
                               hparams.num_hidden_layers - 1)
        encoder_output = transformer_with_contexts_layers.transformer_encoder(
            input_x,
            encoder_self_attention_bias,
            temp_hparam,
            nonpadding=features_to_nonpadding(features, "inputs"),
            save_weights_to=save_weights_to,
            make_image_summary=make_image_summary)

        context_encoded_outputs = {}
        for context_name in context_xs:
            context_encoded_outputs[
                context_name] = transformer_with_contexts_layers.transformer_encoder(
                    context_xs[context_name],
                    context_self_attention_biases[context_name],
                    temp_hparam,
                    nonpadding=features_to_nonpadding(features, context_name),
                    save_weights_to=save_weights_to,
                    make_image_summary=make_image_summary)

        with tf.variable_scope("hierarchical_context_encoder",
                               reuse=tf.AUTO_REUSE):
            for context_name in context_encoded_outputs:
                # self attention feed-forward
                _y = ffn_self_attention_layer(
                    context_encoded_outputs[context_name],
                    hparams.hidden_size,
                    hparams.hidden_size,
                    hparams.num_heads,
                    hparams.attention_dropout,
                    save_weights_to=save_weights_to,
                    name="attentive_sum")
                # mean over sequence length
                context_encoded_outputs[context_name] = tf.reduce_mean(
                    _y, axis=1, keep_dims=True)

            encoded_contexts = [
                context_encoded_outputs[context_name]
                for context_name in context_encoded_outputs
            ]
            encoded_contexts = tf.concat(encoded_contexts, axis=1)

            temp_hparam = tf.contrib.training.HParams(
            )  # copy hparams except num_hidden_layers -> 1
            for key, val in hparams.values().items():
                temp_hparam.add_hparam(key, val)
            temp_hparam.set_hparam("num_hidden_layers", 1)
            context_padding = common_attention.embedding_to_padding(
                encoded_contexts)
            ignore_padding = common_attention.attention_bias_ignore_padding(
                context_padding)

            encoded_contexts = transformer_encoder(encoded_contexts,
                                                   ignore_padding, temp_hparam)

        with tf.variable_scope("encoder/layer_%d" % hparams.num_hidden_layers,
                               reuse=tf.AUTO_REUSE):
            with tf.variable_scope("context_input_attention"):
                context_padding = common_attention.embedding_to_padding(
                    encoded_contexts)
                ignore_padding = common_attention.attention_bias_ignore_padding(
                    context_padding)
                _y = common_attention.multihead_attention(
                    common_layers.layer_preprocess(encoder_output, hparams),
                    encoded_contexts,
                    ignore_padding,
                    hparams.attention_key_channels or hparams.hidden_size,
                    hparams.attention_value_channels or hparams.hidden_size,
                    hparams.hidden_size,
                    hparams.num_heads,
                    hparams.attention_dropout,
                    attention_type=hparams.self_attention_type,
                    save_weights_to=save_weights_to,
                    make_image_summary=make_image_summary,
                    max_relative_position=hparams.max_relative_position,
                    dropout_broadcast_dims=attention_dropout_broadcast_dims,
                    max_length=hparams.get("max_length"),
                    vars_3d=hparams.get("attention_variables_3d"))
                encoded_contexts = common_layers.layer_postprocess(
                    encoder_output, _y, hparams)

            with tf.variable_scope("input_self_attention"):
                _y = common_attention.multihead_attention(
                    common_layers.layer_preprocess(encoder_output, hparams),
                    None,
                    encoder_self_attention_bias,
                    hparams.attention_key_channels or hparams.hidden_size,
                    hparams.attention_value_channels or hparams.hidden_size,
                    hparams.hidden_size,
                    hparams.num_heads,
                    hparams.attention_dropout,
                    attention_type=hparams.self_attention_type,
                    save_weights_to=save_weights_to,
                    max_relative_position=hparams.max_relative_position,
                    make_image_summary=make_image_summary,
                    dropout_broadcast_dims=attention_dropout_broadcast_dims,
                    max_length=hparams.get("max_length"),
                    vars_3d=hparams.get("attention_variables_3d"))
                encoder_output = common_layers.layer_postprocess(
                    encoder_output, _y, hparams)

            with tf.variable_scope("gated_sum"):
                _depth = common_layers.shape_list(encoder_output)[-1]
                gate = tf.layers.dense(tf.concat(
                    [encoded_contexts, encoder_output], axis=-1),
                                       _depth,
                                       activation=tf.nn.sigmoid)
                if save_weights_to:
                    save_weights_to["gated_sum"] = gate
                encoder_output = gate * encoder_output + (
                    1. - gate) * encoded_contexts

            with tf.variable_scope("ffn"):
                _y = transformer_ffn_layer(common_layers.layer_preprocess(
                    encoder_output, hparams),
                                           hparams,
                                           input_pad_remover,
                                           conv_padding="SAME",
                                           nonpadding_mask=input_nonpadding,
                                           losses=losses)
                encoder_output = common_layers.layer_postprocess(
                    encoder_output, _y, hparams)

    return common_layers.layer_preprocess(encoder_output, hparams)
Exemple #29
0
def transformer_prepare_encoder(inputs, target_space, hparams, features=None):
    """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
    ishape_static = inputs.shape.as_list()
    encoder_input = inputs
    if features and "inputs_segmentation" in features:
        # Packed dataset.  Keep the examples from seeing each other.
        inputs_segmentation = features["inputs_segmentation"]
        inputs_position = features["inputs_position"]
        targets_segmentation = features["targets_segmentation"]
        encoder_self_attention_bias = common_attention.attention_bias_same_segment(
            inputs_segmentation, inputs_segmentation)
        encoder_decoder_attention_bias = (
            common_attention.attention_bias_same_segment(
                targets_segmentation, inputs_segmentation))
    else:
        # Usual case - not a packed dataset.
        encoder_padding = common_attention.embedding_to_padding(encoder_input)
        ignore_padding = common_attention.attention_bias_ignore_padding(
            encoder_padding)
        encoder_self_attention_bias = ignore_padding
        encoder_decoder_attention_bias = ignore_padding
        inputs_position = None
    if hparams.proximity_bias:
        encoder_self_attention_bias += common_attention.attention_bias_proximal(
            common_layers.shape_list(inputs)[1])
    # Append target_space_id embedding to inputs.
    emb_target_space = common_layers.embedding(target_space,
                                               32,
                                               ishape_static[-1],
                                               name="target_space_embedding")
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
    encoder_input += emb_target_space
    #if hparams.pos == "timing":
    #  if inputs_position is not None:
    #    encoder_input = common_attention.add_timing_signal_1d_given_position(
    #        encoder_input, inputs_position)
    #  else:
    #    encoder_input = common_attention.add_timing_signal_1d(encoder_input)
    raw_encoder_input = tf.squeeze(features['inputs_raw'], axis=[-2, -1])
    pos_signals = generate_positional_signals(raw_encoder_input, hparams)
    pos_embeddings = generate_positional_embeddings(pos_signals,
                                                    hparams.encoder_pos,
                                                    hparams)
    if "sum" in hparams.encoder_pos_integration:
        encoder_input = encoder_input + pos_embeddings
    elif "ffn" in hparams.encoder_pos_integration:
        with tf.variable_scope("encoder_pos_ffn"):
            encoder_input = tf.concat([encoder_input, pos_embeddings], axis=2)
            encoder_input = transformer_ffn_layer(encoder_input,
                                                  hparams,
                                                  conv_padding="SAME")
    return (encoder_input, encoder_self_attention_bias,
            encoder_decoder_attention_bias)
Exemple #30
0
def transformer_prepare_encoder(inputs, target_space, hparams, features=None):
  """Prepare one shard of the model for the encoder.

  Args:
    inputs: a Tensor.
    target_space: a Tensor.
    hparams: run hyperparameters
    features: optionally pass the entire features dictionary as well.
      This is needed now for "packed" datasets.

  Returns:
    encoder_input: a Tensor, bottom of encoder stack
    encoder_self_attention_bias: a bias tensor for use in encoder self-attention
    encoder_decoder_attention_bias: a bias tensor for use in encoder-decoder
      attention
  """
  ishape_static = inputs.shape.as_list()
  encoder_input = inputs
  if features and "inputs_segmentation" in features:
    # Packed dataset.  Keep the examples from seeing each other.
    inputs_segmentation = features["inputs_segmentation"]
    inputs_position = features["inputs_position"]
    targets_segmentation = features["targets_segmentation"]
    if (hasattr(hparams, "unidirectional_encoder") and
        hparams.unidirectional_encoder):
      tf.logging.info("Using unidirectional encoder")
      encoder_self_attention_bias = (
          common_attention.attention_bias_lower_triangle(
              common_layers.shape_list(inputs)[1]))
    else:
      encoder_self_attention_bias = (
          common_attention.attention_bias_same_segment(
              inputs_segmentation, inputs_segmentation))
    encoder_decoder_attention_bias = (
        common_attention.attention_bias_same_segment(targets_segmentation,
                                                     inputs_segmentation))
  else:
    encoder_padding = common_attention.embedding_to_padding(encoder_input)
    ignore_padding = common_attention.attention_bias_ignore_padding(
        encoder_padding)
    if (hasattr(hparams, "unidirectional_encoder") and
        hparams.unidirectional_encoder):
      tf.logging.info("Using unidirectional encoder")
      encoder_self_attention_bias = (
          common_attention.attention_bias_lower_triangle(
              common_layers.shape_list(inputs)[1]))
    else:
      # Usual case - not a packed dataset.
      encoder_self_attention_bias = ignore_padding
    encoder_decoder_attention_bias = ignore_padding
    inputs_position = None
  if hparams.proximity_bias:
    encoder_self_attention_bias += common_attention.attention_bias_proximal(
        common_layers.shape_list(inputs)[1])
  if hparams.get("use_target_space_embedding", True):
    # Append target_space_id embedding to inputs.
    emb_target_space = common_layers.embedding(
        target_space,
        32,
        ishape_static[-1],
        name="target_space_embedding",
        dtype=tf.bfloat16
        if hparams.activation_dtype == "bfloat16" else tf.float32)
    emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
    encoder_input += emb_target_space
  if hparams.pos == "timing":
    if inputs_position is not None:
      encoder_input = common_attention.add_timing_signal_1d_given_position(
          encoder_input, inputs_position)
    else:
      encoder_input = common_attention.add_timing_signal_1d(encoder_input)
  elif hparams.pos == "emb":
    encoder_input = common_attention.add_positional_embedding(
        encoder_input, hparams.max_length, "inputs_positional_embedding",
        inputs_position)
  if hparams.activation_dtype == "bfloat16":
    encoder_self_attention_bias = tf.cast(encoder_self_attention_bias,
                                          tf.bfloat16)
    encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias,
                                             tf.bfloat16)
  return (encoder_input, encoder_self_attention_bias,
          encoder_decoder_attention_bias)
Exemple #31
0
    def encode_lex(self, encoder_input, target_space, hparams):
        '''
        encoder_input: [batch_size, input_len, hidden_dim]
        return: 
            encoder_output: [batch_size, input_len, hidden_dim]
            encoder_decoder_attention_bias: [batch_size, input_len]
        '''
        encoder_output_slices = []
        for i in range(encoder_input.get_shape()[2].value):
            encoder_input_slice = encoder_input[:, :, i, :]

            # bias
            encoder_padding = common_attention.embedding_to_padding(
                encoder_input_slice)
            print(encoder_padding.shape.as_list()
                  )  # ==> [None, None] (None, None, 4)
            ignore_padding = common_attention.attention_bias_ignore_padding(
                encoder_padding)
            encoder_self_attention_bias = ignore_padding
            encoder_decoder_attention_bias = ignore_padding
            print(ignore_padding.shape.as_list()
                  )  # ==> [None, 1, 1, None] (None, 1, 1, None, 4)

            # add target space to encoder input?
            ishape_static = encoder_input_slice.shape.as_list()
            print(ishape_static)  # ==> [None, None, 300] (None, None, 4, 300)
            emb_target_space = common_layers.embedding(
                target_space,
                32,
                ishape_static[-1],
                name="target_space_embedding")
            print(emb_target_space.shape.as_list())  # ==> [300]
            emb_target_space = tf.reshape(emb_target_space, [1, 1, -1])
            print(emb_target_space.shape.as_list())  # ==> [1, 1, 300]
            encoder_input_slice += emb_target_space
            print(encoder_input_slice.shape.as_list()
                  )  # ==> [None, None, 300] (None, None, 4, 300)

            # add timing signals to encoder input
            if hparams.pos == "timing":
                encoder_input_slice = common_attention.add_timing_signal_1d(
                    encoder_input_slice)

            # dropout
            encoder_input_slice = tf.nn.dropout(
                encoder_input_slice,
                1.0 - hparams.layer_prepostprocess_dropout)

            # encoder
            '''
            multihead_attention(
            query_antecedent: [batch, length_q, channels], -- x, x
            memory_antecedent: [batch, length_m, channels], -- None, encoder_output
            bias: bias tensor, -- encoder_self_attention_bias
            total_key_depth: int, -- hparams.attention_key_channels or hparams.hidden_size
            total_value_depth: int, -- hparams.attention_value_channels or hparams.hidden_size
            output_depth: integer, -- hparams.hidden_size
            num_heads: integer dividing total_key_depth and total_value_depth, -- hparams.num_heads (8)
            dropout_rate: float, -- hparams.attention_dropout
            ...
            cache=None: dict, containing tensors which are the results of previous attentions used for fast decoding, {'k': [batch_size, 0, key_channels], 'v': [batch_size, 0, value_channels], used in decoder self-attention)
            '''
            x = encoder_input_slice
            with tf.variable_scope("encoder" + str(i)):
                # remove pad
                pad_remover = None
                if hparams.use_pad_remover:
                    pad_remover = expert_utils.PadRemover(
                        common_attention.attention_bias_to_padding(
                            encoder_self_attention_bias))

                # self-attention along the sentence dimension
                for layer in xrange(hparams.num_encoder_layers
                                    or hparams.num_hidden_layers):
                    with tf.variable_scope("layer_%d" % layer):
                        with tf.variable_scope("self_attention"):
                            query_antecedent = common_layers.layer_preprocess(
                                x, hparams)
                            y = common_attention.multihead_attention(
                                query_antecedent=query_antecedent,
                                memory_antecedent=None,
                                bias=encoder_self_attention_bias,
                                total_key_depth=hparams.attention_key_channels
                                or hparams.hidden_size,
                                total_value_depth=hparams.
                                attention_value_channels
                                or hparams.hidden_size,
                                output_depth=hparams.hidden_size,
                                num_heads=hparams.num_heads,
                                dropout_rate=hparams.attention_dropout,
                                attention_type=hparams.self_attention_type,
                                max_relative_position=hparams.
                                max_relative_position)
                            x = common_layers.layer_postprocess(x, y, hparams)
                        with tf.variable_scope("ffn"):
                            y = transformer.transformer_ffn_layer(
                                common_layers.layer_preprocess(x, hparams),
                                hparams, pad_remover)
                            x = common_layers.layer_postprocess(x, y, hparams)
                encoder_output_slice = common_layers.layer_preprocess(
                    x, hparams)
                print(encoder_output_slice.shape.as_list()
                      )  # ==> [None, None, 300] (None, None, 4, 300)

            encoder_output_slices.append(encoder_output_slice)
        encoder_output = tf.stack(encoder_output_slices, 2)
        print(encoder_output.shape.as_list())  # ==> [None, None, 4, 300]

        # --------

        encoder_output_slices = []
        #hparams2 = copy.deepcopy(hparams)
        #hparams2.hidden_size = hparams.lex_cap
        num_heads = int(hparams.lex_cap / 2)
        hparams2 = tf.contrib.training.HParams(
            layer_preprocess_sequence=hparams.layer_preprocess_sequence,
            layer_postprocess_sequence=hparams.layer_postprocess_sequence,
            layer_prepostprocess_dropout=hparams.layer_prepostprocess_dropout,
            norm_type=hparams.norm_type,
            hidden_size=hparams.lex_cap,
            norm_epsilon=hparams.norm_epsilon,
            ffn_layer=hparams.ffn_layer,
            filter_size=hparams.filter_size,
            relu_dropout=hparams.relu_dropout,
            num_heads=num_heads,
            attention_dropout=hparams.attention_dropout,
            parameter_attention_key_channels=hparams.
            parameter_attention_key_channels,
            parameter_attention_value_channels=hparams.
            parameter_attention_value_channels)

        for i in range(encoder_output.get_shape()[3].value):
            encoder_input_slice = encoder_output[:, :, :, i]
            #print(encoder_input_slice.shape.as_list()) # ==> [None, None, 4]

            encoder_padding = common_attention.embedding_to_padding(
                encoder_input_slice)
            ignore_padding = common_attention.attention_bias_ignore_padding(
                encoder_padding)
            encoder_self_attention_bias = ignore_padding
            #print(encoder_self_attention_bias.shape.as_list()) # ==> [None, 1, 1, None]

            # encoder
            '''
            multihead_attention(
            query_antecedent: [batch, length_q, channels], -- x, x
            memory_antecedent: [batch, length_m, channels], -- None, encoder_output
            bias: bias tensor, -- encoder_self_attention_bias
            total_key_depth: int, -- hparams.attention_key_channels or hparams.hidden_size
            total_value_depth: int, -- hparams.attention_value_channels or hparams.hidden_size
            output_depth: integer, -- hparams.hidden_size
            num_heads: integer dividing total_key_depth and total_value_depth, -- hparams.num_heads (8)
            dropout_rate: float, -- hparams.attention_dropout
            ...
            cache=None: dict, containing tensors which are the results of previous attentions used for fast decoding, {'k': [batch_size, 0, key_channels], 'v': [batch_size, 0, value_channels], used in decoder self-attention)
            '''
            x = encoder_input_slice
            with tf.variable_scope("encoder_extra" + str(i)):
                # remove pad
                pad_remover = None
                if hparams.use_pad_remover:
                    pad_remover = expert_utils.PadRemover(
                        common_attention.attention_bias_to_padding(
                            encoder_self_attention_bias))

                # self-attention along the lexicon dimension
                with tf.variable_scope("layer_extra"):
                    with tf.variable_scope("self_attention"):
                        #query_antecedent = layer_preprocess2(x, hparams, hparams.lex_cap)
                        query_antecedent = common_layers.layer_preprocess(
                            x, hparams2)

                        y = common_attention.multihead_attention(
                            query_antecedent=query_antecedent,
                            memory_antecedent=None,
                            bias=encoder_self_attention_bias,
                            total_key_depth=hparams.attention_key_channels
                            or hparams.lex_cap,
                            total_value_depth=hparams.attention_value_channels
                            or hparams.lex_cap,
                            output_depth=hparams.lex_cap,
                            num_heads=num_heads,
                            dropout_rate=hparams.attention_dropout,
                            attention_type=hparams.self_attention_type,
                            max_relative_position=hparams.max_relative_position
                        )
                        #x = layer_postprocess2(x, y, hparams, hparams.lex_cap)
                        x = common_layers.layer_postprocess(x, y, hparams2)
                    with tf.variable_scope("ffn"):
                        y = transformer.transformer_ffn_layer(
                            common_layers.layer_preprocess(x, hparams2),
                            hparams2, pad_remover)
                        #x = layer_postprocess2(x, y, hparams, hparams.lex_cap)
                        x = common_layers.layer_postprocess(x, y, hparams2)
                #encoder_output_slice = layer_preprocess2(x, hparams, hparams.lex_cap)
                encoder_output_slice = common_layers.layer_preprocess(
                    x, hparams2)
                #print(encoder_output_slice.shape.as_list()) # ==> [None, None, 4] (None, None, 4, 300)

            encoder_output_slices.append(encoder_output_slice)
        encoder_output = tf.stack(encoder_output_slices, 3)
        print(encoder_output.shape.as_list())  # ==> [None, None, 4, 300]

        # --------

        lex_cap = encoder_output.get_shape()[2].value
        embed_len = encoder_output.get_shape()[3].value
        assert (lex_cap == hparams.lex_cap)
        aggregate_layer = tf.get_variable(
            name="Aggregate",
            shape=[embed_len, embed_len, lex_cap],
            initializer=tf.random_normal_initializer(mean=0.0, stddev=0.1))
        encoder_output = tf.tensordot(encoder_output,
                                      aggregate_layer,
                                      axes=[[2, 3], [1, 2]])
        print(encoder_output.shape.as_list())  # ==> [None, None, 300]

        return encoder_output, encoder_decoder_attention_bias
    def body(self, features):
        """Transformer main model_fn.

    Args:
      features: Map of features to the model. Should contain the following:
          "inputs": Transformer inputs [batch_size, input_length, hidden_dim]
          "targets": Target decoder outputs. [batch_size, decoder_length,
            hidden_dim]
          "target_space_id": A scalar int from data_generators.problem.SpaceID.

    Returns:
      Final decoder representation. [batch_size, decoder_length, hidden_dim]
    """
        tf.logging.info("Using PgScratch BODY function.")
        hparams = self._hparams

        losses = {}
        inputs = features["inputs"]
        target_space = features["target_space_id"]
        # encoder_output: <tf.float32>[batch_size, input_length, hidden_dim]
        # encoder_decoder_attention_bias: <tf.float32>[batch_size, input_length]
        encoder_output, encoder_decoder_attention_bias = self.encode(
            inputs, target_space, hparams, features=features, losses=losses)

        with tf.variable_scope("knowledge"):
            with tf.name_scope("knowledge_encoding"):
                # Encode knowledge.
                # <tf.float32>[batch_size, triple_num, emb_dim]
                fact_embedding, fact_lengths = self.encode_knowledge_bottom(
                    features)
                tf.logging.info("Encoded knowledge")

            with tf.name_scope("knowledge_selection_and_loss"):
                # Compute knowledge selection and loss.
                triple_logits, avg_triple_selection_loss, knowledge_encoder_output, transe_loss = self.compute_knowledge_selection_and_loss(
                    features, encoder_output, fact_embedding, fact_lengths,
                    hparams.margin, hparams.num_negative_samples)
                losses["kb_loss"] = avg_triple_selection_loss
                losses["transe_loss"] = transe_loss

        if hparams.attend_kb:
            tf.logging.info("ATTEND_KB is ACTIVE")
            with tf.name_scope("knowledge_attention"):

                knowledge_padding = tf.zeros_like(triple_logits,
                                                  dtype=tf.float32)
                knowledge_attention_bias = common_attention.attention_bias_ignore_padding(
                    knowledge_padding)
                encoder_output = tf.concat(
                    [knowledge_encoder_output, encoder_output], 1)
                encoder_decoder_attention_bias = tf.concat(
                    [knowledge_attention_bias, encoder_decoder_attention_bias],
                    -1)

        else:
            tf.logging.info("ATTEND_KB is INACTIVE")

        targets = features["targets"]
        targets_shape = common_layers.shape_list(targets)
        targets = common_layers.flatten4d3d(targets)

        (decoder_input, decoder_self_attention_bias
         ) = transformer.transformer_prepare_decoder(targets,
                                                     hparams,
                                                     features=features)

        decode_kwargs = {}
        decoder_output = self.decode(
            decoder_input,
            encoder_output,
            encoder_decoder_attention_bias,
            decoder_self_attention_bias,
            hparams,
            nonpadding=transformer.features_to_nonpadding(features, "targets"),
            losses=losses,
            **decode_kwargs)

        expected_attentions = features.get("expected_attentions")
        if expected_attentions is not None:
            attention_loss = common_attention.encoder_decoder_attention_loss(
                expected_attentions, self.attention_weights,
                hparams.expected_attention_loss_type,
                hparams.expected_attention_loss_multiplier)
            return decoder_output, {"attention_loss": attention_loss}

        ret = tf.reshape(decoder_output, targets_shape)
        if losses:
            return ret, losses
        else:
            return ret
Exemple #33
0
  def _fast_decode(self,
                   features,
                   decode_length,
                   beam_size=1,
                   top_beams=1,
                   alpha=1.0):
    """Fast decoding.

    Implements both greedy and beam search decoding, uses beam search iff
    beam_size > 1, otherwise beam search related arguments are ignored.

    Args:
      features: a map of string to model  features.
      decode_length: an integer.  How many additional timesteps to decode.
      beam_size: number of beams.
      top_beams: an integer. How many of the beams to return.
      alpha: Float that controls the length penalty. larger the alpha, stronger
        the preference for slonger translations.

    Returns:
       samples: an integer `Tensor`. Top samples from the beam search

    Raises:
      NotImplementedError: If there are multiple data shards.
    """
    #JI: set images shapes
    imageP = features["imageP"]
    imageP.set_shape([None,19600])
    imageP=tf.reshape(imageP,[-1, img_dim, 100])

    if self._num_datashards != 1:
      raise NotImplementedError("Fast decoding only supports a single shard.")
    dp = self._data_parallelism
    hparams = self._hparams

    inputs = features["inputs"]
    batch_size = tf.shape(inputs)[0]
    target_modality = self._problem_hparams.target_modality
    if t2t_model.is_class_modality(target_modality):
      decode_length = 1
    else:
      decode_length = tf.shape(inputs)[1] + decode_length

    # TODO(llion): Clean up this reshaping logic.
    inputs = tf.expand_dims(inputs, axis=1)
    if len(inputs.shape) < 5:
      inputs = tf.expand_dims(inputs, axis=4)
    s = tf.shape(inputs)
    inputs = tf.reshape(inputs, [s[0] * s[1], s[2], s[3], s[4]])
    # _shard_features called to ensure that the variable names match
    inputs = self._shard_features({"inputs": inputs})["inputs"]
    input_modality = self._problem_hparams.input_modality["inputs"]
    with tf.variable_scope(input_modality.name):
      inputs = input_modality.bottom_sharded(inputs, dp)
    #JI: send images to encoder if needed
    with tf.variable_scope("body"):
      encoder_output, encoder_decoder_attention_bias = dp(
          self.encode, inputs, features["target_space_id"], hparams, imageP=None)
    encoder_output = encoder_output[0]
    encoder_decoder_attention_bias = encoder_decoder_attention_bias[0]

    if hparams.pos == "timing":
      timing_signal = common_attention.get_timing_signal_1d(
          decode_length + 1, hparams.hidden_size)

    def preprocess_targets(targets, i):
      """Performs preprocessing steps on the targets to prepare for the decoder.

      This includes:
        - Embedding the ids.
        - Flattening to 3D tensor.
        - Optionally adding timing signals.

      Args:
        targets: inputs ids to the decoder. [batch_size, 1]
        i: scalar, Step number of the decoding loop.

      Returns:
        Processed targets [batch_size, 1, hidden_dim]
      """
      # _shard_features called to ensure that the variable names match
      targets = self._shard_features({"targets": targets})["targets"]
      with tf.variable_scope(target_modality.name):
        targets = target_modality.targets_bottom_sharded(targets, dp)[0]
      targets = common_layers.flatten4d3d(targets)

      # TODO(llion): Explain! Is this even needed?
      targets = tf.cond(
          tf.equal(i, 0), lambda: tf.zeros_like(targets), lambda: targets)

      if hparams.pos == "timing":
        targets += timing_signal[:, i:i + 1]
      return targets

    decoder_self_attention_bias = (
        common_attention.attention_bias_lower_triangle(decode_length))
    if hparams.proximity_bias:
      decoder_self_attention_bias += common_attention.attention_bias_proximal(
          decode_length)

    def symbols_to_logits_fn(ids, i, cache):
      """Go from ids to logits for next symbol."""
      ids = ids[:, -1:]
      targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
      targets = preprocess_targets(targets, i)

      bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1]

      with tf.variable_scope("body"):
        #JI: send images to decoder if needed
        body_outputs = dp(self.decode, targets, cache["encoder_output"],
                          cache["encoder_decoder_attention_bias"], bias,
                          hparams, cache, imageP=cache["imageP"], imageP_decoder_self_attention_bias=cache["imageP_decoder_self_attention_bias"])

      with tf.variable_scope(target_modality.name):
        logits = target_modality.top_sharded(body_outputs, None, dp)[0]

      return tf.squeeze(logits, axis=[1, 2, 3]), cache

    key_channels = hparams.attention_key_channels or hparams.hidden_size
    value_channels = hparams.attention_value_channels or hparams.hidden_size
    num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers

    cache = {
        "layer_%d" % layer: {
            "k": tf.zeros([batch_size, 0, key_channels]),
            "v": tf.zeros([batch_size, 0, value_channels]),
        }
        for layer in range(num_layers)
    }

    # Set 2nd dim to None since it's not invariant in the tf.while_loop
    # Note: Tensor.set_shape() does not work here since it merges shape info.
    # TODO(llion); Find a more robust solution.
    # pylint: disable=protected-access
    for layer in cache:
      cache[layer]["k"]._shape = tf.TensorShape([None, None, key_channels])
      cache[layer]["v"]._shape = tf.TensorShape([None, None, value_channels])
    # pylint: enable=protected-access
    
    # get image attention bias for decoder
    img_encoder_padding = common_attention.embedding_to_padding(imageP)
    imageP_decoder_self_attention_bias = common_attention.attention_bias_ignore_padding(img_encoder_padding)
    
    cache["encoder_output"] = encoder_output
    cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias
     
    #get images to cache for input to decoder
    cache["imageP"] = imageP
    cache["imageP_decoder_self_attention_bias"] = imageP_decoder_self_attention_bias

    if beam_size > 1:  # Beam Search
      target_modality = (
          self._hparams.problems[self._problem_idx].target_modality)
      vocab_size = target_modality.top_dimensionality
      initial_ids = tf.zeros([batch_size], dtype=tf.int32)
      decoded_ids, scores = beam_search.beam_search(
          symbols_to_logits_fn, initial_ids, beam_size, decode_length,
          vocab_size, alpha, states=cache, stop_early=(top_beams == 1))

      if top_beams == 1:
        decoded_ids = decoded_ids[:, 0, 1:]
      else:
        decoded_ids = decoded_ids[:, :top_beams, 1:]
    else:  # Greedy

      def inner_loop(i, next_id, decoded_ids, cache):
        logits, cache = symbols_to_logits_fn(next_id, i, cache)
        temperature = (0.0 if hparams.sampling_method == "argmax"
                       else hparams.sampling_temp)
        next_id = tf.expand_dims(
            common_layers.sample_with_temperature(logits, temperature), axis=1)
        decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
        return i + 1, next_id, decoded_ids, cache

      decoded_ids = tf.zeros([batch_size, 0], dtype=tf.int64)
      scores = None
      next_id = tf.zeros([batch_size, 1], dtype=tf.int64)
      _, _, decoded_ids, _ = tf.while_loop(
          # TODO(llion): Early stopping.
          lambda i, *_: tf.less(i, decode_length),
          inner_loop,
          [tf.constant(0), next_id, decoded_ids, cache],
          shape_invariants=[
              tf.TensorShape([]),
              tf.TensorShape([None, None]),
              tf.TensorShape([None, None]),
              nest.map_structure(lambda t: tf.TensorShape(t.shape), cache),
          ])

    return decoded_ids, scores
  def test_aaa_glow_training(self, depths, split_plans, prior_type):
    with tf.Graph().as_default():
      _, x_mask, _ = self.get_data()
      x = tf.random_normal((BATCH_SIZE, TARGET_LENGTH, N_CHANNELS),
                           mean=10.0, stddev=3.0, dtype=DTYPE)
      bias = common_attention.attention_bias_ignore_padding(1.0 - x_mask)
      hparams = self.get_hparams()
      hparams.prior_type = prior_type
      hparams.depths = depths
      hparams.split_plans = split_plans
      n_levels = len(hparams.depths.split("/"))
      kwargs = self.get_kwargs(x_mask, hparams)
      _ = kwargs.pop("decoder_self_attention_bias")

      x_inv, _, _, _ = glow.glow(
          "glow", x, x_mask, bias, inverse=False, init=True,
          disable_dropout=True, **kwargs)
      curr_dir = tempfile.mkdtemp()
      model_path = os.path.join(curr_dir, "model")

      with tf.Session() as session:
        saver = tf.train.Saver()
        session.run(tf.global_variables_initializer())
        session.run(x_inv)
        saver.save(session, model_path)

    with tf.Graph().as_default():
      _, x_mask, _ = self.get_data()
      x = tf.random_normal((BATCH_SIZE, TARGET_LENGTH, N_CHANNELS),
                           mean=10.0, stddev=3.0, dtype=DTYPE)
      bias = common_attention.attention_bias_ignore_padding(1.0 - x_mask)
      hparams = self.get_hparams()
      hparams.depths = depths
      hparams.split_plans = split_plans
      kwargs = self.get_kwargs(x_mask, hparams)
      _ = kwargs.pop("decoder_self_attention_bias")
      log_q_z = gops.standard_normal_density(x, x_mask)
      log_q_z = tf.reduce_sum(log_q_z) / tf.reduce_sum(x_mask)

      x_inv, logabsdets, log_ps, zs = glow.glow(
          "glow", x, x_mask, bias, inverse=False, init=False,
          disable_dropout=True, **kwargs)
      x_inv_inv, logabsdets_inv, log_ps_inv, _ = glow.glow(
          "glow", x_inv, x_mask, bias, inverse=True, split_zs=zs, init=False,
          disable_dropout=True, **kwargs)
      logabsdets = tf.reduce_sum(
          logabsdets, axis=0) / tf.reduce_sum(x_mask)
      logabsdets_inv = tf.reduce_sum(
          logabsdets_inv, axis=0) / tf.reduce_sum(x_mask)
      log_ps = tf.reduce_sum(log_ps, axis=0) / tf.reduce_sum(x_mask)
      log_ps_inv = tf.reduce_sum(log_ps_inv, axis=0) / tf.reduce_sum(x_mask)

      with tf.Session() as session:
        saver = tf.train.Saver()
        saver.restore(session, model_path)
        (x, x_inv, x_inv_inv, log_q_z, logabsdets, log_ps,
         logabsdets_inv, log_ps_inv) = session.run([
             x, x_inv, x_inv_inv, log_q_z, logabsdets, log_ps,
             logabsdets_inv, log_ps_inv])
        diff = x - x_inv_inv
        log_ps_diff = log_ps - log_ps_inv
        logabsdets_sum = logabsdets + logabsdets_inv
        self.assertEqual(
            x_inv.shape,
            (BATCH_SIZE, TARGET_LENGTH//(2**(n_levels-1)), N_CHANNELS))
        print (np.max(np.abs(diff)))
        print (np.max(np.abs(log_ps_diff)))
        print (np.max(np.abs(logabsdets_sum)))
        self.assertTrue(np.allclose(diff, 0.0, atol=1e-4),
                        msg=np.max(np.abs(diff)))
        self.assertTrue(np.allclose(log_ps_diff, 0.0, atol=1e-4),
                        msg=np.max(np.abs(log_ps_diff)))
        self.assertTrue(np.allclose(logabsdets_sum, 0.0, atol=1e-4),
                        msg=np.max(np.abs(logabsdets_sum)))