Esempio n. 1
0
    def _save_trial_log(self, i, trial) -> None:
        os.makedirs(os.path.join(self.output().path, "plot_history"),
                    exist_ok=True)

        if trial:
            history_df = pd.read_csv(get_history_path(self.output().path))
            plot_history(history_df).savefig(
                os.path.join(self.output().path, "plot_history",
                             "history_{}.jpg".format(i)))
            self._save_score_log(i, trial)
Esempio n. 2
0
    def train(self):
        if self.device == "cuda":
            torch.cuda.set_device(self.device_id)

        train_loader = self.get_train_generator()
        val_loader = self.get_val_generator()
        module = self.create_module()

        print("train_data_frame:")
        print(self.train_data_frame.describe())

        summary_path = os.path.join(self.output().path, "summary.txt")
        with open(summary_path, "w") as summary_file:
            with redirect_stdout(summary_file):
                sample_input = self.get_sample_batch()
                summary(module, sample_input)
            summary(module, sample_input)

        sample_data = self.train_data_frame.sample(100, replace=True)
        sample_data.to_csv(os.path.join(self.output().path,
                                        "sample_train.csv"))

        trial = self.create_trial(module)

        try:
            trial.with_generators(
                train_generator=train_loader,
                val_generator=val_loader).run(epochs=self.epochs)
        except KeyboardInterrupt:
            print("Finishing the training at the request of the user...")

        history_df = pd.read_csv(get_history_path(self.output().path))

        plot_history(history_df).savefig(
            os.path.join(self.output().path, "history.jpg"))

        self.after_fit()
        self.evaluate()
        self.cache_cleanup()