Пример #1
0
  def infer(self,
            features=None,
            decode_length=50,
            beam_size=1,
            top_beams=1,
            alpha=0.0,
            use_tpu=False):
    """Returns the targets and their log probabilities."""
    del decode_length, beam_size, top_beams, alpha, use_tpu
    assert features is not None

    self._fill_problem_hparams_features(features)

    # Run the model
    self.hparams.force_full_predict = True
    with tf.variable_scope(self.name):
      logits, _ = self.model_fn(features)
    assert len(logits.shape) == 5  # [batch, time, 1, 1, vocab]
    logits = tf.squeeze(logits, [2, 3])
    #import pdb; pdb.set_trace()

    # Compute the log probabilities
    log_probs = common_layers.log_prob_from_logits(logits)

    targets = features["targets"]
    assert len(targets.shape) == 4  # [batch, time, 1, 1]
    targets = tf.squeeze(targets, [2, 3])

    # Slice out the log_probs of the targets
    log_probs = common_layers.index_last_dim_with_indices(log_probs, targets)

    # return log-probs instead of beam-score
    return {"outputs": targets, "scores": log_probs}
Пример #2
0
        def inner_loop(
            i,
            hit_eos,
            next_id,
            next_id_tag,
            decoded_ids,
            decoded_ids_tag,
            cache,
            log_prob,
        ):
            """One step of greedy decoding."""
            logits, logits_tag, cache = symbols_to_logits_fn(
                next_id, next_id_tag, i, cache)
            log_probs = common_layers.log_prob_from_logits(logits)
            temperature = sampling_temperature
            if hparams.sampling_method == 'random_per_example':
                next_id = common_layers.sample_temperature_per_example(
                    logits, temperature, top_k)
            else:
                if hparams.sampling_method == 'argmax':
                    temperature = 0.0
                next_id = common_layers.sample_with_temperature(
                    logits, temperature, top_k)

            if hparams.sampling_method == 'random_per_example':
                next_id_tag = common_layers.sample_temperature_per_example(
                    logits_tag, temperature, top_k)
            else:
                if hparams.sampling_method == 'argmax':
                    temperature = 0.0
                next_id_tag = common_layers.sample_with_temperature(
                    logits_tag, temperature, top_k)

            log_prob_indices = tf.stack(
                [tf.range(tf.to_int64(batch_size)), next_id], axis=1)
            log_prob += tf.gather_nd(
                log_probs, log_prob_indices) * (1 - tf.to_float(hit_eos))
            hit_eos |= tf.equal(next_id, eos_id)

            next_id = tf.expand_dims(next_id, axis=1)
            decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
            next_id_tag = tf.expand_dims(next_id_tag, axis=1)
            decoded_ids_tag = tf.concat([decoded_ids_tag, next_id_tag], axis=1)

            return (
                i + 1,
                hit_eos,
                next_id,
                next_id_tag,
                decoded_ids,
                decoded_ids_tag,
                cache,
                log_prob,
            )
Пример #3
0
    def infer(self,
              features=None,
              decode_length=50,
              beam_size=1,
              top_beams=1,
              alpha=0.0,
              use_tpu=False):
        """Returns the targets and their log probabilities."""
        del decode_length, beam_size, top_beams, alpha
        assert features is not None

        # Run the model
        self.hparams.force_full_predict = True
        with tf.variable_scope(self.name):
            logits, _ = self.model_fn(features)
        assert len(
            logits.shape
        ) == 5  # [batch, time, 1, 1, vocab] this logits is [1,1,245] not the same as EVAL mode logits [1,1,step,1,245]
        logits = tf.squeeze(logits, [2, 3])

        # Compute the log probabilities

        log_probs = common_layers.log_prob_from_logits(logits)

        # Slice out the log_probs of the targets
        targets = features["targets"]
        assert len(targets.shape) == 4  # [batch, time, 1, 1]
        targets = tf.squeeze(targets, [2, 3])
        batch_size, timesteps = common_layers.shape_list(targets)
        vocab_size = common_layers.shape_list(log_probs)[-1]
        flat_targets = tf.reshape(targets, [batch_size * timesteps])
        flat_log_probs = tf.reshape(log_probs,
                                    [batch_size * timesteps, vocab_size])
        flat_indices = tf.stack([
            tf.range(tf.to_int64(batch_size) * tf.to_int64(timesteps)),
            tf.to_int64(flat_targets)
        ],
                                axis=1)

        # log_probs = tf.reshape(
        #     tf.gather_nd(flat_log_probs, flat_indices),
        #     [batch_size, timesteps])

        # Sum over time to get the log_prob of the sequence

        #scores = tf.reduce_sum(log_probs, axis=1)  #[batch,step]

        #return {"outputs": targets, "scores": scores} #origin
        return {"outputs": targets, "scores": log_probs}
def compute_uncertainty_reward(logits, predictions):
    """Uncertainty reward based on logits."""
    # TODO(rsepassi): Add support for L1/L2 loss models. Current code only
    # works for softmax models.
    vocab_size = logits.shape[-1]
    assert vocab_size > 1
    log_probs = common_layers.log_prob_from_logits(logits)
    max_log_probs = common_layers.index_last_dim_with_indices(
        log_probs, predictions)
    # Threshold
    neg_log_prob = tf.nn.relu(-max_log_probs - 0.02)
    # Sum across all but the batch dimension
    reduce_dims = list(range(len(neg_log_prob.shape)))[1:]
    summed = tf.reduce_sum(neg_log_prob, axis=reduce_dims)
    return summed / 10
Пример #5
0
def compute_uncertainty_reward(logits, predictions):
  """Uncertainty reward based on logits."""
  # TODO(rsepassi): Add support for L1/L2 loss models. Current code only
  # works for softmax models.
  vocab_size = logits.shape[-1]
  assert vocab_size > 1
  log_probs = common_layers.log_prob_from_logits(logits)
  max_log_probs = common_layers.index_last_dim_with_indices(log_probs,
                                                            predictions)
  # Threshold
  neg_log_prob = tf.nn.relu(-max_log_probs - 0.02)
  # Sum across all but the batch dimension
  reduce_dims = list(range(len(neg_log_prob.shape)))[1:]
  summed = tf.reduce_sum(neg_log_prob, axis=reduce_dims)
  return summed / 10
Пример #6
0
    def inner_loop(i, hit_eos, next_id, decoded_ids, cache, log_prob):
      """One step of greedy decoding."""
      logits, cache = symbols_to_logits_fn(next_id, i, cache)
      log_probs = common_layers.log_prob_from_logits(logits)
      temperature = (0.0 if hparams.sampling_method == "argmax" else
                     hparams.sampling_temp)
      next_id = common_layers.sample_with_temperature(logits, temperature)
      hit_eos |= tf.equal(next_id, eos_id)

      log_prob_indices = tf.stack(
          [tf.range(tf.to_int64(batch_size)), next_id], axis=1)
      log_prob += tf.gather_nd(log_probs, log_prob_indices)

      next_id = tf.expand_dims(next_id, axis=1)
      decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
      return i + 1, hit_eos, next_id, decoded_ids, cache, log_prob
Пример #7
0
 def infer(self,
           features=None,
           decode_length=1,
           beam_size=1,
           top_beams=1,
           alpha=0.0,
           use_tpu=False):
     """Predict."""
     features["targets"] = tf.identity(features["inputs"])
     logits, _ = self(features)
     log_probs = common_layers.log_prob_from_logits(logits)
     predictions, scores = common_layers.argmax_with_score(log_probs)
     return {
         "outputs": predictions,
         "scores": scores,
     }
Пример #8
0
 def infer(self,
           features=None,
           decode_length=1,
           beam_size=1,
           top_beams=1,
           alpha=0.0,
           use_tpu=False):
     """Predict."""
     del decode_length, beam_size, top_beams, alpha, use_tpu
     assert features is not None
     logits, _ = self(features)
     assert len(logits.get_shape()) == 5
     logits = tf.squeeze(logits, [1, 2, 3])
     log_probs = common_layers.log_prob_from_logits(logits)
     predictions, scores = common_layers.argmax_with_score(log_probs)
     return {
         "outputs": predictions,
         "scores": scores,
     }
Пример #9
0
 def infer(self,
           features=None,
           decode_length=50,
           beam_size=1,
           top_beams=1,
           alpha=0.0,
           use_tpu=False):
   """Predict."""
   del decode_length, beam_size, top_beams, alpha, use_tpu
   assert features is not None
   logits, _ = self(features)  # pylint: disable=not-callable
   assert len(logits.get_shape()) == 5
   logits = tf.squeeze(logits, [1, 2, 3])
   log_probs = common_layers.log_prob_from_logits(logits)
   predictions, scores = common_layers.argmax_with_score(log_probs)
   return {
       "outputs": predictions,
       "scores": scores,
   }
Пример #10
0
    def grow_topk(i, alive_seq, alive_log_probs, states):
        r"""Inner beam search loop.

    This function takes the current alive sequences, and grows them to topk
    sequences where k = 2*beam. We use 2*beam because, we could have beam_size
    number of sequences that might hit <EOS> and there will be no alive
    sequences to continue. With 2*beam_size, this will not happen. This relies
    on the assumption the vocab size is > beam size. If this is true, we'll
    have at least beam_size non <EOS> extensions if we extract the next top
    2*beam words.
    Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to
    https://arxiv.org/abs/1609.08144.

    Args:
      i: loop index
      alive_seq: Topk sequences decoded so far [batch_size, beam_size, i+1]
      alive_log_probs: probabilities of these sequences. [batch_size, beam_size]
      states: dict (possibly nested) of decoding states.
    Returns:
      Tuple of
        (Topk sequences extended by the next word,
         The log probs of these sequences,
         The scores with length penalty of these sequences,
         Flags indicating which of these sequences have finished decoding,
         dict of transformed decoding states)
    """
        # Get the logits for all the possible next symbols
        if use_tpu:
            flat_ids = tf.reshape(
                tf.slice(alive_seq, [0, 0, i], [batch_size, beam_size, 1]),
                [batch_size * beam_size, -1])
        else:
            flat_ids = tf.reshape(alive_seq, [batch_size * beam_size, -1])

        # (batch_size * beam_size, decoded_length)
        if states:
            flat_states = nest.map_structure(_merge_beam_dim, states)
            flat_logits, flat_states = symbols_to_logits_fn(
                flat_ids, i, flat_states)
            states = nest.map_structure(
                lambda t: _unmerge_beam_dim(t, batch_size, beam_size),
                flat_states)
        else:
            flat_logits = symbols_to_logits_fn(flat_ids)

        logits = tf.reshape(flat_logits, [batch_size, beam_size, -1])

        # Convert logits to normalized log probs
        candidate_log_probs = common_layers.log_prob_from_logits(logits)

        # Multiply the probabilities by the current probabilities of the beam.
        # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1)
        log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs,
                                                         axis=2)

        length_penalty = tf.pow(((5. + tf.to_float(i + 1)) / 6.), alpha)

        curr_scores = log_probs / length_penalty
        # Flatten out (beam_size, vocab_size) probs in to a list of possibilities
        flat_curr_scores = tf.reshape(curr_scores,
                                      [-1, beam_size * vocab_size])

        topk_scores, topk_ids = tf.nn.top_k(flat_curr_scores, k=beam_size * 2)

        # Recovering the log probs because we will need to send them back
        topk_log_probs = topk_scores * length_penalty

        # Work out what beam the top probs are in.
        topk_beam_index = topk_ids // vocab_size
        topk_ids %= vocab_size  # Unflatten the ids

        if not use_tpu:
            # The next three steps are to create coordinates for tf.gather_nd to pull
            # out the correct sequences from id's that we need to grow.
            # We will also use the coordinates to gather the booleans of the beam
            # items that survived.
            batch_pos = compute_batch_indices(batch_size, beam_size * 2)

            # top beams will give us the actual coordinates to do the gather.
            # stacking will create a tensor of dimension batch * beam * 2, where the
            # last dimension contains the i,j gathering coordinates.
            topk_coordinates = tf.stack([batch_pos, topk_beam_index], axis=2)

            # Gather up the most probable 2*beams both for the ids and
            # finished_in_alive bools
            topk_seq = tf.gather_nd(alive_seq, topk_coordinates)
            if states:
                states = nest.map_structure(
                    lambda state: tf.gather_nd(state, topk_coordinates),
                    states)

            # Append the most probable alive
            topk_seq = tf.concat(
                [topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)
        else:
            # Gather up the most probable 2*beams both for the ids and
            # finished_in_alive bools
            topk_seq = fast_tpu_gather(alive_seq, topk_beam_index)

            if states:
                states = nest.map_structure(
                    lambda state: fast_tpu_gather(state, topk_beam_index),
                    states)

            # Update the most probable alive
            topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1])
            topk_seq = inplace_ops.alias_inplace_update(
                topk_seq, i + 1, topk_ids)
            topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0])

        topk_finished = tf.equal(topk_ids, eos_id)

        return topk_seq, topk_log_probs, topk_scores, topk_finished, states
    def _init_env(self):
        FLAGS.use_tpu = False
        #tf.logging.set_verbosity(tf.logging.DEBUG)
        tf.logging.info("Import usr dir from %s", self._usr_dir)
        if self._usr_dir != None:
            #usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
            usr_dir.import_usr_dir(self._usr_dir)
        tf.logging.info("Start to create hparams,for %s of %s", self._problem, self._hparams_set)

        self._hparams = create_hparams()

        self._hparams_decode = create_decode_hparams(extra_length=self._extra_length,
                                                     batch_size=self._batch_size,
                                                     beam_size=self._beam_size,
                                                     alpha=self._alpha,
                                                     return_beams=self._return_beams,
                                                     write_beam_scores=self._write_beam_scores,
                                                     force_decode_length=self._force_decode_length)



        self.estimator = trainer_lib.create_estimator(
            FLAGS.model,
            self._hparams,
            t2t_trainer.create_run_config(self._hparams),
            decode_hparams=self._hparams_decode,
            use_tpu=False)

        tf.logging.info("Finish intialize environment")

        ####### problem type :输出分类 还是序列 还是语言模型
        #self.problem_type = self._hparams.problem_hparams[0].target_modality[0] #class? symble
        self.problem_type = self._hparams.problem_hparams.target_modality[0]
        #self._whether_has_inputs = self._hparams.problem[0].has_inputs
        self._whether_has_inputs = self._hparams.problem.has_inputs
        self._beam_size=1 if self._customer_problem_type=='classification' else self._beam_size



        ### make input placeholder
        #self._inputs_ph = tf.placeholder(dtype=tf.int32)  # shape not specified,any shape

        # x=tf.placeholder(dtype=tf.int32)
        # x.set_shape([None, None]) # ? -> (?,?)
        # x = tf.expand_dims(x, axis=[2])# -> (?,?,1)
        # x = tf.to_int32(x)
        #self._inputs_ph=x

        #batch_inputs = tf.reshape(self._inputs_ph, [self._batch_size, -1, 1, 1])
        #batch_inputs=x

        # batch_inputs = tf.reshape(self._inputs_ph, [-1, -1, 1, 1])

        batch_inputs,self._targets_ph,self.input_extra_length_ph=get_ph(x_dim_3=True)

        #targets_ph = tf.placeholder(dtype=tf.int32)
        #batch_targets = tf.reshape(targets_ph, [1, -1, 1, 1])
        self._features = {"inputs": batch_inputs,
                    "problem_choice": 0,  # We run on the first problem here.
                    "input_space_id": self._hparams.problem_hparams.input_space_id,
                    "target_space_id": self._hparams.problem_hparams.target_space_id}
        ### 加入 decode length  变长的
        #self.input_extra_length_ph = tf.placeholder(dtype=tf.int32,shape=[])
        self._features['decode_length'] = self.input_extra_length_ph # total_decode=input_len+extra_len|  extra of chunkProblem =0
        # real_decode_length=len(input)+extra_length
        ##
        #self._features['decode_length_decide_end'] = True

        #### 如果是relative 参数
        if self._hparams_set=="transformer_relative":
            del self._features['problem_choice']
            del self._features['input_space_id']
            del self._features['target_space_id']

        if self._customer_problem_type=='languageModel_pp':
            del self._features['problem_choice']
            del self._features['input_space_id']
            del self._features['target_space_id']
        if self._model_name in ['slice_net','transformer_encoder']:
            del self._features['problem_choice']
            del self._features['input_space_id']
            del self._features['target_space_id']
        if self._model_name=='transformer' and self._customer_problem_type=='classification':
            del self._features['problem_choice']
            del self._features['input_space_id']
            del self._features['target_space_id']




        ###### target if transformer_scorer
        if self._customer_problem_type=='classification':
            self._targets_ph = tf.placeholder(tf.int32, shape=(None, None, None, None), name='targets')
            self._features['targets'] = self._targets_ph  # batch targets

        if self._customer_problem_type=='languageModel_pp':
            self._targets_ph = tf.placeholder(tf.int32, shape=(None, None, None, None), name='targets')
            self._features['targets']=  self._targets_ph


        #### mode
        mode = tf.estimator.ModeKeys.PREDICT
        if self._customer_problem_type == 'languageModel_pp':
            mode = tf.estimator.ModeKeys.EVAL
        elif self._customer_problem_type=='classification' and 'score' not in self._model_name:
            mode = tf.estimator.ModeKeys.EVAL
        # estimator_spec = model_builder.model_fn(self._model_name, features, mode, self._hparams,
        #                                         problem_names=[self._problem], decode_hparams=self._hparams_dc)
        predictions_dict = self.estimator._call_model_fn(self._features,None,mode,t2t_trainer.create_run_config(self._hparams))
        self._predictions_dict=predictions_dict.predictions
        # score -> score_yr
        if self._customer_problem_type=='classification' and 'score' in self._model_name:
            self._score=predictions_dict.predictions.get('scores')
            if self._score!=None: #[batch,beam] [batch,]
                self._predictions_dict['scores_class']=tf.exp(common_layers.log_prob_from_logits(self._score))
        elif self._customer_problem_type=='classification' and 'score' not in self._model_name:
            self._score = predictions_dict.predictions.get('predictions')
            if self._score!=None: #[batch,beam] [batch,]
                self._predictions_dict['scores_class']=tf.exp(common_layers.log_prob_from_logits(self._score))
        #self._predictions = self._predictions_dict["outputs"]
        # self._scores=predictions_dict['scores'] not return when greedy search
        tf.logging.info("Start to init tf session")
        if self._isGpu:
            print('Using GPU in Decoder')
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self._fraction)
            self._sess = tf.Session(
                config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, gpu_options=gpu_options))
        else:
            print('Using CPU in Decoder')
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0)
            config = tf.ConfigProto(gpu_options=gpu_options)
            config.allow_soft_placement = True
            config.log_device_placement = False
            self._sess = tf.Session(config=config)
        with self._sess.as_default():
            ckpt = saver_mod.get_checkpoint_state(self._model_dir)
            saver = tf.train.Saver(allow_empty=True)
            tf.logging.info("Start to restore the parameters from %s", ckpt.model_checkpoint_path)
            saver.restore(self._sess, ckpt.model_checkpoint_path)
        tf.logging.info("Finish intialize environment")
Пример #12
0
def _advance(step, beam_log_probs, previous_refs,
             area_logits, areas, batch_size, beam_size, append_refs=True,
             condition=None):
  """Advance one element in the tuple for a decoding step.

  Args:
    step: the current decoding step.
    beam_log_probs: [batch_size * beam_size]
    previous_refs: [batch_size * beam_size, input_length - 1, 2]
    area_logits: [batch_size * beam_size, num_areas]
    areas: the areas.
    batch_size: the batch size.
    beam_size: the beam_size.
    append_refs: returning references or ids.
    condition: conditional probability mask in shape [batch_size * beam_size].
  Returns:
    beam_log_probs: [batch_size * beam_size]
    references in shape of [batch_size * beam_size, input_length, 2] or
        ids in shape of [batch_size * beam_size]
  """
  with tf.control_dependencies([
      tf.equal(tf.shape(beam_log_probs), (batch_size * beam_size,))]):
    num_expansions = tf.minimum(beam_size, tf.shape(area_logits)[-1])
    # [batch_size * beam_size, num_expansions]
    area_log_probs = common_layers.log_prob_from_logits(area_logits)
    if condition is not None:
      area_log_probs = area_log_probs * tf.to_float(
          tf.expand_dims(condition, 1))
    top_area_log_probs, top_area_ids = tf.nn.top_k(
        area_log_probs, k=num_expansions)
  if append_refs:
    # [batch_size * beam_size, num_expansions, 2]
    refs = area_utils.area_to_refs(areas["starts"], areas["ends"],
                                   top_area_ids)
    if condition is not None:
      refs = refs * tf.expand_dims(tf.expand_dims(condition, 1), 2)
    refs = tf.reshape(refs, [batch_size, beam_size, num_expansions, 1, 2])
    if step > 0:
      previous_refs = tf.reshape(
          previous_refs, [batch_size, beam_size, 1, step, 2])
      previous_refs = tf.tile(previous_refs, [1, 1, num_expansions, 1, 1])
      new_refs = tf.concat([previous_refs, refs], axis=3)
    else:
      new_refs = refs
    new_refs = tf.reshape(
        new_refs, [batch_size * beam_size * num_expansions, step + 1, 2])
  # [batch_size, beam_size * num_expansions]
  log_probs = tf.reshape(tf.expand_dims(beam_log_probs, 1) + top_area_log_probs,
                         [batch_size, beam_size * num_expansions])
  # [batch_size, beam_size]
  beam_log_probs, beam_indices = tf.nn.top_k(log_probs, k=beam_size)
  beam_indices = tf.reshape(beam_indices, [-1])
  beam_log_probs = tf.reshape(beam_log_probs, [batch_size * beam_size])
  indices = tf.reshape(
      tf.tile(tf.expand_dims(tf.range(batch_size) * beam_size * num_expansions,
                             axis=1), [1, beam_size]), [-1]) + beam_indices
  if append_refs:
    new_refs = tf.gather(new_refs, indices=indices)
  else:
    new_refs = tf.gather(tf.reshape(top_area_ids, [-1]), indices=indices)
  return beam_log_probs, new_refs
Пример #13
0
  def grow_topk(i, alive_seq, alive_log_probs, states):
    r"""Inner beam search loop.

    This function takes the current alive sequences, and grows them to topk
    sequences where k = 2*beam. We use 2*beam because, we could have beam_size
    number of sequences that might hit <EOS> and there will be no alive
    sequences to continue. With 2*beam_size, this will not happen. This relies
    on the assumption the vocab size is > beam size. If this is true, we'll
    have at least beam_size non <EOS> extensions if we extract the next top
    2*beam words.
    Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to
    https://arxiv.org/abs/1609.08144.

    Args:
      i: loop index
      alive_seq: Topk sequences decoded so far [batch_size, beam_size, i+1]
      alive_log_probs: probabilities of these sequences. [batch_size, beam_size]
      states: dict (possibly nested) of decoding states.
    Returns:
      Tuple of
        (Topk sequences extended by the next word,
         The log probs of these sequences,
         The scores with length penalty of these sequences,
         Flags indicating which of these sequences have finished decoding,
         dict of transformed decoding states)
    """
    # Get the logits for all the possible next symbols
    flat_ids = tf.reshape(alive_seq, [batch_size * beam_size, -1])

    # (batch_size * beam_size, decoded_length)
    if states:
      flat_states = nest.map_structure(_merge_beam_dim, states)
      flat_logits, flat_states = symbols_to_logits_fn(flat_ids, i, flat_states)
      states = nest.map_structure(
          lambda t: _unmerge_beam_dim(t, batch_size, beam_size), flat_states)
    else:
      flat_logits = symbols_to_logits_fn(flat_ids)

    logits = tf.reshape(flat_logits, [batch_size, beam_size, -1])

    # Convert logits to normalized log probs
    candidate_log_probs = common_layers.log_prob_from_logits(logits)

    # Multiply the probabilities by the current probabilities of the beam.
    # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1)
    log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2)

    length_penalty = tf.pow(((5. + tf.to_float(i + 1)) / 6.), alpha)

    curr_scores = log_probs / length_penalty
    # Flatten out (beam_size, vocab_size) probs in to a list of possibilities
    flat_curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size])

    topk_scores, topk_ids = tf.nn.top_k(flat_curr_scores, k=beam_size * 2)

    # Recovering the log probs because we will need to send them back
    topk_log_probs = topk_scores * length_penalty

    # Work out what beam the top probs are in.
    topk_beam_index = topk_ids // vocab_size
    topk_ids %= vocab_size  # Unflatten the ids

    # The next three steps are to create coordinates for tf.gather_nd to pull
    # out the correct sequences from id's that we need to grow.
    # We will also use the coordinates to gather the booleans of the beam items
    # that survived.
    batch_pos = compute_batch_indices(batch_size, beam_size * 2)

    # top beams will give us the actual coordinates to do the gather.
    # stacking will create a tensor of dimension batch * beam * 2, where the
    # last dimension contains the i,j gathering coordinates.
    topk_coordinates = tf.stack([batch_pos, topk_beam_index], axis=2)

    # Gather up the most probable 2*beams both for the ids and finished_in_alive
    # bools
    topk_seq = tf.gather_nd(alive_seq, topk_coordinates)
    if states:
      states = nest.map_structure(
          lambda state: tf.gather_nd(state, topk_coordinates), states)

    # Append the most probable alive
    topk_seq = tf.concat([topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)

    topk_finished = tf.equal(topk_ids, eos_id)

    return topk_seq, topk_log_probs, topk_scores, topk_finished, states