Exemple #1
0
    def __getitem__(self, index):
        sample = {}
        sample_path = self.samples[index]

        if self.save_filename:
            sample['left_name'] = sample_path['left_name']

        sample['left'] = read_img(sample_path['left'])  # [H, W, 3]
        sample['right'] = read_img(sample_path['right'])

        # GT disparity of subset if negative, finalpass and cleanpass is positive
        subset = True if 'subset' in self.dataset_name else False
        if sample_path['disp'] is not None:
            sample['disp'] = read_disp(sample_path['disp'],
                                       subset=subset)  # [H, W]
        if sample_path['pseudo_disp'] is not None:
            sample['pseudo_disp'] = read_disp(sample_path['pseudo_disp'],
                                              subset=subset)  # [H, W]

        if self.transform is not None:
            sample = self.transform(sample)

        if (self.dataset_name == 'custom_dataset_full'
                or self.dataset_name == 'custom_dataset_sim'
                or self.dataset_name == 'custom_dataset_real'
                or self.dataset_name == 'custom_dataset_obj'):
            temp = pd.read_pickle(sample_path['meta'])
            sample['intrinsic'] = temp['intrinsic']
            sample['baseline'] = abs(
                (temp['extrinsic_l'] - temp['extrinsic_r'])[0][3])

            if (self.dataset_name != 'custom_dataset_full'):
                sample['label'] = np.array(
                    Image.open(sample_path['label']).resize(
                        (960, 540), resample=Image.NEAREST))
                sample['object_ids'] = temp['object_ids']
                if (self.dataset_name != 'custom_dataset_obj'):
                    sample['extrinsic'] = temp['extrinsic']

        return sample
Exemple #2
0
    def __getitem__(self, index):
        sample = {}
        sample_path = self.samples[index]

        if self.save_filename:
            sample['left_name'] = sample_path['left_name']

        sample['left'] = read_img(sample_path['left'])  # [H, W, 3]
        sample['right'] = read_img(sample_path['right'])

        # GT disparity of subset if negative, finalpass and cleanpass is positive
        subset = True if 'subset' in self.dataset_name else False
        if sample_path['disp'] is not None:
            sample['disp'] = read_disp(sample_path['disp'],
                                       subset=subset)  # [H, W]
        if sample_path['pseudo_disp'] is not None:
            sample['pseudo_disp'] = read_disp(sample_path['pseudo_disp'],
                                              subset=subset)  # [H, W]

        if self.transform is not None:
            sample = self.transform(sample)

        return sample
    def __getitem__(self, index):
        sample = {}
        sample_path = self.samples[index]

        if self.save_filename:
            sample['left_name'] = sample_path['left_name']

        sample['left'] = read_img(sample_path)['left']
        sample['right'] = read_img(sample_path)['right']

        # GT disparity of subset if negative, finalpass and cleanpass is positive
        # TODO: 这里说是要判断subset,但是看了一圈好像没有这个词, 并且前面assert说dataset_name必须要在一个字典里,也没有subset
        subset = True if 'subset' in self.dataset_name else False
        if sample_path['disp'] is not None:
            sample['disp'] = read_disp(sample_path['disp'],
                                       subset=subset)  # [H, W]
        if sample_path['pseudo_disp'] is not None:
            sample['pseudo_disp'] = read_disp(sample_path['pseudo_disp'],
                                              subset=subset)  # [H, W]

        if self.transform is not None:
            sample = self.transform(sample)

        return sample
Exemple #4
0
def main():
    # For reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    torch.backends.cudnn.benchmark = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Test loader
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])

    aanet = nets.AANet(
        args.max_disp,
        num_downsample=args.num_downsample,
        feature_type=args.feature_type,
        no_feature_mdconv=args.no_feature_mdconv,
        feature_pyramid=args.feature_pyramid,
        feature_pyramid_network=args.feature_pyramid_network,
        feature_similarity=args.feature_similarity,
        aggregation_type=args.aggregation_type,
        num_scales=args.num_scales,
        num_fusions=args.num_fusions,
        num_stage_blocks=args.num_stage_blocks,
        num_deform_blocks=args.num_deform_blocks,
        no_intermediate_supervision=args.no_intermediate_supervision,
        refinement_type=args.refinement_type,
        mdconv_dilation=args.mdconv_dilation,
        deformable_groups=args.deformable_groups).to(device)

    if os.path.exists(args.pretrained_aanet):
        print('=> Loading pretrained AANet:', args.pretrained_aanet)
        utils.load_pretrained_net(aanet, args.pretrained_aanet, no_strict=True)
    else:
        print('=> Using random initialization')

    if torch.cuda.device_count() > 1:
        print('=> Use %d GPUs' % torch.cuda.device_count())
        aanet = torch.nn.DataParallel(aanet)

    # Inference
    aanet.eval()

    if args.data_dir.endswith('/'):
        args.data_dir = args.data_dir[:-1]

    # all_samples = sorted(glob(args.data_dir + '/*left.png'))
    all_samples = sorted(glob(args.data_dir + '/left/*.png'))

    num_samples = len(all_samples)
    print('=> %d samples found in the data dir' % num_samples)

    for i, sample_name in enumerate(all_samples):
        if i % 100 == 0:
            print('=> Inferencing %d/%d' % (i, num_samples))

        left_name = sample_name

        right_name = left_name.replace('left', 'right')

        left = read_img(left_name)
        right = read_img(right_name)
        sample = {'left': left, 'right': right}
        sample = test_transform(sample)  # to tensor and normalize

        left = sample['left'].to(device)  # [3, H, W]
        left = left.unsqueeze(0)  # [1, 3, H, W]
        right = sample['right'].to(device)
        right = right.unsqueeze(0)

        # Pad
        ori_height, ori_width = left.size()[2:]

        # Automatic
        factor = 48
        args.img_height = math.ceil(ori_height / factor) * factor
        args.img_width = math.ceil(ori_width / factor) * factor

        if ori_height < args.img_height or ori_width < args.img_width:
            top_pad = args.img_height - ori_height
            right_pad = args.img_width - ori_width

            # Pad size: (left_pad, right_pad, top_pad, bottom_pad)
            left = F.pad(left, (0, right_pad, top_pad, 0))
            right = F.pad(right, (0, right_pad, top_pad, 0))

        with torch.no_grad():
            pred_disp = aanet(left, right)[-1]  # [B, H, W]

        if pred_disp.size(-1) < left.size(-1):
            pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
            pred_disp = F.interpolate(
                pred_disp, (left.size(-2), left.size(-1)),
                mode='bilinear') * (left.size(-1) / pred_disp.size(-1))
            pred_disp = pred_disp.squeeze(1)  # [B, H, W]

        # Crop
        if ori_height < args.img_height or ori_width < args.img_width:
            if right_pad != 0:
                pred_disp = pred_disp[:, top_pad:, :-right_pad]
            else:
                pred_disp = pred_disp[:, top_pad:]

        disp = pred_disp[0].detach().cpu().numpy()  # [H, W]

        save_name = os.path.basename(
            left_name)[:-4] + '_' + args.save_suffix + '.png'
        save_name = os.path.join(args.output_dir, save_name)

        if args.save_type == 'pfm':
            if args.visualize:
                skimage.io.imsave(save_name, (disp * 256.).astype(np.uint16))

            save_name = save_name[:-3] + 'pfm'
            write_pfm(save_name, disp)
        elif args.save_type == 'npy':
            save_name = save_name[:-3] + 'npy'
            np.save(save_name, disp)
        else:
            skimage.io.imsave(save_name, (disp * 256.).astype(np.uint16))