Example #1
0
    def training_step(self, batch, batch_idx):
        device = (next(self.model.parameters())).device
        batch.to(device)

        outputs = self.model(batch)
        loss = self._compute_loss(outputs, batch)
        logs = {**compute_perform_metrics(outputs, batch), **{'loss': loss}}
        log = {key + '/train': val for key, val in logs.items()}
        return {'loss': loss, 'log': log}
Example #2
0
    def _train_val_step(self, batch, batch_idx, train_val):
        device = (next(self.model.parameters())).device
        batch.to(device)

        outputs = self.model(batch)
        loss = self._compute_loss(outputs, batch)
        logs = {**compute_perform_metrics(outputs, batch), **{'loss': loss}}
        log = {key + f'/{train_val}': val for key, val in logs.items()}

        if train_val == 'train':
            return {'loss': loss, 'log': log}

        else:
            return log
Example #3
0
    def validation_step(self, batch, batch_idx):
        device = (next(self.model.parameters())).device
        batch.to(device)
        outputs = self.model(batch)
        loss = self._compute_loss(outputs, batch)
        logs = {**compute_perform_metrics(outputs, batch), **{'loss': loss}}
        log = {key + '/val': val for key, val in logs.items()}
        val_outputs = log

        if 'mask' in outputs:
            accumulated_fn = torch.zeros(len(outputs['mask']) + 8, ).to(device)
            mask_fn = torch.zeros(len(outputs['mask']), ).to(device)
            for i in range(len(outputs['mask'])):
                mask = outputs['mask'][i]
                accumulated_fn[i] = torch.sum(
                    batch.edge_labels.view(-1)[~mask])
                mask_fn[i] = torch.sum(~mask, dim=0)
            accumulated_fn[-8] = torch.sum(batch.edge_labels.view(-1))
            accumulated_fn[-7] = len(batch.edge_labels)
            final_mask = outputs['mask'][-1]
            final_pros = outputs['classified_edges'][-1][final_mask]
            final_act_mask = (batch.edge_labels.view(-1)[final_mask] == True)
            act_pro = final_pros[final_act_mask]
            inact_pro = final_pros[~final_act_mask]
            if len(act_pro) > 0:
                accumulated_fn[-6] = torch.min(act_pro)
                accumulated_fn[-4] = torch.mean(act_pro)
                accumulated_fn[-2] = torch.sum(act_pro < 0.5).type(
                    torch.float) / len(final_pros)
            else:
                accumulated_fn[-6] = 0.5
                accumulated_fn[-4] = 0.5
                accumulated_fn[-2] = 0
            if len(inact_pro) > 0:
                accumulated_fn[-5] = torch.max(inact_pro)
                accumulated_fn[-3] = torch.mean(inact_pro)
                accumulated_fn[-1] = torch.sum(inact_pro > 0.5).type(
                    torch.float) / len(final_pros)
            else:
                accumulated_fn[-5] = 0.5
                accumulated_fn[-3] = 0.5
                accumulated_fn[-1] = 0
            val_outputs['dynamic'] = accumulated_fn
            val_outputs["mask"] = mask_fn
        return val_outputs