示例#1
0
    def __init__(self, opts):
        self.opts = opts
        self.model_path = opts.model_path
        self.output_dir = opts.output_dir
        self.gpu_id = opts.gpu_id
        self.disp_module = opts.disp_module
        self.dataset = opts.dataset
        self.batch_size = opts.batch_size
        self.train = opts.train

        # The data loader
        # getting the dataloader ready
        if self.dataset == 'kitti':
            dataset = Datasets.KittiDataset(self.opts)
        elif self.dataset == 'nyu':
            dataset = Datasets.NYUDataset(self.opts)
        else:
            raise NameError('Dataset not found')
        self.DataLoader = data.DataLoader(dataset,
                                          batch_size=self.batch_size,
                                          shuffle=False,
                                          num_workers=8)
        print('Data loader made')

        # loading the model
        disp_module = importlib.import_module(self.disp_module)
        self.DispNet = disp_module.DispResNet()
        self.DispNet.load_state_dict(torch.load(self.model_path))
        if self.gpu_id is not None:
            self.device = torch.device('cuda:' + str(self.gpu_id[0]))
            self.DispNet = self.DispNet.to(self.device)
            if len(self.gpu_id) > 1:
                self.DispNet = nn.DataParallel(self.DispNet, self.gpu_id)
        else:
            self.device = torch.device('cpu')
        self.DispNet.eval()
        print('Model Loaded')

        self.start_test()
示例#2
0
    def __init__(self, opts):
        self.opts = opts 
        self.epochs = opts.epochs 
        self.batch_size = opts.batch_size
        self.shuffle = opts.shuffle 
        self.lr = opts.lr
        self.beta1 = opts.beta1 
        self.beta2 = opts.beta2
        self.console_out = opts.console_out 
        self.save_disp = opts.save_disp 
        self.tboard_out = opts.tboard_out 
        self.log_tb = opts.log_tensorboard 
        self.disp_module = opts.disp_module 
        self.pose_module = opts.pose_module 
        self.dataset = opts.dataset
        self.gpus = opts.gpus  
        self.tboard_dir = opts.tboard_dir 
        self.int_result_dir = opts.int_results_dir 
        self.save_model_dir = opts.save_model_dir 
        self.save_model_iter = opts.save_model_iter 
        self.frame_size = opts.frame_size
        self.disp_model_path = opts.disp_model_path 
        self.pose_model_path = opts.pose_model_path
        self.start_time = time.time()

        # getting the dataloader ready 
        if self.dataset == 'kitti':
            dataset = Datasets.KittiDataset(self.opts)
        elif self.dataset == 'nyu':
            dataset = Datasets.NYUDataset(self.opts)
        else:
            raise NameError('Dataset not found')
        self.DataLoader = data.DataLoader(dataset, batch_size=self.batch_size, shuffle=self.shuffle, 
                                          num_workers=8)
        
        # loading the modules 
        if len(self.gpus) == 0:
            self.device = torch.device('cpu')
        else:
            self.device = torch.device('cuda:' + str(self.gpus[0]))
        disp_module = importlib.import_module(self.disp_module)
        pose_module = importlib.import_module(self.pose_module)
        self.DispModel = disp_module.DispResNet().to(self.device)
        self.PoseModel = pose_module.PoseResNet().to(self.device)
        if self.disp_model_path is not None:
            self.DispModel.load_state_dict(torch.load(self.disp_model_path))
        if self.pose_model_path is not None:
            self.PoseModel.load_state_dict(torch.load(self.pose_model_path))
        if len(self.gpus) != 0:
            self.DispModel = nn.DataParallel(self.DispModel, self.gpus)
            self.PoseModel = nn.DataParallel(self.PoseModel, self.gpus)
        self.DispModel.to(self.device)
        self.PoseModel.to(self.device)
        self.Loss = Loss(opts)

        # the optimizer 
        params_dict = [{'params': self.DispModel.parameters()},
                       {'params': self.PoseModel.parameters()}]
        self.optim = torch.optim.Adam(params_dict, lr=self.lr, betas=[self.beta1, self.beta2])

        # data logger and output
        self.fix_batch = next(iter(self.DataLoader))
        self.writer = tb.SummaryWriter(self.tboard_dir)

        self.start()