コード例 #1
0
def pointcloud_to_voxel_grid(points,
                             features,
                             grid_cell_size,
                             start_location,
                             end_location,
                             segment_func=tf.math.unsorted_segment_mean):
  """Converts a pointcloud into a voxel grid.

  Args:
    points: A tf.float32 tensor of size [N, 3].
    features: A tf.float32 tensor of size [N, F].
    grid_cell_size: A tf.float32 tensor of size [3].
    start_location: A tf.float32 tensor of size [3].
    end_location: A tf.float32 tensor of size [3].
    segment_func: A tensorflow function that operates on segments. Expect one
      of tf.math.unsorted_segment_{min/max/mean/prod/sum}. Defaults to
      tf.math.unsorted_segment_mean

  Returns:
    voxel_features: A tf.float32 tensor of
      size [grid_x_len, grid_y_len, grid_z_len, F].
    segment_ids: A tf.int32 tensor of IDs for each point indicating
      which (flattened) voxel cell its data was mapped to.
    point_indices: A tf.int32 tensor of size [num_points, 3] containing the
      location of each point in the 3d voxel grid.
  """
  grid_cell_size = tf.convert_to_tensor(grid_cell_size, dtype=tf.float32)
  start_location = tf.convert_to_tensor(start_location, dtype=tf.float32)
  end_location = tf.convert_to_tensor(end_location, dtype=tf.float32)
  point_indices = tf.cast(
      (points - tf.expand_dims(start_location, axis=0)) /
      tf.expand_dims(grid_cell_size, axis=0),
      dtype=tf.int32)
  grid_size = tf.cast(
      tf.math.ceil((end_location - start_location) / grid_cell_size),
      dtype=tf.int32)
  # Note: all points outside the grid are added to the edges
  # Cap index at grid_size - 1 (so a 10x10x10 grid's max cell is (9,9,9))
  point_indices = tf.minimum(point_indices, tf.expand_dims(grid_size - 1,
                                                           axis=0))
  # Don't allow any points below index (0, 0, 0)
  point_indices = tf.maximum(point_indices, 0)
  segment_ids = tf.reduce_sum(
      point_indices * tf.stack(
          [grid_size[1] * grid_size[2], grid_size[2], 1], axis=0),
      axis=1)
  voxel_features = segment_func(
      data=features,
      segment_ids=segment_ids,
      num_segments=(grid_size[0] * grid_size[1] * grid_size[2]))
  return (tf.reshape(voxel_features,
                     [grid_size[0],
                      grid_size[1],
                      grid_size[2],
                      features.get_shape().as_list()[1]]),
          segment_ids,
          point_indices)
コード例 #2
0
def crop_and_pad_voxels(voxels, start_coordinates, end_coordinates):
    """Crops a voxel region and pads past the boundaries with zeros.

  This accepts start and end coordinates past the limits of the voxel grid,
  and uses it to calculate how much top/left/right/bottom padding to add.

  Args:
    voxels: A tf.float32 tensor of shape [x, y, z, f] to crop
    start_coordinates: A list of len 4 with the [x, y, z, f] starting location
      of our crop. This can be negative, which indicates left/top padding.
    end_coordinates: A list of len 4 with the [x, y, z, f] ending location of
      our crop. This can be beyond the size of the voxel tensor, which indicates
      padding.

  Returns:
    cropped_and_padded_voxels: A voxel grid with shape
      [end_coordinates[0] - start_coordinates[0],
       end_coordinates[1] - start_coordinates[1],
       end_coordinates[2] - start_coordinates[2],
       end_coordinates[3] - start_coordinates[3]]
  Raises:
    ValueError: If requested crop and pad is outside the bounds of what the
      function supports.
  """
    if len(start_coordinates) != 4:
        raise ValueError('start_coordinates should be of length 4')
    if len(end_coordinates) != 4:
        raise ValueError('end_coordinates should be of length 4')
    if any([coord <= 0 for coord in end_coordinates]):
        raise ValueError('Requested end coordinates should be > 0')

    start_coordinates = tf.convert_to_tensor(start_coordinates, tf.int32)
    end_coordinates = tf.convert_to_tensor(end_coordinates, tf.int32)

    # Clip the coordinates to within the voxel grid
    clipped_start_coordinates = tf.maximum(0, start_coordinates)
    clipped_end_coordinates = tf.minimum(voxels.shape, end_coordinates)

    cropped_voxels = tf.slice(voxels,
                              begin=clipped_start_coordinates,
                              size=(clipped_end_coordinates -
                                    clipped_start_coordinates))

    top_and_left_padding = tf.maximum(0, -start_coordinates)
    bottom_and_right_padding = tf.maximum(0, end_coordinates - voxels.shape)

    padding = tf.stack([top_and_left_padding, bottom_and_right_padding],
                       axis=1)
    return tf.pad(cropped_voxels, padding)
コード例 #3
0
def pointcloud_to_sparse_voxel_grid_unbatched(points, features, grid_cell_size,
                                              segment_func):
  """Converts a pointcloud into a voxel grid.

  This function does not handle batch size and only works for a single batch
  of points. The function `pointcloud_to_sparse_voxel_grid` below calls this
  function in a while loop to map a batch of points to a batch of voxels.

  A sparse voxel grid is represented by only keeping the voxels that
  have points in them in memory. Assuming that N' voxels have points in them,
  we represent a sparse voxel grid by
    (a) voxel_features, a [N', F] or [N', G, F] tensor containing the feature
          vector for each voxel.
    (b) voxel_indices, a [N', 3] tensor containing the x, y, z index of each
          voxel.

  Args:
    points: A tf.float32 tensor of size [N, 3].
    features: A tf.float32 tensor of size [N, F].
    grid_cell_size: The size of the grid cells in x, y, z dimensions in the
      voxel grid. It should be either a tf.float32 tensor, a numpy array or a
      list of size [3].
    segment_func: A tensorflow function that operates on segments. Examples are
      one of tf.math.unsorted_segment_{min/max/mean/prod/sum}.

  Returns:
    voxel_features: A tf.float32 tensor of size [N', F] or [N', G, F] where G is
      the number of points sampled per voxel.
    voxel_indices: A tf.int32 tensor of size [N', 3].
    segment_ids: A size [N] tf.int32 tensor of IDs for each point indicating
      which (flattened) voxel cell its data was mapped to.
    voxel_start_location: A tf.float32 tensor of size [3] containing the start
      location of the voxels.

  Raises:
    ValueError: If pooling method is unknown.
  """
  grid_cell_size = tf.convert_to_tensor(grid_cell_size, dtype=tf.float32)
  voxel_xyz_indices, voxel_single_number_indices, voxel_start_location = (
      _points_to_voxel_indices(points=points, grid_cell_size=grid_cell_size))
  voxel_features, segment_ids, num_segments = pool_features_given_indices(
      features=features,
      indices=voxel_single_number_indices,
      segment_func=segment_func)
  voxel_xyz_indices = tf.math.unsorted_segment_max(
      data=voxel_xyz_indices,
      segment_ids=segment_ids,
      num_segments=num_segments)
  return voxel_features, voxel_xyz_indices, segment_ids, voxel_start_location
コード例 #4
0
ファイル: tf_utils.py プロジェクト: suhridbuddha/ntsa
def sequence_loss(y_hat, y, weights, loss_fn, avg_time=True, avg_batch=True):
    loss = loss_fn(y_hat, y)
    total_size = tf.convert_to_tensor(1e-12)
    if avg_batch and avg_time:
        loss = tf.reduce_sum(loss)
        total_size += tf.reduce_sum(weights)
    elif avg_batch and not avg_time:
        loss = tf.reduce_sum(loss, axis=0)
        total_size += tf.reduce_sum(loss, axis=0)
    else:
        loss = tf.reduce_sum(loss, axis=1)
        total_size = tf.reduce_sum(loss, axis=1)

    loss = tf.divide(loss, total_size, name="seq_loss")
    tf.losses.add_loss(loss)
    return loss
コード例 #5
0
def classification_loss_fn(logits, labels, num_valid_voxels=None, weights=1.0):
    """Semantic segmentation cross entropy loss."""
    logits_rank = len(logits.get_shape().as_list())
    labels_rank = len(labels.get_shape().as_list())
    if logits_rank != labels_rank:
        raise ValueError('Logits and labels should have the same rank.')
    if logits_rank != 2 and logits_rank != 3:
        raise ValueError(
            'Logits and labels should have either 2 or 3 dimensions.')
    if logits_rank == 2:
        if num_valid_voxels is not None:
            raise ValueError(
                '`num_valid_voxels` should be None if not using batched logits.'
            )
    elif logits_rank == 3:
        if num_valid_voxels is None:
            raise ValueError(
                '`num_valid_voxels` cannot be None if using batched logits.')
    if logits_rank == 3:
        if (isinstance(weights, tf.Tensor)
                and len(weights.get_shape().as_list()) == 3):
            use_weights = True
        else:
            use_weights = False
        batch_size = logits.get_shape().as_list()[0]
        logits_list = []
        labels_list = []
        weights_list = []
        for i in range(batch_size):
            num_valid_voxels_i = num_valid_voxels[i]
            logits_list.append(logits[i, 0:num_valid_voxels_i, :])
            labels_list.append(labels[i, 0:num_valid_voxels_i, :])
            if use_weights:
                weights_list.append(weights[i, 0:num_valid_voxels_i, :])
        logits = tf.concat(logits_list, axis=0)
        labels = tf.concat(labels_list, axis=0)
        if use_weights:
            weights = tf.concat(weights_list, axis=0)
    weights = tf.convert_to_tensor(weights, dtype=tf.float32)
    if labels.get_shape().as_list()[-1] == 1:
        num_classes = logits.get_shape().as_list()[-1]
        labels = tf.one_hot(tf.reshape(labels, shape=[-1]), num_classes)
    losses = tf.nn.softmax_cross_entropy_with_logits(
        labels=tf.stop_gradient(labels), logits=logits)
    return tf.reduce_mean(losses * tf.reshape(weights, [-1]))
コード例 #6
0
def preprocess(inputs,
               output_keys=None,
               is_training=False,
               input_field_mapping_fn=None,
               image_preprocess_fn_dic=None,
               images_points_correspondence_fn=None,
               points_pad_or_clip_size=None,
               voxels_pad_or_clip_size=None,
               voxel_grid_cell_size=(0.1, 0.1, 0.1),
               num_offset_bins_x=4,
               num_offset_bins_y=4,
               num_offset_bins_z=4,
               point_feature_keys=('point_offset_bins',),
               point_to_voxel_segment_func=tf.math.unsorted_segment_mean,
               x_min_degree_rotation=None,
               x_max_degree_rotation=None,
               y_min_degree_rotation=None,
               y_max_degree_rotation=None,
               z_min_degree_rotation=None,
               z_max_degree_rotation=None,
               rotation_center=(0.0, 0.0, 0.0),
               min_scale_ratio=None,
               max_scale_ratio=None,
               translation_range=None,
               points_within_box_margin=0.0,
               num_points_to_randomly_sample=None,
               crop_points_around_random_seed_point=False,
               crop_num_points=None,
               crop_radius=None,
               crop_num_background_points=None,
               make_objects_axis_aligned=False,
               min_num_points_in_objects=0,
               fit_objects_to_instance_id_points=False,
               voxel_density_threshold=None,
               voxel_density_grid_cell_size=None):
  """Preprocesses data before running 3D object detection.

  Args:
    inputs: A dictionary of inputs. Each value must be a `Tensor`.
    output_keys: Either None, or a list of strings containing the keys in the
      dictionary that is returned by the preprocess function.
    is_training: Whether at training stage or not.
    input_field_mapping_fn: A function that maps the input fields to the
      fields expected by object detection pipeline.
    image_preprocess_fn_dic: Image preprocessing function. Maps view names to
      their image preprocessing functions. Set it to None, if there are no
      images to preprocess or you are not interested in preprocessing images.
    images_points_correspondence_fn: The function that computes correspondence
      between images and points.
    points_pad_or_clip_size: Number of target points to pad or clip to. If None,
      it will not perform the padding.
    voxels_pad_or_clip_size: Number of target voxels to pad or clip to. If None,
      it will not perform the voxel padding.
    voxel_grid_cell_size: A three dimensional tuple determining the voxel grid
      size.
    num_offset_bins_x: Number of bins for point offsets in x direction.
    num_offset_bins_y: Number of bins for point offsets in y direction.
    num_offset_bins_z: Number of bins for point offsets in z direction.
    point_feature_keys: The keys used to form the voxel features.
    point_to_voxel_segment_func: The function used to aggregate the features
      of the points that fall in the same voxel.
    x_min_degree_rotation: Min degree of rotation around the x axis.
    x_max_degree_rotation: Max degree of rotation around the x axis.
    y_min_degree_rotation: Min degree of rotation around the y axis.
    y_max_degree_rotation: Max degree of rotation around the y axis.
    z_min_degree_rotation: Min degree of rotation around the z axis.
    z_max_degree_rotation: Max degree of rotation around the z axis.
    rotation_center: Center of rotation.
    min_scale_ratio: Minimum scale ratio.
    max_scale_ratio: Maximum scale ratio.
    translation_range: A float value corresponding to the range of random
      translation in x, y, z directions. If None, no translation would happen.
    points_within_box_margin: A margin to add to box radius when deciding which
      points fall inside each box.
    num_points_to_randomly_sample: Number of points to randomly sample. If None,
      it will keep the original points and does not perform sampling.
    crop_points_around_random_seed_point: If True, randomly samples a seed
      point and crops the closest `points_pad_or_clip_size` points to the seed
      point. The random seed point selection is based on the following
      procedure. First an object box is randomly selected. Then a random point
      from the random box is selected. Note that the random seed point could be
      sampled from background as well.
    crop_num_points: Number of points to crop.
    crop_radius: The maximum distance of the cropped points from the randomly
      sampled point. If None, it won't be used.
    crop_num_background_points: Minimum number of background points in crop. If
      None, it won't get applied.
    make_objects_axis_aligned: If True, the objects will become axis aligned,
      meaning that they will have identity rotation matrix.
    min_num_points_in_objects: Remove objects that have less number of points
      in them than this value.
    fit_objects_to_instance_id_points: If True, it will fit objects to points
      based on their instance ids.
    voxel_density_threshold: Points that belong to a voxel with a density lower
      than this will be removed.
    voxel_density_grid_cell_size: Voxel grid size for removing noise based on
      voxel density threshold.

  Returns:
    inputs: The inputs processed according to our configuration.

  Raises:
    ValueError: If input dictionary is missing any of the required keys.
  """
  inputs = dict(inputs)

  # Convert all float64 to float32 and all int64 to int32.
  for key in sorted(inputs):
    if isinstance(inputs[key], tf.Tensor):
      if inputs[key].dtype == tf.float64:
        inputs[key] = tf.cast(inputs[key], dtype=tf.float32)
      if inputs[key].dtype == tf.int64:
        if key == 'timestamp':
          continue
        else:
          inputs[key] = tf.cast(inputs[key], dtype=tf.int32)

  (view_image_inputs, view_indices_2d_inputs, mesh_inputs, object_inputs,
   non_tensor_inputs) = split_inputs(
       inputs=inputs,
       input_field_mapping_fn=input_field_mapping_fn,
       image_preprocess_fn_dic=image_preprocess_fn_dic,
       images_points_correspondence_fn=images_points_correspondence_fn)

  if standard_fields.InputDataFields.point_positions not in mesh_inputs:
    raise ValueError('Key %s is missing' %
                     standard_fields.InputDataFields.point_positions)

  # Randomly sample points (optional)
  preprocessor_utils.randomly_sample_points(
      mesh_inputs=mesh_inputs,
      view_indices_2d_inputs=view_indices_2d_inputs,
      target_num_points=num_points_to_randomly_sample)

  # Remove low density points
  if voxel_density_threshold is not None:
    preprocessor_utils.remove_pointcloud_noise(
        mesh_inputs=mesh_inputs,
        view_indices_2d_inputs=view_indices_2d_inputs,
        voxel_grid_cell_size=voxel_density_grid_cell_size,
        voxel_density_threshold=voxel_density_threshold)

  rotation_center = tf.convert_to_tensor(rotation_center, dtype=tf.float32)

  # Remove objects that do not have 3d info.
  _filter_valid_objects(inputs=object_inputs)

  # Cast the objects_class to tf.int32.
  _cast_objects_class(inputs=object_inputs)

  # Remove objects that have less than a certain number of poitns
  if min_num_points_in_objects > 0:
    preprocessor_utils.remove_objects_by_num_points(
        mesh_inputs=mesh_inputs,
        object_inputs=object_inputs,
        min_num_points_in_objects=min_num_points_in_objects)

  # Set point box ids.
  preprocessor_utils.set_point_instance_ids(
      mesh_inputs=mesh_inputs,
      object_inputs=object_inputs,
      points_within_box_margin=points_within_box_margin)

  # Process images.
  preprocessor_utils.preprocess_images(
      view_image_inputs=view_image_inputs,
      view_indices_2d_inputs=view_indices_2d_inputs,
      image_preprocess_fn_dic=image_preprocess_fn_dic,
      is_training=is_training)

  # Randomly transform points and boxes.
  _randomly_transform_points_boxes(
      mesh_inputs=mesh_inputs,
      object_inputs=object_inputs,
      x_min_degree_rotation=x_min_degree_rotation,
      x_max_degree_rotation=x_max_degree_rotation,
      y_min_degree_rotation=y_min_degree_rotation,
      y_max_degree_rotation=y_max_degree_rotation,
      z_min_degree_rotation=z_min_degree_rotation,
      z_max_degree_rotation=z_max_degree_rotation,
      rotation_center=rotation_center,
      min_scale_ratio=min_scale_ratio,
      max_scale_ratio=max_scale_ratio,
      translation_range=translation_range)

  # Randomly crop points around a random seed point.
  if crop_points_around_random_seed_point:
    preprocessor_utils.crop_points_around_random_seed_point(
        mesh_inputs=mesh_inputs,
        view_indices_2d_inputs=view_indices_2d_inputs,
        num_closest_points=crop_num_points,
        max_distance=crop_radius,
        num_background_points=crop_num_background_points)

  if fit_objects_to_instance_id_points:
    preprocessor_utils.fit_objects_to_instance_id_points(
        mesh_inputs=mesh_inputs, object_inputs=object_inputs)

  if make_objects_axis_aligned:
    preprocessor_utils.make_objects_axis_aligned(object_inputs=object_inputs)

  # Putting back the dictionaries together
  inputs = mesh_inputs.copy()
  inputs.update(object_inputs)
  inputs.update(non_tensor_inputs)
  for key in sorted(view_image_inputs):
    inputs[('%s/features' % key)] = view_image_inputs[key]
  for key in sorted(view_indices_2d_inputs):
    inputs[('%s/indices_2d' % key)] = view_indices_2d_inputs[key]

  # Transfer object properties to points, and randomly rotate the points around
  # y axis at training time.
  _transfer_object_properties_to_points(inputs=inputs)

  # Pad or clip points and their properties.
  _pad_or_clip_point_properties(
      inputs=inputs, pad_or_clip_size=points_pad_or_clip_size)

  # Create features that do not exist
  preprocessor_utils.add_point_offsets(
      inputs=inputs, voxel_grid_cell_size=voxel_grid_cell_size)
  preprocessor_utils.add_point_offset_bins(
      inputs=inputs,
      voxel_grid_cell_size=voxel_grid_cell_size,
      num_bins_x=num_offset_bins_x,
      num_bins_y=num_offset_bins_y,
      num_bins_z=num_offset_bins_z)

  # Voxelize point features
  preprocessor_utils.voxelize_point_features(
      inputs=inputs,
      voxels_pad_or_clip_size=voxels_pad_or_clip_size,
      voxel_grid_cell_size=voxel_grid_cell_size,
      point_feature_keys=point_feature_keys,
      point_to_voxel_segment_func=point_to_voxel_segment_func)

  # Voxelizing the semantic labels
  preprocessor_utils.voxelize_semantic_labels(
      inputs=inputs,
      voxels_pad_or_clip_size=voxels_pad_or_clip_size,
      voxel_grid_cell_size=voxel_grid_cell_size)

  # Voxelizing the instance labels
  preprocessor_utils.voxelize_instance_labels(
      inputs=inputs,
      voxels_pad_or_clip_size=voxels_pad_or_clip_size,
      voxel_grid_cell_size=voxel_grid_cell_size)

  # Voxelize the object properties
  preprocessor_utils.voxelize_object_properties(
      inputs=inputs,
      voxels_pad_or_clip_size=voxels_pad_or_clip_size,
      voxel_grid_cell_size=voxel_grid_cell_size)

  # Filter preinputs by output_keys if it is not None.
  if output_keys is not None:
    for key in list(inputs):
      if key not in output_keys:
        inputs.pop(key, None)

  return inputs
コード例 #7
0
def diff_distance(states,
                  actions,
                  rewards,
                  next_states,
                  contexts,
                  state_scales=1.0,
                  goal_scales=1.0,
                  reward_scales=1.0,
                  weight_index=None,
                  weight_vector=None,
                  summarize=False,
                  termination_epsilon=1e-4,
                  state_indices=None,
                  goal_indices=None,
                  norm='L2',
                  epsilon=1e-10):
  """Returns the difference in euclidean distance between states/next_states and contexts.

  Args:
    states: A [batch_size, num_state_dims] Tensor representing a batch
        of states.
    actions: A [batch_size, num_action_dims] Tensor representing a batch
      of actions.
    rewards: A [batch_size] Tensor representing a batch of rewards.
    next_states: A [batch_size, num_state_dims] Tensor representing a batch
      of next states.
    contexts: A list of [batch_size, num_context_dims] Tensor representing
      a batch of contexts.
    state_scales: multiplicative scale for (next) states. A scalar or 1D tensor,
      must be broadcastable to number of state dimensions.
    goal_scales: multiplicative scale for goals. A scalar or 1D tensor,
      must be broadcastable to number of goal dimensions.
    reward_scales: multiplicative scale for rewards. A scalar or 1D tensor,
      must be broadcastable to number of reward dimensions.
    weight_index: (integer) The context list index that specifies weight.
    weight_vector: (a number or a list or Numpy array) The weighting vector,
      broadcastable to `next_states`.
    summarize: (boolean) enable summary ops.
    termination_epsilon: terminate if dist is less than this quantity.
    state_indices: (a list of integers) list of state indices to select.
    goal_indices: (a list of integers) list of goal indices to select.
    vectorize: Return a vectorized form.
    norm: L1 or L2.
    epsilon: small offset to ensure non-negative/zero distance.

  Returns:
    A new tf.float32 [batch_size] rewards Tensor, and
      tf.float32 [batch_size] discounts tensor.
  """
  del actions, rewards  # Unused
  stats = {}
  record_tensor(next_states, state_indices, stats, 'next_states')
  next_states = index_states(next_states, state_indices)
  states = index_states(states, state_indices)
  goals = index_states(contexts[0], goal_indices)
  next_sq_dists = tf.squared_difference(next_states * state_scales,
                                        goals * goal_scales)
  sq_dists = tf.squared_difference(states * state_scales,
                                   goals * goal_scales)
  record_tensor(sq_dists, None, stats, 'sq_dists')
  if weight_vector is not None:
    next_sq_dists *= tf.convert_to_tensor(weight_vector, dtype=next_states.dtype)
    sq_dists *= tf.convert_to_tensor(weight_vector, dtype=next_states.dtype)
  if weight_index is not None:
    next_sq_dists *= contexts[weight_index]
    sq_dists *= contexts[weight_index]
  if norm == 'L1':
    next_dist = tf.sqrt(next_sq_dists + epsilon)
    dist = tf.sqrt(sq_dists + epsilon)
    next_dist = tf.reduce_sum(next_dist, -1)
    dist = tf.reduce_sum(dist, -1)
  elif norm == 'L2':
    next_dist = tf.reduce_sum(next_sq_dists, -1)
    next_dist = tf.sqrt(next_dist + epsilon)  # tf.gradients fails when tf.sqrt(-0.0)
    dist = tf.reduce_sum(sq_dists, -1)
    dist = tf.sqrt(dist + epsilon)  # tf.gradients fails when tf.sqrt(-0.0)
  else:
    raise NotImplementedError(norm)
  discounts = next_dist > termination_epsilon
  if summarize:
    with tf.name_scope('RewardFn/'):
      tf.summary.scalar('mean_dist', tf.reduce_mean(dist))
      tf.summary.histogram('dist', dist)
      summarize_stats(stats)
  diff = dist - next_dist
  diff *= reward_scales
  return tf.to_float(diff), tf.to_float(discounts)
コード例 #8
0
def cosine_similarity(states,
                      starting_states,
                      actions,
                      rewards,
                      next_states,
                      contexts,
                      state_scales=1.0,
                      goal_scales=1.0,
                      reward_scales=1.0,
                      normalize_states=True,
                      normalize_goals=True,
                      weight_index=None,
                      weight_vector=None,
                      summarize=False,
                      state_indices=None,
                      goal_indices=None,
                      offset=0.0):
  """Returns the cosine similarity between next_states - states and contexts.

  Args:
    states: A [batch_size, num_state_dims] Tensor representing a batch
        of states.
    actions: A [batch_size, num_action_dims] Tensor representing a batch
      of actions.
    rewards: A [batch_size] Tensor representing a batch of rewards.
    next_states: A [batch_size, num_state_dims] Tensor representing a batch
      of next states.
    contexts: A list of [batch_size, num_context_dims] Tensor representing
      a batch of contexts.
    state_scales: multiplicative scale for (next) states. A scalar or 1D tensor,
      must be broadcastable to number of state dimensions.
    goal_scales: multiplicative scale for goals. A scalar or 1D tensor,
      must be broadcastable to number of goal dimensions.
    reward_scales: multiplicative scale for rewards. A scalar or 1D tensor,
      must be broadcastable to number of reward dimensions.
    weight_index: (integer) The context list index that specifies weight.
    weight_vector: (a number or a list or Numpy array) The weighting vector,
      broadcastable to `next_states`.
    summarize: (boolean) enable summary ops.
    termination_epsilon: terminate if dist is less than this quantity.
    state_indices: (a list of integers) list of state indices to select.
    goal_indices: (a list of integers) list of goal indices to select.
    vectorize: Return a vectorized form.
    norm: L1 or L2.
    epsilon: small offset to ensure non-negative/zero distance.

  Returns:
    A new tf.float32 [batch_size] rewards Tensor, and
      tf.float32 [batch_size] discounts tensor.
  """
  del actions, rewards  # Unused
  stats = {}
  record_tensor(next_states, state_indices, stats, 'next_states')
  states = index_states(states, state_indices)
  next_states = index_states(next_states, state_indices)
  goals = index_states(contexts[0], goal_indices)

  if weight_vector is not None:
    goals *= tf.convert_to_tensor(weight_vector, dtype=next_states.dtype)
  if weight_index is not None:
    weights = tf.abs(index_states(contexts[0], weight_index))
    goals *= weights

  direction_vec = next_states - states
  if normalize_states:
    direction_vec = tf.nn.l2_normalize(direction_vec, -1)
  goal_vec = goals
  if normalize_goals:
    goal_vec = tf.nn.l2_normalize(goal_vec, -1)

  similarity = tf.reduce_sum(goal_vec * direction_vec, -1)
  discounts = tf.ones_like(similarity)
  return offset + tf.to_float(similarity), tf.to_float(discounts)