def evaluate(model, args): val_datasets = args.val_dataset or [] if args.weights and args.test_only: weights = torchvision.models.get_weight(args.weights) trans = weights.transforms() def preprocessing(img1, img2, flow, valid_flow_mask): img1, img2 = trans(img1, img2) if flow is not None and not isinstance(flow, torch.Tensor): flow = torch.from_numpy(flow) if valid_flow_mask is not None and not isinstance( valid_flow_mask, torch.Tensor): valid_flow_mask = torch.from_numpy(valid_flow_mask) return img1, img2, flow, valid_flow_mask else: preprocessing = OpticalFlowPresetEval() for name in val_datasets: if name == "kitti": # Kitti has different image sizes so we need to individually pad them, we can't batch. # see comment in InputPadder if args.batch_size != 1 and (not args.distributed or args.rank == 0): warnings.warn( f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1." ) val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing) _evaluate(model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1) elif name == "sintel": for pass_name in ("clean", "final"): val_dataset = Sintel(root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing) _evaluate( model, args, val_dataset, num_flow_updates=32, padder_mode="sintel", header=f"Sintel val {pass_name}", ) else: warnings.warn(f"Can't validate on {val_dataset}, skipping.")
def evaluate(model, args): val_datasets = args.val_dataset or [] if args.prototype: if args.weights: weights = prototype.models.get_weight(args.weights) preprocessing = weights.transforms() else: preprocessing = prototype.transforms.OpticalFlowEval() else: preprocessing = OpticalFlowPresetEval() for name in val_datasets: if name == "kitti": # Kitti has different image sizes so we need to individually pad them, we can't batch. # see comment in InputPadder if args.batch_size != 1 and (not args.distributed or args.rank == 0): warnings.warn( f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1." ) val_dataset = KittiFlow(root=args.dataset_root, split="train", transforms=preprocessing) _evaluate(model, args, val_dataset, num_flow_updates=24, padder_mode="kitti", header="Kitti val", batch_size=1) elif name == "sintel": for pass_name in ("clean", "final"): val_dataset = Sintel(root=args.dataset_root, split="train", pass_name=pass_name, transforms=preprocessing) _evaluate( model, args, val_dataset, num_flow_updates=32, padder_mode="sintel", header=f"Sintel val {pass_name}", ) else: warnings.warn(f"Can't validate on {val_dataset}, skipping.")
def get_train_dataset(stage, dataset_root): if stage == "chairs": transforms = OpticalFlowPresetTrain(crop_size=(368, 496), min_scale=0.1, max_scale=1.0, do_flip=True) return FlyingChairs(root=dataset_root, split="train", transforms=transforms) elif stage == "things": transforms = OpticalFlowPresetTrain(crop_size=(400, 720), min_scale=-0.4, max_scale=0.8, do_flip=True) return FlyingThings3D(root=dataset_root, split="train", pass_name="both", transforms=transforms) elif stage == "sintel_SKH": # S + K + H as from paper crop_size = (368, 768) transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.2, max_scale=0.6, do_flip=True) things_clean = FlyingThings3D(root=dataset_root, split="train", pass_name="clean", transforms=transforms) sintel = Sintel(root=dataset_root, split="train", pass_name="both", transforms=transforms) kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.3, max_scale=0.5, do_flip=True) kitti = KittiFlow(root=dataset_root, split="train", transforms=kitti_transforms) hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size, min_scale=-0.5, max_scale=0.2, do_flip=True) hd1k = HD1K(root=dataset_root, split="train", transforms=hd1k_transforms) # As future improvement, we could probably be using a distributed sampler here # The distribution is S(.71), T(.135), K(.135), H(.02) return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean elif stage == "kitti": transforms = OpticalFlowPresetTrain( # resize and crop params crop_size=(288, 960), min_scale=-0.2, max_scale=0.4, stretch_prob=0, # flip params do_flip=False, # jitter params brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3 / 3.14, asymmetric_jitter_prob=0, ) return KittiFlow(root=dataset_root, split="train", transforms=transforms) else: raise ValueError(f"Unknown stage {stage}")