예제 #1
0
    def __init__(self, cont, logits=None, probs=None, validate_args=None):
        """
        cont: a (properly normalised) distribution over (0, 1)
            e.g. RightTruncatedExponential, Uniform(0, 1)
        logits: [..., 3] 
        probs: [..., 3]
        """
        if logits is None and probs is None:
            raise ValueError("You must specify either logits or probs")
        if logits is not None and probs is not None:
            raise ValueError("You cannot specify both logits and probs")
        shape = cont.batch_shape
        super(MixtureD01C01, self).__init__(batch_shape=shape,
                                            validate_args=validate_args)
        if logits is None:
            self.logits = probs_to_logits(probs, is_binary=False)
            self.probs = probs
        else:
            self.logits = logits
            self.probs = logits_to_probs(logits, is_binary=False)

        self.logprobs = F.log_softmax(self.logits, dim=-1)
        self.cont = cont
        self.p0, self.p1, self.pc = [
            t.squeeze(-1) for t in torch.split(self.probs, 1, dim=-1)
        ]
        self.log_p0, self.log_p1, self.log_pc = [
            t.squeeze(-1) for t in torch.split(self.logprobs, 1, dim=-1)
        ]
        self.uniform = Uniform(
            torch.zeros(shape).to(self.logits.device),
            torch.ones(shape).to(self.logits.device))
예제 #2
0
 def _convert_logits_to_ps(self, dist_params):
     if 'logits' in dist_params:
         logits = torch.tensor(dist_params.pop('logits'))
         is_multidimensional = self.get_test_distribution_name() != 'Bernoulli'
         probs = logits_to_probs(logits, is_binary=not is_multidimensional)
         dist_params['probs'] = list(probs.detach().cpu().numpy())
     return dist_params
예제 #3
0
 def _convert_logits_to_ps(self, dist_params):
     if 'logits' in dist_params:
         logits = torch.tensor(dist_params.pop('logits'))
         is_multidimensional = self.get_test_distribution_name() != 'Bernoulli'
         probs = logits_to_probs(logits, is_binary=not is_multidimensional)
         dist_params['probs'] = list(probs.detach().cpu().numpy())
     return dist_params
예제 #4
0
 def aggregate_predictions(self, predictions, dim=0):
     probs = dist_utils.logits_to_probs(
         predictions, is_binary=self.is_binary
     ) if self.logit_predictions else predictions
     avg_probs = probs.mean(dim)
     return dist_utils.probs_to_logits(
         avg_probs,
         is_binary=self.is_binary) if self.logit_predictions else avg_probs
예제 #5
0
 def _convert_logits_to_ps(self, dist_params):
     if "logits" in dist_params:
         logits = torch.tensor(dist_params.pop("logits"))
         is_multidimensional = self.get_test_distribution_name() not in [
             "Bernoulli",
             "Geometric",
         ]
         probs = logits_to_probs(logits, is_binary=not is_multidimensional)
         dist_params["probs"] = list(probs.detach().cpu().numpy())
     return dist_params
예제 #6
0
    def predict_next_q_values(self, next_observations: Dict[Union[str, int], Dict[str, torch.Tensor]],
                              next_actions: Dict[Union[str, int], Dict[str, torch.Tensor]],
                              next_actions_logits: Dict[Union[str, int], Dict[str, torch.Tensor]],
                              next_actions_log_probs: Dict[Union[str, int], Dict[str, torch.Tensor]],
                              alpha: Dict[Union[str, int], torch.Tensor]) \
            -> Dict[Union[str, int], Union[torch.Tensor, Dict[str, torch.Tensor]]]:
        """implementation of
        :class:`~maze.core.agent.torch_state_action_critic.TorchStateActionCritic`
        """

        flattened_next_observations = flatten_spaces(next_observations.values())
        flattened_next_actions = flatten_spaces(next_actions.values())
        flattened_next_actions_logits = flatten_spaces(next_actions_logits.values())
        flattened_next_action_log_probs = flatten_spaces(next_actions_log_probs.values())

        assert len(self.step_critic_keys) == 1
        step_id = self.step_critic_keys[0]
        alpha = sum(alpha.values())

        if all(self.only_discrete_spaces.values()):
            next_q_values = self.compute_state_action_values_step(flattened_next_observations,
                                                                  critic_id=(step_id, self.target_key))
            transpose_next_q_value = {k: [dic[k] for dic in next_q_values] for k in next_q_values[0]}
            next_q_value = dict()
            for q_action_head, q_values in transpose_next_q_value.items():
                action_key = q_action_head.replace('_q_values', '')
                tmp_q_value = torch.stack(q_values).min(dim=0).values
                next_action_probs = logits_to_probs(flattened_next_actions_logits[action_key])
                next_action_log_probs = torch.log(next_action_probs + (next_action_probs == 0.0).float() * 1e-8)

                # output shape of V(st) is (rollout_length, batch_dim)
                next_q_value[action_key] = torch.matmul(
                    next_action_probs.unsqueeze(-2),
                    (tmp_q_value - alpha * next_action_log_probs).unsqueeze(-1)).squeeze(-1).squeeze(-1)

        else:
            next_q_value = self.compute_state_action_value_step(flattened_next_observations,
                                                                flattened_next_actions,
                                                                (step_id, self.target_key))
            next_q_value = torch.stack(next_q_value).min(dim=0).values - alpha * \
                            torch.stack(list(flattened_next_action_log_probs.values())).mean(dim=0)

        return {step_id: next_q_value}
예제 #7
0
    def predict_next_q_values(self, next_observations: Dict[Union[str, int], Dict[str, torch.Tensor]],
                              next_actions: Dict[Union[str, int], Dict[str, torch.Tensor]],
                              next_actions_logits: Dict[Union[str, int], Dict[str, torch.Tensor]],
                              next_actions_log_probs: Dict[Union[str, int], Dict[str, torch.Tensor]],
                              alpha: Dict[Union[str, int], torch.Tensor]) -> Dict[
        Union[str, int], Union[torch.Tensor, Dict[str, torch.Tensor]]]:
        """implementation of
        :class:`~maze.core.agent.torch_state_action_critic.TorchStateActionCritic`
        """

        next_q_values = dict()
        for step_id in next_observations.keys():
            if self.only_discrete_spaces[step_id]:
                next_q_value = self.compute_state_action_values_step(next_observations[step_id],
                                                                     critic_id=(step_id, self.target_key))
                transpose_next_q_value = {k: [dic[k] for dic in next_q_value] for k in next_q_value[0]}
                next_q_values[step_id] = dict()
                for q_action_head, q_values in transpose_next_q_value.items():
                    action_key = q_action_head.replace('_q_values', '')
                    tmp_q_value = torch.stack(q_values).min(dim=0).values
                    next_action_probs = logits_to_probs(next_actions_logits[step_id][action_key])
                    next_action_log_probs = torch.log(next_action_probs + (next_action_probs == 0.0).float() * 1e-8)
                    # output shape of V(st) is (rollout_length, batch_dim)

                    next_q_values[step_id][action_key] = torch.matmul(
                        next_action_probs.unsqueeze(-2),
                        (tmp_q_value - alpha[step_id] * next_action_log_probs).unsqueeze(-1)).squeeze(-1).squeeze(-1)
            else:
                next_q_value = self.compute_state_action_value_step(next_observations[step_id],
                                                                    next_actions[step_id],
                                                                    (step_id, self.target_key))
                # output shape of V(st) is (rollout_length, batch_size)
                next_q_values[step_id] = torch.stack(next_q_value).min(dim=0).values - alpha[step_id] * \
                                         torch.stack(list(next_actions_log_probs[step_id].values())).mean(dim=0)

        return next_q_values
    def forward(self, model, target_model, sample, reduce=True):
        """Compute the loss for the given sample.

        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        nsentences, ntokens = sample["nsentences"], sample["ntokens"]
        # B x T
        src_tokens, src_lengths, prev_output_tokens = (
            sample["net_input"]["src_tokens"],
            sample["net_input"]["src_lengths"],
            sample["net_input"]["prev_output_tokens"],
        )
        tgt_tokens, nat_prev_output_tokens = sample["target"], sample[
            "prev_target"]
        """ forward target_model """
        with torch.no_grad():
            target_model_outputs = target_model(
                src_tokens, src_lengths, nat_prev_output_tokens if isinstance(
                    target_model, NATransformerModel) else prev_output_tokens,
                tgt_tokens)
            target_model_logits, target_model_masks = (
                target_model_outputs["word_ins"]["out"],
                target_model_outputs["word_ins"].get("mask", None),
            )
        """ forward model """
        outputs = model(
            src_tokens, src_lengths, nat_prev_output_tokens if isinstance(
                model, NATransformerModel) else prev_output_tokens, tgt_tokens)
        model_logits, model_masks, smoothing = (outputs["word_ins"]["out"],
                                                outputs["word_ins"].get(
                                                    "mask", None),
                                                outputs["word_ins"].get(
                                                    "ls", 0.0))
        """ model loss
        1. label smoothed ground-truth loss (label loss)
        2. kd loss
        """
        lb_losses = self._compute_loss(
            model_logits,
            tgt_tokens,
            model_masks,
            smoothing,
            name='label-loss',
            factor=1.  # 1. - kd_factor
        )

        kd_losses = self._compute_loss_ctrl(
            model_logits,
            logits_to_probs(target_model_logits).detach(),
            torch.logical_and(model_masks, target_model_masks),
            name='kd-loss',
            factor=model.kd_factor,
            controller=model.controller
            if model.use_control_kd_factor else None,
        )

        losses = [
            lb_losses,
            kd_losses,
        ]
        """ length prediction module
        length prediction loss
        """
        if "length" in outputs:
            length_losses = self._compute_loss(outputs["length"].get("out"),
                                               outputs["length"].get("tgt"),
                                               name="length-loss",
                                               factor=outputs["length"].get(
                                                   "factor", 1.0))
            losses += [length_losses]

        loss = sum(l["loss"] for l in losses)
        nll_loss = loss.new_tensor(0)

        # NOTE:
        # we don't need to use sample_size as denominator for the gradient
        # here sample_size is just used for logging
        sample_size = 1
        logging_output = {
            "loss": loss.data,
            "nll_loss": nll_loss.data,
            "ntokens": ntokens,
            "nsentences": nsentences,
            "sample_size": sample_size,
        }

        for l in losses:
            logging_output[l["name"]] = (utils.item(
                l["loss"].data / l["factor"]) if reduce else l[["loss"]].data /
                                         l["factor"])

        return loss, sample_size, logging_output
예제 #9
0
 def gate(self):
     return logits_to_probs(self.gate_logits)
예제 #10
0
 def probs(self):
     return logits_to_probs(self.logits, is_binary=True)
예제 #11
0
 def mixture_probs(self) -> torch.Tensor:
     return logits_to_probs(self.mixture_logits, is_binary=True)
예제 #12
0
 def zi_probs(self) -> torch.Tensor:
     return logits_to_probs(self.zi_logits, is_binary=True)
예제 #13
0
 def probs(self):
     return logits_to_probs(self.logits)
예제 #14
0
def _perplexity_class_test(
    rank: int,
    worldsize: int,
    probs: Optional[torch.Tensor],
    logits: Optional[torch.Tensor],
    dist_sync_on_step: bool,
    metric_args: dict = {},
    check_dist_sync_on_step: bool = True,
    check_batch: bool = True,
    atol: float = 1e-8,
):
    """ Utility function doing the actual comparison between lightning class metric
        and reference metric.
        Args:
            rank: rank of current process
            worldsize: number of processes
            probs: torch tensor with probabilities
            logits: torch tensor with logits. The function checks ``probs`` and ``logits are mutually exclusive for
                ``Perplexity`` metric.
            dist_sync_on_step: bool, if true will synchronize metric state across
                processes at each ``forward()``
            metric_args: dict with additional arguments used for class initialization
            check_dist_sync_on_step: bool, if true will check if the metric is also correctly
                calculated per batch per device (and not just at the end)
            check_batch: bool, if true will check if the metric is also correctly
                calculated across devices for each batch (and not just at the end)
    """
    # Instanciate lightning metric
    perplexity = Perplexity(compute_on_step=True,
                            dist_sync_on_step=dist_sync_on_step,
                            **metric_args)
    if (probs is None) == (logits is None):
        with pytest.raises(ValueError):
            perplexity(probs, logits)
        return

    # verify perplexity works after being loaded from pickled state
    pickled_metric = pickle.dumps(perplexity)
    perplexity = pickle.loads(pickled_metric)

    for i in range(rank, NUM_BATCHES, worldsize):
        batch_result = perplexity(None if probs is None else probs[i],
                                  None if logits is None else logits[i])

        if perplexity.dist_sync_on_step:
            if rank == 0:
                if probs is not None:
                    ddp_probs = torch.stack(
                        [probs[i + r] for r in range(worldsize)])
                else:
                    ddp_logits = torch.stack(
                        [logits[i + r] for r in range(worldsize)])
                    ddp_probs = logits_to_probs(ddp_logits, is_binary=False)
                sk_batch_result = reference_perplexity_func(ddp_probs)
                # assert for dist_sync_on_step
                if check_dist_sync_on_step:
                    assert np.allclose(batch_result.numpy(),
                                       sk_batch_result,
                                       atol=atol)
        else:
            if probs is None:
                p = logits_to_probs(logits[i], is_binary=False)
            else:
                p = probs[i]
            sk_batch_result = reference_perplexity_func(p)
            # assert for batch
            if check_batch:
                assert np.allclose(batch_result.numpy(),
                                   sk_batch_result,
                                   atol=atol)

    assert (probs is None) != (logits is None)
    # check on all batches on all ranks
    result = perplexity.compute()
    assert isinstance(result, torch.Tensor)

    if probs is None:
        probs = logits_to_probs(logits, is_binary=False)
    sk_result = reference_perplexity_func(probs)

    # assert after aggregation
    assert np.allclose(result.numpy(), sk_result, atol=atol)
예제 #15
0
 def probs(self):
     return logits_to_probs(self.logits)
예제 #16
0
 def _probsfn(self):
     return lambda conds: tcdu.logits_to_probs(self._logitsfn(conds),
                                               is_binary=True)
예제 #17
0
 def _probsfn(self):
     return lambda conds: tcdu.logits_to_probs(self._logitsfn(conds))
예제 #18
0
 def probs(self):
     return logits_to_probs(self.logits, is_binary=True)
예제 #19
0
    def _compute_policy_loss(self, worker_output: StructuredSpacesRecord) -> \
            Tuple[Dict[Union[str, int], torch.Tensor],
                  Dict[Union[str, int], Union[torch.Tensor, Dict[str, torch.Tensor]]],
                  Dict[Union[str, int], Union[torch.Tensor, Dict[str, torch.Tensor]]],
                  Dict[Union[str, int], torch.Tensor]]:
        """Compute the critic losses.

        :param worker_output: The batched output of the workers.
        :return: The policy losses as well a few other metrics needed for the entropy loss computation and stats.
        """

        # Sample actions and compute action log probabilities (continuous steps)/ action probabilities (discrete steps)
        policy_losses, action_entropies, action_log_probs, actions_sampled = dict(
        ), dict(), dict(), dict()
        action_probs = dict()

        for step_key in self.sub_step_keys:
            step_obs = worker_output.observations_dict[step_key]
            learner_policy_out = self.learner_model.policy.compute_substep_policy_output(
                step_obs, ActorID(step_key, 0))
            learner_action = learner_policy_out.prob_dist.sample()

            # Average the logp_policy of all actions in this step (all steps if shared critic)
            if self.learner_model.critic.only_discrete_spaces[step_key]:
                probs_policy = {
                    action_key: logits_to_probs(x)
                    for action_key, x in
                    learner_policy_out.action_logits.items()
                }
                logp_policy = {
                    action_key: torch.log(x + (x == 0.0).float() * 1e-8)
                    for action_key, x in probs_policy.items()
                }
            else:
                probs_policy = None
                logp_policy = torch.stack(
                    list(
                        learner_policy_out.prob_dist.log_prob(
                            learner_action).values())).mean(dim=0)

            action_probs[step_key] = probs_policy
            action_log_probs[step_key] = logp_policy
            actions_sampled[step_key] = learner_action
            action_entropies[step_key] = learner_policy_out.entropy

        # Predict Q values
        q_values = self.learner_model.critic.predict_q_values(
            worker_output.observations_dict,
            actions_sampled,
            gather_output=False)
        if len(q_values) < len(self.sub_step_keys):
            assert len(q_values) == 1
            critic_key = list(q_values.keys())[0]
            q_values = {
                step_key: q_values[critic_key]
                for step_key in self.sub_step_keys
            }

        # Compute loss
        for step_key in self.sub_step_keys:
            action_log_probs_step = action_log_probs[step_key]
            q_values_step = q_values[step_key]

            if self.learner_model.critic.only_discrete_spaces[step_key]:
                action_probs_step = action_probs[step_key]

                policy_losses_per_action = list()
                # Compute the policy loss for each individual action
                for action_key in action_log_probs_step.keys():
                    q_action_key = action_key + '_q_values'
                    action_q_values = torch.stack([
                        q_values_sub_critic[q_action_key]
                        for q_values_sub_critic in q_values_step
                    ]).min(dim=0).values
                    q_term = (self.curr_entropy_coef[step_key] *
                              action_log_probs_step[action_key] -
                              action_q_values)
                    action_policy_loss = torch.matmul(
                        action_probs_step[action_key].unsqueeze(-2),
                        q_term.unsqueeze(-1)).squeeze(-1).squeeze(-1)
                    policy_losses_per_action.append(action_policy_loss)
                # Sum the losses of all action together
                policy_losses_per_step = torch.stack(
                    policy_losses_per_action).sum(dim=0)
                # Average the losses w.r.t. to the batch
                policy_losses[step_key] = policy_losses_per_step.mean()
            else:
                # Do not detach q_values in discrete setting
                q_value_per_step = torch.stack(q_values_step).min(dim=0).values
                # Average the losses w.r.t. to the batch
                policy_losses[step_key] = torch.mean(
                    (self.curr_entropy_coef[step_key] * action_log_probs_step -
                     q_value_per_step))

        return policy_losses, action_probs, action_log_probs, action_entropies