Exemplo n.º 1
0
  def register_normal_predictive_distribution(self,
                                              mean,
                                              var=0.5,
                                              seed=None,
                                              targets=None,
                                              name=None,
                                              reuse=VARIABLE_SCOPE):
    """Registers a normal predictive distribution.

    Args:
      mean: The mean vector defining the distribution.
      var: The variance (must be a scalar).  Note that the default value of
        0.5 corresponds to a standard squared error loss (target -
        prediction)**2. If your squared error loss is of the form
        0.5*(target - prediction)**2 you should use var=1.0. (Default: 0.5)
      seed: The seed for the RNG (for debugging) (Default: None)
      targets: (OPTIONAL) The targets for the loss function.  Only required if
        one wants to call total_loss() instead of total_sampled_loss().
        total_loss() is required, for example, to estimate the
        "empirical Fisher" (instead of the true Fisher).
        (Default: None)
      name: (OPTIONAL) str or None. Unique name for this loss function. If None,
        a new name is generated. (Default: None)
      reuse: (OPTIONAL) bool or str.  If True, reuse an existing FisherBlock.
        If False, create a new FisherBlock.  If VARIABLE_SCOPE, use
        tf.get_variable_scope().reuse.
    """
    loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets,
                                            seed=seed)
    self.register_loss_function(loss, mean,
                                "normal_predictive_distribution",
                                name=name, reuse=reuse)
Exemplo n.º 2
0
    def register_normal_predictive_distribution(self,
                                                mean,
                                                var=0.5,
                                                seed=None,
                                                targets=None,
                                                name=None):
        """Registers a normal predictive distribution.

    Args:
      mean: The mean vector defining the distribution.
      var: The variance (must be a scalar).  Note that the default value of
        0.5 corresponds to a standard squared error loss (target -
        prediction)**2. If your squared error loss is of the form
        0.5*(target - prediction)**2 you should use var=1.0. (Default: 0.5)
      seed: The seed for the RNG (for debugging) (Default: None)
      targets: (OPTIONAL) The targets for the loss function.  Only required if
        one wants to call total_loss() instead of total_sampled_loss().
        total_loss() is required, for example, to estimate the
        "empirical Fisher" (instead of the true Fisher).
        (Default: None)
      name: (OPTIONAL) str or None. Unique name for this loss function. If None,
        a new name is generated. (Default: None)
    """
        name = name or self._graph.unique_name(
            "register_normal_predictive_distribution")
        if name in self._loss_dict:
            raise NotImplementedError(
                "Adding logits to an existing LossFunction not yet supported.")
        loss = lf.NormalMeanNegativeLogProbLoss(mean,
                                                var,
                                                targets=targets,
                                                seed=seed)
        self._loss_dict[name] = loss
Exemplo n.º 3
0
    def register_normal_predictive_distribution(self,
                                                mean,
                                                var=0.5,
                                                seed=None,
                                                targets=None):
        """Registers a normal predictive distribution.

    Args:
      mean: The mean vector defining the distribution.
      var: The variance (must be a scalar).  Note that the default value of
        0.5 corresponds to a standard squared error loss (target -
        prediction)**2. If your squared error loss is of the form
        0.5*(target - prediction)**2 you should use var=1.0. (Default: 0.5)
      seed: The seed for the RNG (for debugging) (Default: None)
      targets: (OPTIONAL) The targets for the loss function.  Only required if
        one wants to call total_loss() instead of total_sampled_loss().
        total_loss() is required, for example, to estimate the
        "empirical Fisher" (instead of the true Fisher).
        (Default: None)
    """
        loss = lf.NormalMeanNegativeLogProbLoss(mean,
                                                var,
                                                targets=targets,
                                                seed=seed)
        self.losses.append(loss)