def training_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=None): """ Full loop flow train step (result obj + dp) """ x, y = batch x = x.view(x.size(0), -1) y_hat = self(x.to(self.device)) loss_val = y_hat.sum() result = TrainResult(minimize=loss_val) result.log('train_step_metric', loss_val + 1) self.training_step_called = True return result
def training_step_result_obj(self, batch, batch_idx, optimizer_idx=None): # forward pass x, y = batch x = x.view(x.size(0), -1) y_hat = self(x) # calculate loss loss_val = self.loss(y, y_hat) log_val = loss_val # alternate between tensors and scalars for "log" and "progress_bar" if batch_idx % 2 == 0: log_val = log_val.item() result = TrainResult(loss_val) result.log('some_val', log_val * log_val, prog_bar=True, logger=False) result.log('train_some_val', log_val * log_val) return result
def training_step__using_metrics(self, batch, batch_idx, optimizer_idx=None): """Lightning calls this inside the training loop""" # forward pass x, y = batch x = x.view(x.size(0), -1) y_hat = self(x) # calculate loss loss_val = self.loss(y, y_hat) # call metric val = self.metric(x, y) result = TrainResult(minimize=loss_val) result.log('metric_val', val) return result
def training_step(self, batch, batch_idx): if self.hparams.data.augment: # Using only first augmentation batch = batch[0] logits = self(batch) labels = batch[3] if self.hparams.exp.tsa: scores = torch.softmax(logits.detach(), dim=-1) max_probs, preds = torch.max(scores, dim=-1) mask = (max_probs.le(self.tsa.threshold) | (labels != preds)).float() loss = (F.cross_entropy( logits, labels, reduction='none', weight=self.cross_entropy_weights.type_as(logits)) * mask).mean() else: loss = F.cross_entropy( logits, labels, weight=self.cross_entropy_weights.type_as(logits)) result = TrainResult(minimize=loss) result.log('train_loss', loss.detach(), on_epoch=True, on_step=False, sync_dist=True) result.log_dict(self.calculate_metrics(logits.detach(), labels, prefix='train'), on_epoch=True, on_step=False, sync_dist=True) if self.hparams.exp.tsa: result.log('train_l_mask', mask.float().mean(), on_epoch=True, on_step=False, sync_dist=True) return result
def training_step_result_log_step_only(self, batch, batch_idx): acc = self.step(batch, batch_idx) result = TrainResult(minimize=acc) # step only metrics result.log(f'step_log_and_pbar_acc1_b{batch_idx}', torch.tensor(11).type_as(acc), prog_bar=True) result.log(f'step_log_acc2_b{batch_idx}', torch.tensor(12).type_as(acc)) result.log(f'step_pbar_acc3_b{batch_idx}', torch.tensor(13).type_as(acc), logger=False, prog_bar=True) self.training_step_called = True return result
def training_step_result_log_epoch_only(self, batch, batch_idx): acc = self.step(batch, batch_idx) result = TrainResult(minimize=acc) result.log(f'epoch_log_and_pbar_acc1_e{self.current_epoch}', torch.tensor(14).type_as(acc), on_epoch=True, prog_bar=True, on_step=False) result.log(f'epoch_log_acc2_e{self.current_epoch}', torch.tensor(15).type_as(acc), on_epoch=True, on_step=False) result.log(f'epoch_pbar_acc3_e{self.current_epoch}', torch.tensor(16).type_as(acc), on_epoch=True, logger=False, prog_bar=True, on_step=False) self.training_step_called = True return result
def training_step_result_log_epoch_and_step(self, batch, batch_idx): acc = self.step(batch, batch_idx) result = TrainResult(minimize=acc) val_1 = (5 + batch_idx) * (self.current_epoch + 1) val_2 = (6 + batch_idx) * (self.current_epoch + 1) val_3 = (7 + batch_idx) * (self.current_epoch + 1) result.log('step_epoch_log_and_pbar_acc1', torch.tensor(val_1).type_as(acc), on_epoch=True, prog_bar=True) result.log('step_epoch_log_acc2', torch.tensor(val_2).type_as(acc), on_epoch=True) result.log('step_epoch_pbar_acc3', torch.tensor(val_3).type_as(acc), on_epoch=True, logger=False, prog_bar=True) self.training_step_called = True return result
def training_step(self, composite_batch, batch_idx): # Batch collation l_batch = composite_batch[0][0][0] ul_branches = [[torch.cat(sub_item) for sub_item in zip(*item)] for item in zip(*composite_batch[1])] # Supervised loss l_logits = self(l_batch) l_labels = l_batch[3] if self.hparams.exp.tsa: l_scores = torch.softmax(l_logits.detach(), dim=-1) l_max_probs, l_preds = torch.max(l_scores, dim=-1) l_mask = (l_max_probs.le(self.tsa.threshold) | (l_labels != l_preds)).float() l_loss = (F.cross_entropy( l_logits, l_labels, reduction='none', weight=self.cross_entropy_weights.type_as(l_logits)) * l_mask).mean() else: l_loss = F.cross_entropy( l_logits, l_labels, reduction='mean', weight=self.cross_entropy_weights.type_as(l_logits)) # Unsupervised loss # Choosing pseudo-labels and branches to back-propagate u_max_probs_2d = torch.empty( (len(ul_branches), len(ul_branches[0][0]))).type_as(l_loss) u_targets_2d = torch.empty( (len(ul_branches), len(ul_branches[0][0]))).type_as(l_loss) with torch.no_grad(): for i, ul_branch in enumerate(ul_branches): u_logits = self(ul_branch) pseudo_labels = torch.softmax(u_logits.detach(), dim=-1) u_max_probs_2d[i], u_targets_2d[i] = torch.max(pseudo_labels, dim=-1) if self.hparams.model.tsa_as_threshold: u_mask_2d = u_max_probs_2d.ge( torch.tensor(self.tsa.threshold)**(1 / 3)).int() else: u_mask_2d = u_max_probs_2d.ge(self.hparams.model.threshold).int() u_mask = ( u_mask_2d.sum(dim=0) > 0 ) # Threshold u_mask per instance, at least one branch should pass the threshold u_max_probs, u_best_branches = torch.max(u_max_probs_2d, dim=0) u_loss = torch.tensor(0.0).type_as(l_loss) u_batch = [] u_targets = [] if u_mask.int().sum() > 0: # Creating one batch for unlabelled loss for i in range(len(ul_branches[0][0])): if u_mask[i]: nonmax_branches = [ ul_branch for (ind, ul_branch) in enumerate(ul_branches) if ind != u_best_branches[i] and ( u_targets_2d[ind, i] != u_targets_2d[ u_best_branches[i], i] or not self.hparams. model.choose_only_wrongly_predicted_branches) ] if len(nonmax_branches) > 0: u_batch.extend([[item[i] for item in branch] for branch in nonmax_branches]) u_targets.extend( u_targets_2d[u_best_branches[i]][i].repeat( len(nonmax_branches))) # Cutting huge batches if self.hparams.model.max_ul_batch_size_per_gpu is not None: u_batch = u_batch[:self.hparams.model. max_ul_batch_size_per_gpu] u_targets = u_targets[:self.hparams.model. max_ul_batch_size_per_gpu] if len(u_batch) > 0: u_batch = [torch.stack(item) for item in zip(*u_batch)] u_targets = torch.stack(u_targets).long() # Unlabelled loss u_logits = self(u_batch) u_loss = F.cross_entropy( u_logits, u_targets, reduction='mean', weight=self.cross_entropy_weights.type_as(u_logits)) # Train loss / labelled accuracy loss = l_loss + self.hparams.model.lambda_u * u_loss result = TrainResult(minimize=loss) result.log('train_loss', loss.detach(), on_epoch=True, on_step=False, sync_dist=True) result.log('train_loss_l', l_loss.detach(), on_epoch=True, on_step=False, sync_dist=True) result.log('train_loss_ul', u_loss.detach(), on_epoch=True, on_step=False, sync_dist=True) result.log('train_u_mask', u_mask.float().mean(), on_epoch=True, on_step=False, sync_dist=True) result.log('train_u_batch_size', torch.tensor(len(u_targets)).float(), on_epoch=True, on_step=False, sync_dist=True) if self.hparams.exp.tsa: result.log('tsa_threshold', torch.tensor(self.tsa.threshold), on_epoch=True, on_step=False, sync_dist=True) result.log('train_l_mask', l_mask.float().mean(), on_epoch=True, on_step=False, sync_dist=True) result.log_dict(self.calculate_metrics(l_logits.detach(), l_labels, prefix='train'), on_epoch=True, on_step=False, sync_dist=True) return result