def _classifier_logits(self, theta: Tensor, x: Tensor, num_atoms: int) -> Tensor: """Return logits obtained through classifier forward pass. The logits are obtained from atomic sets of (theta,x) pairs. """ batch_size = theta.shape[0] repeated_x = utils.repeat_rows(x, num_atoms) # Choose `1` or `num_atoms - 1` thetas from the rest of the batch for each x. probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1) choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False) contrasting_theta = theta[choices] atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape(batch_size * num_atoms, -1) theta_and_x = torch.cat((atomic_theta, repeated_x), dim=1) return self._neural_net(theta_and_x)
def _get_loss(parameters, observations): # num_atoms = parameters.shape[0] num_atoms = self._num_atoms if self._num_atoms > 0 else clipped_batch_size if self._classifier_loss == "aalr": assert num_atoms == 2, "aalr allows only two atoms, i.e. num_atoms=2." repeated_observations = utils.repeat_rows(observations, num_atoms) # Choose between 1 and num_atoms - 1 parameters from the rest # of the batch for each observation. assert 0 < num_atoms - 1 < clipped_batch_size probs = ( (1 / (clipped_batch_size - 1)) * torch.ones(clipped_batch_size, clipped_batch_size) * (1 - torch.eye(clipped_batch_size)) ) choices = torch.multinomial( probs, num_samples=num_atoms - 1, replacement=False ) contrasting_parameters = parameters[choices] atomic_parameters = torch.cat( (parameters[:, None, :], contrasting_parameters), dim=1 ).reshape(clipped_batch_size * num_atoms, -1) inputs = torch.cat((atomic_parameters, repeated_observations), dim=1) if self._classifier_loss == "aalr": network_outputs = self._neural_posterior.neural_net(inputs) likelihood = torch.squeeze(torch.sigmoid(network_outputs)) # the first clipped_batch_size elements are the ones where theta and x # are sampled from the joint p(theta, x) and are labelled 1s. # The second clipped_batch_size elements are the ones where theta and x # are sampled from the marginals p(theta)p(x) and are labelled 0s. labels = torch.cat( (torch.ones(clipped_batch_size), torch.zeros(clipped_batch_size)) ) # binary cross entropy to learn the likelihood loss = criterion(likelihood, labels) else: logits = self._neural_posterior.neural_net(inputs).reshape( clipped_batch_size, num_atoms ) # index 0 is the parameter set sampled from the joint log_prob = logits[:, 0] - torch.logsumexp(logits, dim=-1) loss = -torch.mean(log_prob) return loss
def _log_prob_proposal_posterior_atomic( self, theta: Tensor, x: Tensor, masks: Tensor ): """ Return log probability of the proposal posterior for atomic proposals. We have two main options when evaluating the proposal posterior. (1) Generate atoms from the proposal prior. (2) Generate atoms from a more targeted distribution, such as the most recent posterior. If we choose the latter, it is likely beneficial not to do this in the first round, since we would be sampling from a randomly-initialized neural density estimator. Args: theta: Batch of parameters θ. x: Batch of data. masks: Mask that is True for prior samples in the batch in order to train them with prior loss. Returns: Log-probability of the proposal posterior. """ batch_size = theta.shape[0] num_atoms = clamp_and_warn( "num_atoms", self._num_atoms, min_val=2, max_val=batch_size ) # Each set of parameter atoms is evaluated using the same x, # so we repeat rows of the data x, e.g. [1, 2] -> [1, 1, 2, 2] repeated_x = repeat_rows(x, num_atoms) # To generate the full set of atoms for a given item in the batch, # we sample without replacement num_atoms - 1 times from the rest # of the theta in the batch. probs = ones(batch_size, batch_size) * (1 - eye(batch_size)) / (batch_size - 1) choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False) contrasting_theta = theta[choices] # We can now create our sets of atoms from the contrasting parameter sets # we have generated. atomic_theta = torch.cat((theta[:, None, :], contrasting_theta), dim=1).reshape( batch_size * num_atoms, -1 ) # Evaluate large batch giving (batch_size * num_atoms) log prob posterior evals. log_prob_posterior = self._neural_net.log_prob(atomic_theta, repeated_x) self._assert_all_finite(log_prob_posterior, "posterior eval") log_prob_posterior = log_prob_posterior.reshape(batch_size, num_atoms) # Get (batch_size * num_atoms) log prob prior evals. log_prob_prior = self._prior.log_prob(atomic_theta) log_prob_prior = log_prob_prior.reshape(batch_size, num_atoms) self._assert_all_finite(log_prob_prior, "prior eval") # Compute unnormalized proposal posterior. unnormalized_log_prob = log_prob_posterior - log_prob_prior # Normalize proposal posterior across discrete set of atoms. log_prob_proposal_posterior = unnormalized_log_prob[:, 0] - torch.logsumexp( unnormalized_log_prob, dim=-1 ) self._assert_all_finite(log_prob_proposal_posterior, "proposal posterior eval") # XXX This evaluates the posterior on _all_ prior samples if self._use_combined_loss: log_prob_posterior_non_atomic = self._neural_net.log_prob(theta, x) masks = masks.reshape(-1) log_prob_proposal_posterior = ( masks * log_prob_posterior_non_atomic + log_prob_proposal_posterior ) return log_prob_proposal_posterior
def _get_log_prob_proposal_posterior(self, inputs, context, masks): """ We have two main options when evaluating the proposal posterior. (1) Generate atoms from the proposal prior. (2) Generate atoms from a more targeted distribution, such as the most recent posterior. If we choose the latter, it is likely beneficial not to do this in the first round, since we would be sampling from a randomly initialized neural density estimator. Args: inputs: torch.Tensor Batch of parameters. context: torch.Tensor Batch of observations. masks: torch.Tensor binary, whether or not to retrain with prior loss on specific prior sample Returns: torch.Tensor [1] log_prob_proposal_posterior """ batch_size = inputs.shape[0] num_atoms = self._num_atoms if self._num_atoms > 0 else batch_size # Each set of parameter atoms is evaluated using the same observation, # so we repeat rows of the context. # e.g. [1, 2] -> [1, 1, 2, 2] repeated_context = utils.repeat_rows(context, num_atoms) # To generate the full set of atoms for a given item in the batch, # we sample without replacement num_atoms - 1 times from the rest # of the parameters in the batch. assert 0 < num_atoms - 1 < batch_size probs = ((1 / (batch_size - 1)) * torch.ones(batch_size, batch_size) * (1 - torch.eye(batch_size))) choices = torch.multinomial(probs, num_samples=num_atoms - 1, replacement=False) contrasting_inputs = inputs[choices] # We can now create our sets of atoms from the contrasting parameter sets # we have generated. atomic_inputs = torch.cat((inputs[:, None, :], contrasting_inputs), dim=1).reshape(batch_size * num_atoms, -1) # Evaluate large batch giving (batch_size * num_atoms) log prob posterior evals. log_prob_posterior = self._neural_posterior.log_prob( atomic_inputs, repeated_context, normalize_snpe_density=False) assert torch.isfinite( log_prob_posterior).all(), "NaN/inf detected in posterior eval." log_prob_posterior = log_prob_posterior.reshape(batch_size, num_atoms) # Get (batch_size * num_atoms) log prob prior evals. log_prob_prior = self._prior.log_prob(atomic_inputs) log_prob_prior = log_prob_prior.reshape(batch_size, num_atoms) assert torch.isfinite( log_prob_prior).all(), "NaN/inf detected in prior eval." # Compute unnormalized proposal posterior. unnormalized_log_prob_proposal_posterior = log_prob_posterior - log_prob_prior # Normalize proposal posterior across discrete set of atoms. log_prob_proposal_posterior = self.calibration_kernel( context ) * unnormalized_log_prob_proposal_posterior[:, 0] - torch.logsumexp( unnormalized_log_prob_proposal_posterior, dim=-1) assert torch.isfinite(log_prob_proposal_posterior).all( ), "NaN/inf detected in proposal posterior eval." # todo: this implementation is not perfect: it evaluates the posterior # todo: at all prior samples if self._use_combined_loss: log_prob_posterior_non_atomic = self._neural_posterior.log_prob( inputs, context, normalize_snpe_density=False) masks = masks.reshape(-1) log_prob_proposal_posterior = ( masks * log_prob_posterior_non_atomic + log_prob_proposal_posterior) return log_prob_proposal_posterior