Example #1
0
def _top_k_sample(logits, ignore_ids=None, num_samples=1, k=10):
    """
    Does top-k sampling. if ignore_ids is on, then we will zero out those logits.
    :param logits: [batch_size, vocab_size] tensor
    :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
                        like padding maybe
    :param p: topp threshold to use, either a float or a [batch_size] vector
    :return: [batch_size, num_samples] samples

    # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
    """
    with tf.variable_scope('top_p_sample'):
        batch_size, vocab_size = get_shape_list(logits, expected_rank=2)

        probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
                              axis=-1)
        # [batch_size, vocab_perm]
        indices = tf.argsort(probs, direction='DESCENDING')

        # find the top pth index to cut off. careful we don't want to cutoff everything!
        # result will be [batch_size, vocab_perm]
        k_expanded = k if isinstance(k, int) else k[:, None]
        exclude_mask = tf.range(vocab_size)[None] >= k_expanded

        # OPTION A - sample in the sorted space, then unsort.
        logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
        sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
        sample = tf.batch_gather(indices, sample_perm)

    return {
        'probs': probs,
        'sample': sample,
    }
Example #2
0
def _top_p_sample(logits, ignore_ids=None, num_samples=1, p=0.9):
    """
    Does top-p sampling. if ignore_ids is on, then we will zero out those logits.
    :param logits: [batch_size, vocab_size] tensor
    :param ignore_ids: [vocab_size] one-hot representation of the indices we'd like to ignore and never predict,
                        like padding maybe
    :param p: topp threshold to use, either a float or a [batch_size] vector
    :return: [batch_size, num_samples] samples

    # TODO FIGURE OUT HOW TO DO THIS ON TPUS. IT'S HELLA SLOW RIGHT NOW, DUE TO ARGSORT I THINK
    """
    with tf.variable_scope('top_p_sample'):
        batch_size, vocab_size = get_shape_list(logits, expected_rank=2)

        probs = tf.nn.softmax(logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
                              axis=-1)

        if isinstance(p, float) and p > 0.999999:
            # Don't do top-p sampling in this case
            print("Top-p sampling DISABLED", flush=True)
            return {
                'probs': probs,
                'sample': tf.random.categorical(
                    logits=logits if ignore_ids is None else logits - tf.cast(ignore_ids[None], tf.float32) * 1e10,
                    num_samples=num_samples, dtype=tf.int32),
            }

        # [batch_size, vocab_perm]
        indices = tf.argsort(probs, direction='DESCENDING')
        cumulative_probabilities = tf.math.cumsum(tf.batch_gather(probs, indices), axis=-1, exclusive=False)

        # find the top pth index to cut off. careful we don't want to cutoff everything!
        # result will be [batch_size, vocab_perm]
        p_expanded = p if isinstance(p, float) else p[:, None]
        exclude_mask = tf.logical_not(
            tf.logical_or(cumulative_probabilities < p_expanded, tf.range(vocab_size)[None] < 1))

        # OPTION A - sample in the sorted space, then unsort.
        logits_to_use = tf.batch_gather(logits, indices) - tf.cast(exclude_mask, tf.float32) * 1e10
        sample_perm = tf.random.categorical(logits=logits_to_use, num_samples=num_samples)
        sample = tf.batch_gather(indices, sample_perm)

        # OPTION B - unsort first - Indices need to go back to 0 -> N-1 -- then sample
        # unperm_indices = tf.argsort(indices, direction='ASCENDING')
        # include_mask_unperm = tf.batch_gather(include_mask, unperm_indices)
        # logits_to_use = logits - (1 - tf.cast(include_mask_unperm, tf.float32)) * 1e10
        # sample = tf.random.categorical(logits=logits_to_use, num_samples=num_samples, dtype=tf.int32)

    return {
        'probs': probs,
        'sample': sample,
    }
  def test_posterior_mode_invariance_states(self):
    observation_probs_data = tf.constant([[0.12, 0.48, 0.5, 0.1],
                                          [0.4, 0.1, 0.5, 0.0],
                                          [0.1, 0.2, 0.3, 0.4]],
                                         dtype=self.dtype)
    transition_matrix_data = tf.constant([[0.21, 0.49, 0.3],
                                          [0.18, 0.12, 0.7],
                                          [0.75, 0.15, 0.1]],
                                         dtype=self.dtype)
    initial_prob_data = tf.constant([0.8, 0.13, 0.07],
                                    dtype=self.dtype)

    (initial_prob, transition_matrix,
     observation_probs) = self.make_placeholders([
         initial_prob_data, transition_matrix_data,
         observation_probs_data])

    permutations = tf.identity(np.array([np.random.permutation(3)
                                         for _ in range(8)]))
    inverse_permutations = tf.argsort(permutations)

    initial_prob_permuted = tf.gather(initial_prob, inverse_permutations)

    # Permute rows of observation matrix
    observation_probs_permuted = tf.gather(observation_probs,
                                           inverse_permutations)

    # Permute both rows and columns of transition matrix
    transition_matrix_permuted = tf.transpose(
        a=tf.gather(tf.transpose(a=transition_matrix),
                    inverse_permutations),
        perm=[0, 2, 1])
    transition_matrix_permuted = tf1.batch_gather(transition_matrix_permuted,
                                                  inverse_permutations)

    observations = tf.constant([1, 0, 3, 1, 3, 0, 2, 1, 2, 1, 3, 0, 0, 1, 1, 2])

    [num_steps] = self.make_placeholders([16])
    model = tfd.HiddenMarkovModel(
        tfd.Categorical(probs=initial_prob_permuted),
        tfd.Categorical(probs=transition_matrix_permuted),
        tfd.Categorical(probs=observation_probs_permuted),
        num_steps=num_steps)

    inferred_states = model.posterior_mode(observations)
    expected_states = [0, 1, 2, 0, 2, 1, 2, 0, 2, 0, 2, 0, 1, 2, 0, 1]
    expected_states_permuted = tf.transpose(
        a=tf1.batch_gather(
            tf.expand_dims(tf.transpose(
                a=permutations), axis=-1), expected_states)[..., 0])
    self.assertAllEqual(inferred_states, expected_states_permuted)
    def sample(self, num_samples=1):
        """Sample from the rejection sampling distribution.

    For ease of implementation, draw the maximum number of proposal samples.

    Args:
      num_samples: integer, number of samples to draw.

    Returns:
      samples: Tensor of samples from the distribution, [num_samples] + data_dim
    """
        flat_proposal_samples = self.proposal.sample(num_samples * self.T)
        proposal_samples = tf.reshape(flat_proposal_samples,
                                      [num_samples, self.T] + self.data_dim)
        flat_logit_accept = self.logit_accept_fn(flat_proposal_samples)
        logit_accept = tf.reshape(flat_logit_accept, [num_samples, self.T])
        accept_samples = tfd.Bernoulli(logits=logit_accept[:, :-1]).sample()

        # Add forced accept to last sample to ensure truncation
        accept_samples = tf.concat([
            accept_samples,
            tf.ones([num_samples, 1], dtype=accept_samples.dtype)
        ],
                                   axis=-1)

        # For each of sample_shape, find the first nonzero accept
        def get_first_nonzero_index(t):
            # t is batch_dims + [T], t is binary.
            _, indices = tf.math.top_k(t, k=1, sorted=False)
            return indices

        accept_indices = get_first_nonzero_index(
            accept_samples)  # sample_shape
        samples = tf.batch_gather(proposal_samples, accept_indices)
        return tf.squeeze(samples, axis=1)  # Squeeze the selected dim
Example #5
0
def gaussian_mixture_approximate_mode(gm):
    """Returns the mean of the most probable mixture component."""
    # Find the most likely mixture component.
    mode_alpha = gm.mixture_distribution.mode()[Ellipsis, None]
    mus = gm.components_distribution.mean()
    # Gather the mean of the most likely component.
    return tf.squeeze(tf.batch_gather(mus, mode_alpha), axis=-2)
Example #6
0
def initialize_from_context(initial_context,
                            ignore_ids,
                            news_config,
                            p_for_topp=0.95,
                            k_for_topk=100,
                            do_topk=False):
    """ same signature as sample_step"""
    batch_size, _ = get_shape_list(initial_context, expected_rank=2)

    context_output = sample_step(tokens=initial_context,
                                 ignore_ids=ignore_ids,
                                 news_config=news_config,
                                 batch_size=batch_size,
                                 p_for_topp=p_for_topp,
                                 k_for_topk=k_for_topk,
                                 cache=None,
                                 do_topk=do_topk)
    model = context_output['model']

    gt_logprobs = tf.squeeze(tf.batch_gather(model.log_probs[:, :-1],
                                             model.input_ids[:, 1:, None]),
                             axis=2)

    return {
        'tokens':
        tf.concat([initial_context, context_output['new_tokens'][:, None]], 1),
        'cache':
        context_output['new_cache'],
        'probs':
        tf.concat([tf.exp(gt_logprobs), context_output['new_probs'][:, None]],
                  axis=1)
    }
def sample_step(tokens,
                ignore_ids,
                news_config,
                batch_size=1,
                p_for_topp=0.95,
                cache=None,
                do_topk=False):
    """
    Helper function that samples from grover for a single step
    :param tokens: [batch_size, n_ctx_b] tokens that we will predict from
    :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict
    :param news_config: config for the GroverModel
    :param batch_size: batch size to use
    :param p_for_topp: top-p or top-k threshold
    :param cache: [batch_size, news_config.num_hidden_layers, 2,
                   news_config.num_attention_heads, n_ctx_a,
                   news_config.hidden_size // news_config.num_attention_heads] OR, None
    :return: new_tokens, size [batch_size]
             new_probs, also size [batch_size]
             new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b,
                   news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads]
    """
    model = GroverModel(
        config=news_config,
        is_training=False,
        input_ids=tokens,
        reuse=tf.AUTO_REUSE,
        scope='newslm',
        chop_off_last_token=False,
        do_cache=True,
        cache=cache,
    )

    # Extract the FINAL SEQ LENGTH
    batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat,
                                                             expected_rank=2)
    next_logits = tf.reshape(model.logits_flat,
                             [batch_size, -1, vocab_size])[:, -1]

    if do_topk:
        sample_info = _top_k_sample(next_logits,
                                    num_samples=1,
                                    k=tf.cast(p_for_topp, dtype=tf.int32))
    else:
        sample_info = _top_p_sample(next_logits,
                                    ignore_ids=ignore_ids,
                                    num_samples=1,
                                    p=p_for_topp)

    new_tokens = tf.squeeze(sample_info['sample'], 1)
    new_probs = tf.squeeze(
        tf.batch_gather(sample_info['probs'], sample_info['sample']), 1)
    return {
        'new_tokens': new_tokens,
        'new_probs': new_probs,
        'new_cache': model.new_kvs,
    }
Example #8
0
def fast_tpu_gather(params, indices, name=None):
    """Fast gather implementation for models running on TPU.

  This function use one_hot and batch matmul to do gather, which is faster
  than gather_nd on TPU. For params that have dtype of int32 (sequences to
  gather from), batch_gather is used to keep accuracy.

  Args:
    params: A tensor from which to gather values.
      [batch_size, original_size, ...]
    indices: A tensor used as the index to gather values.
      [batch_size, selected_size].
    name: A string, name of the operation (optional).

  Returns:
    gather_result: A tensor that has the same rank as params.
      [batch_size, selected_size, ...]
  """
    with tf.name_scope(name):
        dtype = params.dtype

        def _gather(params, indices):
            """Fast gather using one_hot and batch matmul."""
            if dtype != tf.float32:
                params = tf.to_float(params)
            shape = common_layers.shape_list(params)
            indices_shape = common_layers.shape_list(indices)
            ndims = params.shape.ndims
            # Adjust the shape of params to match one-hot indices, which is the
            # requirement of Batch MatMul.
            if ndims == 2:
                params = tf.expand_dims(params, axis=-1)
            if ndims > 3:
                params = tf.reshape(params, [shape[0], shape[1], -1])
            gather_result = tf.matmul(
                tf.one_hot(indices, shape[1], dtype=params.dtype), params)
            if ndims == 2:
                gather_result = tf.squeeze(gather_result, axis=-1)
            if ndims > 3:
                shape[1] = indices_shape[1]
                gather_result = tf.reshape(gather_result, shape)
            if dtype != tf.float32:
                gather_result = tf.cast(gather_result, dtype)
            return gather_result

        # If the dtype is int, use the gather instead of one_hot matmul to avoid
        # precision loss. The max int value can be represented by bfloat16 in MXU is
        # 256, which is smaller than the possible id values. Encoding/decoding can
        # potentially used to make it work, but the benenfit is small right now.
        if dtype.is_integer:
            gather_result = tf.batch_gather(params, indices)
        else:
            gather_result = _gather(params, indices)

        return gather_result
Example #9
0
 def gather_neighbour(pc, neighbor_idx):
     # gather the coordinates or features of neighboring points
     batch_size = tf.shape(pc)[0]
     num_points = tf.shape(pc)[1]
     d = pc.get_shape()[2].value
     index_input = tf.reshape(neighbor_idx, shape=[batch_size, -1])
     features = tf.batch_gather(pc, index_input)
     features = tf.reshape(
         features, [batch_size, num_points,
                    tf.shape(neighbor_idx)[-1], d])
     return features
Example #10
0
def _compute_calibration_bin_statistics(num_bins,
                                        logits=None,
                                        labels_true=None,
                                        labels_predicted=None):
    """Compute binning statistics required for calibration measures.

  Args:
    num_bins: int, number of probability bins, e.g. 10.
    logits: Tensor, (n,nlabels), with logits for n instances and nlabels.
    labels_true: Tensor, (n,), with tf.int32 or tf.int64 elements containing
      ground truth class labels in the range [0,nlabels].
    labels_predicted: Tensor, (n,), with tf.int32 or tf.int64 elements
      containing decisions of the predictive system.  If `None`, we will use
      the argmax decision using the `logits`.

  Returns:
    bz: Tensor, shape (2,num_bins), tf.int32, counts of incorrect (row 0) and
      correct (row 1) predictions in each of the `num_bins` probability bins.
    pmean_observed: Tensor, shape (num_bins,), tf.float32, the mean predictive
      probabilities in each probability bin.
  """

    if labels_predicted is None:
        # If no labels are provided, we take the label with the maximum probability
        # decision.  This corresponds to the optimal expected minimum loss decision
        # under 0/1 loss.
        pred_y = tf.argmax(logits, axis=1, output_type=labels_true.dtype)
    else:
        pred_y = labels_predicted

    correct = tf.cast(tf.equal(pred_y, labels_true), tf.int32)

    # Collect predicted probabilities of decisions
    pred = tf.nn.softmax(logits, axis=1)
    prob_y = tf1.batch_gather(pred, pred_y[:, tf.newaxis])  # p(pred_y | x)
    prob_y = tf.reshape(prob_y, (tf.size(prob_y), ))

    # Compute b/z histogram statistics:
    # bz[0,bin] contains counts of incorrect predictions in the probability bin.
    # bz[1,bin] contains counts of correct predictions in the probability bin.
    bins = tf.histogram_fixed_width_bins(prob_y, [0.0, 1.0], nbins=num_bins)
    event_bin_counts = tf.math.bincount(correct * num_bins + bins,
                                        minlength=2 * num_bins,
                                        maxlength=2 * num_bins)
    event_bin_counts = tf.reshape(event_bin_counts, (2, num_bins))

    # Compute mean predicted probability value in each of the `num_bins` bins
    pmean_observed = tf.math.unsorted_segment_sum(prob_y, bins, num_bins)
    tiny = np.finfo(dtype_util.as_numpy_dtype(logits.dtype)).tiny
    pmean_observed = pmean_observed / (
        tf.cast(tf.reduce_sum(event_bin_counts, axis=0), logits.dtype) + tiny)

    return event_bin_counts, pmean_observed
Example #11
0
 def sample(self, num_samples=1):
     """Sample from the model."""
     flat_proposal_samples = self.proposal.sample(num_samples * self.K)
     proposal_samples = tf.reshape(flat_proposal_samples,
                                   [num_samples, self.K] + self.data_dim)
     log_energy = tf.reshape(
         tf.squeeze(self.energy_fn(flat_proposal_samples), axis=-1),
         [num_samples, self.K])
     indexes = tfd.Categorical(logits=log_energy).sample()  # [num_samples]
     samples = tf.batch_gather(proposal_samples,
                               tf.expand_dims(indexes, axis=-1))
     return tf.squeeze(samples, axis=1)  # Squeeze the selected dim
Example #12
0
 def nearest_interpolation(feature, interp_idx):
     """
     :param feature: [B, N, d] input features matrix
     :param interp_idx: [B, up_num_points, 1] nearest neighbour index
     :return: [B, up_num_points, d] interpolated features matrix
     """
     feature = tf.squeeze(feature, axis=2)
     batch_size = tf.shape(interp_idx)[0]
     up_num_points = tf.shape(interp_idx)[1]
     interp_idx = tf.reshape(interp_idx, [batch_size, up_num_points])
     interpolated_features = tf.batch_gather(feature, interp_idx)
     interpolated_features = tf.expand_dims(interpolated_features, axis=2)
     return interpolated_features
Example #13
0
  def _build_verified_loss(self, labels):
    """Build verified loss using an upper bound on specification."""
    if not self._specification:
      self._verified_loss = tf.constant(0.)
      self._interval_bounds_accuracy = tf.constant(0.)
      return
    # Interval bounds.
    bounds = self._get_specification_bounds()
    # Select specifications.
    if self._interval_bounds_loss_mode == 'all':
      pass  # Keep bounds the way it is.
    elif self._interval_bounds_loss_mode == 'most':
      bounds = tf.reduce_max(bounds, axis=1, keepdims=True)
    elif self._interval_bounds_loss_mode == 'random':
      idx = tf.random.uniform(
          [tf.shape(bounds)[0], self._interval_bounds_loss_n],
          0, tf.shape(bounds)[1], dtype=tf.int32)
      bounds = tf.batch_gather(bounds, idx)
    else:
      assert self._interval_bounds_loss_mode == 'least'
      # This picks the least violated contraint.
      mask = tf.cast(bounds < 0., tf.float32)
      smallest_violation = tf.reduce_min(
          bounds + mask * _BIG_NUMBER, axis=1, keepdims=True)
      has_violations = tf.less(
          tf.reduce_sum(mask, axis=1, keepdims=True) + .5,
          tf.cast(tf.shape(bounds)[1], tf.float32))
      largest_bounds = tf.reduce_max(bounds, axis=1, keepdims=True)
      bounds = tf.where(has_violations, smallest_violation, largest_bounds)

    if self._interval_bounds_loss_type == 'xent':
      v = tf.concat(
          [bounds, tf.zeros([tf.shape(bounds)[0], 1], dtype=bounds.dtype)],
          axis=1)
      l = tf.concat(
          [tf.zeros_like(bounds),
           tf.ones([tf.shape(bounds)[0], 1], dtype=bounds.dtype)],
          axis=1)
      self._verified_loss = tf.reduce_mean(
          tf.nn.softmax_cross_entropy_with_logits_v2(
              labels=tf.stop_gradient(l), logits=v))
    elif self._interval_bounds_loss_type == 'softplus':
      self._verified_loss = tf.reduce_mean(
          tf.nn.softplus(bounds + self._interval_bounds_hinge_margin))
    else:
      assert self._interval_bounds_loss_type == 'hinge'
      self._verified_loss = tf.reduce_mean(
          tf.maximum(bounds, -self._interval_bounds_hinge_margin))
def compute_joint_mlp_logits(sequence, max_span_length):
    """Computes joint span (start, end) logits from sequence input."""
    batch_size, seq_length, hidden_size = modeling.get_shape_list(
        sequence, expected_rank=3)

    projection_size = hidden_size  # This seems to be a reasonable setting.

    with tf.variable_scope("joint_span"):
        projection = tf.layers.dense(
            sequence,
            projection_size * 2,
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
            name="projection")
        start_projection, end_projection = tf.split(projection, 2, axis=-1)

        # 1. The start representations are tiled max_answer_length times.
        # TODO(danielandor): Use the mask to compute an optimal span list.
        starts = tf.reshape(start_projection,
                            [batch_size * seq_length, 1, projection_size])
        starts = tf.tile(starts, [1, max_span_length, 1])
        starts = tf.reshape(
            starts,
            [batch_size, seq_length * max_span_length, projection_size])

        # 2. To make the end representations, we compute band diagonal indices and
        #    perform a batched gather.
        seqs = tf.expand_dims(tf.range(seq_length), 1)
        offsets = tf.expand_dims(tf.range(max_span_length), 0)
        indices = seqs + offsets  # uses broadcasting
        indices.shape.assert_is_compatible_with((seq_length, max_span_length))
        indices = tf.reshape(indices, [1, seq_length * max_span_length])
        indices = tf.tile(indices, [batch_size, 1])
        indices = tf.minimum(indices, seq_length - 1)  # clips indices
        ends = tf.batch_gather(end_projection, indices)

        # 3. The final step adds the starts and ends.
        ends.shape.assert_is_compatible_with(starts.shape)
        inputs = starts + ends
        inputs = modeling.gelu(inputs)  # Bias is already in the projection.
        inputs = contrib_layers.layer_norm(inputs)
        start_logits = tf.layers.dense(
            inputs,
            1,
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
            name="logits")
    return tf.reshape(start_logits, [batch_size, seq_length, max_span_length])
Example #15
0
 def random_sample(feature, pool_idx):
     """
     :param feature: [B, N, d] input features matrix
     :param pool_idx: [B, N', max_num] N' < N, N' is the selected position after pooling
     :return: pool_features = [B, N', d] pooled features matrix
     """
     feature = tf.squeeze(feature, axis=2)
     num_neigh = tf.shape(pool_idx)[-1]
     d = feature.get_shape()[-1]
     batch_size = tf.shape(pool_idx)[0]
     pool_idx = tf.reshape(pool_idx, [batch_size, -1])
     pool_features = tf.batch_gather(feature, pool_idx)
     pool_features = tf.reshape(pool_features,
                                [batch_size, -1, num_neigh, d])
     pool_features = tf.reduce_max(pool_features, axis=2, keepdims=True)
     return pool_features
Example #16
0
    def positional_sampling(self,
                            layer_input,
                            feature_dimension,
                            name='positional_sampling'):
        featuremap = layer_input[0]
        batch_indices = layer_input[1]
        grid = layer_input[2]

        shape_grid = tf.shape(grid)

        featuremap_flat = tf.reshape(featuremap,
                                     [shape_grid[0], -1, feature_dimension])
        batch_indices_flat = tf.reshape(batch_indices, [shape_grid[0], -1])
        batch_ps_flat = tf.batch_gather(featuremap_flat, batch_indices_flat)

        b, h, w, c = shape_grid[0], shape_grid[1], shape_grid[
            2], feature_dimension
        return tf.reshape(batch_ps_flat, [b, h, w, c])
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        input_ids = features["input_ids"]

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = GroverModel(
            config=config,
            is_training=is_training,
            input_ids=input_ids,
            pad_token_id=config.pad_token_id,
            chop_off_last_token=True,
        )

        total_loss = model.lm_loss()

        if is_training:
            train_op, train_metrics = optimization_adafactor.create_optimizer(
                total_loss, learning_rate, num_train_steps, num_warmup_steps,
                use_tpu)
            tvars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        else:
            train_op = None
            train_metrics = {}
            tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map,
             initialized_variable_names) = get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                            init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            if use_tpu:
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op,
                    host_call=construct_scalar_host_call(
                        metric_dict=train_metrics,
                        model_dir=params['model_dir'],
                        prefix='training/'),
                    scaffold_fn=scaffold_fn)
            else:
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op,
                    training_hooks=[
                        tf.train.LoggingTensorHook(
                            {'loss': tf.metrics.mean(total_loss)[1]},
                            every_n_iter=100)
                    ],
                    scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(total_loss):
                loss = tf.metrics.mean(values=total_loss)
                return {
                    "eval_loss": loss,
                }

            eval_metrics = (metric_fn, [total_loss])
            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)
        else:
            gt_logprobs = tf.squeeze(tf.batch_gather(
                model.log_probs, model.target_ids[:, :, None]),
                                     axis=2)

            # Need top-p required under topp sampling!
            better_than_gt = model.log_probs > gt_logprobs[:, :, None]
            top_p_required = tf.reduce_sum(
                tf.cast(better_than_gt, tf.float32) * tf.exp(model.log_probs),
                axis=2)

            # No top-p sampling for now, since this seems to be too slow on TPUs
            if use_tpu:
                predictions = tf.reshape(
                    tf.random.categorical(logits=model.logits_flat,
                                          num_samples=1),
                    get_shape_list(model.target_ids),
                )
            else:
                # Argmax
                # predictions = tf.math.argmax(model.log_probs, axis=-1, output_type=tf.int32)
                predictions = tf.reshape(
                    _top_p_sample(model.logits_flat, num_samples=1,
                                  p=0.99)['sample'],
                    get_shape_list(model.target_ids),
                )
            pred_logprobs = tf.squeeze(tf.batch_gather(model.log_probs,
                                                       predictions[:, :,
                                                                   None]),
                                       axis=2)

            output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                predictions={
                    'gt_logprobs': gt_logprobs,
                    'top_p_required': top_p_required,
                    'predictions': predictions,
                    'pred_logprobs': pred_logprobs,
                    'labels': input_ids
                },
                scaffold_fn=scaffold_fn)
        return output_spec
Example #18
0
def define_ppo_step(data_points, hparams, action_space, lr, epoch=-1,
                    distributional_size=1, distributional_subscale=0.04):
  """Define ppo step."""
  del distributional_subscale
  (observation, action, discounted_reward, discounted_reward_probs,
   norm_advantage, old_pdf) = data_points

  obs_shape = common_layers.shape_list(observation)
  observation = tf.reshape(
      observation, [obs_shape[0] * obs_shape[1]] + obs_shape[2:]
  )
  (logits, new_value) = get_policy(observation, hparams, action_space,
                                   epoch=epoch,
                                   distributional_size=distributional_size)
  logits = tf.reshape(logits, obs_shape[:2] + [action_space.n])
  new_policy_dist = tfp.distributions.Categorical(logits=logits)

  new_pdf = new_policy_dist.prob(action)

  ratio = new_pdf / old_pdf
  clipped_ratio = tf.clip_by_value(ratio, 1 - hparams.clipping_coef,
                                   1 + hparams.clipping_coef)

  surrogate_objective = tf.minimum(clipped_ratio * norm_advantage,
                                   ratio * norm_advantage)
  policy_loss = -tf.reduce_mean(surrogate_objective)

  if distributional_size > 1:
    new_value = tf.reshape(new_value, obs_shape[:2] + [distributional_size])
    new_value = tf.nn.log_softmax(new_value, axis=-1)
    value_shape = common_layers.shape_list(new_value)
    # The above is the new value distribution. We are also given as discounted
    # reward the value distribution and the corresponding probabilities.
    # The given discounted reward is already rounded to integers but in range
    # increased by 2x for greater fidelity. Increase range of new_values here.
    new_value_shifted = tf.concat([new_value[1:], new_value[-1:]], axis=0)
    new_value_mean = (new_value + new_value_shifted) / 2
    new_value = tf.concat([tf.expand_dims(new_value, axis=-1),
                           tf.expand_dims(new_value_mean, axis=-1)], -1)
    new_value = tf.reshape(new_value, value_shape[:-1] + [2 * value_shape[-1]])
    # Cast discounted reward to integers and gather the new log-probs for them.
    discounted_reward = tf.cast(discounted_reward, tf.int32)
    value_loss = tf.batch_gather(new_value, discounted_reward)
    # Weight the gathered (new) log-probs by the old probabilities.
    discounted_reward_probs = tf.expand_dims(discounted_reward_probs, axis=1)
    value_loss = - tf.reduce_sum(value_loss * discounted_reward_probs, axis=-1)
    # Take the mean over batch and time as final loss, multiply by coefficient.
    value_loss = hparams.value_loss_coef * tf.reduce_mean(value_loss)
  else:
    new_value = tf.reshape(new_value, obs_shape[:2])
    value_error = new_value - discounted_reward
    value_loss = hparams.value_loss_coef * tf.reduce_mean(value_error ** 2)

  entropy = new_policy_dist.entropy()
  entropy_loss = -hparams.entropy_loss_coef * tf.reduce_mean(entropy)

  losses = [policy_loss, value_loss, entropy_loss]
  loss = sum(losses)
  variables = tf.global_variables(hparams.policy_network + "/.*")
  train_op = optimize.optimize(loss, lr, hparams, variables=variables)

  with tf.control_dependencies([train_op]):
    return [tf.identity(x) for x in losses]
Example #19
0
def model_fn(features, labels, mode, params):
  """Model function."""
  del labels

  # ==============================
  # Input features
  # ==============================
  # [batch_size, query_seq_len]
  query_inputs = features["query_inputs"]

  # [batch_size, num_candidates, candidate_seq_len]
  candidate_inputs = features["candidate_inputs"]

  # [batch_size, num_candidates, query_seq_len + candidate_seq_len]
  joint_inputs = features["joint_inputs"]

  # [batch_size, num_masks]
  mlm_targets = features["mlm_targets"]
  mlm_positions = features["mlm_positions"]
  mlm_mask = features["mlm_mask"]

  # ==============================
  # Create modules.
  # ==============================
  bert_module = hub.Module(
      spec=params["bert_hub_module_handle"],
      name="locbert",
      tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
      trainable=True)
  hub.register_module_for_export(bert_module, "locbert")

  embedder_module = hub.Module(
      spec=params["embedder_hub_module_handle"],
      name="embedder",
      tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
      trainable=True)
  hub.register_module_for_export(embedder_module, "embedder")

  if params["share_embedders"]:
    query_embedder_module = embedder_module
  else:
    query_embedder_module = hub.Module(
        spec=params["embedder_hub_module_handle"],
        name="embedder",
        tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
        trainable=True)
    hub.register_module_for_export(embedder_module, "query_embedder")

  # ==============================
  # Retrieve.
  # ==============================
  # [batch_size, projected_size]
  query_emb = query_embedder_module(
      inputs=dict(
          input_ids=query_inputs.token_ids,
          input_mask=query_inputs.mask,
          segment_ids=query_inputs.segment_ids),
      signature="projected")

  # [batch_size * num_candidates, candidate_seq_len]
  flat_candidate_inputs, unflatten = flatten_bert_inputs(
      candidate_inputs)

  # [batch_size * num_candidates, projected_size]
  flat_candidate_emb = embedder_module(
      inputs=dict(
          input_ids=flat_candidate_inputs.token_ids,
          input_mask=flat_candidate_inputs.mask,
          segment_ids=flat_candidate_inputs.segment_ids),
      signature="projected")

  # [batch_size, num_candidates, projected_size]
  unflattened_candidate_emb = unflatten(flat_candidate_emb)

  # [batch_size, num_candidates]
  retrieval_score = tf.einsum("BD,BND->BN", query_emb,
                              unflattened_candidate_emb)

  # ==============================
  # Read.
  # ==============================
  # [batch_size * num_candidates, query_seq_len + candidate_seq_len]
  flat_joint_inputs, unflatten = flatten_bert_inputs(joint_inputs)

  # [batch_size * num_candidates, num_masks]
  flat_mlm_positions, _ = tensor_utils.flatten(
      tf.tile(
          tf.expand_dims(mlm_positions, 1), [1, params["num_candidates"], 1]))

  batch_size, num_masks = tensor_utils.shape(mlm_targets)

  # [batch_size * num_candidates, query_seq_len + candidates_seq_len]
  flat_joint_bert_outputs = bert_module(
      inputs=dict(
          input_ids=flat_joint_inputs.token_ids,
          input_mask=flat_joint_inputs.mask,
          segment_ids=flat_joint_inputs.segment_ids,
          mlm_positions=flat_mlm_positions),
      signature="mlm",
      as_dict=True)

  # [batch_size, num_candidates]
  candidate_score = retrieval_score

  # [batch_size, num_candidates]
  candidate_log_probs = tf.math.log_softmax(candidate_score)

  # ==============================
  # Compute marginal log-likelihood.
  # ==============================
  # [batch_size * num_candidates, num_masks]
  flat_mlm_logits = flat_joint_bert_outputs["mlm_logits"]

  # [batch_size, num_candidates, num_masks, vocab_size]
  mlm_logits = tf.reshape(
      flat_mlm_logits, [batch_size, params["num_candidates"], num_masks, -1])
  mlm_log_probs = tf.math.log_softmax(mlm_logits)

  # [batch_size, num_candidates, num_masks]
  tiled_mlm_targets = tf.tile(
      tf.expand_dims(mlm_targets, 1), [1, params["num_candidates"], 1])

  # [batch_size, num_candidates, num_masks, 1]
  tiled_mlm_targets = tf.expand_dims(tiled_mlm_targets, -1)

  # [batch_size, num_candidates, num_masks, 1]
  gold_log_probs = tf.batch_gather(mlm_log_probs, tiled_mlm_targets)

  # [batch_size, num_candidates, num_masks]
  gold_log_probs = tf.squeeze(gold_log_probs, -1)

  # [batch_size, num_candidates, num_masks]
  joint_gold_log_probs = (
      tf.expand_dims(candidate_log_probs, -1) + gold_log_probs)

  # [batch_size, num_masks]
  marginal_gold_log_probs = tf.reduce_logsumexp(joint_gold_log_probs, 1)

  # [batch_size, num_masks]
  float_mlm_mask = tf.cast(mlm_mask, tf.float32)

  # []
  loss = -tf.div_no_nan(
      tf.reduce_sum(marginal_gold_log_probs * float_mlm_mask),
      tf.reduce_sum(float_mlm_mask))

  # ==============================
  # Optimization
  # ==============================
  num_warmup_steps = min(10000, max(100, int(params["num_train_steps"] / 10)))
  train_op = optimization.create_optimizer(
      loss=loss,
      init_lr=params["learning_rate"],
      num_train_steps=params["num_train_steps"],
      num_warmup_steps=num_warmup_steps,
      use_tpu=params["use_tpu"])

  # ==============================
  # Evaluation
  # ==============================
  eval_metric_ops = None if params["use_tpu"] else dict()
  if mode != tf.estimator.ModeKeys.PREDICT:
    # [batch_size, num_masks]
    retrieval_utility = marginal_gold_log_probs - gold_log_probs[:, 0]
    retrieval_utility *= tf.cast(features["mlm_mask"], tf.float32)

    # []
    retrieval_utility = tf.div_no_nan(
        tf.reduce_sum(retrieval_utility), tf.reduce_sum(float_mlm_mask))
    add_mean_metric("retrieval_utility", retrieval_utility, eval_metric_ops)

    has_timestamp = tf.cast(
        tf.greater(features["export_timestamp"], 0), tf.float64)
    off_policy_delay_secs = (
        tf.timestamp() - tf.cast(features["export_timestamp"], tf.float64))
    off_policy_delay_mins = off_policy_delay_secs / 60.0
    off_policy_delay_mins *= tf.cast(has_timestamp, tf.float64)

    add_mean_metric("off_policy_delay_mins", off_policy_delay_mins,
                    eval_metric_ops)

  # Create empty predictions to avoid errors when running in prediction mode.
  predictions = dict()

  if params["use_tpu"]:
    return tf.estimator.tpu.TPUEstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        predictions=predictions)
  else:
    if eval_metric_ops is not None:
      # Make sure the eval metrics are updated during training so that we get
      # quick feedback from tensorboard summaries when debugging locally.
      with tf.control_dependencies([u for _, u in eval_metric_ops.values()]):
        loss = tf.identity(loss)
    return tf.estimator.EstimatorSpec(
        mode=mode,
        loss=loss,
        train_op=train_op,
        eval_metric_ops=eval_metric_ops,
        predictions=predictions)
Example #20
0
def non_max_suppression(scores_in,
                        boxes_in,
                        top_k_indices,
                        labels,
                        num_detections=ssd_constants.MAX_NUM_EVAL_BOXES):
    """Implement Non-maximum suppression.

  Args:
    scores_in: a Tensor with shape [batch_size,
      ssd_constants.MAX_NUM_EVAL_BOXES, num_classes]. The top
      ssd_constants.MAX_NUM_EVAL_BOXES box scores for each class.
    boxes_in: a Tensor with shape [batch_size, N, 4], which stacks box
      regression outputs on all feature levels. The N is the number of total
      anchors on all levels.
    top_k_indices: a Tensor with shape [batch_size,
      ssd_constants.MAX_NUM_EVAL_BOXES, num_classes]. The indices for these top
      boxes for each class.
    labels: labels tensor.
    num_detections: maximum output length.

  Returns:
    A tensor size of [batch_size, num_detections, 6] represents boxes, labels
    and scores after NMS.
  """

    _, _, num_classes = scores_in.get_shape().as_list()
    source_id = tf.cast(
        tf.tile(tf.expand_dims(labels[ssd_constants.SOURCE_ID], 1),
                [1, num_detections]), scores_in.dtype)
    raw_shape = tf.cast(
        tf.tile(tf.expand_dims(labels[ssd_constants.RAW_SHAPE], 1),
                [1, num_detections, 1]), scores_in.dtype)

    list_of_all_boxes = []
    list_of_all_scores = []
    list_of_all_classes = []
    # Skip background class.
    for class_i in range(1, num_classes, 1):
        boxes = tf.batch_gather(boxes_in, top_k_indices[:, :, class_i])
        class_i_scores = scores_in[:, :, class_i]
        class_i_scores, boxes = _filter_scores(class_i_scores, boxes)
        (class_i_post_scores,
         class_i_post_boxes) = ssd_architecture.non_max_suppression_padded(
             scores=tf.cast(class_i_scores, scores_in.dtype),
             boxes=tf.cast(boxes, scores_in.dtype),
             max_output_size=num_detections,
             iou_threshold=ssd_constants.OVERLAP_CRITERIA)
        class_i_classes = tf.fill(tf.shape(class_i_post_scores),
                                  ssd_constants.CLASS_INV_MAP[class_i])
        list_of_all_boxes.append(class_i_post_boxes)
        list_of_all_scores.append(class_i_post_scores)
        list_of_all_classes.append(class_i_classes)

    post_nms_boxes = tf.concat(list_of_all_boxes, axis=1)
    post_nms_scores = tf.concat(list_of_all_scores, axis=1)
    post_nms_classes = tf.concat(list_of_all_classes, axis=1)

    # sort all results.
    post_nms_scores, sorted_indices = tf.nn.top_k(tf.cast(
        post_nms_scores, scores_in.dtype),
                                                  k=num_detections,
                                                  sorted=True)

    post_nms_boxes = tf.gather(post_nms_boxes, sorted_indices, batch_dims=1)
    post_nms_classes = tf.gather(post_nms_classes,
                                 sorted_indices,
                                 batch_dims=1)
    detections_result = tf.stack([
        source_id,
        post_nms_boxes[:, :, 1] * raw_shape[:, :, 1],
        post_nms_boxes[:, :, 0] * raw_shape[:, :, 0],
        (post_nms_boxes[:, :, 3] - post_nms_boxes[:, :, 1]) *
        raw_shape[:, :, 1],
        (post_nms_boxes[:, :, 2] - post_nms_boxes[:, :, 0]) *
        raw_shape[:, :, 0],
        post_nms_scores,
        tf.cast(post_nms_classes, scores_in.dtype),
    ],
                                 axis=2)

    return detections_result
Example #21
0
def read_from_memory(read_keys, read_strengths, mem_state, top_k):
    """Function for cosine similarity content based reading from memory matrix.

  In the args list, we have the following conventions:
    B: batch size
    M: number of slots in a row of the memory matrix
    R: number of rows in the memory matrix
    H: number of read heads (of the controller or the policy)
    K: top_k if top_k>0

  Args:
    read_keys: the read keys of shape [B, H, M].
    read_strengths: the coefficients used to compute the normalised weighting
      vector of shape [B, H].
    mem_state: the primary memory tensor. Of shape [B, R, M].
    top_k: only use top k read matches, other reads do not go into softmax and
      are zeroed out in the output. top_k=0 (default) means use dense reads.

  Returns:
    The memory reads [B, H, M], read weights [B, H, top k], read indices
      [B, H, top k], and read strengths [B, H, 1].
  """
    _assert_compatible_read_from_memory_inputs(read_keys, read_strengths,
                                               mem_state)
    batch_size = read_keys.shape[0]
    num_read_heads = read_keys.shape[1]

    with tf.name_scope('memory_reading'):
        # Scale such that all rows are L2-unit vectors, for memory and read query.
        scaled_read_keys = tf.math.l2_normalize(read_keys,
                                                axis=-1)  # [B, H, M]
        scaled_mem = tf.math.l2_normalize(mem_state, axis=-1)  # [B, R, M]

        # The cosine distance is then their dot product.
        # Find the cosine distance between each read head and each row of memory.
        cosine_distances = tf.matmul(scaled_read_keys,
                                     scaled_mem,
                                     transpose_b=True)  # [B, H, R]

        # The rank must match cosine_distances for broadcasting to work.
        read_strengths = tf.expand_dims(read_strengths, axis=-1)  # [B, H, 1]
        weighted_distances = read_strengths * cosine_distances  # [B, H, R]

        if top_k:
            # Get top k indices (row indices with top k largest weighted distances).
            top_k_output = tf.nn.top_k(weighted_distances, top_k, sorted=False)
            read_indices = top_k_output.indices  # [B, H, K]

            # Create a sub-memory for each read head with only the top k rows.
            # Each batch_gather is [B, K, M] and the list stacks to [B, H, K, M].
            topk_mem_per_head = [
                tf.batch_gather(mem_state, ri_this_head)
                for ri_this_head in tf.unstack(read_indices, axis=1)
            ]
            topk_mem = tf.stack(topk_mem_per_head, axis=1)  # [B, H, K, M]
            topk_scaled_mem = tf.math.l2_normalize(topk_mem,
                                                   axis=-1)  # [B, H, K, M]

            # Calculate read weights for each head's top k sub-memory.
            expanded_scaled_read_keys = tf.expand_dims(scaled_read_keys,
                                                       axis=2)  # [B, H, 1, M]
            topk_cosine_distances = tf.reduce_sum(expanded_scaled_read_keys *
                                                  topk_scaled_mem,
                                                  axis=-1)  # [B, H, K]
            topk_weighted_distances = (read_strengths * topk_cosine_distances
                                       )  # [B, H, K]
            read_weights = tf.nn.softmax(topk_weighted_distances,
                                         axis=-1)  # [B, H, K]

            # For each head, read using the sub-memories and corresponding weights.
            expanded_weights = tf.expand_dims(read_weights,
                                              axis=-1)  # [B, H, K, 1]
            memory_reads = tf.reduce_sum(expanded_weights * topk_mem,
                                         axis=2)  # [B, H, M]
        else:
            read_weights = tf.nn.softmax(weighted_distances, axis=-1)

            num_rows_memory = mem_state.shape[1]
            all_indices = tf.range(num_rows_memory, dtype=tf.int32)
            all_indices = tf.reshape(all_indices, [1, 1, num_rows_memory])
            read_indices = tf.tile(all_indices,
                                   [batch_size, num_read_heads, 1])

            # This is the actual memory access.
            # Note that matmul automatically batch applies for us.
            memory_reads = tf.matmul(read_weights, mem_state)

        read_keys.shape.assert_is_compatible_with(memory_reads.shape)

        read_strengths = tf.squeeze(read_strengths,
                                    axis=-1)  # [B, H, 1] -> [B, H]

        return memory_reads, read_weights, read_indices, read_strengths
Example #22
0
def generate_detections_per_image_op(cls_outputs,
                                     box_outputs,
                                     anchor_boxes,
                                     image_id,
                                     image_info,
                                     num_detections=100,
                                     pre_nms_num_detections=1000,
                                     nms_threshold=0.3,
                                     bbox_reg_weights=(10., 10., 5., 5.)):
    """Generates detections with model outputs and anchors.

  Args:
    cls_outputs: a Tensor with shape [N, num_classes], which stacks class
      logit outputs on all feature levels. The N is the number of total anchors
      on all levels. The num_classes is the number of classes predicted by the
      model. Note that the cls_outputs should be the output of softmax().
    box_outputs: a Tensor with shape [N, num_classes*4], which stacks
      box regression outputs on all feature levels. The N is the number of total
      anchors on all levels.
    anchor_boxes: a Tensor with shape [N, 4], which stacks anchors on all
      feature levels. The N is the number of total anchors on all levels.
    image_id: an integer number to specify the image id.
    image_info: a tensor of shape [5] which encodes the input image's [height,
      width, scale, original_height, original_width]
    num_detections: Number of detections after NMS.
    pre_nms_num_detections: Number of candidates before NMS.
    nms_threshold: a float number to specify the threshold of NMS.
    bbox_reg_weights: a list of 4 float scalars, which are default weights on
      (dx, dy, dw, dh) for normalizing bbox regression targets.
  Returns:
    detections: detection results in a tensor with each row representing
      [image_id, ymin, xmin, ymax, xmax, score, class]
  """
    num_boxes, num_classes = cls_outputs.get_shape().as_list()

    # Removes background class scores.
    cls_outputs = cls_outputs[:, 1:num_classes]
    top_k_scores, top_k_indices_with_classes = tf.nn.top_k(
        tf.reshape(cls_outputs, [-1]), k=pre_nms_num_detections, sorted=True)
    classes = tf.mod(top_k_indices_with_classes, num_classes - 1)
    top_k_indices = tf.floordiv(top_k_indices_with_classes, num_classes - 1)

    anchor_boxes = tf.gather(anchor_boxes, top_k_indices)
    box_outputs = tf.reshape(box_outputs,
                             [num_boxes, num_classes, 4])[:, 1:num_classes, :]
    box_outputs = tf.gather_nd(box_outputs,
                               tf.stack([top_k_indices, classes], axis=1))

    # Applies bounding box regression to anchors.
    boxes = box_utils.batch_decode_box_outputs_op(
        tf.expand_dims(anchor_boxes, axis=0),
        tf.expand_dims(box_outputs, axis=0), bbox_reg_weights)[0]
    boxes = box_utils.clip_boxes(tf.expand_dims(boxes, axis=0),
                                 tf.expand_dims(image_info[:2], axis=0))[0]

    classes = tf.tile(tf.reshape(classes, [1, pre_nms_num_detections]),
                      [num_classes - 1, 1])
    scores = tf.tile(tf.reshape(top_k_scores, [1, pre_nms_num_detections]),
                     [num_classes - 1, 1])
    boxes = tf.tile(tf.reshape(boxes, [1, pre_nms_num_detections, 4]),
                    [num_classes - 1, 1, 1])

    class_bitmask = tf.tile(
        tf.reshape(tf.range(num_classes - 1), [num_classes - 1, 1]),
        [1, pre_nms_num_detections])
    scores = tf.where(tf.equal(classes, class_bitmask), scores,
                      tf.zeros_like(scores))
    scores = tf.where(tf.greater(scores, 0.05), scores, tf.zeros_like(scores))
    # Reshape classes to be compartible with the top_k function.
    classes = tf.reshape(classes, [num_classes - 1, pre_nms_num_detections, 1])
    scores, sorted_tensors = box_utils.top_k(scores,
                                             k=pre_nms_num_detections,
                                             tensors=[boxes, classes])
    boxes = sorted_tensors[0]
    classes = tf.reshape(sorted_tensors[1],
                         [num_classes - 1, pre_nms_num_detections])

    idx, num_valid = non_max_suppression.non_max_suppression_padded(
        scores,
        boxes,
        max_output_size=num_detections,
        iou_threshold=nms_threshold,
        level=0)

    post_nms_boxes = non_max_suppression.gather_boxes_by_indices(
        boxes, num_detections, idx, num_valid)
    post_nms_scores = non_max_suppression.gather_scores_by_indices(
        scores, num_detections, idx, num_valid)

    # Sorts all results.
    sorted_scores, sorted_indices = tf.nn.top_k(tf.to_float(
        tf.reshape(post_nms_scores, [-1])),
                                                k=num_detections,
                                                sorted=True)
    post_nms_boxes = tf.gather(tf.reshape(post_nms_boxes, [-1, 4]),
                               sorted_indices)
    classes = tf.batch_gather(classes, idx)
    post_nms_classes = tf.gather(tf.reshape(classes, [-1]), sorted_indices) + 1

    if isinstance(image_id, int):
        image_id = tf.constant(image_id)
    image_id = tf.reshape(image_id, [])
    detections_result = tf.stack([
        tf.to_float(tf.fill(tf.shape(sorted_scores), image_id)),
        post_nms_boxes[:, 0],
        post_nms_boxes[:, 1],
        post_nms_boxes[:, 2],
        post_nms_boxes[:, 3],
        sorted_scores,
        tf.to_float(post_nms_classes),
    ],
                                 axis=1)
    return detections_result