Ejemplo n.º 1
0
    def mask_tokens(
        self, in_batch: torch.Tensor, lab_batch: torch.tensor
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        true_inputs = in_batch.clone()

        probability_matrix = torch.full(lab_batch.shape, self.mlm_probability)
        special_tokens_mask = [
            self.tokenizer.get_special_tokens_mask(
                val, already_has_special_tokens=True)
            for val in lab_batch.tolist()
        ]
        probability_matrix.masked_fill_(torch.tensor(special_tokens_mask,
                                                     dtype=torch.bool),
                                        value=0.0)
        if self.tokenizer._pad_token is not None:
            padding_mask = lab_batch.eq(self.tokenizer.pad_token_id)
            lab_batch[padding_mask] = -100
            probability_matrix.masked_fill_(padding_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        true_inputs[~masked_indices] = -100

        indices_replaced = torch.bernoulli(torch.full(
            lab_batch.shape, 0.5)).bool() & masked_indices
        in_batch[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
            self.tokenizer.mask_token)

        indices_sim = torch.bernoulli(torch.full(
            lab_batch.shape, 0.6)).bool() & masked_indices & ~indices_replaced
        in_batch[indices_sim] = self.replace_sim1(in_batch[indices_sim])

        indices_random = masked_indices & ~indices_replaced & ~indices_sim
        in_batch[indices_random] = self.replace_sim2(in_batch[indices_random])

        return in_batch, lab_batch, true_inputs
Ejemplo n.º 2
0
 def init_prototypes(self, input_data: torch.tensor, labels: torch.tensor):
     """
     Initialize prototypes from provided data anf labels.
     """
     input_data = self._get_embeddings(input_data)
     self._unique_classes = torch.unique(
         labels)  # unique classes from support
     classes_indexes = [
         labels.eq(c).nonzero().squeeze(1) for c in self._unique_classes
     ]  # classes indexes in support
     self._prototypes = torch.stack(
         [input_data[idxs].mean(0) for idxs in classes_indexes])
     self._has_prototypes = True
Ejemplo n.º 3
0
def mask_fill(
    fill_value: float,
    tokens: torch.tensor,
    embeddings: torch.tensor,
    padding_index: int,
) -> torch.tensor:
    """
    Function that masks embeddings representing padded elements.
    :param fill_value: the value to fill the embeddings belonging to padded tokens.
    :param tokens: The input sequences [bsz x seq_len].
    :param embeddings: word embeddings [bsz x seq_len x hiddens].
    :param padding_index: Index of the padding token.
    """
    padding_mask = tokens.eq(padding_index).unsqueeze(-1)
    return embeddings.float().masked_fill_(padding_mask, fill_value).type_as(embeddings)
Ejemplo n.º 4
0
    def __call__(self, predicted: torch.tensor, gt: torch.tensor, gamma=2.0):
        pred = torch.clamp(predicted.sigmoid_(), min=1e-4, max=1 - 1e-4)
        pos_inds = gt.eq(1).float()
        neg_inds = gt.lt(1).float()

        neg_weights = torch.pow(1 - gt, 4)
        pos_loss = -torch.log(pred) * torch.pow(1 - pred, gamma) * pos_inds
        neg_loss = -torch.log(1 - pred) * torch.pow(
            pred, gamma) * neg_inds * neg_weights

        num_pos = pos_inds.float().sum()
        pos_loss = pos_loss.sum()
        neg_loss = neg_loss.sum()

        if num_pos == 0:
            return neg_loss
        return (pos_loss + neg_loss) / num_pos