예제 #1
0
    def create_positional_emb_2d(self, targets, max_length_dim, model_dim):
        """Learned 2d positional embedding for images."""
        mesh = targets.mesh
        hparams = self._hparams
        activation_dtype = self.set_activation_type()

        rows_dim = mtf.Dimension("rows", hparams.img_len)
        cols_dim = mtf.Dimension("cols",
                                 hparams.img_len * hparams.num_channels)

        positional_emb_rows_var = mtf.get_variable(
            mesh,
            "positional_emb_rows",
            mtf.Shape([max_length_dim, model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)
        positional_emb_cols_var = mtf.get_variable(
            mesh,
            "positional_emb_cols",
            mtf.Shape([max_length_dim, model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)

        targets_position_x = mtf.range(mesh, rows_dim, dtype=tf.int32)
        targets_position_y = mtf.range(mesh, cols_dim, dtype=tf.int32)
        position_x = mtf.broadcast(
            mtf.gather(positional_emb_rows_var, targets_position_x,
                       max_length_dim),
            mtf.Shape([rows_dim, cols_dim, model_dim]))

        position_y = mtf.broadcast(
            mtf.gather(positional_emb_cols_var, targets_position_y,
                       max_length_dim),
            mtf.Shape([rows_dim, cols_dim, model_dim]))
        return position_x + position_y
    def create_positional_emb_2d(self, targets):
        """Learned 2d positional embedding for images."""
        mesh = targets.mesh

        positional_emb_rows_var = mtf.get_variable(
            mesh,
            "positional_emb_rows",
            mtf.Shape([self.max_length_dim, self.model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=self.activation_type)
        positional_emb_cols_var = mtf.get_variable(
            mesh,
            "positional_emb_cols",
            mtf.Shape([self.max_length_dim, self.model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=self.activation_type)

        targets_position_x = mtf.range(mesh, self.rows_dim, dtype=tf.int32)
        targets_position_y = mtf.range(mesh, self.cols_dim, dtype=tf.int32)
        position_x = mtf.broadcast(
            mtf.gather(positional_emb_rows_var, targets_position_x,
                       self.max_length_dim),
            mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim]))

        position_y = mtf.broadcast(
            mtf.gather(positional_emb_cols_var, targets_position_y,
                       self.max_length_dim),
            mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim]))
        return position_x + position_y
예제 #3
0
 def my_gather(tensor):
     return mtf.gather(tensor,
                       top_beam_index,
                       beam_dim,
                       output_shape=mtf.Shape([
                           double_beam if d == beam_dim else d
                           for d in tensor.shape.dims
                       ]))
예제 #4
0
 def gather(tensor, name):
     with tf.name_scope(prefix + name):
         output_shape = mtf.Shape([
             beam_dim if d == old_beam_dim else d for d in tensor.shape.dims
         ])
         return mtf.gather(tensor,
                           topk_indices,
                           old_beam_dim,
                           output_shape=output_shape)
예제 #5
0
 def logits_fn(step_num, ids, states):
   """Produce logits for this step, and new states."""
   self_attention_k = states[:hparams.num_decoder_layers]
   self_attention_v = states[hparams.num_decoder_layers:]
   ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
   x = (mtf.gather(targets_embedding_var, ids_this_step,
                   self.targets_vocab_dim) +
        mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
   with tf.variable_scope("decoder"):
     x, new_self_attention_k, new_self_attention_v = (
         self._decoder_layer_stack_incremental(
             x,
             step_num,
             encdec_tensors,
             self_attention_k,
             self_attention_v,
             encdec_attention_mask=encoder_attention_mask))
   logits = mtf.matmul(x, softmax_var)
   return logits, new_self_attention_k + new_self_attention_v
예제 #6
0
    def _is_finished(i, unused_alive_seq, alive_log_probs, unused_finished_seq,
                     finished_scores, finished_in_finished, *unused_states):
        """Checking termination condition.

    We terminate when we decoded up to decode_length or the lowest scoring item
    in finished has a greater score that the highest prob item in alive divided
    by the max length penalty

    Args:
      i: loop index
      alive_log_probs: probabilities of the beams. [batch_size, beam_size]
      finished_scores: scores for each of these sequences.
        [batch_size, beam_size]
      finished_in_finished: finished bools for each of these sequences.
        [batch_size, beam_size]

    Returns:
      Bool.
    """
        # TODO(noam): support a different decode length...
        # decode_length = mtf.constant(mesh, length_dim.size, dtype=tf.int32)

        # del alive_log_probs, finished_scores, finished_in_finished
        # return mtf.less(i, length_dim.size)
        if not stop_early:
            return mtf.less(i, decode_length)
        max_length_penalty = mtf.pow(((5. + mtf.to_float(decode_length)) / 6.),
                                     alpha)
        # The best possible score of the most likely alive sequence.
        lower_bound_alive_scores = mtf.gather(
            alive_log_probs, mtf.constant(mesh, 0, dtype=tf.int32),
            beam_dim) / max_length_penalty

        # Now to compute the lowest score of a finished sequence in finished
        # If the sequence isn't finished, we multiply it's score by 0. since
        # scores are all -ve, taking the min will give us the score of the lowest
        # finished item.
        lowest_score_of_finished_in_finished = mtf.reduce_min(
            finished_scores * mtf.to_float(finished_in_finished),
            reduced_dim=beam_dim)

        # If none of the sequences have finished, then the min will be 0 and
        # we have to replace it by -ve INF if it is. The score of any seq in alive
        # will be much higher than -ve INF and the termination condition will not
        # be met.
        lowest_score_of_finished_in_finished += ((1. - mtf.to_float(
            mtf.reduce_any(finished_in_finished, reduced_dim=beam_dim))) *
                                                 -INF)

        bound_is_met = mtf.reduce_all(
            mtf.greater(lowest_score_of_finished_in_finished,
                        lower_bound_alive_scores))
        return mtf.logical_and(mtf.less(i, decode_length),
                               mtf.logical_not(bound_is_met))
예제 #7
0
    def body_fn(step_num, ids, *states):
        """Body function for greedy decoding.

    Args:
      step_num: a mtf.Tensor
      ids: a mtf.Tensor
      *states: additional mtf.Tensors
    Returns:
      new_step_num, new_ids, *new_states
    """
        logits, new_states = logits_fn(step_num, ids, states)
        vocab_dim = logits.shape.dims[-1]
        new_ids = mtf.sample_with_temperature(logits, vocab_dim, temperature)
        if forced_ids is not None:
            # force the new ids to equal the partial targets where specified
            # (positions where partial_targets contain nonzero values)
            forced = mtf.gather(forced_ids, step_num, length_dim)
            new_ids = forced + new_ids * mtf.to_int32(mtf.equal(forced, 0))
        ids += new_ids * mtf.one_hot(step_num, length_dim, dtype=tf.int32)
        new_step_num = step_num + 1
        return [new_step_num, ids] + new_states
예제 #8
0
  def _sample(self, features, mesh):
    hparams = self._hparams
    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if self.has_input:
      inputs = features["inputs"]
      while len(inputs.shape.as_list()) > 2:
        inputs = tf.squeeze(inputs, axis=2)
      actual_batch_size = tf.shape(inputs)[0]
      actual_length = tf.shape(inputs)[1]
      inputs = tf.pad(
          inputs, [[0, hparams.batch_size - actual_batch_size],
                   [0, hparams.max_length - actual_length]])
      inputs = self._import_to_batch_by_length(
          inputs, "inputs", mesh, hparams)
      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.reshape(positional_embedding_var,
                       mtf.Shape([self.length_dim, self.model_dim])))
      encoder_attention_mask = (
          mtf_layers.attention_mask_ignore_padding(
              inputs, dtype=self.activation_dtype))
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.num_encoder_layers,
                              self_attention_mask=encoder_attention_mask)
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)
      encdec_tensors = []
      for layer_num in xrange(hparams.num_decoder_layers):
        with tf.variable_scope("decoder/layer_%d/encdec_attention" % layer_num):
          q_var, k_var, v_var, o_var = mtf_layers.multihead_attention_vars(
              mesh, self.heads_dim, self.model_dim,
              self.kv_dim, self.activation_dtype)
          k = mtf.einsum(
              [encoder_output, k_var],
              mtf.Shape(
                  [self.batch_dim, self.heads_dim,
                   self.memory_length_dim, self.kv_dim]))
          v = mtf.einsum(
              [encoder_output, v_var],
              mtf.Shape(
                  [self.batch_dim, self.heads_dim,
                   self.memory_length_dim, self.kv_dim]))
        encdec_tensors.append((q_var, o_var, k, v))
      partial_targets = None
    else:
      encdec_tensors = None
      encoder_output = None
      encoder_attention_mask = None
      # Prepare partial targets.
      # In either features["inputs"] or features["targets"].
      # We force the outputs to begin with these sequences.
      partial_targets = features.get("inputs", None)
      if partial_targets is None:
        partial_targets = features.get("targets", None)
      if partial_targets is not None:
        partial_targets = common_layers.expand_squeeze_to_nd(partial_targets, 2)
        partial_targets = tf.to_int32(partial_targets)
        partial_targets_batch = tf.shape(partial_targets)[0]
        partial_targets_length = tf.shape(partial_targets)[1]
        partial_targets = tf.pad(
            partial_targets, [[0, hparams.batch_size - partial_targets_batch],
                              [0, hparams.max_length - partial_targets_length]])
        partial_targets = self._import_to_batch_by_length(
            partial_targets, "partial_targets", mesh, hparams)

    if hparams.beam_size == 1:
      ids_shape = mtf.Shape([self.batch_dim, self.length_dim])
      kv_shape = mtf.Shape([self.batch_dim, self.heads_dim,
                            self.memory_length_dim, self.kv_dim])
    else:
      beam_dim = mtf.Dimension("beam", hparams.beam_size)
      ids_shape = mtf.Shape([self.batch_dim, beam_dim, self.length_dim])
      kv_shape = mtf.Shape([self.batch_dim, beam_dim, self.heads_dim,
                            self.memory_length_dim, self.kv_dim])

    initial_ids = mtf.constant(mesh, 0, ids_shape, dtype=tf.int32)
    initial_kv_states = (
        [mtf.zeros(mesh, kv_shape, dtype=self.activation_dtype)]
        * (2 * hparams.num_decoder_layers))
    def logits_fn(step_num, ids, states):
      """Produce logits for this step, and new states."""
      self_attention_k = states[:hparams.num_decoder_layers]
      self_attention_v = states[hparams.num_decoder_layers:]
      ids_this_step = mtf.gather(ids, step_num - 1, self.length_dim)
      x = (mtf.gather(targets_embedding_var, ids_this_step,
                      self.targets_vocab_dim) +
           mtf.gather(positional_embedding_var, step_num, self.max_length_dim))
      with tf.variable_scope("decoder"):
        x, new_self_attention_k, new_self_attention_v = (
            self._decoder_layer_stack_incremental(
                x,
                step_num,
                encdec_tensors,
                self_attention_k,
                self_attention_v,
                encdec_attention_mask=encoder_attention_mask))
      logits = mtf.matmul(x, softmax_var)
      return logits, new_self_attention_k + new_self_attention_v

    if hparams.beam_size == 1:
      temperature = (0.0 if hparams.sampling_method == "argmax"
                     else hparams.sampling_temp)
      return mtf_beam_search.greedy_decode(
          logits_fn,
          initial_ids,
          temperature=temperature,
          initial_states=initial_kv_states,
          forced_ids=partial_targets,
          use_tpu=hparams.use_tpu)
    else:
      if self.has_input:
        input_length = mtf.reduce_sum(
            mtf.to_float(mtf.cast(inputs, tf.bool)),
            reduced_dim=self.length_dim)
        max_input_length = mtf.reduce_max(input_length)
        decode_length = mtf.cast(
            max_input_length * hparams.decode_length_multiplier
            + hparams.decode_length_constant, tf.int32)
      else:
        decode_length = None
      beams, unused_scores = mtf_beam_search.beam_search(
          logits_fn,
          initial_ids,
          hparams.alpha,
          states=initial_kv_states,
          decode_length=decode_length,
          use_tpu=hparams.use_tpu)
      return mtf.gather(beams, mtf.constant(mesh, 0, dtype=tf.int32), beam_dim)
예제 #9
0
  def _mtf_model_fn(self, features, mesh):
    features = copy.copy(features)
    hparams = self._hparams
    targets = tf.to_int32(features["targets"])
    if len(targets.get_shape()) > 2:
      tf.logging.info("targets = %s" % targets)
      targets = tf.squeeze(targets, [2, 3])
    # pad targets to max_length
    def pad_to_max_length(x):
      extra_length = hparams.max_length - tf.shape(x)[1]
      x = tf.pad(x, [[0, 0], [0, extra_length]])
      x = tf.reshape(x, [hparams.batch_size, hparams.max_length])
      return x
    targets = pad_to_max_length(targets)
    for key in ["targets_segmentation", "targets_position",
                "inputs_segmentation", "inputs_position"]:
      if key in features:
        features[key] = pad_to_max_length(features[key])
    shifted_targets = common_layers.shift_right_2d(targets)

    targets = self._import_to_batch_by_length(targets, "targets", mesh, hparams)
    shifted_targets = self._import_to_batch_by_length(
        shifted_targets, "shifted_targets", mesh, hparams)

    if "targets_segmentation" in features:
      # "Packed" dataset - keep the examples from seeing each other.
      targets_segmentation = self._import_to_batch_by_length(
          features["targets_segmentation"], "targets_segmentation",
          mesh, hparams)
      targets_position = self._import_to_batch_by_length(
          features["targets_position"], "targets_position",
          mesh, hparams)
      decoder_self_attention_mask = (
          mtf_layers.attention_mask_autoregressive(
              targets_position, dtype=self.activation_dtype) +
          mtf_layers.attention_mask_same_segment(
              targets_segmentation, dtype=self.activation_dtype))
    else:
      targets_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
      decoder_self_attention_mask = mtf_layers.attention_mask_autoregressive(
          targets_position, dtype=self.activation_dtype)

    def layer_prepostprocess_dropout(x):
      return mtf.dropout(
          x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
          noise_shape=mtf.Shape([self.batch_dim, self.model_dim]))

    extra_losses = []
    (inputs_embedding_var,
     targets_embedding_var,
     softmax_var,
     positional_embedding_var) = self._embedding_and_softmax_vars(mesh)
    if self.has_input:
      inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
      inputs = pad_to_max_length(inputs)
      inputs = self._import_to_batch_by_length(inputs, "inputs", mesh, hparams)
      if "inputs_segmentation" in features:
        # "Packed" dataset - keep the examples from seeing each other.
        inputs_segmentation = self._import_to_batch_by_length(
            features["inputs_segmentation"], "inputs_segmentation",
            mesh, hparams)
        inputs_position = self._import_to_batch_by_length(
            features["inputs_position"], "inputs_position",
            mesh, hparams)
        encoder_self_attention_mask = (
            mtf_layers.attention_mask_same_segment(
                inputs_segmentation, dtype=self.activation_dtype))
        encoder_decoder_attention_mask = (
            mtf_layers.attention_mask_same_segment(
                targets_segmentation, inputs_segmentation,
                dtype=self.activation_dtype))
      else:
        inputs_position = mtf.range(mesh, self.length_dim, dtype=tf.int32)
        encoder_self_attention_mask = (
            mtf_layers.attention_mask_ignore_padding(
                inputs, dtype=self.activation_dtype))
        encoder_decoder_attention_mask = encoder_self_attention_mask

      x = (mtf.gather(inputs_embedding_var, inputs, self.inputs_vocab_dim) +
           mtf.gather(positional_embedding_var, inputs_position,
                      self.max_length_dim))
      x = layer_prepostprocess_dropout(x)
      with tf.variable_scope("encoder"):
        x = self._layer_stack(x,
                              hparams.num_encoder_layers,
                              self_attention_mask=encoder_self_attention_mask,
                              losses=extra_losses)
      encoder_output = mtf.rename_dimension(
          x, self.length_dim.name, self.memory_length_dim.name)
    else:
      encoder_output = None
      encoder_decoder_attention_mask = None

    # DECODER
    x = (mtf.gather(
        targets_embedding_var, shifted_targets, self.targets_vocab_dim) +
         mtf.gather(
             positional_embedding_var, targets_position, self.max_length_dim))
    x = layer_prepostprocess_dropout(x)

    # Decoder
    with tf.variable_scope("decoder"):
      x = self._layer_stack(
          x,
          hparams.num_decoder_layers,
          encoder_output=encoder_output,
          self_attention_mask=decoder_self_attention_mask,
          encdec_attention_mask=encoder_decoder_attention_mask,
          losses=extra_losses)
    logits = mtf.matmul(x, softmax_var)
    off_value = hparams.label_smoothing / self._targets_vocab_size
    on_value = 1.0 - hparams.label_smoothing + off_value
    soft_targets = mtf.one_hot(
        targets, self.targets_vocab_dim, on_value=on_value, off_value=off_value,
        dtype=self.activation_dtype)
    loss = mtf_layers.softmax_cross_entropy_with_logits(
        logits, soft_targets, self.targets_vocab_dim)
    weights = mtf_layers.weights_nonzero(
        targets, dtype=self.activation_dtype)
    loss = mtf.reduce_mean(loss * weights)
    for l in extra_losses:
      loss += l
    return logits, loss
예제 #10
0
    def mtf_model_fn(self, features, mesh):
        features = copy.copy(features)
        tf.logging.info("features = %s" % features)
        hparams = self._hparams
        activation_dtype = self.set_activation_type()

        # We assume fixed vocab size for targets
        targets_vocab_size = self._problem_hparams.target_modality._vocab_size  # pylint: disable=protected-access
        targets = tf.to_int32(features["targets"])

        # Image preprocessing, reshape into a 1D sequence and shift right.
        length = hparams.img_len * hparams.img_len * hparams.num_channels
        targets = tf.reshape(targets, [hparams.batch_size, length])
        shifted_targets = common_layers.shift_right_2d(targets)

        # Declare all the dimensions
        model_dim = mtf.Dimension("d_model", hparams.hidden_size)
        batch_dim = mtf.Dimension("batch", hparams.batch_size)
        length_dim = mtf.Dimension("length", length)
        max_length_dim = mtf.Dimension("max_length", hparams.max_length)
        filter_dim = mtf.Dimension("d_ff", hparams.d_ff)
        kv_channels = mtf.Dimension("kv_channels", hparams.d_kv)
        heads = mtf.Dimension("heads", hparams.num_heads)

        def import_to_batch_by_length(x, name):
            return mtf.import_tf_tensor(mesh,
                                        x,
                                        mtf.Shape([batch_dim, length_dim]),
                                        name=name)

        def layer_prepostprocess_dropout(x):
            return mtf.dropout(x,
                               keep_prob=1.0 -
                               hparams.layer_prepostprocess_dropout,
                               noise_shape=mtf.Shape([batch_dim, model_dim]))

        targets = import_to_batch_by_length(targets, "targets")
        shifted_targets = import_to_batch_by_length(shifted_targets,
                                                    "shifted_targets")

        extra_losses = []

        # Create targets content and position embeddings.
        targets_vocab_size = 256 * hparams.num_channels
        targets_vocab_dim = mtf.Dimension("vocab", targets_vocab_size)
        outputs_vocab_dim = mtf.Dimension("output_vocab", 256)

        # Create embedding var for targets and positions and do a gather.
        targets_embedding_var = mtf.get_variable(
            mesh,
            "targets_embedding",
            mtf.Shape([targets_vocab_dim, model_dim]),
            initializer=tf.random_normal_initializer(),
            activation_dtype=activation_dtype)

        x = mtf.gather(targets_embedding_var, shifted_targets,
                       targets_vocab_dim)
        # Add positional embeddings
        x += mtf.reshape(
            self.create_positional_emb_2d(targets, max_length_dim, model_dim),
            [length_dim, model_dim])

        # If conditional and input is given, add the input embedding to the target.
        # TODO(nikip): Verify conditional.
        if self.has_input and not hparams.unconditional:
            vocab_size = hparams.num_classes
            inputs_vocab_dim = mtf.Dimension("vocab", vocab_size)
            inputs = tf.squeeze(tf.to_int32(features["inputs"]), [2, 3])
            inputs = import_to_batch_by_length(inputs, "inputs")

            # Input embeddings
            inputs_embedding_var = mtf_layers.embedding(
                mesh,
                "input_embedding",
                mtf.Shape([inputs_vocab_dim, model_dim]),
                activation_dtype=activation_dtype)
            inputs_emb = mtf.gather(inputs_embedding_var, inputs,
                                    inputs_vocab_dim)
            x += inputs_emb

        # Image Transformer Decoder
        # [ self attention - ffn - residual + dropout] x n
        for layer in range(hparams.num_decoder_layers):
            layer_name = "decoder_layer_%d" % layer
            with tf.variable_scope(layer_name):
                # Self attention layer
                x += layer_prepostprocess_dropout(
                    mtf_layers.masked_local_attention_1d(
                        mtf_layers.layer_norm(x,
                                              model_dim,
                                              name="layer_norm_self_att"),
                        None,
                        kv_channels,
                        heads,
                        block_length=hparams.block_length,
                        name="self_att"))
                # ffn layer
                x += layer_prepostprocess_dropout(
                    mtf_layers.dense_relu_dense(
                        mtf_layers.layer_norm(x,
                                              model_dim,
                                              name="layer_norm_ffn"),
                        filter_dim,
                        hparams.dropout,
                        dropout_broadcast_dims=[length_dim]))

        x = mtf_layers.layer_norm(x,
                                  model_dim,
                                  name="decoder_final_layer_norm")

        # Calculate the logits and loss.
        logits = mtf_layers.dense(x, outputs_vocab_dim, name="logits")
        soft_targets = mtf.one_hot(targets,
                                   outputs_vocab_dim,
                                   dtype=activation_dtype)
        loss = mtf_layers.softmax_cross_entropy_with_logits(
            logits, soft_targets, outputs_vocab_dim)

        loss = mtf.reduce_mean(loss)
        for l in extra_losses:
            loss += l
        return logits, loss