Ejemplo n.º 1
0
def test_add_data_set(data_sets_dict):
    data_bunch = DataBunch()
    for data_set_name, data_set in data_sets_dict.items():
        data_bunch.add_data_set(data_set_name, data_set)

    for data_set_name, data_set in data_sets_dict.items():
        assert getattr(data_bunch, data_set_name) == data_set
Ejemplo n.º 2
0
def read_test_data_bunch(read_data_set, test_params):
    """
    Takes the provider "read_data_set" and uses it to read a "test" DataSet and add it to a new DataBunch

    :param callable read_data_set: A function that reads source data and returns a DatSet
    :param dict test_params: The parameters according to which to read the "test" DataSet
    :return: a DataBunch
    """
    return DataBunch({"test": read_data_set(**test_params)})
Ejemplo n.º 3
0
def read_train_valid_test_data_bunch(read_data_set, train_params, valid_params,
                                     test_params):
    """
    Takes the provider "read_data_set" and uses it to read "train, "valid" and "test" DataSets and add them to a
    new DataBunch

    :param callable read_data_set: A function that reads source data and returns a DatSet
    :param dict train_params: The parameters according to which to read the "train" DataSet
    :param dict valid_params: The parameters according to which to read the "valid" DataSet
    :param dict test_params: The parameters according to which to read the "test" DataSet
    :return: a DataBunch
    """
    return DataBunch({
        "train": read_data_set(**train_params),
        "valid": read_data_set(**valid_params),
        "test": read_data_set(**test_params)
    })
Ejemplo n.º 4
0
    def transform(self, transform_then_slice, transformation_params):
        new_data_set = MockDataSet(self.value * 100)
        new_data_set.transform_then_slice = transform_then_slice
        new_data_set.transformation_params = transformation_params

        return new_data_set


data_sets_dict = {
    "train": MockDataSet(123),
    "valid": MockDataSet(456),
    "test": MockDataSet(789)
}

data_bunch = DataBunch(data_sets_dict)


@pytest.mark.parametrize(
    "data_bunch, data_set_names, params, transform_then_slice",
    [(data_bunch, ["train", "valid"], {
        "blah": "456"
    }, True)])
def test_transform(data_bunch, data_set_names, params, transform_then_slice):
    new_data_bunch = data_bunch.transform(data_set_names, params,
                                          transform_then_slice)

    for data_set_name in data_set_names:
        new_data_set = getattr(new_data_bunch, data_set_name)
        assert new_data_set.value == getattr(data_bunch,
                                             data_set_name).value * 100