Esempio n. 1
0
def test_MLTask_set_batch_parameters_batch_size(_, mltask_kwargs):
    mltask = MLTask(n_batches=100, **mltask_kwargs)
    assert mltask.set_batch_parameters() == 1000
    assert mltask.batch_size == 10
    assert mltask.n_batches == 100

    mltask = MLTask(n_batches=1000, **mltask_kwargs)
    mltask.set_batch_parameters()
    assert mltask.batch_size == 1
    assert mltask.n_batches == 1000

    mltask = MLTask(n_batches=10000, **mltask_kwargs)
    mltask.set_batch_parameters()
    assert mltask.batch_size == 1
    assert mltask.n_batches == 1000
Esempio n. 2
0
def test_MLTask_calculate_batch_indices(_, mltask_kwargs):
    mltask = MLTask(n_batches=100, **mltask_kwargs)
    with pytest.raises(ValueError):
        mltask.calculate_batch_indices(1, 10)

    total = mltask.set_batch_parameters()
    first_idx, last_idx = mltask.calculate_batch_indices(3, total)
    assert first_idx < last_idx
    assert last_idx - first_idx == mltask.batch_size

    first_idx, last_idx = mltask.calculate_batch_indices(99, total)
    assert last_idx == total

    with pytest.raises(ValueError):
        first_idx, last_idx = mltask.calculate_batch_indices(100, total)
Esempio n. 3
0
def test_MLTask_set_batch_parameters_batch_size(_, mltask_kwargs):
    mltask = MLTask(batch_size=100, **mltask_kwargs)
    assert mltask.set_batch_parameters() == 1000
    assert batch_size == 100
    assert mltask.n_batches == 10

    mltask = MLTask(batch_size=900, **mltask_kwargs)
    mltask.set_batch_parameters()
    assert batch_size == 900
    assert mltask.n_batches == 2

    mltask = MLTask(batch_size=1000, **mltask_kwargs)
    mltask.set_batch_parameters()
    assert batch_size == 1000
    assert mltask.n_batches == 1

    mltask = MLTask(batch_size=1001, **mltask_kwargs)
    mltask.set_batch_parameters()
    assert batch_size == 1000
    assert mltask.n_batches == 1