Ejemplo n.º 1
0
def test_go_dataset(filenames: List[str], length: int, transform: bool):
    view = Dataset(filenames, transform)
    assert len(view) == length
    random_idx = random.randrange(0, len(view))

    planes, moves, outcome = view[random_idx]
    assert planes.size() == (18, 19, 19)
    assert moves.item() in list(range(19 * 19 + 1))
    assert moves.dtype == torch.int64
    assert outcome.item() in (-1, 1)
Ejemplo n.º 2
0
def test_go_dataset(filenames: List[str], length: int, transform: bool):
    dataset = Dataset(filenames, transform)
    assert len(dataset) == length
    for i in range(len(dataset)):
        planes, moves, outcome = dataset[i]
        assert planes.size() == (18, 19, 19)
        assert moves.item() in list(range(19 * 19 + 1))
        assert moves.dtype == torch.int64
        assert outcome.item() in (-1, 1)
        assert outcome.dtype == torch.float32