예제 #1
0
def main():
    board_writer = SummaryWriter(
        os.path.join(args.checkpoint_dir, "tensorBoard"))

    if args.model == 'PSMNet_stackhourglass':
        net = PSMNet_stackhourglass(args.max_disp)
    elif args.model == 'PSMNet_basic':
        net = PSMNet_basic(args.max_disp)
    else:
        print('no model')

    # Validation loader
    test_transform_list = [
        myTransforms.RandomCrop(args.test_img_height,
                                args.test_img_width,
                                validate=True),
        myTransforms.ToTensor(),
        myTransforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ]

    test_transform = myTransforms.Compose(test_transform_list)
    test_data = StereoDataset(data_dir=args.data_dir,
                              isDebug=args.isDebug,
                              dataset_name=args.dataset_name,
                              mode='test',
                              transform=test_transform)

    logger.info('=> {} test samples found in the test set'.format(
        len(test_data)))

    test_loader = DataLoader(dataset=test_data,
                             batch_size=args.test_batch_size,
                             shuffle=False,
                             num_workers=args.num_workers,
                             pin_memory=True,
                             drop_last=False)

    net.cuda()

    if args.pretrained_net is not None:
        logger.info('=> Loading pretrained Net: %s' % args.pretrained_net)
        # Enable training from a partially pretrained model
        utils.load_pretrained_net(net,
                                  args.pretrained_net,
                                  strict=args.strict,
                                  logger=logger)
    else:
        logger.info('=>  args.pretrained_net is None! Please specify it!!!')
        return

    assert args.test_batch_size == 1, "test_batch_size must be 1."

    logger.info('=> Start testing...')
    testOnTestSet(net,
                  test_loader,
                  args.dataset_name,
                  board_writer,
                  mode="test",
                  epoch=0)
    def __init__(
        self,
        config,
        device,
        transform=transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=StereoMatcherBase.IMAGENET_MEAN,
                    std=StereoMatcherBase.IMAGENET_STD,
                ),
            ]
        ),
    ):

        StereoMatcherBase.__init__(self, config=config, device=device, transform=transform)

        self._model = nets.AANet(
            config.max_disp,
            num_downsample=config.num_downsample,
            feature_type=config.feature_type,
            no_feature_mdconv=config.no_feature_mdconv,
            feature_pyramid=config.feature_pyramid,
            feature_pyramid_network=config.feature_pyramid_network,
            feature_similarity=config.feature_similarity,
            aggregation_type=config.aggregation_type,
            num_scales=config.num_scales,
            num_fusions=config.num_fusions,
            num_stage_blocks=config.num_stage_blocks,
            num_deform_blocks=config.num_deform_blocks,
            no_intermediate_supervision=config.no_intermediate_supervision,
            refinement_type=config.refinement_type,
            mdconv_dilation=config.mdconv_dilation,
            deformable_groups=config.deformable_groups,
        )

        self._model = self._model.to(device)
        utils.load_pretrained_net(self._model, config.model_path, no_strict=True)

        self._model.eval()
예제 #3
0
def main():
    # For reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    torch.backends.cudnn.benchmark = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Test loader
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])
    test_data = dataloader.StereoDataset(data_dir=args.data_dir,
                                         dataset_name=args.dataset_name,
                                         mode=args.mode,
                                         save_filename=True,
                                         transform=test_transform)
    test_loader = DataLoader(dataset=test_data, batch_size=args.batch_size, shuffle=False,
                             num_workers=args.num_workers, pin_memory=True, drop_last=False)

    aanet = nets.AANet(args.max_disp,
                       num_downsample=args.num_downsample,
                       feature_type=args.feature_type,
                       no_feature_mdconv=args.no_feature_mdconv,
                       feature_pyramid=args.feature_pyramid,
                       feature_pyramid_network=args.feature_pyramid_network,
                       feature_similarity=args.feature_similarity,
                       aggregation_type=args.aggregation_type,
                       num_scales=args.num_scales,
                       num_fusions=args.num_fusions,
                       num_stage_blocks=args.num_stage_blocks,
                       num_deform_blocks=args.num_deform_blocks,
                       no_intermediate_supervision=args.no_intermediate_supervision,
                       refinement_type=args.refinement_type,
                       mdconv_dilation=args.mdconv_dilation,
                       deformable_groups=args.deformable_groups).to(device)

    # print(aanet)

    if os.path.exists(args.pretrained_aanet):
        print('=> Loading pretrained AANet:', args.pretrained_aanet)
        utils.load_pretrained_net(aanet, args.pretrained_aanet, no_strict=True)
    else:
        print('=> Using random initialization')

    # Save parameters
    num_params = utils.count_parameters(aanet)
    print('=> Number of trainable parameters: %d' % num_params)

    if torch.cuda.device_count() > 1:
        print('=> Use %d GPUs' % torch.cuda.device_count())
        aanet = torch.nn.DataParallel(aanet)

    # Inference
    aanet.eval()

    inference_time = 0
    num_imgs = 0

    num_samples = len(test_loader)
    print('=> %d samples found in the test set' % num_samples)

    for i, sample in enumerate(test_loader):
        if args.count_time and i == args.num_images:  # testing time only
            break

        if i % 100 == 0:
            print('=> Inferencing %d/%d' % (i, num_samples))

        left = sample['left'].to(device)  # [B, 3, H, W]
        right = sample['right'].to(device)

        # Pad
        ori_height, ori_width = left.size()[2:]
        if ori_height < args.img_height or ori_width < args.img_width:
            top_pad = args.img_height - ori_height
            right_pad = args.img_width - ori_width

            # Pad size: (left_pad, right_pad, top_pad, bottom_pad)
            left = F.pad(left, (0, right_pad, top_pad, 0))
            right = F.pad(right, (0, right_pad, top_pad, 0))

        # Warmup
        if i == 0 and args.count_time:
            with torch.no_grad():
                for _ in range(10):
                    aanet(left, right)

        num_imgs += left.size(0)

        with torch.no_grad():
            time_start = time.perf_counter()
            pred_disp = aanet(left, right)[-1]  # [B, H, W]
            inference_time += time.perf_counter() - time_start

        if pred_disp.size(-1) < left.size(-1):
            pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
            pred_disp = F.interpolate(pred_disp, (left.size(-2), left.size(-1)),
                                      mode='bilinear', align_corners=True, recompute_scale_factor=True) * (left.size(-1) / pred_disp.size(-1))
            pred_disp = pred_disp.squeeze(1)  # [B, H, W]

        # Crop
        if ori_height < args.img_height or ori_width < args.img_width:
            if right_pad != 0:
                pred_disp = pred_disp[:, top_pad:, :-right_pad]
            else:
                pred_disp = pred_disp[:, top_pad:]

        for b in range(pred_disp.size(0)):
            disp = pred_disp[b].detach().cpu().numpy()  # [H, W]
            save_name = sample['left_name'][b]
            save_name = os.path.join(args.output_dir, save_name)
            utils.check_path(os.path.dirname(save_name))
            if not args.count_time:
                if args.save_type == 'pfm':
                    if args.visualize:
                        skimage.io.imsave(save_name, (disp * 256.).astype(np.uint16))

                    save_name = save_name[:-3] + 'pfm'
                    write_pfm(save_name, disp)
                elif args.save_type == 'npy':
                    save_name = save_name[:-3] + 'npy'
                    np.save(save_name, disp)
                else:
                    skimage.io.imsave(save_name, (disp * 256.).astype(np.uint16))

    print('=> Mean inference time for %d images: %.3fs' % (num_imgs, inference_time / num_imgs))
예제 #4
0
파일: train.py 프로젝트: hx-Tang/SAN-stereo
def main():
    # For reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    torch.backends.cudnn.benchmark = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Train loader
    train_transform_list = [transforms.RandomCrop(args.img_height, args.img_width),
                            transforms.RandomColor(),
                            transforms.RandomVerticalFlip(),
                            transforms.ToTensor(),
                            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
                            ]
    train_transform = transforms.Compose(train_transform_list)

    train_data = dataloader.StereoDataset(data_dir=args.data_dir,
                                          dataset_name=args.dataset_name,
                                          mode='train' if args.mode != 'train_all' else 'train_all',
                                          load_pseudo_gt=args.load_pseudo_gt,
                                          transform=train_transform)

    logger.info('=> {} training samples found in the training set'.format(len(train_data)))

    train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers, pin_memory=True, drop_last=True)

    # Validation loader
    val_transform_list = [transforms.RandomCrop(args.val_img_height, args.val_img_width, validate=True),
                          transforms.ToTensor(),
                          transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
                         ]
    val_transform = transforms.Compose(val_transform_list)
    val_data = dataloader.StereoDataset(data_dir=args.data_dir,
                                        dataset_name=args.dataset_name,
                                        mode=args.mode,
                                        transform=val_transform)

    val_loader = DataLoader(dataset=val_data, batch_size=args.val_batch_size, shuffle=False,
                            num_workers=args.num_workers, pin_memory=True, drop_last=False)

    # Network
    aanet = nets.AANet(args.max_disp,
                       num_downsample=args.num_downsample,
                       feature_type=args.feature_type,
                       no_feature_mdconv=args.no_feature_mdconv,
                       feature_pyramid=args.feature_pyramid,
                       feature_pyramid_network=args.feature_pyramid_network,
                       feature_similarity=args.feature_similarity,
                       aggregation_type=args.aggregation_type,
                       num_scales=args.num_scales,
                       num_fusions=args.num_fusions,
                       num_stage_blocks=args.num_stage_blocks,
                       num_deform_blocks=args.num_deform_blocks,
                       no_intermediate_supervision=args.no_intermediate_supervision,
                       refinement_type=args.refinement_type,
                       mdconv_dilation=args.mdconv_dilation,
                       deformable_groups=args.deformable_groups).to(device)

    logger.info('%s' % aanet)

    if args.pretrained_aanet is not None:
        logger.info('=> Loading pretrained AANet: %s' % args.pretrained_aanet)
        # Enable training from a partially pretrained model
        utils.load_pretrained_net(aanet, args.pretrained_aanet, no_strict=(not args.strict))

    if torch.cuda.device_count() > 1:
        logger.info('=> Use %d GPUs' % torch.cuda.device_count())
        aanet = torch.nn.DataParallel(aanet)

    # Save parameters
    num_params = utils.count_parameters(aanet)
    logger.info('=> Number of trainable parameters: %d' % num_params)
    save_name = '%d_parameters' % num_params
    open(os.path.join(args.checkpoint_dir, save_name), 'a').close()

    # Optimizer
    # Learning rate for offset learning is set 0.1 times those of existing layers
    specific_params = list(filter(utils.filter_specific_params,
                                  aanet.named_parameters()))
    base_params = list(filter(utils.filter_base_params,
                              aanet.named_parameters()))

    specific_params = [kv[1] for kv in specific_params]  # kv is a tuple (key, value)
    base_params = [kv[1] for kv in base_params]

    specific_lr = args.learning_rate * 0.1
    params_group = [
        {'params': base_params, 'lr': args.learning_rate},
        {'params': specific_params, 'lr': specific_lr},
    ]

    optimizer = torch.optim.Adam(params_group, weight_decay=args.weight_decay)

    # Resume training
    if args.resume:
        # AANet
        start_epoch, start_iter, best_epe, best_epoch = utils.resume_latest_ckpt(
            args.checkpoint_dir, aanet, 'aanet')

        # Optimizer
        utils.resume_latest_ckpt(args.checkpoint_dir, optimizer, 'optimizer')
    else:
        start_epoch = 0
        start_iter = 0
        best_epe = None
        best_epoch = None

    # LR scheduler
    if args.lr_scheduler_type is not None:
        last_epoch = start_epoch if args.resume else start_epoch - 1
        if args.lr_scheduler_type == 'MultiStepLR':
            milestones = [int(step) for step in args.milestones.split(',')]
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                                milestones=milestones,
                                                                gamma=args.lr_decay_gamma,
                                                                last_epoch=last_epoch)
        else:
            raise NotImplementedError

    train_model = model.Model(args, logger, optimizer, aanet, device, start_iter, start_epoch,
                              best_epe=best_epe, best_epoch=best_epoch)

    logger.info('=> Start training...')

    if args.evaluate_only:
        assert args.val_batch_size == 1
        train_model.validate(val_loader)
    else:
        for _ in range(start_epoch, args.max_epoch):
            if not args.evaluate_only:
                train_model.train(train_loader)
            if not args.no_validate:
                train_model.validate(val_loader)
            if args.lr_scheduler_type is not None:
                lr_scheduler.step()

        logger.info('=> End training\n\n')
예제 #5
0
def main():
    # For reproducibility
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)

    torch.backends.cudnn.benchmark = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Test loader
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ])

    aanet = nets.AANet(
        args.max_disp,
        num_downsample=args.num_downsample,
        feature_type=args.feature_type,
        no_feature_mdconv=args.no_feature_mdconv,
        feature_pyramid=args.feature_pyramid,
        feature_pyramid_network=args.feature_pyramid_network,
        feature_similarity=args.feature_similarity,
        aggregation_type=args.aggregation_type,
        num_scales=args.num_scales,
        num_fusions=args.num_fusions,
        num_stage_blocks=args.num_stage_blocks,
        num_deform_blocks=args.num_deform_blocks,
        no_intermediate_supervision=args.no_intermediate_supervision,
        refinement_type=args.refinement_type,
        mdconv_dilation=args.mdconv_dilation,
        deformable_groups=args.deformable_groups).to(device)

    if os.path.exists(args.pretrained_aanet):
        print('=> Loading pretrained AANet:', args.pretrained_aanet)
        utils.load_pretrained_net(aanet, args.pretrained_aanet, no_strict=True)
    else:
        print('=> Using random initialization')

    if torch.cuda.device_count() > 1:
        print('=> Use %d GPUs' % torch.cuda.device_count())
        aanet = torch.nn.DataParallel(aanet)

    # Inference
    aanet.eval()

    if args.data_dir.endswith('/'):
        args.data_dir = args.data_dir[:-1]

    # all_samples = sorted(glob(args.data_dir + '/*left.png'))
    all_samples = sorted(glob(args.data_dir + '/left/*.png'))

    num_samples = len(all_samples)
    print('=> %d samples found in the data dir' % num_samples)

    for i, sample_name in enumerate(all_samples):
        if i % 100 == 0:
            print('=> Inferencing %d/%d' % (i, num_samples))

        left_name = sample_name

        right_name = left_name.replace('left', 'right')

        left = read_img(left_name)
        right = read_img(right_name)
        sample = {'left': left, 'right': right}
        sample = test_transform(sample)  # to tensor and normalize

        left = sample['left'].to(device)  # [3, H, W]
        left = left.unsqueeze(0)  # [1, 3, H, W]
        right = sample['right'].to(device)
        right = right.unsqueeze(0)

        # Pad
        ori_height, ori_width = left.size()[2:]

        # Automatic
        factor = 48
        args.img_height = math.ceil(ori_height / factor) * factor
        args.img_width = math.ceil(ori_width / factor) * factor

        if ori_height < args.img_height or ori_width < args.img_width:
            top_pad = args.img_height - ori_height
            right_pad = args.img_width - ori_width

            # Pad size: (left_pad, right_pad, top_pad, bottom_pad)
            left = F.pad(left, (0, right_pad, top_pad, 0))
            right = F.pad(right, (0, right_pad, top_pad, 0))

        with torch.no_grad():
            pred_disp = aanet(left, right)[-1]  # [B, H, W]

        if pred_disp.size(-1) < left.size(-1):
            pred_disp = pred_disp.unsqueeze(1)  # [B, 1, H, W]
            pred_disp = F.interpolate(
                pred_disp, (left.size(-2), left.size(-1)),
                mode='bilinear') * (left.size(-1) / pred_disp.size(-1))
            pred_disp = pred_disp.squeeze(1)  # [B, H, W]

        # Crop
        if ori_height < args.img_height or ori_width < args.img_width:
            if right_pad != 0:
                pred_disp = pred_disp[:, top_pad:, :-right_pad]
            else:
                pred_disp = pred_disp[:, top_pad:]

        disp = pred_disp[0].detach().cpu().numpy()  # [H, W]

        save_name = os.path.basename(
            left_name)[:-4] + '_' + args.save_suffix + '.png'
        save_name = os.path.join(args.output_dir, save_name)

        if args.save_type == 'pfm':
            if args.visualize:
                skimage.io.imsave(save_name, (disp * 256.).astype(np.uint16))

            save_name = save_name[:-3] + 'pfm'
            write_pfm(save_name, disp)
        elif args.save_type == 'npy':
            save_name = save_name[:-3] + 'npy'
            np.save(save_name, disp)
        else:
            skimage.io.imsave(save_name, (disp * 256.).astype(np.uint16))
예제 #6
0
import numpy as np
import matplotlib.pyplot as plt
import cv2
from utils import utils
from dataloader import transforms
import os

train_transform_list = [
    transforms.ToPILImage(),
    transforms.RandomContrast(),
    transforms.RandomBrightness(),
    # transforms.ToNumpyArray(),

    # transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
]
train_transform = transforms.Compose(train_transform_list)


def gen_error_colormap(sim, real):
    cols = np.array(
        [[0, 0.5, 49, 54, 149], [0.5, 1, 69, 117, 180], [1, 2, 116, 173, 209],
         [2, 4, 171, 217, 233], [4, 8, 224, 243, 248], [8, 16, 254, 224, 144],
         [16, 32, 253, 174, 97], [32, 64, 244, 109, 67],
         [64, 128, 215, 48, 39], [128, 256, 165, 0, 38]],
        dtype=np.float32)
    cols[:, 2:5] /= 255.

    H, W = sim.shape
    error = np.abs(sim - real)

    error_image = np.zeros([H, W, 3], dtype=np.float32)
예제 #7
0
def getDataLoader(args, logger):
    # Train loader
    train_transform_list = [
        myTransforms.RandomCrop(args.img_height, args.img_width),
        myTransforms.RandomColor(),
        myTransforms.RandomVerticalFlip(),
        myTransforms.ToTensor(),
        # 将图像数据转化为Tensor并除以255.0,将像素数值范围归一化到[0,1]之间且[H, W, C=3]->[C=3, H, W]
        myTransforms.Normalize(mean=IMAGENET_MEAN,
                               std=IMAGENET_STD)  # 使用ImageNet数据集的均值和方差再做归一化
    ]
    train_transform = myTransforms.Compose(train_transform_list)

    train_data = StereoDataset(
        data_dir=args.data_dir,
        isDebug=args.isDebug,
        dataset_name=args.dataset_name,
        mode='train' if args.mode != 'train_all' else 'train_all',
        load_pseudo_gt=args.load_pseudo_gt,
        transform=train_transform)

    logger.info('=> {} training samples found in the training set'.format(
        len(train_data)))
    #  尝试分布式训练
    # 注意DistributedSampler默认参数就进行了shuffle
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_data) if args.distributed else None

    # train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True,
    #                           num_workers=args.num_workers, pin_memory=True, drop_last=True,
    #                           sampler=train_sampler)
    #  尝试分布式训练
    is_shuffle = False if args.distributed else True
    # 需要注意的是,这里的batch_size指的是每个进程下的batch_size。也就是说,总batch_size是这里的batch_size再乘以并行数(world_size)。
    train_loader = DataLoader(dataset=train_data,
                              batch_size=args.batch_size,
                              num_workers=args.num_workers,
                              shuffle=is_shuffle,
                              pin_memory=True,
                              drop_last=True,
                              sampler=train_sampler)
    # Validation loader
    val_transform_list = [
        myTransforms.RandomCrop(args.val_img_height,
                                args.val_img_width,
                                validate=True),
        myTransforms.ToTensor(),
        myTransforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
    ]

    val_transform = myTransforms.Compose(val_transform_list)
    val_data = StereoDataset(data_dir=args.data_dir,
                             isDebug=args.isDebug,
                             dataset_name=args.dataset_name,
                             mode='val',
                             transform=val_transform)
    logger.info('=> {} val samples found in the val set'.format(len(val_data)))
    val_loader = DataLoader(dataset=val_data,
                            batch_size=args.val_batch_size,
                            shuffle=False,
                            num_workers=args.num_workers,
                            pin_memory=True,
                            drop_last=False)

    return train_loader, val_loader