def training_step(self, batch, batch_idx): device = (next(self.model.parameters())).device batch.to(device) outputs = self.model(batch) loss = self._compute_loss(outputs, batch) logs = {**compute_perform_metrics(outputs, batch), **{'loss': loss}} log = {key + '/train': val for key, val in logs.items()} return {'loss': loss, 'log': log}
def _train_val_step(self, batch, batch_idx, train_val): device = (next(self.model.parameters())).device batch.to(device) outputs = self.model(batch) loss = self._compute_loss(outputs, batch) logs = {**compute_perform_metrics(outputs, batch), **{'loss': loss}} log = {key + f'/{train_val}': val for key, val in logs.items()} if train_val == 'train': return {'loss': loss, 'log': log} else: return log
def validation_step(self, batch, batch_idx): device = (next(self.model.parameters())).device batch.to(device) outputs = self.model(batch) loss = self._compute_loss(outputs, batch) logs = {**compute_perform_metrics(outputs, batch), **{'loss': loss}} log = {key + '/val': val for key, val in logs.items()} val_outputs = log if 'mask' in outputs: accumulated_fn = torch.zeros(len(outputs['mask']) + 8, ).to(device) mask_fn = torch.zeros(len(outputs['mask']), ).to(device) for i in range(len(outputs['mask'])): mask = outputs['mask'][i] accumulated_fn[i] = torch.sum( batch.edge_labels.view(-1)[~mask]) mask_fn[i] = torch.sum(~mask, dim=0) accumulated_fn[-8] = torch.sum(batch.edge_labels.view(-1)) accumulated_fn[-7] = len(batch.edge_labels) final_mask = outputs['mask'][-1] final_pros = outputs['classified_edges'][-1][final_mask] final_act_mask = (batch.edge_labels.view(-1)[final_mask] == True) act_pro = final_pros[final_act_mask] inact_pro = final_pros[~final_act_mask] if len(act_pro) > 0: accumulated_fn[-6] = torch.min(act_pro) accumulated_fn[-4] = torch.mean(act_pro) accumulated_fn[-2] = torch.sum(act_pro < 0.5).type( torch.float) / len(final_pros) else: accumulated_fn[-6] = 0.5 accumulated_fn[-4] = 0.5 accumulated_fn[-2] = 0 if len(inact_pro) > 0: accumulated_fn[-5] = torch.max(inact_pro) accumulated_fn[-3] = torch.mean(inact_pro) accumulated_fn[-1] = torch.sum(inact_pro > 0.5).type( torch.float) / len(final_pros) else: accumulated_fn[-5] = 0.5 accumulated_fn[-3] = 0.5 accumulated_fn[-1] = 0 val_outputs['dynamic'] = accumulated_fn val_outputs["mask"] = mask_fn return val_outputs