Exemple #1
0
def evaluate(model, args):
    val_datasets = args.val_dataset or []

    if args.weights and args.test_only:
        weights = torchvision.models.get_weight(args.weights)
        trans = weights.transforms()

        def preprocessing(img1, img2, flow, valid_flow_mask):
            img1, img2 = trans(img1, img2)
            if flow is not None and not isinstance(flow, torch.Tensor):
                flow = torch.from_numpy(flow)
            if valid_flow_mask is not None and not isinstance(
                    valid_flow_mask, torch.Tensor):
                valid_flow_mask = torch.from_numpy(valid_flow_mask)
            return img1, img2, flow, valid_flow_mask

    else:
        preprocessing = OpticalFlowPresetEval()

    for name in val_datasets:
        if name == "kitti":
            # Kitti has different image sizes so we need to individually pad them, we can't batch.
            # see comment in InputPadder
            if args.batch_size != 1 and (not args.distributed
                                         or args.rank == 0):
                warnings.warn(
                    f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
                )

            val_dataset = KittiFlow(root=args.dataset_root,
                                    split="train",
                                    transforms=preprocessing)
            _evaluate(model,
                      args,
                      val_dataset,
                      num_flow_updates=24,
                      padder_mode="kitti",
                      header="Kitti val",
                      batch_size=1)
        elif name == "sintel":
            for pass_name in ("clean", "final"):
                val_dataset = Sintel(root=args.dataset_root,
                                     split="train",
                                     pass_name=pass_name,
                                     transforms=preprocessing)
                _evaluate(
                    model,
                    args,
                    val_dataset,
                    num_flow_updates=32,
                    padder_mode="sintel",
                    header=f"Sintel val {pass_name}",
                )
        else:
            warnings.warn(f"Can't validate on {val_dataset}, skipping.")
Exemple #2
0
def evaluate(model, args):
    val_datasets = args.val_dataset or []

    if args.prototype:
        if args.weights:
            weights = prototype.models.get_weight(args.weights)
            preprocessing = weights.transforms()
        else:
            preprocessing = prototype.transforms.OpticalFlowEval()
    else:
        preprocessing = OpticalFlowPresetEval()

    for name in val_datasets:
        if name == "kitti":
            # Kitti has different image sizes so we need to individually pad them, we can't batch.
            # see comment in InputPadder
            if args.batch_size != 1 and (not args.distributed
                                         or args.rank == 0):
                warnings.warn(
                    f"Batch-size={args.batch_size} was passed. For technical reasons, evaluating on Kitti can only be done with a batch-size of 1."
                )

            val_dataset = KittiFlow(root=args.dataset_root,
                                    split="train",
                                    transforms=preprocessing)
            _evaluate(model,
                      args,
                      val_dataset,
                      num_flow_updates=24,
                      padder_mode="kitti",
                      header="Kitti val",
                      batch_size=1)
        elif name == "sintel":
            for pass_name in ("clean", "final"):
                val_dataset = Sintel(root=args.dataset_root,
                                     split="train",
                                     pass_name=pass_name,
                                     transforms=preprocessing)
                _evaluate(
                    model,
                    args,
                    val_dataset,
                    num_flow_updates=32,
                    padder_mode="sintel",
                    header=f"Sintel val {pass_name}",
                )
        else:
            warnings.warn(f"Can't validate on {val_dataset}, skipping.")
Exemple #3
0
def get_train_dataset(stage, dataset_root):
    if stage == "chairs":
        transforms = OpticalFlowPresetTrain(crop_size=(368, 496),
                                            min_scale=0.1,
                                            max_scale=1.0,
                                            do_flip=True)
        return FlyingChairs(root=dataset_root,
                            split="train",
                            transforms=transforms)
    elif stage == "things":
        transforms = OpticalFlowPresetTrain(crop_size=(400, 720),
                                            min_scale=-0.4,
                                            max_scale=0.8,
                                            do_flip=True)
        return FlyingThings3D(root=dataset_root,
                              split="train",
                              pass_name="both",
                              transforms=transforms)
    elif stage == "sintel_SKH":  # S + K + H as from paper
        crop_size = (368, 768)
        transforms = OpticalFlowPresetTrain(crop_size=crop_size,
                                            min_scale=-0.2,
                                            max_scale=0.6,
                                            do_flip=True)

        things_clean = FlyingThings3D(root=dataset_root,
                                      split="train",
                                      pass_name="clean",
                                      transforms=transforms)
        sintel = Sintel(root=dataset_root,
                        split="train",
                        pass_name="both",
                        transforms=transforms)

        kitti_transforms = OpticalFlowPresetTrain(crop_size=crop_size,
                                                  min_scale=-0.3,
                                                  max_scale=0.5,
                                                  do_flip=True)
        kitti = KittiFlow(root=dataset_root,
                          split="train",
                          transforms=kitti_transforms)

        hd1k_transforms = OpticalFlowPresetTrain(crop_size=crop_size,
                                                 min_scale=-0.5,
                                                 max_scale=0.2,
                                                 do_flip=True)
        hd1k = HD1K(root=dataset_root,
                    split="train",
                    transforms=hd1k_transforms)

        # As future improvement, we could probably be using a distributed sampler here
        # The distribution is S(.71), T(.135), K(.135), H(.02)
        return 100 * sintel + 200 * kitti + 5 * hd1k + things_clean
    elif stage == "kitti":
        transforms = OpticalFlowPresetTrain(
            # resize and crop params
            crop_size=(288, 960),
            min_scale=-0.2,
            max_scale=0.4,
            stretch_prob=0,
            # flip params
            do_flip=False,
            # jitter params
            brightness=0.3,
            contrast=0.3,
            saturation=0.3,
            hue=0.3 / 3.14,
            asymmetric_jitter_prob=0,
        )
        return KittiFlow(root=dataset_root,
                         split="train",
                         transforms=transforms)
    else:
        raise ValueError(f"Unknown stage {stage}")