def train( self, cfg: dict, records_train: api.InputRecords, records_validation: api.InputRecords, ) -> tf.keras.Model: """Train the network. Args: cfg: dict, config. records_train: InputRecords, training records. records_validation: InputRecords, validation records. Returns: tf.keras.Model, trained network. """ logger.info("Starting training") tf_utils.reset() cfg = config.prepare_config(cfg) logger.info(f"Creating artifact directory: {self.artifact_dir}") services.make_artifact_dir(self.artifact_dir) io_utils.save_json(cfg, "config.json", self.artifact_dir) io_utils.save_pickle(cfg, "config.pkl", self.artifact_dir) logger.info("Creating datasets") ds_train = dataset.RecordDataset( artifact_dir=self.artifact_dir, cfg_dataset=cfg["dataset"], records=records_train, mode=api.RecordMode.TRAIN, batch_size=cfg["solver"]["batch_size"], ) ds_validation = dataset.RecordDataset( artifact_dir=self.artifact_dir, cfg_dataset=cfg["dataset"], records=records_validation, mode=api.RecordMode.VALIDATION, batch_size=cfg["solver"]["batch_size"], ) network_params = ds_train.transformer.network_params io_utils.save_json(network_params, "network_params.json", self.artifact_dir) io_utils.save_pickle(network_params, "network_params.pkl", self.artifact_dir) logger.info("Building network") net = model.build_network(cfg["model"], network_params) model.check_output_names(cfg["model"], net) logger.info("Compiling network") opt = solver.build_optimizer(cfg["solver"]) objective = model.build_objective(cfg["model"]) net.compile(optimizer=opt, **objective) logger.info("Creating services") callbacks = services.create_all_services(self.artifact_dir, cfg["services"]) if "learning_rate_reducer" in cfg["solver"]: logger.info("Creating learning rate reducer") callbacks.append(solver.create_learning_rate_reducer(cfg["solver"])) logger.info("Training network") net.summary() net.fit( ds_train, validation_data=ds_validation, epochs=cfg["solver"]["epochs"], steps_per_epoch=cfg["solver"].get("steps"), callbacks=callbacks, verbose=1, ) return net
def __init__( self, artifact_dir: str, cfg_dataset: dict, records: Union[pd.DataFrame, core.Records], mode: core.RecordMode, batch_size: int, ): if not isinstance(mode, core.RecordMode): raise TypeError("mode must be type RecordMode") if isinstance(records, pd.DataFrame): records.reset_index(drop=True, inplace=True) self.records = records.to_dict(orient="records") elif all(isinstance(record, dict) for record in records): self.records = records else: raise TypeError( "record must be a list of dicts or pandas DataFrame") self.num_records = len(records) logger.info(f"Building {mode} dataset with {self.num_records} records") self.mode = mode self.batch_size = batch_size self.seed = cfg_dataset.get("seed") np.random.seed(self.seed) sample_count = cfg_dataset.get("sample_count") if self.mode == core.RecordMode.TRAIN and sample_count is not None: self._sample_inds = convert_sample_count_to_inds( [record[sample_count] for record in self.records]) else: self._sample_inds = list(range(self.num_records)) self.shuffle() logger.info(f"Creating record loader") loader_cls = import_utils.import_obj_with_search_modules( cfg_dataset["loader"]["import"], search_modules=SEARCH_MODULES) self.loader = loader_cls(mode=mode, params=cfg_dataset["loader"].get( "params", {})) if not isinstance(self.loader, RecordLoader): raise TypeError( f"loader {self.loader} is not of type RecordLoader") logger.info(f"Creating record transformer") transformer_cls = import_utils.import_obj_with_search_modules( cfg_dataset["transformer"]["import"], search_modules=SEARCH_MODULES) self.transformer = transformer_cls( mode=self.mode, loader=self.loader, params=cfg_dataset["transformer"].get("params", {}), ) if not isinstance(self.transformer, RecordTransformer): raise TypeError( f"transformer {self.transformer} is not of type RecordTransformer" ) dataset_dir = os.path.join(artifact_dir, "dataset") if self.mode == core.RecordMode.TRAIN: logger.info("Creating record augmentor") self.augmentor = RecordAugmentor(cfg_dataset["augmentor"]) logger.info( f"Fitting transform: {self.transformer.__class__.__name__}") self.transformer.fit(copy.deepcopy(self.records)) logger.info( f"Transformer network params: {self.transformer.network_params}" ) logger.info("Saving transformer") self.transformer.save(dataset_dir) else: logger.info( f"Loading transform: {self.transformer.__class__.__name__}") self.transformer.load(dataset_dir)
def train( self, cfg: dict, records_train: Union[pd.DataFrame, api.Records], records_validation: Union[pd.DataFrame, api.Records], workers: int = 10, max_queue_size: int = 10, ) -> tf.keras.Model: """Train the network. Args: cfg: dict, config. records_train: Union[pd.DataFrame, Records], training records. records_validation: Union[pd.DataFrame, Records], validation records. workers: int (OPTIONAL = 10), number of process threads for the sequence. max_queue_size: int (OPTIONAL = 10), queue size for the sequence. Returns: tf.keras.Model, trained network. """ logger.info("Starting training") tf_utils.reset() logger.info("Validating config schema and applying defaults") cfg = config.prepare_config(cfg) logger.info(f"Making artifact directory: {self._artifact_dir}") services.make_artifact_dir(self._artifact_dir) logger.info("Saving config") io_utils.save_json(cfg, "config.json", self._artifact_dir) io_utils.save_pickle(cfg, "config.pkl", self._artifact_dir) logger.info("Building datasets") ds_train = dataset.RecordDataset( artifact_dir=self._artifact_dir, cfg_dataset=cfg["dataset"], records=records_train, mode=api.RecordMode.TRAIN, batch_size=cfg["solver"]["batch_size"], ) ds_validation = dataset.RecordDataset( artifact_dir=self._artifact_dir, cfg_dataset=cfg["dataset"], records=records_validation, mode=api.RecordMode.VALIDATION, batch_size=cfg["solver"]["batch_size"], ) network_params = ds_train.transformer.network_params io_utils.save_json(network_params, "network_params.json", self._artifact_dir) io_utils.save_pickle(network_params, "network_params.pkl", self._artifact_dir) logger.info("Building network") net = model.build_network(cfg["model"], network_params) logger.info("Checking network output names match config output names") model.check_output_names(cfg["model"], net) logger.info("Building optimizer") opt = solver.build_optimizer(cfg["solver"]) logger.info("Building objective") objective = model.build_objective(cfg["model"]) logger.info("Compiling network") net.compile(optimizer=opt, **objective) metrics_names = net.metrics_names logger.info("Creating services") callbacks = services.create_all_services( self._artifact_dir, cfg["services"], metrics_names ) if "learning_rate_reducer" in cfg["solver"]: logger.info("Creating learning rate reducer") callbacks.append( solver.create_learning_rate_reducer(cfg["solver"], metrics_names) ) logger.info("Training network") logger.info(net.summary()) net.fit_generator( ds_train, validation_data=ds_validation, epochs=cfg["solver"]["epochs"], steps_per_epoch=cfg["solver"].get("steps"), callbacks=callbacks, use_multiprocessing=(workers > 1), max_queue_size=max_queue_size, workers=workers, verbose=1, ) return net