Beispiel #1
0
    def register_categorical_predictive_distribution(self,
                                                     logits,
                                                     seed=None,
                                                     targets=None,
                                                     name=None):
        """Registers a categorical predictive distribution.

    Args:
      logits: The logits of the distribution (i.e. its parameters).
      seed: The seed for the RNG (for debugging) (Default: None)
      targets: (OPTIONAL) The targets for the loss function.  Only required if
        one wants to call total_loss() instead of total_sampled_loss().
        total_loss() is required, for example, to estimate the
        "empirical Fisher" (instead of the true Fisher).
        (Default: None)
      name: (OPTIONAL) str or None. Unique name for this loss function. If None,
        a new name is generated. (Default: None)
    """
        name = name or self._graph.unique_name(
            "register_categorical_predictive_distribution")
        if name in self._loss_dict:
            raise NotImplementedError(
                "Adding logits to an existing LossFunction not yet supported.")
        loss = lf.CategoricalLogitsNegativeLogProbLoss(logits,
                                                       targets=targets,
                                                       seed=seed)
        self._loss_dict[name] = loss
Beispiel #2
0
    def register_categorical_predictive_distribution(self,
                                                     logits,
                                                     seed=None,
                                                     targets=None,
                                                     name=None,
                                                     reuse=VARIABLE_SCOPE):
        """Registers a categorical predictive distribution.

    Args:
      logits: The logits of the distribution (i.e. its parameters).
      seed: The seed for the RNG (for debugging) (Default: None)
      targets: (OPTIONAL) The targets for the loss function.  Only required if
        one wants to call total_loss() instead of total_sampled_loss().
        total_loss() is required, for example, to estimate the
        "empirical Fisher" (instead of the true Fisher).
        (Default: None)
      name: (OPTIONAL) str or None. Unique name for this loss function. If None,
        a new name is generated. (Default: None)
      reuse: (OPTIONAL) bool or str.  If True, reuse an existing FisherBlock.
        If False, create a new FisherBlock.  If VARIABLE_SCOPE, use
        tf.get_variable_scope().reuse.

    Raises:
      ValueError: If reuse == True and name == None.
      ValueError: If reuse == True and seed != None.
      KeyError: If reuse == True and no existing LossFunction with 'name' found.
      KeyError: If reuse == False and existing LossFunction with 'name' found.
    """
        name = name or self._graph.unique_name(
            "register_categorical_predictive_distribution")

        if reuse == VARIABLE_SCOPE:
            reuse = variable_scope.get_variable_scope().reuse

        if reuse:
            if name is None:
                raise ValueError(
                    "If reuse is enabled, loss function's name must be set.")
            if seed is not None:
                raise ValueError(
                    "Seed can only be specified at LossFunction instantiation."
                )

            loss = self._loss_dict.get(name, None)

            if loss is None:
                raise KeyError(
                    "Unable to find loss function named {}. Create a new LossFunction "
                    "with reuse=False.".format(name))

            loss.register_additional_minibatch(logits, targets=targets)
        else:
            if name in self._loss_dict:
                raise KeyError(
                    "Loss function named {} already exists. Set reuse=True to append "
                    "another minibatch.".format(name))
            loss = lf.CategoricalLogitsNegativeLogProbLoss(logits,
                                                           targets=targets,
                                                           seed=seed)
            self._loss_dict[name] = loss
Beispiel #3
0
  def register_categorical_predictive_distribution(self,
                                                   logits,
                                                   seed=None,
                                                   targets=None,
                                                   name=None,
                                                   reuse=VARIABLE_SCOPE):
    """Registers a categorical predictive distribution.

    Args:
      logits: The logits of the distribution (i.e. its parameters).
      seed: The seed for the RNG (for debugging) (Default: None)
      targets: (OPTIONAL) The targets for the loss function.  Only required if
        one wants to call total_loss() instead of total_sampled_loss().
        total_loss() is required, for example, to estimate the
        "empirical Fisher" (instead of the true Fisher).
        (Default: None)
      name: (OPTIONAL) str or None. Unique name for this loss function. If None,
        a new name is generated. (Default: None)
      reuse: (OPTIONAL) bool or str.  If True, reuse an existing FisherBlock.
        If False, create a new FisherBlock.  If VARIABLE_SCOPE, use
        tf.get_variable_scope().reuse.
    """
    loss = lf.CategoricalLogitsNegativeLogProbLoss(logits, targets=targets,
                                                   seed=seed)
    self.register_loss_function(loss, logits,
                                "categorical_predictive_distribution",
                                name=name, reuse=reuse)
Beispiel #4
0
 def testSample(self):
   """Ensure samples can be drawn."""
   with ops.Graph().as_default(), self.test_session() as sess:
     logits = np.asarray([
         [0., 0., 0.],  #
         [1., -1., 0.]
     ]).astype(np.float32)
     loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
         array_ops.constant(logits))
     sample = loss.sample(42)
     sample = sess.run(sample)
     self.assertEqual(sample.shape, (2,))
Beispiel #5
0
  def testEvaluateOnSample(self):
    """Ensure log probability of a sample can be drawn."""
    with ops.Graph().as_default(), self.test_session() as sess:
      logits = np.asarray([
          [0., 0., 0.],  #
          [1., -1., 0.]
      ]).astype(np.float32)
      loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
          array_ops.constant(logits))
      neg_log_prob = loss.evaluate_on_sample(42)

      # Simply ensure this doesn't crash. As the output is random, it's
      # difficult to say if the output is correct or not...
      neg_log_prob = sess.run(neg_log_prob)
 def testMultiMinibatchRegistration(self):
     """Ensure this loss function supports registering multiple minibatches."""
     with ops.Graph().as_default():
         tower_logits = []
         loss = None
         num_towers = 5
         for _ in range(num_towers):
             logits = random_ops.random_uniform(shape=[2, 3])
             tower_logits.append(logits)
             if loss is None:
                 loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
                     logits)
             else:
                 loss.register_additional_minibatch(logits)
         self.assertListEqual(loss.input_minibatches, tower_logits)
         self.assertEqual(loss.num_registered_minibatches, num_towers)
    def testMultiplyFisherBatch(self):
        with ops.Graph().as_default(), self.test_session() as sess:
            logits = np.array([[1., 2., 3.], [4., 6., 8.]])
            loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)

            vector = np.array([[1., 2., 3.], [5., 3., 1.]])

            na = np.newaxis
            probs = np.exp(logits -
                           np.logaddexp.reduce(logits, axis=-1, keepdims=True))
            fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[
                ..., na, :]

            result = loss.multiply_fisher(vector)
            expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :]
            self.assertEqual(sess.run(result).shape, logits.shape)
            self.assertAllClose(expected_result, sess.run(result))
    def testUpdateVelocities(self):
        with ops.Graph().as_default(), self.test_session() as sess:
            layers = lc.LayerCollection()
            layers.losses = [
                lf.CategoricalLogitsNegativeLogProbLoss(
                    array_ops.constant([1.0]))
            ]
            opt = optimizer.KfacOptimizer(0.1,
                                          0.2,
                                          0.3,
                                          layers,
                                          momentum=0.5,
                                          momentum_type='regular')
            x = variable_scope.get_variable('x',
                                            initializer=array_ops.ones((2, 2)))
            y = variable_scope.get_variable('y',
                                            initializer=array_ops.ones(
                                                (2, 2)) * 2)
            vec1 = array_ops.ones((2, 2)) * 3
            vec2 = array_ops.ones((2, 2)) * 4

            model_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
            update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5)
            opt_vars = [
                v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
                if v not in model_vars
            ]

            sess.run(tf_variables.global_variables_initializer())
            old_opt_vars = sess.run(opt_vars)

            # Optimizer vars start out at 0.
            for opt_var in old_opt_vars:
                self.assertAllEqual(sess.run(array_ops.zeros_like(opt_var)),
                                    opt_var)

            sess.run(update_op)
            new_opt_vars = sess.run(opt_vars)
            # After one update, the velocities are equal to the vectors.
            for vec, opt_var in zip([vec1, vec2], new_opt_vars):
                self.assertAllEqual(sess.run(vec), opt_var)

            sess.run(update_op)
            final_opt_vars = sess.run(opt_vars)
            for first, second in zip(new_opt_vars, final_opt_vars):
                self.assertFalse(np.equal(first, second).all())
    def testMultiplyFisherSingleVector(self):
        with ops.Graph().as_default(), self.test_session() as sess:
            logits = np.array([1., 2., 3.])
            loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits)

            # the LossFunction.multiply_fisher docstring only says it supports the
            # case where the vector is the same shape as the input natural parameters
            # (i.e. the logits here), but here we also test leading dimensions
            vector = np.array([1., 2., 3.])
            vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)]

            probs = np.exp(logits - np.logaddexp.reduce(logits))
            fisher = np.diag(probs) - np.outer(probs, probs)

            for vector in vectors:
                result = loss.multiply_fisher(vector)
                expected_result = np.dot(vector, fisher)
                self.assertAllClose(expected_result, sess.run(result))
    def register_categorical_predictive_distribution(self,
                                                     logits,
                                                     seed=None,
                                                     targets=None):
        """Registers a categorical predictive distribution.

    Args:
      logits: The logits of the distribution (i.e. its parameters).
      seed: The seed for the RNG (for debugging) (Default: None)
      targets: (OPTIONAL) The targets for the loss function.  Only required if
        one wants to call total_loss() instead of total_sampled_loss().
        total_loss() is required, for example, to estimate the
        "empirical Fisher" (instead of the true Fisher).
        (Default: None)
    """
        loss = lf.CategoricalLogitsNegativeLogProbLoss(logits,
                                                       targets=targets,
                                                       seed=seed)
        self.losses.append(loss)
Beispiel #11
0
  def testEvaluateOnTargets(self):
    """Ensure log probability can be evaluated correctly."""
    with ops.Graph().as_default(), self.test_session() as sess:
      logits = np.asarray([
          [0., 0., 0.],  #
          [1., -1., 0.]
      ]).astype(np.float32)
      targets = np.asarray([2, 1]).astype(np.int32)
      loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(
          array_ops.constant(logits), targets=array_ops.constant(targets))
      neg_log_prob = loss.evaluate()
      neg_log_prob = sess.run(neg_log_prob)

      # Calculate explicit log probability of targets.
      probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
      log_probs = np.log([
          probs[0, targets[0]],  #
          probs[1, targets[1]]
      ])
      expected_log_prob = np.sum(log_probs)

      self.assertAllClose(neg_log_prob, -expected_log_prob)