def process_dumped_episode(support_strings, query_strings, image_size, support_decoder, query_decoder): """Processes a dumped episode. This function is almost like `process_episode()` function, except: - It doesn't need to call flush_and_chunk_episode(). - And the labels are read from the tf.Example directly. We assume that labels are already mapped in to [0, n_ways - 1]. Args: support_strings: 1-D Tensor of dtype str, Example protocol buffers of support set. query_strings: 1-D Tensor of dtype str, Example protocol buffers of query set. image_size: int, desired image size used during decoding. support_decoder: If ImageDecoder, used to decode support set images. If None, no decoding of support images is performed. query_decoder: ImageDecoder, used to decode query set images. If None, no decoding of query images is performed. Returns: support_images, support_labels, support_labels, query_images, query_labels, query_labels: Tensors, batches of images, labels, and labels, for the support and query sets (respectively). We return labels twice since dumped datasets doesn't have (absolute) class IDs anymore. Example proto buffers in place of images, and None in place of labels are returned if the corresponding decoder is None. """ if isinstance(support_decoder, decoder.ImageDecoder): log_data_augmentation(support_decoder.data_augmentation, 'support') support_decoder.image_size = image_size if isinstance(query_decoder, decoder.ImageDecoder): log_data_augmentation(query_decoder.data_augmentation, 'query') query_decoder.image_size = image_size support_images = support_strings query_images = query_strings support_labels = None query_labels = None if support_decoder: support_images, support_labels = tf.map_fn( support_decoder.decode_with_label, support_strings, dtype=(support_decoder.out_type, tf.int32), back_prop=False) if query_decoder: query_images, query_labels = tf.map_fn(query_decoder.decode_with_label, query_strings, dtype=(query_decoder.out_type, tf.int32), back_prop=False) return (support_images, support_labels, support_labels, query_images, query_labels, query_labels)
def process_episode(example_strings, class_ids, chunk_sizes, image_size, support_decoder=None, query_decoder=None): """Processes an episode. This function: 1) splits the batch of examples into "flush", "support", and "query" chunks, 2) throws away the "flush" chunk, 3) removes the padded dummy examples from the "support" and "query" chunks, 4) extracts and processes images out of the example strings, and 5) builds support and query targets (numbers from 0 to K-1 where K is the number of classes in the episode) from the class IDs. Args: example_strings: 1-D Tensor of dtype str, tf.train.Example protocol buffers. class_ids: 1-D Tensor of dtype int, class IDs (absolute wrt the original dataset). chunk_sizes: Tuple of ints representing the sizes the flush and additional chunks. image_size: int, desired image size used during decoding. support_decoder: Decoder class instance for support set. query_decoder: Decoder class instance for query set. Returns: support_images, support_labels, support_class_ids, query_images, query_labels, query_class_ids: Tensors, batches of images, labels, and (absolute) class IDs, for the support and query sets (respectively). """ # TODO(goroshin): Replace with `support_decoder.log_summary(name='support')`. # TODO(goroshin): Eventually remove setting the image size here and pass it # to the ImageDecoder constructor instead. if isinstance(support_decoder, decoder.ImageDecoder): log_data_augmentation(support_decoder.data_augmentation, 'support') support_decoder.image_size = image_size if isinstance(query_decoder, decoder.ImageDecoder): log_data_augmentation(query_decoder.data_augmentation, 'query') query_decoder.image_size = image_size (support_strings, support_class_ids), (query_strings, query_class_ids) = \ flush_and_chunk_episode(example_strings, class_ids, chunk_sizes) support_images = tf.map_fn( support_decoder, support_strings, dtype=tf.float32, back_prop=False) query_images = tf.map_fn( query_decoder, query_strings, dtype=tf.float32, back_prop=False) # Convert class IDs into labels in [0, num_ways). _, support_labels = tf.unique(support_class_ids) _, query_labels = tf.unique(query_class_ids) return (support_images, support_labels, support_class_ids, query_images, query_labels, query_class_ids)
def simclr_augment(image_batch, blur=False): """Apply simclr-style augmentations to a single set of images.""" (h, w) = image_batch.shape.as_list()[1:3] image_batch = (image_batch + 1.0) / 2.0 image_batch = tf.map_fn( lambda x: data_util.preprocess_for_train(x, h, w, impl='simclrv1'), image_batch) if blur: image_batch = tf.map_fn(lambda x: data_util.random_blur(x, h, w), image_batch) image_batch = image_batch * 2.0 - 1.0 return image_batch
def process_batch(example_strings, class_ids, image_size, batch_data_augmentation=None): """Processes a batch. This function: 1) extracts and processes images out of the example strings. 2) builds targets from the class ID and offset. Args: example_strings: 1-D Tensor of dtype str, Example protocol buffers. class_ids: 1-D Tensor of dtype int, class IDs (absolute wrt the original dataset). image_size: int, desired image size used during decoding. batch_data_augmentation: A DataAugmentation object with parameters for perturbing the batch images. Returns: images, labels: Tensors, a batch of image and labels. """ _log_data_augmentation(batch_data_augmentation, 'batch') map_fn = functools.partial( process_example, image_size=image_size, data_augmentation=batch_data_augmentation) images = tf.map_fn(map_fn, example_strings, dtype=tf.float32, back_prop=False) labels = class_ids return (images, labels)
def process_batch(example_strings, class_ids, image_size, batch_decoder): """Processes a batch. This function: 1) extracts and processes images out of the example strings. 2) builds targets from the class ID and offset. Args: example_strings: 1-D Tensor of dtype str, Example protocol buffers. class_ids: 1-D Tensor of dtype int, class IDs (absolute wrt the original dataset). image_size: int, desired image size used during decoding. batch_decoder: Decoder class instance for the batch. Returns: images, labels: Tensors, a batch of image and labels. """ # TODO(goroshin): Replace with `batch_decoder.log_summary(name='support')`. if isinstance(batch_decoder, decoder.ImageDecoder): log_data_augmentation(batch_decoder.data_augmentation, 'batch') batch_decoder.image_size = image_size images = tf.map_fn(batch_decoder, example_strings, dtype=batch_decoder.out_type, back_prop=False) labels = class_ids return (images, labels)
def knn_graph_from_points(points, num_valid_points, k, distance_upper_bound, mask=None): """Returns the distances and indices of the neighbors of each point. Note that each point will have at least k neighbors unless the number of points is less than k. In that case, the python function that is wrapped in py_function will raise a value error. Args: points: A tf.float32 tensor of size [batch_size, N, D] where D is the point dimensions. num_valid_points: A tf.int32 tensor of size [batch_size] containing the number of valid points in each batch example. k: Number of neighbors for each point. distance_upper_bound: Only build the graph using points that are closer than this distance. mask: If not None, A tf.bool tensor of size [batch_size, N]. If None, it is ignored. If not None, knn will be applied to only points where the mask is True. The points where the mask is False will have themselves as their neighbors. Returns: distances: A tf.float32 tensor of size [batch_size, N, k]. indices: A tf.int32 tensor of size [batch_size, N, k]. Raises: ValueError: If batch_size is unknown. """ if points.get_shape().as_list()[0] is None: raise ValueError('Batch size is unknown.') batch_size = points.get_shape().as_list()[0] num_points = tf.shape(points)[1] def fn_knn_graph_from_points_unbatched(i): """Computes knn graph for example i in the batch.""" num_valid_points_i = num_valid_points[i] points_i = points[i, :num_valid_points_i, :] if mask is None: mask_i = None else: mask_i = mask[i, :num_valid_points_i] distances_i, indices_i = knn_graph_from_points_unbatched( points=points_i, k=k, distance_upper_bound=distance_upper_bound, mask=mask_i) distances_i = tf.pad( distances_i, paddings=[[0, num_points - num_valid_points_i], [0, 0]]) indices_i = tf.pad( indices_i, paddings=[[0, num_points - num_valid_points_i], [0, 0]]) return distances_i, indices_i distances, indices = tf.map_fn( fn=fn_knn_graph_from_points_unbatched, elems=tf.range(batch_size), dtype=(tf.float32, tf.int32)) return distances, indices
def compute_target_optimal_q(reward, gamma, next_actions, next_q_values, next_states, terminals): """Builds an op used as a target for the Q-value. This algorithm corresponds to the method "OT" in Ie et al. https://arxiv.org/abs/1905.12767.. Args: reward: [batch_size] tensor, the immediate reward. gamma: float, discount factor with the usual RL meaning. next_actions: [batch_size, slate_size] tensor, the next slate. next_q_values: [batch_size, num_of_documents] tensor, the q values of the documents in the next step. next_states: [batch_size, 1 + num_of_documents] tensor, the features for the user and the docuemnts in the next step. terminals: [batch_size] tensor, indicating if this is a terminal step. Returns: [batch_size] tensor, the target q values. """ scores, score_no_click = _get_unnormalized_scores(next_states) # Obtain all possible slates given current docs in the candidate set. slate_size = next_actions.get_shape().as_list()[1] num_candidates = next_q_values.get_shape().as_list()[1] mesh_args = [list(range(num_candidates))] * slate_size slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1) slates = tf.reshape(slates, shape=(-1, slate_size)) # Filter slates that include duplicates to ensure each document is picked # at most once. unique_mask = tf.map_fn( lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])), slates, dtype=tf.bool) # [num_of_slates, slate_size] slates = tf.boolean_mask(tensor=slates, mask=unique_mask) # [batch_size, num_of_slates, slate_size] next_q_values_slate = tf.gather(next_q_values, slates, axis=1) # [batch_size, num_of_slates, slate_size] scores_slate = tf.gather(scores, slates, axis=1) # [batch_size, num_of_slates] batch_size = next_states.get_shape().as_list()[0] score_no_click_slate = tf.reshape( tf.tile(score_no_click, tf.shape(input=slates)[:1]), [batch_size, -1]) # [batch_size, num_of_slates] next_q_target_slate = tf.reduce_sum( input_tensor=next_q_values_slate * scores_slate, axis=2) / (tf.reduce_sum(input_tensor=scores_slate, axis=2) + score_no_click_slate) next_q_target_max = tf.reduce_max(input_tensor=next_q_target_slate, axis=1) return reward + gamma * next_q_target_max * ( 1. - tf.cast(terminals, tf.float32))
def compute_loss(self, labels, scores): if (self._max_num_distractors != -1 and self._max_num_distractors <= scores.shape[1]): # Truncates the number of distractors and redefines labels and scores. # TODO(dei): Add gin config arg for choosing random num distractor.s # max_num_dist = tf.random.uniform( # [], 1, self.embedding_matrix.shape[0], dtype=tf.int32) max_num_dist = self._max_num_distractors def slice_to_max_num_distractors_fn(inputs): """Reduces the number of distractors to the max number.""" label_for_ex, scores_for_ex = inputs scores_nocorrect = tf.concat([ scores_for_ex[0:label_for_ex], scores_for_ex[(label_for_ex + 1):] ], axis=0) random_start_index = tf.random.uniform( shape=[], minval=0, maxval=scores_for_ex.shape[0] - max_num_dist, dtype=tf.int32) new_scores = scores_nocorrect[ random_start_index:random_start_index + max_num_dist] # Put the groundtruth embedding in position 0 to make labels easy. new_scores = tf.concat([ tf.expand_dims(scores_for_ex[label_for_ex], 0), new_scores ], axis=0) return new_scores # Truncates the number of distractors being scores to the max number. scores = tf.map_fn(slice_to_max_num_distractors_fn, [labels, scores], dtype=tf.float32) logging.warning('HERE: scores=%s, labels%s', str(scores.shape), str(labels.shape)) # Since we moved the correct embedding to position 0. labels = tf.zeros_like(labels) main_loss = self._loss_object(labels, scores) return main_loss
def points_offset_in_voxels(points, grid_cell_size): """Converts points into offsets in voxel grid. Args: points: A tf.float32 tensor of size [batch_size, N, 3]. grid_cell_size: The size of the grid cells in x, y, z dimensions in the voxel grid. It should be either a tf.float32 tensor, a numpy array or a list of size [3]. Returns: voxel_xyz_offsets: A tf.float32 tensor of size [batch_size, N, 3]. """ batch_size = points.get_shape().as_list()[0] def fn(i): return _points_offset_in_voxels_unbatched( points=points[i, :, :], grid_cell_size=grid_cell_size) return tf.map_fn(fn=fn, elems=tf.range(batch_size), dtype=tf.float32)
def select_slate_optimal(slate_size, s_no_click, s, q): """Selects the slate using exhaustive search. This algorithm corresponds to the method "OS" in Ie et al. https://arxiv.org/abs/1905.12767. Args: slate_size: int, the size of the recommendation slate. s_no_click: float tensor, the score for not clicking any document. s: [num_of_documents] tensor, the scores for clicking documents. q: [num_of_documents] tensor, the predicted q values for documents. Returns: [slate_size] tensor, the selected slate. """ num_candidates = s.shape.as_list()[0] # Obtain all possible slates given current docs in the candidate set. mesh_args = [list(range(num_candidates))] * slate_size slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1) slates = tf.reshape(slates, shape=(-1, slate_size)) # Filter slates that include duplicates to ensure each document is picked # at most once. unique_mask = tf.map_fn( lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])), slates, dtype=tf.bool) slates = tf.boolean_mask(tensor=slates, mask=unique_mask) slate_q_values = tf.gather(s * q, slates) slate_scores = tf.gather(s, slates) slate_normalizer = tf.reduce_sum(input_tensor=slate_scores, axis=1) + s_no_click slate_q_values = slate_q_values / tf.expand_dims(slate_normalizer, 1) slate_sum_q_values = tf.reduce_sum(input_tensor=slate_q_values, axis=1) max_q_slate_index = tf.argmax(input=slate_sum_q_values) return tf.gather(slates, max_q_slate_index, axis=0)
def sparse_voxel_grid_to_pointcloud(voxel_features, segment_ids, num_valid_voxels, num_valid_points): """Convert voxel features back to points given their segment ids. Args: voxel_features: A tf.float32 tensor of size [batch_size, N', F]. segment_ids: A size [batch_size, N] tf.int32 tensor of IDs for each point indicating which (flattened) voxel cell its data was mapped to. num_valid_voxels: A tf.int32 tensor of size [batch_size] containing the number of valid voxels in each batch example. num_valid_points: A tf.int32 tensor of size [batch_size] containing the number of valid points in each batch example. Returns: point_features: A tf.float32 tensor of size [batch_size, N, F]. Raises: ValueError: If batch_size is unknown at graph construction time. """ batch_size = voxel_features.shape[0] if batch_size is None: raise ValueError('batch_size is unknown at graph construction time.') num_points = tf.shape(segment_ids)[1] def fn(i): num_valid_voxels_i = num_valid_voxels[i] num_valid_points_i = num_valid_points[i] voxel_features_i = voxel_features[i, :num_valid_voxels_i, :] segment_ids_i = segment_ids[i, :num_valid_points_i] point_features = tf.gather(voxel_features_i, segment_ids_i) point_features_rank = len(point_features.get_shape().as_list()) point_features_paddings = [[0, num_points - num_valid_points_i]] for _ in range(point_features_rank - 1): point_features_paddings.append([0, 0]) point_features = tf.pad(point_features, paddings=point_features_paddings) return point_features return tf.map_fn(fn=fn, elems=tf.range(batch_size), dtype=tf.float32)
def pointcloud_to_sparse_voxel_grid(points, features, num_valid_points, grid_cell_size, voxels_pad_or_clip_size, segment_func): """Converts a pointcloud into a voxel grid. This function calls the `pointcloud_to_sparse_voxel_grid_unbatched` function avove in a while loop to map a batch of points to a batch of voxels. Args: points: A tf.float32 tensor of size [batch_size, N, 3]. features: A tf.float32 tensor of size [batch_size, N, F]. num_valid_points: A tf.int32 tensor of size [num_batches] containing the number of valid points in each batch example. grid_cell_size: A tf.float32 tensor of size [3]. voxels_pad_or_clip_size: Number of target voxels to pad or clip to. If None, it will not perform the padding. segment_func: A tensorflow function that operates on segments. Examples are one of tf.math.unsorted_segment_{min/max/mean/prod/sum}. Returns: voxel_features: A tf.float32 tensor of size [batch_size, N', F] or [batch_size, N', G, F] where G is the number of points sampled per voxel. voxel_indices: A tf.int32 tensor of size [batch_size, N', 3]. num_valid_voxels: A tf.int32 tensor of size [batch_size]. segment_ids: A size [batch_size, N] tf.int32 tensor of IDs for each point indicating which (flattened) voxel cell its data was mapped to. voxel_start_location: A size [batch_size, 3] tf.float32 tensor of voxel start locations. Raises: ValueError: If pooling method is unknown. """ batch_size = points.get_shape().as_list()[0] if batch_size is None: batch_size = tf.shape(points)[0] num_points = tf.shape(points)[1] def fn(i): """Map function.""" num_valid_points_i = num_valid_points[i] points_i = points[i, :num_valid_points_i, :] features_i = features[i, :num_valid_points_i, :] voxel_features_i, voxel_indices_i, segment_ids_i, voxel_start_location_i = ( pointcloud_to_sparse_voxel_grid_unbatched( points=points_i, features=features_i, grid_cell_size=grid_cell_size, segment_func=segment_func)) num_valid_voxels_i = tf.shape(voxel_features_i)[0] (voxel_features_i, voxel_indices_i, num_valid_voxels_i, segment_ids_i) = _pad_or_clip_voxels( voxel_features=voxel_features_i, voxel_indices=voxel_indices_i, num_valid_voxels=num_valid_voxels_i, segment_ids=segment_ids_i, voxels_pad_or_clip_size=voxels_pad_or_clip_size) segment_ids_i = tf.pad( segment_ids_i, paddings=[[0, num_points - num_valid_points_i]]) return (voxel_features_i, voxel_indices_i, num_valid_voxels_i, segment_ids_i, voxel_start_location_i) return tf.map_fn( fn=fn, elems=tf.range(batch_size), dtype=(tf.float32, tf.int32, tf.int32, tf.int32, tf.float32))
def process_episode(example_strings, class_ids, chunk_sizes, image_size, support_data_augmentation=None, query_data_augmentation=None): """Processes an episode. This function: 1) splits the batch of examples into "flush", "support", and "query" chunks, 2) throws away the "flush" chunk, 3) removes the padded dummy examples from the "support" and "query" chunks, and 4) extracts and processes images out of the example strings. 5) builds support and query targets (numbers from 0 to K-1 where K is the number of classes in the episode) from the class IDs. Args: example_strings: 1-D Tensor of dtype str, tf.train.Example protocol buffers. class_ids: 1-D Tensor of dtype int, class IDs (absolute wrt the original dataset). chunk_sizes: Tuple of 3 ints representing the sizes of (resp.) the flush, support, and query chunks. image_size: int, desired image size used during decoding. support_data_augmentation: A DataAugmentation object with parameters for perturbing the support set images. query_data_augmentation: A DataAugmentation object with parameters for perturbing the query set images. Returns: support_images, support_labels, support_class_ids, query_images, query_labels, query_class_ids: Tensors, batches of images, labels, and (absolute) class IDs, for the support and query sets (respectively). """ _log_data_augmentation(support_data_augmentation, 'support') _log_data_augmentation(query_data_augmentation, 'query') flush_chunk_size, support_chunk_size, _ = chunk_sizes support_start = flush_chunk_size query_start = support_start + support_chunk_size support_map_fn = functools.partial( process_example, image_size=image_size, data_augmentation=support_data_augmentation) query_map_fn = functools.partial( process_example, image_size=image_size, data_augmentation=query_data_augmentation) support_strings = example_strings[support_start:query_start] support_class_ids = class_ids[support_start:query_start] (support_strings, support_class_ids) = filter_dummy_examples(support_strings, support_class_ids) support_images = tf.map_fn( support_map_fn, support_strings, dtype=tf.float32, back_prop=False) query_strings = example_strings[query_start:] query_class_ids = class_ids[query_start:] (query_strings, query_class_ids) = filter_dummy_examples(query_strings, query_class_ids) query_images = tf.map_fn( query_map_fn, query_strings, dtype=tf.float32, back_prop=False) # Convert class IDs into labels in [0, num_ways). _, support_labels = tf.unique(support_class_ids) _, query_labels = tf.unique(query_class_ids) return (support_images, support_labels, support_class_ids, query_images, query_labels, query_class_ids)
def compute_module_criticality( objective_fn, module_variables_init, module_variables_final, num_samples_per_iteration=10, alpha_grid_size=10, sigma_grid_size=10, sigma_ratio=1.0, loss_threshold_condition=relative_error_condition, normalize_error=False, ): """Compute the criticality of a module parameterized by `module_variables`. Args: objective_fn: A callable that takes in an iterable of the module-specific variables and produces the value of the objective function. module_variables_init: A list of tf.Tensors; the variables of the module at initialization. module_variables_final: A list of tf.Tensors; the variables of the module at convergence. num_samples_per_iteration: Number of perturbations to sample each iteration. alpha_grid_size: The number of values to test for alpha, the interpolation coefficient. sigma_grid_size: The number of values to test for sigma, the standard deviation of the perturbation. sigma_ratio: Positive scalar multiplier k for values of sigma, to enforce that the tested values of sigma lie in [k * 1e-16, k]; the default is 1.0, implying that the tested values of sigma lie in the interval [1e-16, 1]. loss_threshold_condition: A callable that takes in a reference objective value and a candidate objective value and produces a thresholding decision. normalize_error: Whether to normalize the error that is minimized over in the definition of criticality by the Frobenius norm of the distance between initial and final parameters. Returns: A `collections.NamedTuple` that contains the results of the criticality analysis. """ initial_objective_value = objective_fn(module_variables_init) final_objective_value = objective_fn(module_variables_final) # Test a 2D grid of alpha and sigma values. float_zero = tf.cast(0, tf.float32) alphas, sigmas = tf.meshgrid( tf.linspace(float_zero, 1, alpha_grid_size + 1), tf.linspace(float_zero + 1e-16, 1, sigma_grid_size + 1) * sigma_ratio, ) alphas, sigmas = tf.reshape(alphas, [-1]), tf.reshape(sigmas, [-1]) def _evaluate_alpha_sigma(alpha_sigma): alpha, sigma = alpha_sigma return _interpolate_and_perturb( alpha=alpha, sigma=sigma, params_init=module_variables_init, params_final=module_variables_final, objective_fn=objective_fn, loss_threshold_condition=functools.partial( loss_threshold_condition, reference_error=final_objective_value), normalize_error=normalize_error, num_samples_per_iteration=num_samples_per_iteration, ) (threshold_conditions, interpolated_and_perturbed_losses, interpolated_and_perturbed_norms) = tf.map_fn( _evaluate_alpha_sigma, elems=(alphas, sigmas), dtype=(tf.bool, tf.float32, tf.float32), ) masked_interpolated_and_perturbed_norms = tf.where( threshold_conditions, interpolated_and_perturbed_norms, tf.ones_like(interpolated_and_perturbed_norms) * np.inf) idx_min = tf.math.argmin(masked_interpolated_and_perturbed_norms) (loss_final, norm_final, alpha_final, sigma_final) = (interpolated_and_perturbed_losses[idx_min], interpolated_and_perturbed_norms[idx_min], alphas[idx_min], sigmas[idx_min]) return ModuleCriticalityAnalysis( criticality_score=norm_final, alpha=alpha_final, sigma=sigma_final, loss_value=loss_final, num_samples_per_iteration=num_samples_per_iteration, alpha_grid_size=alpha_grid_size, sigma_grid_size=sigma_grid_size, sigma_ratio=sigma_ratio, initial_objective_value=initial_objective_value, final_objective_value=final_objective_value, )