예제 #1
0
    def save_model(self, output_model: ModelEntity):
        """
        Save the model after training is completed.
        """
        config = self.get_config()
        model_info = {
            "model": self.model.state_dict(),
            "config": config,
            "VERSION": 1,
        }
        buffer = io.BytesIO()
        torch.save(model_info, buffer)
        output_model.set_data("weights.pth", buffer.getvalue())
        output_model.set_data(
            "label_schema.json",
            label_schema_to_bytes(self.task_environment.label_schema))
        # store computed threshold
        output_model.set_data(
            "threshold", bytes(struct.pack("f", self.model.threshold.item())))

        f1_score = self.model.image_metrics.OptimalF1.compute().item()
        output_model.performance = Performance(
            score=ScoreMetric(name="F1 Score", value=f1_score))
        output_model.precision = [ModelPrecision.FP32]
예제 #2
0
    def train(self,
              dataset: DatasetEntity,
              output_model: ModelEntity,
              train_parameters: Optional[TrainParameters] = None):
        """ Trains a model on a dataset """

        train_model = deepcopy(self._model)

        if train_parameters is not None:
            update_progress_callback = train_parameters.update_progress
        else:
            update_progress_callback = default_progress_callback
        time_monitor = TrainingProgressCallback(
            update_progress_callback,
            num_epoch=self._cfg.train.max_epoch,
            num_train_steps=math.ceil(
                len(dataset.get_subset(Subset.TRAINING)) /
                self._cfg.train.batch_size),
            num_val_steps=0,
            num_test_steps=0)

        self.metrics_monitor = DefaultMetricsMonitor()
        self.stop_callback.reset()

        set_random_seed(self._cfg.train.seed)
        train_subset = dataset.get_subset(Subset.TRAINING)
        val_subset = dataset.get_subset(Subset.VALIDATION)
        self._cfg.custom_datasets.roots = [
            OTEClassificationDataset(train_subset,
                                     self._labels,
                                     self._multilabel,
                                     keep_empty_label=self._empty_label
                                     in self._labels),
            OTEClassificationDataset(val_subset,
                                     self._labels,
                                     self._multilabel,
                                     keep_empty_label=self._empty_label
                                     in self._labels)
        ]
        datamanager = torchreid.data.ImageDataManager(
            **imagedata_kwargs(self._cfg))

        num_aux_models = len(self._cfg.mutual_learning.aux_configs)

        if self._cfg.use_gpu:
            main_device_ids = list(range(self.num_devices))
            extra_device_ids = [main_device_ids for _ in range(num_aux_models)]
            train_model = DataParallel(train_model,
                                       device_ids=main_device_ids,
                                       output_device=0).cuda(
                                           main_device_ids[0])
        else:
            extra_device_ids = [None for _ in range(num_aux_models)]

        optimizer = torchreid.optim.build_optimizer(
            train_model, **optimizer_kwargs(self._cfg))

        if self._cfg.lr_finder.enable:
            scheduler = None
        else:
            scheduler = torchreid.optim.build_lr_scheduler(
                optimizer,
                num_iter=datamanager.num_iter,
                **lr_scheduler_kwargs(self._cfg))

        if self._cfg.lr_finder.enable:
            _, train_model, optimizer, scheduler = \
                        run_lr_finder(self._cfg, datamanager, train_model, optimizer, scheduler, None,
                                      rebuild_model=False, gpu_num=self.num_devices, split_models=False)

        _, final_acc = run_training(self._cfg,
                                    datamanager,
                                    train_model,
                                    optimizer,
                                    scheduler,
                                    extra_device_ids,
                                    self._cfg.train.lr,
                                    tb_writer=self.metrics_monitor,
                                    perf_monitor=time_monitor,
                                    stop_callback=self.stop_callback)

        training_metrics = self._generate_training_metrics_group()

        self.metrics_monitor.close()
        if self.stop_callback.check_stop():
            logger.info('Training cancelled.')
            return

        logger.info("Training finished.")

        best_snap_path = os.path.join(self._scratch_space, 'best.pth')
        if os.path.isfile(best_snap_path):
            load_pretrained_weights(self._model, best_snap_path)

        for filename in os.listdir(self._scratch_space):
            match = re.match(r'best_(aux_model_[0-9]+\.pth)', filename)
            if match:
                aux_model_name = match.group(1)
                best_aux_snap_path = os.path.join(self._scratch_space,
                                                  filename)
                self._aux_model_snap_paths[aux_model_name] = best_aux_snap_path

        self.save_model(output_model)
        performance = Performance(score=ScoreMetric(value=final_acc,
                                                    name="accuracy"),
                                  dashboard_metrics=training_metrics)
        logger.info(f'FINAL MODEL PERFORMANCE {performance}')
        output_model.performance = performance