def __init__(
        self,
        root: Union[str, Path],
        record_set: RecordSet,
        sampler: FrameSampler = _default_sampler(),
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        frame_counter: Optional[Callable[[Path], int]] = None,
    ) -> None:

        self.root = root
        self.sampler = sampler
        self.record_set = record_set

        if frame_counter is None:
            frame_counter = _get_videofile_frame_count
        self.frame_counter = frame_counter

        if transform is None:
            transform = PILVideoToTensor()
        self.transform = transform

        if target_transform is None:
            target_transform = int
        self.target_transform = target_transform
        self.video_lens = {}
 def __init__(
     self,
     root_path: Union[str, Path],
     filter: Optional[Callable[[Path], bool]] = None,
     label_set: Optional[LabelSet] = None,
     sampler: FrameSampler = _default_sampler(),
     transform: Optional[PILVideoTransform] = None,
     frame_counter: Optional[Callable[[Path], int]] = None,
 ) -> None:
     """
     Args:
         root_path: Path to dataset folder on disk. The contents of this folder
             should be video files.
         filter: Optional filter callable that decides whether a given example video
             is to be included in the dataset or not.
         label_set: Optional label set for labelling examples.
         sampler: Optional sampler for drawing frames from each video.
         transform: Optional transform over the list of frames.
         frame_counter: Optional callable used to determine the number of frames
             each video contains. The callable will be passed the path to a video and
             should return a positive integer representing the number of frames.
             This tends to be useful if you've precomputed the number of frames in a
             dataset.
     """
     if transform is None:
         transform = PILVideoToTensor()
     super().__init__(root_path,
                      label_set=label_set,
                      sampler=sampler,
                      transform=transform)
     self._video_paths = self._get_video_paths(self.root_path, filter)
     self.labels = self._label_examples(self._video_paths, label_set)
     self.video_lengths = self._measure_video_lengths(
         self._video_paths, frame_counter)
Ejemplo n.º 3
0
    def test_disabled_rescale(self):
        transform = PILVideoToTensor(rescale=False)
        frame_arr = 255 * np.ones(shape=(10, 20, 3), dtype=np.uint8)
        frame_arr[0:5, 0:10, :] = 0
        video = [PIL.Image.fromarray(frame_arr)]
        tensor = transform(video)

        assert tensor.min().item() == 0
        assert tensor.max().item() == 255
Ejemplo n.º 4
0
 def test_transform(self, video):
     transform = PILVideoToTensor()
     tensor = transform(video)
     width, height = video[0].size
     n_channels = 3 if video[0].mode == "RGB" else 1
     assert tensor.size(0) == n_channels
     assert tensor.size(1) == len(video)
     assert tensor.size(2) == height
     assert tensor.size(3) == width
Ejemplo n.º 5
0
    def test_rescales_between_0_and_1(self):
        transform = PILVideoToTensor()
        frame_arr = 255 * np.ones(shape=(10, 20, 3), dtype=np.uint8)
        frame_arr[0:5, 0:10, :] = 0
        video = [PIL.Image.fromarray(frame_arr)]
        tensor = transform(video)

        assert tensor.min().item() == 0
        assert tensor.max().item() == 1
    def test_raises_exception_if_ordering_isnt_tchw_or_cthw(self):
        invalid_orderings = [
            "".join(order) for order in permutations(list("TCHW"))
            if "".join(order) not in ["TCHW", "CTHW"]
        ]

        for invalid_ordering in invalid_orderings:
            with pytest.raises(ValueError):
                PILVideoToTensor(ordering=invalid_ordering)
 def test_mapping_to_cthw_ordering(self):
     transform = PILVideoToTensor(ordering="CTHW", rescale=False)
     frames = [
         Image.fromarray(frame) for frame in np.random.randint(
             low=0, high=255, size=(5, 4, 4, 3), dtype=np.uint8)
     ]
     transformed_frames = transform(frames)
     for frame_index, frame in enumerate(frames):
         assert (frame.getpixel(
             (0, 0)) == transformed_frames[:, frame_index, 0,
                                           0].numpy()).all()
Ejemplo n.º 8
0
 def __init__(self,
              record_set,
              filter=None,
              label_set=None,
              sampler=_default_sampler(),
              loader=default_loader,
              transform=PILVideoToTensor(),
              target_transform=int):
     self.filter = filter
     self.record_set = record_set if filter is None else [
         r for r in record_set if filter(r)
     ]
     self.label_set = label_set
     self.labels = self._label_examples(self.record_set, self.label_set)
     self.sampler = sampler
     self.loader = loader
     self.transform = transform
     self.target_transform = target_transform
    def __init__(
        self,
        root_path: Union[str, Path],
        filename_template: str,
        filter: Optional[Callable[[Path], bool]] = None,
        label_set: Optional[LabelSet] = None,
        sampler: FrameSampler = _default_sampler(),
        transform: Optional[PILVideoTransform] = None,
        frame_counter: Optional[Callable[[Path], int]] = None,
    ):
        """

        Args:
            root_path: Path to dataset on disk. Contents of this folder should be
                example folders, each with frames named according to the
                ``filename_template`` argument.
            filename_template: Python 3 style formatting string describing frame
                filenames: e.g. ``"frame_{:06d}.jpg"`` for the example dataset in the
                class docstring.
            filter: Optional filter callable that decides whether a given example folder
                is to be included in the dataset or not.
            label_set: Optional label set for labelling examples.
            sampler: Optional sampler for drawing frames from each video.
            transform: Optional transform performed over the loaded clip.
            frame_counter: Optional callable used to determine the number of frames
                each video contains. The callable will be passed the path to a video
                folder and should return a positive integer representing the number of
                frames. This tends to be useful if you've precomputed the number of
                frames in a dataset.
        """
        super().__init__(root_path,
                         label_set,
                         sampler=sampler,
                         transform=transform)
        self._video_dirs = sorted([
            d for d in self.root_path.iterdir() if filter is None or filter(d)
        ])
        self.labels = self._label_examples(self._video_dirs, label_set)
        self.video_lengths = self._measure_video_lengths(
            self._video_dirs, frame_counter)
        self.filename_template = filename_template
        if self.transform is None:
            self.transform = PILVideoToTensor()
Ejemplo n.º 10
0
def main(args) -> None:
    sampler = make_sampler(args)
    dataset = make_dataset(
        args,
        sampler=sampler,
        transform=Compose([CenterCropVideo(100), CollectFrames(), PILVideoToTensor()]),
    )
    loader = DataLoader(
        dataset,
        num_workers=args.workers,
        batch_size=args.batch_size,
        shuffle=args.shuffle,
        pin_memory=args.pin_memory,
    )
    benchmark_dataloader(
        loader,
        max_iterations=args.max_iterations,
        profile=args.profile,
        profile_callgrind=args.profile_callgrind,
    )
Ejemplo n.º 11
0
IMAGE_TEST_TRANSFORMS = T.Compose([
    # image_rescale_zero_to_1_transform(),
    T.ToPILImage(),
    T.Resize(cfg.RESIZE),
    T.CenterCrop(cfg.CROP_SIZE),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
Y_TRANSFORMS = ApplyToKeysTransform(torch.FloatTensor)
# Y_TRANSFORMS = None
VIDEO_TRAIN_TRANSFORMS = T.Compose([
    ResizeVideo(cfg.RESIZE),
    RandomCropVideo(cfg.CROP_SIZE),
    RandomHorizontalFlipVideo(),
    CollectFrames(),
    PILVideoToTensor(rescale=True),
    RescaleInRange(-1, 1)
])
VIDEO_TEST_TRANSFORMS = T.Compose([
    ResizeVideo((cfg.RESIZE, cfg.RESIZE)),
    CenterCropVideo((cfg.CROP_SIZE, cfg.CROP_SIZE)),
    CollectFrames(),
    PILVideoToTensor(rescale=True),
    RescaleInRange(-1, 1)
])
FRAMES_TRAIN_TRANSFORMS = T.Compose([VIDEO_TRAIN_TRANSFORMS, TimeToChannel()])
FRAMES_TEST_TRANSFORMS = T.Compose([VIDEO_TEST_TRANSFORMS, TimeToChannel()])


def get_dataset(dset_name,
                train_transforms=None,
Ejemplo n.º 12
0
    def test_propagates_label_unchanged(self):
        video = pil_video(min_width=1, min_height=1).example()
        transform = PILVideoToTensor()

        assert_preserves_label(transform, video)
Ejemplo n.º 13
0
 def test_repr(self):
     assert repr(PILVideoToTensor()) == "PILVideoToTensor()"
 def test_repr(self):
     assert (repr(PILVideoToTensor()) ==
             "PILVideoToTensor(rescale=True, ordering='CTHW')")
Ejemplo n.º 15
0
def get_transforms(
    args, model_settings: RGB2DModelSettings
) -> Tuple[Callable[[Any], torch.Tensor], Callable[[Any], torch.Tensor]]:
    train_transforms = []

    # model_settings.input_size is to be interpreted based on model_settings.input_order
    input_order = model_settings.input_order.lower()
    if input_order.endswith("hw"):
        input_height, input_width = model_input_size = model_settings.input_size[
            -2:]
    else:
        raise NotImplementedError(
            "Unsupported input ordering: {}".format(input_order))

    if args.augment_hflip:
        LOG.info("Using horizontal flipping")
        train_transforms.append(RandomHorizontalFlipVideo())
    if args.preserve_aspect_ratio:
        LOG.info(f"Preserving aspect ratio of videos")
        rescaled_size: Union[int, Tuple[int, int]] = int(
            input_height * args.image_scale_factor)
    else:
        rescaled_size = (
            int(input_height * args.image_scale_factor),
            int(input_width * args.image_scale_factor),
        )
        LOG.info(f"Squashing videos to {rescaled_size}")
    train_transforms.append(ResizeVideo(rescaled_size))
    LOG.info(f"Resizing videos to {rescaled_size}")
    if args.augment_crop:
        LOG.info(f"Using multiscale cropping "
                 f"(scales: {args.augment_crop_scales}, "
                 f"fixed_crops: {args.augment_crop_fixed_crops}, "
                 f"more_fixed_crops: {args.augment_crop_more_fixed_crops}"
                 f")")
        train_transforms.append(
            MultiScaleCropVideo(
                model_input_size,
                scales=args.augment_crop_scales,
                fixed_crops=args.augment_crop_fixed_crops,
                more_fixed_crops=args.augment_crop_more_fixed_crops,
            ))
    else:
        LOG.info(f"Cropping videos to {model_input_size}")
        train_transforms.append(RandomCropVideo(model_input_size))

    channel_dim = input_order.find("c")
    if channel_dim == -1:
        raise ValueError(
            f"Could not determine channel position in input_order {input_order!r}"
        )
    if model_settings.input_space == "BGR":
        LOG.info(f"Flipping channels from RGB to BGR")
        channel_transform = FlipChannels(channel_dim)
    else:
        assert model_settings.input_space == "RGB"
        channel_transform = IdentityTransform()
    common_transforms = [
        PILVideoToTensor(
            rescale=model_settings.input_range[-1] != 255,
            ordering=input_order,
        ),
        channel_transform,
        NormalizeVideo(mean=model_settings.mean,
                       std=model_settings.std,
                       channel_dim=channel_dim),
    ]
    train_transform = Compose(train_transforms + common_transforms)
    LOG.info(f"Training transform: {train_transform!r}")
    validation_transform = Compose(
        [ResizeVideo(rescaled_size),
         CenterCropVideo(model_input_size)] + common_transforms)
    LOG.info(f"Validation transform: {validation_transform!r}")
    return train_transform, validation_transform