def maybe_step(accepted, diagnostics, iterand, solver_internal_state): """Takes a single step only if the outcome has a low enough error.""" [ num_jacobian_evaluations, num_matrix_factorizations, num_ode_fn_evaluations, status ] = diagnostics [ jacobian_mat, jacobian_is_up_to_date, new_step_size, num_steps, num_steps_same_size, should_update_jacobian, should_update_step_size, time, unitary, upper ] = iterand [backward_differences, order, step_size] = solver_internal_state if max_num_steps is not None: status = tf1.where(tf.equal(num_steps, max_num_steps), -1, 0) backward_differences = tf1.where( should_update_step_size, bdf_util.interpolate_backward_differences( backward_differences, order, new_step_size / step_size), backward_differences) step_size = tf1.where(should_update_step_size, new_step_size, step_size) should_update_factorization = should_update_step_size num_steps_same_size = tf1.where(should_update_step_size, 0, num_steps_same_size) def update_factorization(): return bdf_util.newton_qr( jacobian_mat, newton_coefficients_array.read(order), step_size) if self._evaluate_jacobian_lazily: def update_jacobian_and_factorization(): new_jacobian_mat = jacobian_fn_mat(time, backward_differences[0]) new_unitary, new_upper = update_factorization() return [ new_jacobian_mat, True, num_jacobian_evaluations + 1, new_unitary, new_upper ] def maybe_update_factorization(): new_unitary, new_upper = tf.cond( should_update_factorization, update_factorization, lambda: [unitary, upper]) return [ jacobian_mat, jacobian_is_up_to_date, num_jacobian_evaluations, new_unitary, new_upper ] [ jacobian_mat, jacobian_is_up_to_date, num_jacobian_evaluations, unitary, upper ] = tf.cond(should_update_jacobian, update_jacobian_and_factorization, maybe_update_factorization) else: unitary, upper = update_factorization() num_matrix_factorizations += 1 tol = atol + rtol * tf.abs(backward_differences[0]) newton_tol = newton_tol_factor * tf.norm(tol) [ newton_converged, next_backward_difference, next_state_vec, newton_num_iters ] = bdf_util.newton(backward_differences, max_num_newton_iters, newton_coefficients_array.read(order), ode_fn_vec, order, step_size, time, newton_tol, unitary, upper) num_steps += 1 num_ode_fn_evaluations += newton_num_iters # If Newton's method failed and the Jacobian was up to date, decrease the # step size. newton_failed = tf.logical_not(newton_converged) should_update_step_size = newton_failed & jacobian_is_up_to_date new_step_size = step_size * tf1.where(should_update_step_size, newton_step_size_factor, 1.) # If Newton's method failed and the Jacobian was NOT up to date, update # the Jacobian. should_update_jacobian = newton_failed & tf.logical_not( jacobian_is_up_to_date) error_ratio = tf1.where( newton_converged, bdf_util.error_ratio(next_backward_difference, error_coefficients_array.read(order), tol), np.nan) accepted = error_ratio < 1. converged_and_rejected = newton_converged & tf.logical_not( accepted) # If Newton's method converged but the solution was NOT accepted, decrease # the step size. new_step_size = tf1.where( converged_and_rejected, util.next_step_size(step_size, order, error_ratio, safety_factor, min_step_size_factor, max_step_size_factor), new_step_size) should_update_step_size = should_update_step_size | converged_and_rejected # If Newton's method converged and the solution was accepted, update the # matrix of backward differences. time = tf1.where(accepted, time + step_size, time) backward_differences = tf1.where( accepted, bdf_util.update_backward_differences(backward_differences, next_backward_difference, next_state_vec, order), backward_differences) jacobian_is_up_to_date = jacobian_is_up_to_date & tf.logical_not( accepted) num_steps_same_size = tf1.where(accepted, num_steps_same_size + 1, num_steps_same_size) # Order and step size are only updated if we have taken strictly more than # order + 1 steps of the same size. This is to prevent the order from # being throttled. should_update_order_and_step_size = accepted & (num_steps_same_size > order + 1) backward_differences_array = tf.TensorArray( backward_differences.dtype, size=bdf_util.MAX_ORDER + 3, clear_after_read=False, element_shape=next_backward_difference.get_shape()).unstack( backward_differences) new_order = order new_error_ratio = error_ratio for offset in [-1, +1]: proposed_order = tf.clip_by_value(order + offset, 1, max_order) proposed_error_ratio = bdf_util.error_ratio( backward_differences_array.read(proposed_order + 1), error_coefficients_array.read(proposed_order), tol) proposed_error_ratio_is_lower = proposed_error_ratio < new_error_ratio new_order = tf1.where( should_update_order_and_step_size & proposed_error_ratio_is_lower, proposed_order, new_order) new_error_ratio = tf1.where( should_update_order_and_step_size & proposed_error_ratio_is_lower, proposed_error_ratio, new_error_ratio) order = new_order error_ratio = new_error_ratio new_step_size = tf1.where( should_update_order_and_step_size, util.next_step_size(step_size, order, error_ratio, safety_factor, min_step_size_factor, max_step_size_factor), new_step_size) should_update_step_size = (should_update_step_size | should_update_order_and_step_size) diagnostics = _BDFDiagnostics(num_jacobian_evaluations, num_matrix_factorizations, num_ode_fn_evaluations, status) iterand = _BDFIterand(jacobian_mat, jacobian_is_up_to_date, new_step_size, num_steps, num_steps_same_size, should_update_jacobian, should_update_step_size, time, unitary, upper) solver_internal_state = _BDFSolverInternalState( backward_differences, order, step_size) return accepted, diagnostics, iterand, solver_internal_state
def update(self, expert_dataset_iter, policy_dataset_iter, discount, replay_regularization=0.05, nu_reg=10.0): """A function that updates nu network. When replay regularization is non-zero, it learns (d_pi * (1 - replay_regularization) + d_rb * replay_regulazation) / (d_expert * (1 - replay_regularization) + d_rb * replay_regulazation) instead. Args: expert_dataset_iter: An tensorflow graph iteratable over expert data. policy_dataset_iter: An tensorflow graph iteratable over training policy data, used for regularization. discount: An MDP discount. replay_regularization: A fraction of samples to add from a replay buffer. nu_reg: A grad penalty regularization coefficient. """ (expert_states, expert_actions, expert_next_states) = expert_dataset_iter.get_next() expert_initial_states = expert_states rb_states, rb_actions, rb_next_states, _, _ = policy_dataset_iter.get_next( )[0] with tf.GradientTape(watch_accessed_variables=False, persistent=True) as tape: tape.watch(self.actor.variables) tape.watch(self.nu_net.variables) _, policy_next_actions, _ = self.actor(expert_next_states) _, rb_next_actions, rb_log_prob = self.actor(rb_next_states) _, policy_initial_actions, _ = self.actor(expert_initial_states) # Inputs for the linear part of DualDICE loss. expert_init_inputs = tf.concat( [expert_initial_states, policy_initial_actions], 1) expert_inputs = tf.concat([expert_states, expert_actions], 1) expert_next_inputs = tf.concat( [expert_next_states, policy_next_actions], 1) rb_inputs = tf.concat([rb_states, rb_actions], 1) rb_next_inputs = tf.concat([rb_next_states, rb_next_actions], 1) expert_nu_0 = self.nu_net(expert_init_inputs) expert_nu = self.nu_net(expert_inputs) expert_nu_next = self.nu_net(expert_next_inputs) rb_nu = self.nu_net(rb_inputs) rb_nu_next = self.nu_net(rb_next_inputs) expert_diff = expert_nu - discount * expert_nu_next rb_diff = rb_nu - discount * rb_nu_next linear_loss_expert = tf.reduce_mean(expert_nu_0 * (1 - discount)) linear_loss_rb = tf.reduce_mean(rb_diff) rb_expert_diff = tf.concat([expert_diff, rb_diff], 0) rb_expert_weights = tf.concat([ tf.ones(expert_diff.shape) * (1 - replay_regularization), tf.ones(rb_diff.shape) * replay_regularization ], 0) rb_expert_weights /= tf.reduce_sum(rb_expert_weights) non_linear_loss = tf.reduce_sum( tf.stop_gradient( weighted_softmax(rb_expert_diff, rb_expert_weights, axis=0)) * rb_expert_diff) linear_loss = (linear_loss_expert * (1 - replay_regularization) + linear_loss_rb * replay_regularization) loss = (non_linear_loss - linear_loss) alpha = tf.random.uniform(shape=(expert_inputs.shape[0], 1)) nu_inter = alpha * expert_inputs + (1 - alpha) * rb_inputs nu_next_inter = alpha * expert_next_inputs + ( 1 - alpha) * rb_next_inputs nu_inter = tf.concat([nu_inter, nu_next_inter], 0) with tf.GradientTape(watch_accessed_variables=False) as tape2: tape2.watch(nu_inter) nu_output = self.nu_net(nu_inter) nu_grad = tape2.gradient(nu_output, [nu_inter])[0] + EPS nu_grad_penalty = tf.reduce_mean( tf.square(tf.norm(nu_grad, axis=-1, keepdims=True) - 1)) nu_loss = loss + nu_grad_penalty * nu_reg pi_loss = -loss + keras_utils.orthogonal_regularization( self.actor.trunk) nu_grads = tape.gradient(nu_loss, self.nu_net.variables) pi_grads = tape.gradient(pi_loss, self.actor.variables) self.nu_optimizer.apply_gradients(zip(nu_grads, self.nu_net.variables)) self.actor_optimizer.apply_gradients( zip(pi_grads, self.actor.variables)) del tape self.avg_nu_expert(expert_nu) self.avg_nu_rb(rb_nu) self.nu_reg_metric(nu_grad_penalty) self.avg_loss(loss) self.avg_actor_loss(pi_loss) self.avg_actor_entropy(-rb_log_prob) if tf.equal(self.nu_optimizer.iterations % self.log_interval, 0): tf.summary.scalar('train dual dice/loss', self.avg_loss.result(), step=self.nu_optimizer.iterations) keras_utils.my_reset_states(self.avg_loss) tf.summary.scalar('train dual dice/nu expert', self.avg_nu_expert.result(), step=self.nu_optimizer.iterations) keras_utils.my_reset_states(self.avg_nu_expert) tf.summary.scalar('train dual dice/nu rb', self.avg_nu_rb.result(), step=self.nu_optimizer.iterations) keras_utils.my_reset_states(self.avg_nu_rb) tf.summary.scalar('train dual dice/nu reg', self.nu_reg_metric.result(), step=self.nu_optimizer.iterations) keras_utils.my_reset_states(self.nu_reg_metric) if tf.equal(self.actor_optimizer.iterations % self.log_interval, 0): tf.summary.scalar('train sac/actor_loss', self.avg_actor_loss.result(), step=self.actor_optimizer.iterations) keras_utils.my_reset_states(self.avg_actor_loss) tf.summary.scalar('train sac/actor entropy', self.avg_actor_entropy.result(), step=self.actor_optimizer.iterations) keras_utils.my_reset_states(self.avg_actor_entropy)
def spherical_uniform( shape, dimension, dtype=tf.float32, seed=None, name=None): """Generates `Tensor` drawn from a uniform distribution on the sphere. Args: shape: Vector-shaped, `int` `Tensor` representing shape of output. dimension: Scalar `int` `Tensor`, representing the dimensionality of the space where the sphere is embedded. dtype: (Optional) TF `dtype` representing `dtype` of output. Default value: `tf.float32`. seed: PRNG seed; see `tfp.random.sanitize_seed` for details. Default value: `None` (i.e., no seed). name: Python `str` name prefixed to Ops created by this function. Default value: `None` (i.e., 'random_spherical_uniform'). Returns: spherical_uniform: `Tensor` with specified `shape` and `dtype` consisting of real values drawn from a spherical uniform distribution. """ with tf.name_scope(name or 'spherical_uniform'): seed = samplers.sanitize_seed(seed) dimension = ps.convert_to_shape_tensor(ps.cast(dimension, dtype=tf.int32)) shape = ps.convert_to_shape_tensor(shape, dtype=tf.int32) dimension_static = tf.get_static_value(dimension) sample_shape = ps.concat([shape, [dimension]], axis=0) sample_shape = ps.convert_to_shape_tensor(sample_shape) # Special case one and two dimensions. This is to guard against the case # where the normal samples are zero. This can happen in dimensions 1 and 2. if dimension_static is not None: # This is equivalent to sampling Rademacher random variables. if dimension_static == 1: return rademacher(sample_shape, dtype=dtype, seed=seed) elif dimension_static == 2: u = samplers.uniform( shape, minval=0, maxval=2 * np.pi, dtype=dtype, seed=seed) return tf.stack([tf.math.cos(u), tf.math.sin(u)], axis=-1) else: normal_samples = samplers.normal( shape=ps.concat([shape, [dimension_static]], axis=0), seed=seed, dtype=dtype) unit_norm = normal_samples / tf.norm( normal_samples, ord=2, axis=-1)[..., tf.newaxis] return unit_norm # If we can't determine the dimension statically, tf.where between the # different options. r_seed, u_seed, n_seed = samplers.split_seed( seed, n=3, salt='spherical_uniform_dynamic_shape') rademacher_samples = rademacher(sample_shape, dtype=dtype, seed=r_seed) u = samplers.uniform( shape, minval=0, maxval=2 * np.pi, dtype=dtype, seed=u_seed) twod_samples = tf.concat( [tf.math.cos(u)[..., tf.newaxis], tf.math.sin(u)[..., tf.newaxis] * tf.ones( [dimension - 1], dtype=dtype)], axis=-1) normal_samples = samplers.normal( shape=ps.concat([shape, [dimension]], axis=0), seed=n_seed, dtype=dtype) nd_samples = normal_samples / tf.norm( normal_samples, ord=2, axis=-1)[..., tf.newaxis] return tf.where( tf.math.equal(dimension, 1), rademacher_samples, tf.where( tf.math.equal(dimension, 2), twod_samples, nd_samples))
def minimize_one_step(gradient_unregularized_loss, hessian_unregularized_loss_outer, hessian_unregularized_loss_middle, x_start, tolerance, l1_regularizer, l2_regularizer=None, maximum_full_sweeps=1, learning_rate=None, name=None): """One step of (the outer loop of) the minimization algorithm. This function returns a new value of `x`, equal to `x_start + x_update`. The increment `x_update in R^n` is computed by a coordinate descent method, that is, by a loop in which each iteration updates exactly one coordinate of `x_update`. (Some updates may leave the value of the coordinate unchanged.) The particular update method used is to apply an L1-based proximity operator, "soft threshold", whose fixed point `x_update_fix` is the desired minimum ```none x_update_fix = argmin{ Loss(x_start + x_update') + l1_regularizer * ||x_start + x_update'||_1 + l2_regularizer * ||x_start + x_update'||_2**2 : x_update' } ``` where in each iteration `x_update'` is constrained to have at most one nonzero coordinate. This update method preserves sparsity, i.e., tends to find sparse solutions if `x_start` is sparse. Additionally, the choice of step size is based on curvature (Hessian), which significantly speeds up convergence. This algorithm assumes that `Loss` is convex, at least in a region surrounding the optimum. (If `l2_regularizer > 0`, then only weak convexity is needed.) Args: gradient_unregularized_loss: (Batch of) `Tensor` with the same shape and dtype as `x_start` representing the gradient, evaluated at `x_start`, of the unregularized loss function (denoted `Loss` above). (In all current use cases, `Loss` is the negative log likelihood.) hessian_unregularized_loss_outer: (Batch of) `Tensor` or `SparseTensor` having the same dtype as `x_start`, and shape `[N, n]` where `x_start` has shape `[n]`, satisfying the property `Transpose(hessian_unregularized_loss_outer) @ diag(hessian_unregularized_loss_middle) @ hessian_unregularized_loss_inner = (approximation of) Hessian matrix of Loss, evaluated at x_start`. hessian_unregularized_loss_middle: (Batch of) vector-shaped `Tensor` having the same dtype as `x_start`, and shape `[N]` where `hessian_unregularized_loss_outer` has shape `[N, n]`, satisfying the property `Transpose(hessian_unregularized_loss_outer) @ diag(hessian_unregularized_loss_middle) @ hessian_unregularized_loss_inner = (approximation of) Hessian matrix of Loss, evaluated at x_start`. x_start: (Batch of) vector-shaped, `float` `Tensor` representing the current value of the argument to the Loss function. tolerance: scalar, `float` `Tensor` representing the convergence threshold. The optimization step will terminate early, returning its current value of `x_start + x_update`, once the following condition is met: `||x_update_end - x_update_start||_2 / (1 + ||x_start||_2) < sqrt(tolerance)`, where `x_update_end` is the value of `x_update` at the end of a sweep and `x_update_start` is the value of `x_update` at the beginning of that sweep. l1_regularizer: scalar, `float` `Tensor` representing the weight of the L1 regularization term (see equation above). If L1 regularization is not required, then `tfp.glm.fit_one_step` is preferable. l2_regularizer: scalar, `float` `Tensor` representing the weight of the L2 regularization term (see equation above). Default value: `None` (i.e., no L2 regularization). maximum_full_sweeps: Python integer specifying maximum number of sweeps to run. A "sweep" consists of an iteration of coordinate descent on each coordinate. After this many sweeps, the algorithm will terminate even if convergence has not been reached. Default value: `1`. learning_rate: scalar, `float` `Tensor` representing a multiplicative factor used to dampen the proximal gradient descent steps. Default value: `None` (i.e., factor is conceptually `1`). name: Python string representing the name of the TensorFlow operation. The default name is `"minimize_one_step"`. Returns: x: (Batch of) `Tensor` having the same shape and dtype as `x_start`, representing the updated value of `x`, that is, `x_start + x_update`. is_converged: scalar, `bool` `Tensor` indicating whether convergence occurred across all batches within the specified number of sweeps. iter: scalar, `int` `Tensor` representing the actual number of coordinate updates made (before achieving convergence). Since each sweep consists of `tf.size(x_start)` iterations, the maximum number of updates is `maximum_full_sweeps * tf.size(x_start)`. #### References [1]: Jerome Friedman, Trevor Hastie and Rob Tibshirani. Regularization Paths for Generalized Linear Models via Coordinate Descent. _Journal of Statistical Software_, 33(1), 2010. https://www.jstatsoft.org/article/view/v033i01/v33i01.pdf [2]: Guo-Xun Yuan, Chia-Hua Ho and Chih-Jen Lin. An Improved GLMNET for L1-regularized Logistic Regression. _Journal of Machine Learning Research_, 13, 2012. http://www.jmlr.org/papers/volume13/yuan12a/yuan12a.pdf """ with tf.name_scope(name or 'minimize_one_step'): x_shape = _get_shape(x_start) batch_shape = x_shape[:-1] dims = x_shape[-1] def _hessian_diag_elt_with_l2(coord): # pylint: disable=missing-docstring # Returns the (coord, coord) entry of # # Hessian(UnregularizedLoss(x) + l2_regularizer * ||x||_2**2) # # evaluated at x = x_start. inner_square = tf.reduce_sum(_sparse_or_dense_matmul_onehot( hessian_unregularized_loss_outer, coord)**2, axis=-1) unregularized_component = ( hessian_unregularized_loss_middle[..., coord] * inner_square) l2_component = _mul_or_none(2., l2_regularizer) return _add_ignoring_nones(unregularized_component, l2_component) grad_loss_with_l2 = _add_ignoring_nones( gradient_unregularized_loss, _mul_or_none(2., l2_regularizer, x_start)) # We define `x_update_diff_norm_sq_convergence_threshold` such that the # convergence condition # ||x_update_end - x_update_start||_2 / (1 + ||x_start||_2) # < sqrt(tolerance) # is equivalent to # ||x_update_end - x_update_start||_2**2 # < x_update_diff_norm_sq_convergence_threshold. x_update_diff_norm_sq_convergence_threshold = ( tolerance * (1. + tf.norm(tensor=x_start, ord=2, axis=-1))**2) # Reshape update vectors so that the coordinate sweeps happen along the # first dimension. This is so that we can use tensor_scatter_update to make # sparse updates along the first axis without copying the Tensor. # TODO(b/118789120): Switch to something like tf.tensor_scatter_nd_add if # or when it exists. update_shape = tf.concat([[dims], batch_shape], axis=-1) def _loop_cond(iter_, x_update_diff_norm_sq, x_update, hess_matmul_x_update): del x_update del hess_matmul_x_update sweep_complete = (iter_ > 0) & tf.equal(iter_ % dims, 0) small_delta = (x_update_diff_norm_sq < x_update_diff_norm_sq_convergence_threshold) converged = sweep_complete & small_delta allowed_more_iterations = iter_ < maximum_full_sweeps * dims return allowed_more_iterations & tf.reduce_any(~converged) def _loop_body( # pylint: disable=missing-docstring iter_, x_update_diff_norm_sq, x_update, hess_matmul_x_update): # Inner loop of the minimizer. # # This loop updates a single coordinate of x_update. Ideally, an # iteration of this loop would set # # x_update[j] += argmin{ LocalLoss(x_update + z*e_j) : z in R } # # where # # LocalLoss(x_update') # = LocalLossSmoothComponent(x_update') # + l1_regularizer * (||x_start + x_update'||_1 - # ||x_start + x_update||_1) # := (UnregularizedLoss(x_start + x_update') - # UnregularizedLoss(x_start + x_update) # + l2_regularizer * (||x_start + x_update'||_2**2 - # ||x_start + x_update||_2**2) # + l1_regularizer * (||x_start + x_update'||_1 - # ||x_start + x_update||_1) # # In this algorithm approximate the above argmin using (univariate) # proximal gradient descent: # # (*) x_update[j] = prox_{t * l1_regularizer * L1}( # x_update[j] - # t * d/dz|z=0 UnivariateLocalLossSmoothComponent(z)) # # where # # UnivariateLocalLossSmoothComponent(z) # := LocalLossSmoothComponent(x_update + z*e_j) # # and we approximate # # d/dz UnivariateLocalLossSmoothComponent(z) # = grad LocalLossSmoothComponent(x_update))[j] # ~= (grad LossSmoothComponent(x_start) # + x_update matmul HessianOfLossSmoothComponent(x_start))[j]. # # To choose the parameter t, we squint and pretend that the inner term of # (*) is a Newton update as if we were using Newton's method to minimize # UnivariateLocalLossSmoothComponent. That is, we choose t such that # # -t * d/dz ULLSC = -learning_rate * (d/dz ULLSC) / (d^2/dz^2 ULLSC) # # at z=0. Hence # # t = learning_rate / (d^2/dz^2|z=0 ULLSC) # = learning_rate / HessianOfLossSmoothComponent( # x_start + x_update)[j,j] # ~= learning_rate / HessianOfLossSmoothComponent( # x_start)[j,j] # # The above approximation is equivalent to assuming that # HessianOfUnregularizedLoss is constant, i.e., ignoring third-order # effects. # # Note that because LossSmoothComponent is (assumed to be) convex, t is # positive. # In above notation, coord = j. coord = iter_ % dims # x_update_diff_norm_sq := ||x_update_end - x_update_start||_2**2, # computed incrementally, where x_update_end and x_update_start are as # defined in the convergence criteria. Accordingly, we reset # x_update_diff_norm_sq to zero at the beginning of each sweep. x_update_diff_norm_sq = tf.where( tf.equal(coord, 0), dtype_util.as_numpy_dtype(x_update_diff_norm_sq.dtype)(0.), x_update_diff_norm_sq) # Recall that x_update and hess_matmul_x_update has the rightmost # dimension transposed to the leftmost dimension. w_old = x_start[..., coord] + x_update[coord, ...] # This is the coordinatewise Newton update if no L1 regularization. # In above notation, newton_step = -t * (approximation of d/dz|z=0 ULLSC). second_deriv = _hessian_diag_elt_with_l2(coord) newton_step = -_mul_ignoring_nones( # pylint: disable=invalid-unary-operand-type learning_rate, grad_loss_with_l2[..., coord] + hess_matmul_x_update[coord, ...]) / second_deriv # Applying the soft-threshold operator accounts for L1 regularization. # In above notation, delta = # prox_{t*l1_regularizer*L1}(w_old + newton_step) - w_old. delta = (soft_threshold( w_old + newton_step, _mul_ignoring_nones(learning_rate, l1_regularizer) / second_deriv) - w_old) def _do_update(x_update_diff_norm_sq, x_update, hess_matmul_x_update): # pylint: disable=missing-docstring hessian_column_with_l2 = sparse_or_dense_matvecmul( hessian_unregularized_loss_outer, hessian_unregularized_loss_middle * _sparse_or_dense_matmul_onehot( hessian_unregularized_loss_outer, coord), adjoint_a=True) if l2_regularizer is not None: hessian_column_with_l2 += _one_hot_like( hessian_column_with_l2, coord, on_value=2. * l2_regularizer) # Move the batch dimensions of `hessian_column_with_l2` to rightmost in # order to conform to `hess_matmul_x_update`. n = tf.rank(hessian_column_with_l2) perm = tf.roll(tf.range(n), shift=1, axis=0) hessian_column_with_l2 = tf.transpose(a=hessian_column_with_l2, perm=perm) # Update the entire batch at `coord` even if `delta` may be 0 at some # batch coordinates. In those cases, adding `delta` is a no-op. x_update = tf.tensor_scatter_nd_add(x_update, [[coord]], [delta]) with tf.control_dependencies([x_update]): x_update_diff_norm_sq_ = x_update_diff_norm_sq + delta**2 hess_matmul_x_update_ = (hess_matmul_x_update + delta * hessian_column_with_l2) # Hint that loop vars retain the same shape. x_update_diff_norm_sq_.set_shape( x_update_diff_norm_sq_.shape.merge_with( x_update_diff_norm_sq.shape)) hess_matmul_x_update_.set_shape( hess_matmul_x_update_.shape.merge_with( hess_matmul_x_update.shape)) return [ x_update_diff_norm_sq_, x_update, hess_matmul_x_update_ ] inputs_to_update = [ x_update_diff_norm_sq, x_update, hess_matmul_x_update ] return [iter_ + 1] + prefer_static.cond( # Note on why checking delta (a difference of floats) for equality to # zero is ok: # # First of all, x - x == 0 in floating point -- see # https://stackoverflow.com/a/2686671 # # Delta will conceptually equal zero when one of the following holds: # (i) |w_old + newton_step| <= threshold and w_old == 0 # (ii) |w_old + newton_step| > threshold and # w_old + newton_step - sign(w_old + newton_step) * threshold # == w_old # # In case (i) comparing delta to zero is fine. # # In case (ii), newton_step conceptually equals # sign(w_old + newton_step) * threshold. # Also remember # threshold = -newton_step / (approximation of d/dz|z=0 ULLSC). # So (i) happens when # (approximation of d/dz|z=0 ULLSC) == -sign(w_old + newton_step). # If we did not require LossSmoothComponent to be strictly convex, # then this could actually happen a non-negligible amount of the time, # e.g. if the loss function is piecewise linear and one of the pieces # has slope 1. But since LossSmoothComponent is strictly convex, (i) # should not systematically happen. tf.reduce_all(tf.equal(delta, 0.)), lambda: inputs_to_update, lambda: _do_update(*inputs_to_update)) base_dtype = x_start.dtype.base_dtype iter_, x_update_diff_norm_sq, x_update, _ = tf.while_loop( cond=_loop_cond, body=_loop_body, loop_vars=[ tf.zeros([], dtype=np.int32, name='iter'), tf.zeros(batch_shape, dtype=base_dtype, name='x_update_diff_norm_sq'), tf.zeros(update_shape, dtype=base_dtype, name='x_update'), tf.zeros(update_shape, dtype=base_dtype, name='hess_matmul_x_update'), ]) # Convert back x_update to the shape of x_start by transposing the leftmost # dimension to the rightmost. n = tf.rank(x_update) perm = tf.roll(tf.range(n), shift=-1, axis=0) x_update = tf.transpose(a=x_update, perm=perm) converged = tf.reduce_all( x_update_diff_norm_sq < x_update_diff_norm_sq_convergence_threshold ) return x_start + x_update, converged, iter_ / dims
def main(unused_args): del unused_args # # General setup. # ebm_util.init_tf2() ebm_util.set_seed(FLAGS.seed) output_dir = FLAGS.logdir checkpoint_dir = os.path.join(output_dir, 'checkpoint') samples_dir = os.path.join(output_dir, 'samples') tf.io.gfile.makedirs(samples_dir) tf.io.gfile.makedirs(checkpoint_dir) log_f = tf.io.gfile.GFile(os.path.join(output_dir, 'log.out'), mode='w') logger = ebm_util.setup_logging('main', log_f, console=False) logger.info({k: v._value for (k, v) in FLAGS._flags().items()}) # pylint: disable=protected-access # # Data # if FLAGS.dataset == 'mnist': x_train = ebm_util.mnist_dataset(N_CH) elif FLAGS.dataset == 'celeba': x_train = ebm_util.celeba_dataset() else: raise ValueError(f'Unknown dataset. {FLAGS.dataset}') train_ds = tf.data.Dataset.from_tensor_slices(x_train).shuffle( 10000).batch(FLAGS.batch_size) # # Models # if FLAGS.q_type == 'mean_field_gaussian': q = MeanFieldGaussianQ() u = make_u() # # Optimizers # def lr_p(step): lr = FLAGS.p_learning_rate * (1. - (step / (1.5 * FLAGS.train_steps))) return lr def lr_q(step): lr = FLAGS.q_learning_rate * (1. - (step / (1.5 * FLAGS.train_steps))) return lr opt_q = tf.optimizers.Adam(learning_rate=ebm_util.LambdaLr(lr_q)) opt_p = tf.optimizers.Adam(learning_rate=ebm_util.LambdaLr(lr_p), beta_1=FLAGS.p_adam_beta_1) # # Checkpointing # global_step_var = tf.Variable(0, trainable=False) checkpoint = tf.train.Checkpoint(opt_p=opt_p, opt_q=opt_q, u=u, q=q, global_step_var=global_step_var) checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint') if tf.io.gfile.exists(checkpoint_path + '.index'): print(f'Restoring from {checkpoint_path}') checkpoint.restore(checkpoint_path) # # Stats initialization # stat_i = [] stat_keys = [ 'E_pos', # Mean energy of the positive samples. 'E_neg_q', # Mean energy of the negative samples (pre-HMC). 'E_neg_p', # Mean energy of the negative samples (post-HMC). 'H', # Entropy of Q (if known). 'pd_pos', # Pairse differences of the positive samples. 'pd_neg_q', # Pairwise differences of the negative samples (pre-HMC). 'pd_neg_p', # Pairwise differences of the negative samples (post-HMC). 'hmc_disp', # L2 distance between initial and final entropyMC samples. 'hmc_p_accept', # entropyMC P(accept). 'hmc_step_size', # entropyMC step size. 'x_neg_p_min', # Minimum value of the negative samples (post-HMC). 'x_neg_p_max', # Maximum value of the negative samples (post-HMC). 'time', # Time taken to do the training step. ] stat = {k: [] for k in stat_keys} def array_to_str(a, fmt='{:>8.4f}'): return ' '.join([fmt.format(v) for v in a]) def stats_callback(step, entropy, pd_neg_q): del step, entropy, pd_neg_q step_size = FLAGS.mcmc_step_size train_ds_iter = iter(train_ds) x_pos_1 = ebm_util.data_preprocess(next(train_ds_iter)) x_pos_2 = ebm_util.data_preprocess(next(train_ds_iter)) global_step = global_step_var.numpy() while global_step < (FLAGS.train_steps + 1): for x_pos in train_ds: # Drop partial batches. if x_pos.shape[0] != FLAGS.batch_size: continue # # Update # start_time = time.time() x_pos = ebm_util.data_preprocess(x_pos) x_pos = ebm_util.data_discrete_noise(x_pos) if FLAGS.p_loss == 'neutra_hmc': (x_neg_q, x_neg_p, p_accept, step_size, pos_e, pos_e_updated, neg_e_q, neg_e_p, neg_e_p_updated) = train_p(q, u, x_pos, step_size, opt_p) elif FLAGS.p_loss == 'neutra_iid': (x_neg_q, x_neg_p, p_accept, step_size, pos_e, pos_e_updated, neg_e_q, neg_e_p, neg_e_p_updated) = train_p_mh(q, u, x_pos, step_size, opt_p) else: raise ValueError(f'Unknown P loss {FLAGS.p_loss}') if FLAGS.q_loss == 'forward_kl': train_q_fwd_kl(q, x_neg_p, opt_q) entropy = 0.0 mle_loss = 0.0 elif FLAGS.q_loss == 'reverse_kl': for _ in range(10): _, entropy = train_q_rev_kl(q, u, opt_q) mle_loss = 0.0 elif FLAGS.q_loss == 'reverse_kl_mle': for _ in range(FLAGS.q_sub_steps): alpha = FLAGS.q_rkl_weight (_, entropy, _, mle_loss, norm_grads_ebm, norm_grads_mle) = train_q_rev_kl_mle( q, u, x_pos, tf.convert_to_tensor(alpha), opt_q) elif FLAGS.q_loss == 'mle': mle_loss = train_q_mle(q, x_pos, opt_q) entropy = 0.0 else: raise ValueError(f'Unknown Q loss {FLAGS.q_loss}') end_time = time.time() # # Stats # hmc_disp = tf.reduce_mean( tf.norm(tf.reshape(x_neg_q, [64, -1]) - tf.reshape(x_neg_p, [64, -1]), axis=1)) if global_step % FLAGS.plot_steps == 0: # Positives + negatives. ebm_util.plot( tf.reshape(ebm_util.data_postprocess(x_neg_q), [FLAGS.batch_size, N_WH, N_WH, N_CH]), os.path.join(samples_dir, f'x_neg_q_{global_step}.png')) ebm_util.plot( tf.reshape(ebm_util.data_postprocess(x_neg_p), [FLAGS.batch_size, N_WH, N_WH, N_CH]), os.path.join(samples_dir, f'x_neg_p_{global_step}.png')) ebm_util.plot( tf.reshape(ebm_util.data_postprocess(x_pos), [FLAGS.batch_size, N_WH, N_WH, N_CH]), os.path.join(samples_dir, f'x_pos_{global_step}.png')) # Samples for various temperatures. for t in [0.1, 0.5, 1.0, 2.0, 4.0]: _, x_neg_q_t, _ = q.sample_with_log_prob(FLAGS.batch_size, temp=t) ebm_util.plot( tf.reshape(ebm_util.data_postprocess(x_neg_q_t), [FLAGS.batch_size, N_WH, N_WH, N_CH]), os.path.join(samples_dir, f'x_neg_t_{t}_{global_step}.png')) stats_callback(global_step, entropy, ebm_util.nearby_difference(x_neg_q)) stat_i.append(global_step) stat['E_pos'].append(pos_e_updated) stat['E_neg_q'].append(neg_e_q) stat['E_neg_p'].append(neg_e_p) stat['H'].append(entropy) stat['pd_neg_q'].append(ebm_util.nearby_difference(x_neg_q)) stat['pd_neg_p'].append(ebm_util.nearby_difference(x_neg_p)) stat['pd_pos'].append(ebm_util.nearby_difference(x_pos)) stat['hmc_disp'].append(hmc_disp) stat['hmc_p_accept'].append(p_accept) stat['hmc_step_size'].append(step_size) stat['x_neg_p_min'].append(tf.reduce_min(x_neg_p)) stat['x_neg_p_max'].append(tf.reduce_max(x_neg_p)) stat['time'].append(end_time - start_time) ebm_util.plot_stat(stat_keys, stat, stat_i, output_dir) # Doing a linear interpolation in the latent space. z_pos_1 = q.forward(x_pos_1)[0] z_pos_2 = q.forward(x_pos_2)[0] x_alphas = [] n_steps = 10 for j in range(0, n_steps + 1): alpha = (j / n_steps) z_alpha = (1. - alpha) * z_pos_1 + (alpha) * z_pos_2 x_alpha = q.reverse(z_alpha)[0] x_alphas.append(x_alpha) ebm_util.plot_n_by_m( ebm_util.data_postprocess( tf.reshape(tf.stack(x_alphas, axis=1), [ (n_steps + 1) * FLAGS.batch_size, N_WH, N_WH, N_CH ])), os.path.join(samples_dir, f'x_alpha_{global_step}.png'), FLAGS.batch_size, n_steps + 1) # Doing random perturbations in the latent space. for eps in [1e-4, 1e-3, 1e-2, 1e-1, 1e0, 2e0, 2.5e0, 3e0]: z_pos_2_eps = z_pos_2 + eps * tf.random.normal( z_pos_2.shape) x_alpha = q.reverse(z_pos_2_eps)[0] ebm_util.plot( tf.reshape(ebm_util.data_postprocess(x_alpha), [FLAGS.batch_size, N_WH, N_WH, N_CH]), os.path.join(samples_dir, f'x_alpha_eps_{eps}_{global_step}.png')) # Checking the log-probabilites of positive and negative examples under # Q. z_neg_test, x_neg_test, _ = q.sample_with_log_prob( FLAGS.batch_size, temp=FLAGS.q_temperature) z_pos_test = q.forward(x_pos)[0] z_neg_test_pd = ebm_util.nearby_difference(z_neg_test) z_pos_test_pd = ebm_util.nearby_difference(z_pos_test) z_norms_neg = tf.reduce_mean(tf.norm(z_neg_test, axis=1)) z_norms_pos = tf.reduce_mean(tf.norm(z_pos_test, axis=1)) log_prob_neg = tf.reduce_mean(q.log_prob(x_neg_test)) log_prob_pos = tf.reduce_mean(q.log_prob(x_pos)) logger.info(' '.join([ f'i={global_step:6d}', # Pre-update, post-update (f'E_pos=[{pos_e:10.4f} {pos_e_updated:10.4f} ' + f'{pos_e_updated - pos_e:10.4f}]'), # Pre-update pre-HMC, pre-update post-HMC, post-update post-HMC (f'E_neg=[{neg_e_q:10.4f} {neg_e_p:10.4f} ' + f'{neg_e_p_updated:10.4f} {neg_e_p_updated - neg_e_p:10.4f}]' ), f'mle={tf.reduce_mean(mle_loss):8.4f}', f'H={entropy:8.4f}', f'norm_grads_ebm={norm_grads_ebm:8.4f}', f'norm_grads_mle={norm_grads_mle:8.4f}', f'pd(x_pos)={ebm_util.nearby_difference(x_pos):8.4f}', f'pd(x_neg_q)={ebm_util.nearby_difference(x_neg_q):8.4f}', f'pd(x_neg_p)={ebm_util.nearby_difference(x_neg_p):8.4f}', f'hmc_disp={hmc_disp:8.4f}', f'p(accept)={p_accept:8.4f}', f'step_size={step_size:8.4f}', # Min, max. (f'x_neg_q=[{tf.reduce_min(x_neg_q):8.4f} ' + f'{tf.reduce_max(x_neg_q):8.4f}]'), (f'x_neg_p=[{tf.reduce_min(x_neg_p):8.4f} ' + f'{tf.reduce_max(x_neg_p):8.4f}]'), f'z_neg_norm={array_to_str(z_norms_neg)}', f'z_pos_norm={array_to_str(z_norms_pos)}', f'z_neg_test_pd={z_neg_test_pd:>8.2f}', f'z_pos_test_pd={z_pos_test_pd:>8.2f}', f'log_prob_neg={log_prob_neg:12.2f}', f'log_prob_pos={log_prob_pos:12.2f}', ])) if global_step % FLAGS.save_steps == 0: global_step_var.assign(global_step) checkpoint.write(os.path.join(checkpoint_dir, 'checkpoint')) global_step += 1
def _sample_n(self, n, seed=None): raw = samplers.normal( shape=ps.concat([[n], self.batch_shape, [self.dimension]], axis=0), seed=seed, dtype=self.dtype) unit_norm = raw / tf.norm(raw, ord=2, axis=-1)[..., tf.newaxis] return unit_norm
def train_model(model, ds_train, ds_test, logdir, total_steps=5000, batch_size=128, val_batch_size=1000, save_freq=5, log_freq=250, use_metainit=False, oneshot_prune_fraction=0., gradient_regularization=0): """Training of the CNN on MNIST.""" logging.info('Writing training logs to %s', logdir) writer = tf.summary.create_file_writer(os.path.join(logdir, 'train_logs')) optimizer = utils.get_optimizer(total_steps) loss_object = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True) train_batch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy( name='train_batch_accuracy') # Let's create 2 disjoint validation sets. (val_x, val_y), (val2_x, val2_y) = [ d for d in ds_train.take(val_batch_size * 2).batch(val_batch_size) ] # We use a separate set than the one we are using in our training. def loss_fn(x, y): predictions = model(x, training=True) reg_loss = tf.add_n(model.losses) if model.losses else 0 return loss_object(y, predictions) + reg_loss mask_updater = mask_updaters.get_mask_updater(model, optimizer, loss_fn) if mask_updater: mask_updater.set_validation_data(val2_x, val2_y) update_prune_step(model, 0) if oneshot_prune_fraction > 0: logging.info('Running one shot prunning at the beginning.') if not mask_updater: raise ValueError( 'mask_updater does not exists. Please set ' 'mask_updater.update_alg flag for one shot pruning.') mask_updater.prune(oneshot_prune_fraction) if use_metainit: n_params = 0 for layer in model.layers: if isinstance(layer, utils.PRUNING_WRAPPER): for _, mask, _ in layer.pruning_vars: n_params += tf.reduce_sum(mask) metainit.meta_init(model, loss_object, (128, 28, 28, 1), (128, 10), n_params, mask_gradient_fn=mask_gradients) # This is used to calculate some distances, would give incorrect results when # we restart the training. initial_params = list(map(lambda a: a.numpy(), model.trainable_variables)) # Create the checkpoint object and restore if there is a checkpoint in the # folder. ckpt = tf.train.Checkpoint(model=model, optimizer=optimizer) ckpt_manager = tf.train.CheckpointManager(checkpoint=ckpt, directory=logdir, max_to_keep=None) if ckpt_manager.latest_checkpoint: logging.info('Restored from %s', ckpt_manager.latest_checkpoint) ckpt.restore(ckpt_manager.latest_checkpoint) is_restored = True else: logging.info('Starting from scratch.') is_restored = False # Obtain global_step after loading checkpoint. global_step = optimizer.iterations tf.summary.experimental.set_step(global_step) trainable_vars = model.trainable_variables def get_gradients(x, y, log_batch_gradient=False, is_regularized=True): """Gets spars gradients and possibly logs some statistics.""" is_grad_regularized = gradient_regularization != 0 with tf.GradientTape(persistent=is_grad_regularized) as tape: predictions = model(x, training=True) batch_loss = loss_object(y, predictions) if is_regularized and is_grad_regularized: gradients = tape.gradient(batch_loss, trainable_vars) gradients = mask_gradients(model, gradients, trainable_vars) grad_vec = flatten_list_of_vars(gradients) batch_loss += tf.nn.l2_loss(grad_vec) * gradient_regularization # Regularization might have been disabled. reg_loss = tf.add_n(model.losses) if model.losses else 0 if is_regularized: batch_loss += reg_loss gradients = tape.gradient(batch_loss, trainable_vars) # Gradients are dense, we should mask them to ensure updates are sparse; # So is the norm calculation. gradients = mask_gradients(model, gradients, trainable_vars) # If batch gradient log it. if log_batch_gradient: tf.summary.scalar('train_batch_loss', batch_loss) tf.summary.scalar('train_batch_reg_loss', reg_loss) train_batch_accuracy.update_state(y, predictions) tf.summary.scalar('train_batch_accuracy', train_batch_accuracy.result()) train_batch_accuracy.reset_states() return gradients def log_fn(): logging.info('Logging at iter: %d', global_step.numpy()) log_sparsities(model) test_loss, test_acc = test_model(model, ds_test) tf.summary.scalar('test_loss', test_loss) tf.summary.scalar('test_acc', test_acc) # Log gradient norm. # We want to obtain/log gradients without regularization term. gradients = get_gradients(val_x, val_y, log_batch_gradient=False, is_regularized=False) for var, grad in zip(trainable_vars, gradients): tf.summary.scalar(f'gradnorm/{var.name}', tf.norm(grad)) # Log all gradients together all_norm = tf.norm(flatten_list_of_vars(gradients)) tf.summary.scalar('.allparams/gradnorm', all_norm) # Log momentum values: for s_name in optimizer.get_slot_names(): # Currently we only log momentum. if s_name not in ['momentum']: continue all_slots = [ optimizer.get_slot(var, s_name) for var in trainable_vars ] all_norm = tf.norm(flatten_list_of_vars(all_slots)) tf.summary.scalar(f'.allparams/norm_{s_name}', all_norm) # Log distance to init. for initial_val, val in zip(initial_params, model.trainable_variables): tf.summary.scalar(f'dist_init_l2/{val.name}', tf.norm(initial_val - val)) cos_distance = cosine_distance(initial_val, val) tf.summary.scalar(f'dist_init_cosine/{val.name}', cos_distance) # Mask update logs: if mask_updater: tf.summary.scalar('drop_fraction', mask_updater.last_drop_fraction) # Log all distances together. flat_initial = flatten_list_of_vars(initial_params) flat_current = flatten_list_of_vars(model.trainable_variables) tf.summary.scalar('.allparams/dist_init_l2/', tf.norm(flat_initial - flat_current)) tf.summary.scalar('.allparams/dist_init_cosine/', cosine_distance(flat_initial, flat_current)) # Log masks for layer in model.layers: if isinstance(layer, utils.PRUNING_WRAPPER): for _, mask, _ in layer.pruning_vars: tf.summary.image('mask/%s' % mask.name, var_to_img(mask)) writer.flush() def save_fn(step=None): save_step = step if step else global_step saved_ckpt = ckpt_manager.save(checkpoint_number=save_step) logging.info('Saved checkpoint: %s', saved_ckpt) with writer.as_default(): for x, y in ds_train.repeat().shuffle( buffer_size=60000).batch(batch_size): if global_step >= total_steps: logging.info('Total steps: %d is completed', global_step.numpy()) save_fn() break update_prune_step(model, global_step) if tf.equal(global_step, 0): logging.info('Seed: %s First 10 Label: %s', FLAGS.seed, y[:10]) if global_step % save_freq == 0: # If just loaded, don't save it again. if is_restored: is_restored = False else: save_fn() if global_step % log_freq == 0: log_fn() gradients = get_gradients(x, y, log_batch_gradient=True) tf.summary.scalar('lr', optimizer.lr(global_step)) optimizer.apply_gradients(zip(gradients, trainable_vars)) if mask_updater and mask_updater.is_update_iter(global_step): # Save the network before mask_update, we want to use negative integers # for this. save_fn(step=(-global_step + 1)) # Gradient norm before. gradients = get_gradients(val_x, val_y, log_batch_gradient=False, is_regularized=False) norm_before = tf.norm(flatten_list_of_vars(gradients)) results = mask_updater.update(global_step) # Save network again save_fn(step=-global_step) if results: for mask_name, drop_frac in results.items(): tf.summary.scalar('drop_fraction/%s' % mask_name, drop_frac) # Gradient norm after mask update. gradients = get_gradients(val_x, val_y, log_batch_gradient=False, is_regularized=False) norm_after = tf.norm(flatten_list_of_vars(gradients)) tf.summary.scalar('.allparams/gradnorm_mask_update_improvment', norm_after - norm_before) logging.info('Performance after training:') log_fn() return model
def compute_norm(self, x): return tf.reduce_sum(tf.norm(x, ord=2, axis=1)**3)
def _gradients_order2_norm(self, gradients): norm = tf.norm( tf.stack([tf.norm(grad) for grad in gradients if grad is not None])) return norm
def prune_one_unit(self, pruning_pool, baselines=None, normalized_scores=True, pruning_method=None, is_bp=None): """Picks a layer and prunes a single unit using the scoring function. Args: pruning_pool: list, of layers that are considered for pruning. baselines: dict, if exists, subtracts the given constant from the scores of individual layers. The keys should a subset of pruning_pool. normalized_scores: bool, if True the scores are normalized with l2 norm. pruning_method: str, from ['norm', 'mrs', 'rs', 'rand', 'abs_mrs', 'rs']. If given, overwrites the default value. is_bp: bool, if True Mean Replacement Pruning is used and bias propagation is made. If given, overwrites the default value. Raises: AssertionError: if the arguments provided doesn't match specs. """ pruning_method = pruning_method if pruning_method else self.pruning_method is_bp = is_bp if is_bp else self.is_bp if pruning_method not in ALL_SCORING_FUNCTIONS: raise ValueError('%s is not one of %s' % (pruning_method, ALL_SCORING_FUNCTIONS)) if baselines is None: baselines = {} logging.info('Prunning with: %s, is_bp: %s', pruning_method, is_bp) # Calculating the scoring function/mean value. is_abs = pruning_method.startswith('abs') is_mrs = pruning_method.endswith('mrs') is_rs = pruning_method.endswith('rs') and not is_mrs is_grad = is_mrs or is_rs train_utils.cross_entropy_loss( self.model, self.subset_val, training=False, compute_mean_replacement_saliency=is_mrs, compute_removal_saliency=is_rs, is_abs=is_abs, aggregate_values=True, run_gradient=is_grad) scores = {} mean_values = {} smallest_score = None smallest_l_name = None smallest_nprune = None for l_name in pruning_pool: l_ts = getattr(self.model, l_name + '_ts') l = getattr(self.model, l_name) mean_values[l_name] = l_ts.get_saved_values('mean') # Make sure the masks are applied after last gradient update. Note # that this is necessary for `norm` functions, since it doesn't call the # model and therefore the masks are not applied. l.apply_masks() if pruning_method == 'rand': scores[l_name] = unitscorers.random_score( l.get_layer().weights[0]) elif pruning_method == 'norm': scores[l_name] = unitscorers.norm_score( l.get_layer().weights[0]) else: # mrs or rs. score_name = 'rs' if is_rs else 'mrs' scores[l_name] = l_ts.get_saved_values(score_name) if normalized_scores: scores[l_name] /= tf.norm(scores[l_name]) baseline_score = baselines.get(l_name, 0) if baseline_score != 0: # Regularizing the scores with c_flop weights. scores[l_name] -= baseline_score # If there is an existing mask we have to make sure pruned connections # are indicated. Let's set them to very small negative number (-1e10). # Note that the elements of `l.mask_bias` consist of zeros and ones only. if l.mask_bias is not None: # Setting the scores of the pruned units to zero. scores[l_name] = scores[l_name] * l.mask_bias # Setting the scores of the pruned units to -1e10. scores[l_name] += -1e10 * (1 - l.mask_bias) # Number of previously pruned units. n_pruned = tf.count_nonzero(l.mask_bias - 1).numpy() layer_smallest_score = tf.reduce_min( tf.boolean_mask(scores[l_name], l.mask_bias)).numpy() # Do not prune the last unit. if tf.equal(n_pruned + 1, tf.size(l.mask_bias)): continue else: n_pruned = 0 layer_smallest_score = tf.reduce_min(scores[l_name]).numpy() logging.info('Layer:%s, min:%f', l_name, layer_smallest_score) if smallest_score is None or (layer_smallest_score < smallest_score): smallest_score = layer_smallest_score smallest_l_name = l_name # We want to prune one more than before. smallest_nprune = n_pruned + 1 logging.info('UNIT_PRUNED, layer:%s, n_pruned:%d', smallest_l_name, smallest_nprune) mean_values = {smallest_l_name: mean_values[smallest_l_name]} scores = {smallest_l_name: scores[smallest_l_name]} input_shapes = { smallest_l_name: getattr(self.model, smallest_l_name + '_ts').xshape } layers2prune = [smallest_l_name] prune_model_with_scores(self.model, scores, is_bp, layers2prune, None, smallest_nprune, mean_values, input_shapes)
def train_step_black_box(data, labels_one_hot, samples, weights, _lambda, trainable=tf.constant(False)): print("----Tracing--train_step_black_box") @tf.function def share_loss(X, weights): print("----Tracing---share_loss") def kl_divergence(x_d): print("---Tracing the KL") kl = tf.keras.losses.KLDivergence() return kl(tf.exp(model.compute_log_conditional_distribution(x_d)), black_box(x_d, trainable=trainable)) return tfp.monte_carlo.expectation(f=kl_divergence, samples=X, log_prob=model.log_pdf, use_reparametrization=False) with tf.GradientTape() as tape1: # share_loss = _lambda*black_box.share_loss(X = samples, sTGMA = model , weights = weights) # cross_entropy = cross_ent(labels_one_hot, black_box(data)) # loss = cross_entropy + share_loss + black_box.losses() #gradients = tape.gradient(loss , black_box.trainable_variables) print("--tracing-gradient_persistent") #print(samples) #print(weights) #print(black_box(data)) share_loss = share_loss(X=samples, weights=weights) with tf.GradientTape() as tape2: cross_ent = tf.keras.losses.CategoricalCrossentropy() logits = black_box(data, trainable=tf.constant(True)) cross_entropy = cross_ent(labels_one_hot, logits) # loss = cross_entropy + share_loss + black_box.losses() gradients1 = tape1.gradient(share_loss, black_box.trainable_variables) gradients2 = tape2.gradient(cross_entropy, black_box.trainable_variables) #tf.print([grads.shape for grads in gradients1] ) #print("tattaataaa") numerator = tf.constant(0.0) denominator = tf.constant(0.0) for grads1, grads2 in zip(gradients1, gradients2): numerator = numerator + tf.reduce_sum(grads2 * grads2 - grads1 * grads2) denominator = denominator + tf.norm(grads1 - grads2)**2 qiota = 1. - 1. / (1. + _lambda) tau = tf.math.maximum(tf.math.minimum(numerator / denominator, qiota), 0.0) gradients = [ tau * grads1 + (1 - tau) * grads2 for grads1, grads2 in zip(gradients1, gradients2) ] tf.print("Tau param: ", tau) optimizer_black_box.apply_gradients( zip(gradients, black_box.trainable_variables)) del tape1 del tape2 return cross_entropy, share_loss, tau #, gradients
def error_ratio(backward_difference, error_coefficient, tol): """Computes the ratio of the error in the computed state to the tolerance.""" tol_cast = tf.cast(tol, backward_difference.dtype) error_ratio_ = tf.norm(error_coefficient * backward_difference / tol_cast) return tf.cast(error_ratio_, tf.abs(backward_difference).dtype)
def while_loop_condition(iteration, eigenvector, old_eigenvector): """Returns false if the while loop should terminate.""" not_done = (iteration < maximum_iterations) not_converged = (tf.norm(eigenvector - old_eigenvector) > epsilon) return tf.logical_and(not_done, not_converged)
def _maximal_eigenvector_power_method(matrix, epsilon=1e-6, maximum_iterations=100): """Returns a maximal right-eigenvector of "matrix" using the power method. Args: matrix: 2D Tensor, the matrix of which we will find a maximal right-eigenvector. epsilon: non-negative float, if two iterations of the power method differ (in L2 norm) by no more than epsilon, we will terminate. maximum_iterations: non-negative int, if we perform this many iterations, we will terminate. Returns: A maximal right-eigenvector of "matrix". Raises: TypeError: if the "matrix" `Tensor` is not floating-point. ValueError: if the "epsilon" or "maximum_iterations" parameters violate their bounds. """ if not matrix.dtype.is_floating: raise TypeError("multipliers must have a floating-point dtype") if epsilon <= 0.0: raise ValueError("epsilon must be strictly positive") if maximum_iterations <= 0: raise ValueError("maximum_iterations must be strictly positive") def while_loop_condition(iteration, eigenvector, old_eigenvector): """Returns false if the while loop should terminate.""" not_done = (iteration < maximum_iterations) not_converged = (tf.norm(eigenvector - old_eigenvector) > epsilon) return tf.logical_and(not_done, not_converged) def while_loop_body(iteration, eigenvector, old_eigenvector): """Performs one iteration of the power method.""" del old_eigenvector # Needed by the condition, but not the body (for lint). iteration += 1 # We need to use tf.matmul() and tf.expand_dims(), instead of # tf.tensordot(), since the former will infer the shape of the result, while # the latter will not (tf.while_loop() needs the shapes). new_eigenvector = tf.matmul(matrix, tf.expand_dims(eigenvector, 1))[:, 0] new_eigenvector /= tf.norm(new_eigenvector) return (iteration, new_eigenvector, eigenvector) iteration = tf.constant(0) eigenvector = tf.ones_like(matrix[:, 0]) eigenvector /= tf.norm(eigenvector) # We actually want a do-while loop, so we explicitly call while_loop_body() # once before tf.while_loop(). iteration, eigenvector, old_eigenvector = while_loop_body( iteration, eigenvector, eigenvector) iteration, eigenvector, old_eigenvector = tf.while_loop( while_loop_condition, while_loop_body, loop_vars=(iteration, eigenvector, old_eigenvector), name="power_method") return eigenvector
def minimize(value_and_gradients_function, initial_position, tolerance=1e-8, x_tolerance=0, f_relative_tolerance=0, max_iterations=50, parallel_iterations=1, stopping_condition=None, params=None, name=None): """Minimizes a differentiable function. Implementation of algorithm described in [HZ2006]. Updated formula for next search direction were taken from [HZ2013]. Supports batches with 1-dimensional batch shape. ### References: [HZ2006] Hager, William W., and Hongchao Zhang. "Algorithm 851: CG_DESCENT, a conjugate gradient method with guaranteed descent." http://users.clas.ufl.edu/hager/papers/CG/cg_compare.pdf [HZ2013] W. W. Hager and H. Zhang (2013) The limited memory conjugate gradient method. https://pdfs.semanticscholar.org/8769/69f3911777e0ff0663f21b67dff30518726b.pdf ### Usage: The following example demonstrates this optimizer attempting to find the minimum for a simple two dimensional quadratic objective function. ```python minimum = np.array([1.0, 1.0]) # The center of the quadratic bowl. scales = np.array([2.0, 3.0]) # The scales along the two axes. # The objective function and the gradient. def quadratic(x): value = tf.reduce_sum(scales * (x - minimum) ** 2) return value, tf.gradients(value, x)[0] start = tf.constant([0.6, 0.8]) # Starting point for the search. optim_results = conjugate_gradient.minimize( quadratic, initial_position=start, tolerance=1e-8) with tf.Session() as session: results = session.run(optim_results) # Check that the search converged assert(results.converged) # Check that the argmin is close to the actual value. np.testing.assert_allclose(results.position, minimum) ``` Args: value_and_gradients_function: A Python callable that accepts a point as a real `Tensor` and returns a tuple of `Tensor`s of real dtype containing the value of the function and its gradient at that point. The function to be minimized. The input should be of shape `[..., n]`, where `n` is the size of the domain of input points, and all others are batching dimensions. The first component of the return value should be a real `Tensor` of matching shape `[...]`. The second component (the gradient) should also be of shape `[..., n]` like the input value to the function. initial_position: Real `Tensor` of shape `[..., n]`. The starting point, or points when using batching dimensions, of the search procedure. At these points the function value and the gradient norm should be finite. tolerance: Scalar `Tensor` of real dtype. Specifies the gradient tolerance for the procedure. If the supremum norm of the gradient vector is below this number, the algorithm is stopped. x_tolerance: Scalar `Tensor` of real dtype. If the absolute change in the position between one iteration and the next is smaller than this number, the algorithm is stopped. f_relative_tolerance: Scalar `Tensor` of real dtype. If the relative change in the objective value between one iteration and the next is smaller than this value, the algorithm is stopped. max_iterations: Scalar positive int32 `Tensor`. The maximum number of iterations. parallel_iterations: Positive integer. The number of iterations allowed to run in parallel. stopping_condition: (Optional) A Python function that takes as input two Boolean tensors of shape `[...]`, and returns a Boolean scalar tensor. The input tensors are `converged` and `failed`, indicating the current status of each respective batch member; the return value states whether the algorithm should stop. The default is tfp.optimizer.converged_all which only stops when all batch members have either converged or failed. An alternative is tfp.optimizer.converged_any which stops as soon as one batch member has converged, or when all have failed. params: ConjugateGradientParams object with adjustable parameters of the algorithm. If not supplied, default parameters will be used. name: (Optional) Python str. The name prefixed to the ops created by this function. If not supplied, the default name 'minimize' is used. Returns: optimizer_results: A namedtuple containing the following items: converged: boolean tensor of shape `[...]` indicating for each batch member whether the minimum was found within tolerance. failed: boolean tensor of shape `[...]` indicating for each batch member whether a line search step failed to find a suitable step size satisfying Wolfe conditions. In the absence of any constraints on the number of objective evaluations permitted, this value will be the complement of `converged`. However, if there is a constraint and the search stopped due to available evaluations being exhausted, both `failed` and `converged` will be simultaneously False. num_objective_evaluations: The total number of objective evaluations performed. position: A tensor of shape `[..., n]` containing the last argument value found during the search from each starting point. If the search converged, then this value is the argmin of the objective function. objective_value: A tensor of shape `[...]` with the value of the objective function at the `position`. If the search converged, then this is the (local) minimum of the objective function. objective_gradient: A tensor of shape `[..., n]` containing the gradient of the objective function at the `position`. If the search converged the max-norm of this tensor should be below the tolerance. """ with tf.compat.v1.name_scope(name, 'minimize', [initial_position, tolerance]): if params is None: params = ConjugateGradientParams() initial_position = tf.convert_to_tensor( value=initial_position, name='initial_position') dtype = initial_position.dtype tolerance = tf.convert_to_tensor( value=tolerance, dtype=dtype, name='grad_tolerance') f_relative_tolerance = tf.convert_to_tensor( value=f_relative_tolerance, dtype=dtype, name='f_relative_tolerance') x_tolerance = tf.convert_to_tensor( value=x_tolerance, dtype=dtype, name='x_tolerance') max_iterations = tf.convert_to_tensor( value=max_iterations, name='max_iterations') stopping_condition = stopping_condition or converged_all delta = tf.convert_to_tensor( params.sufficient_decrease_param, dtype=dtype, name='delta') sigma = tf.convert_to_tensor( params.curvature_param, dtype=dtype, name='sigma') eps = tf.convert_to_tensor( params.threshold_use_approximate_wolfe_condition, dtype=dtype, name='sigma') eta = tf.convert_to_tensor( params.direction_update_param, dtype=dtype, name='eta') psi_1 = tf.convert_to_tensor( params.initial_guess_small_factor, dtype=dtype, name='psi_1') psi_2 = tf.convert_to_tensor( params.initial_guess_step_multiplier, dtype=dtype, name='psi_2') f0, df0 = value_and_gradients_function(initial_position) converged = tf.norm(df0, axis=-1) < tolerance initial_state = _OptimizerState( converged=converged, failed=tf.zeros_like(converged), # All false. num_iterations=tf.convert_to_tensor(value=0), num_objective_evaluations=tf.convert_to_tensor(value=1), position=initial_position, objective_value=f0, objective_gradient=df0, direction=-df0, prev_step=tf.ones_like(f0), ) def _cond(state): """Continue if iterations remain and stopping condition is not met.""" return ( (state.num_iterations < max_iterations) & tf.logical_not(stopping_condition(state.converged, state.failed))) def _body(state): """Main optimization loop.""" # We use notation of [HZ2006] for brevity. x_k = state.position d_k = state.direction f_k = state.objective_value g_k = state.objective_gradient a_km1 = state.prev_step # Means a_{k-1}. # Define scalar function, which is objective restricted to direction. def ls_func(alpha): pt = x_k + tf.expand_dims(alpha, axis=-1) * d_k objective_value, gradient = value_and_gradients_function(pt) return ValueAndGradient( x=alpha, f=objective_value, df=_dot(gradient, d_k), full_gradient=gradient) # Generate initial guess for line search. # [HZ2006] suggests to generate first initial guess separately, but # [JuliaLineSearches] generates it as if previous step length was 1, and # we do the same. phi_0 = f_k dphi_0 = _dot(g_k, d_k) ls_val_0 = ValueAndGradient( x=tf.zeros_like(phi_0), f=phi_0, df=dphi_0, full_gradient=g_k) step_guess_result = _init_step(ls_val_0, a_km1, ls_func, psi_1, psi_2, params.quad_step) init_step = step_guess_result.step # Check if initial step size already satisfies Wolfe condition, and in # that case don't perform line search. c = init_step.x phi_lim = phi_0 + eps * tf.abs(phi_0) phi_c = init_step.f dphi_c = init_step.df # Original Wolfe conditions, T1 in [HZ2006]. suff_decrease_1 = delta * dphi_0 >= (phi_c - phi_0) / c curvature = dphi_c >= sigma * dphi_0 wolfe1 = suff_decrease_1 & curvature # Approximate Wolfe conditions, T2 in [HZ2006]. suff_decrease_2 = (2 * delta - 1) * dphi_0 >= dphi_c curvature = dphi_c >= sigma * dphi_0 wolfe2 = suff_decrease_2 & curvature & (phi_c <= phi_lim) wolfe = wolfe1 | wolfe2 skip_line_search = (step_guess_result.may_terminate & wolfe) | state.failed | state.converged # Call Hager-Zhang line search (L0-L3 in [HZ2006]). # Parameter theta from [HZ2006] is not adjustable, it's always 0.5. ls_result = linesearch.hager_zhang( ls_func, value_at_zero=ls_val_0, converged=skip_line_search, initial_step_size=init_step.x, value_at_initial_step=init_step, shrinkage_param=params.shrinkage_param, expansion_param=params.expansion_param, sufficient_decrease_param=delta, curvature_param=sigma, threshold_use_approximate_wolfe_condition=eps) # Moving to the next point, using step length from line search. # If line search was skipped, take step length from initial guess. # To save objective evaluation, use objective value and gradient returned # by line search or initial guess. a_k = tf.compat.v1.where( skip_line_search, init_step.x, ls_result.left.x) x_kp1 = state.position + tf.expand_dims(a_k, -1) * d_k f_kp1 = tf.compat.v1.where( skip_line_search, init_step.f, ls_result.left.f) g_kp1 = tf.compat.v1.where(skip_line_search, init_step.full_gradient, ls_result.left.full_gradient) # Evaluate next direction. # Use formulas (2.7)-(2.11) from [HZ2013] with P_k=I. y_k = g_kp1 - g_k d_dot_y = _dot(d_k, y_k) b_k = (_dot(y_k, g_kp1) - _norm_sq(y_k) * _dot(g_kp1, d_k) / d_dot_y) / d_dot_y eta_k = eta * _dot(d_k, g_k) / _norm_sq(d_k) b_k = tf.maximum(b_k, eta_k) d_kp1 = -g_kp1 + tf.expand_dims(b_k, -1) * d_k # Check convergence criteria. grad_converged = _norm_inf(g_kp1) <= tolerance x_converged = (_norm_inf(x_kp1 - x_k) <= x_tolerance) f_converged = ( tf.math.abs(f_kp1 - f_k) <= f_relative_tolerance * tf.math.abs(f_k)) converged = grad_converged | x_converged | f_converged # Construct new state for next iteration. new_state = _OptimizerState( converged=converged, failed=state.failed, num_iterations=state.num_iterations + 1, num_objective_evaluations=state.num_objective_evaluations + step_guess_result.func_evals + ls_result.func_evals, position=tf.compat.v1.where(state.converged, x_k, x_kp1), objective_value=tf.compat.v1.where(state.converged, f_k, f_kp1), objective_gradient=tf.compat.v1.where(state.converged, g_k, g_kp1), direction=d_kp1, prev_step=a_k) return (new_state,) final_state = tf.while_loop( _cond, _body, (initial_state,), parallel_iterations=parallel_iterations)[0] return OptimizerResult( converged=final_state.converged, failed=final_state.failed, num_iterations=final_state.num_iterations, num_objective_evaluations=final_state.num_objective_evaluations, position=final_state.position, objective_value=final_state.objective_value, objective_gradient=final_state.objective_gradient)
def _assert_ops( self, previous_solver_internal_state, initial_state_vec, final_time, initial_time, solution_times, max_num_steps, max_num_newton_iters, atol, rtol, first_step_size, safety_factor, min_step_size_factor, max_step_size_factor, max_order, newton_tol_factor, newton_step_size_factor, solution_times_chosen_by_solver, ): """Creates a list of assert operations.""" if not self._validate_args: return [] assert_ops = [] if previous_solver_internal_state is not None: assert_initial_state_matches_previous_solver_internal_state = ( tf.debugging.assert_near( tf.norm( initial_state_vec - previous_solver_internal_state.backward_differences[0], np.inf), 0., message='`previous_solver_internal_state` does not match ' '`initial_state`.')) assert_ops.append( assert_initial_state_matches_previous_solver_internal_state) assert_ops.append( util.assert_positive(final_time - initial_time, 'final_time - initial_time')) if not solution_times_chosen_by_solver: assert_ops += [ util.assert_increasing(solution_times, 'solution_times'), util.assert_nonnegative(solution_times[0] - initial_time, 'solution_times[0] - initial_time'), ] if max_num_steps is not None: assert_ops.append( util.assert_positive(max_num_steps, 'max_num_steps')) if max_num_newton_iters is not None: assert_ops.append( util.assert_positive(max_num_newton_iters, 'max_num_newton_iters')) assert_ops += [ util.assert_positive(rtol, 'rtol'), util.assert_positive(atol, 'atol'), util.assert_positive(first_step_size, 'first_step_size'), util.assert_positive(safety_factor, 'safety_factor'), util.assert_positive(min_step_size_factor, 'min_step_size_factor'), util.assert_positive(max_step_size_factor, 'max_step_size_factor'), tf.Assert((max_order >= 1) & (max_order <= bdf_util.MAX_ORDER), [ '`max_order` must be between 1 and {}.'.format( bdf_util.MAX_ORDER) ]), util.assert_positive(newton_tol_factor, 'newton_tol_factor'), util.assert_positive(newton_step_size_factor, 'newton_step_size_factor'), ] return assert_ops
def _norm(x): """Evaluates L2 norm.""" return tf.norm(x, axis=-1)
def get_dice_pose_results(bounding_boxes, classes, scores, y_rotation_angles, camera_matrix : np.ndarray, distortion_coefficients : np.ndarray, score_threshold : float = 0.5): """Estimates pose results for all die, given estimates for bounding box, die (top face) classes, scores and threshold, rotation angles around vertical axes, and camera information.""" scores_in_threshold = tf.math.greater(scores, score_threshold) classes_in_score = tf.boolean_mask(classes, scores_in_threshold) boxes_in_scores = tf.boolean_mask(bounding_boxes, scores_in_threshold) y_angles_in_scores = tf.boolean_mask(y_rotation_angles, scores_in_threshold) classes_are_dots = tf.equal(classes_in_score, 0) classes_are_dice = tf.logical_not(classes_are_dots) dice_bounding_boxes = tf.boolean_mask(boxes_in_scores, classes_are_dice) dice_y_angles = tf.boolean_mask(y_angles_in_scores, classes_are_dice) dice_classes = tf.boolean_mask(classes_in_score, classes_are_dice) dot_bounding_boxes = tf.boolean_mask(boxes_in_scores, classes_are_dots) dot_centers = _get_dot_centers(dot_bounding_boxes) dot_sizes = _get_dot_sizes(dot_bounding_boxes) #NB Largest box[2] is the box lower bound dice_bb_lower_y = dice_bounding_boxes[:,2] dice_indices = tf.argsort(dice_bb_lower_y, axis = -1, direction='DESCENDING') def get_area(bb): return tf.math.maximum(bb[:, 3] - bb[:, 1], 0) * tf.math.maximum(bb[:, 2] - bb[:, 0], 0) dice_indices_np = dice_indices.numpy() bounding_box_pose_results = [_get_die_image_bounding_box_pose(dice_bounding_boxes[index, :], camera_matrix, distortion_coefficients) for index in dice_indices_np] approximate_dice_up_vector_pyrender = _get_approximate_dice_up_vector(bounding_box_pose_results, in_pyrender_coords=True) pose_results = [] for index, bounding_box_pose_result in zip(dice_indices_np, bounding_box_pose_results): die_box = dice_bounding_boxes[index, :] die_y_angle = dice_y_angles[index] die_class = dice_classes[index] die_box_size = (-die_box[0:2] + die_box[2:4]) dot_centers_fraction_of_die_box = (dot_centers - die_box[0:2]) / die_box_size dot_centers_rounded_rectangle_distance = tf.norm(tf.math.maximum(tf.math.abs(dot_centers_fraction_of_die_box - 0.5) - 0.5 + rounded_rectangle_radius,0.0), axis = -1) - rounded_rectangle_radius dots_are_in_rounded_rectangle = dot_centers_rounded_rectangle_distance < 0 dot_bb_intersection_left = tf.math.maximum(dot_bounding_boxes[:, 1], die_box[1]) dot_bb_intersection_right = tf.math.minimum(dot_bounding_boxes[:, 3], die_box[3]) dot_bb_intersection_top = tf.math.maximum(dot_bounding_boxes[:, 0], die_box[0]) dot_bb_intersection_bottom = tf.math.minimum(dot_bounding_boxes[:, 2], die_box[2]) dot_bb_intersection = tf.stack([dot_bb_intersection_top, dot_bb_intersection_left, dot_bb_intersection_bottom, dot_bb_intersection_right], axis = 1) dot_bb_intersection_area = get_area(dot_bb_intersection) dot_bb_area = get_area(dot_bounding_boxes) dot_bb_intersection_over_area = dot_bb_intersection_area / dot_bb_area dots_have_sufficient_bb_intersection_over_area = tf.greater(dot_bb_intersection_over_area, 0.9) dots_are_in_box = tf.logical_and(dots_have_sufficient_bb_intersection_over_area, dots_are_in_rounded_rectangle) dot_centers_in_box = tf.boolean_mask(dot_centers, dots_are_in_box) dot_centers_cv = _convert_tensorflow_points_to_opencv(dot_centers_in_box) die_pose_result = _get_die_pose(die_box, die_class, die_y_angle, dot_centers_cv, bounding_box_pose_result, approximate_dice_up_vector_pyrender, camera_matrix, distortion_coefficients) die_pose_result.calculate_comparison(dot_centers_cv, camera_matrix, distortion_coefficients) die_pose_result.calculate_inliers(_convert_tensorflow_points_to_opencv(dot_sizes)) pose_results.append(die_pose_result) indices_in_box = tf.where(dots_are_in_box) inlier_indices_in_box = tf.gather(indices_in_box, die_pose_result.comparison_inlier_indices) dot_centers = _delete_tf(dot_centers, inlier_indices_in_box) dot_sizes = _delete_tf(dot_sizes, inlier_indices_in_box) dot_bounding_boxes = _delete_tf(dot_bounding_boxes, inlier_indices_in_box) return pose_results