def __init__(self, opts): self.opts = opts self.model_path = opts.model_path self.output_dir = opts.output_dir self.gpu_id = opts.gpu_id self.disp_module = opts.disp_module self.dataset = opts.dataset self.batch_size = opts.batch_size self.train = opts.train # The data loader # getting the dataloader ready if self.dataset == 'kitti': dataset = Datasets.KittiDataset(self.opts) elif self.dataset == 'nyu': dataset = Datasets.NYUDataset(self.opts) else: raise NameError('Dataset not found') self.DataLoader = data.DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=8) print('Data loader made') # loading the model disp_module = importlib.import_module(self.disp_module) self.DispNet = disp_module.DispResNet() self.DispNet.load_state_dict(torch.load(self.model_path)) if self.gpu_id is not None: self.device = torch.device('cuda:' + str(self.gpu_id[0])) self.DispNet = self.DispNet.to(self.device) if len(self.gpu_id) > 1: self.DispNet = nn.DataParallel(self.DispNet, self.gpu_id) else: self.device = torch.device('cpu') self.DispNet.eval() print('Model Loaded') self.start_test()
def __init__(self, opts): self.opts = opts self.epochs = opts.epochs self.batch_size = opts.batch_size self.shuffle = opts.shuffle self.lr = opts.lr self.beta1 = opts.beta1 self.beta2 = opts.beta2 self.console_out = opts.console_out self.save_disp = opts.save_disp self.tboard_out = opts.tboard_out self.log_tb = opts.log_tensorboard self.disp_module = opts.disp_module self.pose_module = opts.pose_module self.dataset = opts.dataset self.gpus = opts.gpus self.tboard_dir = opts.tboard_dir self.int_result_dir = opts.int_results_dir self.save_model_dir = opts.save_model_dir self.save_model_iter = opts.save_model_iter self.frame_size = opts.frame_size self.disp_model_path = opts.disp_model_path self.pose_model_path = opts.pose_model_path self.start_time = time.time() # getting the dataloader ready if self.dataset == 'kitti': dataset = Datasets.KittiDataset(self.opts) elif self.dataset == 'nyu': dataset = Datasets.NYUDataset(self.opts) else: raise NameError('Dataset not found') self.DataLoader = data.DataLoader(dataset, batch_size=self.batch_size, shuffle=self.shuffle, num_workers=8) # loading the modules if len(self.gpus) == 0: self.device = torch.device('cpu') else: self.device = torch.device('cuda:' + str(self.gpus[0])) disp_module = importlib.import_module(self.disp_module) pose_module = importlib.import_module(self.pose_module) self.DispModel = disp_module.DispResNet().to(self.device) self.PoseModel = pose_module.PoseResNet().to(self.device) if self.disp_model_path is not None: self.DispModel.load_state_dict(torch.load(self.disp_model_path)) if self.pose_model_path is not None: self.PoseModel.load_state_dict(torch.load(self.pose_model_path)) if len(self.gpus) != 0: self.DispModel = nn.DataParallel(self.DispModel, self.gpus) self.PoseModel = nn.DataParallel(self.PoseModel, self.gpus) self.DispModel.to(self.device) self.PoseModel.to(self.device) self.Loss = Loss(opts) # the optimizer params_dict = [{'params': self.DispModel.parameters()}, {'params': self.PoseModel.parameters()}] self.optim = torch.optim.Adam(params_dict, lr=self.lr, betas=[self.beta1, self.beta2]) # data logger and output self.fix_batch = next(iter(self.DataLoader)) self.writer = tb.SummaryWriter(self.tboard_dir) self.start()