def _get_transformed_random_signs(self): transformed_random_signs = [] for loss in self._layers.losses: transformed_random_signs.append( loss.multiply_fisher_factor( utils.generate_random_signs(loss.fisher_factor_inner_shape))) return transformed_random_signs
def _get_transformed_random_signs(self): transformed_random_signs = [] for loss in self._layers.losses: with tf_ops.colocate_with(self._layers.loss_colocation_ops[loss]): transformed_random_signs.append( loss.multiply_fisher_factor( utils.generate_random_signs(loss.fisher_factor_inner_shape))) return transformed_random_signs
def _setup(self, cov_ema_decay): """Sets up the various operations. Args: cov_ema_decay: The decay factor used when calculating the covariance estimate moving averages. Returns: A triple (covs_update_op, invs_update_op, inv_updates_dict), where covs_update_op is the grouped Op to update all the covariance estimates, invs_update_op is the grouped Op to update all the inverses, and inv_updates_dict is a dict mapping Op names to individual inverse updates. Raises: ValueError: If estimation_mode was improperly specified at construction. """ damping = self.damping fisher_blocks_list = self._layers.get_blocks() tensors_to_compute_grads = [ fb.tensors_to_compute_grads() for fb in fisher_blocks_list ] tensors_to_compute_grads_flat = nest.flatten(tensors_to_compute_grads) if self._estimation_mode == "gradients": grads_flat = gradients_impl.gradients( self._layers.total_sampled_loss(), tensors_to_compute_grads_flat) grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat) grads_lists = tuple((grad, ) for grad in grads_all) elif self._estimation_mode == "empirical": grads_flat = gradients_impl.gradients( self._layers.total_loss(), tensors_to_compute_grads_flat) grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat) grads_lists = tuple((grad, ) for grad in grads_all) elif self._estimation_mode == "curvature_prop": loss_inputs = list(loss.inputs for loss in self._layers.losses) loss_inputs_flat = nest.flatten(loss_inputs) transformed_random_signs = list( loss.multiply_fisher_factor( utils.generate_random_signs( loss.fisher_factor_inner_shape)) for loss in self._layers.losses) transformed_random_signs_flat = nest.flatten( transformed_random_signs) grads_flat = gradients_impl.gradients( loss_inputs_flat, tensors_to_compute_grads_flat, grad_ys=transformed_random_signs_flat) grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat) grads_lists = tuple((grad, ) for grad in grads_all) elif self._estimation_mode == "exact": # Loop over all coordinates of all losses. grads_all = [] for loss in self._layers.losses: for index in np.ndindex( *loss.fisher_factor_inner_static_shape[1:]): transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( index) grads_flat = gradients_impl.gradients( loss.inputs, tensors_to_compute_grads_flat, grad_ys=transformed_one_hot) grads_all.append( nest.pack_sequence_as(tensors_to_compute_grads, grads_flat)) grads_lists = zip(*grads_all) else: raise ValueError( "Unrecognized value {} for estimation_mode.".format( self._estimation_mode)) for grads_list, fb in zip(grads_lists, fisher_blocks_list): fb.instantiate_factors(grads_list, damping) cov_updates = [ factor.make_covariance_update_op(cov_ema_decay) for factor in self._layers.get_factors() ] inv_updates = { op.name: op for factor in self._layers.get_factors() for op in factor.make_inverse_update_ops() } return control_flow_ops.group(*cov_updates), control_flow_ops.group( *inv_updates.values()), inv_updates
def _setup(self, cov_ema_decay): """Sets up the various operations. Args: cov_ema_decay: The decay factor used when calculating the covariance estimate moving averages. Returns: A triple (covs_update_op, invs_update_op, inv_updates_dict), where covs_update_op is the grouped Op to update all the covariance estimates, invs_update_op is the grouped Op to update all the inverses, and inv_updates_dict is a dict mapping Op names to individual inverse updates. Raises: ValueError: If estimation_mode was improperly specified at construction. """ damping = self.damping fisher_blocks_list = self._layers.get_blocks() tensors_to_compute_grads = [ fb.tensors_to_compute_grads() for fb in fisher_blocks_list ] tensors_to_compute_grads_flat = nest.flatten(tensors_to_compute_grads) if self._estimation_mode == "gradients": grads_flat = gradients_impl.gradients(self._layers.total_sampled_loss(), tensors_to_compute_grads_flat) grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat) grads_lists = tuple((grad,) for grad in grads_all) elif self._estimation_mode == "empirical": grads_flat = gradients_impl.gradients(self._layers.total_loss(), tensors_to_compute_grads_flat) grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat) grads_lists = tuple((grad,) for grad in grads_all) elif self._estimation_mode == "curvature_prop": loss_inputs = list(loss.inputs for loss in self._layers.losses) loss_inputs_flat = nest.flatten(loss_inputs) transformed_random_signs = list(loss.multiply_fisher_factor( utils.generate_random_signs(loss.fisher_factor_inner_shape)) for loss in self._layers.losses) transformed_random_signs_flat = nest.flatten(transformed_random_signs) grads_flat = gradients_impl.gradients(loss_inputs_flat, tensors_to_compute_grads_flat, grad_ys =transformed_random_signs_flat) grads_all = nest.pack_sequence_as(tensors_to_compute_grads, grads_flat) grads_lists = tuple((grad,) for grad in grads_all) elif self._estimation_mode == "exact": # Loop over all coordinates of all losses. grads_all = [] for loss in self._layers.losses: for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( index) grads_flat = gradients_impl.gradients(loss.inputs, tensors_to_compute_grads_flat, grad_ys=transformed_one_hot) grads_all.append(nest.pack_sequence_as(tensors_to_compute_grads, grads_flat)) grads_lists = zip(*grads_all) else: raise ValueError("Unrecognized value {} for estimation_mode.".format( self._estimation_mode)) for grads_list, fb in zip(grads_lists, fisher_blocks_list): fb.instantiate_factors(grads_list, damping) cov_updates = [ factor.make_covariance_update_op(cov_ema_decay) for factor in self._layers.get_factors() ] inv_updates = { op.name: op for factor in self._layers.get_factors() for op in factor.make_inverse_update_ops() } return control_flow_ops.group(*cov_updates), control_flow_ops.group( *inv_updates.values()), inv_updates