예제 #1
0
        return _img, _target

    def __str__(self):
        return 'ISPRS(split=' + str(self.split) + ')'


if __name__ == '__main__':
    import custom_transforms as tr
    from utils import decode_segmap
    from torch.utils.data import DataLoader
    from torchvision import transforms
    import matplotlib.pyplot as plt

    composed_transforms_tr = transforms.Compose([tr.RandomHorizontalFlip(),
                                                 tr.RandomCrop(512),
                                                 tr.RandomRotate(15),
                                                 tr.ToTensor()]
                                                )

    isprs_train = ISPRSSegmentation(split='train', transform=composed_transforms_tr)

    dataloader = DataLoader(isprs_train, batch_size=5, shuffle=True, num_workers=2)

    for ii, sample in enumerate(dataloader):
        for jj in range(sample["image"].size()[0]):
            img = sample['image'].numpy()
            gt = sample['label'].numpy()
            tmp = np.array(gt[jj]).astype(np.uint8)
            tmp = np.squeeze(tmp, axis=0)
            segmap = decode_segmap(tmp, dataset='ISPRS')
예제 #2
0
		sample['label_name'] = img_name[:-4] + '.png'
		return sample


if __name__ == '__main__':
	print(os.getcwd())
	# from dataloaders import custom_transforms as tr
	import custom_transforms as tr
	from torch.utils.data import DataLoader
	from torchvision import transforms
	import matplotlib.pyplot as plt

	composed_transforms_tr = transforms.Compose([
		tr.RandomHorizontalFlip(),
		tr.RandomScale((0.5, 0.75)),
		tr.RandomCrop((512, 1024)),
		tr.RandomRotate(5),
		tr.ToTensor()])

	msrab_train= MSRAB(split='train',transform=composed_transforms_tr)

	dataloader = DataLoader(msrab_train, batch_size=2, shuffle=True, num_workers=2)

	for ii, sample in enumerate(dataloader):
		print(ii, sample["image"].size(), sample["label"].size(), type(sample["image"]), type(sample["label"]))
		for jj in range(sample["image"].size()[0]):
			img = sample['image'].numpy()
			gt = sample['label'].numpy()
			tmp = np.array(gt[jj]*255.0).astype(np.uint8)
			tmp = np.squeeze(tmp, axis=0)
			tmp = np.expand_dims(tmp, axis=2)
예제 #3
0
def main():
    global n_iter
    args = parser.parse_args()
    output_dir = Path(args.output_dir)
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / (args.exp + '_' + save_path)
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    for i in range(3):
        output_writers.append(SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    if args.dataset == 'sceneflow':
        train_transform = custom_transforms.Compose([
            custom_transforms.RandomCrop(scale=args.scale,
                                         h=args.crop_h,
                                         w=args.crop_w),
            custom_transforms.ArrayToTensor(), normalize
        ])
    else:
        train_transform = custom_transforms.Compose([
            custom_transforms.RandomScaleCrop(),
            custom_transforms.ArrayToTensor(), normalize
        ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    val_set = SequenceFolder(args.data,
                             transform=valid_transform,
                             seed=args.seed,
                             ttype=args.ttype2,
                             dataset=args.dataset,
                             index=args.index)

    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               ttype=args.ttype,
                               dataset=args.dataset)

    train_set.samples = train_set.samples[:len(train_set) -
                                          len(train_set) % args.batch_size]

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    if args.dataset == 'sceneflow':
        args.mindepth = 5.45
    mvdnet = MVDNet(args.nlabel, args.mindepth)
    #mvdnet = convert_model(mvdnet)

    if args.pretrained_mvdn:
        print("=> using pre-trained weights for MVDNet")
        weights = torch.load(args.pretrained_mvdn)
        mvdnet.init_weights()
        mvdnet.load_state_dict(weights['state_dict'])
    elif args.pretrained_dps:
        print("=> using pre-trained DPS weights for MVDNet")
        weights = torch.load('pretrained/dpsnet_updated.pth.tar')
        mvdnet.init_weights()
        mvdnet.load_state_dict(weights['state_dict'], strict=False)
    else:
        mvdnet.init_weights()

    print('=> setting adam solver')

    optimizer = torch.optim.Adam(mvdnet.parameters(),
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    cudnn.benchmark = True
    mvdnet = torch.nn.DataParallel(mvdnet)
    mvdnet = mvdnet.cuda()

    print(' ==> setting log files')
    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'validation_abs_rel', 'validation_abs_diff',
            'validation_sq_rel', 'validation_a1', 'validation_a2',
            'validation_a3', 'mean_angle_error'
        ])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss'])

    print(' ==> main Loop')
    for epoch in range(args.epochs):
        adjust_learning_rate(args, optimizer, epoch)

        # train for one epoch
        if args.evaluate:
            train_loss = 0
        else:
            train_loss = train(args, train_loader, mvdnet, optimizer,
                               args.epoch_size, training_writer, epoch)
        if not args.evaluate and (args.skip_v or (epoch + 1) % 3 != 0):
            error_names = [
                'abs_rel', 'abs_diff', 'sq_rel', 'a1', 'a2', 'a3', 'angle'
            ]
            errors = [0] * 7
        else:
            errors, error_names = validate_with_gt(args, val_loader, mvdnet,
                                                   epoch, output_writers)

        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[0]
        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                train_loss, decisive_error, errors[1], errors[2], errors[3],
                errors[4], errors[5], errors[6]
            ])
        if args.evaluate:
            break
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': mvdnet.module.state_dict()
        },
                        epoch,
                        file_prefixes=['mvdnet'])