Пример #1
0
 def _get_callbacks(self):
     callbacks = [
         *self._get_extra_callbacks(),
         ModelCheckpoint(
             get_weights_path(self.output().path),
             save_best_only=True,
             monitor=self.monitor_metric,
             mode=self.monitor_mode,
         ),
         EarlyStopping(
             patience=self.early_stopping_patience,
             min_delta=self.early_stopping_min_delta,
             monitor=self.monitor_metric,
             mode=self.monitor_mode,
         ),
         CSVLogger(get_history_path(self.output().path)),
         TensorBoard(get_tensorboard_logdir(self.task_id), write_graph=False),
     ]
     if self.gradient_norm_clipping:
         callbacks.append(
             GradientNormClipping(
                 self.gradient_norm_clipping, self.gradient_norm_clipping_type
             )
         )
     return callbacks
Пример #2
0
    def _save_trial_log(self, i, trial) -> None:
        os.makedirs(os.path.join(self.output().path, "plot_history"),
                    exist_ok=True)

        if trial:
            history_df = pd.read_csv(get_history_path(self.output().path))
            plot_history(history_df).savefig(
                os.path.join(self.output().path, "plot_history",
                             "history_{}.jpg".format(i)))
            self._save_score_log(i, trial)
Пример #3
0
    def train(self):
        if self.device == "cuda":
            torch.cuda.set_device(self.device_id)

        train_loader = self.get_train_generator()
        val_loader = self.get_val_generator()
        module = self.create_module()

        print("train_data_frame:")
        print(self.train_data_frame.describe())

        summary_path = os.path.join(self.output().path, "summary.txt")
        with open(summary_path, "w") as summary_file:
            with redirect_stdout(summary_file):
                sample_input = self.get_sample_batch()
                summary(module, sample_input)
            summary(module, sample_input)

        sample_data = self.train_data_frame.sample(100, replace=True)
        sample_data.to_csv(os.path.join(self.output().path,
                                        "sample_train.csv"))

        trial = self.create_trial(module)

        try:
            trial.with_generators(
                train_generator=train_loader,
                val_generator=val_loader).run(epochs=self.epochs)
        except KeyboardInterrupt:
            print("Finishing the training at the request of the user...")

        history_df = pd.read_csv(get_history_path(self.output().path))

        plot_history(history_df).savefig(
            os.path.join(self.output().path, "history.jpg"))

        self.after_fit()
        self.evaluate()
        self.cache_cleanup()