def test_result_map(tmpdir): result = TrainResult() result.log_dict({'x1': torch.tensor(1), 'x2': torch.tensor(2)}) result.rename_keys({'x1': 'y1', 'x2': 'y2'}) assert 'x1' not in result assert 'x2' not in result assert 'y1' in result assert 'y2' in result
def training_step(self, batch, batch_idx): input_ids, attention_mask, labels = batch outputs = self.forward(input_ids, attention_mask) loss = self.loss_function(predictions=outputs, targets=labels, weight=self.hparams.loss_weight) result = TrainResult(loss) metrics = {"train_loss": loss} result.log_dict(metrics, prog_bar=True, logger=self.logger.experiment) return result
def training_step(self, batch, batch_idx): if self.hparams.data.augment: # Using only first augmentation batch = batch[0] logits = self(batch) labels = batch[3] if self.hparams.exp.tsa: scores = torch.softmax(logits.detach(), dim=-1) max_probs, preds = torch.max(scores, dim=-1) mask = (max_probs.le(self.tsa.threshold) | (labels != preds)).float() loss = (F.cross_entropy( logits, labels, reduction='none', weight=self.cross_entropy_weights.type_as(logits)) * mask).mean() else: loss = F.cross_entropy( logits, labels, weight=self.cross_entropy_weights.type_as(logits)) result = TrainResult(minimize=loss) result.log('train_loss', loss.detach(), on_epoch=True, on_step=False, sync_dist=True) result.log_dict(self.calculate_metrics(logits.detach(), labels, prefix='train'), on_epoch=True, on_step=False, sync_dist=True) if self.hparams.exp.tsa: result.log('train_l_mask', mask.float().mean(), on_epoch=True, on_step=False, sync_dist=True) return result
def training_step(self, composite_batch, batch_idx): # Batch collation l_batch = composite_batch[0][0][0] ul_branches = [[torch.cat(sub_item) for sub_item in zip(*item)] for item in zip(*composite_batch[1])] # Supervised loss l_logits = self(l_batch) l_labels = l_batch[3] if self.hparams.exp.tsa: l_scores = torch.softmax(l_logits.detach(), dim=-1) l_max_probs, l_preds = torch.max(l_scores, dim=-1) l_mask = (l_max_probs.le(self.tsa.threshold) | (l_labels != l_preds)).float() l_loss = (F.cross_entropy( l_logits, l_labels, reduction='none', weight=self.cross_entropy_weights.type_as(l_logits)) * l_mask).mean() else: l_loss = F.cross_entropy( l_logits, l_labels, reduction='mean', weight=self.cross_entropy_weights.type_as(l_logits)) # Unsupervised loss # Choosing pseudo-labels and branches to back-propagate u_max_probs_2d = torch.empty( (len(ul_branches), len(ul_branches[0][0]))).type_as(l_loss) u_targets_2d = torch.empty( (len(ul_branches), len(ul_branches[0][0]))).type_as(l_loss) with torch.no_grad(): for i, ul_branch in enumerate(ul_branches): u_logits = self(ul_branch) pseudo_labels = torch.softmax(u_logits.detach(), dim=-1) u_max_probs_2d[i], u_targets_2d[i] = torch.max(pseudo_labels, dim=-1) if self.hparams.model.tsa_as_threshold: u_mask_2d = u_max_probs_2d.ge( torch.tensor(self.tsa.threshold)**(1 / 3)).int() else: u_mask_2d = u_max_probs_2d.ge(self.hparams.model.threshold).int() u_mask = ( u_mask_2d.sum(dim=0) > 0 ) # Threshold u_mask per instance, at least one branch should pass the threshold u_max_probs, u_best_branches = torch.max(u_max_probs_2d, dim=0) u_loss = torch.tensor(0.0).type_as(l_loss) u_batch = [] u_targets = [] if u_mask.int().sum() > 0: # Creating one batch for unlabelled loss for i in range(len(ul_branches[0][0])): if u_mask[i]: nonmax_branches = [ ul_branch for (ind, ul_branch) in enumerate(ul_branches) if ind != u_best_branches[i] and ( u_targets_2d[ind, i] != u_targets_2d[ u_best_branches[i], i] or not self.hparams. model.choose_only_wrongly_predicted_branches) ] if len(nonmax_branches) > 0: u_batch.extend([[item[i] for item in branch] for branch in nonmax_branches]) u_targets.extend( u_targets_2d[u_best_branches[i]][i].repeat( len(nonmax_branches))) # Cutting huge batches if self.hparams.model.max_ul_batch_size_per_gpu is not None: u_batch = u_batch[:self.hparams.model. max_ul_batch_size_per_gpu] u_targets = u_targets[:self.hparams.model. max_ul_batch_size_per_gpu] if len(u_batch) > 0: u_batch = [torch.stack(item) for item in zip(*u_batch)] u_targets = torch.stack(u_targets).long() # Unlabelled loss u_logits = self(u_batch) u_loss = F.cross_entropy( u_logits, u_targets, reduction='mean', weight=self.cross_entropy_weights.type_as(u_logits)) # Train loss / labelled accuracy loss = l_loss + self.hparams.model.lambda_u * u_loss result = TrainResult(minimize=loss) result.log('train_loss', loss.detach(), on_epoch=True, on_step=False, sync_dist=True) result.log('train_loss_l', l_loss.detach(), on_epoch=True, on_step=False, sync_dist=True) result.log('train_loss_ul', u_loss.detach(), on_epoch=True, on_step=False, sync_dist=True) result.log('train_u_mask', u_mask.float().mean(), on_epoch=True, on_step=False, sync_dist=True) result.log('train_u_batch_size', torch.tensor(len(u_targets)).float(), on_epoch=True, on_step=False, sync_dist=True) if self.hparams.exp.tsa: result.log('tsa_threshold', torch.tensor(self.tsa.threshold), on_epoch=True, on_step=False, sync_dist=True) result.log('train_l_mask', l_mask.float().mean(), on_epoch=True, on_step=False, sync_dist=True) result.log_dict(self.calculate_metrics(l_logits.detach(), l_labels, prefix='train'), on_epoch=True, on_step=False, sync_dist=True) return result