Exemple #1
0
            # evaluation
            self.seq_eval_engine.train_eval_seq(self.valid_data,
                                                self.test_data, engine, epoch)

    def train(self):
        """ Train and test NARM

        Returns:
            None
        """
        self.monitor = Monitor(log_dir=self.config["run_dir"],
                               delay=1,
                               gpu_id=self.gpu_id)
        train_loader = self.load_train_data
        self.engine = NARMEngine(self.config)
        self.narm_save_dir = os.path.join(self.config["model_save_dir"],
                                          self.config["save_name"])
        self._train(self.engine, train_loader, self.narm_save_dir)
        self.config["run_time"] = self.monitor.stop()
        self.seq_eval_engine.test_eval_seq(self.test_data, self.engine)


if __name__ == "__main__":
    args = parse_args()
    config = {}
    update_args(config, args)
    narm = NARM_train(config)
    narm.train()
    # narm.test() have already implemented in train()
Exemple #2
0
    def prepare_env(self):
        """Prepare running environment.

        * Load parameters from json files.
        * Initialize system folders, model name and the paths to be saved.
        * Initialize resource monitor.
        * Initialize random seed.
        * Initialize logging.
        """
        # Load config file from json
        with open(self.args.config_file) as config_params:
            print(f"loading config file {self.args.config_file}")
            config = json.load(config_params)

        # Update configs based on the received args from the command line .
        update_args(config, self.args)

        # obtain abspath for the project
        config["system"]["root_dir"] = os.path.abspath(
            config["system"]["root_dir"])

        # construct unique model run id, which consist of model name, config id and a timestamp
        timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
        random_str = "".join(
            [random.choice(string.ascii_lowercase) for n in range(6)])
        config["system"]["model_run_id"] = (config["model"]["model"] + "_" +
                                            config["model"]["config_id"] +
                                            "_" + timestamp_str + "_" +
                                            random_str)

        # Initialize random seeds
        set_seed(config["system"]["seed"] if "seed" in
                 config["system"] else 2020)

        # Initialize working folders
        self.initialize_folders(config)

        config["system"]["process_dir"] = os.path.join(
            config["system"]["root_dir"], config["system"]["process_dir"])

        # Initialize log file
        config["system"]["log_file"] = os.path.join(
            config["system"]["root_dir"],
            config["system"]["log_dir"],
            config["system"]["model_run_id"],
        )
        logger.init_std_logger(config["system"]["log_file"])

        print("Python version:", sys.version)
        print("pytorch version:", torch.__version__)

        #  File paths to be saved
        config["model"]["run_dir"] = os.path.join(
            config["system"]["root_dir"],
            config["system"]["run_dir"],
            config["system"]["model_run_id"],
        )
        config["system"]["run_dir"] = config["model"]["run_dir"]
        print(
            "The intermediate running statuses will be reported in folder:",
            config["system"]["run_dir"],
        )

        config["system"]["tune_dir"] = os.path.join(
            config["system"]["root_dir"], config["system"]["tune_dir"])

        def get_user_temp_dir():
            tempdir = os.path.join(config["system"]["root_dir"], "tmp")
            print(f"ray temp dir {tempdir}")
            return tempdir

        ray.utils.get_user_temp_dir = get_user_temp_dir

        #  Model checkpoints paths to be saved
        config["system"]["model_save_dir"] = os.path.join(
            config["system"]["root_dir"],
            config["system"]["checkpoint_dir"],
            config["system"]["model_run_id"],
        )
        ensureDir(config["system"]["model_save_dir"])
        print("Model checkpoint will save in file:",
              config["system"]["model_save_dir"])

        config["system"]["result_file"] = os.path.join(
            config["system"]["root_dir"],
            config["system"]["result_dir"],
            config["system"]["result_file"],
        )
        print("Performance result will save in file:",
              config["system"]["result_file"])

        print_dict_as_table(config["system"], "System configs")
        return config