Пример #1
0
    def validate_model_independence(self, labels, log_probs, task_parameters):
        """Partition gradients into those assumed active and inactive."""
        num_task_parameters = len(task_parameters)
        # pylint: disable=g-complex-comprehension
        on_gradients = [[
            tf.norm(tensor=on_gradient) for on_gradient in on_gradients
        ] for on_gradients in [
            tf.gradients(ys=tf.gather(log_probs,
                                      tf.compat.v1.where(tf.equal(labels, i))),
                         xs=task_parameters[i * num_task_parameters:(i + 1) *
                                            num_task_parameters])
            for i in range(1)
        ]]
        off_gradients = [[
            tf.norm(tensor=off_gradient) for off_gradient in off_gradients
        ] for off_gradients in [
            tf.gradients(ys=tf.gather(log_probs,
                                      tf.compat.v1.where(tf.equal(labels, i))),
                         xs=task_parameters[i * num_task_parameters:(i + 1) *
                                            num_task_parameters])
            for i in range(1)
        ]]
        # pylint: enable=g-complex-comprehension

        return (list(itertools.chain.from_iterable(on_gradients)),
                list(itertools.chain.from_iterable(off_gradients)))
Пример #2
0
  def update_state(self, inputs, outputs):
    """Function that updates the metric state at each example.

    Args:
      inputs: A dictionary containing input tensors.
      outputs: A dictionary containing output tensors.

    Returns:
      Update op.
    """
    detections_score = tf.reshape(
        outputs[standard_fields.DetectionResultFields.objects_score], [-1])
    detections_class = tf.reshape(
        outputs[standard_fields.DetectionResultFields.objects_class], [-1])
    num_detections = tf.shape(detections_score)[0]
    detections_instance_mask = tf.reshape(
        outputs[
            standard_fields.DetectionResultFields.instance_segments_voxel_mask],
        [num_detections, -1])
    gt_class = tf.reshape(inputs[standard_fields.InputDataFields.objects_class],
                          [-1])
    num_gt = tf.shape(gt_class)[0]
    gt_voxel_instance_ids = tf.reshape(
        inputs[standard_fields.InputDataFields.object_instance_id_voxels], [-1])
    gt_instance_masks = tf.transpose(
        tf.one_hot(gt_voxel_instance_ids - 1, depth=num_gt, dtype=tf.float32))
    for c in self.class_range:
      gt_mask_c = tf.equal(gt_class, c)
      num_gt_c = tf.math.reduce_sum(tf.cast(gt_mask_c, dtype=tf.int32))
      gt_instance_masks_c = tf.boolean_mask(gt_instance_masks, gt_mask_c)
      detections_mask_c = tf.equal(detections_class, c)
      num_detections_c = tf.math.reduce_sum(
          tf.cast(detections_mask_c, dtype=tf.int32))
      if num_detections_c == 0:
        continue
      det_scores_c = tf.boolean_mask(detections_score, detections_mask_c)
      det_instance_mask_c = tf.boolean_mask(detections_instance_mask,
                                            detections_mask_c)
      det_scores_c, sorted_indices = tf.math.top_k(
          det_scores_c, k=num_detections_c)
      det_instance_mask_c = tf.gather(det_instance_mask_c, sorted_indices)
      tp_c = tf.zeros([num_detections_c], dtype=tf.int32)
      if num_gt_c > 0:
        ious_c = instance_segmentation_utils.points_mask_iou(
            masks1=gt_instance_masks_c, masks2=det_instance_mask_c)
        max_overlap_gt_ids = tf.cast(
            tf.math.argmax(ious_c, axis=0), dtype=tf.int32)
        is_gt_box_detected = tf.zeros([num_gt_c], dtype=tf.int32)
        for i in tf.range(num_detections_c):
          gt_id = max_overlap_gt_ids[i]
          if (ious_c[gt_id, i] > self.iou_threshold and
              is_gt_box_detected[gt_id] == 0):
            tp_c = tf.maximum(
                tf.one_hot(i, num_detections_c, dtype=tf.int32), tp_c)
            is_gt_box_detected = tf.maximum(
                tf.one_hot(gt_id, num_gt_c, dtype=tf.int32), is_gt_box_detected)
      self.tp[c] = tf.concat([self.tp[c], tp_c], axis=0)
      self.scores[c] = tf.concat([self.scores[c], det_scores_c], axis=0)
      self.num_gt[c] += num_gt_c
    return tf.no_op()
Пример #3
0
    def convert_to_simclr_episode(support_images=None,
                                  support_labels=None,
                                  support_class_ids=None,
                                  query_images=None,
                                  query_labels=None,
                                  query_class_ids=None):
        """Convert a single episode into a SimCLR Episode."""

        # If there were k query examples of class c, keep the first k support
        # examples of class c as 'simclr' queries.  We do this by assigning an
        # id for each image in the query set, implemented as label*1e5+x+1, where
        # x is the number of images of the same label with a lower index within
        # the query set.  We do the same for the support set, which gives us a
        # mapping between query and support images which is injective (as long
        # as there's enough support-set images of each class).
        #
        # note: assumes max support label is 10000 - max_images_per_class
        query_idx_within_class = tf.cast(
            tf.equal(query_labels[tf.newaxis, :], query_labels[:, tf.newaxis]),
            tf.int32)
        query_idx_within_class = tf.linalg.diag_part(
            tf.cumsum(query_idx_within_class, axis=1))
        query_uid = query_labels * 10000 + query_idx_within_class
        support_idx_within_class = tf.cast(
            tf.equal(support_labels[tf.newaxis, :],
                     support_labels[:, tf.newaxis]), tf.int32)
        support_idx_within_class = tf.linalg.diag_part(
            tf.cumsum(support_idx_within_class, axis=1))
        support_uid = support_labels * 10000 + support_idx_within_class

        # compute which support-set images have matches in the query set, and
        # discard the rest to produce the new query set.
        support_keep = tf.reduce_any(tf.equal(support_uid[:, tf.newaxis],
                                              query_uid[tf.newaxis, :]),
                                     axis=1)
        query_images = tf.boolean_mask(support_images, support_keep)

        support_labels = tf.range(tf.shape(support_labels)[0],
                                  dtype=support_labels.dtype)
        query_labels = tf.boolean_mask(support_labels, support_keep)
        query_class_ids = tf.boolean_mask(support_class_ids, support_keep)

        # Finally, apply SimCLR augmentation to all images.
        # Note simclr only blurs one image.
        query_images = simclr_augment(query_images, blur=True)
        support_images = simclr_augment(support_images)

        return (support_images, support_labels, support_class_ids,
                query_images, query_labels, query_class_ids)
Пример #4
0
def select_slate_greedy(slate_size, s_no_click, s, q):
    """Selects the slate using the adaptive greedy algorithm.

  This algorithm corresponds to the method "GS" in
  Ie et al. https://arxiv.org/abs/1905.12767.

  Args:
    slate_size: int, the size of the recommendation slate.
    s_no_click: float tensor, the score for not clicking any document.
    s: [num_of_documents] tensor, the scores for clicking documents.
    q: [num_of_documents] tensor, the predicted q values for documents.

  Returns:
    [slate_size] tensor, the selected slate.
  """
    def argmax(v, mask):
        return tf.argmax((v - tf.reduce_min(v) + 1) * mask, axis=0)

    numerator = tf.constant(0.)
    denominator = tf.constant(0.) + s_no_click
    mask = tf.ones(tf.shape(q)[0])

    def set_element(v, i, x):
        mask = tf.one_hot(i, tf.shape(v)[0])
        v_new = tf.ones_like(v) * x
        return tf.where(tf.equal(mask, 1), v_new, v)

    for _ in range(slate_size):
        k = argmax((numerator + s * q) / (denominator + s), mask)
        mask = set_element(mask, k, 0)
        numerator = numerator + tf.gather(s * q, k)
        denominator = denominator + tf.gather(s, k)

    output_slate = tf.where(tf.equal(mask, 0))
    return output_slate
Пример #5
0
def pick_labeled_image(mesh_inputs, view_image_inputs, view_indices_2d_inputs,
                       view_name):
    """Pick the image with most number of labeled points projecting to it."""
    if view_name not in view_image_inputs:
        return
    if view_name not in view_indices_2d_inputs:
        return
    if standard_fields.InputDataFields.point_loss_weights not in mesh_inputs:
        raise ValueError('The key `weights` is missing from mesh_inputs.')
    height = tf.shape(view_image_inputs[view_name])[1]
    width = tf.shape(view_image_inputs[view_name])[2]
    valid_points_y = tf.logical_and(
        tf.greater_equal(view_indices_2d_inputs[view_name][:, :, 0], 0),
        tf.less(view_indices_2d_inputs[view_name][:, :, 0], height))
    valid_points_x = tf.logical_and(
        tf.greater_equal(view_indices_2d_inputs[view_name][:, :, 1], 0),
        tf.less(view_indices_2d_inputs[view_name][:, :, 1], width))
    valid_points = tf.logical_and(valid_points_y, valid_points_x)
    image_total_weights = tf.reduce_sum(
        tf.cast(valid_points, dtype=tf.float32) * tf.squeeze(
            mesh_inputs[standard_fields.InputDataFields.point_loss_weights],
            axis=1),
        axis=1)
    image_total_weights = tf.cond(
        tf.equal(tf.reduce_sum(image_total_weights), 0),
        lambda: tf.reduce_sum(tf.cast(valid_points, dtype=tf.float32), axis=1),
        lambda: image_total_weights)
    best_image = tf.math.argmax(image_total_weights)
    view_image_inputs[view_name] = view_image_inputs[view_name][
        best_image:best_image + 1, :, :, :]
    view_indices_2d_inputs[view_name] = view_indices_2d_inputs[view_name][
        best_image:best_image + 1, :, :]
Пример #6
0
  def _set_up_staging(self, transition):
    """Sets up staging ops for prefetching the next transition.

    This allows us to hide the py_func latency. To do so we use a staging area
    to pre-fetch the next batch of transitions.

    Args:
      transition: tuple of tf.Tensors with shape
        memory.get_transition_elements().

    Returns:
      prefetched_transition: tuple of tf.Tensors with shape
        memory.get_transition_elements() that have been previously prefetched.
    """
    transition_type = self.memory.get_transition_elements()

    # Create the staging area in CPU.
    prefetch_area = tf.contrib.staging.StagingArea(
        [shape_with_type.type for shape_with_type in transition_type])

    # Store prefetch op for tests, but keep it private -- users should not be
    # calling _prefetch_batch.
    self._prefetch_batch = prefetch_area.put(transition)
    initial_prefetch = tf.cond(
        tf.equal(prefetch_area.size(), 0),
        lambda: prefetch_area.put(transition), tf.no_op)

    # Every time a transition is sampled self.prefetch_batch will be
    # called. If the staging area is empty, two put ops will be called.
    with tf.control_dependencies([self._prefetch_batch, initial_prefetch]):
      prefetched_transition = prefetch_area.get()

    return prefetched_transition
Пример #7
0
def class_specific_data(onehot_labels, data, num_classes, axis=0):
    # TODO(eringrant): Deal with case of no data for a class in [1...num_classes].
    data_shape = [s for i, s in enumerate(data.shape) if i != axis]
    labels = tf.argmax(onehot_labels, axis=-1)
    class_idx = [tf.where(tf.equal(labels, i)) for i in range(num_classes)]
    return [
        tf.reshape(tf.gather(data, idx, axis=axis), [-1] + data_shape)
        for idx in class_idx
    ]
def compute_target_optimal_q(reward, gamma, next_actions, next_q_values,
                             next_states, terminals):
    """Builds an op used as a target for the Q-value.

  This algorithm corresponds to the method "OT" in
  Ie et al. https://arxiv.org/abs/1905.12767..

  Args:
    reward: [batch_size] tensor, the immediate reward.
    gamma: float, discount factor with the usual RL meaning.
    next_actions: [batch_size, slate_size] tensor, the next slate.
    next_q_values: [batch_size, num_of_documents] tensor, the q values of the
      documents in the next step.
    next_states: [batch_size, 1 + num_of_documents] tensor, the features for the
      user and the docuemnts in the next step.
    terminals: [batch_size] tensor, indicating if this is a terminal step.

  Returns:
    [batch_size] tensor, the target q values.
  """
    scores, score_no_click = _get_unnormalized_scores(next_states)

    # Obtain all possible slates given current docs in the candidate set.
    slate_size = next_actions.get_shape().as_list()[1]
    num_candidates = next_q_values.get_shape().as_list()[1]
    mesh_args = [list(range(num_candidates))] * slate_size
    slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1)
    slates = tf.reshape(slates, shape=(-1, slate_size))
    # Filter slates that include duplicates to ensure each document is picked
    # at most once.
    unique_mask = tf.map_fn(
        lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])),
        slates,
        dtype=tf.bool)
    # [num_of_slates, slate_size]
    slates = tf.boolean_mask(tensor=slates, mask=unique_mask)

    # [batch_size, num_of_slates, slate_size]
    next_q_values_slate = tf.gather(next_q_values, slates, axis=1)
    # [batch_size, num_of_slates, slate_size]
    scores_slate = tf.gather(scores, slates, axis=1)
    # [batch_size, num_of_slates]
    batch_size = next_states.get_shape().as_list()[0]
    score_no_click_slate = tf.reshape(
        tf.tile(score_no_click,
                tf.shape(input=slates)[:1]), [batch_size, -1])

    # [batch_size, num_of_slates]
    next_q_target_slate = tf.reduce_sum(
        input_tensor=next_q_values_slate * scores_slate,
        axis=2) / (tf.reduce_sum(input_tensor=scores_slate, axis=2) +
                   score_no_click_slate)

    next_q_target_max = tf.reduce_max(input_tensor=next_q_target_slate, axis=1)

    return reward + gamma * next_q_target_max * (
        1. - tf.cast(terminals, tf.float32))
Пример #9
0
def filter_before_first_step(time_steps, actions=None):
  flat_time_steps = tf.nest.flatten(time_steps)
  flat_time_steps = [tf.unstack(time_step, axis=1) for time_step in
                     flat_time_steps]
  time_steps = [tf.nest.pack_sequence_as(time_steps, time_step) for time_step in
                zip(*flat_time_steps)]
  if actions is None:
    actions = [None] * len(time_steps)
  else:
    actions = tf.unstack(actions, axis=1)
  assert len(time_steps) == len(actions)

  time_steps = list(reversed(time_steps))
  actions = list(reversed(actions))
  filtered_time_steps = []
  filtered_actions = []
  for t, (time_step, action) in enumerate(zip(time_steps, actions)):
    if t == 0:
      reset_mask = tf.equal(time_step.step_type, ts.StepType.FIRST)
    else:
      time_step = tf.nest.map_structure(lambda x, y: tf.where(reset_mask, x, y),
                                        last_time_step, time_step)
      action = tf.where(reset_mask, tf.zeros_like(action),
                        action) if action is not None else None
    filtered_time_steps.append(time_step)
    filtered_actions.append(action)
    reset_mask = tf.logical_or(
        reset_mask,
        tf.equal(time_step.step_type, ts.StepType.FIRST))
    last_time_step = time_step
  filtered_time_steps = list(reversed(filtered_time_steps))
  filtered_actions = list(reversed(filtered_actions))

  filtered_flat_time_steps = [tf.nest.flatten(time_step) for time_step in
                              filtered_time_steps]
  filtered_flat_time_steps = [tf.stack(time_step, axis=1) for time_step in
                              zip(*filtered_flat_time_steps)]
  filtered_time_steps = tf.nest.pack_sequence_as(filtered_time_steps[0],
                                                 filtered_flat_time_steps)
  if action is None:
    return filtered_time_steps
  else:
    actions = tf.stack(filtered_actions, axis=1)
    return filtered_time_steps, actions
def classification_loss_using_mask_iou_func_unbatched(
        embeddings, instance_ids, sampled_embeddings, sampled_instance_ids,
        sampled_class_labels, sampled_logits, similarity_strategy,
        is_balanced):
    """Classification loss using mask iou.

  Args:
    embeddings: A tf.float32 tensor of size [n, f].
    instance_ids: A tf.int32 tensor of size [n].
    sampled_embeddings: A tf.float32 tensor of size [num_samples, f].
    sampled_instance_ids: A tf.int32 tensor of size [num_samples].
    sampled_class_labels: A tf.int32 tensor of size [num_samples, 1].
    sampled_logits: A tf.float32 tensor of size [num_samples, num_classes].
    similarity_strategy: Defines the method for computing similarity between
                         embedding vectors. Possible values are 'dotproduct' and
                         'distance'.
    is_balanced: If True, the per-voxel losses are re-weighted to have equal
      total weight for foreground vs. background voxels.

  Returns:
    A tf.float32 loss scalar tensor.
  """
    predicted_soft_masks = metric_learning_utils.embedding_centers_to_soft_masks(
        embedding=embeddings,
        centers=sampled_embeddings,
        similarity_strategy=similarity_strategy)
    predicted_masks = tf.cast(tf.greater(predicted_soft_masks, 0.5),
                              dtype=tf.float32)
    gt_masks = tf.cast(tf.equal(tf.expand_dims(sampled_instance_ids, axis=1),
                                tf.expand_dims(instance_ids, axis=0)),
                       dtype=tf.float32)
    pairwise_iou = instance_segmentation_utils.points_mask_pairwise_iou(
        masks1=predicted_masks, masks2=gt_masks)
    num_classes = sampled_logits.get_shape().as_list()[1]
    sampled_class_labels_one_hot = tf.one_hot(indices=tf.reshape(
        sampled_class_labels, [-1]),
                                              depth=num_classes)
    sampled_class_labels_one_hot_fg = sampled_class_labels_one_hot[:, 1:]
    iou_coefs = tf.tile(tf.reshape(pairwise_iou, [-1, 1]),
                        [1, num_classes - 1])
    sampled_class_labels_one_hot_fg *= iou_coefs
    sampled_class_labels_one_hot_bg = tf.maximum(
        1.0 - tf.math.reduce_sum(
            sampled_class_labels_one_hot_fg, axis=1, keepdims=True), 0.0)
    sampled_class_labels_one_hot = tf.concat(
        [sampled_class_labels_one_hot_bg, sampled_class_labels_one_hot_fg],
        axis=1)
    params = {}
    if is_balanced:
        weights = loss_utils.get_balanced_loss_weights_multiclass(
            labels=tf.expand_dims(sampled_instance_ids, axis=1))
        params['weights'] = weights
    return classification_loss_fn(logits=sampled_logits,
                                  labels=sampled_class_labels_one_hot,
                                  **params)
Пример #11
0
def _get_class_labels_and_predictions(labels, logits, num_classes,
                                      multi_label):
    """Returns list of per-class-labels and list of per-class-predictions.

  Args:
    labels: A `Tensor` of size [n, k]. In the
      multi-label case, values are either 0 or 1 and k = num_classes. Otherwise,
      k = 1 and values are in [0, num_classes).
    logits: A `Tensor` of size [n, `num_classes`]
      representing the logits of each pixel and semantic class.
    num_classes: Number of classes.
    multi_label: Boolean which defines if we are in a multi_label setting, where
      pixels can have multiple labels, or not.

  Returns:
    class_labels: List of size num_classes, where each entry is a `Tensor' of
      size [batch_size, height, width] of type float with values of 0 or 1
      representing the ground truth labels.
    class_predictions: List of size num_classes, each entry is a `Tensor' of
      size [batch_size, height, width] of type float with values of 0 or 1
      representing the predicted labels.
  """
    class_predictions = [None] * num_classes
    if multi_label:
        class_labels = tf.split(labels, num_or_size_splits=num_classes, axis=1)
        class_logits = tf.split(logits, num_or_size_splits=num_classes, axis=1)
        for c in range(num_classes):
            class_predictions[c] = tf.cast(tf.greater(class_logits[c], 0),
                                           dtype=tf.float32)
    else:
        class_predictions_flat = tf.argmax(logits, 1)
        class_labels = [None] * num_classes
        for c in range(num_classes):
            class_labels[c] = tf.cast(tf.equal(labels, c), dtype=tf.float32)
            class_predictions[c] = tf.cast(tf.equal(class_predictions_flat, c),
                                           dtype=tf.float32)
    return class_labels, class_predictions
Пример #12
0
def _remove_second_return_lidar_points(mesh_inputs, view_indices_2d_inputs):
    """removes the points that are not lidar first-return ."""
    if standard_fields.InputDataFields.point_spin_coordinates not in mesh_inputs:
        raise ValueError('spin_coordinates not in mesh_inputs.')
    first_return_mask = tf.equal(
        tf.cast(mesh_inputs[
            standard_fields.InputDataFields.point_spin_coordinates][:, 2],
                dtype=tf.int32), 0)
    for key in sorted(mesh_inputs):
        mesh_inputs[key] = tf.boolean_mask(mesh_inputs[key], first_return_mask)
    for key in sorted(view_indices_2d_inputs):
        view_indices_2d_inputs[key] = tf.transpose(
            tf.boolean_mask(
                tf.transpose(view_indices_2d_inputs[key], [1, 0, 2]),
                first_return_mask), [1, 0, 2])
Пример #13
0
    def _build_train_op(self):
        """Builds a training op.

    Returns:
      An op performing one step of training from replay data.
    """
        # click_indicator: [B, S]
        # q_values: [B, A]
        # actions: [B, S]
        # slate_q_values: [B, S]
        # replay_click_q: [B]
        click_indicator = self._replay.rewards[:, :,
                                               self._click_response_index]
        slate_q_values = tf.compat.v1.batch_gather(
            self._replay_net_outputs.q_values,
            tf.cast(self._replay.actions, dtype=tf.int32))
        # Only get the Q from the clicked document.
        replay_click_q = tf.reduce_sum(input_tensor=slate_q_values *
                                       click_indicator,
                                       axis=1,
                                       name='replay_click_q')

        target = tf.stop_gradient(self._build_target_q_op())

        clicked = tf.reduce_sum(input_tensor=click_indicator, axis=1)
        clicked_indices = tf.squeeze(tf.compat.v1.where(tf.equal(clicked, 1)),
                                     axis=1)
        # clicked_indices is a vector and tf.gather selects the batch dimension.
        q_clicked = tf.gather(replay_click_q, clicked_indices)
        target_clicked = tf.gather(target, clicked_indices)

        def get_train_op():
            loss = tf.reduce_mean(input_tensor=tf.square(q_clicked -
                                                         target_clicked))
            if self.summary_writer is not None:
                with tf.compat.v1.variable_scope('Losses'):
                    tf.compat.v1.summary.scalar('Loss', loss)

            return loss

        loss = tf.cond(pred=tf.greater(tf.reduce_sum(input_tensor=clicked), 0),
                       true_fn=get_train_op,
                       false_fn=lambda: tf.constant(0.),
                       name='')

        return self.optimizer.minimize(loss)
Пример #14
0
  def compute_accuracy(self, onehot_labels, predictions):
    """Computes the accuracy of `predictions` with respect to `onehot_labels`.

    Args:
      onehot_labels: A `tf.Tensor` containing the the class labels; each vector
        along the (last) class dimension is expected to contain only a single
        `1`.
      predictions: A `tf.Tensor` containing the the class predictions
        represented as unnormalized log probabilities.

    Returns:
       A `tf.Tensor` of ones and zeros representing the correctness of
       individual predictions; use `tf.reduce_mean(...)` to obtain the average
       accuracy.
    """
    correct = tf.equal(tf.argmax(onehot_labels, -1), tf.argmax(predictions, -1))
    return tf.cast(correct, tf.float32)
def _box_classification_loss_unbatched(inputs_1, outputs_1, is_intermediate,
                                       is_balanced, mine_hard_negatives,
                                       hard_negative_score_threshold):
    """Loss function for input and outputs of batch size 1."""
    valid_mask = _get_voxels_valid_mask(inputs_1=inputs_1)
    if is_intermediate:
        logits = outputs_1[standard_fields.DetectionResultFields.
                           intermediate_object_semantic_voxels]
    else:
        logits = outputs_1[
            standard_fields.DetectionResultFields.object_semantic_voxels]
    num_classes = logits.get_shape().as_list()[-1]
    if num_classes is None:
        raise ValueError('Number of classes is unknown.')
    logits = tf.boolean_mask(tf.reshape(logits, [-1, num_classes]), valid_mask)
    labels = tf.boolean_mask(
        tf.reshape(
            inputs_1[standard_fields.InputDataFields.object_class_voxels],
            [-1, 1]), valid_mask)
    if mine_hard_negatives or is_balanced:
        instances = tf.boolean_mask(
            tf.reshape(
                inputs_1[
                    standard_fields.InputDataFields.object_instance_id_voxels],
                [-1]), valid_mask)
    params = {}
    if mine_hard_negatives:
        negative_scores = tf.reshape(tf.nn.softmax(logits)[:, 0], [-1])
        hard_negative_mask = tf.logical_and(
            tf.less(negative_scores, hard_negative_score_threshold),
            tf.equal(tf.reshape(labels, [-1]), 0))
        hard_negative_labels = tf.boolean_mask(labels, hard_negative_mask)
        hard_negative_logits = tf.boolean_mask(logits, hard_negative_mask)
        hard_negative_instances = tf.boolean_mask(
            tf.ones_like(instances) * (tf.reduce_max(instances) + 1),
            hard_negative_mask)
        logits = tf.concat([logits, hard_negative_logits], axis=0)
        instances = tf.concat([instances, hard_negative_instances], axis=0)
        labels = tf.concat([labels, hard_negative_labels], axis=0)
    if is_balanced:
        weights = loss_utils.get_balanced_loss_weights_multiclass(
            labels=tf.expand_dims(instances, axis=1))
        params['weights'] = weights
    return classification_loss_fn(logits=logits, labels=labels, **params)
Пример #16
0
    def host_call_fn(**kwargs):
        """Host_call_fn.

    Args:
      **kwargs: dict of summary name to tf.Tensor mapping. The value we see here
        is the tensor across all cores, concatenated along axis 0. This function
        will take make a scalar summary that is the mean of the whole tensor (as
        all the values are the same - the mean, trait of
        tpu.CrossShardOptimizer).

    Returns:
      A merged summary op.
    """
        gs = kwargs.pop('global_step')[0]
        with tf_summary.create_file_writer(model_dir).as_default():
            with tf_summary.record_if(tf.equal(gs % 10, 0)):
                for name, tensor in kwargs.items():
                    # Take the mean across cores.
                    tensor = tf.reduce_mean(tensor)
                    tf_summary.scalar(name, tensor, step=gs)
                return tf.summary.all_v2_summary_ops()
Пример #17
0
def random_flip_up_down(images, flow=None, mask=None):
  """Performs a random up/down flip."""
  # 50/50 chance
  perform_flip = tf.equal(tf.random.uniform([], maxval=2, dtype=tf.int32), 1)
  # apply flip
  images = tf.cond(pred=perform_flip,
                   true_fn=lambda: tf.reverse(images, axis=[-3]),
                   false_fn=lambda: images)
  if flow is not None:
    flow = tf.cond(pred=perform_flip,
                   true_fn=lambda: tf.reverse(flow, axis=[-3]),
                   false_fn=lambda: flow)
    mask = tf.cond(pred=perform_flip,
                   true_fn=lambda: tf.reverse(mask, axis=[-3]),
                   false_fn=lambda: mask)
    # correct sign of flow
    sign_correction = tf.reshape([-1.0, 1.0], [1, 1, 2])
    flow = tf.cond(pred=perform_flip,
                   true_fn=lambda: flow * sign_correction,
                   false_fn=lambda: flow)
  return images, flow, mask
Пример #18
0
def select_slate_optimal(slate_size, s_no_click, s, q):
    """Selects the slate using exhaustive search.

  This algorithm corresponds to the method "OS" in
  Ie et al. https://arxiv.org/abs/1905.12767.

  Args:
    slate_size: int, the size of the recommendation slate.
    s_no_click: float tensor, the score for not clicking any document.
    s: [num_of_documents] tensor, the scores for clicking documents.
    q: [num_of_documents] tensor, the predicted q values for documents.

  Returns:
    [slate_size] tensor, the selected slate.
  """

    num_candidates = s.shape.as_list()[0]

    # Obtain all possible slates given current docs in the candidate set.
    mesh_args = [list(range(num_candidates))] * slate_size
    slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1)
    slates = tf.reshape(slates, shape=(-1, slate_size))

    # Filter slates that include duplicates to ensure each document is picked
    # at most once.
    unique_mask = tf.map_fn(
        lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])),
        slates,
        dtype=tf.bool)
    slates = tf.boolean_mask(tensor=slates, mask=unique_mask)

    slate_q_values = tf.gather(s * q, slates)
    slate_scores = tf.gather(s, slates)
    slate_normalizer = tf.reduce_sum(input_tensor=slate_scores,
                                     axis=1) + s_no_click

    slate_q_values = slate_q_values / tf.expand_dims(slate_normalizer, 1)
    slate_sum_q_values = tf.reduce_sum(input_tensor=slate_q_values, axis=1)
    max_q_slate_index = tf.argmax(input=slate_sum_q_values)
    return tf.gather(slates, max_q_slate_index, axis=0)
def _voxel_hard_negative_classification_loss_unbatched(inputs_1, outputs_1,
                                                       is_intermediate, gamma):
    """Loss function for input and outputs of batch size 1."""
    inputs_1, outputs_1 = _get_voxels_valid_inputs_outputs(inputs_1=inputs_1,
                                                           outputs_1=outputs_1)
    if is_intermediate:
        logits = outputs_1[standard_fields.DetectionResultFields.
                           intermediate_object_semantic_voxels]
    else:
        logits = outputs_1[
            standard_fields.DetectionResultFields.object_semantic_voxels]
    labels = tf.reshape(
        inputs_1[standard_fields.InputDataFields.object_class_voxels], [-1])
    background_mask = tf.equal(labels, 0)
    num_background_points = tf.reduce_sum(
        tf.cast(background_mask, dtype=tf.int32))

    def loss_fn():
        """Loss function."""
        num_classes = logits.get_shape().as_list()[-1]
        if num_classes is None:
            raise ValueError('Number of classes is unknown.')
        masked_logits = tf.boolean_mask(logits, background_mask)
        masked_weights = tf.pow(
            1.0 - tf.reshape(tf.nn.softmax(masked_logits)[:, 0], [-1, 1]),
            gamma)
        num_points = tf.shape(masked_logits)[0]
        masked_weights = masked_weights * tf.cast(
            num_points, dtype=tf.float32) / tf.reduce_sum(masked_weights)
        masked_labels_one_hot = tf.one_hot(indices=tf.boolean_mask(
            labels, background_mask),
                                           depth=num_classes)
        loss = classification_loss_fn(logits=masked_logits,
                                      labels=masked_labels_one_hot,
                                      weights=masked_weights)
        return loss

    cond = tf.logical_and(tf.greater(num_background_points, 0),
                          tf.greater(tf.shape(labels)[0], 0))
    return tf.cond(cond, loss_fn, lambda: tf.constant(0.0, dtype=tf.float32))
Пример #20
0
def compute_episode_stats(episode):
    """Computes various episode stats: way, shots, and class IDs.

  Args:
    episode: An EpisodeDataset.

  Returns:
    way: An int constant tensor. The number of classes in the episode.
    shots: An int 1D tensor: The number of support examples per class.
    class_ids: An int 1D tensor: (absolute) class IDs.
  """
    # The train labels of the next episode.
    train_labels = episode.train_labels
    # Compute way.
    episode_classes, _ = tf.unique(train_labels)
    way = tf.size(episode_classes)
    # Compute shots.
    class_ids = tf.reshape(tf.range(way), [way, 1])
    class_labels = tf.reshape(train_labels, [1, -1])
    is_equal = tf.equal(class_labels, class_ids)
    shots = tf.reduce_sum(tf.cast(is_equal, tf.int32), axis=1)
    # Compute class_ids.
    class_ids, _ = tf.unique(episode.train_class_ids)
    return way, shots, class_ids
Пример #21
0
 def set_element(v, i, x):
     mask = tf.one_hot(i, tf.shape(input=v)[0])
     v_new = tf.ones_like(v) * x
     return tf.where(tf.equal(mask, 1), v_new, v)
Пример #22
0
def false_n(y, y_hat, name='fn'):
    return tf.logical_and(
        tf.equal(y, True), tf.equal(y_hat, False), name=name)
Пример #23
0
def acc(y, y_hat, name='accuracy'):
    return tf.cast(tf.equal(y, y_hat, name=name), dtype=tf.int32)
def prepare_lidar_images_and_correspondences(
    inputs,
    resized_image_height,
    resized_image_width,
    camera_names=('front', 'front_left', 'front_right', 'side_left',
                  'side_right'),
    lidar_names=('top', 'front', 'side_left', 'side_right', 'rear')):
  """Integrates and returns the lidars, cameras and their correspondences.

  Args:
    inputs: A dictionary containing the images and point / pixel
      correspondences.
    resized_image_height: Target height of the images.
    resized_image_width: Target width of the images.
    camera_names: List of cameras to include images from.
    lidar_names: List of lidars to include point clouds from.

  Returns:
    A tf.float32 tensor of size [num_points, 3] containing point positions.
    A tf.float32 tensor of size [num_points, 1] containing point intensities.
    A tf.float32 tensor of size [num_points, 1] containing point elongations.
    A tf.float32 tensor of size [num_points, 3] containing point normals.
    A tf.float32 tensor of size [num_images, resized_image_height,
      resized_image_width, 3].
    A tf.int32 tensor of size [num_images, num_points, 2].

  Raises:
    ValueError: If camera_names or lidar_names are empty lists.
  """
  if not camera_names:
    raise ValueError('camera_names should contain at least one name.')
  if not lidar_names:
    raise ValueError('lidar_names should contain at least one name.')

  (points_position, points_intensity, points_elongation, points_normal,
   points_in_image_frame_yx, points_in_image_frame_id) = _prepare_lidar_points(
       inputs=inputs, lidar_names=lidar_names)

  images = []
  points_in_image_frame = []

  for camera_name in camera_names:
    image_key = ('cameras/%s/image' % camera_name)
    image_height = tf.shape(inputs[image_key])[0]
    image_width = tf.shape(inputs[image_key])[1]
    height_ratio = tf.cast(
        resized_image_height, dtype=tf.float32) / tf.cast(
            image_height, dtype=tf.float32)
    width_ratio = tf.cast(
        resized_image_width, dtype=tf.float32) / tf.cast(
            image_width, dtype=tf.float32)
    if tf.executing_eagerly():
      resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
    else:
      resize_method = tf.image.ResizeMethod.BILINEAR
      if inputs[image_key].dtype in [
          tf.int8, tf.uint8, tf.int16, tf.uint16, tf.int32, tf.int64
      ]:
        resize_method = tf.image.ResizeMethod.NEAREST_NEIGHBOR
    images.append(
        tf.image.resize(
            images=inputs[image_key],
            size=[resized_image_height, resized_image_width],
            method=resize_method,
            antialias=True))
    camera_id = tf.cast(inputs[('cameras/%s/id' % camera_name)], dtype=tf.int32)
    valid_points = tf.equal(points_in_image_frame_id, camera_id)
    valid_points = tf.tile(valid_points, [1, 2])
    point_coords = tf.cast(
        tf.cast(points_in_image_frame_yx, dtype=tf.float32) *
        tf.stack([height_ratio, width_ratio]),
        dtype=tf.int32)
    points_in_image_frame_camera = tf.where(
        valid_points, point_coords, -tf.ones_like(valid_points, dtype=tf.int32))
    points_in_image_frame.append(points_in_image_frame_camera)
  num_images = len(images)
  images = tf.stack(images, axis=0)
  images.set_shape([num_images, resized_image_height, resized_image_width, 3])
  points_in_image_frame = tf.stack(points_in_image_frame, axis=0)
  return {
      'points_position': points_position,
      'points_intensity': points_intensity,
      'points_elongation': points_elongation,
      'points_normal': points_normal,
      'view_images': {'rgb_view': images},
      'view_indices_2d': {'rgb_view': points_in_image_frame}
  }
Пример #25
0
def linear_classifier(embeddings, num_classes, cosine_classifier,
                      cosine_logits_multiplier, use_weight_norm, weight_decay):
    """Forward pass through a linear classifier, or possibly a cosine classifier.

  Args:
    embeddings: A Tensor of size [batch size, embedding dim].
    num_classes: An integer; the dimension of the classification.
    cosine_classifier: A bool. If true, a cosine classifier is used, which does
      not require a bias.
    cosine_logits_multiplier: A float. Only used if cosine_classifier is True,
      and multiplies the resulting logits.
    use_weight_norm: A bool. Whether weight norm was used. If so, then if using
      cosine classifier, normalize only the embeddings but not the weights.
    weight_decay: A float; the scalar multiple on the L2 regularization of the
      weight matrix.

  Returns:
    logits: A Tensor of size [batch size, num outputs].
  """

    embedding_dims = embeddings.get_shape().as_list()[-1]

    if use_weight_norm:
        # A variable to keep track of whether the initialization has already
        # happened.
        data_dependent_init_done = tf.get_variable('data_dependent_init_done',
                                                   initializer=0,
                                                   dtype=tf.int32,
                                                   trainable=False)

        w_fc = tf.get_variable('w_fc', [embedding_dims, num_classes],
                               initializer=tf.random_normal_initializer(
                                   0, 0.05),
                               trainable=True)
        # This init is temporary as it needs to be done in a data-dependent way.
        # It will be overwritten during the first forward pass through this layer.
        g = tf.get_variable('g',
                            dtype=tf.float32,
                            initializer=tf.ones([num_classes]),
                            trainable=True)
        b_fc = None
        if not cosine_classifier:
            # Also initialize a bias.
            b_fc = tf.get_variable('b_fc',
                                   initializer=tf.zeros([num_classes]),
                                   trainable=True)

        def _do_data_dependent_init():
            """Returns ops for the data-dependent init of g and maybe b_fc."""
            w_fc_normalized = tf.nn.l2_normalize(w_fc.read_value(), [0])
            output_init = tf.matmul(embeddings, w_fc_normalized)
            mean_init, var_init = tf.nn.moments(output_init, [0])
            # Data-dependent init values.
            g_init_value = 1. / tf.sqrt(var_init + 1e-10)
            ops = [tf.assign(g, g_init_value)]
            if not cosine_classifier:
                # Also initialize a bias in a data-dependent way.
                b_fc_init_value = -mean_init * g_init_value
                ops.append(tf.assign(b_fc, b_fc_init_value))
            # Mark that the data-dependent initialization is done to prevent it from
            # happening again in the future.
            ops.append(tf.assign(data_dependent_init_done, 1))
            return tf.group(*ops)

        # Possibly perform data-dependent init (if it hasn't been done already).
        init_op = tf.cond(tf.equal(data_dependent_init_done, 0),
                          _do_data_dependent_init, tf.no_op)

        with tf.control_dependencies([init_op]):
            # Apply weight normalization.
            w_fc *= g / tf.sqrt(tf.reduce_sum(tf.square(w_fc), [0]))
            # Forward pass through the layer defined by w_fc and b_fc.
            logits = linear_classifier_forward_pass(embeddings, w_fc, b_fc,
                                                    cosine_classifier,
                                                    cosine_logits_multiplier,
                                                    True)

    else:
        # No weight norm.
        w_fc = functional_backbones.weight_variable(
            [embedding_dims, num_classes], weight_decay=weight_decay)
        b_fc = None
        if not cosine_classifier:
            # Also initialize a bias.
            b_fc = functional_backbones.bias_variable([num_classes])
        # Forward pass through the layer defined by w_fc and b_fc.
        logits = linear_classifier_forward_pass(embeddings, w_fc, b_fc,
                                                cosine_classifier,
                                                cosine_logits_multiplier,
                                                False)
    return logits
Пример #26
0
    def update_state(self, inputs, outputs):
        """Function that updates the metric state at each example.

    Args:
      inputs: A dictionary containing input tensors.
      outputs: A dictionary containing output tensors.

    Returns:
      Update op.
    """
        detections_score = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_score], [-1])
        detections_class = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_class], [-1])
        detections_length = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_length],
            [-1])
        detections_height = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_height],
            [-1])
        detections_width = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_width], [-1])
        detections_center = tf.reshape(
            outputs[standard_fields.DetectionResultFields.objects_center],
            [-1, 3])
        detections_rotation_matrix = tf.reshape(
            outputs[
                standard_fields.DetectionResultFields.objects_rotation_matrix],
            [-1, 3, 3])
        gt_class = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_class], [-1])
        gt_length = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_length], [-1])
        gt_height = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_height], [-1])
        gt_width = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_width], [-1])
        gt_center = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_center], [-1, 3])
        gt_rotation_matrix = tf.reshape(
            inputs[standard_fields.InputDataFields.objects_rotation_matrix],
            [-1, 3, 3])
        for c in self.class_range:
            gt_mask_c = tf.equal(gt_class, c)
            num_gt_c = tf.math.reduce_sum(tf.cast(gt_mask_c, dtype=tf.int32))
            gt_length_c = tf.boolean_mask(gt_length, gt_mask_c)
            gt_height_c = tf.boolean_mask(gt_height, gt_mask_c)
            gt_width_c = tf.boolean_mask(gt_width, gt_mask_c)
            gt_center_c = tf.boolean_mask(gt_center, gt_mask_c)
            gt_rotation_matrix_c = tf.boolean_mask(gt_rotation_matrix,
                                                   gt_mask_c)
            detections_mask_c = tf.equal(detections_class, c)
            num_detections_c = tf.math.reduce_sum(
                tf.cast(detections_mask_c, dtype=tf.int32))
            if num_detections_c == 0:
                continue
            det_length_c = tf.boolean_mask(detections_length,
                                           detections_mask_c)
            det_height_c = tf.boolean_mask(detections_height,
                                           detections_mask_c)
            det_width_c = tf.boolean_mask(detections_width, detections_mask_c)
            det_center_c = tf.boolean_mask(detections_center,
                                           detections_mask_c)
            det_rotation_matrix_c = tf.boolean_mask(detections_rotation_matrix,
                                                    detections_mask_c)
            det_scores_c = tf.boolean_mask(detections_score, detections_mask_c)
            det_scores_c, sorted_indices = tf.math.top_k(det_scores_c,
                                                         k=num_detections_c)
            det_length_c = tf.gather(det_length_c, sorted_indices)
            det_height_c = tf.gather(det_height_c, sorted_indices)
            det_width_c = tf.gather(det_width_c, sorted_indices)
            det_center_c = tf.gather(det_center_c, sorted_indices)
            det_rotation_matrix_c = tf.gather(det_rotation_matrix_c,
                                              sorted_indices)
            tp_c = tf.zeros([num_detections_c], dtype=tf.int32)
            if num_gt_c > 0:
                ious_c = box_ops.iou3d(
                    boxes1_length=gt_length_c,
                    boxes1_height=gt_height_c,
                    boxes1_width=gt_width_c,
                    boxes1_center=gt_center_c,
                    boxes1_rotation_matrix=gt_rotation_matrix_c,
                    boxes2_length=det_length_c,
                    boxes2_height=det_height_c,
                    boxes2_width=det_width_c,
                    boxes2_center=det_center_c,
                    boxes2_rotation_matrix=det_rotation_matrix_c)
                max_overlap_gt_ids = tf.cast(tf.math.argmax(ious_c, axis=0),
                                             dtype=tf.int32)
                is_gt_box_detected = tf.zeros([num_gt_c], dtype=tf.int32)
                for i in tf.range(num_detections_c):
                    gt_id = max_overlap_gt_ids[i]
                    if (ious_c[gt_id, i] > self.iou_threshold
                            and is_gt_box_detected[gt_id] == 0):
                        tp_c = tf.maximum(
                            tf.one_hot(i, num_detections_c, dtype=tf.int32),
                            tp_c)
                        is_gt_box_detected = tf.maximum(
                            tf.one_hot(gt_id, num_gt_c, dtype=tf.int32),
                            is_gt_box_detected)
            self.tp[c] = tf.concat([self.tp[c], tp_c], axis=0)
            self.scores[c] = tf.concat([self.scores[c], det_scores_c], axis=0)
            self.num_gt[c] += num_gt_c
        return tf.no_op()
Пример #27
0
def true_n(y, y_hat, name='tn'):
    return tf.logical_and(
        tf.equal(y, False), tf.equal(y_hat, False), name=name)
Пример #28
0
def true_p(y, y_hat, name='tp'):
    return tf.logical_and(
        tf.equal(y, True), tf.equal(y_hat, True), name=name)
Пример #29
0
 def true_fn(images):
   r = tf.random.uniform([], maxval=3, dtype=tf.int32)
   images = tf.roll(images, r, axis=-1)
   r = tf.equal(tf.random.uniform([], maxval=2, dtype=tf.int32), 1)
   return tf.reverse(images, axis=[-1])
Пример #30
0
def project_distribution(supports, weights, target_support,
                         validate_args=False):
  """Projects a batch of (support, weights) onto target_support.

  Based on equation (7) in (Bellemare et al., 2017):
    https://arxiv.org/abs/1707.06887
  In the rest of the comments we will refer to this equation simply as Eq7.

  This code is not easy to digest, so we will use a running example to clarify
  what is going on, with the following sample inputs:
    * supports =       [[0, 2, 4, 6, 8],
                        [1, 3, 4, 5, 6]]
    * weights =        [[0.1, 0.6, 0.1, 0.1, 0.1],
                        [0.1, 0.2, 0.5, 0.1, 0.1]]
    * target_support = [4, 5, 6, 7, 8]
  In the code below, comments preceded with 'Ex:' will be referencing the above
  values.

  Args:
    supports: Tensor of shape (batch_size, num_dims) defining supports for the
      distribution.
    weights: Tensor of shape (batch_size, num_dims) defining weights on the
      original support points. Although for the CategoricalDQN agent these
      weights are probabilities, it is not required that they are.
    target_support: Tensor of shape (num_dims) defining support of the projected
      distribution. The values must be monotonically increasing. Vmin and Vmax
      will be inferred from the first and last elements of this tensor,
      respectively. The values in this tensor must be equally spaced.
    validate_args: Whether we will verify the contents of the
      target_support parameter.

  Returns:
    A Tensor of shape (batch_size, num_dims) with the projection of a batch of
    (support, weights) onto target_support.

  Raises:
    ValueError: If target_support has no dimensions, or if shapes of supports,
      weights, and target_support are incompatible.
  """
  target_support_deltas = target_support[1:] - target_support[:-1]
  # delta_z = `\Delta z` in Eq7.
  delta_z = target_support_deltas[0]
  validate_deps = []
  supports.shape.assert_is_compatible_with(weights.shape)
  supports[0].shape.assert_is_compatible_with(target_support.shape)
  target_support.shape.assert_has_rank(1)
  if validate_args:
    # Assert that supports and weights have the same shapes.
    validate_deps.append(
        tf.Assert(
            tf.reduce_all(tf.equal(tf.shape(supports), tf.shape(weights))),
            [supports, weights]))
    # Assert that elements of supports and target_support have the same shape.
    validate_deps.append(
        tf.Assert(
            tf.reduce_all(
                tf.equal(tf.shape(supports)[1], tf.shape(target_support))),
            [supports, target_support]))
    # Assert that target_support has a single dimension.
    validate_deps.append(
        tf.Assert(
            tf.equal(tf.size(tf.shape(target_support)), 1), [target_support]))
    # Assert that the target_support is monotonically increasing.
    validate_deps.append(
        tf.Assert(tf.reduce_all(target_support_deltas > 0), [target_support]))
    # Assert that the values in target_support are equally spaced.
    validate_deps.append(
        tf.Assert(
            tf.reduce_all(tf.equal(target_support_deltas, delta_z)),
            [target_support]))

  with tf.control_dependencies(validate_deps):
    # Ex: `v_min, v_max = 4, 8`.
    v_min, v_max = target_support[0], target_support[-1]
    # Ex: `batch_size = 2`.
    batch_size = tf.shape(supports)[0]
    # `N` in Eq7.
    # Ex: `num_dims = 5`.
    num_dims = tf.shape(target_support)[0]
    # clipped_support = `[\hat{T}_{z_j}]^{V_max}_{V_min}` in Eq7.
    # Ex: `clipped_support = [[[ 4.  4.  4.  6.  8.]]
    #                         [[ 4.  4.  4.  5.  6.]]]`.
    clipped_support = tf.clip_by_value(supports, v_min, v_max)[:, None, :]
    # Ex: `tiled_support = [[[[ 4.  4.  4.  6.  8.]
    #                         [ 4.  4.  4.  6.  8.]
    #                         [ 4.  4.  4.  6.  8.]
    #                         [ 4.  4.  4.  6.  8.]
    #                         [ 4.  4.  4.  6.  8.]]
    #                        [[ 4.  4.  4.  5.  6.]
    #                         [ 4.  4.  4.  5.  6.]
    #                         [ 4.  4.  4.  5.  6.]
    #                         [ 4.  4.  4.  5.  6.]
    #                         [ 4.  4.  4.  5.  6.]]]]`.
    tiled_support = tf.tile([clipped_support], [1, 1, num_dims, 1])
    # Ex: `reshaped_target_support = [[[ 4.]
    #                                  [ 5.]
    #                                  [ 6.]
    #                                  [ 7.]
    #                                  [ 8.]]
    #                                 [[ 4.]
    #                                  [ 5.]
    #                                  [ 6.]
    #                                  [ 7.]
    #                                  [ 8.]]]`.
    reshaped_target_support = tf.tile(target_support[:, None], [batch_size, 1])
    reshaped_target_support = tf.reshape(reshaped_target_support,
                                         [batch_size, num_dims, 1])
    # numerator = `|clipped_support - z_i|` in Eq7.
    # Ex: `numerator = [[[[ 0.  0.  0.  2.  4.]
    #                     [ 1.  1.  1.  1.  3.]
    #                     [ 2.  2.  2.  0.  2.]
    #                     [ 3.  3.  3.  1.  1.]
    #                     [ 4.  4.  4.  2.  0.]]
    #                    [[ 0.  0.  0.  1.  2.]
    #                     [ 1.  1.  1.  0.  1.]
    #                     [ 2.  2.  2.  1.  0.]
    #                     [ 3.  3.  3.  2.  1.]
    #                     [ 4.  4.  4.  3.  2.]]]]`.
    numerator = tf.abs(tiled_support - reshaped_target_support)
    quotient = 1 - (numerator / delta_z)
    # clipped_quotient = `[1 - numerator / (\Delta z)]_0^1` in Eq7.
    # Ex: `clipped_quotient = [[[[ 1.  1.  1.  0.  0.]
    #                            [ 0.  0.  0.  0.  0.]
    #                            [ 0.  0.  0.  1.  0.]
    #                            [ 0.  0.  0.  0.  0.]
    #                            [ 0.  0.  0.  0.  1.]]
    #                           [[ 1.  1.  1.  0.  0.]
    #                            [ 0.  0.  0.  1.  0.]
    #                            [ 0.  0.  0.  0.  1.]
    #                            [ 0.  0.  0.  0.  0.]
    #                            [ 0.  0.  0.  0.  0.]]]]`.
    clipped_quotient = tf.clip_by_value(quotient, 0, 1)
    # Ex: `weights = [[ 0.1  0.6  0.1  0.1  0.1]
    #                 [ 0.1  0.2  0.5  0.1  0.1]]`.
    weights = weights[:, None, :]
    # inner_prod = `\sum_{j=0}^{N-1} clipped_quotient * p_j(x', \pi(x'))`
    # in Eq7.
    # Ex: `inner_prod = [[[[ 0.1  0.6  0.1  0.  0. ]
    #                      [ 0.   0.   0.   0.  0. ]
    #                      [ 0.   0.   0.   0.1 0. ]
    #                      [ 0.   0.   0.   0.  0. ]
    #                      [ 0.   0.   0.   0.  0.1]]
    #                     [[ 0.1  0.2  0.5  0.  0. ]
    #                      [ 0.   0.   0.   0.1 0. ]
    #                      [ 0.   0.   0.   0.  0.1]
    #                      [ 0.   0.   0.   0.  0. ]
    #                      [ 0.   0.   0.   0.  0. ]]]]`.
    inner_prod = clipped_quotient * weights
    # Ex: `projection = [[ 0.8 0.0 0.1 0.0 0.1]
    #                    [ 0.8 0.1 0.1 0.0 0.0]]`.
    projection = tf.reduce_sum(inner_prod, 3)
    projection = tf.reshape(projection, [batch_size, num_dims])
    return projection