def mask_tokens( self, in_batch: torch.Tensor, lab_batch: torch.tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: true_inputs = in_batch.clone() probability_matrix = torch.full(lab_batch.shape, self.mlm_probability) special_tokens_mask = [ self.tokenizer.get_special_tokens_mask( val, already_has_special_tokens=True) for val in lab_batch.tolist() ] probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) if self.tokenizer._pad_token is not None: padding_mask = lab_batch.eq(self.tokenizer.pad_token_id) lab_batch[padding_mask] = -100 probability_matrix.masked_fill_(padding_mask, value=0.0) masked_indices = torch.bernoulli(probability_matrix).bool() true_inputs[~masked_indices] = -100 indices_replaced = torch.bernoulli(torch.full( lab_batch.shape, 0.5)).bool() & masked_indices in_batch[indices_replaced] = self.tokenizer.convert_tokens_to_ids( self.tokenizer.mask_token) indices_sim = torch.bernoulli(torch.full( lab_batch.shape, 0.6)).bool() & masked_indices & ~indices_replaced in_batch[indices_sim] = self.replace_sim1(in_batch[indices_sim]) indices_random = masked_indices & ~indices_replaced & ~indices_sim in_batch[indices_random] = self.replace_sim2(in_batch[indices_random]) return in_batch, lab_batch, true_inputs
def init_prototypes(self, input_data: torch.tensor, labels: torch.tensor): """ Initialize prototypes from provided data anf labels. """ input_data = self._get_embeddings(input_data) self._unique_classes = torch.unique( labels) # unique classes from support classes_indexes = [ labels.eq(c).nonzero().squeeze(1) for c in self._unique_classes ] # classes indexes in support self._prototypes = torch.stack( [input_data[idxs].mean(0) for idxs in classes_indexes]) self._has_prototypes = True
def mask_fill( fill_value: float, tokens: torch.tensor, embeddings: torch.tensor, padding_index: int, ) -> torch.tensor: """ Function that masks embeddings representing padded elements. :param fill_value: the value to fill the embeddings belonging to padded tokens. :param tokens: The input sequences [bsz x seq_len]. :param embeddings: word embeddings [bsz x seq_len x hiddens]. :param padding_index: Index of the padding token. """ padding_mask = tokens.eq(padding_index).unsqueeze(-1) return embeddings.float().masked_fill_(padding_mask, fill_value).type_as(embeddings)
def __call__(self, predicted: torch.tensor, gt: torch.tensor, gamma=2.0): pred = torch.clamp(predicted.sigmoid_(), min=1e-4, max=1 - 1e-4) pos_inds = gt.eq(1).float() neg_inds = gt.lt(1).float() neg_weights = torch.pow(1 - gt, 4) pos_loss = -torch.log(pred) * torch.pow(1 - pred, gamma) * pos_inds neg_loss = -torch.log(1 - pred) * torch.pow( pred, gamma) * neg_inds * neg_weights num_pos = pos_inds.float().sum() pos_loss = pos_loss.sum() neg_loss = neg_loss.sum() if num_pos == 0: return neg_loss return (pos_loss + neg_loss) / num_pos