コード例 #1
0
 def _init_dataloader(self, mode, loader=None, transforms=None):
     """Init dataloader."""
     if loader is not None:
         return loader
     if mode == "train" and self.hps is not None and self.hps.get(
             "dataset") is not None:
         if self.hps.get("dataset") and self.hps.get("dataset").get('type'):
             dataset_cls = ClassFactory.get_cls(
                 ClassType.DATASET,
                 self.hps.get("dataset").get('type'))
         else:
             dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
         dataset = dataset_cls(mode=mode, hps=self.hps.get("dataset"))
     elif self.hps:
         if self.hps.get("dataset") and self.hps.get("dataset").get('type'):
             dataset_cls = ClassFactory.get_cls(
                 ClassType.DATASET,
                 self.hps.get("dataset").get('type'))
             dataset = dataset_cls(mode=mode, hps=self.hps.get("dataset"))
         else:
             dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
             dataset = dataset_cls(mode=mode)
     else:
         dataset_cls = ClassFactory.get_cls(ClassType.DATASET)
         dataset = dataset_cls(mode=mode)
     if transforms is not None:
         dataset.transforms = transforms
     if self.distributed and mode == "train":
         dataset.set_distributed(self._world_size, self._rank_id)
     # adapt the dataset to specific backend
     dataloader = Adapter(dataset).loader
     return dataloader
コード例 #2
0
ファイル: dqn_zeus.py プロジェクト: shishouyuan/xingtian
    def create_model(self, model_info):
        """Create Deep-Q network."""
        model_config = model_info['model_config']
        zeus_model_name = model_config['zeus_model_name']

        model_cls = ClassFactory.get_cls(ClassType.NETWORK, zeus_model_name)
        zeus_model = model_cls(state_dim=self.state_dim,
                               action_dim=self.action_dim)
        # zeus_model = DqnMlpNet(state_dim=self.state_dim, action_dim=self.action_dim)

        LossConfig.type = model_config['loss']
        OptimConfig.type = model_config['optim']
        OptimConfig.params.update({'lr': model_config['lr']})

        loss_input = dict()
        loss_input['inputs'] = [{
            "name": "input_state",
            "type": "float32",
            "shape": self.state_dim
        }]
        loss_input['labels'] = [{
            "name": "target_value",
            "type": "float32",
            "shape": self.action_dim
        }]
        model = Trainer(model=zeus_model,
                        backend='tensorflow',
                        device='GPU',
                        loss_input=loss_input,
                        lazy_build=False)
        return model
コード例 #3
0
ファイル: train_pipe_step.py プロジェクト: vineetrao25/vega
 def _train_single_model(self,
                         model_desc=None,
                         model_id=None,
                         weights_file=None):
     cls_trainer = ClassFactory.get_cls('trainer')
     step_name = self.task.step_name
     if model_desc is not None:
         sample = dict(worker_id=model_id,
                       desc=model_desc,
                       step_name=step_name)
         record = ReportRecord().load_dict(sample)
         logging.debug("Broadcast Record=%s", str(record))
         trainer = cls_trainer(model_desc=model_desc,
                               id=model_id,
                               pretrained_model_file=weights_file)
     else:
         trainer = cls_trainer(None, 0)
         record = ReportRecord(trainer.step_name,
                               trainer.worker_id,
                               desc=trainer.model_desc)
     ReportClient.broadcast(record)
     ReportServer.add_watched_var(trainer.step_name, trainer.worker_id)
     # resume training
     if vega.is_torch_backend() and General._resume:
         trainer.load_checkpoint = True
         trainer._resume_training = True
     if self._distributed_training:
         self._do_distributed_fully_train(trainer)
     else:
         self._do_single_fully_train(trainer)
コード例 #4
0
ファイル: train_pipe_step.py プロジェクト: vineetrao25/vega
    def _do_hccl_fully_train(self, trainer):
        origin_worker_id = trainer.worker_id
        model_desc = trainer.model_desc
        del trainer

        origin_parallel_fully_train = General.parallel_fully_train
        origin_parallel = General._parallel
        General.parallel_fully_train = True
        General.dft = True
        General._parallel = True

        cls_trainer = ClassFactory.get_cls('trainer')
        self.master = create_master()
        workers_num = int(os.environ['RANK_SIZE'])
        for i in range(workers_num):
            worker_id = "{}-{}".format(origin_worker_id, i)
            trainer = cls_trainer(model_desc, id=worker_id)
            evaluator = self._get_evaluator(worker_id)
            self.master.run(trainer, evaluator)

        self.master.join()
        self.master.shutdown()
        General.parallel_fully_train = origin_parallel_fully_train
        General.dft = False
        General._parallel = origin_parallel
コード例 #5
0
ファイル: train_pipe_step.py プロジェクト: vineetrao25/vega
 def _get_evaluator(self, worker_id):
     if not PipeStepConfig.evaluator_enable:
         return None
     cls_evaluator = ClassFactory.get_cls('evaluator', "Evaluator")
     evaluator = cls_evaluator({
         "step_name": self.task.step_name,
         "worker_id": worker_id
     })
     return evaluator
コード例 #6
0
 def to_model(self):
     """Transform a NetworkDesc to a special model."""
     logging.debug("Start to Create a Network.")
     module = ClassFactory.get_cls(ClassType.NETWORK, "Module")
     model = module.from_desc(self._desc)
     if not model:
         raise Exception("Failed to create model, model desc={}".format(
             self._desc))
     if not self.is_deformation:
         model.desc = self._desc
     return model
コード例 #7
0
ファイル: search_space.py プロジェクト: vineetrao25/vega
 def __init__(self, desc=None):
     """Init SearchSpace."""
     super(SearchSpace, self).__init__()
     if desc is None:
         desc = SearchSpaceConfig().to_dict()
         if desc.type is not None:
             desc = ClassFactory.get_cls(ClassType.SEARCHSPACE,
                                         desc.type).get_space(desc)
     for name, item in desc.items():
         self.__setattr__(name, item)
         self.__setitem__(name, item)
     self._params = OrderedDict()
     self._condition_dict = OrderedDict()
     self._forbidden_list = []
     self._hp_count = 0
     self._dag = DAG()
     if desc is not None:
         self.form_desc(desc)