示例#1
0
 def load_records_from_model_folder(cls, model_folder):
     """Transfer json_file to records."""
     if not model_folder or not os.path.exists(model_folder):
         logging.error(
             "Failed to load records from model folder, folder={}".format(
                 model_folder))
         return []
     records = []
     pattern = FileOps.join_path(model_folder, "desc_*.json")
     files = glob.glob(pattern)
     for _file in files:
         try:
             with open(_file) as f:
                 worker_id = _file.split(".")[-2].split("_")[-1]
                 weights_file = os.path.join(
                     os.path.dirname(_file),
                     "model_{}.pth".format(worker_id))
                 if os.path.exists(weights_file):
                     sample = dict(worker_id=worker_id,
                                   desc=json.load(f),
                                   weights_file=weights_file)
                 else:
                     sample = dict(worker_id=worker_id, desc=json.load(f))
                 record = ReportRecord().load_dict(sample)
                 records.append(record)
         except Exception as ex:
             logging.info(
                 'Can not read records from json because {}'.format(ex))
     return records
示例#2
0
 def _evaluate_esr_models(self, esr_models_file, models_folder):
     models_folder = models_folder.replace("{local_base_path}",
                                           self.task.local_base_path)
     models_folder = os.path.abspath(models_folder)
     esr_models_file = esr_models_file.replace("{local_base_path}",
                                               self.task.local_base_path)
     esr_models_file = os.path.abspath(esr_models_file)
     archs = np.load(esr_models_file)
     for i, arch in enumerate(archs):
         try:
             cls_gpu_evaluator = ClassFactory.get_cls(
                 ClassType.GPU_EVALUATOR)
         except Exception:
             logger.error(
                 "Failed to create Evaluator, please check the config file")
             logger.error(traceback.format_exc())
             return
         pretrained_model = FileOps.join_path(models_folder,
                                              "model_{}.pth".format(i))
         if not os.path.exists(pretrained_model):
             logger.error("Failed to find model file, file={}".format(
                 pretrained_model))
         cls_gpu_evaluator.cfg.model_arch = arch
         cls_gpu_evaluator.cfg.pretrained_model_file = pretrained_model
         try:
             evaluator = cls_gpu_evaluator()
             evaluator.train_process()
             evaluator.output_evaluate_result(i, evaluator.evaluate_result)
         except Exception:
             logger.error(
                 "Failed to evaluate model, id={}, pretrained_model={}".
                 format(i, pretrained_model))
             logger.error(traceback.format_exc())
             return
示例#3
0
文件: trainer.py 项目: zhwzhong/vega
 def _save_checkpoint(self, epoch):
     """Save checkpoint."""
     checkpoint_file = FileOps.join_path(self.get_local_worker_path(),
                                         self.checkpoint_file_name)
     model_pickle_file = FileOps.join_path(self.get_local_worker_path(),
                                           self.model_pickle_file_name)
     # pickle model
     with open(model_pickle_file, 'wb') as handle:
         pickle.dump(self.model, handle, protocol=pickle.HIGHEST_PROTOCOL)
     # save checkpoint
     ckpt = {
         'epoch': epoch,
         'weight': self.model.state_dict(),
         'optimizer': self.optimizer.state_dict(),
         'lr_scheduler': self.lr_scheduler.state_dict(),
     }
     torch.save(ckpt, checkpoint_file)
示例#4
0
    def __init__(self, search_space=None, **kwargs):
        super(SpNas, self).__init__(search_space, **kwargs)
        self.search_space = search_space
        # self.codec = Codec(self.config.codec, search_space)
        self.sample_level = self.config.sample_level
        self.max_sample = self.config.max_sample
        self.max_optimal = self.config.max_optimal
        self._total_list_file = self.config.total_list.replace(
            "{local_base_path}",
            TaskOps().local_base_path)
        self.serial_settings = self.config.serial_settings

        self._total_list = ListDict()
        self.sample_count = 0
        self.init_code = None
        self.output_path = TaskOps().local_output_path

        if self.config.last_search_result:
            last_search_file = self.config.last_search_result.replace(
                "{local_base_path}",
                TaskOps().local_base_path)
            assert FileOps.exists(
                last_search_file), "Not found serial results!"
            last_search_results = ListDict.load_csv(last_search_file)
            pre_worker_id, pre_arch = self.select_from_remote(
                self.max_optimal, last_search_results)
            # re-write config template
            if self.config.regnition:
                self.codec.config_template['model']['backbone'][
                    'reignition'] = True
                assert FileOps.exists(
                    os.path.join(self.output_path, pre_arch + '_imagenet.pth')
                ), "Not found {} pretrained .pth file!".format(pre_arch)
                pretrained_pth = os.path.join(self.output_path,
                                              pre_arch + '_imagenet.pth')
                self.codec.config_template['model'][
                    'pretrained'] = pretrained_pth
                pre_worker_id = -1
            # update config template
            self.init_code = dict(arch=pre_arch,
                                  pre_arch=pre_arch.split('_')[1],
                                  pre_worker_id=pre_worker_id)

        logging.info("inited SpNas {}-level search...".format(
            self.sample_level))
示例#5
0
    def _save_descript(self, descript):
        """Save result descript.

        :param descript: darts search result descript
        :type descript: dict or Config
        """
        template_file = self.cfg.darts_template_file
        genotypes = self.search_alg.codec.calc_genotype(self.model.arch_weights)
        if template_file == "{default_darts_cifar10_template}":
            template = DefaultConfig().data.default_darts_cifar10_template
        elif template_file == "{default_darts_imagenet_template}":
            template = DefaultConfig().data.default_darts_imagenet_template
        else:
            dst = FileOps.join_path(self.trainer.get_local_worker_path(), os.path.basename(template_file))
            FileOps.copy_file(template_file, dst)
            template = Config(dst)
        model_desc = self._gen_model_desc(genotypes, template)
        self.trainer.output_model_desc(self.trainer.worker_id, model_desc)
示例#6
0
 def update(self, record):
     """Update sampler."""
     step_name = record.get("step_name")
     worker_id = record.get("worker_id")
     worker_result_path = TaskOps().get_local_worker_path(
         step_name, worker_id)
     performance_file = self.performance_path(worker_result_path)
     logging.info(
         "SpNas.update(), performance file={}".format(performance_file))
     info = FileOps.load_pickle(performance_file)
     if info is not None:
         self._total_list.append(info)
     else:
         logging.info("SpNas.update(), file is not exited, "
                      "performance file={}".format(performance_file))
     self.save_output(self.output_path)
     if self.backup_base_path is not None:
         FileOps.copy_folder(self.output_path, self.backup_base_path)
示例#7
0
 def _init_npu_estimator(self, sess_config):
     model_dir = self.get_local_worker_path()
     if self.distributed:
         model_dir = FileOps.join_path(model_dir, str(self._local_rank_id))
     config = NPURunConfig(model_dir=model_dir,
                           save_checkpoints_steps=self.config.save_steps,
                           log_step_count_steps=self.config.report_freq,
                           session_config=sess_config,
                           enable_data_pre_proc=True,
                           iterations_per_loop=1)
     self.estimator = NPUEstimator(model_fn=self.model_fn, config=config)
示例#8
0
 def __init__(self, **kwargs):
     """Init Cifar10."""
     super(Imagenet, self).__init__(**kwargs)
     self.data_path = FileOps.download_dataset(self.args.data_path)
     self.fp16 = self.args.fp16
     self.num_parallel_batches = self.args.num_parallel_batches
     self.image_size = self.args.image_size
     self.drop_remainder = self.args.drop_last
     if self.data_path == 'null' or not self.data_path:
         self.data_path = None
     self.num_parallel_calls = self.args.num_parallel_calls
    def after_valid(self, logs=None):
        """Call after_valid of the managed callbacks."""
        self.model = self.trainer.model
        feature_interaction_score = self.model.get_feature_interaction_score()
        print('get feature_interaction_score', feature_interaction_score)

        curr_auc = float(self.trainer.valid_metrics.results['auc'])
        if curr_auc > self.best_score:
            best_config = {
                'score': curr_auc,
                'feature_interaction_score': feature_interaction_score
            }

            logging.info("BEST CONFIG IS\n{}".format(best_config))
            pickle_result_file = FileOps.join_path(
                self.trainer.local_output_path, 'best_config.pickle')
            logging.info("Saved to {}".format(pickle_result_file))
            FileOps.dump_pickle(best_config, pickle_result_file)

            self.best_score = curr_auc
示例#10
0
 def before_train(self, logs=None):
     """Be called before the training process."""
     self.cfg = self.trainer.cfg
     self.trainer.auto_save_ckpt = False
     self.trainer.auto_save_perf = False
     self.worker_id = self.trainer.worker_id
     self.local_base_path = self.trainer.local_base_path
     self.local_output_path = self.trainer.local_output_path
     self.result_path = FileOps.join_path(self.local_output_path,
                                          self.cfg.step_name)
     FileOps.make_dir(self.result_path)
     count_input = torch.FloatTensor(1, 3, 192, 192).cuda()
     flops_count, params_count = calc_model_flops_params(
         self.trainer.model, count_input)
     GFlops, KParams = flops_count * 1e-9, params_count * 1e-3
     logger.info("Flops: {:.2f} G, Params: {:.1f} K".format(
         GFlops, KParams))
     if GFlops > 0.6:
         logger.info("Flop too large!")
         self.trainer.skip_train = True
     self._copy_needed_file()
示例#11
0
 def _init_model(self, model=None):
     """Load model desc from save path and parse to model."""
     if model is not None:
         if vega.is_torch_backend() and self.use_cuda:
             model = model.cuda()
         return model
     model_cfg = Config(ClassFactory.__configs__.get('model'))
     if "model_desc_file" in model_cfg and model_cfg.model_desc_file is not None:
         desc_file = model_cfg.model_desc_file
         desc_file = desc_file.replace("{local_base_path}",
                                       self.local_base_path)
         if ":" not in desc_file:
             desc_file = os.path.abspath(desc_file)
         if ":" in desc_file:
             local_desc_file = FileOps.join_path(
                 self.local_output_path, os.path.basename(desc_file))
             FileOps.copy_file(desc_file, local_desc_file)
             desc_file = local_desc_file
         model_desc = Config(desc_file)
         logging.info("net_desc:{}".format(model_desc))
     elif "model_desc" in model_cfg and model_cfg.model_desc is not None:
         model_desc = model_cfg.model_desc
     elif "models_folder" in model_cfg and model_cfg.models_folder is not None:
         folder = model_cfg.models_folder.replace("{local_base_path}",
                                                  self.local_base_path)
         pattern = FileOps.join_path(folder, "desc_*.json")
         desc_file = glob.glob(pattern)[0]
         model_desc = Config(desc_file)
     else:
         return None
     if model_desc is not None:
         self.model_desc = model_desc
         net_desc = NetworkDesc(model_desc)
         model = net_desc.to_model()
         if vega.is_torch_backend() and self.use_cuda:
             model = model.cuda()
         return model
     else:
         return None
示例#12
0
    def _save_pareto_front(self, metric_x, metric_y):
        """Save pareto front of the searched models.

        :param metric_x: x axis of pareto front
        :param metric_y: y axis of pareto front
        """
        df_all = pd.read_csv(FileOps.join_path(self.result_path, "random.csv"))
        mutate_csv = FileOps.join_path(self.result_path, 'mutate.csv')
        if os.path.exists(mutate_csv):
            df_mutate = pd.read_csv(mutate_csv)
            df_all = pd.concat([df_all, df_mutate], ignore_index=True)
        current_best = 0
        df_result = pd.DataFrame(columns=df_all.columns)
        df_all = df_all.sort_values(by=metric_x)
        for _, row in df_all.iterrows():
            if row[metric_y] > current_best:
                current_best = row[metric_y]
                df_result.loc[len(df_result)] = row
        result_file_name = FileOps.join_path(self.result_path,
                                             "pareto_front.csv")
        df_result.to_csv(result_file_name, index=False)
        logger.info("Pareto front updated to {}".format(result_file_name))
示例#13
0
 def set_trainer(self, trainer):
     """Set trainer object for current callback."""
     self.trainer = trainer
     self.trainer._train_loop = self.train_process
     self.cfg = self.trainer.config
     self._worker_id = self.trainer._worker_id
     if hasattr(self.cfg, "kwargs") and "spnas_sample" in self.cfg.kwargs:
         self.sample_result = self.cfg.kwargs["spnas_sample"]
     self.worker_path = self.trainer.get_local_worker_path()
     self.output_path = self.trainer.local_output_path
     self.best_model_name = "model_best"
     self.best_model_file = FileOps.join_path(
         self.worker_path, "model_{}.pth".format(self.trainer.worker_id))
示例#14
0
 def logger_patch(self):
     """Patch the default logger."""
     worker_path = self.trainer.get_local_worker_path()
     worker_spec_log_file = FileOps.join_path(worker_path,
                                              'current_worker.log')
     logger = logging.getLogger(__name__)
     for hdlr in logger.handlers:
         logger.removeHandler(hdlr)
     for hdlr in logging.root.handlers:
         logging.root.removeHandler(hdlr)
     logger.addHandler(logging.FileHandler(worker_spec_log_file))
     logger.addHandler(logging.StreamHandler())
     logger.setLevel(logging.INFO)
     logging.root = logger
示例#15
0
文件: trainer.py 项目: zhwzhong/vega
    def output_model(self,
                     id=None,
                     model=None,
                     model_desc=None,
                     performance=None):
        """Save model, model description, performance.

        :param id: model desc id, usually worker id.
        :type id: int or str.
        :param model: hyper parameters.
        :type hps: json.

        """
        if id is None:
            id = self.worker_id
        if model is None:
            if not hasattr(self, "model"):
                logger.error(
                    "Failed to save model, param 'model' is not assigned.")
                return
            model = self.model
        if model_desc is None:
            if not hasattr(self, "model_desc"):
                logger.error(
                    "Failed to save model, param 'model_desc' is not assigned."
                )
                return
            model_desc = self.model_desc
        _pth_file = FileOps.join_path(self.local_output_path, self.step_name,
                                      "model_{}.pth".format(id))
        FileOps.make_base_dir(_pth_file)
        try:
            torch.save(model.state_dict(), _pth_file)
        except Exception as ex:
            logger.error("Failed to save model pth, file={}, msg={}".format(
                _pth_file, str(ex)))
        self.output_model_desc(id, model_desc, performance)
示例#16
0
 def _save_model_desc(self):
     search_space = SearchSpace()
     codec = Codec(self.cfg.codec, search_space)
     pareto_front_df = pd.read_csv(
         FileOps.join_path(self.result_path, "pareto_front.csv"))
     codes = pareto_front_df['Code']
     for i in range(len(codes)):
         search_desc = Config()
         search_desc.custom = deepcopy(search_space.search_space.custom)
         search_desc.modules = deepcopy(search_space.search_space.modules)
         code = codes.loc[i]
         search_desc.custom.code = code
         search_desc.custom.method = 'full'
         codec.decode(search_desc.custom)
         self.trainer.output_model_desc(i, search_desc)
示例#17
0
    def __init__(self, **kwargs):
        super(Cityscapes, self).__init__(**kwargs)
        config = obj2config(getattr(self.config, self.mode))
        config.update(self.args)
        self.args = config
        self.root_dir = self.args['root_dir']
        self.image_size = self.args.Rescale.size
        self.list_file = self.args.list_file
        self.batch_size = self.args.get('batch_size', 1)
        self.num_parallel_batches = self.args.get('num_parallel_batches', 1)
        self.drop_remainder = self.args.get('drop_remainder', False)

        self.transforms = self._init_transforms()
        self.root_dir = FileOps.download_dataset(self.root_dir)
        self._init_data_files()
示例#18
0
 def load_checkpoint(self,
                     worker_id=None,
                     step_name=None,
                     saved_folder=None):
     """Load checkpoint."""
     if saved_folder is None:
         if worker_id is None:
             worker_id = self.worker_id
         if step_name is None:
             step_name = self.step_name
         saved_folder = self.get_local_worker_path(step_name, worker_id)
     checkpoint_file = FileOps.join_path(saved_folder,
                                         self.checkpoint_file_name)
     model_pickle_file = FileOps.join_path(saved_folder,
                                           self.model_pickle_file_name)
     if not (os.path.isfile(checkpoint_file)):
         checkpoint_file = FileOps.join_path(saved_folder,
                                             str(self.worker_id),
                                             self.checkpoint_file_name)
     if not (os.path.isfile(model_pickle_file)):
         model_pickle_file = FileOps.join_path(saved_folder,
                                               str(self.worker_id),
                                               self.model_pickle_file_name)
     try:
         with open(model_pickle_file, 'rb') as f:
             model = pickle.load(f)
             ckpt = torch.load(checkpoint_file,
                               map_location=torch.device('cpu'))
             model.load_state_dict(ckpt['weight'])
             if self.config.cuda:
                 model = model.cuda()
             self.model = model
     except Exception:
         logging.info(
             'Checkpoint file is not existed, use default model now.')
         return
示例#19
0
    def save_genotypes_to_json(self, genotypes, acc, obj, save_folder,
                               ga_epoch):
        """Save genotypes.

        :param genotypes: Genotype for models
        :type genotypes: namedtuple Genotype
        :param acc: accuracy
        :type acc: ndarray
        :param obj: objectives, etc. FLOPs or number of parameters
        :type obj: ndarray
        :param save_name: Path to save
        :type save_name: string
        """
        if self.trainer.cfg.darts_template_file == "{default_darts_cifar10_template}":
            template = DefaultConfig().data.default_darts_cifar10_template
        elif self.trainer.cfg.darts_template_file == "{default_darts_imagenet_template}":
            template = DefaultConfig().data.default_darts_imagenet_template
        else:
            worker_path = self.trainer.get_local_worker_path()
            _path = os.path.join(worker_path,
                                 save_folder + '_{}'.format(ga_epoch))
            if not os.path.isdir(_path):
                os.makedirs(_path)
            base_file = os.path.basename(self.trainer.cfg.darts_template_file)
            local_template = FileOps.join_path(self.trainer.local_output_path,
                                               base_file)
            FileOps.copy_file(self.trainer.cfg.darts_template_file,
                              local_template)
            with open(local_template, 'r') as f:
                template = json.load(f)

        for idx in range(len(genotypes)):
            template_cfg = Config(template)
            template_cfg.super_network.normal.genotype = genotypes[idx].normal
            template_cfg.super_network.reduce.genotype = genotypes[idx].reduce
            self.trainer.output_model_desc(idx, template_cfg)
    def after_train(self, logs=None):
        """Call after_train of the managed callbacks."""
        curr_auc = float(self.trainer.valid_metrics.results['auc'])

        self.sieve_board = self.sieve_board.append(
            {
                'selected_feature_pairs': self.selected_pairs,
                'score': curr_auc
            },
            ignore_index=True)
        result_file = FileOps.join_path(
            self.trainer.local_output_path,
            '{}_result.csv'.format(self.trainer.__worker_id__))

        self.sieve_board.to_csv(result_file, sep='\t')
示例#21
0
 def _init_hps(self, hps=None):
     """Load hps from file."""
     if hps is not None:
         self.hps = hps
     elif self.config.hps_file is not None:
         desc_file = self.config.hps_file.replace("{local_base_path}",
                                                  self.local_base_path)
         self.hps = Config(desc_file)
     elif self.config.hps_folder is not None:
         folder = self.config.hps_folder.replace("{local_base_path}",
                                                 self.local_base_path)
         pattern = FileOps.join_path(folder, "desc_*.json")
         desc_file = glob.glob(pattern)[0]
         self.hps = Config(desc_file)
     if self.hps and self.hps.get('trainer'):
         load_conf_from_desc(self.config, self.hps.get('trainer'))
示例#22
0
文件: trainer.py 项目: zhwzhong/vega
    def _save_performance(self, performance):
        """Save performance into performance.txt.

        :param performance: performance value
        """
        logging.debug("performance=%s", str(performance))
        self.performance_file = FileOps.join_path(self.get_local_worker_path(),
                                                  self.performance_file_name)
        with open(self.performance_file, 'w') as f:
            if isinstance(performance, list):
                for p in performance:
                    f.write("{}\n".format(p))
            elif isinstance(performance, dict):
                for p in performance.values():
                    f.write("{}\n".format(p))
            else:
                f.write("{}".format(performance))
示例#23
0
文件: report.py 项目: wjwangppt/vega
 def _output_records(self,
                     step_name,
                     records,
                     desc=True,
                     weights_file=False,
                     performance=False):
     """Dump records."""
     columns = ["worker_id", "performance", "desc"]
     outputs = []
     for record in records:
         record = record.serialize()
         _record = {}
         for key in columns:
             _record[key] = record[key]
         outputs.append(deepcopy(_record))
     data = pd.DataFrame(outputs)
     step_path = FileOps.join_path(TaskOps().local_output_path, step_name)
     FileOps.make_dir(step_path)
     _file = FileOps.join_path(step_path, "output.csv")
     try:
         data.to_csv(_file, index=False)
     except Exception:
         logging.error("Failed to save output file, file={}".format(_file))
     for record in outputs:
         worker_id = record["worker_id"]
         worker_path = TaskOps().get_local_worker_path(step_name, worker_id)
         outputs_globs = []
         if desc:
             outputs_globs += glob.glob(
                 FileOps.join_path(worker_path, "desc_*.json"))
         if weights_file:
             outputs_globs += glob.glob(
                 FileOps.join_path(worker_path, "model_*.pth"))
         if performance:
             outputs_globs += glob.glob(
                 FileOps.join_path(worker_path, "performance_*.json"))
         for _file in outputs_globs:
             FileOps.copy_file(_file, step_path)
示例#24
0
    def _save_checkpoint(self, performance=None, model_name="best.pth"):
        """Save the trained model.

        :param performance: dict of all the result needed
        :param model_name: name of the result file
        :return: the path of the saved file
        """
        model_save_path = FileOps.join_path(
            self.trainer.get_local_worker_path(), model_name)
        torch.save(
            {
                'model_state_dict': self.trainer.model.state_dict(),
                **performance
            }, model_save_path)

        torch.save(self.trainer.model.state_dict(), model_save_path)
        logger.info("model saved to {}".format(model_save_path))
        return model_save_path
示例#25
0
文件: coco.py 项目: zeyefkey/vega
 def __init__(self,
              data_dir,
              batch_size,
              mode,
              num_parallel_batches=1,
              repeat_num=5,
              padding=8,
              fp16=False,
              drop_remainder=False):
     """Init CocoTF."""
     self.data_dir = FileOps.download_dataset(data_dir)
     self.batch_size = batch_size
     self.mode = mode
     self.num_parallel_batches = num_parallel_batches
     self.repeat_num = repeat_num
     self.dtype = tf.float16 if fp16 is True else tf.float32
     self.drop_remainder = drop_remainder
     self._include_mask = False
     self._dataset_fn = tf.data.TFRecordDataset
示例#26
0
 def _get_current_step_records(self):
     step_name = self.task.step_name
     models_folder = PipeStepConfig.pipe_step.get("models_folder")
     records = []
     cur_index = PipelineConfig.steps.index(step_name)
     if cur_index >= 1 or models_folder:
         # records = Report().get_pareto_front_records(PipelineConfig.steps[cur_index - 1])
         if not models_folder:
             models_folder = FileOps.join_path(
                 TaskOps().local_output_path, PipelineConfig.steps[cur_index - 1])
         models_folder = models_folder.replace(
             "{local_base_path}", TaskOps().local_base_path)
         records = Report().load_records_from_model_folder(models_folder)
     else:
         records = [ReportRecord(step_name, 0)]
     logging.debug("Records: {}".format(records))
     for record in records:
         record.step_name = step_name
     return records
示例#27
0
 def _copy_needed_file(self):
     if "pareto_front_file" in self.cfg and self.cfg.pareto_front_file is not None:
         init_pareto_front_file = self.cfg.pareto_front_file.replace(
             "{local_base_path}", self.local_base_path)
         self.pareto_front_file = FileOps.join_path(self.result_path,
                                                    "pareto_front.csv")
         FileOps.copy_file(init_pareto_front_file, self.pareto_front_file)
     if "random_file" in self.cfg and self.cfg.random_file is not None:
         init_random_file = self.cfg.random_file.replace(
             "{local_base_path}", self.local_base_path)
         self.random_file = FileOps.join_path(self.local_output_path,
                                              self.cfg.step_name,
                                              "random.csv")
         FileOps.copy_file(init_random_file, self.random_file)
示例#28
0
    def _init_model(self):
        """Initialize the model architecture for full train step.

        :return: train model
        :rtype: class
        """
        logging.info('Initializing model')
        if self.cfg.model_desc:
            logging.debug("model_desc: {}".format(self.cfg.model_desc))
            _file = FileOps.join_path(
                self.worker_path, "model_desc_{}.json".format(self._worker_id))
            with open(_file, "w") as f:
                json.dump(self.cfg.model_desc, f)
            if self.cfg.distributed:
                hvd.join()
            model_desc = self.cfg.model_desc
            net_desc = NetworkDesc(model_desc)
            model = net_desc.to_model()
            return model
        else:
            return None
示例#29
0
 def __init__(self, **kwargs):
     """Init Cifar10."""
     super(Cifar10, self).__init__(**kwargs)
     self.data_path = FileOps.download_dataset(self.args.data_path)
     self.num_parallel_batches = self.args.num_parallel_batches
     self.train_portion = self.args.train_portion
     self.dtype = tf.float16 if self.args.fp16 is True else tf.float32
     self.num_channels = 3
     self.height = 32
     self.width = 32
     self.single_data_bytes = self.height * self.width * self.num_channels + 1
     self.num_images = self.args.num_images
     if self.train_portion != 1:
         if self.mode == 'train':
             self.num_images = int(self.num_images * self.train_portion)
         elif self.mode == 'val':
             self.num_images = int(self.args.num_images_train *
                                   (1 - self.train_portion))
     self.drop_remainder = self.args.drop_last
     self.single_data_size = [self.num_channels, self.height, self.width]
     if self.mode == 'train':
         self.padding = self.args.padding
示例#30
0
文件: trainer.py 项目: zhwzhong/vega
    def get_performance(self,
                        worker_id=None,
                        step_name=None,
                        saved_folder=None):
        """Read Performance values from perform.txt.

        :param step_name: step name in the pipeline.
        :type step_name: str.
        :param worker_id: the worker's worker id.
        :type worker_id: str.
        :return: performance value
        :rtype: int/float/list

        """
        if saved_folder is None:
            if worker_id is None:
                worker_id = self.worker_id
            if step_name is None:
                step_name = self.step_name
            saved_folder = self.get_local_worker_path(step_name, worker_id)
        performance_file = FileOps.join_path(saved_folder,
                                             self.performance_file_name)
        if not os.path.isfile(performance_file):
            logging.info("Performance file is not exited, file={}".format(
                performance_file))
            return []
        with open(performance_file, 'r') as f:
            performance = []
            for line in f.readlines():
                line = line.strip()
                if line == "":
                    continue
                data = json.loads(line)
                if isinstance(data, list):
                    data = data[0]
                performance.append(data)
            logging.info("performance={}".format(performance))
        return performance