Esempio n. 1
0
    def init_training(self):
        model = self.elements["model"]
        start_epoch = self.params["start_epoch"]
        exist_model = self.params["exist_model"]
        model_dir = self.params["model_dir"]
        model_blueprint = self.params["model_blueprint"]
        suffix = self.params["suffix"]

        if start_epoch <= 0 and utils.is_main_training():
            model_creation = model.get_model_creation()
            utils.write_nnet_config(model_blueprint, model_creation, "{0}/config/nnet.config".format(model_dir))

        ## Recover checkpoint | Tansform learning | Initialize parametes 
        if start_epoch > 0:
            # This train_stage is equal to number of completed epoch
            if utils.is_main_training(): logger.info("Recover training from {0} epoch.".format(start_epoch))
            model.load_state_dict(torch.load('{0}/{1}.{2}'.format(model_dir, start_epoch, suffix), 
                                             map_location="cpu"))
        elif os.path.exists(exist_model):
            if utils.is_main_training(): logger.info("Use {0} as the initial model to start transform-training.".format(exist_model))
            model.load_transform_state_dict(torch.load(exist_model, map_location="cpu"))
        else:
            # Just use the raw initial model or initialize it again by some initial functions here
            pass # Now, it means use the raw initial model

        if utils.use_horovod():
            import horovod.torch as hvd

            # Broadcast parameters from rank 0 to all other processes.
            hvd.broadcast_parameters(self.elements["model"].state_dict(), root_rank=0)

             # For optimizer wrapper such as lookahead.
            if getattr(self.elements["optimizer"], "optimizer", None) is not None:
                raise TypeError("Do not support using lookahead with horovod now.")
            else:
                # Broadcast optimizer state.
                hvd.broadcast_optimizer_state(self.elements["optimizer"], root_rank=0)
                self.elements["optimizer"] = hvd.DistributedOptimizer(self.elements["optimizer"], 
                                             named_parameters=self.elements["model"].named_parameters())

        ## Select device
        model = self.select_device()

        # Original model is built in libs.nnet.framework.TopVirtualNnet, and it is not available after
        # wrapped by DistributedDataParallel. So, to call functions of TopVirtualNnet conveniently, the 
        # self.elements["model_forward"] is set here to name DistributedDataParallel.
        if isinstance(model, torch.nn.parallel.DistributedDataParallel):
            self.elements["model"] = model.module
            self.elements["model_forward"] = model
Esempio n. 2
0
    # Run a batch extracting process.
    try:
        for position in to_extracted_positions:
            # Generate the extracting config from nnet config where
            # which position to extract depends on the 'extracted_embedding' parameter of model_creation (by my design).
            model_blueprint, model_creation = utils.read_nnet_config(
                "{0}/config/nnet.config".format(model_dir))
            model_creation = model_creation.replace(
                "training=True", "training=False"
            )  # To save memory without loading some independent components.
            model_creation = model_creation.replace(
                model_params["extracted_embedding"], position)
            extract_config = "{0}.extract.config".format(position)
            utils.write_nnet_config(
                model_blueprint, model_creation,
                "{0}/config/{1}".format(model_dir, extract_config))
            for epoch in to_extracted_epochs:
                model_file = "{0}.{1}".format(epoch, suffix)
                point_name = "{0}_epoch_{1}".format(position, epoch)

                # If run a trainer with background thread (do not be supported now) or run this launcher extrally with stage=4
                # (it means another process), then this while-listen is useful to start extracting immediately (but require more gpu-memory).
                model_path = "{0}/{1}".format(model_dir, model_file)
                while True:
                    if os.path.exists(model_path):
                        break
                    else:
                        time.sleep(sleep_time)

                for data in to_extracted_data: