def validation_step(self, batch, batch_idx):
        X, fx = batch
        X_compressed = self.conv(self.X_fit)
        km = self.cal_km(self.params, X_compressed, X)
        alpha_i = torch.abs(self.untreated_coef)
        constrainted_alpha_i = (alpha_i - torch.min(alpha_i)) / (torch.max(alpha_i) - torch.min(alpha_i))
        coef = constrainted_alpha_i * self.label
        fx_hat = torch.sum(coef * km.t(), axis=1)
        loss = F.smooth_l1_loss(fx_hat, fx)
        var = FM.explained_variance(fx_hat, fx)
        mae = FM.mean_absolute_error(fx_hat, fx)
        mse = FM.mean_squared_error(fx_hat, fx)

        val_metrics = {'val_var': var, 'val_mae': mae, 'val_mse': mse, 'val_loss': loss}
        self.log_dict(val_metrics)
        return val_metrics
Exemple #2
0
def test_v1_5_metric_regress():
    ExplainedVariance.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        ExplainedVariance()

    MeanAbsoluteError.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        MeanAbsoluteError()

    MeanSquaredError.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        MeanSquaredError()

    MeanSquaredLogError.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        MeanSquaredLogError()

    target = torch.tensor([3, -0.5, 2, 7])
    preds = torch.tensor([2.5, 0.0, 2, 8])
    explained_variance._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        res = explained_variance(preds, target)
    assert torch.allclose(res, torch.tensor(0.9572), atol=1e-4)

    x = torch.tensor([0., 1, 2, 3])
    y = torch.tensor([0., 1, 2, 2])
    mean_absolute_error._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert mean_absolute_error(x, y) == 0.25

    mean_relative_error._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert mean_relative_error(x, y) == 0.125

    mean_squared_error._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        assert mean_squared_error(x, y) == 0.25

    mean_squared_log_error._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        res = mean_squared_log_error(x, y)
    assert torch.allclose(res, torch.tensor(0.0207), atol=1e-4)

    PSNR.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        PSNR()

    R2Score.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        R2Score()

    SSIM.__init__._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        SSIM()

    preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]])
    target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
    psnr._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        res = psnr(preds, target)
    assert torch.allclose(res, torch.tensor(2.5527), atol=1e-4)

    target = torch.tensor([3, -0.5, 2, 7])
    preds = torch.tensor([2.5, 0.0, 2, 8])
    r2score._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        res = r2score(preds, target)
    assert torch.allclose(res, torch.tensor(0.9486), atol=1e-4)

    preds = torch.rand([16, 1, 16, 16])
    target = preds * 0.75
    ssim._warned = False
    with pytest.deprecated_call(match='It will be removed in v1.5.0'):
        res = ssim(preds, target)
    assert torch.allclose(res, torch.tensor(0.9219), atol=1e-4)