def test_record_augmentor_no_funcs(): funcs = [] augmentor = RecordAugmentor(funcs) data_record_1 = ({"x1": 1, "x2": 1}, {"y": 0}) expected = ({"x1": 1, "x2": 1}, {"y": 0}) result = augmentor.augment(data_record_1) assert result == expected result = augmentor(data_record_1) assert result == expected
def test_record_augmentor(): funcs = [ { "import": "tests.unit.dataset.test_augmentor.augment_add", "params": { "input_key": "x1", "ind": 0 }, }, { "import": "tests.unit.dataset.test_augmentor.augment_add", "params": { "input_key": "x2", "ind": 0, "value": 2 }, }, { "import": "tests.unit.dataset.test_augmentor.augment_mult", "params": { "input_key": "x1", "ind": 0, "value": 5 }, }, { "import": "tests.unit.dataset.test_augmentor.augment_add", "params": { "input_key": "y", "ind": 1, "value": 4 }, }, ] augmentor = RecordAugmentor(funcs) data_record_1 = ({"x1": 1, "x2": 1}, {"y": 0}) expected = ({"x1": 5, "x2": 3}, {"y": 4}) result = augmentor.augment(data_record_1) assert result == expected data_record_2 = ({"x1": 3, "x2": 2}, {"y": 1}) expected = ({"x1": 15, "x2": 4}, {"y": 5}) result = augmentor(data_record_2) assert result == expected
def test_reduce_compose(): def add1(x): return x + 1 def mult2(x): return x * 2 func = RecordAugmentor.reduce_compose(mult2, add1) assert func(1) == 3 assert func(2) == 5 assert func(13) == 27
def __init__( self, artifact_dir: str, cfg_dataset: dict, records: pd.DataFrame, mode: RecordMode, batch_size: int, ): if not isinstance(records, pd.DataFrame): raise TypeError("records must be type pd.DataFrame") records.reset_index(drop=True, inplace=True) if not isinstance(mode, RecordMode): raise TypeError("mode must be type RecordMode") self.num_records = len(records) logger.info(f"Building {mode} dataset with {self.num_records} records") self.records = records self.mode = mode self.batch_size = batch_size self.seed = cfg_dataset.get("seed") np.random.seed(self.seed) sample_count = cfg_dataset.get("sample_count") if self.mode == RecordMode.TRAIN and sample_count is not None: self._sample_inds = convert_sample_count_to_inds(records[sample_count]) else: self._sample_inds = list(range(self.num_records)) self.shuffle() logger.info(f"Creating record loader") loader_cls = import_utils.import_obj_with_search_modules( cfg_dataset["loader"]["import"], search_modules=SEARCH_MODULES ) self.loader = loader_cls( mode=mode, params=cfg_dataset["loader"].get("params", {}) ) if not isinstance(self.loader, RecordLoader): raise TypeError(f"loader {self.loader} is not of type RecordLoader") logger.info(f"Creating record transformer") transformer_cls = import_utils.import_obj_with_search_modules( cfg_dataset["transformer"]["import"], search_modules=SEARCH_MODULES ) self.transformer = transformer_cls( mode=self.mode, loader=self.loader, params=cfg_dataset["transformer"].get("params", {}), ) if not isinstance(self.transformer, RecordTransformer): raise TypeError( f"transformer {self.transformer} is not of type RecordTransformer" ) dataset_dir = os.path.join(artifact_dir, "dataset") if self.mode == RecordMode.TRAIN: logger.info("Creating record augmentor") self.augmentor = RecordAugmentor(cfg_dataset["augmentor"]) logger.info(f"Fitting transform: {self.transformer.__class__.__name__}") self.transformer.fit(self.records.copy(deep=True)) logger.info( f"Transformer network params: {self.transformer.network_params}" ) logger.info("Saving transformer") self.transformer.save(dataset_dir) elif self.mode == RecordMode.VALIDATION or self.mode == RecordMode.SCORE: logger.info(f"Loading transform: {self.transformer.__class__.__name__}") self.transformer.load(dataset_dir)
def test_reduce_compose_no_funcs(): func = RecordAugmentor.reduce_compose() assert func(1) == 1 assert func(2) == 2 assert func(13) == 13