def _positions_center_origin(height, width): """Returns image coordinates where the origin at the image center.""" h = tf.range(0.0, height, 1) w = tf.range(0.0, width, 1) center_h = tf.cast(height, tf.float32) / 2.0 - 0.5 center_w = tf.cast(width, tf.float32) / 2.0 - 0.5 return tf.stack(tf.meshgrid(h - center_h, w - center_w, indexing='ij'), -1)
def compute_target_optimal_q(reward, gamma, next_actions, next_q_values, next_states, terminals): """Builds an op used as a target for the Q-value. This algorithm corresponds to the method "OT" in Ie et al. https://arxiv.org/abs/1905.12767.. Args: reward: [batch_size] tensor, the immediate reward. gamma: float, discount factor with the usual RL meaning. next_actions: [batch_size, slate_size] tensor, the next slate. next_q_values: [batch_size, num_of_documents] tensor, the q values of the documents in the next step. next_states: [batch_size, 1 + num_of_documents] tensor, the features for the user and the docuemnts in the next step. terminals: [batch_size] tensor, indicating if this is a terminal step. Returns: [batch_size] tensor, the target q values. """ scores, score_no_click = _get_unnormalized_scores(next_states) # Obtain all possible slates given current docs in the candidate set. slate_size = next_actions.get_shape().as_list()[1] num_candidates = next_q_values.get_shape().as_list()[1] mesh_args = [list(range(num_candidates))] * slate_size slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1) slates = tf.reshape(slates, shape=(-1, slate_size)) # Filter slates that include duplicates to ensure each document is picked # at most once. unique_mask = tf.map_fn( lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])), slates, dtype=tf.bool) # [num_of_slates, slate_size] slates = tf.boolean_mask(tensor=slates, mask=unique_mask) # [batch_size, num_of_slates, slate_size] next_q_values_slate = tf.gather(next_q_values, slates, axis=1) # [batch_size, num_of_slates, slate_size] scores_slate = tf.gather(scores, slates, axis=1) # [batch_size, num_of_slates] batch_size = next_states.get_shape().as_list()[0] score_no_click_slate = tf.reshape( tf.tile(score_no_click, tf.shape(input=slates)[:1]), [batch_size, -1]) # [batch_size, num_of_slates] next_q_target_slate = tf.reduce_sum( input_tensor=next_q_values_slate * scores_slate, axis=2) / (tf.reduce_sum(input_tensor=scores_slate, axis=2) + score_no_click_slate) next_q_target_max = tf.reduce_max(input_tensor=next_q_target_slate, axis=1) return reward + gamma * next_q_target_max * ( 1. - tf.cast(terminals, tf.float32))
def kaf(linear, name, kernel='rbf', D=None, gamma=None): if D is None: D = tf.linspace(start=-2., stop=2., num=20) with tf.variable_scope('kaf', reuse=tf.AUTO_REUSE): if kernel == "rbf": K = gauss_kernel(linear, D, gamma=gamma) alpha = tf.get_variable(name, shape=(1, linear.get_shape()[-1], D.get_shape()[0]), initializer=tf.random_normal_initializer(stddev=0.1)) elif kernel == 'rbf2d': Dx, Dy = tf.meshgrid(D, D) K = gauss_kernel2D(linear, Dx, Dy, gamma=gamma) alpha = tf.get_variable(name, shape=(1, linear.get_shape()[-1] // 2, D.get_shape()[0] * D.get_shape()[0]), initializer=tf.random_normal_initializer(stddev=0.1)) else: raise NotImplementedError() act = tf.reduce_sum(tf.multiply(K, alpha), axis=-1) # act = tf.squeeze(act, axis=0) return act
def select_slate_optimal(slate_size, s_no_click, s, q): """Selects the slate using exhaustive search. This algorithm corresponds to the method "OS" in Ie et al. https://arxiv.org/abs/1905.12767. Args: slate_size: int, the size of the recommendation slate. s_no_click: float tensor, the score for not clicking any document. s: [num_of_documents] tensor, the scores for clicking documents. q: [num_of_documents] tensor, the predicted q values for documents. Returns: [slate_size] tensor, the selected slate. """ num_candidates = s.shape.as_list()[0] # Obtain all possible slates given current docs in the candidate set. mesh_args = [list(range(num_candidates))] * slate_size slates = tf.stack(tf.meshgrid(*mesh_args), axis=-1) slates = tf.reshape(slates, shape=(-1, slate_size)) # Filter slates that include duplicates to ensure each document is picked # at most once. unique_mask = tf.map_fn( lambda x: tf.equal(tf.size(input=x), tf.size(input=tf.unique(x)[0])), slates, dtype=tf.bool) slates = tf.boolean_mask(tensor=slates, mask=unique_mask) slate_q_values = tf.gather(s * q, slates) slate_scores = tf.gather(s, slates) slate_normalizer = tf.reduce_sum(input_tensor=slate_scores, axis=1) + s_no_click slate_q_values = slate_q_values / tf.expand_dims(slate_normalizer, 1) slate_sum_q_values = tf.reduce_sum(input_tensor=slate_q_values, axis=1) max_q_slate_index = tf.argmax(input=slate_sum_q_values) return tf.gather(slates, max_q_slate_index, axis=0)
def compute_module_criticality( objective_fn, module_variables_init, module_variables_final, num_samples_per_iteration=10, alpha_grid_size=10, sigma_grid_size=10, sigma_ratio=1.0, loss_threshold_condition=relative_error_condition, normalize_error=False, ): """Compute the criticality of a module parameterized by `module_variables`. Args: objective_fn: A callable that takes in an iterable of the module-specific variables and produces the value of the objective function. module_variables_init: A list of tf.Tensors; the variables of the module at initialization. module_variables_final: A list of tf.Tensors; the variables of the module at convergence. num_samples_per_iteration: Number of perturbations to sample each iteration. alpha_grid_size: The number of values to test for alpha, the interpolation coefficient. sigma_grid_size: The number of values to test for sigma, the standard deviation of the perturbation. sigma_ratio: Positive scalar multiplier k for values of sigma, to enforce that the tested values of sigma lie in [k * 1e-16, k]; the default is 1.0, implying that the tested values of sigma lie in the interval [1e-16, 1]. loss_threshold_condition: A callable that takes in a reference objective value and a candidate objective value and produces a thresholding decision. normalize_error: Whether to normalize the error that is minimized over in the definition of criticality by the Frobenius norm of the distance between initial and final parameters. Returns: A `collections.NamedTuple` that contains the results of the criticality analysis. """ initial_objective_value = objective_fn(module_variables_init) final_objective_value = objective_fn(module_variables_final) # Test a 2D grid of alpha and sigma values. float_zero = tf.cast(0, tf.float32) alphas, sigmas = tf.meshgrid( tf.linspace(float_zero, 1, alpha_grid_size + 1), tf.linspace(float_zero + 1e-16, 1, sigma_grid_size + 1) * sigma_ratio, ) alphas, sigmas = tf.reshape(alphas, [-1]), tf.reshape(sigmas, [-1]) def _evaluate_alpha_sigma(alpha_sigma): alpha, sigma = alpha_sigma return _interpolate_and_perturb( alpha=alpha, sigma=sigma, params_init=module_variables_init, params_final=module_variables_final, objective_fn=objective_fn, loss_threshold_condition=functools.partial( loss_threshold_condition, reference_error=final_objective_value), normalize_error=normalize_error, num_samples_per_iteration=num_samples_per_iteration, ) (threshold_conditions, interpolated_and_perturbed_losses, interpolated_and_perturbed_norms) = tf.map_fn( _evaluate_alpha_sigma, elems=(alphas, sigmas), dtype=(tf.bool, tf.float32, tf.float32), ) masked_interpolated_and_perturbed_norms = tf.where( threshold_conditions, interpolated_and_perturbed_norms, tf.ones_like(interpolated_and_perturbed_norms) * np.inf) idx_min = tf.math.argmin(masked_interpolated_and_perturbed_norms) (loss_final, norm_final, alpha_final, sigma_final) = (interpolated_and_perturbed_losses[idx_min], interpolated_and_perturbed_norms[idx_min], alphas[idx_min], sigmas[idx_min]) return ModuleCriticalityAnalysis( criticality_score=norm_final, alpha=alpha_final, sigma=sigma_final, loss_value=loss_final, num_samples_per_iteration=num_samples_per_iteration, alpha_grid_size=alpha_grid_size, sigma_grid_size=sigma_grid_size, sigma_ratio=sigma_ratio, initial_objective_value=initial_objective_value, final_objective_value=final_objective_value, )
def prepare_scannet_frame_dataset(inputs, min_pixel_depth=0.3, max_pixel_depth=6.0, valid_object_classes=None): """Maps the fields from loaded input to standard fields. Args: inputs: A dictionary of input tensors. min_pixel_depth: Pixels with depth values less than this are pruned. max_pixel_depth: Pixels with depth values more than this are pruned. valid_object_classes: List of valid object classes. if None, it is ignored. Returns: A dictionary of input tensors with standard field names. """ prepared_inputs = {} if 'cameras/rgbd_camera/intrinsics/K' not in inputs: raise ValueError('Intrinsic matrix is missing.') if 'cameras/rgbd_camera/extrinsics/R' not in inputs: raise ValueError('Extrinsic rotation matrix is missing.') if 'cameras/rgbd_camera/extrinsics/t' not in inputs: raise ValueError('Extrinsics translation is missing.') if 'cameras/rgbd_camera/depth_image' not in inputs: raise ValueError('Depth image is missing.') if 'cameras/rgbd_camera/color_image' not in inputs: raise ValueError('Color image is missing.') if 'frame_name' in inputs: prepared_inputs[standard_fields.InputDataFields .camera_image_name] = inputs['frame_name'] camera_intrinsics = inputs['cameras/rgbd_camera/intrinsics/K'] depth_image = inputs['cameras/rgbd_camera/depth_image'] image_height = tf.shape(depth_image)[0] image_width = tf.shape(depth_image)[1] x, y = tf.meshgrid( tf.range(image_width), tf.range(image_height), indexing='xy') x = tf.reshape(tf.cast(x, dtype=tf.float32) + 0.5, [-1, 1]) y = tf.reshape(tf.cast(y, dtype=tf.float32) + 0.5, [-1, 1]) point_positions = projections.image_frame_to_camera_frame( image_frame=tf.concat([x, y], axis=1), camera_intrinsics=camera_intrinsics) rotate_world_to_camera = inputs['cameras/rgbd_camera/extrinsics/R'] translate_world_to_camera = inputs['cameras/rgbd_camera/extrinsics/t'] point_positions = projections.to_world_frame( camera_frame_points=point_positions, rotate_world_to_camera=rotate_world_to_camera, translate_world_to_camera=translate_world_to_camera) prepared_inputs[standard_fields.InputDataFields .point_positions] = point_positions * tf.reshape( depth_image, [-1, 1]) depth_values = tf.reshape(depth_image, [-1]) valid_depth_mask = tf.logical_and( tf.greater_equal(depth_values, min_pixel_depth), tf.less_equal(depth_values, max_pixel_depth)) prepared_inputs[standard_fields.InputDataFields.point_colors] = tf.reshape( tf.cast(inputs['cameras/rgbd_camera/color_image'], dtype=tf.float32), [-1, 3]) prepared_inputs[standard_fields.InputDataFields.point_colors] *= (2.0 / 255.0) prepared_inputs[standard_fields.InputDataFields.point_colors] -= 1.0 prepared_inputs[ standard_fields.InputDataFields.point_positions] = tf.boolean_mask( prepared_inputs[standard_fields.InputDataFields.point_positions], valid_depth_mask) prepared_inputs[ standard_fields.InputDataFields.point_colors] = tf.boolean_mask( prepared_inputs[standard_fields.InputDataFields.point_colors], valid_depth_mask) if 'cameras/rgbd_camera/semantic_image' in inputs: prepared_inputs[ standard_fields.InputDataFields.object_class_points] = tf.cast( tf.reshape(inputs['cameras/rgbd_camera/semantic_image'], [-1, 1]), dtype=tf.int32) prepared_inputs[ standard_fields.InputDataFields.object_class_points] = tf.boolean_mask( prepared_inputs[ standard_fields.InputDataFields.object_class_points], valid_depth_mask) if 'cameras/rgbd_camera/instance_image' in inputs: prepared_inputs[ standard_fields.InputDataFields.object_instance_id_points] = tf.cast( tf.reshape(inputs['cameras/rgbd_camera/instance_image'], [-1]), dtype=tf.int32) prepared_inputs[standard_fields.InputDataFields .object_instance_id_points] = tf.boolean_mask( prepared_inputs[standard_fields.InputDataFields .object_instance_id_points], valid_depth_mask) if valid_object_classes is not None: valid_objects_mask = tf.cast( tf.zeros_like( prepared_inputs[ standard_fields.InputDataFields.object_class_points], dtype=tf.int32), dtype=tf.bool) for object_class in valid_object_classes: valid_objects_mask = tf.logical_or( valid_objects_mask, tf.equal( prepared_inputs[ standard_fields.InputDataFields.object_class_points], object_class)) valid_objects_mask = tf.cast( valid_objects_mask, dtype=prepared_inputs[ standard_fields.InputDataFields.object_class_points].dtype) prepared_inputs[standard_fields.InputDataFields .object_class_points] *= valid_objects_mask return prepared_inputs