예제 #1
0
 def _get_model_desc(self):
     model_desc = self.trainer.model_desc
     if not model_desc:
         if ModelConfig.model_desc_file is not None:
             desc_file = ModelConfig.model_desc_file
             desc_file = desc_file.replace("{local_base_path}",
                                           self.trainer.local_base_path)
             if ":" not in desc_file:
                 desc_file = os.path.abspath(desc_file)
             if ":" in desc_file:
                 local_desc_file = FileOps.join_path(
                     self.trainer.local_output_path,
                     os.path.basename(desc_file))
                 FileOps.copy_file(desc_file, local_desc_file)
                 desc_file = local_desc_file
             model_desc = Config(desc_file)
             logger.info("net_desc:{}".format(model_desc))
         elif ModelConfig.model_desc is not None:
             model_desc = ModelConfig.model_desc
         elif ModelConfig.models_folder is not None:
             folder = ModelConfig.models_folder.replace(
                 "{local_base_path}", self.trainer.local_base_path)
             pattern = FileOps.join_path(folder, "desc_*.json")
             desc_file = glob.glob(pattern)[0]
             model_desc = Config(desc_file)
     return model_desc
예제 #2
0
파일: div2k.py 프로젝트: huawei-noah/vega
    def dataset_init(self):
        """Costruct method, which will load some dataset information."""
        self.args.root_HR = FileOps.download_dataset(self.args.root_HR)
        self.args.root_LR = FileOps.download_dataset(self.args.root_LR)
        if self.args.subfile is not None:
            with open(self.args.subfile
                      ) as f:  # lmdb format has no self.args.subfile
                file_names = sorted([line.rstrip('\n') for line in f])
                self.datatype = util.get_files_datatype(file_names)
                self.paths_HR = [
                    os.path.join(self.args.root_HR, file_name)
                    for file_name in file_names
                ]
                self.paths_LR = [
                    os.path.join(self.args.root_LR, file_name)
                    for file_name in file_names
                ]
        else:
            self.datatype = util.get_datatype(self.args.root_LR)
            self.paths_LR = util.get_paths_from_dir(self.args.root_LR)
            self.paths_HR = util.get_paths_from_dir(self.args.root_HR)

        if self.args.save_in_memory:
            self.imgs_LR = [self._read_img(path) for path in self.paths_LR]
            self.imgs_HR = [self._read_img(path) for path in self.paths_HR]
    def before_train(self, logs=None):
        """Call before_train of the managed callbacks."""
        super().before_train(logs)
        """Be called before the training process."""
        hpo_result = FileOps.load_pickle(
            FileOps.join_path(self.trainer.local_output_path,
                              'best_config.pickle'))
        logging.info("loading stage1_hpo_result \n{}".format(hpo_result))

        feature_interaction_score = hpo_result['feature_interaction_score']
        print('feature_interaction_score:', feature_interaction_score)
        sorted_pairs = sorted(feature_interaction_score.items(),
                              key=lambda x: abs(x[1]),
                              reverse=True)

        if ModelConfig.model_desc:
            fis_ratio = ModelConfig.model_desc["custom"]["fis_ratio"]
        else:
            fis_ratio = 1.0
        top_k = int(len(feature_interaction_score) * min(1.0, fis_ratio))
        self.selected_pairs = list(map(lambda x: x[0], sorted_pairs[:top_k]))

        # add selected_pairs
        setattr(ModelConfig.model_desc['custom'], 'selected_pairs',
                self.selected_pairs)
예제 #4
0
    def dataset_init(self):
        """Construct method.

        If both data_dir and label_dir are provided, then use data_dir and label_dir
        Otherwise use data_path and list_file.
        """
        if "data_dir" in self.args and "label_dir" in self.args:
            self.args.data_dir = FileOps.download_dataset(self.args.data_dir)
            self.args.label_dir = FileOps.download_dataset(self.args.label_dir)
            self.data_files = sorted(glob.glob(osp.join(self.args.data_dir, "*")))
            self.label_files = sorted(glob.glob(osp.join(self.args.label_dir, "*")))
        else:
            if "data_path" not in self.args or "list_file" not in self.args:
                raise Exception("You must provide a data_path and a list_file!")
            self.args.data_path = FileOps.download_dataset(self.args.data_path)
            with open(osp.join(self.args.data_path, self.args.list_file)) as f:
                lines = f.readlines()
            self.data_files = [None] * len(lines)
            self.label_files = [None] * len(lines)
            for i, line in enumerate(lines):
                data_file_name, label_file_name = line.strip().split()
                self.data_files[i] = osp.join(self.args.data_path, data_file_name)
                self.label_files[i] = osp.join(self.args.data_path, label_file_name)

        datatype = self._get_datatype()
        if datatype == "image":
            self.read_fn = self._read_item_image
        else:
            self.read_fn = self._read_item_pickle
예제 #5
0
    def _save_checkpoint(self, epoch, best=False):
        """Save model weights.

        :param epoch: current epoch
        :type epoch: int
        """
        save_dir = os.path.join(self.worker_path, str(epoch))
        FileOps.make_dir(save_dir)
        for name in self.model.model_names:
            if isinstance(name, str):
                save_filename = '%s_net_%s.pth' % (epoch, name)
                save_path = FileOps.join_path(save_dir, save_filename)
                net = getattr(self.model, 'net' + name)
                best_file = FileOps.join_path(self.worker_path,
                                              "model_{}.pth".format(name))
                if vega.is_gpu_device() and torch.cuda.is_available():
                    # torch.save(net.module.cpu().state_dict(), save_path)
                    torch.save(net.module.state_dict(), save_path)
                    # net.cuda()
                    if best:
                        torch.save(net.module.state_dict(), best_file)
                elif vega.is_npu_device():
                    torch.save(net.state_dict(), save_path)
                    if best:
                        torch.save(net.state_dict(), best_file)
                else:
                    torch.save(net.cpu().state_dict(), save_path)
                    if best:
                        torch.save(net.cpu().state_dict(), best_file)
예제 #6
0
    def save_report(self, records):
        """Save report to `reports.json`."""
        try:
            _file = FileOps.join_path(TaskOps().local_output_path,
                                      "reports.json")
            FileOps.make_base_dir(_file)
            data = {"_steps_": []}

            for step in self.step_names:
                if step in self.steps:
                    data["_steps_"].append(self.steps[step])
                else:
                    data["_steps_"].append({
                        "step_name": step,
                        "status": Status.unstarted
                    })

            for record in records:
                if record.step_name in data:
                    data[record.step_name].append(record.to_dict())
                else:
                    data[record.step_name] = [record.to_dict()]
            with open(_file, "w") as f:
                json.dump(data, f, indent=4, cls=JsonEncoder)
        except Exception:
            logging.warning(traceback.format_exc())
예제 #7
0
    def after_valid(self, logs=None):
        """Call after_valid of the managed callbacks."""
        self.model = self.trainer.model
        feature_interaction_score = self.model.get_feature_interaction_score()
        print('get feature_interaction_score', feature_interaction_score)
        feature_interaction = []
        for feature in feature_interaction_score:
            if abs(feature_interaction_score[feature]) > 0:
                feature_interaction.append(feature)
        print('get feature_interaction', feature_interaction)

        curr_auc = float(self.trainer.valid_metrics.results['auc'])
        if curr_auc > self.best_score:
            best_config = {
                'score': curr_auc,
                'feature_interaction': feature_interaction
            }

            logging.info("BEST CONFIG IS\n{}".format(best_config))
            pickle_result_file = FileOps.join_path(
                self.trainer.local_output_path, 'best_config.pickle')
            logging.info("Saved to {}".format(pickle_result_file))
            FileOps.dump_pickle(best_config, pickle_result_file)

            self.best_score = curr_auc
예제 #8
0
 def _backup(self):
     """Backup result worker folder."""
     if self.need_backup is True and self.backup_base_path is not None:
         backup_worker_path = FileOps.join_path(self.backup_base_path,
                                                self.get_worker_subpath())
         FileOps.copy_folder(
             self.get_local_worker_path(self.step_name, self.worker_id),
             backup_worker_path)
예제 #9
0
 def __init__(self, **kwargs):
     """Construct the Imagenet class."""
     Dataset.__init__(self, **kwargs)
     self.args.data_path = FileOps.download_dataset(self.args.data_path)
     split = 'train' if self.mode == 'train' else 'val'
     local_data_path = FileOps.join_path(self.args.data_path, split)
     delattr(self, 'loader')
     ImageFolder.__init__(self,
                          root=local_data_path,
                          transform=Compose(self.transforms.__transform__))
예제 #10
0
    def before_train(self, logs=None):
        """Be called before the whole train process."""
        self.trainer.config.call_metrics_on_train = False
        self.cfg = self.trainer.config
        self.worker_id = self.trainer.worker_id
        self.local_base_path = self.trainer.local_base_path
        self.local_output_path = self.trainer.local_output_path

        self.result_path = FileOps.join_path(self.trainer.local_base_path,
                                             "result")
        FileOps.make_dir(self.result_path)
        self.logger_patch()
    def before_train(self, logs=None):
        """Call before_train of the managed callbacks."""
        super().before_train(logs)

        """Be called before the training process."""
        hpo_result = FileOps.load_pickle(FileOps.join_path(
            self.trainer.local_output_path, 'best_config.pickle'))
        logging.info("loading stage1_hpo_result \n{}".format(hpo_result))

        self.selected_pairs = hpo_result['feature_interaction']
        logging.info('feature_interaction:', self.selected_pairs)

        # add selected_pairs
        setattr(ModelConfig.model_desc['custom'], 'selected_pairs', self.selected_pairs)
예제 #12
0
    def save_results(self):
        """Save the results of evolution contains the information of pupulation and elitism."""
        _path = FileOps.join_path(self.local_output_path, General.step_name)
        FileOps.make_dir(_path)
        arch_file = FileOps.join_path(_path, 'arch.txt')
        arch_child = FileOps.join_path(_path, 'arch_child.txt')
        sel_arch_file = FileOps.join_path(_path, 'selected_arch.npy')
        sel_arch = []
        with open(arch_file, 'a') as fw_a, open(arch_child, 'a') as fw_ac:
            writer_a = csv.writer(fw_a, lineterminator='\n')
            writer_ac = csv.writer(fw_ac, lineterminator='\n')
            writer_ac.writerow(
                ['Population Iteration: ' + str(self.evolution_count + 1)])
            for c in range(self.individual_num):
                writer_ac.writerow(
                    self._log_data(net_info_type='active_only',
                                   pop=self.pop[c],
                                   value=self.pop[c].fitness))

            writer_a.writerow(
                ['Population Iteration: ' + str(self.evolution_count + 1)])
            for c in range(self.elitism_num):
                writer_a.writerow(
                    self._log_data(net_info_type='active_only',
                                   pop=self.elitism[c],
                                   value=self.elit_fitness[c]))
                sel_arch.append(self.elitism[c].gene)
        sel_arch = np.stack(sel_arch)
        np.save(sel_arch_file, sel_arch)
        if self.backup_base_path is not None:
            FileOps.copy_folder(self.local_output_path, self.backup_base_path)
예제 #13
0
 def _init_dataloader(self):
     """Init dataloader from timm."""
     if self.distributed and hvd.local_rank(
     ) == 0 and 'remote_data_dir' in self.config.dataset:
         FileOps.copy_folder(self.config.dataset.remote_data_dir,
                             self.config.dataset.data_dir)
     if self.distributed:
         hvd.join()
     args = self.config.dataset
     train_dir = os.path.join(self.config.dataset.data_dir, 'train')
     dataset_train = Dataset(train_dir)
     world_size, rank = None, None
     if self.distributed:
         world_size, rank = hvd.size(), hvd.rank()
     self.trainer.train_loader = create_loader(
         dataset_train,
         input_size=tuple(args.input_size),
         batch_size=args.batch_size,
         is_training=True,
         use_prefetcher=self.config.prefetcher,
         rand_erase_prob=args.reprob,
         rand_erase_mode=args.remode,
         rand_erase_count=args.recount,
         color_jitter=args.color_jitter,
         auto_augment=args.aa,
         interpolation='random',
         mean=tuple(args.mean),
         std=tuple(args.std),
         num_workers=args.workers,
         distributed=self.distributed,
         world_size=world_size,
         rank=rank)
     valid_dir = os.path.join(self.config.dataset.data_dir, 'val')
     dataset_eval = Dataset(valid_dir)
     self.trainer.valid_loader = create_loader(
         dataset_eval,
         input_size=tuple(args.input_size),
         batch_size=4 * args.batch_size,
         is_training=False,
         use_prefetcher=self.config.prefetcher,
         interpolation=args.interpolation,
         mean=tuple(args.mean),
         std=tuple(args.std),
         num_workers=args.workers,
         distributed=self.distributed,
         world_size=world_size,
         rank=rank)
     self.trainer.batch_num_train = len(self.trainer.train_loader)
     self.trainer.batch_num_valid = len(self.trainer.valid_loader)
예제 #14
0
 def _get_current_step_records(self):
     step_name = General.step_name
     models_folder = PipeStepConfig.pipe_step.get("models_folder")
     cur_index = PipelineConfig.steps.index(step_name)
     if cur_index >= 1 or models_folder:
         if not models_folder:
             models_folder = FileOps.join_path(
                 TaskOps().local_output_path,
                 PipelineConfig.steps[cur_index - 1])
         models_folder = models_folder.replace("{local_base_path}",
                                               TaskOps().local_base_path)
         records = ReportServer().load_records_from_model_folder(
             models_folder)
     else:
         records = self._load_single_model_records()
     final_records = []
     for record in records:
         if not record.weights_file:
             logger.error("Model file is not existed, id={}".format(
                 record.worker_id))
         else:
             record.step_name = General.step_name
             final_records.append(record)
     logging.debug("Records: {}".format(final_records))
     return final_records
예제 #15
0
파일: utils.py 프로젝트: huawei-noah/vega
def load_master_ip():
    """Get the ip and port that write in a system path.

    here will not download anything from S3.
    """
    temp_folder = TaskOps().temp_path
    FileOps.make_dir(temp_folder)
    file_path = os.path.join(temp_folder, 'ip_address.txt')
    if os.path.isfile(file_path):
        with open(file_path, 'r') as f:
            ip = f.readline().strip()
            port = f.readline().strip()
            logging.info("get write ip, ip={}, port={}".format(ip, port))
            return ip, port
    else:
        return None, None
예제 #16
0
파일: utils.py 프로젝트: huawei-noah/vega
def save_master_ip(ip_address, port, args):
    """Write the ip and port in a system path.

    :param str ip_address: The `ip_address` need to write.
    :param str port: The `port` need to write.
    :param argparse.ArgumentParser args: `args` is a argparse that should
         contain `init_method`, `rank` and `world_size`.

    """
    temp_folder = TaskOps().temp_path
    FileOps.make_dir(temp_folder)
    file_path = os.path.join(temp_folder, 'ip_address.txt')
    logging.info("write ip, file path={}".format(file_path))
    with open(file_path, 'w') as f:
        f.write(ip_address + "\n")
        f.write(port + "\n")
예제 #17
0
 def _save_pb_model(self, weight_file, model_id):
     from tensorflow.python.framework import graph_util
     valid_data = self.trainer.valid_loader.input_fn()
     iterator = valid_data.make_one_shot_iterator()
     one_element = iterator.get_next()
     with tf.Session() as sess:
         batch = sess.run(one_element)
     input_shape = batch[0].shape
     with tf.Graph().as_default():
         input_holder_shape = (None, ) + tuple(input_shape[1:])
         input_holder = tf.placeholder(dtype=tf.float32,
                                       shape=input_holder_shape)
         self.trainer.model.training = False
         output = self.trainer.model(input_holder)
         if isinstance(output, tuple):
             output_name = [output[0].name.split(":")[0]]
         else:
             output_name = [output.name.split(":")[0]]
         with tf.Session() as sess:
             sess.run(tf.global_variables_initializer())
             if weight_file is not None:
                 saver = tf.train.Saver()
                 last_weight_file = tf.train.latest_checkpoint(weight_file)
                 if last_weight_file:
                     saver.restore(sess, last_weight_file)
             constant_graph = graph_util.convert_variables_to_constants(
                 sess, sess.graph_def, output_name)
             output_graph = FileOps.join_path(weight_file,
                                              '{}.pb'.format(model_id))
             with tf.gfile.FastGFile(output_graph, mode='wb') as f:
                 f.write(constant_graph.SerializeToString())
예제 #18
0
    def __init__(self, **kwargs):
        """Construct the dataset."""
        super().__init__(**kwargs)
        self.args.data_path = FileOps.download_dataset(self.args.data_path)
        dataset_pairs = dict(train=create_train_subset(self.args.data_path),
                             test=create_test_subset(self.args.data_path),
                             val=create_test_subset(self.args.data_path))

        if self.mode not in dataset_pairs.keys():
            raise NotImplementedError(
                f'mode should be one of {dataset_pairs.keys()}')
        self.image_annot_path_pairs = dataset_pairs.get(self.mode)

        self.codec_obj = PointLaneCodec(input_width=512,
                                        input_height=288,
                                        anchor_stride=16,
                                        points_per_line=72,
                                        class_num=2)
        self.encode_lane = self.codec_obj.encode_lane
        read_funcs = dict(
            CULane=_read_culane_type_annot,
            CurveLane=_read_curvelane_type_annot,
        )
        if self.args.dataset_format not in read_funcs:
            raise NotImplementedError(
                f'dataset_format should be one of {read_funcs.keys()}')
        self.read_annot = read_funcs.get(self.args.dataset_format)
        self.with_aug = self.args.get('with_aug', False)
예제 #19
0
    def __init__(self, **kwargs):
        """Construct the Cifar10 class."""
        Dataset.__init__(self, **kwargs)
        self.args.data_path = FileOps.download_dataset(self.args.data_path)
        is_train = self.mode == 'train' or self.mode == 'val' and self.args.train_portion < 1
        self.base_folder = 'cifar-100-python'
        if is_train:
            files_list = ["train"]
        else:
            files_list = ['test']

        self.data = []
        self.targets = []

        # now load the picked numpy arrays
        for file_name in files_list:
            file_path = os.path.join(self.args.data_path, self.base_folder,
                                     file_name)
            with open(file_path, 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                if 'labels' in entry:
                    self.targets.extend(entry['labels'])
                else:
                    self.targets.extend(entry['fine_labels'])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC
예제 #20
0
    def update(self, record):
        """Update current performance into hpo score board.

        :param hps: hyper parameters need to update
        :param performance:  trainer performance
        """
        super().update(record)
        config_id = str(record.get('worker_id'))
        step_name = record.get('step_name')
        worker_result_path = self.get_local_worker_path(step_name, config_id)
        new_worker_result_path = FileOps.join_path(self.local_base_path, 'cache', config_id, 'checkpoint')
        FileOps.make_dir(worker_result_path)
        FileOps.make_dir(new_worker_result_path)
        if os.path.exists(new_worker_result_path):
            shutil.rmtree(new_worker_result_path)
        shutil.copytree(worker_result_path, new_worker_result_path)
예제 #21
0
    def load_records_from_model_folder(cls, model_folder):
        """Transfer json_file to records."""
        if not model_folder or not os.path.exists(model_folder):
            logging.error("Failed to load records from model folder, folder={}".format(model_folder))
            return []
        records = []
        pattern = FileOps.join_path(model_folder, "desc_*.json")
        files = glob.glob(pattern)
        for _file in files:
            try:
                with open(_file) as f:
                    worker_id = _file.split(".")[-2].split("_")[-1]
                    weights_file = os.path.join(os.path.dirname(_file), "model_{}".format(worker_id))
                    if vega.is_torch_backend():
                        weights_file = '{}.pth'.format(weights_file)
                    elif vega.is_ms_backend():
                        weights_file = '{}.ckpt'.format(weights_file)
                    if not os.path.exists(weights_file):
                        weights_file = None

                    sample = dict(worker_id=worker_id, desc=json.load(f), weights_file=weights_file)
                    record = ReportRecord().load_dict(sample)
                    records.append(record)
            except Exception as ex:
                logging.info('Can not read records from json because {}'.format(ex))
        return records
예제 #22
0
    def _new_model_init(self):
        """Init new model.

        :return: initial model after loading pretrained model
        :rtype: torch.nn.Module
        """
        init_model_file = self.config.init_model_file
        if ":" in init_model_file:
            local_path = FileOps.join_path(
                self.trainer.get_local_worker_path(),
                os.path.basename(init_model_file))
            FileOps.copy_file(init_model_file, local_path)
            self.config.init_model_file = local_path
        network_desc = copy.deepcopy(self.base_net_desc)
        network_desc.backbone.cfgs = network_desc.backbone.base_cfgs
        model_init = NetworkDesc(network_desc).to_model()
        return model_init
예제 #23
0
 def get_pareto_list_size(self):
     """Get the number of pareto list."""
     pareto_list_size = 0
     pareto_file_locate = FileOps.join_path(self.local_base_path, "result",
                                            "pareto_front.csv")
     if os.path.exists(pareto_file_locate):
         pareto_front_df = pd.read_csv(pareto_file_locate)
         pareto_list_size = pareto_front_df.size
     return pareto_list_size
예제 #24
0
 def _saved_multi_checkpoint(self, epoch):
     """Save multi tasks checkpoint."""
     FileOps.make_dir(self.trainer.get_local_worker_path(),
                      self.trainer.multi_task)
     checkpoint_file = FileOps.join_path(
         self.trainer.get_local_worker_path(), self.trainer.multi_task,
         self.trainer.checkpoint_file_name)
     logging.debug("Start Save Multi Task Model, model_file=%s",
                   self.trainer.model_pickle_file_name)
     if vega.is_torch_backend():
         ckpt = {
             'epoch': epoch,
             'weight': self.trainer.model.state_dict(),
             'optimizer': self.trainer.optimizer.state_dict(),
             'lr_scheduler': self.trainer.lr_scheduler.state_dict(),
         }
         torch.save(ckpt, checkpoint_file)
     self.trainer.checkpoint_file = checkpoint_file
예제 #25
0
 def _save_descript(self):
     """Save result descript."""
     template_file = self.config.darts_template_file
     genotypes = self.search_alg.codec.calc_genotype(
         self._get_arch_weights())
     if template_file == "{default_darts_cifar10_template}":
         template = DartsNetworkTemplateConfig.cifar10
     elif template_file == "{default_darts_cifar100_template}":
         template = DartsNetworkTemplateConfig.cifar100
     elif template_file == "{default_darts_imagenet_template}":
         template = DartsNetworkTemplateConfig.imagenet
     else:
         dst = FileOps.join_path(self.trainer.get_local_worker_path(),
                                 os.path.basename(template_file))
         FileOps.copy_file(template_file, dst)
         template = Config(dst)
     model_desc = self._gen_model_desc(genotypes, template)
     self.trainer.config.codec = model_desc
예제 #26
0
파일: mnist.py 프로젝트: huawei-noah/vega
 def __init__(self, **kwargs):
     """Construct the Mnist class."""
     Dataset.__init__(self, **kwargs)
     self.args.data_path = FileOps.download_dataset(self.args.data_path)
     MNIST.__init__(self,
                    root=self.args.data_path,
                    train=self.train,
                    transform=self.transforms,
                    download=self.args.download)
예제 #27
0
    def dataset_init(self):
        """Initialize dataset."""
        self.args.HR_dir = FileOps.download_dataset(self.args.HR_dir)
        self.args.LR_dir = FileOps.download_dataset(self.args.LR_dir)
        self.Y_paths = sorted(self.make_dataset(
            self.args.LR_dir,
            float("inf"))) if self.args.LR_dir is not None else None
        self.HR_paths = sorted(
            self.make_dataset(
                self.args.HR_dir,
                float("inf"))) if self.args.HR_dir is not None else None

        self.trans_norm = transforms.Compose(
            [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        for i in range(len(self.HR_paths)):
            file_name = os.path.basename(self.HR_paths[i])
            if (file_name.find("0401") >= 0):
                logging.info(
                    "We find the possion of NO. 401 in the HR patch NO. {}".
                    format(i))
                self.HR_paths = self.HR_paths[:i]
                break
        for i in range(len(self.Y_paths)):
            file_name = os.path.basename(self.Y_paths[i])
            if (file_name.find("0401") >= 0):
                logging.info(
                    "We find the possion of NO. 401 in the LR patch NO. {}".
                    format(i))
                self.Y_paths = self.Y_paths[i:]
                break

        self.Y_size = len(self.Y_paths)
        if self.train:
            self.load_size = self.args.load_size
            self.crop_size = self.args.crop_size
            self.upscale = self.args.upscale
            self.augment_transform = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip()
            ])
            self.HR_transform = transforms.RandomCrop(
                int(self.crop_size * self.upscale))
            self.LR_transform = transforms.RandomCrop(self.crop_size)
예제 #28
0
 def _load_checkpoint(self):
     """Load checkpoint."""
     if vega.is_torch_backend():
         if hasattr(self.trainer.config, "checkpoint_path"):
             checkpoint_path = self.trainer.config.checkpoint_path
         else:
             checkpoint_path = self.trainer.get_local_worker_path()
         checkpoint_file = FileOps.join_path(
             checkpoint_path, self.trainer.checkpoint_file_name)
         if os.path.exists(checkpoint_file):
             try:
                 logging.info("Load checkpoint file, file={}".format(
                     checkpoint_file))
                 checkpoint = torch.load(checkpoint_file)
                 if self.trainer.multi_task:
                     self.trainer.model.load_state_dict(
                         checkpoint["weight"], strict=False)
                     multi_task_checkpoint = torch.load(
                         FileOps.join_path(
                             checkpoint_path, self.trainer.multi_task,
                             self.trainer.checkpoint_file_name))
                     self.trainer.optimizer.load_state_dict(
                         multi_task_checkpoint["optimizer"])
                     self.trainer.lr_scheduler.load_state_dict(
                         multi_task_checkpoint["lr_scheduler"])
                 else:
                     self.trainer.model.load_state_dict(
                         checkpoint["weight"])
                     self.trainer.optimizer.load_state_dict(
                         checkpoint["optimizer"])
                 self.trainer.lr_scheduler.load_state_dict(
                     checkpoint["lr_scheduler"])
                 if self.trainer._resume_training:
                     # epoch = checkpoint["epoch"]
                     self.trainer._start_epoch = checkpoint["epoch"]
                     logging.info(
                         "Resume fully train, change start epoch to {}".
                         format(self.trainer._start_epoch))
             except Exception as e:
                 logging.info("Load checkpoint failed {}".format(e))
         else:
             logging.info(
                 "skip loading checkpoint file that do not exist, {}".
                 format(checkpoint_file))
예제 #29
0
 def set_trainer(self, trainer):
     """Set trainer object for current callback."""
     self.trainer = trainer
     self.trainer._train_loop = self._train_loop
     self.cfg = self.trainer.config
     self._worker_id = self.trainer._worker_id
     self.worker_path = self.trainer.get_local_worker_path()
     self.output_path = self.trainer.local_output_path
     self.best_model_name = "model_best"
     self.best_model_file = FileOps.join_path(
         self.worker_path, "model_{}.pth".format(self.trainer.worker_id))
예제 #30
0
 def __init__(self, **kwargs):
     """Init Cifar10."""
     super(Imagenet, self).__init__(**kwargs)
     self.data_path = FileOps.download_dataset(self.args.data_path)
     self.fp16 = self.args.fp16
     self.num_parallel_batches = self.args.num_parallel_batches
     self.image_size = self.args.image_size
     self.drop_remainder = self.args.drop_last
     if self.data_path == 'null' or not self.data_path:
         self.data_path = None
     self.num_parallel_calls = self.args.num_parallel_calls