Example #1
0
    def start(self, config: dict = None, **kwargs):
        eval_session_path = os.path.join(self._base_config.out_path, self._base_config.eval_session_id, "checkpoints",
                                         f"{self._base_config.eval_session_id}_weights.pt")

        if not os.path.exists(eval_session_path) or not self._base_config.eval_session_id.startswith("train"):
            raise ValueError(f"Session path '{eval_session_path}' does not exist or is not a training session.")

        config = fill_model_config(config, self._base_config)

        batch_processor = config.get("batch_processor", None)
        if batch_processor is None:
            batch_processor = get_batch_processor_from_config(self._base_config, config)
        validation_data = self._load_data(config.get("test_batch_size", self._base_config.test_batch_size))

        # noinspection PyUnresolvedReferences
        data_shape = validation_data.dataset.get_input_shape()
        # noinspection PyUnresolvedReferences
        num_classes = validation_data.dataset.get_num_classes()

        model, loss_function, _, _ = self._build_model(config, data_shape, num_classes)
        model.load_state_dict(torch.load(eval_session_path))
        progress = self._build_logging(len(validation_data))

        # skeleton_joints, = import_dataset_constants(self._base_config.dataset, ["skeleton_joints"])
        # with open(os.path.join(self._base_config.input_data[0][0], "val_files.pkl"), "rb") as f:
        #     sample_labels = pickle.load(f)

        metrics = self.build_metrics(num_classes, class_labels=self._base_config.class_labels, additional_metrics=[
            # MisclassifiedSamplesList("validation-sample-list", sample_labels, self._base_config.class_labels),
            F1MeasureMetric("validation-f1-measure"),
            # GlobalDynamicAdjacency("validation_global_dynamic_adjacency", "adj_b", labels=skeleton_joints),
            # target_indices
            # 177 - 20_s2_t2_skeleton.mat - knock (18) / catch (19)
            # 345 - 4_s6_t4_skeleton.mat - arm_cross (5) / clap (3)
            # 386 - 7_s4_t1_skeleton.mat - tennis_serve (16) / basketball_shoot (6)
            # DataDependentAdjacency("validation_data_dependent_adjacency", labels=skeleton_joints,
            #                        target_indices=[177, 345, 386])
        ])

        if progress:
            self.print_summary(model, **kwargs)
            print("Training configuration:", config)
            progress.begin_session(self.session_type)

        self.save_base_configuration()

        if progress:
            progress.begin_epoch(0)
            progress.begin_epoch_mode(0)

        Session.validate_epoch(batch_processor, model, loss_function, validation_data, progress, metrics, 0)

        # Save confusion matrix
        np.save(os.path.join(self.out_path, "validation-confusion.npy"), metrics["validation_confusion"].value.numpy())

        if progress:
            progress.end_epoch(metrics)
            progress.end_session()
Example #2
0
    def start(self, config: dict = None, **kwargs):
        """
        Start the training session. This function is not modifying object state in any way due to
        being called multiple times during hyperparameter tuning.

        :param config: training specific configuration
        :param kwargs: additional arguments
        """

        config = fill_model_config(config, self._base_config)

        reporter = kwargs.pop("reporter", None)
        batch_processor = config.get("batch_processor", None)
        if batch_processor is None:
            batch_processor = get_batch_processor_from_config(
                self._base_config, config)
        epochs = config.get("epochs", self._base_config.epochs)
        training_data, validation_data = self._load_data(
            config.get("batch_size", self._base_config.batch_size),
            config.get("test_batch_size", self._base_config.test_batch_size))

        # noinspection PyUnresolvedReferences
        data_shape = training_data.dataset.get_input_shape()
        # noinspection PyUnresolvedReferences
        num_classes = training_data.dataset.get_num_classes()

        config = session_helper.prepare_learning_rate_scheduler_args(
            config, epochs, len(training_data))
        model, loss_function, optimizer, lr_scheduler = self._build_model(
            config, data_shape, num_classes)
        progress, cp_manager = self._build_logging(batch_processor,
                                                   epochs,
                                                   len(training_data),
                                                   len(validation_data),
                                                   state_dict_objects={
                                                       "model": model,
                                                       "optimizer": optimizer,
                                                       "loss_function":
                                                       loss_function,
                                                       "lr_scheduler":
                                                       lr_scheduler
                                                   })
        metrics = self.build_metrics(
            num_classes, class_labels=self._base_config.class_labels)

        if progress:
            self.print_summary(model, **kwargs)
            print("Training configuration:", config)
            progress.begin_session(self.session_type)

        self.save_base_configuration()

        for epoch in range(epochs):
            # Begin epoch
            lr = optimizer.param_groups[0]["lr"]
            metrics["lr"].update(lr)
            if progress:
                progress.begin_epoch(epoch)

            # Training for current epoch
            if progress:
                progress.begin_epoch_mode(0)
            Session.train_epoch(batch_processor, model, loss_function,
                                training_data, optimizer, progress, metrics)

            # Validation for current epoch
            if progress:
                progress.begin_epoch_mode(1)
            Session.validate_epoch(batch_processor, model, loss_function,
                                   validation_data, progress, metrics)

            # Finalize epoch
            if progress:
                progress.end_epoch(metrics)

            if lr_scheduler:
                lr_scheduler.step()

            val_loss = metrics["validation_loss"].value
            val_acc = metrics["validation_accuracy"].value

            if reporter:
                reporter(mean_loss=val_loss, mean_accuracy=val_acc, lr=lr)
            if cp_manager:
                cp_manager.save_checkpoint(epoch, val_acc)
            metrics.reset_all()

        # Save weights at the end of training
        if cp_manager:
            cp_manager.save_weights(model, self.session_id)

        if progress:
            progress.end_session()