def register_multi_bernoulli_predictive_distribution(self, logits, seed=None, targets=None, name=None, reuse=VARIABLE_SCOPE): """Registers a multi-Bernoulli predictive distribution. Args: logits: The logits of the distribution (i.e. its parameters). seed: The seed for the RNG (for debugging) (Default: None) targets: (OPTIONAL) The targets for the loss function. Only required if one wants to call total_loss() instead of total_sampled_loss(). total_loss() is required, for example, to estimate the "empirical Fisher" (instead of the true Fisher). (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock. If False, create a new FisherBlock. If VARIABLE_SCOPE, use tf.get_variable_scope().reuse. """ loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets, seed=seed) self.register_loss_function(loss, logits, "multi_bernoulli_predictive_distribution", name=name, reuse=reuse)
def register_multi_bernoulli_predictive_distribution(self, logits, seed=None, targets=None, name=None): """Registers a multi-Bernoulli predictive distribution. Args: logits: The logits of the distribution (i.e. its parameters). seed: The seed for the RNG (for debugging) (Default: None) targets: (OPTIONAL) The targets for the loss function. Only required if one wants to call total_loss() instead of total_sampled_loss(). total_loss() is required, for example, to estimate the "empirical Fisher" (instead of the true Fisher). (Default: None) name: (OPTIONAL) str or None. Unique name for this loss function. If None, a new name is generated. (Default: None) """ name = name or self._graph.unique_name( "register_multi_bernoulli_predictive_distribution") if name in self._loss_dict: raise NotImplementedError( "Adding logits to an existing LossFunction not yet supported.") loss = lf.MultiBernoulliNegativeLogProbLoss( logits, targets=targets, seed=seed) self._loss_dict[name] = loss
def register_multi_bernoulli_predictive_distribution( self, logits, seed=None, targets=None): """Registers a multi-Bernoulli predictive distribution. Args: logits: The logits of the distribution (i.e. its parameters). seed: The seed for the RNG (for debugging) (Default: None) targets: (OPTIONAL) The targets for the loss function. Only required if one wants to call total_loss() instead of total_sampled_loss(). total_loss() is required, for example, to estimate the "empirical Fisher" (instead of the true Fisher). (Default: None) """ loss = lf.MultiBernoulliNegativeLogProbLoss(logits, targets=targets, seed=seed) self.losses.append(loss)