def __init__(self, cont, logits=None, probs=None, validate_args=None): """ cont: a (properly normalised) distribution over (0, 1) e.g. RightTruncatedExponential, Uniform(0, 1) logits: [..., 3] probs: [..., 3] """ if logits is None and probs is None: raise ValueError("You must specify either logits or probs") if logits is not None and probs is not None: raise ValueError("You cannot specify both logits and probs") shape = cont.batch_shape super(MixtureD01C01, self).__init__(batch_shape=shape, validate_args=validate_args) if logits is None: self.logits = probs_to_logits(probs, is_binary=False) self.probs = probs else: self.logits = logits self.probs = logits_to_probs(logits, is_binary=False) self.logprobs = F.log_softmax(self.logits, dim=-1) self.cont = cont self.p0, self.p1, self.pc = [ t.squeeze(-1) for t in torch.split(self.probs, 1, dim=-1) ] self.log_p0, self.log_p1, self.log_pc = [ t.squeeze(-1) for t in torch.split(self.logprobs, 1, dim=-1) ] self.uniform = Uniform( torch.zeros(shape).to(self.logits.device), torch.ones(shape).to(self.logits.device))
def _convert_logits_to_ps(self, dist_params): if 'logits' in dist_params: logits = torch.tensor(dist_params.pop('logits')) is_multidimensional = self.get_test_distribution_name() != 'Bernoulli' probs = logits_to_probs(logits, is_binary=not is_multidimensional) dist_params['probs'] = list(probs.detach().cpu().numpy()) return dist_params
def aggregate_predictions(self, predictions, dim=0): probs = dist_utils.logits_to_probs( predictions, is_binary=self.is_binary ) if self.logit_predictions else predictions avg_probs = probs.mean(dim) return dist_utils.probs_to_logits( avg_probs, is_binary=self.is_binary) if self.logit_predictions else avg_probs
def _convert_logits_to_ps(self, dist_params): if "logits" in dist_params: logits = torch.tensor(dist_params.pop("logits")) is_multidimensional = self.get_test_distribution_name() not in [ "Bernoulli", "Geometric", ] probs = logits_to_probs(logits, is_binary=not is_multidimensional) dist_params["probs"] = list(probs.detach().cpu().numpy()) return dist_params
def predict_next_q_values(self, next_observations: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions_logits: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions_log_probs: Dict[Union[str, int], Dict[str, torch.Tensor]], alpha: Dict[Union[str, int], torch.Tensor]) \ -> Dict[Union[str, int], Union[torch.Tensor, Dict[str, torch.Tensor]]]: """implementation of :class:`~maze.core.agent.torch_state_action_critic.TorchStateActionCritic` """ flattened_next_observations = flatten_spaces(next_observations.values()) flattened_next_actions = flatten_spaces(next_actions.values()) flattened_next_actions_logits = flatten_spaces(next_actions_logits.values()) flattened_next_action_log_probs = flatten_spaces(next_actions_log_probs.values()) assert len(self.step_critic_keys) == 1 step_id = self.step_critic_keys[0] alpha = sum(alpha.values()) if all(self.only_discrete_spaces.values()): next_q_values = self.compute_state_action_values_step(flattened_next_observations, critic_id=(step_id, self.target_key)) transpose_next_q_value = {k: [dic[k] for dic in next_q_values] for k in next_q_values[0]} next_q_value = dict() for q_action_head, q_values in transpose_next_q_value.items(): action_key = q_action_head.replace('_q_values', '') tmp_q_value = torch.stack(q_values).min(dim=0).values next_action_probs = logits_to_probs(flattened_next_actions_logits[action_key]) next_action_log_probs = torch.log(next_action_probs + (next_action_probs == 0.0).float() * 1e-8) # output shape of V(st) is (rollout_length, batch_dim) next_q_value[action_key] = torch.matmul( next_action_probs.unsqueeze(-2), (tmp_q_value - alpha * next_action_log_probs).unsqueeze(-1)).squeeze(-1).squeeze(-1) else: next_q_value = self.compute_state_action_value_step(flattened_next_observations, flattened_next_actions, (step_id, self.target_key)) next_q_value = torch.stack(next_q_value).min(dim=0).values - alpha * \ torch.stack(list(flattened_next_action_log_probs.values())).mean(dim=0) return {step_id: next_q_value}
def predict_next_q_values(self, next_observations: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions_logits: Dict[Union[str, int], Dict[str, torch.Tensor]], next_actions_log_probs: Dict[Union[str, int], Dict[str, torch.Tensor]], alpha: Dict[Union[str, int], torch.Tensor]) -> Dict[ Union[str, int], Union[torch.Tensor, Dict[str, torch.Tensor]]]: """implementation of :class:`~maze.core.agent.torch_state_action_critic.TorchStateActionCritic` """ next_q_values = dict() for step_id in next_observations.keys(): if self.only_discrete_spaces[step_id]: next_q_value = self.compute_state_action_values_step(next_observations[step_id], critic_id=(step_id, self.target_key)) transpose_next_q_value = {k: [dic[k] for dic in next_q_value] for k in next_q_value[0]} next_q_values[step_id] = dict() for q_action_head, q_values in transpose_next_q_value.items(): action_key = q_action_head.replace('_q_values', '') tmp_q_value = torch.stack(q_values).min(dim=0).values next_action_probs = logits_to_probs(next_actions_logits[step_id][action_key]) next_action_log_probs = torch.log(next_action_probs + (next_action_probs == 0.0).float() * 1e-8) # output shape of V(st) is (rollout_length, batch_dim) next_q_values[step_id][action_key] = torch.matmul( next_action_probs.unsqueeze(-2), (tmp_q_value - alpha[step_id] * next_action_log_probs).unsqueeze(-1)).squeeze(-1).squeeze(-1) else: next_q_value = self.compute_state_action_value_step(next_observations[step_id], next_actions[step_id], (step_id, self.target_key)) # output shape of V(st) is (rollout_length, batch_size) next_q_values[step_id] = torch.stack(next_q_value).min(dim=0).values - alpha[step_id] * \ torch.stack(list(next_actions_log_probs[step_id].values())).mean(dim=0) return next_q_values
def forward(self, model, target_model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: 1) the loss 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ nsentences, ntokens = sample["nsentences"], sample["ntokens"] # B x T src_tokens, src_lengths, prev_output_tokens = ( sample["net_input"]["src_tokens"], sample["net_input"]["src_lengths"], sample["net_input"]["prev_output_tokens"], ) tgt_tokens, nat_prev_output_tokens = sample["target"], sample[ "prev_target"] """ forward target_model """ with torch.no_grad(): target_model_outputs = target_model( src_tokens, src_lengths, nat_prev_output_tokens if isinstance( target_model, NATransformerModel) else prev_output_tokens, tgt_tokens) target_model_logits, target_model_masks = ( target_model_outputs["word_ins"]["out"], target_model_outputs["word_ins"].get("mask", None), ) """ forward model """ outputs = model( src_tokens, src_lengths, nat_prev_output_tokens if isinstance( model, NATransformerModel) else prev_output_tokens, tgt_tokens) model_logits, model_masks, smoothing = (outputs["word_ins"]["out"], outputs["word_ins"].get( "mask", None), outputs["word_ins"].get( "ls", 0.0)) """ model loss 1. label smoothed ground-truth loss (label loss) 2. kd loss """ lb_losses = self._compute_loss( model_logits, tgt_tokens, model_masks, smoothing, name='label-loss', factor=1. # 1. - kd_factor ) kd_losses = self._compute_loss_ctrl( model_logits, logits_to_probs(target_model_logits).detach(), torch.logical_and(model_masks, target_model_masks), name='kd-loss', factor=model.kd_factor, controller=model.controller if model.use_control_kd_factor else None, ) losses = [ lb_losses, kd_losses, ] """ length prediction module length prediction loss """ if "length" in outputs: length_losses = self._compute_loss(outputs["length"].get("out"), outputs["length"].get("tgt"), name="length-loss", factor=outputs["length"].get( "factor", 1.0)) losses += [length_losses] loss = sum(l["loss"] for l in losses) nll_loss = loss.new_tensor(0) # NOTE: # we don't need to use sample_size as denominator for the gradient # here sample_size is just used for logging sample_size = 1 logging_output = { "loss": loss.data, "nll_loss": nll_loss.data, "ntokens": ntokens, "nsentences": nsentences, "sample_size": sample_size, } for l in losses: logging_output[l["name"]] = (utils.item( l["loss"].data / l["factor"]) if reduce else l[["loss"]].data / l["factor"]) return loss, sample_size, logging_output
def gate(self): return logits_to_probs(self.gate_logits)
def probs(self): return logits_to_probs(self.logits, is_binary=True)
def mixture_probs(self) -> torch.Tensor: return logits_to_probs(self.mixture_logits, is_binary=True)
def zi_probs(self) -> torch.Tensor: return logits_to_probs(self.zi_logits, is_binary=True)
def probs(self): return logits_to_probs(self.logits)
def _perplexity_class_test( rank: int, worldsize: int, probs: Optional[torch.Tensor], logits: Optional[torch.Tensor], dist_sync_on_step: bool, metric_args: dict = {}, check_dist_sync_on_step: bool = True, check_batch: bool = True, atol: float = 1e-8, ): """ Utility function doing the actual comparison between lightning class metric and reference metric. Args: rank: rank of current process worldsize: number of processes probs: torch tensor with probabilities logits: torch tensor with logits. The function checks ``probs`` and ``logits are mutually exclusive for ``Perplexity`` metric. dist_sync_on_step: bool, if true will synchronize metric state across processes at each ``forward()`` metric_args: dict with additional arguments used for class initialization check_dist_sync_on_step: bool, if true will check if the metric is also correctly calculated per batch per device (and not just at the end) check_batch: bool, if true will check if the metric is also correctly calculated across devices for each batch (and not just at the end) """ # Instanciate lightning metric perplexity = Perplexity(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) if (probs is None) == (logits is None): with pytest.raises(ValueError): perplexity(probs, logits) return # verify perplexity works after being loaded from pickled state pickled_metric = pickle.dumps(perplexity) perplexity = pickle.loads(pickled_metric) for i in range(rank, NUM_BATCHES, worldsize): batch_result = perplexity(None if probs is None else probs[i], None if logits is None else logits[i]) if perplexity.dist_sync_on_step: if rank == 0: if probs is not None: ddp_probs = torch.stack( [probs[i + r] for r in range(worldsize)]) else: ddp_logits = torch.stack( [logits[i + r] for r in range(worldsize)]) ddp_probs = logits_to_probs(ddp_logits, is_binary=False) sk_batch_result = reference_perplexity_func(ddp_probs) # assert for dist_sync_on_step if check_dist_sync_on_step: assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) else: if probs is None: p = logits_to_probs(logits[i], is_binary=False) else: p = probs[i] sk_batch_result = reference_perplexity_func(p) # assert for batch if check_batch: assert np.allclose(batch_result.numpy(), sk_batch_result, atol=atol) assert (probs is None) != (logits is None) # check on all batches on all ranks result = perplexity.compute() assert isinstance(result, torch.Tensor) if probs is None: probs = logits_to_probs(logits, is_binary=False) sk_result = reference_perplexity_func(probs) # assert after aggregation assert np.allclose(result.numpy(), sk_result, atol=atol)
def _probsfn(self): return lambda conds: tcdu.logits_to_probs(self._logitsfn(conds), is_binary=True)
def _probsfn(self): return lambda conds: tcdu.logits_to_probs(self._logitsfn(conds))
def _compute_policy_loss(self, worker_output: StructuredSpacesRecord) -> \ Tuple[Dict[Union[str, int], torch.Tensor], Dict[Union[str, int], Union[torch.Tensor, Dict[str, torch.Tensor]]], Dict[Union[str, int], Union[torch.Tensor, Dict[str, torch.Tensor]]], Dict[Union[str, int], torch.Tensor]]: """Compute the critic losses. :param worker_output: The batched output of the workers. :return: The policy losses as well a few other metrics needed for the entropy loss computation and stats. """ # Sample actions and compute action log probabilities (continuous steps)/ action probabilities (discrete steps) policy_losses, action_entropies, action_log_probs, actions_sampled = dict( ), dict(), dict(), dict() action_probs = dict() for step_key in self.sub_step_keys: step_obs = worker_output.observations_dict[step_key] learner_policy_out = self.learner_model.policy.compute_substep_policy_output( step_obs, ActorID(step_key, 0)) learner_action = learner_policy_out.prob_dist.sample() # Average the logp_policy of all actions in this step (all steps if shared critic) if self.learner_model.critic.only_discrete_spaces[step_key]: probs_policy = { action_key: logits_to_probs(x) for action_key, x in learner_policy_out.action_logits.items() } logp_policy = { action_key: torch.log(x + (x == 0.0).float() * 1e-8) for action_key, x in probs_policy.items() } else: probs_policy = None logp_policy = torch.stack( list( learner_policy_out.prob_dist.log_prob( learner_action).values())).mean(dim=0) action_probs[step_key] = probs_policy action_log_probs[step_key] = logp_policy actions_sampled[step_key] = learner_action action_entropies[step_key] = learner_policy_out.entropy # Predict Q values q_values = self.learner_model.critic.predict_q_values( worker_output.observations_dict, actions_sampled, gather_output=False) if len(q_values) < len(self.sub_step_keys): assert len(q_values) == 1 critic_key = list(q_values.keys())[0] q_values = { step_key: q_values[critic_key] for step_key in self.sub_step_keys } # Compute loss for step_key in self.sub_step_keys: action_log_probs_step = action_log_probs[step_key] q_values_step = q_values[step_key] if self.learner_model.critic.only_discrete_spaces[step_key]: action_probs_step = action_probs[step_key] policy_losses_per_action = list() # Compute the policy loss for each individual action for action_key in action_log_probs_step.keys(): q_action_key = action_key + '_q_values' action_q_values = torch.stack([ q_values_sub_critic[q_action_key] for q_values_sub_critic in q_values_step ]).min(dim=0).values q_term = (self.curr_entropy_coef[step_key] * action_log_probs_step[action_key] - action_q_values) action_policy_loss = torch.matmul( action_probs_step[action_key].unsqueeze(-2), q_term.unsqueeze(-1)).squeeze(-1).squeeze(-1) policy_losses_per_action.append(action_policy_loss) # Sum the losses of all action together policy_losses_per_step = torch.stack( policy_losses_per_action).sum(dim=0) # Average the losses w.r.t. to the batch policy_losses[step_key] = policy_losses_per_step.mean() else: # Do not detach q_values in discrete setting q_value_per_step = torch.stack(q_values_step).min(dim=0).values # Average the losses w.r.t. to the batch policy_losses[step_key] = torch.mean( (self.curr_entropy_coef[step_key] * action_log_probs_step - q_value_per_step)) return policy_losses, action_probs, action_log_probs, action_entropies