def test_mapped_examples_iterable_input_columns(generate_examples_fn, n, func, batch_size, input_columns): base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) ex_iterable = MappedExamplesIterable(base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size, input_columns=input_columns) all_examples = [x for _, x in generate_examples_fn(n=n)] columns_to_input = input_columns if isinstance(input_columns, list) else [input_columns] if batch_size is None: expected = [{ **x, **func(*[x[col] for col in columns_to_input]) } for x in all_examples] else: # For batched map we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function all_transformed_examples = [] for batch_offset in range(0, len(all_examples), batch_size): examples = all_examples[batch_offset:batch_offset + batch_size] batch = _examples_to_batch(examples) transformed_batch = func(*[batch[col] for col in columns_to_input]) all_transformed_examples.extend( _batch_to_examples(transformed_batch)) expected = _examples_to_batch(all_examples) expected.update(_examples_to_batch(all_transformed_examples)) expected = list(_batch_to_examples(expected)) assert next(iter(ex_iterable))[1] == expected[0] assert list(x for _, x in ex_iterable) == expected
def test_mapped_examples_iterable_with_indices(generate_examples_fn, n, func, batch_size): base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) ex_iterable = MappedExamplesIterable(base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size, with_indices=True) all_examples = [x for _, x in generate_examples_fn(n=n)] if batch_size is None: expected = [{ **x, **func(x, idx) } for idx, x in enumerate(all_examples)] else: # For batched map we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function all_transformed_examples = [] for batch_offset in range(0, len(all_examples), batch_size): examples = all_examples[batch_offset:batch_offset + batch_size] batch = _examples_to_batch(examples) indices = list(range(batch_offset, batch_offset + len(examples))) transformed_batch = func(batch, indices) all_transformed_examples.extend( _batch_to_examples(transformed_batch)) expected = _examples_to_batch(all_examples) expected.update(_examples_to_batch(all_transformed_examples)) expected = list(_batch_to_examples(expected)) assert next(iter(ex_iterable))[1] == expected[0] assert list(x for _, x in ex_iterable) == expected
def test_mapped_examples_iterable(generate_examples_fn, n, func, batch_size): base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) ex_iterable = MappedExamplesIterable(base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size) all_examples = list(generate_examples_fn(n=n)) if batch_size is None: expected = [(key, func(x)) for key, x in all_examples] else: # For batched map we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function expected_examples_per_batch = [ list(_batch_to_examples(func(_examples_to_batch([x for _, x in all_examples[i : i + batch_size]])))) for i in range(0, len(all_examples), batch_size) ] # The new key is the concatenation of the keys of each example in the batch expected_keys_per_batch = [ ["_".join(key for key, _ in all_examples[i : i + batch_size])] * len(examples) for i, examples in zip(range(0, len(all_examples), batch_size), expected_examples_per_batch) ] # Combine keys and examples expected = [ (key, example) for expected_keys, expected_examples in zip(expected_keys_per_batch, expected_examples_per_batch) for key, example in zip(expected_keys, expected_examples) ] assert next(iter(ex_iterable)) == expected[0] assert list(ex_iterable) == expected
def test_filtered_examples_iterable_input_columns(generate_examples_fn, n, func, batch_size, input_columns): base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) ex_iterable = FilteredExamplesIterable(base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size, input_columns=input_columns) all_examples = [x for _, x in generate_examples_fn(n=n)] columns_to_input = input_columns if isinstance(input_columns, list) else [input_columns] if batch_size is None: expected = [ x for x in all_examples if func(*[x[col] for col in columns_to_input]) ] else: # For batched filter we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function expected = [] for batch_offset in range(0, len(all_examples), batch_size): examples = all_examples[batch_offset:batch_offset + batch_size] batch = _examples_to_batch(examples) mask = func(*[batch[col] for col in columns_to_input]) expected.extend( [x for x, to_keep in zip(examples, mask) if to_keep]) assert next(iter(ex_iterable))[1] == expected[0] assert list(x for _, x in ex_iterable) == expected
def test_mapped_examples_iterable_drop_last_batch(generate_examples_fn, n, func, batch_size): base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) ex_iterable = MappedExamplesIterable(base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size, drop_last_batch=True) all_examples = [x for _, x in generate_examples_fn(n=n)] is_empty = False if batch_size is None: # `drop_last_batch` has no effect here expected = [{**x, **func(x)} for x in all_examples] else: # For batched map we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function all_transformed_examples = [] for batch_offset in range(0, len(all_examples), batch_size): examples = all_examples[batch_offset:batch_offset + batch_size] if len(examples) < batch_size: # ignore last batch break batch = _examples_to_batch(examples) transformed_batch = func(batch) all_transformed_examples.extend( _batch_to_examples(transformed_batch)) all_examples = all_examples if n % batch_size == 0 else all_examples[:n // batch_size * batch_size] if all_examples: expected = _examples_to_batch(all_examples) expected.update(_examples_to_batch(all_transformed_examples)) expected = list(_batch_to_examples(expected)) else: is_empty = True if not is_empty: assert next(iter(ex_iterable))[1] == expected[0] assert list(x for _, x in ex_iterable) == expected else: with pytest.raises(StopIteration): next(iter(ex_iterable))
def test_mapped_examples_iterable_remove_columns(generate_examples_fn, n, func, batch_size, remove_columns): base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "extra_column": "foo"}) ex_iterable = MappedExamplesIterable( base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size, remove_columns=remove_columns ) all_examples = [x for _, x in generate_examples_fn(n=n)] columns_to_remove = remove_columns if isinstance(remove_columns, list) else [remove_columns] if batch_size is None: expected = [{**{k: v for k, v in x.items() if k not in columns_to_remove}, **func(x)} for x in all_examples] else: # For batched map we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function all_transformed_examples = [] for batch_offset in range(0, len(all_examples), batch_size): examples = all_examples[batch_offset : batch_offset + batch_size] batch = _examples_to_batch(examples) transformed_batch = func(batch) all_transformed_examples.extend(_batch_to_examples(transformed_batch)) expected = {k: v for k, v in _examples_to_batch(all_examples).items() if k not in columns_to_remove} expected.update(_examples_to_batch(all_transformed_examples)) expected = list(_batch_to_examples(expected)) assert next(iter(ex_iterable))[1] == expected[0] assert list(x for _, x in ex_iterable) == expected
def test_filtered_examples_iterable_with_indices(generate_examples_fn, n, func, batch_size): base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) ex_iterable = FilteredExamplesIterable( base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size, with_indices=True ) all_examples = [x for _, x in generate_examples_fn(n=n)] if batch_size is None: expected = [x for idx, x in enumerate(all_examples) if func(x, idx)] else: # For batched filter we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function expected = [] for batch_offset in range(0, len(all_examples), batch_size): examples = all_examples[batch_offset : batch_offset + batch_size] batch = _examples_to_batch(examples) indices = list(range(batch_offset, batch_offset + len(examples))) mask = func(batch, indices) expected.extend([x for x, to_keep in zip(examples, mask) if to_keep]) assert next(iter(ex_iterable))[1] == expected[0] assert list(x for _, x in ex_iterable) == expected