예제 #1
0
    def _init_model(self):
        self.gan_net = self.model_manager.gan_model()
        self.gan_net = RunnerHelper.load_net(self, self.gan_net)

        self.optimizer, self.scheduler = Trainer.init(self._get_parameters(), self.configer.get('solver'))

        self.train_loader = self.seg_data_loader.get_trainloader()
        self.val_loader = self.seg_data_loader.get_valloader()
예제 #2
0
 def _init_model(self):
     # torch.multiprocessing.set_sharing_strategy('file_system')
     self.det_net = self.det_model_manager.object_detector()
     self.det_net = RunnerHelper.load_net(self, self.det_net)
     self.optimizer, self.scheduler = Trainer.init(
         self._get_parameters(), self.configer.get('solver'))
     self.train_loader = self.det_data_loader.get_trainloader()
     self.val_loader = self.det_data_loader.get_valloader()
     self.det_loss = self.det_model_manager.get_det_loss()
예제 #3
0
    def __init__(self, configer):
        self.configer = configer
        self.runner_state = dict()

        self.batch_time = AverageMeter()
        self.data_time = AverageMeter()
        self.train_losses = DictAverageMeter()
        self.val_losses = DictAverageMeter()
        self.cls_model_manager = ModelManager(configer)
        self.cls_data_loader = DataLoader(configer)
        self.running_score = ClsRunningScore(configer)

        self.cls_net = self.cls_model_manager.get_cls_model()
        self.solver_dict = self.configer.get('solver')
        self.cls_net = RunnerHelper.load_net(self, self.cls_net)
        self.optimizer, self.scheduler = Trainer.init(self._get_parameters(),
                                                      self.solver_dict)
        self.train_loader = self.cls_data_loader.get_trainloader()
        self.val_loader = self.cls_data_loader.get_valloader()
        self.loss = self.cls_model_manager.get_cls_loss()
예제 #4
0
 def _init_model(self):
     self.pose_net = self.pose_model_manager.get_pose_model()
     self.pose_net = RunnerHelper.load_net(self, self.pose_net)
     self.pose_net.eval()
예제 #5
0
 def _init_model(self):
     self.gan_net = self.model_manager.gan_model()
     self.gan_net = RunnerHelper.load_net(self, self.gan_net)
     self.gan_net.eval()
예제 #6
0
 def _init_model(self):
     self.det_net = self.det_model_manager.object_detector()
     self.det_net = RunnerHelper.load_net(self, self.det_net)
     self.det_net.eval()