Exemple #1
0
def test_record_augmentor_no_funcs():
    funcs = []
    augmentor = RecordAugmentor(funcs)

    data_record_1 = ({"x1": 1, "x2": 1}, {"y": 0})
    expected = ({"x1": 1, "x2": 1}, {"y": 0})
    result = augmentor.augment(data_record_1)
    assert result == expected
    result = augmentor(data_record_1)
    assert result == expected
Exemple #2
0
def test_record_augmentor():
    funcs = [
        {
            "import": "tests.unit.dataset.test_augmentor.augment_add",
            "params": {
                "input_key": "x1",
                "ind": 0
            },
        },
        {
            "import": "tests.unit.dataset.test_augmentor.augment_add",
            "params": {
                "input_key": "x2",
                "ind": 0,
                "value": 2
            },
        },
        {
            "import": "tests.unit.dataset.test_augmentor.augment_mult",
            "params": {
                "input_key": "x1",
                "ind": 0,
                "value": 5
            },
        },
        {
            "import": "tests.unit.dataset.test_augmentor.augment_add",
            "params": {
                "input_key": "y",
                "ind": 1,
                "value": 4
            },
        },
    ]
    augmentor = RecordAugmentor(funcs)

    data_record_1 = ({"x1": 1, "x2": 1}, {"y": 0})
    expected = ({"x1": 5, "x2": 3}, {"y": 4})
    result = augmentor.augment(data_record_1)
    assert result == expected

    data_record_2 = ({"x1": 3, "x2": 2}, {"y": 1})
    expected = ({"x1": 15, "x2": 4}, {"y": 5})
    result = augmentor(data_record_2)
    assert result == expected
Exemple #3
0
def test_reduce_compose():
    def add1(x):
        return x + 1

    def mult2(x):
        return x * 2

    func = RecordAugmentor.reduce_compose(mult2, add1)
    assert func(1) == 3
    assert func(2) == 5
    assert func(13) == 27
Exemple #4
0
    def __init__(
        self,
        artifact_dir: str,
        cfg_dataset: dict,
        records: pd.DataFrame,
        mode: RecordMode,
        batch_size: int,
    ):
        if not isinstance(records, pd.DataFrame):
            raise TypeError("records must be type pd.DataFrame")
        records.reset_index(drop=True, inplace=True)
        if not isinstance(mode, RecordMode):
            raise TypeError("mode must be type RecordMode")

        self.num_records = len(records)
        logger.info(f"Building {mode} dataset with {self.num_records} records")
        self.records = records
        self.mode = mode
        self.batch_size = batch_size

        self.seed = cfg_dataset.get("seed")
        np.random.seed(self.seed)

        sample_count = cfg_dataset.get("sample_count")
        if self.mode == RecordMode.TRAIN and sample_count is not None:
            self._sample_inds = convert_sample_count_to_inds(records[sample_count])
        else:
            self._sample_inds = list(range(self.num_records))
        self.shuffle()

        logger.info(f"Creating record loader")
        loader_cls = import_utils.import_obj_with_search_modules(
            cfg_dataset["loader"]["import"], search_modules=SEARCH_MODULES
        )
        self.loader = loader_cls(
            mode=mode, params=cfg_dataset["loader"].get("params", {})
        )
        if not isinstance(self.loader, RecordLoader):
            raise TypeError(f"loader {self.loader} is not of type RecordLoader")

        logger.info(f"Creating record transformer")
        transformer_cls = import_utils.import_obj_with_search_modules(
            cfg_dataset["transformer"]["import"], search_modules=SEARCH_MODULES
        )
        self.transformer = transformer_cls(
            mode=self.mode,
            loader=self.loader,
            params=cfg_dataset["transformer"].get("params", {}),
        )
        if not isinstance(self.transformer, RecordTransformer):
            raise TypeError(
                f"transformer {self.transformer} is not of type RecordTransformer"
            )

        dataset_dir = os.path.join(artifact_dir, "dataset")
        if self.mode == RecordMode.TRAIN:
            logger.info("Creating record augmentor")
            self.augmentor = RecordAugmentor(cfg_dataset["augmentor"])
            logger.info(f"Fitting transform: {self.transformer.__class__.__name__}")
            self.transformer.fit(self.records.copy(deep=True))
            logger.info(
                f"Transformer network params: {self.transformer.network_params}"
            )
            logger.info("Saving transformer")
            self.transformer.save(dataset_dir)
        elif self.mode == RecordMode.VALIDATION or self.mode == RecordMode.SCORE:
            logger.info(f"Loading transform: {self.transformer.__class__.__name__}")
            self.transformer.load(dataset_dir)
Exemple #5
0
def test_reduce_compose_no_funcs():
    func = RecordAugmentor.reduce_compose()
    assert func(1) == 1
    assert func(2) == 2
    assert func(13) == 13