def __init__(self, config, *args, **kwargs):
        transform_params = config.transforms
        assert OmegaConf.is_dict(transform_params) or OmegaConf.is_list(
            transform_params)
        if OmegaConf.is_dict(transform_params):
            transform_params = [transform_params]

        transforms_list = []

        for param in transform_params:
            if OmegaConf.is_dict(param):
                # This will throw config error if missing
                transform_type = param.type
                transform_param = param.get("params", OmegaConf.create({}))
            else:
                assert isinstance(param, str), (
                    "Each transform should either be str or dict containing " +
                    "type and params")
                transform_type = param
                transform_param = OmegaConf.create([])

            transform = getattr(transforms, transform_type, None)
            if transform is None:
                from mmf.utils.env import setup_torchaudio

                setup_torchaudio()
                from torchaudio import transforms as torchaudio_transforms

                transform = getattr(torchaudio_transforms, transform_type,
                                    None)
            # If torchvision or torchaudiodoesn't contain this, check our registry
            # if we implemented a custom transform as processor
            if transform is None:
                transform = registry.get_processor_class(transform_type)
            assert transform is not None, (
                f"transform {transform_type} is not present in torchvision, " +
                "torchaudio or processor registry")

            # https://github.com/omry/omegaconf/issues/248
            transform_param = OmegaConf.to_container(transform_param)
            # If a dict, it will be passed as **kwargs, else a list is *args
            if isinstance(transform_param, collections.abc.Mapping):
                transform_object = transform(**transform_param)
            else:
                transform_object = transform(*transform_param)

            transforms_list.append(transform_object)

        self.transform = transforms.Compose(transforms_list)
Beispiel #2
0
    def __init__(self, config: ProcessorConfigType, *args, **kwargs):
        if not hasattr(config, "type"):
            raise AttributeError(
                "Config must have 'type' attribute to specify type of processor"
            )

        processor_class = registry.get_processor_class(config.type)

        params = {}
        if not hasattr(config, "params"):
            logger.warning("Config doesn't have 'params' attribute to "
                           "specify parameters of the processor "
                           f"of type {config.type}. Setting to default {{}}")
        else:
            params = config.params

        self.processor = processor_class(params, *args, **kwargs)

        self._dir_representation = dir(self)
Beispiel #3
0
    def __init__(self, config, *args, **kwargs):
        self.writer = registry.get("writer")

        if not hasattr(config, "type"):
            raise AttributeError(
                "Config must have 'type' attribute to specify type of processor"
            )

        processor_class = registry.get_processor_class(config.type)

        params = {}
        if not hasattr(config, "params"):
            warnings.warn("Config doesn't have 'params' attribute to "
                          "specify parameters of the processor "
                          "of type {}. Setting to default {{}}".format(
                              config.type))
        else:
            params = config.params

        self.processor = processor_class(params, *args, **kwargs)

        self._dir_representation = dir(self)
    def get_transform_object(self, transform_type, transform_params):
        from pytorchvideo import transforms as ptv_transforms

        # Look for the transform in:
        # 1) pytorchvideo.transforms
        transform = getattr(ptv_transforms, transform_type, None)
        if transform is None:
            # 2) processor registry
            transform = registry.get_processor_class(transform_type)
        if transform is not None:
            return self.instantiate_transform(transform, transform_params)

        # 3) torchvision.transforms
        img_transform = getattr(img_transforms, transform_type, None)
        assert img_transform is not None, (
            f"transform {transform_type} is not found in pytorchvideo "
            "transforms, processor registry, or torchvision transforms")
        # To use the image transform on a video, we need to permute the axes
        # to (T,C,H,W) and back
        return img_transforms.Compose([
            ptv_transforms.Permute((1, 0, 2, 3)),
            self.instantiate_transform(img_transform, transform_params),
            ptv_transforms.Permute((1, 0, 2, 3)),
        ])