Ejemplo n.º 1
0
  def testEstimatorInitManualRegistration(self):
    with self._graph.as_default():
      # We should be able to build an estimator for only the registered vars.
      estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection)

      # Check that we throw an error if we try to build an estimator for vars
      # that were not manually registered.
      with self.assertRaises(ValueError):
        estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2,
                                  self.layer_collection)

      # Check that we throw an error if we don't include registered variables,
      # i.e. self.weights
      with self.assertRaises(ValueError):
        estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection)
Ejemplo n.º 2
0
 def testExactModeBuild(self):
     with self._graph.as_default():
         estimator.FisherEstimator([self.weights],
                                   0.1,
                                   0.2,
                                   self.layer_collection,
                                   estimation_mode="exact")
Ejemplo n.º 3
0
 def testCurvaturePropModeBuild(self):
     with self._graph.as_default():
         estimator.FisherEstimator([self.weights],
                                   0.1,
                                   0.2,
                                   self.layer_collection,
                                   estimation_mode="curvature_prop")
Ejemplo n.º 4
0
 def testInvalidEstimationMode(self):
     with self.assertRaises(ValueError):
         estimator.FisherEstimator([self.weights],
                                   0.1,
                                   0.2,
                                   self.layer_collection,
                                   estimation_mode="not_a_real_mode")
Ejemplo n.º 5
0
  def test_inv_update_thunks(self):
    """Ensures inverse update ops run once per global_step."""
    with self._graph.as_default(), self.test_session() as sess:
      fisher_estimator = estimator.FisherEstimator(
          damping_fn=lambda: 0.2,
          variables=[self.weights],
          layer_collection=self.layer_collection,
          cov_ema_decay=0.0)

      # Construct op that updates one inverse per global step.
      global_step = training_util.get_or_create_global_step()
      inv_matrices = [
          matrix
          for fisher_factor in self.layer_collection.get_factors()
          for matrix in fisher_factor._inverses_by_damping.values()
      ]
      inv_update_op_thunks = fisher_estimator.inv_update_thunks
      inv_update_op = control_flow_ops.case(
          [(math_ops.equal(global_step, i), thunk)
           for i, thunk in enumerate(inv_update_op_thunks)])
      increment_global_step = global_step.assign_add(1)

      sess.run(variables.global_variables_initializer())
      initial_inv_values = sess.run(inv_matrices)

      # Ensure there's one update per inverse matrix. This is true as long as
      # there's no fan-in/fan-out or parameter re-use.
      self.assertEqual(len(inv_matrices), len(inv_update_op_thunks))

      # Test is no-op if only 1 invariance matrix.
      assert len(inv_matrices) > 1

      # Assign each covariance matrix a value other than the identity. This
      # ensures that the inverse matrices are updated to something different as
      # well.
      cov_matrices = [
          fisher_factor.get_cov()
          for fisher_factor in self.layer_collection.get_factors()
      ]
      sess.run([
          cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0])))
          for cov_matrix in cov_matrices
      ])

      for i in range(len(inv_matrices)):
        # Compare new and old inverse values
        new_inv_values = sess.run(inv_matrices)
        is_inv_equal = [
            np.allclose(initial_inv_value, new_inv_value)
            for (initial_inv_value,
                 new_inv_value) in zip(initial_inv_values, new_inv_values)
        ]
        num_inv_equal = sum(is_inv_equal)

        # Ensure exactly one inverse matrix changes per step.
        self.assertEqual(num_inv_equal, len(inv_matrices) - i)

        # Run all inverse update ops.
        sess.run(inv_update_op)
        sess.run(increment_global_step)
Ejemplo n.º 6
0
 def testEmpiricalModeBuild(self):
     with self._graph.as_default():
         est = estimator.FisherEstimator([self.weights],
                                         0.1,
                                         0.2,
                                         self.layer_collection,
                                         estimation_mode="empirical")
         est.make_ops_and_vars()
Ejemplo n.º 7
0
    def test_cov_update_thunks(self):
        """Ensures covariance update ops run once per global_step."""
        with self._graph.as_default(), self.test_session() as sess:
            fisher_estimator = estimator.FisherEstimator(
                variables=[self.weights],
                layer_collection=self.layer_collection,
                damping=0.2,
                cov_ema_decay=0.0)

            # Construct an op that executes one covariance update per step.
            global_step = training_util.get_or_create_global_step()
            (cov_variable_thunks, cov_update_op_thunks, _,
             _) = fisher_estimator.create_ops_and_vars_thunks()
            for thunk in cov_variable_thunks:
                thunk()
            cov_matrices = [
                fisher_factor.get_cov()
                for fisher_factor in self.layer_collection.get_factors()
            ]
            cov_update_op = control_flow_ops.case([
                (math_ops.equal(global_step, i), thunk)
                for i, thunk in enumerate(cov_update_op_thunks)
            ])
            increment_global_step = global_step.assign_add(1)

            sess.run(variables.global_variables_initializer())
            initial_cov_values = sess.run(cov_matrices)

            # Ensure there's one update per covariance matrix.
            self.assertEqual(len(cov_matrices), len(cov_update_op_thunks))

            # Test is no-op if only 1 covariance matrix.
            assert len(cov_matrices) > 1

            for i in range(len(cov_matrices)):
                # Compare new and old covariance values
                new_cov_values = sess.run(cov_matrices)
                is_cov_equal = [
                    np.allclose(initial_cov_value, new_cov_value)
                    for (initial_cov_value, new_cov_value
                         ) in zip(initial_cov_values, new_cov_values)
                ]
                num_cov_equal = sum(is_cov_equal)

                # Ensure exactly one covariance matrix changes per step.
                self.assertEqual(num_cov_equal, len(cov_matrices) - i)

                # Run all covariance update ops.
                sess.run(cov_update_op)
                sess.run(increment_global_step)
Ejemplo n.º 8
0
  def testEstimatorInitManualRegistration(self):
    with ops.Graph().as_default():
      layer_collection = lc.LayerCollection()

      inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32)
      weights = variable_scope.get_variable(
          'w', shape=(2, 2), dtype=dtypes.float32)
      bias = variable_scope.get_variable(
          'b', initializer=init_ops.zeros_initializer(), shape=(2, 1))
      output = math_ops.matmul(inputs, weights) + bias

      # Only register the weights.
      layer_collection.register_fully_connected((weights,), inputs, output)

      outputs = math_ops.tanh(output)
      layer_collection.register_categorical_predictive_distribution(outputs)

      # We should be able to build an estimator for only the registered vars.
      estimator.FisherEstimator([weights], 0.1, 0.2, layer_collection)

      # Check that we throw an error if we try to build an estimator for vars
      # that were not manually registered.
      with self.assertRaises(ValueError):
        estimator.FisherEstimator([weights, bias], 0.1, 0.2, layer_collection)
Ejemplo n.º 9
0
    def __init__(self,
                 learning_rate=0.001,
                 decay=0.9,
                 epsilon=1e-10,
                 damping=0.001,
                 cov_ema_decay=0.95,
                 lrdecay=0.96,
                 decay_interval=50,
                 layer_collection=None,
                 estimation_mode='gradients',
                 colocate_gradient_with_ops=True,
                 use_locking=False,
                 name="kSGLDOpt"):
        super(kSGLDOpt, self).__init__(use_locking, name)
        self._lr = learning_rate
        self._decay = decay
        self._epsilon = epsilon
        self._lrdecay = lrdecay
        self._decay_interval = decay_interval

        self._variables = tf_variables.trainable_variables()
        self.damping_fn = lambda: damping
        self.cov_ema_decay = cov_ema_decay
        self.layer_collection = layer_collection
        self.estimation_mode = estimation_mode
        self.colocate_gradient_with_ops = colocate_gradient_with_ops

        # Tensor versions of the constructor arguments, created in _prepare().
        self._lr_t = None
        self._decay_t = None
        self._epsilon_t = None

        self._fisher_est = est.FisherEstimator(self.damping_fn,
                                               self._variables,
                                               self.cov_ema_decay,
                                               self.layer_collection,
                                               self.estimation_mode,
                                               self.colocate_gradient_with_ops)
Ejemplo n.º 10
0
 def testVariableWrongNumberOfUses(self, mock_uses):
     with self.assertRaises(ValueError):
         est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
                                         self.layer_collection)
         est.make_ops_and_vars()
Ejemplo n.º 11
0
    def __init__(self,
                 learning_rate,
                 cov_ema_decay,
                 damping,
                 layer_collection,
                 var_list=None,
                 momentum=0.9,
                 momentum_type="regular",
                 norm_constraint=None,
                 name="KFAC",
                 estimation_mode="gradients",
                 colocate_gradients_with_ops=True,
                 batch_size=None,
                 cov_devices=None,
                 inv_devices=None):
        """Initializes the KFAC optimizer with the given settings.

    Args:
      learning_rate: The base learning rate for the optimizer.  Should probably
          be set to 1.0 when using momentum_type = 'qmodel', but can still be
          set lowered if desired (effectively lowering the trust in the
          quadratic model.)
      cov_ema_decay: The decay factor used when calculating the covariance
          estimate moving averages.
      damping: The damping factor used to stabilize training due to errors in
          the local approximation with the Fisher information matrix, and to
          regularize the update direction by making it closer to the gradient.
          If damping is adapted during training then this value is used for
          initializing damping varaible.
          (Higher damping means the update looks more like a standard gradient
          update - see Tikhonov regularization.)
      layer_collection: The layer collection object, which holds the fisher
          blocks, kronecker factors, and losses associated with the
          graph.  The layer_collection cannot be modified after KfacOptimizer's
          initialization.
      var_list: Optional list or tuple of variables to train. Defaults to the
          list of variables collected in the graph under the key
          `GraphKeys.TRAINABLE_VARIABLES`.
      momentum: The momentum decay constant to use. Only applies when
          momentum_type is 'regular' or 'adam'. (Default: 0.9)
      momentum_type: The type of momentum to use in this optimizer, one of
          'regular', 'adam', or 'qmodel'. (Default: 'regular')
      norm_constraint: float or Tensor. If specified, the update is scaled down
          so that its approximate squared Fisher norm v^T F v is at most the
          specified value. May only be used with momentum type 'regular'.
          (Default: None)
      name: The name for this optimizer. (Default: 'KFAC')
      estimation_mode: The type of estimator to use for the Fishers.  Can be
          'gradients', 'empirical', 'curvature_propagation', or 'exact'.
          (Default: 'gradients'). See the doc-string for FisherEstimator for
          more a more detailed description of these options.
      colocate_gradients_with_ops: Whether we should request gradients we
          compute in the estimator be colocated with their respective ops.
          (Default: True)
      batch_size: The size of the mini-batch. Only needed when momentum_type
          == 'qmodel' or when automatic adjustment is used.  (Default: None)
      cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
          computations will be placed on these devices in a round-robin fashion.
          Can be None, which means that no devices are specified. Only used
          with (soon-to-be-depcrecated "convenience" properties).
      inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
          computations will be placed on these devices in a round-robin fashion.
          Can be None, which means that no devices are specified. Only used
          with (soon-to-be-depcrecated "convenience" properties).

    Raises:
      ValueError: If the momentum type is unsupported.
      ValueError: If clipping is used with momentum type other than 'regular'.
      ValueError: If no losses have been registered with layer_collection.
      ValueError: If momentum is non-zero and momentum_type is not 'regular'
          or 'adam'.
    """

        variables = var_list
        if variables is None:
            variables = tf_variables.trainable_variables()

        # Parameters to be passed to the Fisher estimator:
        self._variables = variables
        self._cov_ema_decay = cov_ema_decay
        self._layers = layer_collection
        self._estimation_mode = estimation_mode
        self._colocate_gradients_with_ops = colocate_gradients_with_ops
        self._cov_devices = cov_devices
        self._inv_devices = inv_devices

        # The below paramaters are required only if damping needs to be adapated.
        # These parameters can be set by calling
        # set_damping_adaptation_params() explicitly.
        self._damping_adaptation_decay = 0.95
        self._damping_adaptation_interval = 5
        # Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval)
        self._omega = (
            self._damping_adaptation_decay**self._damping_adaptation_interval)
        self._adapt_damping = False
        self._min_damping = 1e-5
        self._prev_train_batch = None
        self._is_chief = False
        self._loss_fn = None
        self._damping_constant = damping
        self._damping = None
        self._rho = None
        self._prev_loss = None
        self._q_model_change = None
        self._update_damping_op = None

        momentum_type = momentum_type.lower()
        legal_momentum_types = ["regular", "adam", "qmodel"]

        if momentum_type not in legal_momentum_types:
            raise ValueError(
                "Unsupported momentum type {}. Must be one of {}.".format(
                    momentum_type, legal_momentum_types))
        if momentum_type != "regular" and norm_constraint is not None:
            raise ValueError("Update clipping is only supported with momentum"
                             "type 'regular'.")
        if momentum_type not in ["regular", "adam"] and momentum != 0:
            raise ValueError(
                "Momentum must be unspecified if using a momentum_type "
                "other than 'regular' or 'adam'.")

        # Extra parameters of the optimizer
        self._momentum = momentum
        self._momentum_type = momentum_type
        self._norm_constraint = norm_constraint
        self._batch_size = batch_size

        with variable_scope.variable_scope(name):
            self._fisher_est = est.FisherEstimator(
                self._variables,
                self._cov_ema_decay,
                self.damping,
                self._layers,
                exps=(-1, ),
                estimation_mode=self._estimation_mode,
                colocate_gradients_with_ops=self._colocate_gradients_with_ops)

        super(KfacOptimizer, self).__init__(learning_rate, name=name)
Ejemplo n.º 12
0
 def testAllModesBuild(self):
   for mode in _ALL_ESTIMATION_MODES:
     with self._graph.as_default():
       estimator.FisherEstimator([self.weights], 0.1, 0.2,
                                 self.layer_collection, mode)
Ejemplo n.º 13
0
 def testModeListCorrect(self):
   with self._graph.as_default():
     est = estimator.FisherEstimator([self.weights], 0.1, 0.2,
                                     self.layer_collection)
   self.assertItemsEqual(_ALL_ESTIMATION_MODES, est._gradient_fns.keys())
Ejemplo n.º 14
0
    def __init__(self,
                 learning_rate,
                 cov_ema_decay,
                 damping,
                 layer_collection,
                 var_list=None,
                 momentum=0.9,
                 momentum_type="regular",
                 norm_constraint=None,
                 name="KFAC",
                 estimation_mode="gradients",
                 colocate_gradients_with_ops=True,
                 cov_devices=None,
                 inv_devices=None):
        """Initializes the KFAC optimizer with the given settings.

    Args:
      learning_rate: The base learning rate for the optimizer.  Should probably
          be set to 1.0 when using momentum_type = 'qmodel', but can still be
          set lowered if desired (effectively lowering the trust in the
          quadratic model.)
      cov_ema_decay: The decay factor used when calculating the covariance
          estimate moving averages.
      damping: The damping factor used to stabilize training due to errors in
          the local approximation with the Fisher information matrix, and to
          regularize the update direction by making it closer to the gradient.
          (Higher damping means the update looks more like a standard gradient
          update - see Tikhonov regularization.)
      layer_collection: The layer collection object, which holds the fisher
          blocks, kronecker factors, and losses associated with the
          graph.  The layer_collection cannot be modified after KfacOptimizer's
          initialization.
      var_list: Optional list or tuple of variables to train. Defaults to the
          list of variables collected in the graph under the key
          `GraphKeys.TRAINABLE_VARIABLES`.
      momentum: The momentum decay constant to use. Only applies when
          momentum_type is 'regular' or 'adam'. (Default: 0.9)
      momentum_type: The type of momentum to use in this optimizer, one of
          'regular', 'adam', or 'qmodel'. (Default: 'regular')
      norm_constraint: float or Tensor. If specified, the update is scaled down
          so that its approximate squared Fisher norm v^T F v is at most the
          specified value. May only be used with momentum type 'regular'.
          (Default: None)
      name: The name for this optimizer. (Default: 'KFAC')
      estimation_mode: The type of estimator to use for the Fishers.  Can be
          'gradients', 'empirical', 'curvature_propagation', or 'exact'.
          (Default: 'gradients'). See the doc-string for FisherEstimator for
          more a more detailed description of these options.
      colocate_gradients_with_ops: Whether we should request gradients we
          compute in the estimator be colocated with their respective ops.
          (Default: True)
      cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance
          computations will be placed on these devices in a round-robin fashion.
          Can be None, which means that no devices are specified.
      inv_devices: Iterable of device strings (e.g. '/gpu:0'). Inversion
          computations will be placed on these devices in a round-robin fashion.
          Can be None, which means that no devices are specified.

    Raises:
      ValueError: If the momentum type is unsupported.
      ValueError: If clipping is used with momentum type other than 'regular'.
      ValueError: If no losses have been registered with layer_collection.
      ValueError: If momentum is non-zero and momentum_type is not 'regular'
          or 'adam'.
    """

        variables = var_list
        if variables is None:
            variables = tf_variables.trainable_variables()

        self._fisher_est = est.FisherEstimator(
            variables,
            cov_ema_decay,
            damping,
            layer_collection,
            estimation_mode=estimation_mode,
            colocate_gradients_with_ops=colocate_gradients_with_ops,
            cov_devices=cov_devices,
            inv_devices=inv_devices)

        momentum_type = momentum_type.lower()
        legal_momentum_types = ["regular", "adam", "qmodel"]

        if momentum_type not in legal_momentum_types:
            raise ValueError(
                "Unsupported momentum type {}. Must be one of {}.".format(
                    momentum_type, legal_momentum_types))
        if momentum_type != "regular" and norm_constraint is not None:
            raise ValueError("Update clipping is only supported with momentum"
                             "type 'regular'.")
        if momentum_type not in ["regular", "adam"] and momentum != 0:
            raise ValueError(
                "Momentum must be unspecified if using a momentum_type "
                "other than 'regular' or 'adam'.")

        self._momentum = momentum
        self._momentum_type = momentum_type
        self._norm_constraint = norm_constraint

        # this is a bit of a hack
        # TODO(duckworthd): Handle this in a better way (e.g. pass it in?)
        self._batch_size = array_ops.shape(
            layer_collection.losses[0].inputs)[0]
        self._losses = layer_collection.losses

        self.cov_update_op = self._fisher_est.cov_update_op
        self.inv_update_op = self._fisher_est.inv_update_op
        self.inv_updates_dict = self._fisher_est.inv_updates_dict

        super(KfacOptimizer, self).__init__(learning_rate, name=name)
Ejemplo n.º 15
0
    def __init__(
        self,
        learning_rate,
        cov_ema_decay,
        damping,
        layer_collection,
        momentum=0.,
        momentum_type="regular",
        norm_constraint=None,
        name="KFAC",
    ):
        """Initializes the KFAC optimizer with the given settings.

    Args:
      learning_rate: The base learning rate for the optimizer.  Should probably
          be set to 1.0 when using momentum_type = 'qmodel', but can still be
          set lowered if desired (effectively lowering the trust in the
          quadratic model.)
      cov_ema_decay: The decay factor used when calculating the covariance
          estimate moving averages.
      damping: The damping factor used to stabilize training due to errors in
          the local approximation with the Fisher information matrix, and to
          regularize the update direction by making it closer to the gradient.
          (Higher damping means the update looks more like a standard gradient
          update - see Tikhonov regularization.)
      layer_collection: The layer collection object, which holds the fisher
          blocks, kronecker factors, and losses associated with the
          graph.  The layer_collection cannot be modified after KfacOptimizer's
          initialization.
      momentum: The momentum value for this optimizer. Only applies when
          momentum_type is 'regular' or 'adam'. (Default: 0)
      momentum_type: The type of momentum to use in this optimizer, one of
          'regular', 'adam', or 'qmodel'. (Default: 'regular')
      norm_constraint: float or Tensor. If specified, the update is scaled down
          so that its approximate squared Fisher norm v^T F v is at most the
          specified value. May only be used with momentum type 'regular'.
          (Default: None)
      name: The name for this optimizer. (Default: 'KFAC')

    Raises:
      ValueError: If the momentum type is unsupported.
      ValueError: If clipping is used with momentum type other than 'regular'.
      ValueError: If no losses have been registered with layer_collection.
      ValueError: If momentum is non-zero and momentum_type is not 'regular'
          or 'adam'.
    """

        # We may consider determining the set of variables some other way, but for
        # now it's just all the trainable variables.
        variables = tf_variables.trainable_variables()

        self._fisher_est = est.FisherEstimator(variables, cov_ema_decay,
                                               damping, layer_collection)

        momentum_type = momentum_type.lower()
        legal_momentum_types = ["regular", "adam", "qmodel"]

        if momentum_type not in legal_momentum_types:
            raise ValueError(
                "Unsupported momentum type {}. Must be one of {}.".format(
                    momentum_type, legal_momentum_types))
        if momentum_type != "regular" and norm_constraint is not None:
            raise ValueError("Update clipping is only supported with momentum"
                             "type 'regular'.")
        if momentum_type not in ["regular", "adam"] and momentum != 0:
            raise ValueError(
                "Momentum must be unspecified if using a momentum_type "
                "other than 'regular' or 'adam'.")

        self._momentum = ops.convert_to_tensor(momentum, name="momentum")
        self._momentum_type = momentum_type
        self._norm_constraint = norm_constraint

        # this is a bit of a hack
        # TODO(duckworthd): Handle this in a better way (e.g. pass it in?)
        self._batch_size = array_ops.shape(
            layer_collection.losses[0].inputs)[0]
        self._losses = layer_collection.losses

        self.cov_update_op = self._fisher_est.cov_update_op
        self.inv_update_op = self._fisher_est.inv_update_op
        self.inv_updates_dict = self._fisher_est.inv_updates_dict

        super(KfacOptimizer, self).__init__(learning_rate, name=name)
Ejemplo n.º 16
0
 def testVariableWrongNumberOfUses(self, mock_uses):
   with self.assertRaises(ValueError):
     estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1,
                               self.layer_collection)