Ejemplo n.º 1
0
def test_make_splits_order():
    train, val, test = make_splits(
        100, 0.7, 0.2, 0.1, 1234, order=torch.arange(100, 0, -1, dtype=torch.int)
    )
    assert (train == torch.arange(100, 30, -1, dtype=torch.int)).all()
    assert (val == torch.arange(30, 10, -1, dtype=torch.int)).all()
    assert (test == torch.arange(10, 0, -1, dtype=torch.int)).all()
Ejemplo n.º 2
0
    def setup(self, stage):
        if self.dataset is None:
            if self.hparams["dataset"] == "Custom":
                self.dataset = datasets.Custom(
                    self.hparams["coord_files"],
                    self.hparams["embed_files"],
                    self.hparams["energy_files"],
                    self.hparams["force_files"],
                )
            else:
                self.dataset = getattr(datasets, self.hparams["dataset"])(
                    self.hparams["dataset_root"],
                    dataset_arg=self.hparams["dataset_arg"],
                )

        self.idx_train, self.idx_val, self.idx_test = make_splits(
            len(self.dataset),
            self.hparams["train_size"],
            self.hparams["val_size"],
            self.hparams["test_size"],
            self.hparams["seed"],
            join(self.hparams["log_dir"], "splits.npz"),
            self.hparams["splits"],
        )
        print(
            f"train {len(self.idx_train)}, val {len(self.idx_val)}, test {len(self.idx_test)}"
        )

        self.train_dataset = Subset(self.dataset, self.idx_train)
        self.val_dataset = Subset(self.dataset, self.idx_val)
        self.test_dataset = Subset(self.dataset, self.idx_test)

        if self.hparams["standardize"]:
            self._standardize()
Ejemplo n.º 3
0
def test_make_splits_errors():
    with raises(AssertionError):
        make_splits(100, 0.5, 0.5, 0.5, 1234)
    with raises(AssertionError):
        make_splits(100, 50, 50, 50, 1234)
    with raises(AssertionError):
        make_splits(100, None, None, 5, 1234)
    with raises(AssertionError):
        make_splits(100, 60, 60, None, 1234)
Ejemplo n.º 4
0
def test_make_splits_outputs():
    result = make_splits(100, 0.7, 0.2, 0.1, 1234)
    assert len(result) == 3
    assert isinstance(result[0], torch.Tensor)
    assert isinstance(result[1], torch.Tensor)
    assert isinstance(result[2], torch.Tensor)
    assert len(result[0]) == 70
    assert len(result[1]) == 20
    assert len(result[2]) == 10
    assert sum_lengths(*result) == len(torch.unique(torch.cat(result)))
    assert max(map(max, result)) == 99
    assert min(map(min, result)) == 0
Ejemplo n.º 5
0
def test_make_splits_sizes():
    assert sum_lengths(*make_splits(100, 70, 20, 10, 1234)) == 100
    assert sum_lengths(*make_splits(100, 70, 20, None, 1234)) == 100
    assert sum_lengths(*make_splits(100, 70, None, 10, 1234)) == 100
    assert sum_lengths(*make_splits(100, None, 20, 10, 1234)) == 100
    assert sum_lengths(*make_splits(100, 70, 20, 0.1, 1234)) == 100
    assert sum_lengths(*make_splits(100, 70, 20, 0.05, 1234)) == 95
Ejemplo n.º 6
0
def test_make_splits_ratios(dset_len, ratio1, ratio2, ratio3):
    train, val, test = make_splits(dset_len, ratio1, ratio2, ratio3, 1234)
    assert sum_lengths(train, val, test) <= dset_len
    assert len(train) == round(ratio1 * dset_len)
    assert len(val) == round(ratio2 * dset_len)
    # simply multiplying and rounding ratios can lead to values larger than dset_len,
    # which make_splits should account for by removing one sample from the test set
    if (
        round(ratio1 * dset_len) + round(ratio2 * dset_len) + round(ratio3 * dset_len)
        > dset_len
    ):
        assert len(test) == round(ratio3 * dset_len) - 1
    else:
        assert len(test) == round(ratio3 * dset_len)
Ejemplo n.º 7
0
def test_make_splits_save_load(tmpdir):
    path = join(tmpdir, "splits.npz")
    train, val, test = make_splits(100, 0.7, 0.2, 0.1, 1234, filename=path)
    assert exists(path)
    trainl, vall, testl = make_splits(None, None, None, None, None, splits=path)
    assert (train == trainl).all() and (val == vall).all() and (test == testl).all()