Esempio n. 1
0
def test_mapped_examples_iterable(generate_examples_fn, n, func, batch_size):
    base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n})
    ex_iterable = MappedExamplesIterable(base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size)
    all_examples = list(generate_examples_fn(n=n))
    if batch_size is None:
        expected = [(key, func(x)) for key, x in all_examples]
    else:
        # For batched map we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function
        expected_examples_per_batch = [
            list(_batch_to_examples(func(_examples_to_batch([x for _, x in all_examples[i : i + batch_size]]))))
            for i in range(0, len(all_examples), batch_size)
        ]
        # The new key is the concatenation of the keys of each example in the batch
        expected_keys_per_batch = [
            ["_".join(key for key, _ in all_examples[i : i + batch_size])] * len(examples)
            for i, examples in zip(range(0, len(all_examples), batch_size), expected_examples_per_batch)
        ]
        # Combine keys and examples
        expected = [
            (key, example)
            for expected_keys, expected_examples in zip(expected_keys_per_batch, expected_examples_per_batch)
            for key, example in zip(expected_keys, expected_examples)
        ]
    assert next(iter(ex_iterable)) == expected[0]
    assert list(ex_iterable) == expected
Esempio n. 2
0
def test_mapped_examples_iterable_with_indices(generate_examples_fn, n, func,
                                               batch_size):
    base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n})
    ex_iterable = MappedExamplesIterable(base_ex_iterable,
                                         func,
                                         batched=batch_size is not None,
                                         batch_size=batch_size,
                                         with_indices=True)
    all_examples = [x for _, x in generate_examples_fn(n=n)]
    if batch_size is None:
        expected = [{
            **x,
            **func(x, idx)
        } for idx, x in enumerate(all_examples)]
    else:
        # For batched map we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function
        all_transformed_examples = []
        for batch_offset in range(0, len(all_examples), batch_size):
            examples = all_examples[batch_offset:batch_offset + batch_size]
            batch = _examples_to_batch(examples)
            indices = list(range(batch_offset, batch_offset + len(examples)))
            transformed_batch = func(batch, indices)
            all_transformed_examples.extend(
                _batch_to_examples(transformed_batch))
        expected = _examples_to_batch(all_examples)
        expected.update(_examples_to_batch(all_transformed_examples))
        expected = list(_batch_to_examples(expected))
    assert next(iter(ex_iterable))[1] == expected[0]
    assert list(x for _, x in ex_iterable) == expected
Esempio n. 3
0
def test_mapped_examples_iterable_input_columns(generate_examples_fn, n, func,
                                                batch_size, input_columns):
    base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n})
    ex_iterable = MappedExamplesIterable(base_ex_iterable,
                                         func,
                                         batched=batch_size is not None,
                                         batch_size=batch_size,
                                         input_columns=input_columns)
    all_examples = [x for _, x in generate_examples_fn(n=n)]
    columns_to_input = input_columns if isinstance(input_columns,
                                                   list) else [input_columns]
    if batch_size is None:
        expected = [{
            **x,
            **func(*[x[col] for col in columns_to_input])
        } for x in all_examples]
    else:
        # For batched map we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function
        all_transformed_examples = []
        for batch_offset in range(0, len(all_examples), batch_size):
            examples = all_examples[batch_offset:batch_offset + batch_size]
            batch = _examples_to_batch(examples)
            transformed_batch = func(*[batch[col] for col in columns_to_input])
            all_transformed_examples.extend(
                _batch_to_examples(transformed_batch))
        expected = _examples_to_batch(all_examples)
        expected.update(_examples_to_batch(all_transformed_examples))
        expected = list(_batch_to_examples(expected))
    assert next(iter(ex_iterable))[1] == expected[0]
    assert list(x for _, x in ex_iterable) == expected
Esempio n. 4
0
def test_mapped_examples_iterable_drop_last_batch(generate_examples_fn, n,
                                                  func, batch_size):
    base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n})
    ex_iterable = MappedExamplesIterable(base_ex_iterable,
                                         func,
                                         batched=batch_size is not None,
                                         batch_size=batch_size,
                                         drop_last_batch=True)
    all_examples = [x for _, x in generate_examples_fn(n=n)]
    is_empty = False
    if batch_size is None:
        # `drop_last_batch` has no effect here
        expected = [{**x, **func(x)} for x in all_examples]
    else:
        # For batched map we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function
        all_transformed_examples = []
        for batch_offset in range(0, len(all_examples), batch_size):
            examples = all_examples[batch_offset:batch_offset + batch_size]
            if len(examples) < batch_size:  # ignore last batch
                break
            batch = _examples_to_batch(examples)
            transformed_batch = func(batch)
            all_transformed_examples.extend(
                _batch_to_examples(transformed_batch))
        all_examples = all_examples if n % batch_size == 0 else all_examples[:n
                                                                             //
                                                                             batch_size
                                                                             *
                                                                             batch_size]
        if all_examples:
            expected = _examples_to_batch(all_examples)
            expected.update(_examples_to_batch(all_transformed_examples))
            expected = list(_batch_to_examples(expected))
        else:
            is_empty = True

    if not is_empty:
        assert next(iter(ex_iterable))[1] == expected[0]
        assert list(x for _, x in ex_iterable) == expected
    else:
        with pytest.raises(StopIteration):
            next(iter(ex_iterable))
def test_mapped_examples_iterable_remove_columns(generate_examples_fn, n, func, batch_size, remove_columns):
    base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n, "extra_column": "foo"})
    ex_iterable = MappedExamplesIterable(
        base_ex_iterable, func, batched=batch_size is not None, batch_size=batch_size, remove_columns=remove_columns
    )
    all_examples = [x for _, x in generate_examples_fn(n=n)]
    columns_to_remove = remove_columns if isinstance(remove_columns, list) else [remove_columns]
    if batch_size is None:
        expected = [{**{k: v for k, v in x.items() if k not in columns_to_remove}, **func(x)} for x in all_examples]
    else:
        # For batched map we have to format the examples as a batch (i.e. in one single dictionary) to pass the batch to the function
        all_transformed_examples = []
        for batch_offset in range(0, len(all_examples), batch_size):
            examples = all_examples[batch_offset : batch_offset + batch_size]
            batch = _examples_to_batch(examples)
            transformed_batch = func(batch)
            all_transformed_examples.extend(_batch_to_examples(transformed_batch))
        expected = {k: v for k, v in _examples_to_batch(all_examples).items() if k not in columns_to_remove}
        expected.update(_examples_to_batch(all_transformed_examples))
        expected = list(_batch_to_examples(expected))
    assert next(iter(ex_iterable))[1] == expected[0]
    assert list(x for _, x in ex_iterable) == expected