def start(self, config: dict = None, **kwargs): eval_session_path = os.path.join(self._base_config.out_path, self._base_config.eval_session_id, "checkpoints", f"{self._base_config.eval_session_id}_weights.pt") if not os.path.exists(eval_session_path) or not self._base_config.eval_session_id.startswith("train"): raise ValueError(f"Session path '{eval_session_path}' does not exist or is not a training session.") config = fill_model_config(config, self._base_config) batch_processor = config.get("batch_processor", None) if batch_processor is None: batch_processor = get_batch_processor_from_config(self._base_config, config) validation_data = self._load_data(config.get("test_batch_size", self._base_config.test_batch_size)) # noinspection PyUnresolvedReferences data_shape = validation_data.dataset.get_input_shape() # noinspection PyUnresolvedReferences num_classes = validation_data.dataset.get_num_classes() model, loss_function, _, _ = self._build_model(config, data_shape, num_classes) model.load_state_dict(torch.load(eval_session_path)) progress = self._build_logging(len(validation_data)) # skeleton_joints, = import_dataset_constants(self._base_config.dataset, ["skeleton_joints"]) # with open(os.path.join(self._base_config.input_data[0][0], "val_files.pkl"), "rb") as f: # sample_labels = pickle.load(f) metrics = self.build_metrics(num_classes, class_labels=self._base_config.class_labels, additional_metrics=[ # MisclassifiedSamplesList("validation-sample-list", sample_labels, self._base_config.class_labels), F1MeasureMetric("validation-f1-measure"), # GlobalDynamicAdjacency("validation_global_dynamic_adjacency", "adj_b", labels=skeleton_joints), # target_indices # 177 - 20_s2_t2_skeleton.mat - knock (18) / catch (19) # 345 - 4_s6_t4_skeleton.mat - arm_cross (5) / clap (3) # 386 - 7_s4_t1_skeleton.mat - tennis_serve (16) / basketball_shoot (6) # DataDependentAdjacency("validation_data_dependent_adjacency", labels=skeleton_joints, # target_indices=[177, 345, 386]) ]) if progress: self.print_summary(model, **kwargs) print("Training configuration:", config) progress.begin_session(self.session_type) self.save_base_configuration() if progress: progress.begin_epoch(0) progress.begin_epoch_mode(0) Session.validate_epoch(batch_processor, model, loss_function, validation_data, progress, metrics, 0) # Save confusion matrix np.save(os.path.join(self.out_path, "validation-confusion.npy"), metrics["validation_confusion"].value.numpy()) if progress: progress.end_epoch(metrics) progress.end_session()
def start(self, config: dict = None, **kwargs): """ Start the training session. This function is not modifying object state in any way due to being called multiple times during hyperparameter tuning. :param config: training specific configuration :param kwargs: additional arguments """ config = fill_model_config(config, self._base_config) reporter = kwargs.pop("reporter", None) batch_processor = config.get("batch_processor", None) if batch_processor is None: batch_processor = get_batch_processor_from_config( self._base_config, config) epochs = config.get("epochs", self._base_config.epochs) training_data, validation_data = self._load_data( config.get("batch_size", self._base_config.batch_size), config.get("test_batch_size", self._base_config.test_batch_size)) # noinspection PyUnresolvedReferences data_shape = training_data.dataset.get_input_shape() # noinspection PyUnresolvedReferences num_classes = training_data.dataset.get_num_classes() config = session_helper.prepare_learning_rate_scheduler_args( config, epochs, len(training_data)) model, loss_function, optimizer, lr_scheduler = self._build_model( config, data_shape, num_classes) progress, cp_manager = self._build_logging(batch_processor, epochs, len(training_data), len(validation_data), state_dict_objects={ "model": model, "optimizer": optimizer, "loss_function": loss_function, "lr_scheduler": lr_scheduler }) metrics = self.build_metrics( num_classes, class_labels=self._base_config.class_labels) if progress: self.print_summary(model, **kwargs) print("Training configuration:", config) progress.begin_session(self.session_type) self.save_base_configuration() for epoch in range(epochs): # Begin epoch lr = optimizer.param_groups[0]["lr"] metrics["lr"].update(lr) if progress: progress.begin_epoch(epoch) # Training for current epoch if progress: progress.begin_epoch_mode(0) Session.train_epoch(batch_processor, model, loss_function, training_data, optimizer, progress, metrics) # Validation for current epoch if progress: progress.begin_epoch_mode(1) Session.validate_epoch(batch_processor, model, loss_function, validation_data, progress, metrics) # Finalize epoch if progress: progress.end_epoch(metrics) if lr_scheduler: lr_scheduler.step() val_loss = metrics["validation_loss"].value val_acc = metrics["validation_accuracy"].value if reporter: reporter(mean_loss=val_loss, mean_accuracy=val_acc, lr=lr) if cp_manager: cp_manager.save_checkpoint(epoch, val_acc) metrics.reset_all() # Save weights at the end of training if cp_manager: cp_manager.save_weights(model, self.session_id) if progress: progress.end_session()