Exemple #1
0
    def test_step_result_obj(self, batch, batch_idx, *args, **kwargs):
        """
        Default, baseline test_step
        :param batch:
        :return:
        """
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self(x)

        loss_test = self.loss(y, y_hat)

        # acc
        labels_hat = torch.argmax(y_hat, dim=1)
        test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
        test_acc = torch.tensor(test_acc)

        test_acc = test_acc.type_as(x)

        result = EvalResult()
        # alternate possible outputs to test
        if batch_idx % 1 == 0:
            result.log_dict({'test_loss': loss_test, 'test_acc': test_acc})
            return result
        if batch_idx % 2 == 0:
            return test_acc

        if batch_idx % 3 == 0:
            result.log_dict({'test_loss': loss_test, 'test_acc': test_acc})
            result.test_dic = {'test_loss_a': loss_test}
            return result
Exemple #2
0
    def test_step(self, batch, batch_idx):
        logits = self(batch, model='best')
        labels = batch[3]
        loss = F.cross_entropy(logits, labels, reduction='mean')

        result = EvalResult()
        result.log('test_loss', loss, sync_dist=True)
        result.log_dict(self.calculate_metrics(logits, labels, prefix='test'),
                        sync_dist=True)
        return result
Exemple #3
0
    def validation_step(self, batch, batch_idx):
        logits = self(batch)
        labels = batch[3]
        loss = F.cross_entropy(logits, labels, reduction='mean')

        result = EvalResult(early_stop_on=loss, checkpoint_on=loss)
        result.log('val_loss', loss, sync_dist=True)
        result.log_dict(self.calculate_metrics(logits, labels, prefix='val'),
                        sync_dist=True)
        return result
    def validation_step_result_obj(self, batch, batch_idx, *args, **kwargs):
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self(x)

        loss_val = self.loss(y, y_hat)

        # acc
        labels_hat = torch.argmax(y_hat, dim=1)
        val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
        val_acc = torch.tensor(val_acc).type_as(x)

        result = EvalResult(checkpoint_on=loss_val, early_stop_on=loss_val)
        result.log_dict({
            'val_loss': loss_val,
            'val_acc': val_acc,
        })
        return result