Beispiel #1
0
def test_refine_stage(args):
    torch.manual_seed(777)
    torch.cuda.manual_seed(777)

    eval_dataset = FlowRefine.FlowSeq(args, isTest=True)
    eval_dataloader = DataLoader(eval_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 drop_last=False,
                                 num_workers=args.n_threads)

    if args.ResNet101:
        dfc_resnet101 = resnet_models.Flow_Branch(66, 4)
        dfc_resnet = nn.DataParallel(dfc_resnet101).cuda()
    else:
        dfc_resnet50 = resnet_models.Flow_Branch_Multi(input_chanels=66, NoLabels=4)
        dfc_resnet = nn.DataParallel(dfc_resnet50).cuda()

    dfc_resnet.eval()

    resume_iter = load_ckpt(args.PRETRAINED_MODEL,
                            [('model', dfc_resnet)], strict=True)

    print('Load Pretrained Model from', args.PRETRAINED_MODEL)

    #task_bar = ProgressBar(eval_dataset.__len__())
    for i, item in tqdm(enumerate(eval_dataloader), total=len(eval_dataset)):
        with torch.no_grad():
            input_x = item[0].cuda()
            flow_masked = item[1].cuda()
            gt_flow = item[2].cuda()
            mask = item[3].cuda()
            output_dir = item[4][0]

            res_flow = dfc_resnet(input_x)

            res_flow_f = res_flow[:, :2, :, :]
            res_flow_r = res_flow[:, 2:, :, :]

            res_complete_f = res_flow_f * mask[:, 10:11, :, :] + flow_masked[:, 10:12, :, :] * (1. - mask[:, 10:11, :, :])
            res_complete_r = res_flow_r * mask[:,32:34,:,:] + flow_masked[:,32:34,:,:] * (1. - mask[:,32:34,:,:])

            output_dir_split = output_dir.split(',')

            output_file_f = os.path.join(args.output_root, output_dir_split[0])
            output_file_r = os.path.join(args.output_root, output_dir_split[1])
            output_basedir = os.path.dirname(output_file_f)
            if not os.path.exists(output_basedir):
                os.makedirs(output_basedir)

            res_save_f = res_complete_f[0].permute(1, 2, 0).contiguous().cpu().data.numpy()
            cvb.write_flow(res_save_f, output_file_f)
            res_save_r = res_complete_r[0].permute(1, 2, 0).contiguous().cpu().data.numpy()
            cvb.write_flow(res_save_r, output_file_r)
            #task_bar.update()
    sys.stdout.write('\n')
    dfc_resnet = None
    torch.cuda.empty_cache()
    print('Refined Results Saved in', args.output_root)
def test_initial_stage(args):
    torch.manual_seed(777)
    torch.cuda.manual_seed(777)

    args.INITIAL_HOLE = True
    args.get_mask = True

    eval_dataset = FlowInitial.FlowSeq(args, isTest=True)
    eval_dataloader = DataLoader(eval_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 drop_last=False,
                                 num_workers=args.n_threads)

    if args.ResNet101:
        dfc_resnet101 = resnet_models.Flow_Branch(33, 2)
        dfc_resnet = nn.DataParallel(dfc_resnet101).cuda()
    else:
        dfc_resnet50 = resnet_models.Flow_Branch_Multi(input_chanels=33,
                                                       NoLabels=2)
        dfc_resnet = nn.DataParallel(dfc_resnet50).cuda()

    dfc_resnet.eval()
    resume_iter = load_ckpt(args.PRETRAINED_MODEL, [('model', dfc_resnet)],
                            strict=True)
    print('Load Pretrained Model from', args.PRETRAINED_MODEL)

    task_bar = ProgressBar(eval_dataset.__len__())
    for i, item in enumerate(eval_dataloader):
        with torch.no_grad():
            input_x = item[0].cuda()
            flow_masked = item[1].cuda()
            mask = item[3].cuda()
            output_dir = item[4][0]

            res_flow = dfc_resnet(input_x)
            res_complete = res_flow * mask[:, 10:
                                           11, :, :] + flow_masked[:, 10:12, :, :] * (
                                               1. - mask[:, 10:11, :, :])

            output_dir_split = output_dir.split(',')
            output_file = os.path.join(args.output_root, output_dir_split[0])
            output_basedir = os.path.dirname(output_file)
            if not os.path.exists(output_basedir):
                os.makedirs(output_basedir)
            res_save = res_complete[0].permute(
                1, 2, 0).contiguous().cpu().data.numpy()
            cvb.write_flow(res_save, output_file)
            task_bar.update()

    print('Initial Results Saved in', args.output_root)
    def flow_completion(self):
        if self.i == -1:
            data_list_dir = os.path.join(self.args.dataset_root, 'data')
            os.makedirs(data_list_dir, exist_ok=True)
            initial_data_list = os.path.join(data_list_dir,
                                             'initial_test_list.txt')
            print('Generate datalist for initial step')
            data_list.gen_flow_initial_test_mask_list(
                flow_root=self.args.DATA_ROOT,
                output_txt_path=initial_data_list)
            self.args.EVAL_LIST = os.path.join(data_list_dir,
                                               'initial_test_list.txt')

            self.args.output_root = os.path.join(self.args.dataset_root,
                                                 'Flow_res', 'initial_res')
            self.args.PRETRAINED_MODEL = self.args.PRETRAINED_MODEL_1

            if self.args.img_size is not None:
                self.args.IMAGE_SHAPE = [
                    self.args.img_size[0] // 2, self.args.img_size[1] // 2
                ]
                self.args.RES_SHAPE = self.args.IMAGE_SHAPE

            print('Flow Completion in First Step')
            self.args.MASK_ROOT = self.args.mask_root
            eval_dataset = FlowInitial.FlowSeq(self.args, isTest=True)
            self.flow_refinement_dataloader = iter(
                DataLoader(eval_dataset,
                           batch_size=self.settings.batch_size,
                           shuffle=False,
                           drop_last=False,
                           num_workers=self.args.n_threads))
            if self.args.ResNet101:
                dfc_resnet101 = resnet_models.Flow_Branch(33, 2)
                self.dfc_resnet = nn.DataParallel(dfc_resnet101).to(
                    self.args.device)
            else:
                dfc_resnet50 = resnet_models.Flow_Branch_Multi(
                    input_chanels=33, NoLabels=2)
                self.dfc_resnet = nn.DataParallel(dfc_resnet50).to(
                    self.args.device)
            self.dfc_resnet.eval()
            io.load_ckpt(self.args.PRETRAINED_MODEL,
                         [('model', self.dfc_resnet)],
                         strict=True)
            print('Load Pretrained Model from', self.args.PRETRAINED_MODEL)

        self.i += 1
        complete = False
        with torch.no_grad():
            try:
                item = next(self.flow_refinement_dataloader)
                input_x = item[0].to(self.args.device)
                flow_masked = item[1].to(self.args.device)
                mask = item[3].to(self.args.device)
                output_dir = item[4][0]

                res_flow = self.dfc_resnet(input_x)
                res_complete = res_flow * mask[:, 10:
                                               11, :, :] + flow_masked[:, 10:12, :, :] * (
                                                   1. - mask[:, 10:11, :, :])

                output_dir_split = output_dir.split(',')
                output_file = os.path.join(self.args.output_root,
                                           output_dir_split[0])
                output_basedir = os.path.dirname(output_file)
                if not os.path.exists(output_basedir):
                    os.makedirs(output_basedir)
                res_save = res_complete[0].permute(
                    1, 2, 0).contiguous().cpu().data.numpy()
                cvb.write_flow(res_save, output_file)
            except StopIteration:
                complete = True
        if self.i == len(self.flow_refinement_dataloader) - 1 or complete:
            self.args.flow_root = self.args.output_root
            del self.flow_refinement_dataloader, self.dfc_resnet
            self.i = -1
            self.state += 1