def testEstimatorInitManualRegistration(self): with self._graph.as_default(): # We should be able to build an estimator for only the registered vars. estimator.FisherEstimatorRoundRobin( variables=[self.weights], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection) # Check that we throw an error if we try to build an estimator for vars # that were not manually registered. with self.assertRaises(ValueError): est = estimator.FisherEstimatorRoundRobin( variables=[self.weights, self.bias], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection) est.make_ops_and_vars() # Check that we throw an error if we don't include registered variables, # i.e. self.weights with self.assertRaises(ValueError): est = estimator.FisherEstimatorRoundRobin( variables=[], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection) est.make_ops_and_vars()
def test_round_robin_placement(self): """Check if the ops and variables are placed on devices correctly.""" with self._graph.as_default(): fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, cov_ema_decay=0.0, cov_devices=["/cpu:{}".format(i) for i in range(2)], inv_devices=["/cpu:{}".format(i) for i in range(2)]) # Construct an op that executes one covariance update per step. (cov_update_ops, _, inv_update_ops, _, _, _) = fisher_estimator.make_ops_and_vars(scope="test") self.assertEqual(cov_update_ops[0].device, "/device:CPU:0") self.assertEqual(cov_update_ops[1].device, "/device:CPU:1") self.assertEqual(inv_update_ops[0].device, "/device:CPU:0") self.assertEqual(inv_update_ops[1].device, "/device:CPU:1") cov_matrices = [ fisher_factor.get_cov() for fisher_factor in self.layer_collection.get_factors() ] inv_matrices = [ matrix for fisher_factor in self.layer_collection.get_factors() for matrix in fisher_factor._matpower_by_exp_and_damping.values() ] self.assertEqual(cov_matrices[0].device, "/device:CPU:0") self.assertEqual(cov_matrices[1].device, "/device:CPU:1") # Inverse matrices need to be explicitly placed. self.assertEqual(inv_matrices[0].device, "") self.assertEqual(inv_matrices[1].device, "")
def testVariableWrongNumberOfUses(self, mock_uses): with self.assertRaises(ValueError): est = estimator.FisherEstimatorRoundRobin( variables=[self.weights], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection) est.make_ops_and_vars()
def testInvalidEstimationMode(self): with self.assertRaises(ValueError): est = estimator.FisherEstimatorRoundRobin( variables=[self.weights], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection, estimation_mode="not_a_real_mode") est.make_ops_and_vars()
def testExactModeBuild(self): with self._graph.as_default(): est = estimator.FisherEstimatorRoundRobin( variables=[self.weights], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection, estimation_mode="exact") est.make_ops_and_vars()
def testCurvaturePropModeBuild(self): with self._graph.as_default(): est = estimator.FisherEstimatorRoundRobin( variables=[self.weights], cov_ema_decay=0.1, damping=0.2, layer_collection=self.layer_collection, estimation_mode="curvature_prop") est.make_vars_and_create_op_thunks()
def test_cov_update_thunks(self): """Ensures covariance update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, cov_ema_decay=0.0) # Construct an op that executes one covariance update per step. global_step = tf.train.get_or_create_global_step() (cov_variable_thunks, cov_update_op_thunks, _, _) = fisher_estimator.create_ops_and_vars_thunks() for thunk in cov_variable_thunks: thunk() cov_matrices = [ fisher_factor.get_cov() for fisher_factor in self.layer_collection.get_factors() ] cov_update_op = tf.case([ (tf.equal(global_step, i), thunk) for i, thunk in enumerate(cov_update_op_thunks) ]) increment_global_step = global_step.assign_add(1) sess.run(tf.global_variables_initializer()) initial_cov_values = sess.run(cov_matrices) # Ensure there's one update per covariance matrix. self.assertEqual(len(cov_matrices), len(cov_update_op_thunks)) # Test is no-op if only 1 covariance matrix. assert len(cov_matrices) > 1 for i in range(len(cov_matrices)): # Compare new and old covariance values new_cov_values = sess.run(cov_matrices) is_cov_equal = [ np.allclose(initial_cov_value, new_cov_value) for (initial_cov_value, new_cov_value ) in zip(initial_cov_values, new_cov_values) ] num_cov_equal = sum(is_cov_equal) # Ensure exactly one covariance matrix changes per step. self.assertEqual(num_cov_equal, len(cov_matrices) - i) # Run all covariance update ops. sess.run(cov_update_op) sess.run(increment_global_step)
def test_inv_update_thunks(self): """Ensures inverse update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: fisher_estimator = estimator.FisherEstimatorRoundRobin( variables=[self.weights], layer_collection=self.layer_collection, damping=0.2, cov_ema_decay=0.0) # Construct op that updates one inverse per global step. global_step = tf.train.get_or_create_global_step() (cov_variable_thunks, _, inv_variable_thunks, inv_update_op_thunks ) = fisher_estimator.create_ops_and_vars_thunks() for thunk in cov_variable_thunks: thunk() for thunk in inv_variable_thunks: thunk() inv_matrices = [ matrix for fisher_factor in self.layer_collection.get_factors() for matrix in fisher_factor._matpower_by_exp_and_damping.values() ] inv_update_op = tf.case([ (tf.equal(global_step, i), thunk) for i, thunk in enumerate(inv_update_op_thunks) ]) increment_global_step = global_step.assign_add(1) sess.run(tf.global_variables_initializer()) initial_inv_values = sess.run(inv_matrices) # Ensure there's one update per inverse matrix. This is true as long as # there's no fan-in/fan-out or parameter re-use. self.assertEqual(len(inv_matrices), len(inv_update_op_thunks)) # Test is no-op if only 1 invariance matrix. assert len(inv_matrices) > 1 # Assign each covariance matrix a value other than the identity. This # ensures that the inverse matrices are updated to something different as # well. cov_matrices = [ fisher_factor.get_cov() for fisher_factor in self.layer_collection.get_factors() ] sess.run([ cov_matrix.assign(2 * tf.eye(int(cov_matrix.shape[0]))) for cov_matrix in cov_matrices ]) for i in range(len(inv_matrices)): # Compare new and old inverse values new_inv_values = sess.run(inv_matrices) is_inv_equal = [ np.allclose(initial_inv_value, new_inv_value) for (initial_inv_value, new_inv_value ) in zip(initial_inv_values, new_inv_values) ] num_inv_equal = sum(is_inv_equal) # Ensure exactly one inverse matrix changes per step. self.assertEqual(num_inv_equal, len(inv_matrices) - i) # Run all inverse update ops. sess.run(inv_update_op) sess.run(increment_global_step)