def test_MLTask_yield_batch_use_intermediate(mocked_ds, mocked_boto3, mocked_s3_path, mltask_kwargs): # Mock some keys to return _keys = [ Key(k) for k in [ 'a/first_key.json', # Good 'b/second_key.json', # Not in subbucket 'a/third_key.json', # Good 'a/third_key.other' ] ] # Not json # Mock the full return iterable when iterating over objects # in a bucket mocked_boto3.resource.return_value.Bucket.return_value.objects.all.return_value = _keys # Test the yielding mltask = MLTask(input_task=SomeTask, use_intermediate_inputs=True, **mltask_kwargs) out_keys = [] for first_idx, last_idx, _in_key, out_key in mltask.yield_batch(): # Indexes are always dummies in use_intermediate_inputs assert first_idx == 0 assert last_idx == -1 # Test keys look right assert _in_key.endswith('.json') assert out_key.endswith('.json') assert _in_key != mocked_s3_path out_keys.append(out_key) assert len(out_keys) == len(set(out_keys)) assert len(out_keys) == 2
def test_MLTask_prepare(mocked_s3, mocked_yield_batch, mltask_kwargs): mocked_s3.S3FS.return_value.exists.side_effect = [True, False, True] * int( 126 / 3) mltask = MLTask(hyperparameters={'a': 20, 'b': 30}, **mltask_kwargs) job_params = mltask.prepare() # Check that the numbers add up assert len(job_params) == 126 assert sum(p['done'] for p in job_params) == int(126 * 2 / 3) ## 2/3 are True # Check the hyperparameters are there for p in job_params: assert 'a' in p.keys() assert 'b' in p.keys()
def test_MLTask_yield_batch_not_use_intermediate(mocked_len, mocked_s3_path, mltask_kwargs): mltask = MLTask(input_task=SomeTask, n_batches=100, use_intermediate_inputs=False, **mltask_kwargs) out_keys = [] previous_first_idx = -1 previous_last_idx = -1 for first_idx, last_idx, _in_key, out_key in mltask.yield_batch(): assert first_idx < last_idx assert first_idx > previous_first_idx assert last_idx > previous_last_idx assert _in_key == mocked_s3_path() out_keys.append(out_key) previous_first_idx = first_idx previous_last_idx = last_idx assert len(out_keys) == mltask.n_batches
def test_MLTask_calculate_batch_indices(_, mltask_kwargs): mltask = MLTask(n_batches=100, **mltask_kwargs) with pytest.raises(ValueError): mltask.calculate_batch_indices(1, 10) total = mltask.set_batch_parameters() first_idx, last_idx = mltask.calculate_batch_indices(3, total) assert first_idx < last_idx assert last_idx - first_idx == mltask.batch_size first_idx, last_idx = mltask.calculate_batch_indices(99, total) assert last_idx == total with pytest.raises(ValueError): first_idx, last_idx = mltask.calculate_batch_indices(100, total)
def mltask(mltask_kwargs): return MLTask(**mltask_kwargs)
def test_MLTask_set_batch_parameters_batch_size(_, mltask_kwargs): mltask = MLTask(n_batches=100, **mltask_kwargs) assert mltask.set_batch_parameters() == 1000 assert mltask.batch_size == 10 assert mltask.n_batches == 100 mltask = MLTask(n_batches=1000, **mltask_kwargs) mltask.set_batch_parameters() assert mltask.batch_size == 1 assert mltask.n_batches == 1000 mltask = MLTask(n_batches=10000, **mltask_kwargs) mltask.set_batch_parameters() assert mltask.batch_size == 1 assert mltask.n_batches == 1000
def test_MLTask_set_batch_parameters_batch_size(_, mltask_kwargs): mltask = MLTask(batch_size=100, **mltask_kwargs) assert mltask.set_batch_parameters() == 1000 assert batch_size == 100 assert mltask.n_batches == 10 mltask = MLTask(batch_size=900, **mltask_kwargs) mltask.set_batch_parameters() assert batch_size == 900 assert mltask.n_batches == 2 mltask = MLTask(batch_size=1000, **mltask_kwargs) mltask.set_batch_parameters() assert batch_size == 1000 assert mltask.n_batches == 1 mltask = MLTask(batch_size=1001, **mltask_kwargs) mltask.set_batch_parameters() assert batch_size == 1000 assert mltask.n_batches == 1
def test_MLTask_no_combine_output(mocked_target, mltask_kwargs): mltask = MLTask(combine_outputs=False, **mltask_kwargs) mltask.output() args, kwargs = mocked_target._mock_call_args assert args[0].endswith('.length')
def test_MLTask_combine_output(mocked_target, mltask_kwargs): mltask = MLTask(combine_outputs=True, **mltask_kwargs) mltask.output() args, kwargs = mocked_target._mock_call_args assert args[0].endswith('.json')
def test_MLTask_requires_child(mocked_task, mltask_kwargs): mltask = MLTask(child={'job_name': 'something'}, **mltask_kwargs) task = mltask.requires() assert mocked_task.call_count == 1 assert task == mocked_task._mock_return_value
def test_MLTask_requires_no_child_yes_input(mltask_kwargs): mltask = MLTask(input_task=SomeTask, **mltask_kwargs) task = mltask.requires() assert type(task) == SomeTask