def testEstimatorInitManualRegistration(self): with self._graph.as_default(): # We should be able to build an estimator for only the registered vars. estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection) # Check that we throw an error if we try to build an estimator for vars # that were not manually registered. with self.assertRaises(ValueError): estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2, self.layer_collection) # Check that we throw an error if we don't include registered variables, # i.e. self.weights with self.assertRaises(ValueError): estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
def testExactModeBuild(self): with self._graph.as_default(): estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection, estimation_mode="exact")
def testCurvaturePropModeBuild(self): with self._graph.as_default(): estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection, estimation_mode="curvature_prop")
def testInvalidEstimationMode(self): with self.assertRaises(ValueError): estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection, estimation_mode="not_a_real_mode")
def test_inv_update_thunks(self): """Ensures inverse update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: fisher_estimator = estimator.FisherEstimator( damping_fn=lambda: 0.2, variables=[self.weights], layer_collection=self.layer_collection, cov_ema_decay=0.0) # Construct op that updates one inverse per global step. global_step = training_util.get_or_create_global_step() inv_matrices = [ matrix for fisher_factor in self.layer_collection.get_factors() for matrix in fisher_factor._inverses_by_damping.values() ] inv_update_op_thunks = fisher_estimator.inv_update_thunks inv_update_op = control_flow_ops.case( [(math_ops.equal(global_step, i), thunk) for i, thunk in enumerate(inv_update_op_thunks)]) increment_global_step = global_step.assign_add(1) sess.run(variables.global_variables_initializer()) initial_inv_values = sess.run(inv_matrices) # Ensure there's one update per inverse matrix. This is true as long as # there's no fan-in/fan-out or parameter re-use. self.assertEqual(len(inv_matrices), len(inv_update_op_thunks)) # Test is no-op if only 1 invariance matrix. assert len(inv_matrices) > 1 # Assign each covariance matrix a value other than the identity. This # ensures that the inverse matrices are updated to something different as # well. cov_matrices = [ fisher_factor.get_cov() for fisher_factor in self.layer_collection.get_factors() ] sess.run([ cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0]))) for cov_matrix in cov_matrices ]) for i in range(len(inv_matrices)): # Compare new and old inverse values new_inv_values = sess.run(inv_matrices) is_inv_equal = [ np.allclose(initial_inv_value, new_inv_value) for (initial_inv_value, new_inv_value) in zip(initial_inv_values, new_inv_values) ] num_inv_equal = sum(is_inv_equal) # Ensure exactly one inverse matrix changes per step. self.assertEqual(num_inv_equal, len(inv_matrices) - i) # Run all inverse update ops. sess.run(inv_update_op) sess.run(increment_global_step)
def testEmpiricalModeBuild(self): with self._graph.as_default(): est = estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection, estimation_mode="empirical") est.make_ops_and_vars()
def test_cov_update_thunks(self): """Ensures covariance update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: fisher_estimator = estimator.FisherEstimator( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, cov_ema_decay=0.0) # Construct an op that executes one covariance update per step. global_step = training_util.get_or_create_global_step() (cov_variable_thunks, cov_update_op_thunks, _, _) = fisher_estimator.create_ops_and_vars_thunks() for thunk in cov_variable_thunks: thunk() cov_matrices = [ fisher_factor.get_cov() for fisher_factor in self.layer_collection.get_factors() ] cov_update_op = control_flow_ops.case([ (math_ops.equal(global_step, i), thunk) for i, thunk in enumerate(cov_update_op_thunks) ]) increment_global_step = global_step.assign_add(1) sess.run(variables.global_variables_initializer()) initial_cov_values = sess.run(cov_matrices) # Ensure there's one update per covariance matrix. self.assertEqual(len(cov_matrices), len(cov_update_op_thunks)) # Test is no-op if only 1 covariance matrix. assert len(cov_matrices) > 1 for i in range(len(cov_matrices)): # Compare new and old covariance values new_cov_values = sess.run(cov_matrices) is_cov_equal = [ np.allclose(initial_cov_value, new_cov_value) for (initial_cov_value, new_cov_value ) in zip(initial_cov_values, new_cov_values) ] num_cov_equal = sum(is_cov_equal) # Ensure exactly one covariance matrix changes per step. self.assertEqual(num_cov_equal, len(cov_matrices) - i) # Run all covariance update ops. sess.run(cov_update_op) sess.run(increment_global_step)
def testEstimatorInitManualRegistration(self): with ops.Graph().as_default(): layer_collection = lc.LayerCollection() inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32) weights = variable_scope.get_variable( 'w', shape=(2, 2), dtype=dtypes.float32) bias = variable_scope.get_variable( 'b', initializer=init_ops.zeros_initializer(), shape=(2, 1)) output = math_ops.matmul(inputs, weights) + bias # Only register the weights. layer_collection.register_fully_connected((weights,), inputs, output) outputs = math_ops.tanh(output) layer_collection.register_categorical_predictive_distribution(outputs) # We should be able to build an estimator for only the registered vars. estimator.FisherEstimator([weights], 0.1, 0.2, layer_collection) # Check that we throw an error if we try to build an estimator for vars # that were not manually registered. with self.assertRaises(ValueError): estimator.FisherEstimator([weights, bias], 0.1, 0.2, layer_collection)
def __init__(self, learning_rate=0.001, decay=0.9, epsilon=1e-10, damping=0.001, cov_ema_decay=0.95, lrdecay=0.96, decay_interval=50, layer_collection=None, estimation_mode='gradients', colocate_gradient_with_ops=True, use_locking=False, name="kSGLDOpt"): super(kSGLDOpt, self).__init__(use_locking, name) self._lr = learning_rate self._decay = decay self._epsilon = epsilon self._lrdecay = lrdecay self._decay_interval = decay_interval self._variables = tf_variables.trainable_variables() self.damping_fn = lambda: damping self.cov_ema_decay = cov_ema_decay self.layer_collection = layer_collection self.estimation_mode = estimation_mode self.colocate_gradient_with_ops = colocate_gradient_with_ops # Tensor versions of the constructor arguments, created in _prepare(). self._lr_t = None self._decay_t = None self._epsilon_t = None self._fisher_est = est.FisherEstimator(self.damping_fn, self._variables, self.cov_ema_decay, self.layer_collection, self.estimation_mode, self.colocate_gradient_with_ops)
def testVariableWrongNumberOfUses(self, mock_uses): with self.assertRaises(ValueError): est = estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection) est.make_ops_and_vars()
def __init__(self, learning_rate, cov_ema_decay, damping, layer_collection, var_list=None, momentum=0.9, momentum_type="regular", norm_constraint=None, name="KFAC", estimation_mode="gradients", colocate_gradients_with_ops=True, batch_size=None, cov_devices=None, inv_devices=None): """Initializes the KFAC optimizer with the given settings. Args: learning_rate: The base learning rate for the optimizer. Should probably be set to 1.0 when using momentum_type = 'qmodel', but can still be set lowered if desired (effectively lowering the trust in the quadratic model.) cov_ema_decay: The decay factor used when calculating the covariance estimate moving averages. damping: The damping factor used to stabilize training due to errors in the local approximation with the Fisher information matrix, and to regularize the update direction by making it closer to the gradient. If damping is adapted during training then this value is used for initializing damping varaible. (Higher damping means the update looks more like a standard gradient update - see Tikhonov regularization.) layer_collection: The layer collection object, which holds the fisher blocks, kronecker factors, and losses associated with the graph. The layer_collection cannot be modified after KfacOptimizer's initialization. var_list: Optional list or tuple of variables to train. Defaults to the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. momentum: The momentum decay constant to use. Only applies when momentum_type is 'regular' or 'adam'. (Default: 0.9) momentum_type: The type of momentum to use in this optimizer, one of 'regular', 'adam', or 'qmodel'. (Default: 'regular') norm_constraint: float or Tensor. If specified, the update is scaled down so that its approximate squared Fisher norm v^T F v is at most the specified value. May only be used with momentum type 'regular'. (Default: None) name: The name for this optimizer. (Default: 'KFAC') estimation_mode: The type of estimator to use for the Fishers. Can be 'gradients', 'empirical', 'curvature_propagation', or 'exact'. (Default: 'gradients'). See the doc-string for FisherEstimator for more a more detailed description of these options. colocate_gradients_with_ops: Whether we should request gradients we compute in the estimator be colocated with their respective ops. (Default: True) batch_size: The size of the mini-batch. Only needed when momentum_type == 'qmodel' or when automatic adjustment is used. (Default: None) cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance computations will be placed on these devices in a round-robin fashion. Can be None, which means that no devices are specified. Only used with (soon-to-be-depcrecated "convenience" properties). inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion computations will be placed on these devices in a round-robin fashion. Can be None, which means that no devices are specified. Only used with (soon-to-be-depcrecated "convenience" properties). Raises: ValueError: If the momentum type is unsupported. ValueError: If clipping is used with momentum type other than 'regular'. ValueError: If no losses have been registered with layer_collection. ValueError: If momentum is non-zero and momentum_type is not 'regular' or 'adam'. """ variables = var_list if variables is None: variables = tf_variables.trainable_variables() # Parameters to be passed to the Fisher estimator: self._variables = variables self._cov_ema_decay = cov_ema_decay self._layers = layer_collection self._estimation_mode = estimation_mode self._colocate_gradients_with_ops = colocate_gradients_with_ops self._cov_devices = cov_devices self._inv_devices = inv_devices # The below paramaters are required only if damping needs to be adapated. # These parameters can be set by calling # set_damping_adaptation_params() explicitly. self._damping_adaptation_decay = 0.95 self._damping_adaptation_interval = 5 # Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval) self._omega = ( self._damping_adaptation_decay**self._damping_adaptation_interval) self._adapt_damping = False self._min_damping = 1e-5 self._prev_train_batch = None self._is_chief = False self._loss_fn = None self._damping_constant = damping self._damping = None self._rho = None self._prev_loss = None self._q_model_change = None self._update_damping_op = None momentum_type = momentum_type.lower() legal_momentum_types = ["regular", "adam", "qmodel"] if momentum_type not in legal_momentum_types: raise ValueError( "Unsupported momentum type {}. Must be one of {}.".format( momentum_type, legal_momentum_types)) if momentum_type != "regular" and norm_constraint is not None: raise ValueError("Update clipping is only supported with momentum" "type 'regular'.") if momentum_type not in ["regular", "adam"] and momentum != 0: raise ValueError( "Momentum must be unspecified if using a momentum_type " "other than 'regular' or 'adam'.") # Extra parameters of the optimizer self._momentum = momentum self._momentum_type = momentum_type self._norm_constraint = norm_constraint self._batch_size = batch_size with variable_scope.variable_scope(name): self._fisher_est = est.FisherEstimator( self._variables, self._cov_ema_decay, self.damping, self._layers, exps=(-1, ), estimation_mode=self._estimation_mode, colocate_gradients_with_ops=self._colocate_gradients_with_ops) super(KfacOptimizer, self).__init__(learning_rate, name=name)
def testAllModesBuild(self): for mode in _ALL_ESTIMATION_MODES: with self._graph.as_default(): estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection, mode)
def testModeListCorrect(self): with self._graph.as_default(): est = estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection) self.assertItemsEqual(_ALL_ESTIMATION_MODES, est._gradient_fns.keys())
def __init__(self, learning_rate, cov_ema_decay, damping, layer_collection, var_list=None, momentum=0.9, momentum_type="regular", norm_constraint=None, name="KFAC", estimation_mode="gradients", colocate_gradients_with_ops=True, cov_devices=None, inv_devices=None): """Initializes the KFAC optimizer with the given settings. Args: learning_rate: The base learning rate for the optimizer. Should probably be set to 1.0 when using momentum_type = 'qmodel', but can still be set lowered if desired (effectively lowering the trust in the quadratic model.) cov_ema_decay: The decay factor used when calculating the covariance estimate moving averages. damping: The damping factor used to stabilize training due to errors in the local approximation with the Fisher information matrix, and to regularize the update direction by making it closer to the gradient. (Higher damping means the update looks more like a standard gradient update - see Tikhonov regularization.) layer_collection: The layer collection object, which holds the fisher blocks, kronecker factors, and losses associated with the graph. The layer_collection cannot be modified after KfacOptimizer's initialization. var_list: Optional list or tuple of variables to train. Defaults to the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. momentum: The momentum decay constant to use. Only applies when momentum_type is 'regular' or 'adam'. (Default: 0.9) momentum_type: The type of momentum to use in this optimizer, one of 'regular', 'adam', or 'qmodel'. (Default: 'regular') norm_constraint: float or Tensor. If specified, the update is scaled down so that its approximate squared Fisher norm v^T F v is at most the specified value. May only be used with momentum type 'regular'. (Default: None) name: The name for this optimizer. (Default: 'KFAC') estimation_mode: The type of estimator to use for the Fishers. Can be 'gradients', 'empirical', 'curvature_propagation', or 'exact'. (Default: 'gradients'). See the doc-string for FisherEstimator for more a more detailed description of these options. colocate_gradients_with_ops: Whether we should request gradients we compute in the estimator be colocated with their respective ops. (Default: True) cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance computations will be placed on these devices in a round-robin fashion. Can be None, which means that no devices are specified. inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion computations will be placed on these devices in a round-robin fashion. Can be None, which means that no devices are specified. Raises: ValueError: If the momentum type is unsupported. ValueError: If clipping is used with momentum type other than 'regular'. ValueError: If no losses have been registered with layer_collection. ValueError: If momentum is non-zero and momentum_type is not 'regular' or 'adam'. """ variables = var_list if variables is None: variables = tf_variables.trainable_variables() self._fisher_est = est.FisherEstimator( variables, cov_ema_decay, damping, layer_collection, estimation_mode=estimation_mode, colocate_gradients_with_ops=colocate_gradients_with_ops, cov_devices=cov_devices, inv_devices=inv_devices) momentum_type = momentum_type.lower() legal_momentum_types = ["regular", "adam", "qmodel"] if momentum_type not in legal_momentum_types: raise ValueError( "Unsupported momentum type {}. Must be one of {}.".format( momentum_type, legal_momentum_types)) if momentum_type != "regular" and norm_constraint is not None: raise ValueError("Update clipping is only supported with momentum" "type 'regular'.") if momentum_type not in ["regular", "adam"] and momentum != 0: raise ValueError( "Momentum must be unspecified if using a momentum_type " "other than 'regular' or 'adam'.") self._momentum = momentum self._momentum_type = momentum_type self._norm_constraint = norm_constraint # this is a bit of a hack # TODO(duckworthd): Handle this in a better way (e.g. pass it in?) self._batch_size = array_ops.shape( layer_collection.losses[0].inputs)[0] self._losses = layer_collection.losses self.cov_update_op = self._fisher_est.cov_update_op self.inv_update_op = self._fisher_est.inv_update_op self.inv_updates_dict = self._fisher_est.inv_updates_dict super(KfacOptimizer, self).__init__(learning_rate, name=name)
def __init__( self, learning_rate, cov_ema_decay, damping, layer_collection, momentum=0., momentum_type="regular", norm_constraint=None, name="KFAC", ): """Initializes the KFAC optimizer with the given settings. Args: learning_rate: The base learning rate for the optimizer. Should probably be set to 1.0 when using momentum_type = 'qmodel', but can still be set lowered if desired (effectively lowering the trust in the quadratic model.) cov_ema_decay: The decay factor used when calculating the covariance estimate moving averages. damping: The damping factor used to stabilize training due to errors in the local approximation with the Fisher information matrix, and to regularize the update direction by making it closer to the gradient. (Higher damping means the update looks more like a standard gradient update - see Tikhonov regularization.) layer_collection: The layer collection object, which holds the fisher blocks, kronecker factors, and losses associated with the graph. The layer_collection cannot be modified after KfacOptimizer's initialization. momentum: The momentum value for this optimizer. Only applies when momentum_type is 'regular' or 'adam'. (Default: 0) momentum_type: The type of momentum to use in this optimizer, one of 'regular', 'adam', or 'qmodel'. (Default: 'regular') norm_constraint: float or Tensor. If specified, the update is scaled down so that its approximate squared Fisher norm v^T F v is at most the specified value. May only be used with momentum type 'regular'. (Default: None) name: The name for this optimizer. (Default: 'KFAC') Raises: ValueError: If the momentum type is unsupported. ValueError: If clipping is used with momentum type other than 'regular'. ValueError: If no losses have been registered with layer_collection. ValueError: If momentum is non-zero and momentum_type is not 'regular' or 'adam'. """ # We may consider determining the set of variables some other way, but for # now it's just all the trainable variables. variables = tf_variables.trainable_variables() self._fisher_est = est.FisherEstimator(variables, cov_ema_decay, damping, layer_collection) momentum_type = momentum_type.lower() legal_momentum_types = ["regular", "adam", "qmodel"] if momentum_type not in legal_momentum_types: raise ValueError( "Unsupported momentum type {}. Must be one of {}.".format( momentum_type, legal_momentum_types)) if momentum_type != "regular" and norm_constraint is not None: raise ValueError("Update clipping is only supported with momentum" "type 'regular'.") if momentum_type not in ["regular", "adam"] and momentum != 0: raise ValueError( "Momentum must be unspecified if using a momentum_type " "other than 'regular' or 'adam'.") self._momentum = ops.convert_to_tensor(momentum, name="momentum") self._momentum_type = momentum_type self._norm_constraint = norm_constraint # this is a bit of a hack # TODO(duckworthd): Handle this in a better way (e.g. pass it in?) self._batch_size = array_ops.shape( layer_collection.losses[0].inputs)[0] self._losses = layer_collection.losses self.cov_update_op = self._fisher_est.cov_update_op self.inv_update_op = self._fisher_est.inv_update_op self.inv_updates_dict = self._fisher_est.inv_updates_dict super(KfacOptimizer, self).__init__(learning_rate, name=name)
def testVariableWrongNumberOfUses(self, mock_uses): with self.assertRaises(ValueError): estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1, self.layer_collection)