Exemplo n.º 1
0
def process_dumped_episode(support_strings, query_strings, image_size,
                           support_decoder, query_decoder):
    """Processes a dumped episode.

  This function is almost like `process_episode()` function, except:
  - It doesn't need to call flush_and_chunk_episode().
  - And the labels are read from the tf.Example directly. We assume that
    labels are already mapped in to [0, n_ways - 1].

  Args:
    support_strings: 1-D Tensor of dtype str, Example protocol buffers of
      support set.
    query_strings: 1-D Tensor of dtype str, Example protocol buffers of query
      set.
    image_size: int, desired image size used during decoding.
    support_decoder: If ImageDecoder, used to decode support set images. If
      None, no decoding of support images is performed.
    query_decoder: ImageDecoder, used to decode query set images. If
      None, no decoding of query images is performed.

  Returns:
    support_images, support_labels, support_labels, query_images,
      query_labels, query_labels: Tensors, batches of images, labels, and
      labels, for the support and query sets (respectively). We return labels
      twice since dumped datasets doesn't have (absolute) class IDs anymore.
      Example proto buffers in place of images, and None in place of labels are
      returned if the corresponding decoder is None.


  """
    if isinstance(support_decoder, decoder.ImageDecoder):
        log_data_augmentation(support_decoder.data_augmentation, 'support')
        support_decoder.image_size = image_size

    if isinstance(query_decoder, decoder.ImageDecoder):
        log_data_augmentation(query_decoder.data_augmentation, 'query')
        query_decoder.image_size = image_size

    support_images = support_strings
    query_images = query_strings
    support_labels = None
    query_labels = None

    if support_decoder:
        support_images, support_labels = tf.map_fn(
            support_decoder.decode_with_label,
            support_strings,
            dtype=(support_decoder.out_type, tf.int32),
            back_prop=False)

    if query_decoder:
        query_images, query_labels = tf.map_fn(query_decoder.decode_with_label,
                                               query_strings,
                                               dtype=(query_decoder.out_type,
                                                      tf.int32),
                                               back_prop=False)

    return (support_images, support_labels, support_labels, query_images,
            query_labels, query_labels)
Exemplo n.º 2
0
def process_episode(example_strings,
                    class_ids,
                    chunk_sizes,
                    image_size,
                    support_decoder=None,
                    query_decoder=None):
  """Processes an episode.

  This function:

  1) splits the batch of examples into "flush", "support", and "query" chunks,
  2) throws away the "flush" chunk,
  3) removes the padded dummy examples from the "support" and "query" chunks,
  4) extracts and processes images out of the example strings, and
  5) builds support and query targets (numbers from 0 to K-1 where K is the
     number of classes in the episode) from the class IDs.

  Args:
    example_strings: 1-D Tensor of dtype str, tf.train.Example protocol buffers.
    class_ids: 1-D Tensor of dtype int, class IDs (absolute wrt the original
      dataset).
    chunk_sizes: Tuple of ints representing the sizes the flush and additional
      chunks.
    image_size: int, desired image size used during decoding.
    support_decoder: Decoder class instance for support set.
    query_decoder: Decoder class instance for query set.

  Returns:
    support_images, support_labels, support_class_ids, query_images,
      query_labels, query_class_ids: Tensors, batches of images, labels, and
      (absolute) class IDs, for the support and query sets (respectively).
  """
  # TODO(goroshin): Replace with `support_decoder.log_summary(name='support')`.
  # TODO(goroshin): Eventually remove setting the image size here and pass it
  # to the ImageDecoder constructor instead.
  if isinstance(support_decoder, decoder.ImageDecoder):
    log_data_augmentation(support_decoder.data_augmentation, 'support')
    support_decoder.image_size = image_size
  if isinstance(query_decoder, decoder.ImageDecoder):
    log_data_augmentation(query_decoder.data_augmentation, 'query')
    query_decoder.image_size = image_size

  (support_strings, support_class_ids), (query_strings, query_class_ids) = \
      flush_and_chunk_episode(example_strings, class_ids, chunk_sizes)

  support_images = tf.map_fn(
      support_decoder, support_strings, dtype=tf.float32, back_prop=False)
  query_images = tf.map_fn(
      query_decoder, query_strings, dtype=tf.float32, back_prop=False)

  # Convert class IDs into labels in [0, num_ways).
  _, support_labels = tf.unique(support_class_ids)
  _, query_labels = tf.unique(query_class_ids)

  return (support_images, support_labels, support_class_ids, query_images,
          query_labels, query_class_ids)
Exemplo n.º 3
0
def simclr_augment(image_batch, blur=False):
    """Apply simclr-style augmentations to a single set of images."""
    (h, w) = image_batch.shape.as_list()[1:3]
    image_batch = (image_batch + 1.0) / 2.0
    image_batch = tf.map_fn(
        lambda x: data_util.preprocess_for_train(x, h, w, impl='simclrv1'),
        image_batch)
    if blur:
        image_batch = tf.map_fn(lambda x: data_util.random_blur(x, h, w),
                                image_batch)
    image_batch = image_batch * 2.0 - 1.0
    return image_batch
Exemplo n.º 4
0
def process_batch(example_strings,
                  class_ids,
                  image_size,
                  batch_data_augmentation=None):
  """Processes a batch.

  This function:

  1) extracts and processes images out of the example strings.
  2) builds targets from the class ID and offset.

  Args:
    example_strings: 1-D Tensor of dtype str, Example protocol buffers.
    class_ids: 1-D Tensor of dtype int, class IDs (absolute wrt the original
      dataset).
    image_size: int, desired image size used during decoding.
    batch_data_augmentation: A DataAugmentation object with parameters for
      perturbing the batch images.

  Returns:
    images, labels: Tensors, a batch of image and labels.
  """
  _log_data_augmentation(batch_data_augmentation, 'batch')
  map_fn = functools.partial(
      process_example,
      image_size=image_size,
      data_augmentation=batch_data_augmentation)
  images = tf.map_fn(map_fn, example_strings, dtype=tf.float32, back_prop=False)
  labels = class_ids
  return (images, labels)
Exemplo n.º 5
0
def process_batch(example_strings, class_ids, image_size, batch_decoder):
    """Processes a batch.

  This function:

  1) extracts and processes images out of the example strings.
  2) builds targets from the class ID and offset.

  Args:
    example_strings: 1-D Tensor of dtype str, Example protocol buffers.
    class_ids: 1-D Tensor of dtype int, class IDs (absolute wrt the original
      dataset).
    image_size: int, desired image size used during decoding.
    batch_decoder: Decoder class instance for the batch.

  Returns:
    images, labels: Tensors, a batch of image and labels.
  """
    # TODO(goroshin): Replace with `batch_decoder.log_summary(name='support')`.
    if isinstance(batch_decoder, decoder.ImageDecoder):
        log_data_augmentation(batch_decoder.data_augmentation, 'batch')
        batch_decoder.image_size = image_size
    images = tf.map_fn(batch_decoder,
                       example_strings,
                       dtype=batch_decoder.out_type,
                       back_prop=False)
    labels = class_ids
    return (images, labels)
Exemplo n.º 6
0
def knn_graph_from_points(points, num_valid_points, k,
                          distance_upper_bound, mask=None):
  """Returns the distances and indices of the neighbors of each point.

  Note that each point will have at least k neighbors unless the number of
  points is less than k. In that case, the python function that is wrapped in
  py_function will raise a value error.

  Args:
    points: A tf.float32 tensor of size [batch_size, N, D] where D is the point
      dimensions.
    num_valid_points: A tf.int32 tensor of size [batch_size] containing the
      number of valid points in each batch example.
    k: Number of neighbors for each point.
    distance_upper_bound: Only build the graph using points that are closer than
      this distance.
    mask: If not None, A tf.bool tensor of size [batch_size, N]. If None, it is
      ignored. If not None, knn will be applied to only points where the mask is
      True. The points where the mask is False will have themselves as their
      neighbors.

  Returns:
    distances: A tf.float32 tensor of size [batch_size, N, k].
    indices: A tf.int32 tensor of size [batch_size, N, k].

  Raises:
    ValueError: If batch_size is unknown.
  """
  if points.get_shape().as_list()[0] is None:
    raise ValueError('Batch size is unknown.')
  batch_size = points.get_shape().as_list()[0]
  num_points = tf.shape(points)[1]

  def fn_knn_graph_from_points_unbatched(i):
    """Computes knn graph for example i in the batch."""
    num_valid_points_i = num_valid_points[i]
    points_i = points[i, :num_valid_points_i, :]
    if mask is None:
      mask_i = None
    else:
      mask_i = mask[i, :num_valid_points_i]
    distances_i, indices_i = knn_graph_from_points_unbatched(
        points=points_i,
        k=k,
        distance_upper_bound=distance_upper_bound,
        mask=mask_i)
    distances_i = tf.pad(
        distances_i, paddings=[[0, num_points - num_valid_points_i], [0, 0]])
    indices_i = tf.pad(
        indices_i, paddings=[[0, num_points - num_valid_points_i], [0, 0]])
    return distances_i, indices_i

  distances, indices = tf.map_fn(
      fn=fn_knn_graph_from_points_unbatched,
      elems=tf.range(batch_size),
      dtype=(tf.float32, tf.int32))

  return distances, indices
def compute_target_optimal_q(reward, gamma, next_actions, next_q_values,
                             next_states, terminals):
    """Builds an op used as a target for the Q-value.

  This algorithm corresponds to the method "OT" in
  Ie et al. https://arxiv.org/abs/1905.12767..

  Args:
    reward: [batch_size] tensor, the immediate reward.
    gamma: float, discount factor with the usual RL meaning.
    next_actions: [batch_size, slate_size] tensor, the next slate.
    next_q_values: [batch_size, num_of_documents] tensor, the q values of the
      documents in the next step.
    next_states: [batch_size, 1 + num_of_documents] tensor, the features for the
      user and the docuemnts in the next step.
    terminals: [batch_size] tensor, indicating if this is a terminal step.

  Returns:
    [batch_size] tensor, the target q values.
  """
    scores, score_no_click = _get_unnormalized_scores(next_states)

    # Obtain all possible slates given current docs in the candidate set.
    slate_size = next_actions.get_shape().as_list()[1]
    num_candidates = next_q_values.get_shape().as_list()[1]
    mesh_args = [list(range(num_candidates))] * slate_size
    slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1)
    slates = tf.reshape(slates, shape=(-1, slate_size))
    # Filter slates that include duplicates to ensure each document is picked
    # at most once.
    unique_mask = tf.map_fn(
        lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])),
        slates,
        dtype=tf.bool)
    # [num_of_slates, slate_size]
    slates = tf.boolean_mask(tensor=slates, mask=unique_mask)

    # [batch_size, num_of_slates, slate_size]
    next_q_values_slate = tf.gather(next_q_values, slates, axis=1)
    # [batch_size, num_of_slates, slate_size]
    scores_slate = tf.gather(scores, slates, axis=1)
    # [batch_size, num_of_slates]
    batch_size = next_states.get_shape().as_list()[0]
    score_no_click_slate = tf.reshape(
        tf.tile(score_no_click,
                tf.shape(input=slates)[:1]), [batch_size, -1])

    # [batch_size, num_of_slates]
    next_q_target_slate = tf.reduce_sum(
        input_tensor=next_q_values_slate * scores_slate,
        axis=2) / (tf.reduce_sum(input_tensor=scores_slate, axis=2) +
                   score_no_click_slate)

    next_q_target_max = tf.reduce_max(input_tensor=next_q_target_slate, axis=1)

    return reward + gamma * next_q_target_max * (
        1. - tf.cast(terminals, tf.float32))
Exemplo n.º 8
0
    def compute_loss(self, labels, scores):
        if (self._max_num_distractors != -1
                and self._max_num_distractors <= scores.shape[1]):
            # Truncates the number of distractors and redefines labels and scores.

            # TODO(dei): Add gin config arg for choosing random num distractor.s
            # max_num_dist = tf.random.uniform(
            #     [], 1, self.embedding_matrix.shape[0], dtype=tf.int32)
            max_num_dist = self._max_num_distractors

            def slice_to_max_num_distractors_fn(inputs):
                """Reduces the number of distractors to the max number."""
                label_for_ex, scores_for_ex = inputs

                scores_nocorrect = tf.concat([
                    scores_for_ex[0:label_for_ex],
                    scores_for_ex[(label_for_ex + 1):]
                ],
                                             axis=0)
                random_start_index = tf.random.uniform(
                    shape=[],
                    minval=0,
                    maxval=scores_for_ex.shape[0] - max_num_dist,
                    dtype=tf.int32)

                new_scores = scores_nocorrect[
                    random_start_index:random_start_index + max_num_dist]

                # Put the groundtruth embedding in position 0 to make labels easy.
                new_scores = tf.concat([
                    tf.expand_dims(scores_for_ex[label_for_ex], 0), new_scores
                ],
                                       axis=0)

                return new_scores

            # Truncates the number of distractors being scores to the max number.
            scores = tf.map_fn(slice_to_max_num_distractors_fn,
                               [labels, scores],
                               dtype=tf.float32)

            logging.warning('HERE: scores=%s, labels%s', str(scores.shape),
                            str(labels.shape))
            # Since we moved the correct embedding to position 0.
            labels = tf.zeros_like(labels)

        main_loss = self._loss_object(labels, scores)
        return main_loss
Exemplo n.º 9
0
def points_offset_in_voxels(points, grid_cell_size):
  """Converts points into offsets in voxel grid.

  Args:
    points: A tf.float32 tensor of size [batch_size, N, 3].
    grid_cell_size: The size of the grid cells in x, y, z dimensions in the
      voxel grid. It should be either a tf.float32 tensor, a numpy array or a
      list of size [3].

  Returns:
    voxel_xyz_offsets: A tf.float32 tensor of size [batch_size, N, 3].
  """
  batch_size = points.get_shape().as_list()[0]

  def fn(i):
    return _points_offset_in_voxels_unbatched(
        points=points[i, :, :], grid_cell_size=grid_cell_size)

  return tf.map_fn(fn=fn, elems=tf.range(batch_size), dtype=tf.float32)
Exemplo n.º 10
0
def select_slate_optimal(slate_size, s_no_click, s, q):
    """Selects the slate using exhaustive search.

  This algorithm corresponds to the method "OS" in
  Ie et al. https://arxiv.org/abs/1905.12767.

  Args:
    slate_size: int, the size of the recommendation slate.
    s_no_click: float tensor, the score for not clicking any document.
    s: [num_of_documents] tensor, the scores for clicking documents.
    q: [num_of_documents] tensor, the predicted q values for documents.

  Returns:
    [slate_size] tensor, the selected slate.
  """

    num_candidates = s.shape.as_list()[0]

    # Obtain all possible slates given current docs in the candidate set.
    mesh_args = [list(range(num_candidates))] * slate_size
    slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1)
    slates = tf.reshape(slates, shape=(-1, slate_size))

    # Filter slates that include duplicates to ensure each document is picked
    # at most once.
    unique_mask = tf.map_fn(
        lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])),
        slates,
        dtype=tf.bool)
    slates = tf.boolean_mask(tensor=slates, mask=unique_mask)

    slate_q_values = tf.gather(s * q, slates)
    slate_scores = tf.gather(s, slates)
    slate_normalizer = tf.reduce_sum(input_tensor=slate_scores,
                                     axis=1) + s_no_click

    slate_q_values = slate_q_values / tf.expand_dims(slate_normalizer, 1)
    slate_sum_q_values = tf.reduce_sum(input_tensor=slate_q_values, axis=1)
    max_q_slate_index = tf.argmax(input=slate_sum_q_values)
    return tf.gather(slates, max_q_slate_index, axis=0)
Exemplo n.º 11
0
def sparse_voxel_grid_to_pointcloud(voxel_features, segment_ids,
                                    num_valid_voxels, num_valid_points):
    """Convert voxel features back to points given their segment ids.

  Args:
    voxel_features: A tf.float32 tensor of size [batch_size, N', F].
    segment_ids: A size [batch_size, N] tf.int32 tensor of IDs for each point
      indicating which (flattened) voxel cell its data was mapped to.
    num_valid_voxels: A tf.int32 tensor of size [batch_size] containing the
      number of valid voxels in each batch example.
    num_valid_points: A tf.int32 tensor of size [batch_size] containing the
      number of valid points in each batch example.

  Returns:
    point_features: A tf.float32 tensor of size [batch_size, N, F].

  Raises:
    ValueError: If batch_size is unknown at graph construction time.
  """
    batch_size = voxel_features.shape[0]
    if batch_size is None:
        raise ValueError('batch_size is unknown at graph construction time.')
    num_points = tf.shape(segment_ids)[1]

    def fn(i):
        num_valid_voxels_i = num_valid_voxels[i]
        num_valid_points_i = num_valid_points[i]
        voxel_features_i = voxel_features[i, :num_valid_voxels_i, :]
        segment_ids_i = segment_ids[i, :num_valid_points_i]
        point_features = tf.gather(voxel_features_i, segment_ids_i)
        point_features_rank = len(point_features.get_shape().as_list())
        point_features_paddings = [[0, num_points - num_valid_points_i]]
        for _ in range(point_features_rank - 1):
            point_features_paddings.append([0, 0])
        point_features = tf.pad(point_features,
                                paddings=point_features_paddings)
        return point_features

    return tf.map_fn(fn=fn, elems=tf.range(batch_size), dtype=tf.float32)
Exemplo n.º 12
0
def pointcloud_to_sparse_voxel_grid(points, features, num_valid_points,
                                    grid_cell_size, voxels_pad_or_clip_size,
                                    segment_func):
  """Converts a pointcloud into a voxel grid.

  This function calls the `pointcloud_to_sparse_voxel_grid_unbatched`
  function avove in a while loop to map a batch of points to a batch of voxels.

  Args:
    points: A tf.float32 tensor of size [batch_size, N, 3].
    features: A tf.float32 tensor of size [batch_size, N, F].
    num_valid_points: A tf.int32 tensor of size [num_batches] containing the
      number of valid points in each batch example.
    grid_cell_size: A tf.float32 tensor of size [3].
    voxels_pad_or_clip_size: Number of target voxels to pad or clip to. If None,
      it will not perform the padding.
    segment_func: A tensorflow function that operates on segments. Examples are
      one of tf.math.unsorted_segment_{min/max/mean/prod/sum}.

  Returns:
    voxel_features: A tf.float32 tensor of size [batch_size, N', F]
      or [batch_size, N', G, F] where G is the number of points sampled per
      voxel.
    voxel_indices: A tf.int32 tensor of size [batch_size, N', 3].
    num_valid_voxels: A tf.int32 tensor of size [batch_size].
    segment_ids: A size [batch_size, N] tf.int32 tensor of IDs for each point
      indicating which (flattened) voxel cell its data was mapped to.
    voxel_start_location: A size [batch_size, 3] tf.float32 tensor of voxel
      start locations.

  Raises:
    ValueError: If pooling method is unknown.
  """
  batch_size = points.get_shape().as_list()[0]
  if batch_size is None:
    batch_size = tf.shape(points)[0]
  num_points = tf.shape(points)[1]

  def fn(i):
    """Map function."""
    num_valid_points_i = num_valid_points[i]
    points_i = points[i, :num_valid_points_i, :]
    features_i = features[i, :num_valid_points_i, :]
    voxel_features_i, voxel_indices_i, segment_ids_i, voxel_start_location_i = (
        pointcloud_to_sparse_voxel_grid_unbatched(
            points=points_i,
            features=features_i,
            grid_cell_size=grid_cell_size,
            segment_func=segment_func))
    num_valid_voxels_i = tf.shape(voxel_features_i)[0]
    (voxel_features_i, voxel_indices_i, num_valid_voxels_i,
     segment_ids_i) = _pad_or_clip_voxels(
         voxel_features=voxel_features_i,
         voxel_indices=voxel_indices_i,
         num_valid_voxels=num_valid_voxels_i,
         segment_ids=segment_ids_i,
         voxels_pad_or_clip_size=voxels_pad_or_clip_size)
    segment_ids_i = tf.pad(
        segment_ids_i, paddings=[[0, num_points - num_valid_points_i]])
    return (voxel_features_i, voxel_indices_i, num_valid_voxels_i,
            segment_ids_i, voxel_start_location_i)

  return tf.map_fn(
      fn=fn,
      elems=tf.range(batch_size),
      dtype=(tf.float32, tf.int32, tf.int32, tf.int32, tf.float32))
Exemplo n.º 13
0
def process_episode(example_strings,
                    class_ids,
                    chunk_sizes,
                    image_size,
                    support_data_augmentation=None,
                    query_data_augmentation=None):
  """Processes an episode.

  This function:

  1) splits the batch of examples into "flush", "support", and "query" chunks,
  2) throws away the "flush" chunk,
  3) removes the padded dummy examples from the "support" and "query" chunks,
     and
  4) extracts and processes images out of the example strings.
  5) builds support and query targets (numbers from 0 to K-1 where K is the
     number of classes in the episode) from the class IDs.

  Args:
    example_strings: 1-D Tensor of dtype str, tf.train.Example protocol buffers.
    class_ids: 1-D Tensor of dtype int, class IDs (absolute wrt the original
      dataset).
    chunk_sizes: Tuple of 3 ints representing the sizes of (resp.) the flush,
      support, and query chunks.
    image_size: int, desired image size used during decoding.
    support_data_augmentation: A DataAugmentation object with parameters for
      perturbing the support set images.
    query_data_augmentation: A DataAugmentation object with parameters for
      perturbing the query set images.

  Returns:
    support_images, support_labels, support_class_ids, query_images,
      query_labels, query_class_ids: Tensors, batches of images, labels, and
      (absolute) class IDs, for the support and query sets (respectively).
  """
  _log_data_augmentation(support_data_augmentation, 'support')
  _log_data_augmentation(query_data_augmentation, 'query')
  flush_chunk_size, support_chunk_size, _ = chunk_sizes
  support_start = flush_chunk_size
  query_start = support_start + support_chunk_size
  support_map_fn = functools.partial(
      process_example,
      image_size=image_size,
      data_augmentation=support_data_augmentation)
  query_map_fn = functools.partial(
      process_example,
      image_size=image_size,
      data_augmentation=query_data_augmentation)

  support_strings = example_strings[support_start:query_start]
  support_class_ids = class_ids[support_start:query_start]
  (support_strings,
   support_class_ids) = filter_dummy_examples(support_strings,
                                              support_class_ids)
  support_images = tf.map_fn(
      support_map_fn, support_strings, dtype=tf.float32, back_prop=False)

  query_strings = example_strings[query_start:]
  query_class_ids = class_ids[query_start:]
  (query_strings,
   query_class_ids) = filter_dummy_examples(query_strings, query_class_ids)
  query_images = tf.map_fn(
      query_map_fn, query_strings, dtype=tf.float32, back_prop=False)

  # Convert class IDs into labels in [0, num_ways).
  _, support_labels = tf.unique(support_class_ids)
  _, query_labels = tf.unique(query_class_ids)

  return (support_images, support_labels, support_class_ids, query_images,
          query_labels, query_class_ids)
Exemplo n.º 14
0
def compute_module_criticality(
    objective_fn,
    module_variables_init,
    module_variables_final,
    num_samples_per_iteration=10,
    alpha_grid_size=10,
    sigma_grid_size=10,
    sigma_ratio=1.0,
    loss_threshold_condition=relative_error_condition,
    normalize_error=False,
):
    """Compute the criticality of a module parameterized by `module_variables`.

  Args:
    objective_fn: A callable that takes in an iterable of the module-specific
      variables and produces the value of the objective function.
    module_variables_init: A list of tf.Tensors; the variables of the module at
      initialization.
    module_variables_final: A list of tf.Tensors; the variables of the module at
      convergence.
    num_samples_per_iteration: Number of perturbations to sample each iteration.
    alpha_grid_size: The number of values to test for alpha, the interpolation
      coefficient.
    sigma_grid_size: The number of values to test for sigma, the standard
      deviation of the perturbation.
    sigma_ratio: Positive scalar multiplier k for values of sigma, to enforce
      that the tested values of sigma lie in [k * 1e-16, k]; the default is 1.0,
      implying that the tested values of sigma lie in the interval [1e-16, 1].
    loss_threshold_condition: A callable that takes in a reference objective
      value and a candidate objective value and produces a thresholding
      decision.
    normalize_error: Whether to normalize the error that is minimized over in
      the definition of criticality by the Frobenius norm of the distance
      between initial and final parameters.

  Returns:
    A `collections.NamedTuple` that contains the results of the criticality
    analysis.
  """
    initial_objective_value = objective_fn(module_variables_init)
    final_objective_value = objective_fn(module_variables_final)

    # Test a 2D grid of alpha and sigma values.
    float_zero = tf.cast(0, tf.float32)
    alphas, sigmas = tf.meshgrid(
        tf.linspace(float_zero, 1, alpha_grid_size + 1),
        tf.linspace(float_zero + 1e-16, 1, sigma_grid_size + 1) * sigma_ratio,
    )
    alphas, sigmas = tf.reshape(alphas, [-1]), tf.reshape(sigmas, [-1])

    def _evaluate_alpha_sigma(alpha_sigma):
        alpha, sigma = alpha_sigma
        return _interpolate_and_perturb(
            alpha=alpha,
            sigma=sigma,
            params_init=module_variables_init,
            params_final=module_variables_final,
            objective_fn=objective_fn,
            loss_threshold_condition=functools.partial(
                loss_threshold_condition,
                reference_error=final_objective_value),
            normalize_error=normalize_error,
            num_samples_per_iteration=num_samples_per_iteration,
        )

    (threshold_conditions, interpolated_and_perturbed_losses,
     interpolated_and_perturbed_norms) = tf.map_fn(
         _evaluate_alpha_sigma,
         elems=(alphas, sigmas),
         dtype=(tf.bool, tf.float32, tf.float32),
     )

    masked_interpolated_and_perturbed_norms = tf.where(
        threshold_conditions, interpolated_and_perturbed_norms,
        tf.ones_like(interpolated_and_perturbed_norms) * np.inf)
    idx_min = tf.math.argmin(masked_interpolated_and_perturbed_norms)
    (loss_final, norm_final, alpha_final,
     sigma_final) = (interpolated_and_perturbed_losses[idx_min],
                     interpolated_and_perturbed_norms[idx_min],
                     alphas[idx_min], sigmas[idx_min])

    return ModuleCriticalityAnalysis(
        criticality_score=norm_final,
        alpha=alpha_final,
        sigma=sigma_final,
        loss_value=loss_final,
        num_samples_per_iteration=num_samples_per_iteration,
        alpha_grid_size=alpha_grid_size,
        sigma_grid_size=sigma_grid_size,
        sigma_ratio=sigma_ratio,
        initial_objective_value=initial_objective_value,
        final_objective_value=final_objective_value,
    )