Exemplo n.º 1
0
 def _init_model(self):
     """Load model desc from save path and parse to model."""
     model = self.trainer.model
     if self.trainer.config.is_detection_trainer:
         model_desc = self.trainer.model_desc
     else:
         model_desc = self._get_model_desc()
     if model_desc:
         ModelConfig.model_desc = model_desc
     pretrained_model_file = self._get_pretrained_model_file()
     if not model:
         if not model_desc:
             raise Exception(
                 "Failed to Init model, can not get model description.")
         model = ModelZoo.get_model(model_desc, pretrained_model_file)
     if model:
         if zeus.is_torch_backend():
             import torch
             if self.trainer.use_cuda:
                 model = model.cuda()
             if General._parallel and General.devices_per_trainer > 1:
                 model = torch.nn.DataParallel(self.trainer.model)
         if zeus.is_tf_backend():
             if pretrained_model_file:
                 model_folder = os.path.dirname(pretrained_model_file)
                 FileOps.copy_folder(model_folder,
                                     self.trainer.get_local_worker_path())
     return model
Exemplo n.º 2
0
 def _do_horovod_fully_train(self):
     pwd_dir = os.path.dirname(os.path.abspath(__file__))
     cf_file = os.path.join(pwd_dir, 'cf.pickle')
     cf_content = {
         'registry': ClassFactory.__registry__,
         'general_config': General().to_json(),
         'pipe_step_config': PipeStepConfig().to_json()
     }
     with open(cf_file, 'wb') as f:
         pickle.dump(cf_content, f)
     cf_file_remote = os.path.join(self.task.local_base_path, 'cf.pickle')
     FileOps.copy_file(cf_file, cf_file_remote)
     if os.environ.get('DLS_TASK_NUMBER') is None:
         # local cluster
         worker_ips = '127.0.0.1'
         if General.cluster.master_ip is not None and General.cluster.master_ip != '127.0.0.1':
             worker_ips = General.cluster.master_ip
             for ip in General.cluster.slaves:
                 worker_ips = worker_ips + ',' + ip
         cmd = [
             'bash',
             '{}/horovod/run_cluster_horovod_train.sh'.format(pwd_dir),
             str(self.world_device_size), cf_file_remote, worker_ips
         ]
     else:
         # Roma
         cmd = [
             'bash', '{}/horovod/run_horovod_train.sh'.format(pwd_dir),
             str(self.world_device_size), cf_file_remote
         ]
     proc = subprocess.Popen(cmd, env=os.environ)
     proc.wait()
Exemplo n.º 3
0
    def dataset_init(self):
        """Construct method.

        If both data_dir and label_dir are provided, then use data_dir and label_dir
        Otherwise use root_dir and list_file.
        """
        if "data_dir" in self.args and "label_dir" in self.args:
            self.args.data_dir = FileOps.download_dataset(self.args.data_dir)
            self.args.label_dir = FileOps.download_dataset(self.args.label_dir)
            self.data_files = sorted(glob.glob(osp.join(self.args.data_dir, "*")))
            self.label_files = sorted(glob.glob(osp.join(self.args.label_dir, "*")))
        else:
            if "root_dir" not in self.args or "list_file" not in self.args:
                raise Exception("You must provide a root_dir and a list_file!")
            self.args.root_dir = FileOps.download_dataset(self.args.root_dir)
            with open(osp.join(self.args.root_dir, self.args.list_file)) as f:
                lines = f.readlines()
            self.data_files = [None] * len(lines)
            self.label_files = [None] * len(lines)
            for i, line in enumerate(lines):
                data_file_name, label_file_name = line.strip().split()
                self.data_files[i] = osp.join(self.args.root_dir, data_file_name)
                self.label_files[i] = osp.join(self.args.root_dir, label_file_name)

        datatype = self._get_datatype()
        if datatype == "image":
            self.read_fn = self._read_item_image
        else:
            self.read_fn = self._read_item_pickle
Exemplo n.º 4
0
 def _save_worker_record(cls, record):
     step_name = record.get('step_name')
     worker_id = record.get('worker_id')
     _path = TaskOps().get_local_worker_path(step_name, worker_id)
     for record_name in ["desc", "performance"]:
         _file_name = None
         _file = None
         record_value = record.get(record_name)
         if not record_value:
             continue
         _file = None
         try:
             # for cars/darts save multi-desc
             if isinstance(record_value, list) and record_name == "desc":
                 for idx, value in enumerate(record_value):
                     _file_name = "desc_{}.json".format(idx)
                     _file = FileOps.join_path(_path, _file_name)
                     with open(_file, "w") as f:
                         json.dump(value, f)
             else:
                 _file_name = None
                 if record_name == "desc":
                     _file_name = "desc_{}.json".format(worker_id)
                 if record_name == "performance":
                     _file_name = "performance_{}.json".format(worker_id)
                 _file = FileOps.join_path(_path, _file_name)
                 with open(_file, "w") as f:
                     json.dump(record_value, f)
         except Exception as ex:
             logging.error(
                 "Failed to save {}, file={}, desc={}, msg={}".format(
                     record_name, _file, record_value, str(ex)))
Exemplo n.º 5
0
 def _append_record_to_csv(self,
                           record_name=None,
                           step_name=None,
                           record=None,
                           mode='a'):
     """Transfer record to csv file."""
     local_output_path = os.path.join(TaskOps().local_output_path,
                                      step_name)
     logging.debug(
         "recode to csv, local_output_path={}".format(local_output_path))
     if not record_name and os.path.exists(local_output_path):
         return
     file_path = os.path.join(local_output_path,
                              "{}.csv".format(record_name))
     FileOps.make_base_dir(file_path)
     try:
         for key in record:
             if isinstance(record[key], dict) or isinstance(
                     record[key], list):
                 record[key] = str(record[key])
         data = pd.DataFrame([record])
         if not os.path.exists(file_path):
             data.to_csv(file_path, index=False)
         elif os.path.exists(file_path) and os.path.getsize(
                 file_path) and mode == 'a':
             data.to_csv(file_path, index=False, mode=mode, header=0)
         else:
             data.to_csv(file_path, index=False, mode=mode)
     except Exception as ex:
         logging.info(
             'Can not transfer record to csv file Error: {}'.format(ex))
Exemplo n.º 6
0
    def dataset_init(self):
        """Costruct method, which will load some dateset information."""
        self.args.root_HR = FileOps.download_dataset(self.args.root_HR)
        self.args.root_LR = FileOps.download_dataset(self.args.root_LR)
        if self.args.subfile is not None:
            with open(self.args.subfile
                      ) as f:  # lmdb format has no self.args.subfile
                file_names = sorted([line.rstrip('\n') for line in f])
                self.datatype = util.get_files_datatype(file_names)
                self.paths_HR = [
                    os.path.join(self.args.root_HR, file_name)
                    for file_name in file_names
                ]
                self.paths_LR = [
                    os.path.join(self.args.root_LR, file_name)
                    for file_name in file_names
                ]
        else:
            self.datatype = util.get_datatype(self.args.root_LR)
            self.paths_LR = util.get_paths_from_dir(self.args.root_LR)
            self.paths_HR = util.get_paths_from_dir(self.args.root_HR)

        if self.args.save_in_memory:
            self.imgs_LR = [self._read_img(path) for path in self.paths_LR]
            self.imgs_HR = [self._read_img(path) for path in self.paths_HR]
Exemplo n.º 7
0
def dump_model_visual_info(trainer, epoch, model, inputs):
    """Dump model to tensorboard event files.

    :param trainer: trainer.
    :type worker: object that the class was inherited from DistributedWorker.
    :param model: model.
    :type model: model.
    :param inputs: input data.
    :type inputs: data.

    """
    (_, visual, interval, title, worker_id, output_path) = _get_trainer_info(trainer)
    if visual is not True:
        return
    if epoch % interval != 0:
        return
    title = str(worker_id)
    _path = FileOps.join_path(output_path, title)
    FileOps.make_dir(_path)
    try:
        with SummaryWriter(_path) as writer:
            writer.add_graph(model, (inputs,))
    except Exception as e:
        logging.error("Failed to dump model visual info, worker id: {}, epoch: {}, error: {}".format(
            worker_id, epoch, str(e)
        ))
Exemplo n.º 8
0
 def _get_model_desc(self):
     model_desc = self.trainer.model_desc
     if not model_desc or 'modules' not in model_desc:
         if ModelConfig.model_desc_file is not None:
             desc_file = ModelConfig.model_desc_file
             desc_file = desc_file.replace("{local_base_path}",
                                           self.trainer.local_base_path)
             if ":" not in desc_file:
                 desc_file = os.path.abspath(desc_file)
             if ":" in desc_file:
                 local_desc_file = FileOps.join_path(
                     self.trainer.local_output_path,
                     os.path.basename(desc_file))
                 FileOps.copy_file(desc_file, local_desc_file)
                 desc_file = local_desc_file
             model_desc = Config(desc_file)
             logger.info("net_desc:{}".format(model_desc))
         elif ModelConfig.model_desc is not None:
             model_desc = ModelConfig.model_desc
         elif ModelConfig.models_folder is not None:
             folder = ModelConfig.models_folder.replace(
                 "{local_base_path}", self.trainer.local_base_path)
             pattern = FileOps.join_path(folder, "desc_*.json")
             desc_file = glob.glob(pattern)[0]
             model_desc = Config(desc_file)
         else:
             return None
     return model_desc
Exemplo n.º 9
0
    def dump(self):
        """Dump report to file."""
        try:
            _file = FileOps.join_path(TaskOps().step_path, "reports.csv")
            FileOps.make_base_dir(_file)
            data = self.all_records
            data_dict = {}
            for step in data:
                step_data = step.serialize().items()
                for k, v in step_data:
                    if k in data_dict:
                        data_dict[k].append(v)
                    else:
                        data_dict[k] = [v]

            data = pd.DataFrame(data_dict)
            data.to_csv(_file, index=False)
            _file = os.path.join(TaskOps().step_path, ".reports")
            _dump_data = [
                ReportServer._hist_records, ReportServer.__instances__
            ]
            with open(_file, "wb") as f:
                pickle.dump(_dump_data, f, protocol=pickle.HIGHEST_PROTOCOL)

            self.backup_output_path()
        except Exception:
            logging.warning(traceback.format_exc())
Exemplo n.º 10
0
 def _backup(self):
     """Backup result worker folder."""
     if self.need_backup is True and self.backup_base_path is not None:
         backup_worker_path = FileOps.join_path(
             self.backup_base_path, self.get_worker_subpath())
         FileOps.copy_folder(
             self.get_local_worker_path(self.step_name, self.worker_id), backup_worker_path)
Exemplo n.º 11
0
    def _save_checkpoint(self, epoch, best=False):
        """Save model weights.

        :param epoch: current epoch
        :type epoch: int
        """
        save_dir = os.path.join(self.worker_path, str(epoch))
        FileOps.make_dir(save_dir)
        for name in self.model.model_names:
            if isinstance(name, str):
                save_filename = '%s_net_%s.pth' % (epoch, name)
                save_path = FileOps.join_path(save_dir, save_filename)
                net = getattr(self.model, 'net' + name)
                best_file = FileOps.join_path(self.worker_path,
                                              "model_{}.pth".format(name))
                if self.cfg.cuda and torch.cuda.is_available():
                    # torch.save(net.module.cpu().state_dict(), save_path)
                    torch.save(net.module.state_dict(), save_path)
                    # net.cuda()
                    if best:
                        torch.save(net.module.state_dict(), best_file)
                else:
                    torch.save(net.cpu().state_dict(), save_path)
                    if best:
                        torch.save(net.cpu().state_dict(), best_file)
Exemplo n.º 12
0
 def search(self):
     """Search an id and hps from hpo."""
     sample = self.hpo.propose()
     if sample is None:
         return None
     re_hps = {}
     sample = copy.deepcopy(sample)
     sample_id = sample.get('config_id')
     trans_para = sample.get('configs')
     rung_id = sample.get('rung_id')
     all_para = sample.get('all_configs')
     re_hps['dataset.transforms'] = [{
         'type': 'PBATransformer',
         'para_array': trans_para,
         'all_para': all_para,
         'operation_names': self.operation_names
     }]
     checkpoint_path = FileOps.join_path(self.local_base_path, 'worker',
                                         'cache', str(sample_id),
                                         'checkpoint')
     FileOps.make_dir(checkpoint_path)
     if os.path.exists(checkpoint_path):
         re_hps['trainer.checkpoint_path'] = checkpoint_path
     if 'epoch' in sample:
         re_hps['trainer.epochs'] = sample.get('epoch')
     return dict(worker_id=sample_id, encoded_desc=re_hps, rung_id=rung_id)
Exemplo n.º 13
0
 def _save_best_model(self):
     """Save best model."""
     if zeus.is_torch_backend():
         torch.save(self.trainer.model.state_dict(),
                    self.trainer.weights_file)
     elif zeus.is_tf_backend():
         worker_path = self.trainer.get_local_worker_path()
         model_id = "model_{}".format(self.trainer.worker_id)
         weights_folder = FileOps.join_path(worker_path, model_id)
         FileOps.make_dir(weights_folder)
         checkpoint_file = tf.train.latest_checkpoint(worker_path)
         ckpt_globs = glob.glob("{}.*".format(checkpoint_file))
         for _file in ckpt_globs:
             dst_file = model_id + os.path.splitext(_file)[-1]
             FileOps.copy_file(_file,
                               FileOps.join_path(weights_folder, dst_file))
         FileOps.copy_file(FileOps.join_path(worker_path, 'checkpoint'),
                           weights_folder)
     elif zeus.is_ms_backend():
         worker_path = self.trainer.get_local_worker_path()
         save_path = os.path.join(
             worker_path, "model_{}.ckpt".format(self.trainer.worker_id))
         for file in os.listdir(worker_path):
             if file.startswith("CKP") and file.endswith(".ckpt"):
                 self.weights_file = FileOps.join_path(worker_path, file)
                 os.rename(self.weights_file, save_path)
Exemplo n.º 14
0
    def before_train(self, logs=None):
        """Call before_train of the managed callbacks."""
        super().before_train(logs)
        """Be called before the training process."""
        hpo_result = FileOps.load_pickle(
            FileOps.join_path(self.trainer.local_output_path,
                              'best_config.pickle'))
        logging.info("loading stage1_hpo_result \n{}".format(hpo_result))

        feature_interaction_score = hpo_result['feature_interaction_score']
        print('feature_interaction_score:', feature_interaction_score)
        sorted_pairs = sorted(feature_interaction_score.items(),
                              key=lambda x: abs(x[1]),
                              reverse=True)

        if ModelConfig.model_desc:
            fis_ratio = ModelConfig.model_desc["custom"]["fis_ratio"]
        else:
            fis_ratio = 1.0
        top_k = int(len(feature_interaction_score) * min(1.0, fis_ratio))
        self.selected_pairs = list(map(lambda x: x[0], sorted_pairs[:top_k]))

        # add selected_pairs
        setattr(ModelConfig.model_desc['custom'], 'selected_pairs',
                self.selected_pairs)
Exemplo n.º 15
0
    def after_valid(self, logs=None):
        """Call after_valid of the managed callbacks."""
        self.model = self.trainer.model
        feature_interaction_score = self.model.get_feature_interaction_score()
        print('get feature_interaction_score', feature_interaction_score)
        feature_interaction = []
        for feature in feature_interaction_score:
            if abs(feature_interaction_score[feature]) > 0:
                feature_interaction.append(feature)
        print('get feature_interaction', feature_interaction)

        curr_auc = float(self.trainer.valid_metrics.results['auc'])
        if curr_auc > self.best_score:
            best_config = {
                'score': curr_auc,
                'feature_interaction': feature_interaction
            }

            logging.info("BEST CONFIG IS\n{}".format(best_config))
            pickle_result_file = FileOps.join_path(
                self.trainer.local_output_path, 'best_config.pickle')
            logging.info("Saved to {}".format(pickle_result_file))
            FileOps.dump_pickle(best_config, pickle_result_file)

            self.best_score = curr_auc
Exemplo n.º 16
0
 def __init__(self, **kwargs):
     """Construct the Imagenet class."""
     Dataset.__init__(self, **kwargs)
     self.args.data_path = FileOps.download_dataset(self.args.data_path)
     split = 'train' if self.mode == 'train' else 'val'
     local_data_path = FileOps.join_path(self.args.data_path, split)
     ImageFolder.__init__(self,
                          root=local_data_path,
                          transform=Compose(self.transforms.__transform__))
Exemplo n.º 17
0
    def before_train(self, logs=None):
        """Be called before the whole train process."""
        self.trainer.config.call_metrics_on_train = False
        self.cfg = self.trainer.config
        self.worker_id = self.trainer.worker_id
        self.local_base_path = self.trainer.local_base_path
        self.local_output_path = self.trainer.local_output_path

        self.result_path = FileOps.join_path(self.trainer.local_base_path, "result")
        FileOps.make_dir(self.result_path)
        self.logger_patch()
Exemplo n.º 18
0
 def load_model(self):
     """Load model."""
     if not self.model_desc and not self.weights_file:
         saved_folder = self.get_local_worker_path(self.step_name,
                                                   self.worker_id)
         self.weights_file = FileOps.join_path(
             saved_folder, 'model_{}.pth'.format(self.worker_id))
         self.model_desc = FileOps.join_path(
             saved_folder, 'desc_{}.json'.format(self.worker_id))
     if 'modules' not in self.model_desc:
         self.model_desc = ModelConfig.model_desc
     self.model = ModelZoo.get_model(self.model_desc, self.weights_file)
Exemplo n.º 19
0
    def save_results(self):
        """Save the results of evolution contains the information of pupulation and elitism."""
        _path = FileOps.join_path(self.local_output_path, General.step_name)
        FileOps.make_dir(_path)
        arch_file = FileOps.join_path(_path, 'arch.txt')
        arch_child = FileOps.join_path(_path, 'arch_child.txt')
        sel_arch_file = FileOps.join_path(_path, 'selected_arch.npy')
        sel_arch = []
        with open(arch_file, 'a') as fw_a, open(arch_child, 'a') as fw_ac:
            writer_a = csv.writer(fw_a, lineterminator='\n')
            writer_ac = csv.writer(fw_ac, lineterminator='\n')
            writer_ac.writerow(
                ['Population Iteration: ' + str(self.evolution_count + 1)])
            for c in range(self.individual_num):
                writer_ac.writerow(
                    self._log_data(net_info_type='active_only',
                                   pop=self.pop[c],
                                   value=self.pop[c].fitness))

            writer_a.writerow(
                ['Population Iteration: ' + str(self.evolution_count + 1)])
            for c in range(self.elitism_num):
                writer_a.writerow(
                    self._log_data(net_info_type='active_only',
                                   pop=self.elitism[c],
                                   value=self.elit_fitness[c]))
                sel_arch.append(self.elitism[c].gene)
        sel_arch = np.stack(sel_arch)
        np.save(sel_arch_file, sel_arch)
        if self.backup_base_path is not None:
            FileOps.copy_folder(self.local_output_path, self.backup_base_path)
Exemplo n.º 20
0
 def copy_pareto_output(self, step_name=None, worker_ids=[]):
     """Copy files related to pareto from  worker to output."""
     taskops = TaskOps()
     local_output_path = os.path.join(taskops.local_output_path, step_name)
     if not (step_name and os.path.exists(local_output_path)):
         return
     for worker_id in worker_ids:
         desDir = os.path.join(local_output_path, str(worker_id))
         FileOps.make_dir(desDir)
         local_worker_path = taskops.get_worker_subpath(
             step_name, str(worker_id))
         srcDir = FileOps.join_path(taskops.local_base_path,
                                    local_worker_path)
         copy_search_file(srcDir, desDir)
Exemplo n.º 21
0
    def before_train(self, logs=None):
        """Call before_train of the managed callbacks."""
        super().before_train(logs)

        """Be called before the training process."""
        hpo_result = FileOps.load_pickle(FileOps.join_path(
            self.trainer.local_output_path, 'best_config.pickle'))
        logging.info("loading stage1_hpo_result \n{}".format(hpo_result))

        self.selected_pairs = hpo_result['feature_interaction']
        logging.info('feature_interaction:', self.selected_pairs)

        # add selected_pairs
        setattr(ModelConfig.model_desc['custom'], 'selected_pairs', self.selected_pairs)
Exemplo n.º 22
0
 def _init_dataloader(self):
     """Init dataloader from timm."""
     if self.distributed and hvd.local_rank(
     ) == 0 and 'remote_data_dir' in self.config.dataset:
         FileOps.copy_folder(self.config.dataset.remote_data_dir,
                             self.config.dataset.data_dir)
     if self.distributed:
         hvd.join()
     args = self.config.dataset
     train_dir = os.path.join(self.config.dataset.data_dir, 'train')
     dataset_train = Dataset(train_dir)
     world_size, rank = None, None
     if self.distributed:
         world_size, rank = hvd.size(), hvd.rank()
     self.trainer.train_loader = create_loader(
         dataset_train,
         input_size=tuple(args.input_size),
         batch_size=args.batch_size,
         is_training=True,
         use_prefetcher=self.config.prefetcher,
         rand_erase_prob=args.reprob,
         rand_erase_mode=args.remode,
         rand_erase_count=args.recount,
         color_jitter=args.color_jitter,
         auto_augment=args.aa,
         interpolation='random',
         mean=tuple(args.mean),
         std=tuple(args.std),
         num_workers=args.workers,
         distributed=self.distributed,
         world_size=world_size,
         rank=rank)
     valid_dir = os.path.join(self.config.dataset.data_dir, 'val')
     dataset_eval = Dataset(valid_dir)
     self.trainer.valid_loader = create_loader(
         dataset_eval,
         input_size=tuple(args.input_size),
         batch_size=4 * args.batch_size,
         is_training=False,
         use_prefetcher=self.config.prefetcher,
         interpolation=args.interpolation,
         mean=tuple(args.mean),
         std=tuple(args.std),
         num_workers=args.workers,
         distributed=self.distributed,
         world_size=world_size,
         rank=rank)
     self.trainer.batch_num_train = len(self.trainer.train_loader)
     self.trainer.batch_num_valid = len(self.trainer.valid_loader)
Exemplo n.º 23
0
 def _get_pretrained_model_file(self):
     if ModelConfig.pretrained_model_file:
         model_file = ModelConfig.pretrained_model_file
         model_file = model_file.replace("{local_base_path}", self.trainer.local_base_path)
         model_file = model_file.replace("{worker_id}", str(self.trainer.worker_id))
         if ":" not in model_file:
             model_file = os.path.abspath(model_file)
         if ":" in model_file:
             local_model_file = FileOps.join_path(
                 self.trainer.local_output_path, os.path.basename(model_file))
             FileOps.copy_file(model_file, local_model_file)
             model_file = local_model_file
         return model_file
     else:
         return None
Exemplo n.º 24
0
def save_master_ip(ip_address, port, args):
    """Write the ip and port in a system path.

    :param str ip_address: The `ip_address` need to write.
    :param str port: The `port` need to write.
    :param argparse.ArgumentParser args: `args` is a argparse that should
         contain `init_method`, `rank` and `world_size`.

    """
    temp_folder = TaskOps().temp_path
    FileOps.make_dir(temp_folder)
    file_path = os.path.join(temp_folder, 'ip_address.txt')
    logging.info("write ip, file path={}".format(file_path))
    with open(file_path, 'w') as f:
        f.write(ip_address + "\n")
        f.write(port + "\n")
Exemplo n.º 25
0
 def _load_checkpoint(self):
     """Load checkpoint."""
     if zeus.is_torch_backend():
         checkpoint_file = FileOps.join_path(
             self.trainer.get_local_worker_path(),
             self.trainer.checkpoint_file_name)
         if os.path.exists(checkpoint_file):
             try:
                 logging.info("Load checkpoint file, file={}".format(
                     checkpoint_file))
                 checkpoint = torch.load(checkpoint_file)
                 self.trainer.model.load_state_dict(checkpoint["weight"])
                 self.trainer.optimizer.load_state_dict(
                     checkpoint["optimizer"])
                 self.trainer.lr_scheduler.load_state_dict(
                     checkpoint["lr_scheduler"])
                 if self.trainer._resume_training:
                     epoch = checkpoint["epoch"]
                     self.trainer._start_epoch = checkpoint["epoch"]
                     logging.info(
                         "Resume fully train, change start epoch to {}".
                         format(self.trainer._start_epoch))
             except Exception as e:
                 logging.info("Load checkpoint failed {}".format(e))
         else:
             logging.info('Use default model')
Exemplo n.º 26
0
 def load_records_from_model_folder(cls, model_folder):
     """Transfer json_file to records."""
     if not model_folder or not os.path.exists(model_folder):
         logging.error(
             "Failed to load records from model folder, folder={}".format(
                 model_folder))
         return []
     records = []
     pattern = FileOps.join_path(model_folder, "desc_*.json")
     files = glob.glob(pattern)
     for _file in files:
         try:
             with open(_file) as f:
                 worker_id = _file.split(".")[-2].split("_")[-1]
                 weights_file = os.path.join(
                     os.path.dirname(_file),
                     "model_{}.pth".format(worker_id))
                 if os.path.exists(weights_file):
                     sample = dict(worker_id=worker_id,
                                   desc=json.load(f),
                                   weights_file=weights_file)
                 else:
                     sample = dict(worker_id=worker_id, desc=json.load(f))
                 record = ReportRecord().load_dict(sample)
                 records.append(record)
         except Exception as ex:
             logging.info(
                 'Can not read records from json because {}'.format(ex))
     return records
Exemplo n.º 27
0
    def _generate_init_model(self):
        """Generate init model by loading pretrained model.

        :return: initial model after loading pretrained model
        :rtype: torch.nn.Module
        """
        model_init = self._new_model_init()
        chn_mask = self._init_chn_node_mask()
        if vega.is_torch_backend():
            checkpoint = torch.load(self.config.init_model_file + '.pth')
            model_init.load_state_dict(checkpoint)
            model = PruneMobileNet(model_init).apply(chn_mask)
            model.to(self.device)
        elif vega.is_tf_backend():
            model = model_init
            with tf.compat.v1.Session(
                    config=self.trainer._init_session_config()) as sess:
                saver = tf.compat.v1.train.import_meta_graph("{}.meta".format(
                    self.config.init_model_file))
                saver.restore(sess, self.config.init_model_file)
                all_weight = tf.compat.v1.get_collection(
                    tf.compat.v1.GraphKeys.VARIABLES)
                all_weight = [
                    t for t in all_weight if not t.name.endswith('Momentum:0')
                ]
                PruneMobileNet(all_weight).apply(chn_mask)
                save_file = FileOps.join_path(
                    self.trainer.get_local_worker_path(), 'prune_model')
                saver.save(sess, save_file)
        elif vega.is_ms_backend():
            parameter_dict = load_checkpoint(self.config.init_model_file)
            load_param_into_net(model_init, parameter_dict)
            model = PruneMobileNet(model_init).apply(chn_mask)
        return model
Exemplo n.º 28
0
    def __init__(self, **kwargs):
        """Construct the Cifar10 class."""
        Dataset.__init__(self, **kwargs)
        self.args.data_path = FileOps.download_dataset(self.args.data_path)
        is_train = self.mode == 'train' or self.mode == 'val' and self.args.train_portion < 1
        self.base_folder = 'cifar-10-batches-py'
        self.transform = Compose(self.transforms.__transform__)
        if is_train:
            files_list = [
                "data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4",
                "data_batch_5"
            ]
        else:
            files_list = ['test_batch']

        self.data = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name in files_list:
            file_path = os.path.join(self.args.data_path, self.base_folder,
                                     file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                if 'labels' in entry:
                    self.targets.extend(entry['labels'])
                else:
                    self.targets.extend(entry['fine_labels'])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC
Exemplo n.º 29
0
 def _get_current_step_records(self):
     step_name = General.step_name
     models_folder = PipeStepConfig.pipe_step.get("models_folder")
     cur_index = PipelineConfig.steps.index(step_name)
     if cur_index >= 1 or models_folder:
         # records = Report().get_step_records(PipelineConfig.steps[cur_index - 1])
         if not models_folder:
             models_folder = FileOps.join_path(
                 TaskOps().local_output_path,
                 PipelineConfig.steps[cur_index - 1])
         models_folder = models_folder.replace("{local_base_path}",
                                               TaskOps().local_base_path)
         records = Report().load_records_from_model_folder(models_folder)
     else:
         records = self._load_single_model_records()
     final_records = []
     for record in records:
         if not record.weights_file:
             logger.error("Model file is not existed, id={}".format(
                 record.worker_id))
         else:
             record.step_name = General.step_name
             final_records.append(record)
     logging.debug("Records: {}".format(final_records))
     return final_records
Exemplo n.º 30
0
    def load_model(self):
        """Load model."""
        self.saved_folder = self.get_local_worker_path(self.step_name, self.worker_id)
        if not self.model_desc:
            self.model_desc = FileOps.join_path(self.saved_folder, 'desc_{}.json'.format(self.worker_id))
        if not self.weights_file:
            if zeus.is_torch_backend():
                self.weights_file = FileOps.join_path(self.saved_folder, 'model_{}.pth'.format(self.worker_id))
            elif zeus.is_ms_backend():
                for file in os.listdir(self.saved_folder):
                    if file.startswith("CKP") and file.endswith(".ckpt"):
                        self.weights_file = FileOps.join_path(self.saved_folder, file)

        if 'modules' not in self.model_desc:
            self.model_desc = ModelConfig.model_desc
        self.model = ModelZoo.get_model(self.model_desc, self.weights_file)