Ejemplo n.º 1
0
    def save_metrics_value(self):
        """Save the metric value of the trained model.

        :return: save_path (local) and s3_path (remote). If s3_path not specified, then s3_path is None
        :rtype: a tuple of two str
        """
        pd_path = FileOps.join_path(self.trainer.local_output_path,
                                    self.trainer.step_name, "performace.csv")
        FileOps.make_base_dir(pd_path)
        encoding = self.model.nbit_w_list + self.model.nbit_a_list
        df = pd.DataFrame(
            [[encoding, self.flops_count, self.params_count, self.metric]],
            columns=[
                "encoding", "flops", "parameters",
                self.cfg.get("valid_metric", "acc")
            ])
        if not os.path.exists(pd_path):
            with open(pd_path, "w") as file:
                df.to_csv(file, index=False)
        else:
            with open(pd_path, "a") as file:
                df.to_csv(file, index=False, header=False)
        if self.trainer.backup_base_path is not None:
            FileOps.copy_folder(self.trainer.local_output_path,
                                self.trainer.backup_base_path)
Ejemplo n.º 2
0
    def update(self, worker_path):
        """Update function.

        :param worker_path: the worker_path that saved `performance.txt`.
        :type worker_path: str
        """
        if self.backup_base_path is not None:
            FileOps.copy_folder(self.local_base_path, self.backup_base_path)
Ejemplo n.º 3
0
 def _save_best_model(self):
     save_path = FileOps.join_path(self.trainer.get_local_worker_path(),
                                   self.trainer.step_name, "best_model.pth")
     FileOps.make_base_dir(save_path)
     torch.save(self.model.state_dict(), save_path)
     if self.trainer.backup_base_path is not None:
         _dst = FileOps.join_path(self.trainer.backup_base_path, "workers",
                                  str(self.trainer.worker_id))
         FileOps.copy_folder(self.trainer.get_local_worker_path(), _dst)
Ejemplo n.º 4
0
    def save_backup(self, performance):
        """Save checkpoints and performance file to backup path.

        :param performance: validated performance
        :type param: float, list or dict
        """
        if self.backup_base_path is None:
            return
        pfm_file = os.path.join(self.get_local_worker_path(), 'performance.txt')
        with open(pfm_file, 'w') as f:
            f.write("{}".format(performance))
        backup_worker_path = FileOps.join_path(self.backup_base_path, self.get_worker_subpath())
        FileOps.copy_folder(self.get_local_worker_path(), backup_worker_path)
Ejemplo n.º 5
0
 def update(self, worker_result_path):
     """Update sampler."""
     performance_file = self.performance_path(worker_result_path)
     logging.info(
         "SpNas.update(), performance file={}".format(performance_file))
     info = FileOps.load_pickle(performance_file)
     if info is not None:
         self._total_list.append(info)
     else:
         logging.info("SpNas.update(), file is not exited, "
                      "performance file={}".format(performance_file))
     self.save_output(self.local_output_path)
     if self.backup_base_path is not None:
         FileOps.copy_folder(self.local_output_path, self.backup_base_path)
Ejemplo n.º 6
0
    def update(self, worker_result_path):
        """Use train and evaluate result to update algorithm.

        :param worker_result_path: current result path
        :type: str
        """
        step_name = os.path.basename(os.path.dirname(worker_result_path))
        config_id = int(os.path.basename(worker_result_path))
        performance = self._get_performance(step_name, config_id)
        logging.info("update performance={}".format(performance))
        self.pareto_front.add_pareto_score(config_id, performance)
        self.save_output(self.local_output_path)
        if self.backup_base_path is not None:
            FileOps.copy_folder(self.local_base_path, self.backup_base_path)
Ejemplo n.º 7
0
 def _init_dataloader(self):
     """Init dataloader from timm."""
     if self.distributed and hvd.local_rank(
     ) == 0 and 'remote_data_dir' in self.config.dataset:
         FileOps.copy_folder(self.config.dataset.remote_data_dir,
                             self.config.dataset.data_dir)
     if self.distributed:
         hvd.join()
     args = self.config.dataset
     train_dir = os.path.join(self.config.dataset.data_dir, 'train')
     dataset_train = Dataset(train_dir)
     world_size, rank = None, None
     if self.distributed:
         world_size, rank = hvd.size(), hvd.rank()
     self.trainer.train_loader = create_loader(
         dataset_train,
         input_size=tuple(args.input_size),
         batch_size=args.batch_size,
         is_training=True,
         use_prefetcher=self.config.prefetcher,
         rand_erase_prob=args.reprob,
         rand_erase_mode=args.remode,
         rand_erase_count=args.recount,
         color_jitter=args.color_jitter,
         auto_augment=args.aa,
         interpolation='random',
         mean=tuple(args.mean),
         std=tuple(args.std),
         num_workers=args.workers,
         distributed=self.distributed,
         world_size=world_size,
         rank=rank)
     valid_dir = os.path.join(self.config.dataset.data_dir, 'val')
     dataset_eval = Dataset(valid_dir)
     self.trainer.valid_loader = create_loader(
         dataset_eval,
         input_size=tuple(args.input_size),
         batch_size=4 * args.batch_size,
         is_training=False,
         use_prefetcher=self.config.prefetcher,
         interpolation=args.interpolation,
         mean=tuple(args.mean),
         std=tuple(args.std),
         num_workers=args.workers,
         distributed=self.distributed,
         world_size=world_size,
         rank=rank)
Ejemplo n.º 8
0
    def update(self, step_name, worker_id):
        """Update hpo score into score board.

        :param step_name: step name in pipeline
        :param worker_id: worker id of worker

        """
        worker_id = str(worker_id)
        performance = self._get_performance(step_name, worker_id)
        if worker_id in self._hps_cache:
            hps = self._hps_cache[worker_id][0]
            self._hps_cache[worker_id][1] = copy.deepcopy(performance)
            logging.info("get hps need to update, worker_id=%s, hps=%s", worker_id, str(hps))
            self.update_performance(hps, performance)
            logging.info("hpo_id=%s, hps=%s, performance=%s", worker_id, str(hps), str(performance))
            self._save_hpo_cache()
            self._save_score_board()
            self._save_best()
            if self.need_backup and self.backup_base_path is not None:
                FileOps.copy_folder(self.local_output_path,
                                    FileOps.join_path(self.backup_base_path, self.output_subpath))
            logger.info("Hpo update finished.")
        else:
            logger.error("worker_id not in hps_cache.")
Ejemplo n.º 9
0
 def update(self, worker_path):
     """Update QuantEA."""
     if self.backup_base_path is not None:
         FileOps.copy_folder(self.local_output_path, self.backup_base_path)