Exemplo n.º 1
0
    def train(
        self,
        cfg: dict,
        records_train: api.InputRecords,
        records_validation: api.InputRecords,
    ) -> tf.keras.Model:
        """Train the network.

        Args:
            cfg: dict, config.
            records_train: InputRecords, training records.
            records_validation: InputRecords, validation records.

        Returns:
            tf.keras.Model, trained network.
        """
        logger.info("Starting training")
        tf_utils.reset()
        cfg = config.prepare_config(cfg)

        logger.info(f"Creating artifact directory: {self.artifact_dir}")
        services.make_artifact_dir(self.artifact_dir)
        io_utils.save_json(cfg, "config.json", self.artifact_dir)
        io_utils.save_pickle(cfg, "config.pkl", self.artifact_dir)

        logger.info("Creating datasets")
        ds_train = dataset.RecordDataset(
            artifact_dir=self.artifact_dir,
            cfg_dataset=cfg["dataset"],
            records=records_train,
            mode=api.RecordMode.TRAIN,
            batch_size=cfg["solver"]["batch_size"],
        )
        ds_validation = dataset.RecordDataset(
            artifact_dir=self.artifact_dir,
            cfg_dataset=cfg["dataset"],
            records=records_validation,
            mode=api.RecordMode.VALIDATION,
            batch_size=cfg["solver"]["batch_size"],
        )
        network_params = ds_train.transformer.network_params
        io_utils.save_json(network_params, "network_params.json", self.artifact_dir)
        io_utils.save_pickle(network_params, "network_params.pkl", self.artifact_dir)

        logger.info("Building network")
        net = model.build_network(cfg["model"], network_params)
        model.check_output_names(cfg["model"], net)

        logger.info("Compiling network")
        opt = solver.build_optimizer(cfg["solver"])
        objective = model.build_objective(cfg["model"])
        net.compile(optimizer=opt, **objective)

        logger.info("Creating services")
        callbacks = services.create_all_services(self.artifact_dir, cfg["services"])

        if "learning_rate_reducer" in cfg["solver"]:
            logger.info("Creating learning rate reducer")
            callbacks.append(solver.create_learning_rate_reducer(cfg["solver"]))

        logger.info("Training network")
        net.summary()
        net.fit(
            ds_train,
            validation_data=ds_validation,
            epochs=cfg["solver"]["epochs"],
            steps_per_epoch=cfg["solver"].get("steps"),
            callbacks=callbacks,
            verbose=1,
        )

        return net
Exemplo n.º 2
0
    def __init__(
        self,
        artifact_dir: str,
        cfg_dataset: dict,
        records: Union[pd.DataFrame, core.Records],
        mode: core.RecordMode,
        batch_size: int,
    ):

        if not isinstance(mode, core.RecordMode):
            raise TypeError("mode must be type RecordMode")

        if isinstance(records, pd.DataFrame):
            records.reset_index(drop=True, inplace=True)
            self.records = records.to_dict(orient="records")
        elif all(isinstance(record, dict) for record in records):
            self.records = records
        else:
            raise TypeError(
                "record must be a list of dicts or pandas DataFrame")

        self.num_records = len(records)
        logger.info(f"Building {mode} dataset with {self.num_records} records")
        self.mode = mode
        self.batch_size = batch_size

        self.seed = cfg_dataset.get("seed")
        np.random.seed(self.seed)

        sample_count = cfg_dataset.get("sample_count")
        if self.mode == core.RecordMode.TRAIN and sample_count is not None:
            self._sample_inds = convert_sample_count_to_inds(
                [record[sample_count] for record in self.records])
        else:
            self._sample_inds = list(range(self.num_records))
        self.shuffle()

        logger.info(f"Creating record loader")
        loader_cls = import_utils.import_obj_with_search_modules(
            cfg_dataset["loader"]["import"], search_modules=SEARCH_MODULES)
        self.loader = loader_cls(mode=mode,
                                 params=cfg_dataset["loader"].get(
                                     "params", {}))
        if not isinstance(self.loader, RecordLoader):
            raise TypeError(
                f"loader {self.loader} is not of type RecordLoader")

        logger.info(f"Creating record transformer")
        transformer_cls = import_utils.import_obj_with_search_modules(
            cfg_dataset["transformer"]["import"],
            search_modules=SEARCH_MODULES)
        self.transformer = transformer_cls(
            mode=self.mode,
            loader=self.loader,
            params=cfg_dataset["transformer"].get("params", {}),
        )
        if not isinstance(self.transformer, RecordTransformer):
            raise TypeError(
                f"transformer {self.transformer} is not of type RecordTransformer"
            )

        dataset_dir = os.path.join(artifact_dir, "dataset")
        if self.mode == core.RecordMode.TRAIN:
            logger.info("Creating record augmentor")
            self.augmentor = RecordAugmentor(cfg_dataset["augmentor"])
            logger.info(
                f"Fitting transform: {self.transformer.__class__.__name__}")
            self.transformer.fit(copy.deepcopy(self.records))
            logger.info(
                f"Transformer network params: {self.transformer.network_params}"
            )
            logger.info("Saving transformer")
            self.transformer.save(dataset_dir)
        else:
            logger.info(
                f"Loading transform: {self.transformer.__class__.__name__}")
            self.transformer.load(dataset_dir)
Exemplo n.º 3
0
    def train(
        self,
        cfg: dict,
        records_train: Union[pd.DataFrame, api.Records],
        records_validation: Union[pd.DataFrame, api.Records],
        workers: int = 10,
        max_queue_size: int = 10,
    ) -> tf.keras.Model:
        """Train the network.

        Args:
            cfg: dict, config.
            records_train: Union[pd.DataFrame, Records], training records.
            records_validation: Union[pd.DataFrame, Records], validation records.
            workers: int (OPTIONAL = 10), number of process threads for the sequence.
            max_queue_size: int (OPTIONAL = 10), queue size for the sequence.

        Returns:
            tf.keras.Model, trained network.
        """
        logger.info("Starting training")
        tf_utils.reset()

        logger.info("Validating config schema and applying defaults")
        cfg = config.prepare_config(cfg)

        logger.info(f"Making artifact directory: {self._artifact_dir}")
        services.make_artifact_dir(self._artifact_dir)

        logger.info("Saving config")
        io_utils.save_json(cfg, "config.json", self._artifact_dir)
        io_utils.save_pickle(cfg, "config.pkl", self._artifact_dir)

        logger.info("Building datasets")
        ds_train = dataset.RecordDataset(
            artifact_dir=self._artifact_dir,
            cfg_dataset=cfg["dataset"],
            records=records_train,
            mode=api.RecordMode.TRAIN,
            batch_size=cfg["solver"]["batch_size"],
        )
        ds_validation = dataset.RecordDataset(
            artifact_dir=self._artifact_dir,
            cfg_dataset=cfg["dataset"],
            records=records_validation,
            mode=api.RecordMode.VALIDATION,
            batch_size=cfg["solver"]["batch_size"],
        )
        network_params = ds_train.transformer.network_params
        io_utils.save_json(network_params, "network_params.json", self._artifact_dir)
        io_utils.save_pickle(network_params, "network_params.pkl", self._artifact_dir)

        logger.info("Building network")
        net = model.build_network(cfg["model"], network_params)

        logger.info("Checking network output names match config output names")
        model.check_output_names(cfg["model"], net)

        logger.info("Building optimizer")
        opt = solver.build_optimizer(cfg["solver"])

        logger.info("Building objective")
        objective = model.build_objective(cfg["model"])

        logger.info("Compiling network")
        net.compile(optimizer=opt, **objective)
        metrics_names = net.metrics_names

        logger.info("Creating services")
        callbacks = services.create_all_services(
            self._artifact_dir, cfg["services"], metrics_names
        )

        if "learning_rate_reducer" in cfg["solver"]:
            logger.info("Creating learning rate reducer")
            callbacks.append(
                solver.create_learning_rate_reducer(cfg["solver"], metrics_names)
            )

        logger.info("Training network")
        logger.info(net.summary())
        net.fit_generator(
            ds_train,
            validation_data=ds_validation,
            epochs=cfg["solver"]["epochs"],
            steps_per_epoch=cfg["solver"].get("steps"),
            callbacks=callbacks,
            use_multiprocessing=(workers > 1),
            max_queue_size=max_queue_size,
            workers=workers,
            verbose=1,
        )

        return net