Example #1
0
def test_dataset_sample_inds(dataset_dir, base_cfg_dataset, records, mode,
                             sample_count):
    base_cfg_dataset["sample_count"] = sample_count
    batch_size = 32
    ds = RecordDataset(
        artifact_dir=dataset_dir,
        cfg_dataset=base_cfg_dataset,
        records=records,
        mode=mode,
        batch_size=batch_size,
    )

    if mode == RecordMode.TRAIN and sample_count is not None:
        assert ds.sample_inds == RecordDataset.convert_sample_count_to_inds(
            records[sample_count])
    else:
        assert ds.sample_inds == list(range(len(records)))
Example #2
0
def test_convert_sample_count_to_inds(s, result):
    assert RecordDataset.convert_sample_count_to_inds(s) == result