Exemplo n.º 1
0
    def testEstimatorInitManualRegistration(self):
        with self._graph.as_default():
            # We should be able to build an estimator for only the registered vars.
            estimator.FisherEstimatorRoundRobin(
                variables=[self.weights],
                cov_ema_decay=0.1,
                damping=0.2,
                layer_collection=self.layer_collection)

            # Check that we throw an error if we try to build an estimator for vars
            # that were not manually registered.
            with self.assertRaises(ValueError):
                est = estimator.FisherEstimatorRoundRobin(
                    variables=[self.weights, self.bias],
                    cov_ema_decay=0.1,
                    damping=0.2,
                    layer_collection=self.layer_collection)
                est.make_ops_and_vars()

            # Check that we throw an error if we don't include registered variables,
            # i.e. self.weights
            with self.assertRaises(ValueError):
                est = estimator.FisherEstimatorRoundRobin(
                    variables=[],
                    cov_ema_decay=0.1,
                    damping=0.2,
                    layer_collection=self.layer_collection)
                est.make_ops_and_vars()
Exemplo n.º 2
0
    def test_round_robin_placement(self):
        """Check if the ops and variables are placed on devices correctly."""
        with self._graph.as_default():
            fisher_estimator = estimator.FisherEstimatorRoundRobin(
                variables=[self.weights],
                layer_collection=self.layer_collection,
                damping=0.2,
                cov_ema_decay=0.0,
                cov_devices=["/cpu:{}".format(i) for i in range(2)],
                inv_devices=["/cpu:{}".format(i) for i in range(2)])

            # Construct an op that executes one covariance update per step.
            (cov_update_ops, _, inv_update_ops, _, _,
             _) = fisher_estimator.make_ops_and_vars(scope="test")
            self.assertEqual(cov_update_ops[0].device, "/device:CPU:0")
            self.assertEqual(cov_update_ops[1].device, "/device:CPU:1")
            self.assertEqual(inv_update_ops[0].device, "/device:CPU:0")
            self.assertEqual(inv_update_ops[1].device, "/device:CPU:1")
            cov_matrices = [
                fisher_factor.get_cov()
                for fisher_factor in self.layer_collection.get_factors()
            ]
            inv_matrices = [
                matrix
                for fisher_factor in self.layer_collection.get_factors() for
                matrix in fisher_factor._matpower_by_exp_and_damping.values()
            ]
            self.assertEqual(cov_matrices[0].device, "/device:CPU:0")
            self.assertEqual(cov_matrices[1].device, "/device:CPU:1")
            # Inverse matrices need to be explicitly placed.
            self.assertEqual(inv_matrices[0].device, "")
            self.assertEqual(inv_matrices[1].device, "")
Exemplo n.º 3
0
 def testVariableWrongNumberOfUses(self, mock_uses):
     with self.assertRaises(ValueError):
         est = estimator.FisherEstimatorRoundRobin(
             variables=[self.weights],
             cov_ema_decay=0.1,
             damping=0.2,
             layer_collection=self.layer_collection)
         est.make_ops_and_vars()
Exemplo n.º 4
0
 def testInvalidEstimationMode(self):
     with self.assertRaises(ValueError):
         est = estimator.FisherEstimatorRoundRobin(
             variables=[self.weights],
             cov_ema_decay=0.1,
             damping=0.2,
             layer_collection=self.layer_collection,
             estimation_mode="not_a_real_mode")
         est.make_ops_and_vars()
Exemplo n.º 5
0
 def testExactModeBuild(self):
     with self._graph.as_default():
         est = estimator.FisherEstimatorRoundRobin(
             variables=[self.weights],
             cov_ema_decay=0.1,
             damping=0.2,
             layer_collection=self.layer_collection,
             estimation_mode="exact")
         est.make_ops_and_vars()
Exemplo n.º 6
0
 def testCurvaturePropModeBuild(self):
   with self._graph.as_default():
     est = estimator.FisherEstimatorRoundRobin(
         variables=[self.weights],
         cov_ema_decay=0.1,
         damping=0.2,
         layer_collection=self.layer_collection,
         estimation_mode="curvature_prop")
     est.make_vars_and_create_op_thunks()
Exemplo n.º 7
0
    def test_cov_update_thunks(self):
        """Ensures covariance update ops run once per global_step."""
        with self._graph.as_default(), self.test_session() as sess:
            fisher_estimator = estimator.FisherEstimatorRoundRobin(
                variables=[self.weights],
                layer_collection=self.layer_collection,
                damping=0.2,
                cov_ema_decay=0.0)

            # Construct an op that executes one covariance update per step.
            global_step = tf.train.get_or_create_global_step()
            (cov_variable_thunks, cov_update_op_thunks, _,
             _) = fisher_estimator.create_ops_and_vars_thunks()
            for thunk in cov_variable_thunks:
                thunk()
            cov_matrices = [
                fisher_factor.get_cov()
                for fisher_factor in self.layer_collection.get_factors()
            ]
            cov_update_op = tf.case([
                (tf.equal(global_step, i), thunk)
                for i, thunk in enumerate(cov_update_op_thunks)
            ])
            increment_global_step = global_step.assign_add(1)

            sess.run(tf.global_variables_initializer())
            initial_cov_values = sess.run(cov_matrices)

            # Ensure there's one update per covariance matrix.
            self.assertEqual(len(cov_matrices), len(cov_update_op_thunks))

            # Test is no-op if only 1 covariance matrix.
            assert len(cov_matrices) > 1

            for i in range(len(cov_matrices)):
                # Compare new and old covariance values
                new_cov_values = sess.run(cov_matrices)
                is_cov_equal = [
                    np.allclose(initial_cov_value, new_cov_value)
                    for (initial_cov_value, new_cov_value
                         ) in zip(initial_cov_values, new_cov_values)
                ]
                num_cov_equal = sum(is_cov_equal)

                # Ensure exactly one covariance matrix changes per step.
                self.assertEqual(num_cov_equal, len(cov_matrices) - i)

                # Run all covariance update ops.
                sess.run(cov_update_op)
                sess.run(increment_global_step)
Exemplo n.º 8
0
    def test_inv_update_thunks(self):
        """Ensures inverse update ops run once per global_step."""
        with self._graph.as_default(), self.test_session() as sess:
            fisher_estimator = estimator.FisherEstimatorRoundRobin(
                variables=[self.weights],
                layer_collection=self.layer_collection,
                damping=0.2,
                cov_ema_decay=0.0)

            # Construct op that updates one inverse per global step.
            global_step = tf.train.get_or_create_global_step()
            (cov_variable_thunks, _, inv_variable_thunks, inv_update_op_thunks
             ) = fisher_estimator.create_ops_and_vars_thunks()
            for thunk in cov_variable_thunks:
                thunk()
            for thunk in inv_variable_thunks:
                thunk()
            inv_matrices = [
                matrix
                for fisher_factor in self.layer_collection.get_factors() for
                matrix in fisher_factor._matpower_by_exp_and_damping.values()
            ]
            inv_update_op = tf.case([
                (tf.equal(global_step, i), thunk)
                for i, thunk in enumerate(inv_update_op_thunks)
            ])
            increment_global_step = global_step.assign_add(1)

            sess.run(tf.global_variables_initializer())
            initial_inv_values = sess.run(inv_matrices)

            # Ensure there's one update per inverse matrix. This is true as long as
            # there's no fan-in/fan-out or parameter re-use.
            self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))

            # Test is no-op if only 1 invariance matrix.
            assert len(inv_matrices) > 1

            # Assign each covariance matrix a value other than the identity. This
            # ensures that the inverse matrices are updated to something different as
            # well.
            cov_matrices = [
                fisher_factor.get_cov()
                for fisher_factor in self.layer_collection.get_factors()
            ]
            sess.run([
                cov_matrix.assign(2 * tf.eye(int(cov_matrix.shape[0])))
                for cov_matrix in cov_matrices
            ])

            for i in range(len(inv_matrices)):
                # Compare new and old inverse values
                new_inv_values = sess.run(inv_matrices)
                is_inv_equal = [
                    np.allclose(initial_inv_value, new_inv_value)
                    for (initial_inv_value, new_inv_value
                         ) in zip(initial_inv_values, new_inv_values)
                ]
                num_inv_equal = sum(is_inv_equal)

                # Ensure exactly one inverse matrix changes per step.
                self.assertEqual(num_inv_equal, len(inv_matrices) - i)

                # Run all inverse update ops.
                sess.run(inv_update_op)
                sess.run(increment_global_step)