def register_normal_predictive_distribution(self, mean, var=0.5, seed=None, targets=None, name=None, reuse=VARIABLE_SCOPE): """Registers a normal predictive distribution. Args: mean: The mean vector defining the distribution. var: The variance (must be a scalar). Note that the default value of 0.5 corresponds to a standard squared error loss (target - prediction)**2. If your squared error loss is of the form 0.5*(target - prediction)**2 you should use var=1.0. (Default: 0.5) seed: The seed for the RNG (for debugging) (Default: None) targets: (OPTIONAL) The targets for the loss function. Only required if one wants to call total_loss() instead of total_sampled_loss(). total_loss() is required, for example, to estimate the "empirical Fisher" (instead of the true Fisher). (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock. If False, create a new FisherBlock. If VARIABLE_SCOPE, use tf.get_variable_scope().reuse. """ loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets, seed=seed) self.register_loss_function(loss, mean, "normal_predictive_distribution", name=name, reuse=reuse)
def register_normal_predictive_distribution(self, mean, var=0.5, seed=None, targets=None, name=None): """Registers a normal predictive distribution. Args: mean: The mean vector defining the distribution. var: The variance (must be a scalar). Note that the default value of 0.5 corresponds to a standard squared error loss (target - prediction)**2. If your squared error loss is of the form 0.5*(target - prediction)**2 you should use var=1.0. (Default: 0.5) seed: The seed for the RNG (for debugging) (Default: None) targets: (OPTIONAL) The targets for the loss function. Only required if one wants to call total_loss() instead of total_sampled_loss(). total_loss() is required, for example, to estimate the "empirical Fisher" (instead of the true Fisher). (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) """ name = name or self._graph.unique_name( "register_normal_predictive_distribution") if name in self._loss_dict: raise NotImplementedError( "Adding logits to an existing LossFunction not yet supported.") loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets, seed=seed) self._loss_dict[name] = loss
def register_normal_predictive_distribution(self, mean, var=0.5, seed=None, targets=None): """Registers a normal predictive distribution. Args: mean: The mean vector defining the distribution. var: The variance (must be a scalar). Note that the default value of 0.5 corresponds to a standard squared error loss (target - prediction)**2. If your squared error loss is of the form 0.5*(target - prediction)**2 you should use var=1.0. (Default: 0.5) seed: The seed for the RNG (for debugging) (Default: None) targets: (OPTIONAL) The targets for the loss function. Only required if one wants to call total_loss() instead of total_sampled_loss(). total_loss() is required, for example, to estimate the "empirical Fisher" (instead of the true Fisher). (Default: None) """ loss = lf.NormalMeanNegativeLogProbLoss(mean, var, targets=targets, seed=seed) self.losses.append(loss)