示例#1
0
def dataflow_test():
    from DataFlow.sequence_folders import SequenceFolder
    import custom_transforms
    from torch.utils.data import DataLoader
    from DataFlow.validation_folders import ValidationSet
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([custom_transforms.RandomHorizontalFlip(),
                                                 custom_transforms.RandomScaleCrop(),
                                                 custom_transforms.ArrayToTensor(), normalize])
    valid_transform = custom_transforms.Compose([custom_transforms.ArrayToTensor(), normalize])
    datapath = 'G:/data/KITTI/KittiRaw_formatted'
    seed = 8964
    train_set = SequenceFolder(datapath, transform=train_transform, seed=seed, train=True,
                               sequence_length=3)

    train_loader = DataLoader(train_set, batch_size=4, shuffle=True, num_workers=4,
                              pin_memory=True)

    val_set = ValidationSet(datapath, transform=valid_transform)
    print("length of train loader is %d" % len(train_loader))
    val_loader = DataLoader(val_set, batch_size=4, shuffle=False, num_workers=4, pin_memory=True)
    print("length of val loader is %d" % len(val_loader))

    dataiter = iter(train_loader)
    imgs, intrinsics = next(dataiter)
    print(len(imgs))
    print(intrinsics.shape)

    pass
def main():
    args = parser.parse_args()
    config = load_config(args.config)

    # directory for storing output
    output_dir = os.path.join(config.save_path, "results", args.seq)

    # load configuration from checkpoints
    if args.use_latest_not_best:
        config.posenet = os.path.join(config.save_path,
                                      "posenet_checkpoint.pth.tar")
        output_dir = output_dir + "-latest"
    else:
        config.posenet = os.path.join(config.save_path, "posenet_best.pth.tar")

    os.makedirs(output_dir)

    # define transformations
    transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # load dataset loader
    dataset = SequenceFolder(
        config.data,
        cameras=config.cameras,
        gray=config.gray,
        sequence_length=config.sequence_length,
        shuffle=False,
        train=False,
        transform=transform,
        sequence=args.seq,
    )

    input_channels = dataset[0][1].shape[0]

    # define posenet
    pose_net = PoseNet(in_channels=input_channels,
                       nb_ref_imgs=2,
                       output_exp=False).to(device)
    weights = torch.load(config.posenet)
    pose_net.load_state_dict(weights['state_dict'])
    pose_net.eval()

    # prediction
    poses = []
    for i, (tgt, tgt_lf, ref, ref_lf, k, kinv, pose_gt) in enumerate(dataset):
        print("{:03d}/{:03d}".format(i + 1, len(dataset)), end="\r")
        tgt = tgt.unsqueeze(0).to(device)
        ref = [r.unsqueeze(0).to(device) for r in ref]
        tgt_lf = tgt_lf.unsqueeze(0).to(device)
        ref_lf = [r.unsqueeze(0).to(device) for r in ref_lf]
        exp, pose = pose_net(tgt_lf, ref_lf)
        poses.append(pose[0, 1, :].cpu().numpy())

    # save poses
    outdir = os.path.join(output_dir, "poses.npy")
    np.save(outdir, poses)
    print("\nok")
示例#3
0
def main():
    global args, best_error, n_iter
    args = parser.parse_args()

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomRotate(),
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=4,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True,
                                               drop_last=True)
    print(len(train_loader))
 def __init__(self, 
              ds_path, 
              seed=0, 
              sequence_length = 3,
              num_scales = 4):
     
     np.random.seed(seed)
     random.seed(seed)
     
     self.num_scales = num_scales
     
     seq_list = get_lists(ds_path)
     self.samples = get_samples(seq_list, sequence_length)
     
     # get by calib.py
     self.intrinsics = np.array([
                              [1.14183754e+03, 0.00000000e+00, 6.28283670e+02],
                              [0.00000000e+00, 1.13869492e+03, 3.56277189e+02],
                              [0.00000000e+00, 0.00000000e+00, 1.00000000e+00],
                              ]).astype(np.float32)
     
     # resize 1280 x 720 -> 416 x 128
     self.intrinsics[0] = self.intrinsics[0]*(416.0/1280.0)
     self.intrinsics[1] = self.intrinsics[1]*(128.0/720.0)
     
     self.ms_k     = get_multi_scale_intrinsics(self.intrinsics, self.num_scales)
     self.ms_inv_k = get_multi_scale_inv_intrinsics(self.intrinsics, self.num_scales)
     
     ######################
     self.to_tensor = custom_transforms.Compose([ custom_transforms.ArrayToTensor() ])
     self.to_tensor_norm = custom_transforms.Compose([ custom_transforms.ArrayToTensor(),
                                                  custom_transforms.Normalize(
                                                          mean=[0.485, 0.456, 0.406],
                                                          std =[0.229, 0.224, 0.225])
                                               ])
 def transforms_val(self, sample):
     composed_transforms = transforms.Compose([
         tr.FixedResize(size=self.input_size),
         tr.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                      std=[x / 255.0 for x in [63.0, 62.1, 66.7]]),
         tr.ToTensor()
     ])
     return composed_transforms(sample)
示例#6
0
def main():
    print('=> PyTorch version: ' + torch.__version__ + ' || CUDA_VISIBLE_DEVICES: ' + os.environ["CUDA_VISIBLE_DEVICES"])

    global device
    args = parser.parse_args()
    if args.save_fig:
        timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
        args.save_path = 'outputs'/Path(args.name)/timestamp
        print('=> will save everything to {}'.format(args.save_path))
        args.save_path.makedirs_p()

    print("=> fetching scenes in '{}'".format(args.data))
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    demo_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        normalize
    ])
    demo_set = SequenceFolder(
        args.data,
        transform=demo_transform,
        max_num_instances=args.mni
    )
    print('=> {} samples found'.format(len(demo_set)))
    demo_loader = torch.utils.data.DataLoader(demo_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True)

    # create model
    print("=> creating model")
    disp_net = models.DispResNet().to(device)
    ego_pose_net = models.EgoPoseNet().to(device)
    obj_pose_net = models.ObjPoseNet().to(device)

    if args.pretrained_ego_pose:
        print("=> using pre-trained weights for EgoPoseNet")
        weights = torch.load(args.pretrained_ego_pose)
        ego_pose_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        ego_pose_net.init_weights()

    if args.pretrained_obj_pose:
        print("=> using pre-trained weights for ObjPoseNet")
        weights = torch.load(args.pretrained_obj_pose)
        obj_pose_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        obj_pose_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for DispResNet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    ego_pose_net = torch.nn.DataParallel(ego_pose_net)
    obj_pose_net = torch.nn.DataParallel(obj_pose_net)

    demo_visualize(args, demo_loader, disp_net, ego_pose_net, obj_pose_net)
示例#7
0
    def transform_val(self, sample):

        composed_transforms = transforms.Compose([
            tr.FixScaleCrop(crop_size=self.crop_size),
            tr.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
            tr.ToTensor()
        ])
        return composed_transforms(sample)
    def transform_val(self, sample):

        composed_transforms = transforms.Compose([
            tr.FixedResize(size=self.args['crop_size']),
            tr.Normalize(mean=self.mean, std=self.std),
            tr.ToTensor()
        ])

        return composed_transforms(sample)
示例#9
0
    def transform_val(self, sample):

        composed_transforms = transforms.Compose([
            tr.FixScaleCrop(400),
            tr.Normalize(mean=self.source_dist['mean'],
                         std=self.source_dist['std']),
            tr.ToTensor(),
        ])
        return composed_transforms(sample)
示例#10
0
    def transform_ts(self, sample):

        composed_transforms = transforms.Compose([
            tr.FixedResize(size=400),
            tr.Normalize(mean=self.source_dist['mean'],
                         std=self.source_dist['std']),
            tr.ToTensor(),
        ])
        return composed_transforms(sample)
示例#11
0
 def transform_tr(self, sample):
     composed_transforms = transforms.Compose([
         #tr.RandomHorizontalFlip(),
         #tr.RandomScaleCrop(base_size=self.base_size, crop_size=self.crop_size),
         ##tr.RandomGaussianBlur(),
         tr.FixScaleCrop(crop_size=self.crop_size),
         tr.Normalize(mean=(0.485, 0.456, 0.406),
                      std=(0.229, 0.224, 0.225)),
         tr.ToTensor()
     ])
     return composed_transforms(sample)
 def transforms_train_esp(self, sample):
     composed_transforms = transforms.Compose([
         tr.RandomVerticalFlip(),
         tr.RandomHorizontalFlip(),
         tr.RandomAffine(degrees=40, scale=(.9, 1.1), shear=30),
         tr.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
         tr.FixedResize(size=self.input_size),
         tr.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                      std=[x / 255.0 for x in [63.0, 62.1, 66.7]]),
         tr.ToTensor()
     ])
     return composed_transforms(sample)
示例#13
0
    def transform_tr(self, sample):
        composed_transforms = transforms.Compose([
            tr.RandomHorizontalFlip(),  #随机水平翻转
            tr.RandomScaleCrop(base_size=self.args.base_size,
                               crop_size=self.args.crop_size),  #随机尺寸裁剪
            tr.RandomGaussianBlur(),  #随机高斯模糊
            tr.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),  #归一化
            tr.ToTensor()
        ])

        return composed_transforms(sample)
示例#14
0
    def transform_pair_val(self, sample):

        composed_transforms = transforms.Compose([
            tr.FixScaleCrop(400),
            tr.HorizontalFlip(),
            tr.GaussianBlur(),
            tr.Normalize(mean=self.source_dist['mean'],
                         std=self.source_dist['std'],
                         if_pair=True),
            tr.ToTensor(if_pair=True),
        ])
        return composed_transforms(sample)
示例#15
0
 def transform_pair_train(self, sample):
     composed_transforms = transforms.Compose([
         tr.RandomHorizontalFlip(),
         tr.RandomScaleCrop(base_size=400, crop_size=400, fill=0),
         tr.HorizontalFlip(),
         tr.GaussianBlur(),
         tr.Normalize(mean=self.source_dist['mean'],
                      std=self.source_dist['std'],
                      if_pair=True),
         tr.ToTensor(if_pair=True),
     ])
     return composed_transforms(sample)
示例#16
0
 def transform_tr(self, sample):
     if not self.random_match:
         composed_transforms = transforms.Compose([
             tr.RandomHorizontalFlip(),
             tr.RandomScaleCrop(base_size=400, crop_size=400, fill=0),
             #tr.Remap(self.building_table, self.nonbuilding_table, self.channels)
             tr.RandomGaussianBlur(),
             #tr.ConvertFromInts(),
             #tr.PhotometricDistort(),
             tr.Normalize(mean=self.source_dist['mean'],
                          std=self.source_dist['std']),
             tr.ToTensor(),
         ])
     else:
         composed_transforms = transforms.Compose([
             tr.HistogramMatching(),
             tr.RandomHorizontalFlip(),
             tr.RandomScaleCrop(base_size=400, crop_size=400, fill=0),
             tr.RandomGaussianBlur(),
             tr.Normalize(mean=self.source_dist['mean'],
                          std=self.source_dist['std']),
             tr.ToTensor(),
         ])
     return composed_transforms(sample)
示例#17
0
    def get_mean_std(self, ratio=0.1):
        trs = tf.Compose(
            [tr.FixedResize(512),
             tr.Normalize(mean=0, std=1),
             tr.ToTensor()])
        dataset = LungDataset(root_dir=r'D:\code\U-net',
                              transforms=trs,
                              train=True)
        print(dataset)
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=int(
                                                     len(dataset) * ratio),
                                                 shuffle=True,
                                                 num_workers=4)

        for item in dataloader:
            train = item['image']
            # train = np.array(train)      #?
            print(train.shape)
            print('sample {} images to calculate'.format(train.shape[0]))
            mean = np.mean(train.numpy(), axis=(0, 2, 3))
            std = np.std(train.numpy(), axis=(0, 2, 3))
        return mean, std
    def init(self, 
                 ds_path, 
                 AoI,
                 seed=0, 
                 sequence_length = 3,
                 num_scales = 4):
        np.random.seed(seed)
        random.seed(seed)

        self.num_scales = num_scales
        
        seq_list = get_lists(ds_path)
        self.samples = get_samples(seq_list, sequence_length,AoI)
        
        # get by calib.py
        self.intrinsics = np.array([
                                 [1.14183754e+03, 0.00000000e+00, 6.28283670e+02],
                                 [0.00000000e+00, 1.13869492e+03, 3.56277189e+02],
                                 [0.00000000e+00, 0.00000000e+00, 1.00000000e+00],
                                 ]).astype(np.float32)
        
        # The original size of the picture taken by my phone camera is 1280 x 720.
        # if your original picture size is not 1280 x 720, change the two numbers below
        # resize 1280 x 720 -> 416 x 128
        self.intrinsics[0] = self.intrinsics[0]*(416.0/1280.0)
        self.intrinsics[1] = self.intrinsics[1]*(128.0/720.0)
        
        self.ms_k     = get_multi_scale_intrinsics(self.intrinsics, self.num_scales)
        self.ms_inv_k = get_multi_scale_inv_intrinsics(self.intrinsics, self.num_scales)
        
        ######################
        self.to_tensor = custom_transforms.Compose([ custom_transforms.ArrayToTensor() ])
        self.to_tensor_norm = custom_transforms.Compose([ custom_transforms.ArrayToTensor(),
                                                     custom_transforms.Normalize(
                                                             mean=[0.485, 0.456, 0.406],
                                                             std =[0.229, 0.224, 0.225])
                                                  ])
示例#19
0
文件: pose_test.py 项目: zuru/DeepSFM
def main():
	global n_iter
	args = parser.parse_args()

	# Data loading code
	normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
	                                        std=[0.5, 0.5, 0.5])
	train_transform = custom_transforms.Compose([
		# custom_transforms.RandomScaleCrop(),
		custom_transforms.ArrayToTensor(),
		normalize
	])

	print("=> fetching scenes in '{}'".format(args.data))
	train_set = SequenceFolder(
		args.data,
		transform=train_transform,
		seed=args.seed,
		ttype=args.ttype,
		add_geo=args.geo,
		depth_source=args.depth_init,
		sequence_length=args.sequence_length,
		gt_source='g',
		std=args.std_tr,
		pose_init=args.pose_init,
		dataset="",
		get_path=True
	)

	print('{} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
	val_loader = torch.utils.data.DataLoader(
		train_set, batch_size=args.batch_size, shuffle=False,
		num_workers=args.workers, pin_memory=True)

	# create model
	print("=> creating model")
	pose_net = PoseNet(args.nlabel, args.std_tr, args.std_rot, add_geo_cost=args.geo, depth_augment=False).cuda()

	if args.pretrained_dps:
		# freeze feature extra layers
		# for param in pose_net.feature_extraction.parameters():
		#     param.requires_grad = False

		print("=> using pre-trained weights for DPSNet")
		model_dict = pose_net.state_dict()
		weights = torch.load(args.pretrained_dps)['state_dict']
		pretrained_dict = {k: v for k, v in weights.items() if
		                   k in model_dict and weights[k].shape == model_dict[k].shape}

		model_dict.update(pretrained_dict)

		pose_net.load_state_dict(model_dict)

	else:
		pose_net.init_weights()

	cudnn.benchmark = True
	pose_net = torch.nn.DataParallel(pose_net)

	global n_iter
	data_time = AverageMeter()

	pose_net.eval()
	end = time.time()

	errors = np.zeros((2, 2, int(np.ceil(len(val_loader)))), np.float32)
	with torch.no_grad():
		for i, (tgt_img, ref_imgs, ref_poses, intrinsics, intrinsics_inv, tgt_depth, ref_depths,
		        ref_noise_poses, initial_pose, tgt_path, ref_paths) in enumerate(val_loader):

			data_time.update(time.time() - end)
			tgt_img_var = Variable(tgt_img.cuda())
			ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
			ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses]
			ref_noise_poses_var = [Variable(pose.cuda()) for pose in ref_noise_poses]
			initial_pose_var = Variable(initial_pose.cuda())

			ref_depths_var = [Variable(dep.cuda()) for dep in ref_depths]
			intrinsics_var = Variable(intrinsics.cuda())
			intrinsics_inv_var = Variable(intrinsics_inv.cuda())
			tgt_depth_var = Variable(tgt_depth.cuda())

			pose = torch.cat(ref_poses_var, 1)

			noise_pose = torch.cat(ref_noise_poses_var, 1)

			pose_norm = torch.norm(noise_pose[:, :, :3, 3], dim=-1, keepdim=True)  # b * n* 1

			p_angle, p_trans, rot_c, trans_c = pose_net(tgt_img_var, ref_imgs_var, initial_pose_var, noise_pose,
			                                            intrinsics_var,
			                                            intrinsics_inv_var,
			                                            tgt_depth_var,
			                                            ref_depths_var, trans_norm=pose_norm)

			batch_size = p_angle.shape[0]
			p_angle_v = torch.sum(F.softmax(p_angle, dim=1).view(batch_size, -1, 1) * rot_c, dim=1)
			p_trans_v = torch.sum(F.softmax(p_trans, dim=1).view(batch_size, -1, 1) * trans_c, dim=1)
			p_matrix = Variable(torch.zeros((batch_size, 4, 4)).float()).cuda()
			p_matrix[:, 3, 3] = 1
			p_matrix[:, :3, :] = torch.cat([angle2matrix(p_angle_v), p_trans_v.unsqueeze(-1)], dim=-1)  # 2*3*4

			p_rel_pose = torch.ones_like(noise_pose)
			for bat in range(batch_size):
				path = tgt_path[bat]
				dirname = Path.dirname(path)

				orig_poses = np.genfromtxt(Path.join(dirname, args.pose_init + "_poses.txt"))
				for j in range(len(ref_imgs)):
					p_rel_pose[:, j] = torch.matmul(noise_pose[:, j], inv(p_matrix))

					seq_num = int(Path.basename(ref_paths[bat][j])[:-4])
					orig_poses[seq_num] = p_rel_pose[bat, j, :3, :].data.cpu().numpy().reshape(12, )

					p_aa = mat2axangle(p_rel_pose[bat, j, :3, :3].data.cpu().numpy())
					gt_aa = mat2axangle(pose[bat, j, :3, :3].data.cpu().numpy(), unit_thresh=1e-2)

					n_aa = mat2axangle(noise_pose[bat, j, :3, :3].data.cpu().numpy(), unit_thresh=1e-2)
					p_t = p_rel_pose[bat, j, :3, 3].data.cpu().numpy()
					gt_t = pose[bat, j, :3, 3].data.cpu().numpy()
					n_t = noise_pose[bat, j, :3, 3].data.cpu().numpy()
					p_aa = p_aa[0] * p_aa[1]
					n_aa = n_aa[0] * n_aa[1]
					gt_aa = gt_aa[0] * gt_aa[1]
					error = compute_motion_errors(np.concatenate([n_aa, n_t]), np.concatenate([gt_aa, gt_t]), True)
					error_p = compute_motion_errors(np.concatenate([p_aa, p_t]), np.concatenate([gt_aa, gt_t]), True)
					print("%d n r%.6f, t%.6f" % (i, error[0], error[2]))
					print("%d p r%.6f, t%.6f" % (i, error_p[0], error_p[2]))
					errors[0, 0, i] += error[0]
					errors[0, 1, i] += error[2]
					errors[1, 0, i] += error_p[0]
					errors[1, 1, i] += error_p[2]
				errors[:, :, i] /= len(ref_imgs)
				if args.save and not Path.exists(Path.join(dirname, args.save + "_poses.txt")):
					np.savetxt(Path.join(dirname, args.save + "_poses.txt"), orig_poses)

		mean_error = errors.mean(2)
		error_names = ['rot', 'trans']
		print("%s Results : " % args.pose_init)
		print(
			"{:>10}, {:>10}".format(
				*error_names))
		print("{:10.4f}, {:10.4f}".format(*mean_error[0]))

		print("new Results : ")
		print(
			"{:>10}, {:>10}".format(
				*error_names))
		print("{:10.4f}, {:10.4f}".format(*mean_error[1]))
def main():
    global best_error, n_iter, device
    args = parse_multiwarp_training_args()
    # Some non-optional parameters for training
    args.training_output_freq = 100
    args.tilesize = 8

    save_path = make_save_path(args)
    args.save_path = save_path

    print("Using device: {}".format(device))

    dump_config(save_path, args)
    print('\n\n=> Saving checkpoints to {}'.format(save_path))

    torch.manual_seed(args.seed)                # setting a manual seed for reproducability
    tb_writer = SummaryWriter(save_path)        # tensorboard summary writer

    # Data pre-processing - Just convert arrays to tensor and normalize the data to be largely between 0 and 1
    train_transform = valid_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=0.5, std=0.5)
    ])

    # Create data loader based on the format of the light field
    print("=> Fetching scenes in '{}'".format(args.data))
    train_set, val_set = None, None
    if args.lfformat == 'focalstack':
        train_set, val_set = get_focal_stack_loaders(args, train_transform, valid_transform)
    elif args.lfformat == 'stack':
        is_monocular = False
        if len(args.cameras) == 1 and args.cameras[0] == 8 and args.cameras_stacked == "input":
                is_monocular = True
        train_set, val_set = get_stacked_lf_loaders(args, train_transform, valid_transform, is_monocular=is_monocular)
    elif args.lfformat == 'epi':
        train_set, val_set = get_epi_loaders(args, train_transform, valid_transform)

    print('=> {} samples found in {} train scenes'.format(len(train_set), len(train_set.scenes)))
    print('=> {} samples found in {} validation scenes'.format(len(val_set), len(val_set.scenes)))

    print('=> Multi-warp training, warping {} sub-apertures'.format(len(args.cameras)))

    # Create batch loader
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    output_channels = len(args.cameras) # for multi-warp photometric loss, we request as many depth values as the cameras used
    args.epoch_size = len(train_loader)
    
    # Create models
    print("=> Creating models")

    if args.lfformat == "epi":
        print("=> Using EPI encoders")
        if args.cameras_epi == "vertical":
            disp_encoder = models.EpiEncoder('vertical', args.tilesize).to(device)
            pose_encoder = models.RelativeEpiEncoder('vertical', args.tilesize).to(device)
            dispnet_input_channels = 16 + len(args.cameras)     # 16 is the number of output channels of the encoder
            posenet_input_channels = 16 + len(args.cameras)     # 16 is the number of output channels of the encoder
        elif args.cameras_epi == "horizontal":
            disp_encoder = models.EpiEncoder('horizontal', args.tilesize).to(device)
            pose_encoder = models.RelativeEpiEncoder('horizontal', args.tilesize).to(device)
            dispnet_input_channels = 16 + len(args.cameras)  # 16 is the number of output channels of the encoder
            posenet_input_channels = 16 + len(args.cameras)  # 16 is the number of output channels of the encoder
        elif args.cameras_epi == "full":
            disp_encoder = models.EpiEncoder('full', args.tilesize).to(device)
            pose_encoder = models.RelativeEpiEncoder('full', args.tilesize).to(device)
            if args.without_disp_stack:
                dispnet_input_channels = 32  # 16 is the number of output channels of each encoder
            else:
                dispnet_input_channels = 32 + 5  # 16 is the number of output channels of each encoder, 5 from stack
            posenet_input_channels = 32 + 5  # 16 is the number of output channels of each encoder
        else:
            raise ValueError("Incorrect cameras epi format")
    else:
        disp_encoder = None
        pose_encoder = None
        # for stack lfformat channels = num_cameras * num_colour_channels
        # for focalstack lfformat channels = num_focal_planes * num_colour_channels
        dispnet_input_channels = posenet_input_channels = train_set[0]['tgt_lf_formatted'].shape[0]
    
    disp_net = models.LFDispNet(in_channels=dispnet_input_channels,
                                out_channels=output_channels, encoder=disp_encoder).to(device)
    pose_net = models.LFPoseNet(in_channels=posenet_input_channels,
                                nb_ref_imgs=args.sequence_length - 1, encoder=pose_encoder).to(device)

    print("=> [DispNet] Using {} input channels, {} output channels".format(dispnet_input_channels, output_channels))
    print("=> [PoseNet] Using {} input channels".format(posenet_input_channels))

    if args.pretrained_exp_pose:
        print("=> [PoseNet] Using pre-trained weights for pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        print("=> [PoseNet] training from scratch")
        pose_net.init_weights()

    if args.pretrained_disp:
        print("=> [DispNet] Using pre-trained weights for DispNet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        print("=> [DispNet] training from scratch")
        disp_net.init_weights()

    # this flag tells CUDNN to find the optimal set of algorithms for this specific input data size, which improves
    # runtime efficiency, but takes a while to load in the beginning.
    cudnn.benchmark = True
    # disp_net = torch.nn.DataParallel(disp_net)
    # pose_net = torch.nn.DataParallel(pose_net)

    print('=> Setting adam solver')

    optim_params = [
        {'params': disp_net.parameters(), 'lr': args.lr}, 
        {'params': pose_net.parameters(), 'lr': args.lr}
    ]

    optimizer = torch.optim.Adam(optim_params, betas=(args.momentum, args.beta), weight_decay=args.weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

    with open(save_path + "/" + args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(save_path + "/" + args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'photo_loss', 'smooth_loss', 'pose_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    w1 = torch.tensor(args.photo_loss_weight, dtype=torch.float32, device=device, requires_grad=True)
    # w2 = torch.tensor(args.mask_loss_weight, dtype=torch.float32, device=device, requires_grad=True)
    w3 = torch.tensor(args.smooth_loss_weight, dtype=torch.float32, device=device, requires_grad=True)
    # w4 = torch.tensor(args.gt_pose_loss_weight, dtype=torch.float32, device=device, requires_grad=True)

    # add some constant parameters to the log for easy visualization
    tb_writer.add_scalar(tag="batch_size", scalar_value=args.batch_size)

    # tb_writer.add_scalar(tag="mask_loss_weight", scalar_value=args.mask_loss_weight)    # this is not used

    # tb_writer.add_scalar(tag="gt_pose_loss_weight", scalar_value=args.gt_pose_loss_weight)

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_net, optimizer, args.epoch_size, logger, tb_writer, w1, w3)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        errors, error_names = validate_without_gt(args, val_loader, disp_net, pose_net, epoch, logger, tb_writer, w1, w3)
        error_string = ', '.join('{} : {:.3f}'.format(name, error) for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        # update the learning rate (annealing)
        lr_scheduler.step()

        # add the learning rate to the tensorboard logging
        tb_writer.add_scalar(tag="learning_rate", scalar_value=lr_scheduler.get_last_lr()[0], global_step=epoch)

        tb_writer.add_scalar(tag="photometric_loss_weight", scalar_value=w1, global_step=epoch)
        tb_writer.add_scalar(tag="smooth_loss_weight", scalar_value=w3, global_step=epoch)

        # add validation errors to the tensorboard logging
        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(tag=name, scalar_value=error, global_step=epoch)

        decisive_error = errors[2]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(save_path, {'epoch': epoch + 1, 'state_dict': disp_net.state_dict()},
                        {'epoch': epoch + 1, 'state_dict': pose_net.state_dict()}, is_best)

        # save a checkpoint every 20 epochs anyway
        if epoch % 20 == 0:
            save_checkpoint_current(save_path, {'epoch': epoch + 1, 'state_dict': disp_net.state_dict()},
                                    {'epoch': epoch + 1, 'state_dict': pose_net.state_dict()}, epoch)

        with open(save_path + "/" + args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
def main():
    global args, best_error, n_iter, device
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints_shifted' / save_path
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = ShiftedSequenceFolder(
        args.data,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length,
        target_displacement=args.target_displacement)

    # if no Groundtruth is avalaible, Validation set is the same type as training set to measure photometric loss from warping
    if args.with_gt:
        from datasets.validation_folders import ValidationSet
        val_set = ValidationSet(args.data, transform=valid_transform)
    else:
        val_set = SequenceFolder(
            args.data,
            transform=valid_transform,
            seed=args.seed,
            train=False,
            sequence_length=args.sequence_length,
        )
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    adjust_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )  # workers is set to 0 to avoid multiple instances to be modified at the same time
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    train.args = args
    # create model
    print("=> creating model")

    disp_net = models.DispNetS().cuda()
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).to(device)

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_exp_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_exp_net,
                           optimizer, args.epoch_size, logger, training_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        if (epoch + 1) % 5 == 0:
            train_set.adjust = True
            logger.reset_train_bar(len(adjust_loader))
            average_shifts = adjust_shifts(args, train_set, adjust_loader,
                                           pose_exp_net, epoch, logger,
                                           training_writer)
            shifts_string = ' '.join(
                ['{:.3f}'.format(s) for s in average_shifts])
            logger.train_writer.write(
                ' * adjusted shifts, average shifts are now : {}'.format(
                    shifts_string))
            for i, shift in enumerate(average_shifts):
                training_writer.add_scalar('shifts{}'.format(i), shift, epoch)
            train_set.adjust = False

        # evaluate on validation set
        logger.reset_valid_bar()
        if args.with_gt:
            errors, error_names = validate_with_gt(args, val_loader, disp_net,
                                                   epoch, logger,
                                                   output_writers)
        else:
            errors, error_names = validate_without_gt(args, val_loader,
                                                      disp_net, pose_exp_net,
                                                      epoch, logger,
                                                      output_writers)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[0]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
示例#22
0
def main(args):
  # parse args
  best_acc1 = 0.0

  if args.gpu >= 0:
    torch.cuda.set_device(args.gpu)
    print("Use GPU: {}".format(args.gpu))
  else:
    print('You are using CPU for computing!',
          'Yet we assume you are using a GPU.',
          'You will NOT be able to switch between CPU and GPU training!')

  # fix the random seeds (the best we can)
  fixed_random_seed = 2019
  torch.manual_seed(fixed_random_seed)
  np.random.seed(fixed_random_seed)
  random.seed(fixed_random_seed)

  # set up the model + loss
  if args.use_custom_conv:
    print("Using custom convolutions in the network")
    model = default_model(conv_op=CustomConv2d, num_classes=100)
  elif args.use_resnet18:
    model = torchvision.models.resnet18(pretrained=True)
    model.fc = nn.Linear(512, 100)
  elif args.use_adv_training:
    model = AdvSimpleNet(num_classes=100)
  else:
    model = default_model(num_classes=100)
  model_arch = "simplenet"
  criterion = nn.CrossEntropyLoss()
  # put everthing to gpu
  if args.gpu >= 0:
    model = model.cuda(args.gpu)
    criterion = criterion.cuda(args.gpu)

  # setup the optimizer
  optimizer = torch.optim.SGD(model.parameters(), args.lr,
                momentum=args.momentum,
                weight_decay=args.weight_decay)

  # resume from a checkpoint?
  if args.resume:
    if os.path.isfile(args.resume):
      print("=> loading checkpoint '{}'".format(args.resume))
      checkpoint = torch.load(args.resume)
      args.start_epoch = checkpoint['epoch']
      best_acc1 = checkpoint['best_acc1']
      model.load_state_dict(checkpoint['state_dict'])
      if args.gpu < 0:
        model = model.cpu()
      else:
        model = model.cuda(args.gpu)
      # only load the optimizer if necessary
      if (not args.evaluate) and (not args.attack):
        optimizer.load_state_dict(checkpoint['optimizer'])
      print("=> loaded checkpoint '{}' (epoch {}, acc1 {})"
          .format(args.resume, checkpoint['epoch'], best_acc1))
    else:
      print("=> no checkpoint found at '{}'".format(args.resume))

  # set up transforms for data augmentation
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
  train_transforms = get_train_transforms(normalize)
  # val transofrms
  val_transforms=[]
  val_transforms.append(transforms.Scale(160, interpolations=None))
  val_transforms.append(transforms.ToTensor())
  val_transforms.append(normalize)
  val_transforms = transforms.Compose(val_transforms)
  if (not args.evaluate) and (not args.attack):
    print("Training time data augmentations:")
    print(train_transforms)

  # setup dataset and dataloader
  train_dataset = MiniPlacesLoader(args.data_folder,
                  split='train', transforms=train_transforms)
  val_dataset = MiniPlacesLoader(args.data_folder,
                  split='val', transforms=val_transforms)

  train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=True,
    num_workers=args.workers, pin_memory=True, sampler=None, drop_last=True)
  val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=100, shuffle=False,
    num_workers=args.workers, pin_memory=True, sampler=None, drop_last=False)

  # testing only
  if (args.evaluate==args.attack) and args.evaluate:
    print("Cann't set evaluate and attack to True at the same time!")
    return

  # set up visualizer
  if args.vis:
    visualizer = default_attention(criterion)
  else:
    visualizer = None

  # evaluation
  if args.resume and args.evaluate:
    print("Testing the model ...")
    cudnn.deterministic = True
    validate(val_loader, model, -1, args, visualizer=visualizer)
    return

  # attack
  if args.resume and args.attack:
    print("Generating adversarial samples for the model ..")
    cudnn.deterministic = True
    validate(val_loader, model, -1, args,
             attacker=default_attack(criterion),
             visualizer=visualizer)
    return

  # enable cudnn benchmark
  cudnn.enabled = True
  cudnn.benchmark = True

  # warmup the training
  if (args.start_epoch == 0) and (args.warmup_epochs > 0):
    print("Warmup the training ...")
    for epoch in range(0, args.warmup_epochs):
      train(train_loader, model, criterion, optimizer, epoch, "warmup", args)

  # start the training
  print("Training the model ...")
  for epoch in range(args.start_epoch, args.epochs):
    # train for one epoch
    train(train_loader, model, criterion, optimizer, epoch, "train", args)

    # evaluate on validation set
    acc1 = validate(val_loader, model, epoch, args)

    # remember best acc@1 and save checkpoint
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1)
    save_checkpoint({
      'epoch': epoch + 1,
      'model_arch': model_arch,
      'state_dict': model.state_dict(),
      'best_acc1': best_acc1,
      'optimizer' : optimizer.state_dict(),
    }, is_best)
示例#23
0
from torchvision import datasets, transforms
import torch.utils.data as data
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import os
import custom_transforms as trans

img_transform = transforms.Compose([
    trans.RandomHorizontalFlip(),
    trans.RandomGaussianBlur(),
    trans.RandomScaleCrop(700, 512),
    trans.Normalize(),
    trans.ToTensor()
])


class TrainImageFolder(data.Dataset):
    def __init__(self, data_dir):
        self.f = open(os.path.join(data_dir, 'train_id.txt'))
        self.file_list = self.f.readlines()
        self.data_dir = data_dir

    def __getitem__(self, index):
        img = Image.open(
            os.path.join(self.data_dir, 'train_images',
                         self.file_list[index][:-1] + '.jpg')).convert('RGB')
        parse = Image.open(
            os.path.join(self.data_dir, 'train_segmentations',
def main():
    global args
    args = parser.parse_args()
    save_path = Path(args.name)
    args.save_path = 'checkpoints'/save_path 
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()


    
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    train_transform = custom_transforms.Compose([
          # custom_transforms.RandomRotate(),
          # custom_transforms.RandomHorizontalFlip(),
          # custom_transforms.RandomScaleCrop(),
          custom_transforms.ArrayToTensor(),
          normalize  ])

    training_writer = SummaryWriter(args.save_path)

    intrinsics = np.array([542.822841, 0, 315.593520, 0, 542.576870, 237.756098, 0, 0, 1]).astype(np.float32).reshape((3, 3))
    
    inference_set = SequenceFolder(
        root = args.dataset_dir,
        intrinsics = intrinsics,
        transform=train_transform,
        train=False,
        sequence_length=args.sequence_length
    )

    print('{} samples found in {} train scenes'.format(len(inference_set), len(inference_set.scenes)))
    inference_loader = torch.utils.data.DataLoader(
        inference_set, batch_size=1, shuffle=False,
        num_workers=args.workers, pin_memory=True, drop_last=True)

    print("=> creating model")
    mask_net = MaskResNet6.MaskResNet6().cuda()
    pose_net = PoseNetB6.PoseNetB6().cuda()
    mask_net = torch.nn.DataParallel(mask_net)

    masknet_weights = torch.load(args.pretrained_mask)# 
    posenet_weights = torch.load(args.pretrained_pose)
    mask_net.load_state_dict(masknet_weights['state_dict'])
    # pose_net.load_state_dict(posenet_weights['state_dict'])
    pose_net.eval()
    mask_net.eval()

    # training 

    for i, (rgb_tgt_img, rgb_ref_imgs, intrinsics, intrinsics_inv) in enumerate(tqdm(inference_loader)):
        #print(rgb_tgt_img)
        tgt_img_var = Variable(rgb_tgt_img.cuda(), volatile=True)
        ref_imgs_var = [Variable(img.cuda(), volatile=True) for img in rgb_ref_imgs]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)
        
        after_mask = tensor2array(ref_imgs_var[0][0]*explainability_mask[0,0]).transpose(1,2,0)
        x = Image.fromarray(np.uint8(after_mask*255))
        x.save(args.save_path/str(i).zfill(3)+'multi.png')
        
        explainability_mask = (explainability_mask[0,0].detach().cpu()).numpy()
        # print(explainability_mask.shape)
        y = Image.fromarray(np.uint8(explainability_mask*255))
        y.save(args.save_path/str(i).zfill(3)+'mask.png')
示例#25
0
def main():
    global global_vars_dict
    args = global_vars_dict['args']
    best_error = -1  #best model choosing

    #mkdir
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")

    args.save_path = Path('checkpoints') / Path(args.data_dir).stem / timestamp

    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)
    if args.alternating:
        args.alternating_flags = np.array([False, False, True])
    #mk writers
    tb_writer = SummaryWriter(args.save_path)

    # Data loading code and transpose

    if args.data_normalization == 'global':
        normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                                std=[0.5, 0.5, 0.5])
    elif args.data_normalization == 'local':
        normalize = custom_transforms.NormalizeLocally()

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data_dir))

    train_transform = custom_transforms.Compose([
        #custom_transforms.RandomRotate(),
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(),
        normalize
    ])

    #train set, loader only建立一个
    from datasets.sequence_mc import SequenceFolder
    train_set = SequenceFolder(  # mc data folder
        args.data_dir,
        transform=train_transform,
        seed=args.seed,
        train=True,
        sequence_length=args.sequence_length,  # 5
        target_transform=None,
        depth_format='png')

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

#val set,loader 挨个建立
#if args.val_with_depth_gt:
    from datasets.validation_folders2 import ValidationSet

    val_set_with_depth_gt = ValidationSet(args.data_dir,
                                          transform=valid_transform,
                                          depth_format='png')

    val_loader_depth = torch.utils.data.DataLoader(val_set_with_depth_gt,
                                                   batch_size=args.batch_size,
                                                   shuffle=False,
                                                   num_workers=args.workers,
                                                   pin_memory=True,
                                                   drop_last=True)

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))

    #1 create model
    print("=> creating model")
    #1.1 disp_net
    disp_net = getattr(models, args.dispnet)().cuda()
    output_exp = True  #args.mask_loss_weight > 0

    if args.pretrained_disp:
        print("=> using pre-trained weights from {}".format(
            args.pretrained_disp))
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    if args.resume:
        print("=> resuming from checkpoint")
        dispnet_weights = torch.load(args.save_path /
                                     'dispnet_checkpoint.pth.tar')
        disp_net.load_state_dict(dispnet_weights['state_dict'])

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters())

    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    if args.resume and (args.save_path /
                        'optimizer_checkpoint.pth.tar').exists():
        print("=> loading optimizer from checkpoint")
        optimizer_weights = torch.load(args.save_path /
                                       'optimizer_checkpoint.pth.tar')
        optimizer.load_state_dict(optimizer_weights['state_dict'])

    #
    if args.log_terminal:
        logger = TermLogger(n_epochs=args.epochs,
                            train_size=min(len(train_loader), args.epoch_size),
                            valid_size=len(val_loader_depth))
        logger.reset_epoch_bar()
    else:
        logger = None


#预先评估下
    criterion_train = MaskedL1Loss().to(device)  # l1LOSS 容易优化
    criterion_val = ComputeErrors().to(device)

    #depth_error_names,depth_errors = validate_depth_with_gt(val_loader_depth, disp_net,criterion=criterion_val, epoch=0, logger=logger,tb_writer=tb_writer,global_vars_dict=global_vars_dict)

    #logger.reset_epoch_bar()
    #    logger.epoch_logger_update(epoch=0,time=0,names=depth_error_names,values=depth_errors)
    epoch_time = AverageMeter()
    end = time.time()
    #3. main cycle
    for epoch in range(1, args.epochs):  #epoch 0 在第没入循环之前已经测试了.

        logger.reset_train_bar()
        logger.reset_valid_bar()

        errors = [0]
        error_names = ['no error names depth']

        #3.2 train for one epoch---------
        loss_names, losses = train_depth_gt(train_loader=train_loader,
                                            disp_net=disp_net,
                                            optimizer=optimizer,
                                            criterion=criterion_train,
                                            logger=logger,
                                            train_writer=tb_writer,
                                            global_vars_dict=global_vars_dict)

        #3.3 evaluate on validation set-----
        depth_error_names, depth_errors = validate_depth_with_gt(
            val_loader=val_loader_depth,
            disp_net=disp_net,
            criterion=criterion_val,
            epoch=epoch,
            logger=logger,
            tb_writer=tb_writer,
            global_vars_dict=global_vars_dict)

        epoch_time.update(time.time() - end)
        end = time.time()

        #3.5 log_terminal
        #if args.log_terminal:
        if args.log_terminal:
            logger.epoch_logger_update(epoch=epoch,
                                       time=epoch_time,
                                       names=depth_error_names,
                                       values=depth_errors)

    # tensorboard scaler
    #train loss
        for loss_name, loss in zip(loss_names, losses.avg):
            tb_writer.add_scalar('train/' + loss_name, loss, epoch)

        #val_with_gt loss
        for name, error in zip(depth_error_names, depth_errors.avg):
            tb_writer.add_scalar('val/' + name, error, epoch)

        #3.6 save model and remember lowest error and save checkpoint
        total_loss = losses.avg[0]
        if best_error < 0:
            best_error = total_loss

        is_best = total_loss <= best_error
        best_error = min(best_error, total_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, {
            'epoch': epoch + 1,
            'state_dict': None
        }, is_best)

    if args.log_terminal:
        logger.epoch_bar.finish()
示例#26
0
def main():
    args = parser.parse_args()
    output_dir = Path(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])
    val_set = SequenceFolder(args.data,
                             transform=valid_transform,
                             seed=args.seed,
                             sequence_length=args.sequence_length)

    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    dpsnet = PSNet(args.nlabel, args.mindepth).cuda()
    weights = torch.load(args.pretrained_dps)
    dpsnet.load_state_dict(weights['state_dict'])
    dpsnet.eval()

    output_dir = Path(args.output_dir)
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    errors = np.zeros((2, 8, int(len(val_loader) / args.print_freq) + 1),
                      np.float32)
    with torch.no_grad():
        for ii, (tgt_img, ref_imgs, ref_poses, intrinsics, intrinsics_inv,
                 tgt_depth, scale_) in enumerate(val_loader):
            if ii % args.print_freq == 0:
                i = int(ii / args.print_freq)
                tgt_img_var = Variable(tgt_img.cuda())
                ref_imgs_var = [Variable(img.cuda()) for img in ref_imgs]
                ref_poses_var = [Variable(pose.cuda()) for pose in ref_poses]
                intrinsics_var = Variable(intrinsics.cuda())
                intrinsics_inv_var = Variable(intrinsics_inv.cuda())
                tgt_depth_var = Variable(tgt_depth.cuda())
                scale = scale_.numpy()[0]

                # compute output
                pose = torch.cat(ref_poses_var, 1)
                start = time.time()
                output_depth = dpsnet(tgt_img_var, ref_imgs_var, pose,
                                      intrinsics_var, intrinsics_inv_var)
                elps = time.time() - start
                mask = (tgt_depth <= args.maxdepth) & (
                    tgt_depth >= args.mindepth) & (tgt_depth == tgt_depth)

                tgt_disp = args.mindepth * args.nlabel / tgt_depth
                output_disp = args.mindepth * args.nlabel / output_depth

                output_disp_ = torch.squeeze(output_disp.data.cpu(), 1)
                output_depth_ = torch.squeeze(output_depth.data.cpu(), 1)

                errors[0, :,
                       i] = compute_errors_test(tgt_depth[mask] / scale,
                                                output_depth_[mask] / scale)
                errors[1, :,
                       i] = compute_errors_test(tgt_disp[mask] / scale,
                                                output_disp_[mask] / scale)

                print('Elapsed Time {} Abs Error {:.4f}'.format(
                    elps, errors[0, 0, i]))

                if args.output_print:
                    output_disp_n = (output_disp_).numpy()[0]
                    np.save(output_dir / '{:04d}{}'.format(i, '.npy'),
                            output_disp_n)
                    disp = (255 * tensor2array(torch.from_numpy(output_disp_n),
                                               max_value=args.nlabel,
                                               colormap='bone')).astype(
                                                   np.uint8)
                    imsave(output_dir / '{:04d}_disp{}'.format(i, '.png'),
                           disp)

    mean_errors = errors.mean(2)
    error_names = [
        'abs_rel', 'abs_diff', 'sq_rel', 'rms', 'log_rms', 'a1', 'a2', 'a3'
    ]
    print("{}".format(args.output_dir))
    print("Depth Results : ")
    print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".
          format(*error_names))
    print(
        "{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*mean_errors[0]))

    print("Disparity Results : ")
    print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".
          format(*error_names))
    print(
        "{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}"
        .format(*mean_errors[1]))

    np.savetxt(output_dir / 'errors.csv',
               mean_errors,
               fmt='%1.4f',
               delimiter=',')
示例#27
0
def main():
    global n_iter
    args = parser.parse_args()
    save_path = save_path_formatter(args, parser)
    args.save_path = 'checkpoints' / (args.exp + '_' + save_path)
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    training_writer = SummaryWriter(args.save_path)
    output_writers = []
    for i in range(3):
        output_writers.append(SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])

    train_transform = custom_transforms.Compose([
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    valid_transform = custom_transforms.Compose(
        [custom_transforms.ArrayToTensor(), normalize])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=train_transform,
                               seed=args.seed,
                               ttype=args.ttype,
                               dataset=args.dataset)
    val_set = SequenceFolder(args.data,
                             transform=valid_transform,
                             seed=args.seed,
                             ttype=args.ttype2,
                             dataset=args.dataset)

    train_set.samples = train_set.samples[:len(train_set) -
                                          len(train_set) % args.batch_size]

    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    mvdnet = MVDNet(args.nlabel, args.mindepth, no_pool=args.no_pool).cuda()
    mvdnet.init_weights()
    if args.pretrained_mvdn:
        print("=> using pre-trained weights for MVDNet")
        weights = torch.load(args.pretrained_mvdn)
        mvdnet.load_state_dict(weights['state_dict'])

    depth_cons = DepthCons().cuda()
    depth_cons.init_weights()

    if args.pretrained_cons:
        print("=> using pre-trained weights for ConsNet")
        weights = torch.load(args.pretrained_cons)
        depth_cons.load_state_dict(weights['state_dict'])

    cons_loss_ = ConsLoss().cuda()
    print('=> setting adam solver')

    if args.train_cons:
        optimizer = torch.optim.Adam(depth_cons.parameters(),
                                     args.lr,
                                     betas=(args.momentum, args.beta),
                                     weight_decay=args.weight_decay)
        mvdnet.eval()
    else:
        optimizer = torch.optim.Adam(mvdnet.parameters(),
                                     args.lr,
                                     betas=(args.momentum, args.beta),
                                     weight_decay=args.weight_decay)

    cudnn.benchmark = True
    mvdnet = torch.nn.DataParallel(mvdnet)
    depth_cons = torch.nn.DataParallel(depth_cons)

    print(' ==> setting log files')
    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow([
            'train_loss', 'validation_abs_rel', 'validation_abs_diff',
            'validation_sq_rel', 'validation_a1', 'validation_a2',
            'validation_a3', 'mean_angle_error'
        ])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss'])

    print(' ==> main Loop')
    for epoch in range(args.epochs):
        adjust_learning_rate(args, optimizer, epoch)

        # train for one epoch
        if args.evaluate:
            train_loss = 0
        else:
            train_loss = train(args, train_loader, mvdnet, depth_cons,
                               cons_loss_, optimizer, args.epoch_size,
                               training_writer, epoch)
        if not args.evaluate and (args.skip_v):
            error_names = [
                'abs_rel', 'abs_diff', 'sq_rel', 'a1', 'a2', 'a3', 'angle'
            ]
            errors = [0] * 7
        else:
            errors, error_names = validate_with_gt(args, val_loader, mvdnet,
                                                   depth_cons, epoch,
                                                   output_writers)

        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))

        for error, name in zip(errors, error_names):
            training_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance, careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[0]
        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([
                train_loss, decisive_error, errors[1], errors[2], errors[3],
                errors[4], errors[5], errors[6]
            ])
        if args.evaluate:
            break
        if args.train_cons:
            save_checkpoint(args.save_path, {
                'epoch': epoch + 1,
                'state_dict': depth_cons.module.state_dict()
            },
                            epoch,
                            file_prefixes=['cons'])
        else:
            save_checkpoint(args.save_path, {
                'epoch': epoch + 1,
                'state_dict': mvdnet.module.state_dict()
            },
                            epoch,
                            file_prefixes=['mvdnet'])
示例#28
0
def main():
    best_error = -1
    n_iter = 0
    torch_device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")

    # parse training arguments
    args = parse_training_args()
    args.training_output_freq = 100  # resetting the training output frequency here.

    # create a folder to save the output of training
    save_path = make_save_path(args)
    args.save_path = save_path
    # save the current configuration to a pickel file
    dump_config(save_path, args)

    print('=> Saving checkpoints to {}'.format(save_path))
    # set manual seed. WHY??
    torch.manual_seed(args.seed)
    # tensorboard summary
    tb_writer = SummaryWriter(save_path)

    # Data preprocessing
    train_transform = valid_transform = custom_transforms.Compose([
        custom_transforms.ArrayToTensor(),
        custom_transforms.Normalize(mean=0.5, std=0.5)
    ])

    # Load datasets
    print("=> Fetching scenes in '{}'".format(args.data))

    train_set = val_set = None
    if args.lfformat is 'focalstack':
        train_set, val_set = get_focal_stack_loaders(args, train_transform,
                                                     valid_transform)
    elif args.lfformat is 'stack':
        train_set, val_set = get_stacked_lf_loaders(args, train_transform,
                                                    valid_transform)

    print('=> {} samples found in {} train scenes'.format(
        len(train_set), len(train_set.scenes)))
    print('=> {} samples found in {} valid scenes'.format(
        len(val_set), len(val_set.scenes)))

    # Create batch loader
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    # Pull first example from dataset to check number of channels
    input_channels = train_set[0][1].shape[0]
    args.epoch_size = len(train_loader)
    print("=> Using {} input channels, {} total batches".format(
        input_channels, args.epoch_size))

    # create model
    print("=> Creating models")
    disp_net = models.LFDispNet(in_channels=input_channels).to(torch_device)
    output_exp = args.mask_loss_weight > 0
    pose_exp_net = models.LFPoseNet(in_channels=input_channels,
                                    nb_ref_imgs=args.sequence_length -
                                    1).to(torch_device)

    # Load or initialize weights
    if args.pretrained_exp_pose:
        print("=> Using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> Using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    # Set some torch flags
    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    # Define optimizer
    print('=> Setting adam solver')
    optim_params = [{
        'params': disp_net.parameters(),
        'lr': args.lr
    }, {
        'params': pose_exp_net.parameters(),
        'lr': args.lr
    }]
    optimizer = torch.optim.Adam(optim_params,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    # Logging
    with open(os.path.join(save_path, args.log_summary), 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(os.path.join(save_path, args.log_full), 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    # train the network
    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(args, train_loader, disp_net, pose_exp_net,
                           optimizer, args.epoch_size, logger, tb_writer,
                           n_iter, torch_device)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        errors, error_names = validate_without_gt(args, val_loader, disp_net,
                                                  pose_exp_net, epoch, logger,
                                                  tb_writer, torch_device)
        error_string = ', '.join('{} : {:.3f}'.format(name, error)
                                 for name, error in zip(error_names, errors))
        logger.valid_writer.write(' * Avg {}'.format(error_string))

        # tensorboard logging
        for error, name in zip(errors, error_names):
            tb_writer.add_scalar(name, error, epoch)

        # Up to you to chose the most relevant error to measure your model's performance,
        # careful some measures are to maximize (such as a1,a2,a3)
        decisive_error = errors[1]
        if best_error < 0:
            best_error = decisive_error

        # remember lowest error and save checkpoint
        is_best = decisive_error < best_error
        best_error = min(best_error, decisive_error)
        save_checkpoint(save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(os.path.join(save_path, args.log_summary), 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, decisive_error])
    logger.epoch_bar.finish()
示例#29
0
def main():
    global args
    args = parser.parse_args()

    args.pretrained_disp = Path(args.pretrained_disp)
    args.pretrained_pose = Path(args.pretrained_pose)
    args.pretrained_mask = Path(args.pretrained_mask)
    args.pretrained_flow = Path(args.pretrained_flow)

    if args.output_dir is not None:
        args.output_dir = Path(args.output_dir)
        args.output_dir.makedirs_p()

        image_dir = args.output_dir / 'images'
        gt_dir = args.output_dir / 'gt'
        mask_dir = args.output_dir / 'mask'
        viz_dir = args.output_dir / 'viz'
        rigidity_mask_dir = args.output_dir / 'rigidity'
        rigidity_census_mask_dir = args.output_dir / 'rigidity_census'
        explainability_mask_dir = args.output_dir / 'explainability'

        image_dir.makedirs_p()
        gt_dir.makedirs_p()
        mask_dir.makedirs_p()
        viz_dir.makedirs_p()
        rigidity_mask_dir.makedirs_p()
        rigidity_census_mask_dir.makedirs_p()
        explainability_mask_dir.makedirs_p()

        output_writer = SummaryWriter(args.output_dir)

    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    flow_loader_h, flow_loader_w = 256, 832
    valid_flow_transform = custom_transforms.Compose([
        custom_transforms.Scale(h=flow_loader_h, w=flow_loader_w),
        custom_transforms.ArrayToTensor(), normalize
    ])
    val_flow_set = ValidationMask(root=args.kitti_dir,
                                  sequence_length=5,
                                  transform=valid_flow_transform)

    val_loader = torch.utils.data.DataLoader(val_flow_set,
                                             batch_size=1,
                                             shuffle=False,
                                             num_workers=2,
                                             pin_memory=True,
                                             drop_last=True)

    disp_net = getattr(models, args.dispnet)().cuda()
    pose_net = getattr(models, args.posenet)(nb_ref_imgs=4).cuda()
    mask_net = getattr(models, args.masknet)(nb_ref_imgs=4).cuda()
    flow_net = getattr(models, args.flownet)(nlevels=args.nlevels).cuda()

    dispnet_weights = torch.load(args.pretrained_disp)
    posenet_weights = torch.load(args.pretrained_pose)
    masknet_weights = torch.load(args.pretrained_mask)
    flownet_weights = torch.load(args.pretrained_flow)
    disp_net.load_state_dict(dispnet_weights['state_dict'])
    pose_net.load_state_dict(posenet_weights['state_dict'])
    flow_net.load_state_dict(flownet_weights['state_dict'])
    mask_net.load_state_dict(masknet_weights['state_dict'])

    disp_net.eval()
    pose_net.eval()
    mask_net.eval()
    flow_net.eval()

    error_names = ['tp_0', 'fp_0', 'fn_0', 'tp_1', 'fp_1', 'fn_1']
    errors = AverageMeter(i=len(error_names))
    errors_census = AverageMeter(i=len(error_names))
    errors_bare = AverageMeter(i=len(error_names))

    for i, (tgt_img, ref_imgs, intrinsics, intrinsics_inv, flow_gt, obj_map_gt,
            semantic_map_gt) in enumerate(tqdm(val_loader)):
        tgt_img_var = Variable(tgt_img.cuda(), volatile=True)
        ref_imgs_var = [
            Variable(img.cuda(), volatile=True) for img in ref_imgs
        ]
        intrinsics_var = Variable(intrinsics.cuda(), volatile=True)
        intrinsics_inv_var = Variable(intrinsics_inv.cuda(), volatile=True)

        flow_gt_var = Variable(flow_gt.cuda(), volatile=True)
        obj_map_gt_var = Variable(obj_map_gt.cuda(), volatile=True)

        disp = disp_net(tgt_img_var)
        depth = 1 / disp
        pose = pose_net(tgt_img_var, ref_imgs_var)
        explainability_mask = mask_net(tgt_img_var, ref_imgs_var)
        if args.flownet in ['Back2Future']:
            flow_fwd, flow_bwd, _ = flow_net(tgt_img_var, ref_imgs_var[1:3])
        else:
            flow_fwd = flow_net(tgt_img_var, ref_imgs_var[2])
        flow_cam = pose2flow(depth.squeeze(1), pose[:, 2], intrinsics_var,
                             intrinsics_inv_var)

        rigidity_mask = 1 - (1 - explainability_mask[:, 1]) * (
            1 - explainability_mask[:, 2]).unsqueeze(1) > 0.5
        rigidity_mask_census_soft = (flow_cam - flow_fwd).pow(2).sum(
            dim=1).unsqueeze(1).sqrt()  #.normalize()
        rigidity_mask_census_soft = 1 - rigidity_mask_census_soft / rigidity_mask_census_soft.max(
        )
        rigidity_mask_census = rigidity_mask_census_soft > args.THRESH

        rigidity_mask_combined = 1 - (
            1 - rigidity_mask.type_as(explainability_mask)) * (
                1 - rigidity_mask_census.type_as(explainability_mask))

        flow_fwd_non_rigid = (1 - rigidity_mask_combined).type_as(
            flow_fwd).expand_as(flow_fwd) * flow_fwd
        flow_fwd_rigid = rigidity_mask_combined.type_as(flow_fwd).expand_as(
            flow_fwd) * flow_cam
        total_flow = flow_fwd_rigid + flow_fwd_non_rigid

        obj_map_gt_var_expanded = obj_map_gt_var.unsqueeze(1).type_as(flow_fwd)

        tgt_img_np = tgt_img[0].numpy()
        rigidity_mask_combined_np = rigidity_mask_combined.cpu().data[0].numpy(
        )
        rigidity_mask_census_np = rigidity_mask_census.cpu().data[0].numpy()
        rigidity_mask_bare_np = rigidity_mask.cpu().data[0].numpy()

        gt_mask_np = obj_map_gt[0].numpy()
        semantic_map_np = semantic_map_gt[0].numpy()

        _errors = mask_error(gt_mask_np, semantic_map_np,
                             rigidity_mask_combined_np[0])
        _errors_census = mask_error(gt_mask_np, semantic_map_np,
                                    rigidity_mask_census_np[0])
        _errors_bare = mask_error(gt_mask_np, semantic_map_np,
                                  rigidity_mask_bare_np[0])

        errors.update(_errors)
        errors_census.update(_errors_census)
        errors_bare.update(_errors_bare)

        if args.output_dir is not None:
            np.save(image_dir / str(i).zfill(3), tgt_img_np)
            np.save(gt_dir / str(i).zfill(3), gt_mask_np)
            np.save(mask_dir / str(i).zfill(3), rigidity_mask_combined_np)
            np.save(rigidity_mask_dir / str(i).zfill(3),
                    rigidity_mask.cpu().data[0].numpy())
            np.save(rigidity_census_mask_dir / str(i).zfill(3),
                    rigidity_mask_census.cpu().data[0].numpy())
            np.save(explainability_mask_dir / str(i).zfill(3),
                    explainability_mask[:, 1].cpu().data[0].numpy())
            # rigidity_mask_dir rigidity_mask.numpy()
            # rigidity_census_mask_dir rigidity_mask_census.numpy()

        if (args.output_dir is not None) and i % 10 == 0:
            ind = int(i // 10)
            output_writer.add_image(
                'val Dispnet Output Normalized',
                tensor2array(disp.data[0].cpu(),
                             max_value=None,
                             colormap='bone'), ind)
            output_writer.add_image('val Input',
                                    tensor2array(tgt_img[0].cpu()), i)
            output_writer.add_image(
                'val Total Flow Output',
                flow_to_image(tensor2array(total_flow.data[0].cpu())), ind)
            output_writer.add_image(
                'val Rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_rigid.data[0].cpu())), ind)
            output_writer.add_image(
                'val Non-rigid Flow Output',
                flow_to_image(tensor2array(flow_fwd_non_rigid.data[0].cpu())),
                ind)
            output_writer.add_image(
                'val Rigidity Mask',
                tensor2array(rigidity_mask.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Census',
                tensor2array(rigidity_mask_census.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)
            output_writer.add_image(
                'val Rigidity Mask Combined',
                tensor2array(rigidity_mask_combined.data[0].cpu(),
                             max_value=1,
                             colormap='bone'), ind)

        if args.output_dir is not None:
            tgt_img_viz = tensor2array(tgt_img[0].cpu())
            depth_viz = tensor2array(disp.data[0].cpu(),
                                     max_value=None,
                                     colormap='magma')
            mask_viz = tensor2array(rigidity_mask_census_soft.data[0].cpu(),
                                    max_value=1,
                                    colormap='bone')
            row2_viz = flow_to_image(
                np.hstack((tensor2array(flow_cam.data[0].cpu()),
                           tensor2array(flow_fwd_non_rigid.data[0].cpu()),
                           tensor2array(total_flow.data[0].cpu()))))

            row1_viz = np.hstack((tgt_img_viz, depth_viz, mask_viz))
            ####### sửa 2 cái vstack thành hstack ###############
            viz3 = np.hstack(
                (255 * tgt_img_viz, 255 * depth_viz, 255 * mask_viz,
                 flow_to_image(
                     np.hstack((tensor2array(flow_fwd_non_rigid.data[0].cpu()),
                                tensor2array(total_flow.data[0].cpu()))))))
            ########################################################
            ######## code tự thêm ####################
            row1_viz = np.transpose(row1_viz, (1, 2, 0))
            row2_viz = np.transpose(row2_viz, (1, 2, 0))
            viz3 = np.transpose(viz3, (1, 2, 0))
            ##########################################

            row1_viz_im = Image.fromarray((255 * row1_viz).astype('uint8'))
            row2_viz_im = Image.fromarray((row2_viz).astype('uint8'))
            viz3_im = Image.fromarray(viz3.astype('uint8'))

            row1_viz_im.save(viz_dir / str(i).zfill(3) + '01.png')
            row2_viz_im.save(viz_dir / str(i).zfill(3) + '02.png')
            viz3_im.save(viz_dir / str(i).zfill(3) + '03.png')

    bg_iou = errors.sum[0] / (errors.sum[0] + errors.sum[1] + errors.sum[2])
    fg_iou = errors.sum[3] / (errors.sum[3] + errors.sum[4] + errors.sum[5])
    avg_iou = (bg_iou + fg_iou) / 2

    bg_iou_census = errors_census.sum[0] / (
        errors_census.sum[0] + errors_census.sum[1] + errors_census.sum[2])
    fg_iou_census = errors_census.sum[3] / (
        errors_census.sum[3] + errors_census.sum[4] + errors_census.sum[5])
    avg_iou_census = (bg_iou_census + fg_iou_census) / 2

    bg_iou_bare = errors_bare.sum[0] / (
        errors_bare.sum[0] + errors_bare.sum[1] + errors_bare.sum[2])
    fg_iou_bare = errors_bare.sum[3] / (
        errors_bare.sum[3] + errors_bare.sum[4] + errors_bare.sum[5])
    avg_iou_bare = (bg_iou_bare + fg_iou_bare) / 2

    print("Results Full Model")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou, bg_iou, fg_iou))

    print("Results Census only")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou_census, bg_iou_census, fg_iou_census))

    print("Results Bare")
    print("\t {:>10}, {:>10}, {:>10} ".format('iou', 'bg_iou', 'fg_iou'))
    print("Errors \t {:10.4f}, {:10.4f} {:10.4f}".format(
        avg_iou_bare, bg_iou_bare, fg_iou_bare))
示例#30
0
def main():
    global args, best_photo_loss, n_iter
    args = parser.parse_args()
    if args.dataset_format == 'stacked':
        from datasets.stacked_sequence_folders import SequenceFolder
    elif args.dataset_format == 'sequential':
        from datasets.sequence_folders import SequenceFolder
    save_path = Path('{}epochs{},seq{},b{},lr{},p{},m{},s{}'.format(
        args.epochs,
        ',epochSize' + str(args.epoch_size) if args.epoch_size > 0 else '',
        args.sequence_length, args.batch_size, args.lr, args.photo_loss_weight,
        args.mask_loss_weight, args.smooth_loss_weight))
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M")
    args.save_path = 'checkpoints' / save_path / timestamp
    print('=> will save everything to {}'.format(args.save_path))
    args.save_path.makedirs_p()
    torch.manual_seed(args.seed)

    train_writer = SummaryWriter(args.save_path / 'train')
    valid_writer = SummaryWriter(args.save_path / 'valid')
    output_writers = []
    if args.log_output:
        for i in range(3):
            output_writers.append(
                SummaryWriter(args.save_path / 'valid' / str(i)))

    # Data loading code
    normalize = custom_transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                            std=[0.5, 0.5, 0.5])
    input_transform = custom_transforms.Compose([
        custom_transforms.RandomHorizontalFlip(),
        custom_transforms.RandomScaleCrop(),
        custom_transforms.ArrayToTensor(), normalize
    ])

    print("=> fetching scenes in '{}'".format(args.data))
    train_set = SequenceFolder(args.data,
                               transform=input_transform,
                               seed=args.seed,
                               train=True,
                               sequence_length=args.sequence_length)
    val_set = SequenceFolder(args.data,
                             transform=custom_transforms.Compose([
                                 custom_transforms.ArrayToTensor(), normalize
                             ]),
                             seed=args.seed,
                             train=False,
                             sequence_length=args.sequence_length)
    print('{} samples found in {} train scenes'.format(len(train_set),
                                                       len(train_set.scenes)))
    print('{} samples found in {} valid scenes'.format(len(val_set),
                                                       len(val_set.scenes)))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.epoch_size == 0:
        args.epoch_size = len(train_loader)

    # create model
    print("=> creating model")

    disp_net = models.DispNetS().cuda()
    output_exp = args.mask_loss_weight > 0
    if not output_exp:
        print("=> no mask loss, PoseExpnet will only output pose")
    pose_exp_net = models.PoseExpNet(
        nb_ref_imgs=args.sequence_length - 1,
        output_exp=args.mask_loss_weight > 0).cuda()

    if args.pretrained_exp_pose:
        print("=> using pre-trained weights for explainabilty and pose net")
        weights = torch.load(args.pretrained_exp_pose)
        pose_exp_net.load_state_dict(weights['state_dict'], strict=False)
    else:
        pose_exp_net.init_weights()

    if args.pretrained_disp:
        print("=> using pre-trained weights for Dispnet")
        weights = torch.load(args.pretrained_disp)
        disp_net.load_state_dict(weights['state_dict'])
    else:
        disp_net.init_weights()

    cudnn.benchmark = True
    disp_net = torch.nn.DataParallel(disp_net)
    pose_exp_net = torch.nn.DataParallel(pose_exp_net)

    print('=> setting adam solver')

    parameters = chain(disp_net.parameters(), pose_exp_net.parameters())
    optimizer = torch.optim.Adam(parameters,
                                 args.lr,
                                 betas=(args.momentum, args.beta),
                                 weight_decay=args.weight_decay)

    with open(args.save_path / args.log_summary, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(['train_loss', 'validation_loss'])

    with open(args.save_path / args.log_full, 'w') as csvfile:
        writer = csv.writer(csvfile, delimiter='\t')
        writer.writerow(
            ['train_loss', 'photo_loss', 'explainability_loss', 'smooth_loss'])

    logger = TermLogger(n_epochs=args.epochs,
                        train_size=min(len(train_loader), args.epoch_size),
                        valid_size=len(val_loader))
    logger.epoch_bar.start()

    for epoch in range(args.epochs):
        logger.epoch_bar.update(epoch)

        # train for one epoch
        logger.reset_train_bar()
        train_loss = train(train_loader, disp_net, pose_exp_net, optimizer,
                           args.epoch_size, logger, train_writer)
        logger.train_writer.write(' * Avg Loss : {:.3f}'.format(train_loss))

        # evaluate on validation set
        logger.reset_valid_bar()
        valid_photo_loss, valid_exp_loss, valid_total_loss = validate(
            val_loader, disp_net, pose_exp_net, epoch, logger, output_writers)
        logger.valid_writer.write(
            ' * Avg Photo Loss : {:.3f}, Valid Loss : {:.3f}, Total Loss : {:.3f}'
            .format(valid_photo_loss, valid_exp_loss, valid_total_loss))
        valid_writer.add_scalar(
            'photometric_error', valid_photo_loss * 4, n_iter
        )  # Loss is multiplied by 4 because it's only one scale, instead of 4 during training
        valid_writer.add_scalar('explanability_loss', valid_exp_loss * 4,
                                n_iter)
        valid_writer.add_scalar('total_loss', valid_total_loss * 4, n_iter)

        if best_photo_loss < 0:
            best_photo_loss = valid_photo_loss

        # remember lowest error and save checkpoint
        is_best = valid_photo_loss < best_photo_loss
        best_photo_loss = min(valid_photo_loss, best_photo_loss)
        save_checkpoint(args.save_path, {
            'epoch': epoch + 1,
            'state_dict': disp_net.module.state_dict()
        }, {
            'epoch': epoch + 1,
            'state_dict': pose_exp_net.module.state_dict()
        }, is_best)

        with open(args.save_path / args.log_summary, 'a') as csvfile:
            writer = csv.writer(csvfile, delimiter='\t')
            writer.writerow([train_loss, valid_total_loss])
    logger.epoch_bar.finish()