示例#1
0
 def _init_model(self):
     """Load model desc from save path and parse to model."""
     model = self.trainer.model
     if self.trainer.config.is_detection_trainer:
         model_desc = self.trainer.model_desc
     else:
         model_desc = self._get_model_desc()
     if model_desc:
         ModelConfig.model_desc = model_desc
     pretrained_model_file = self._get_pretrained_model_file()
     if not model:
         if not model_desc:
             raise Exception(
                 "Failed to Init model, can not get model description.")
         model = ModelZoo.get_model(model_desc, pretrained_model_file)
     if model:
         if zeus.is_torch_backend():
             import torch
             if self.trainer.use_cuda:
                 model = model.cuda()
             if General._parallel and General.devices_per_trainer > 1:
                 model = torch.nn.DataParallel(self.trainer.model)
         if zeus.is_tf_backend():
             if pretrained_model_file:
                 model_folder = os.path.dirname(pretrained_model_file)
                 FileOps.copy_folder(model_folder,
                                     self.trainer.get_local_worker_path())
     return model
示例#2
0
文件: inference.py 项目: ylfzr/vega
def _get_model(args):
    """Get model."""
    from zeus.model_zoo import ModelZoo
    model = ModelZoo.get_model(args.model_desc, args.model)
    if vega.is_torch_backend():
        if args.device == "GPU":
            model = model.cuda()
        model.eval()
    return model
示例#3
0
 def load_model(self):
     """Load model."""
     if not self.model_desc and not self.weights_file:
         saved_folder = self.get_local_worker_path(self.step_name,
                                                   self.worker_id)
         self.weights_file = FileOps.join_path(
             saved_folder, 'model_{}.pth'.format(self.worker_id))
         self.model_desc = FileOps.join_path(
             saved_folder, 'desc_{}.json'.format(self.worker_id))
     if 'modules' not in self.model_desc:
         self.model_desc = ModelConfig.model_desc
     self.model = ModelZoo.get_model(self.model_desc, self.weights_file)
示例#4
0
    def load_model(self):
        """Load model."""
        self.saved_folder = self.get_local_worker_path(self.step_name, self.worker_id)
        if not self.model_desc:
            self.model_desc = FileOps.join_path(self.saved_folder, 'desc_{}.json'.format(self.worker_id))
        if not self.weights_file:
            if zeus.is_torch_backend():
                self.weights_file = FileOps.join_path(self.saved_folder, 'model_{}.pth'.format(self.worker_id))
            elif zeus.is_ms_backend():
                for file in os.listdir(self.saved_folder):
                    if file.startswith("CKP") and file.endswith(".ckpt"):
                        self.weights_file = FileOps.join_path(self.saved_folder, file)

        if 'modules' not in self.model_desc:
            self.model_desc = ModelConfig.model_desc
        self.model = ModelZoo.get_model(self.model_desc, self.weights_file)
示例#5
0
 def _init_model(self):
     """Load model desc from save path and parse to model."""
     model = self.trainer.model
     if self.trainer.config.is_detection_trainer:
         model_desc = self.trainer.model_desc or self._get_model_desc()
     else:
         model_desc = self._get_model_desc()
     pretrained_model_file = self._get_pretrained_model_file()
     if not model:
         if not model_desc:
             raise Exception(
                 "Failed to Init model, can not get model description.")
         model = ModelZoo.get_model(model_desc, pretrained_model_file,
                                    ModelConfig.head)
     if model:
         self.trainer.model_desc = model.desc
         if zeus.is_torch_backend():
             import torch
             if self.trainer.use_cuda:
                 model = model.cuda()
             if General._parallel and General.devices_per_trainer > 1:
                 model = torch.nn.DataParallel(model)
     return model
示例#6
0
 def load_model(self):
     """Load model."""
     self.saved_folder = self.get_local_worker_path(self.step_name,
                                                    self.worker_id)
     if not self.model_desc:
         model_config = Config(
             FileOps.join_path(self.saved_folder,
                               'desc_{}.json'.format(self.worker_id)))
         if "type" not in model_config and "modules" not in model_config:
             model_config = ModelConfig.model_desc
         self.model_desc = model_config
     if not self.weights_file:
         if zeus.is_torch_backend():
             self.weights_file = FileOps.join_path(
                 self.saved_folder, 'model_{}.pth'.format(self.worker_id))
         elif zeus.is_ms_backend():
             for file in os.listdir(self.saved_folder):
                 if file.endswith(".ckpt"):
                     self.weights_file = FileOps.join_path(
                         self.saved_folder, file)
         elif zeus.is_tf_backend():
             self.weights_file = FileOps.join_path(
                 self.saved_folder, 'model_{}'.format(self.worker_id))
     self.model = ModelZoo.get_model(self.model_desc, self.weights_file)
示例#7
0
文件: cam.py 项目: ylfzr/vega
def _get_model(args):
    from zeus.model_zoo import ModelZoo
    model = ModelZoo.get_model(args.model_desc_file, args.model_weights_file)
    model = model.cuda()
    model.eval()
    return model