def test_mmft_from_params(self):
        modalities_config = [
            MMFTransformerModalityConfig(
                type="image",
                key="image",
                embedding_dim=256,
                position_dim=1,
                segment_id=0,
                encoder=IdentityEncoder.Config(),
            ),
            MMFTransformerModalityConfig(
                type="text",
                key="text",
                embedding_dim=768,
                position_dim=512,
                segment_id=1,
                encoder=IdentityEncoder.Config(),
            ),
        ]
        mmft = MMFTransformer.from_params(modalities=modalities_config,
                                          num_labels=2)
        mmft.build()

        config = OmegaConf.structured(
            MMFTransformer.Config(modalities=modalities_config, num_labels=2))
        self.assertIsNotNone(mmft)
        self.assertEqual(mmft.config, config)
    def test_modality_key_preprocessing(self):
        self._text_modality_config.key = "body"
        second_text_modality_config = MMFTransformerModalityConfig(
            type="text",
            key="ocr",
            embedding_dim=756,
            position_dim=128,
            segment_id=2,
            encoder=TextEncoderFactory.Config(type=TextEncoderTypes.identity),
        )

        modalities_config = [
            self._image_modality_config,
            self._text_modality_config,
            second_text_modality_config,
        ]
        config = MMFTransformer.Config(modalities=modalities_config,
                                       num_labels=2)
        mmft = build_model(config)

        sample_list = SampleList()
        sample_list.image = torch.rand(2, 256)
        sample_list.body = torch.randint(0, 512, (2, 128))
        sample_list.ocr = torch.randint(0, 512, (2, 128))
        sample_list.lm_label_ids = torch.randint(-1, 30522, (2, 128))
        lm_labels_sum = sample_list.lm_label_ids.sum().item() * 2

        transformer_input = mmft.preprocess_sample(sample_list)
        self._compare_processed_for_multimodality(transformer_input,
                                                  lm_labels_sum)
 def test_mmft_from_build_model(self):
     modalities_config = [
         MMFTransformerModalityConfig(
             type="image",
             key="image",
             embedding_dim=256,
             position_dim=1,
             segment_id=0,
             encoder=ImageEncoderFactory.Config(
                 type=ImageEncoderTypes.resnet152,
                 params=ResNet152ImageEncoder.Config(pretrained=False),
             ),
         ),
         MMFTransformerModalityConfig(
             type="text",
             key="text",
             embedding_dim=756,
             position_dim=512,
             segment_id=1,
             encoder=TextEncoderFactory.Config(
                 type=TextEncoderTypes.identity),
         ),
     ]
     config = MMFTransformer.Config(modalities=modalities_config,
                                    num_labels=2)
     mmft = build_model(config)
     self.assertIsNotNone(mmft)
    def test_tie_mlm_head_weight_to_encoder(self):
        self._text_modality_config = MMFTransformerModalityConfig(
            type="text",
            key="text",
            embedding_dim=768,
            position_dim=128,
            segment_id=0,
            encoder=TextEncoderFactory.Config(
                type=TextEncoderTypes.transformer),
        )
        heads = [MLM.Config()]
        modalities_config = [
            self._image_modality_config, self._text_modality_config
        ]
        config = MMFTransformer.Config(
            heads=heads,
            modalities=modalities_config,
            num_labels=2,
            tie_weight_to_encoder="text",
        )
        mmft = build_model(config)

        test_utils.compare_tensors(
            mmft.heads[0].cls.predictions.decoder.weight,
            mmft.encoders["text"].embeddings.word_embeddings.weight,
        )
Beispiel #5
0
    def initialize(self, ctx):
        self.manifest = ctx.manifest
        properties = ctx.system_properties
        model_dir = properties.get("model_dir")
        serialized_file = self.manifest['model']['serializedFile']
        model_pt_path = os.path.join(model_dir, serialized_file)
        self.map_location = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.map_location + ":" +
                                   str(properties.get("gpu_id")) if torch.cuda.
                                   is_available() else self.map_location)

        # reading the csv file which include all the labels in the dataset to make the class/index mapping
        # and matching the output of the model with num labels from dataset
        df = pd.read_csv('./charades_action_lables.csv')
        label_set = set()
        df['action_labels'] = df['action_labels'].str.replace('"', '')
        labels_initial = df['action_labels'].tolist()
        labels = []
        for sublist in labels_initial:
            new_sublist = ast.literal_eval(sublist)
            labels.append(new_sublist)
            for item in new_sublist:
                label_set.add(item)
        classes = sorted(list(label_set))
        self.class_to_idx = {classes[i]: i for i in range(len(classes))}
        self.classes = classes
        self.labels = labels
        self.idx_to_class = classes
        config = OmegaConf.load('config.yaml')
        print("*********** config keyssss **********", config.keys())
        setup_very_basic_config()
        setup_imports()
        self.model = MMFTransformer(config.model_config.mmf_transformer)
        self.model.build()
        self.model.init_losses()
        self.processor = build_processors(
            config.dataset_config["charades"].processors)
        state_dict = torch.load(serialized_file, map_location=self.device)
        self.model.load_state_dict(state_dict)
        self.model.to(self.device)
        self.model.eval()
        self.initialized = True
        print(
            "********* files in temp direcotry that .mar file got extracted *********",
            os.listdir(model_dir))
    def test_preprocessing_with_resnet_encoder(self):
        self._image_modality_config = MMFTransformerModalityConfig(
            type="image",
            key="image",
            embedding_dim=2048,
            position_dim=1,
            segment_id=0,
            encoder=ImageEncoderFactory.Config(
                type=ImageEncoderTypes.resnet152,
                params=ResNet152ImageEncoder.Config(pretrained=False),
            ),
        )
        modalities_config = [
            self._image_modality_config, self._text_modality_config
        ]
        config = MMFTransformer.Config(modalities=modalities_config,
                                       num_labels=2)
        mmft = build_model(config)

        sample_list = SampleList()
        sample_list.image = torch.rand(2, 3, 224, 224)
        sample_list.text = torch.randint(0, 512, (2, 128))

        transformer_input = mmft.preprocess_sample(sample_list)

        input_ids = transformer_input["input_ids"]
        self.assertEqual(input_ids["image"].dim(), 3)
        self.assertEqual(list(input_ids["image"].size()), [2, 1, 2048])

        self.assertEqual(input_ids["text"].dim(), 2)
        self.assertEqual(list(input_ids["text"].size()), [2, 128])

        position_ids = transformer_input["position_ids"]
        test_utils.compare_tensors(position_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(
            position_ids["text"],
            torch.arange(0, 128).unsqueeze(0).expand((2, 128)))

        masks = transformer_input["masks"]
        test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
        test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long())

        segment_ids = transformer_input["segment_ids"]
        test_utils.compare_tensors(segment_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(segment_ids["text"],
                                   torch.ones((2, 128)).long())
    def test_one_dim_feature_preprocessing(self):
        modalities_config = [
            self._image_modality_config, self._text_modality_config
        ]
        config = MMFTransformer.Config(modalities=modalities_config,
                                       num_labels=2)
        mmft = build_model(config)

        sample_list = SampleList()
        sample_list.image = torch.rand(2, 256)
        sample_list.text = torch.randint(0, 512, (2, 128))

        transformer_input = mmft.preprocess_sample(sample_list)
        input_ids = transformer_input["input_ids"]
        self.assertEqual(input_ids["image"].dim(), 3)
        self.assertEqual(list(input_ids["image"].size()), [2, 1, 256])

        self.assertEqual(input_ids["text"].dim(), 2)
        self.assertEqual(list(input_ids["text"].size()), [2, 128])

        position_ids = transformer_input["position_ids"]
        test_utils.compare_tensors(position_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(
            position_ids["text"],
            torch.arange(0, 128).unsqueeze(0).expand((2, 128)))

        masks = transformer_input["masks"]
        masks = mmft._infer_masks(sample_list, input_ids)
        test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
        test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long())

        segment_ids = transformer_input["segment_ids"]
        test_utils.compare_tensors(segment_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(segment_ids["text"],
                                   torch.ones((2, 128)).long())

        mlm_labels = transformer_input["mlm_labels"]
        test_utils.compare_tensors(
            mlm_labels["combined_labels"],
            torch.full((2, 129), dtype=torch.long, fill_value=-1),
        )
    def test_custom_feature_and_mask_preprocessing(self):
        extra_modality = MMFTransformerModalityConfig(
            type="my_random_feature",
            key="my_random_feature",
            embedding_dim=128,
            position_dim=4,
            segment_id=3,
            encoder=EncoderFactory.Config(type="identity"),
        )

        modalities_config = [
            self._image_modality_config,
            self._text_modality_config,
            extra_modality,
        ]
        config = MMFTransformer.Config(modalities=modalities_config,
                                       num_labels=2)
        mmft = build_model(config)

        sample_list = SampleList()
        sample_list.image = torch.rand(2, 256)
        sample_list.text = torch.randint(0, 512, (2, 128))
        sample_list.text_mask = torch.ones(2, 128)
        sample_list.text_mask[:, 70:] = 0
        sample_list.my_random_feature = torch.rand(2, 4, 128)
        sample_list.my_random_feature_mask = torch.ones(2, 4)
        sample_list.my_random_feature_mask[:, 3:] = 0

        transformer_input = mmft.preprocess_sample(sample_list)
        input_ids = transformer_input["input_ids"]
        self.assertEqual(input_ids["image"].dim(), 3)
        self.assertEqual(list(input_ids["image"].size()), [2, 1, 256])

        self.assertEqual(input_ids["text"].dim(), 2)
        self.assertEqual(list(input_ids["text"].size()), [2, 128])

        self.assertEqual(input_ids["my_random_feature"].dim(), 3)
        self.assertEqual(list(input_ids["my_random_feature"].size()),
                         [2, 4, 128])

        position_ids = transformer_input["position_ids"]
        test_utils.compare_tensors(position_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(
            position_ids["text"],
            torch.arange(0, 128).unsqueeze(0).expand((2, 128)))
        test_utils.compare_tensors(
            position_ids["my_random_feature"],
            torch.arange(0, 4).unsqueeze(0).expand((2, 4)),
        )

        masks = transformer_input["masks"]
        test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
        self.assertEqual(masks["text"].sum().item(), 140)
        self.assertEqual(masks["my_random_feature"].sum().item(), 6)

        segment_ids = transformer_input["segment_ids"]
        test_utils.compare_tensors(segment_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(segment_ids["text"],
                                   torch.ones((2, 128)).long())
        test_utils.compare_tensors(
            segment_ids["my_random_feature"],
            torch.full((2, 4), dtype=torch.long, fill_value=3).long(),
        )
 def test_mmft_pretrained(self):
     mmft = MMFTransformer.from_params(num_labels=2)
     self.assertIsNotNone(mmft)
Beispiel #10
0
    def test_preprocessing_with_mvit_encoder(self):
        encoder_config = OmegaConf.create({
            "name":
            "pytorchvideo",
            "model_name":
            "mvit_base_32x3",
            "random_init":
            True,
            "drop_last_n_layers":
            0,
            "pooler_name":
            "cls",
            "spatial_size":
            224,
            "temporal_size":
            8,
            "head":
            None,
            "embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
            "atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
            "pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
            "pool_kv_stride_adaptive": [1, 8, 8],
            "pool_kvq_kernel": [3, 3, 3],
        })
        self._image_modality_config = MMFTransformerModalityConfig(
            type="image",
            key="image",
            embedding_dim=768,
            position_dim=1,
            segment_id=0,
            encoder=encoder_config,
        )
        modalities_config = [
            self._image_modality_config, self._text_modality_config
        ]
        config = MMFTransformer.Config(modalities=modalities_config,
                                       num_labels=2)
        mmft = build_model(config)

        sample_list = SampleList()
        sample_list.image = torch.rand((2, 3, 8, 224, 224))
        sample_list.text = torch.randint(0, 512, (2, 128))

        transformer_input = mmft.preprocess_sample(sample_list)
        input_ids = transformer_input["input_ids"]
        self.assertEqual(input_ids["image"].dim(), 3)
        self.assertEqual(list(input_ids["image"].size()), [2, 1, 768])

        self.assertEqual(input_ids["text"].dim(), 2)
        self.assertEqual(list(input_ids["text"].size()), [2, 128])

        position_ids = transformer_input["position_ids"]
        test_utils.compare_tensors(position_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(
            position_ids["text"],
            torch.arange(0, 128).unsqueeze(0).expand((2, 128)))

        masks = transformer_input["masks"]
        test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
        test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long())

        segment_ids = transformer_input["segment_ids"]
        test_utils.compare_tensors(segment_ids["image"],
                                   torch.tensor([[0], [0]]))
        test_utils.compare_tensors(segment_ids["text"],
                                   torch.ones((2, 128)).long())
Beispiel #11
0
class MMFHandler(BaseHandler):
    """
    Transformers handler class for  MMFTransformerWithVideoAudio model.
    """
    def __init__(self):
        super(MMFHandler, self).__init__()
        self.initialized = False

    def initialize(self, ctx):
        self.manifest = ctx.manifest
        properties = ctx.system_properties
        model_dir = properties.get("model_dir")
        serialized_file = self.manifest['model']['serializedFile']
        model_pt_path = os.path.join(model_dir, serialized_file)
        self.map_location = "cuda" if torch.cuda.is_available() else "cpu"
        self.device = torch.device(self.map_location + ":" +
                                   str(properties.get("gpu_id")) if torch.cuda.
                                   is_available() else self.map_location)

        # reading the csv file which include all the labels in the dataset to make the class/index mapping
        # and matching the output of the model with num labels from dataset
        df = pd.read_csv('./charades_action_lables.csv')
        label_set = set()
        df['action_labels'] = df['action_labels'].str.replace('"', '')
        labels_initial = df['action_labels'].tolist()
        labels = []
        for sublist in labels_initial:
            new_sublist = ast.literal_eval(sublist)
            labels.append(new_sublist)
            for item in new_sublist:
                label_set.add(item)
        classes = sorted(list(label_set))
        self.class_to_idx = {classes[i]: i for i in range(len(classes))}
        self.classes = classes
        self.labels = labels
        self.idx_to_class = classes
        config = OmegaConf.load('config.yaml')
        print("*********** config keyssss **********", config.keys())
        setup_very_basic_config()
        setup_imports()
        self.model = MMFTransformer(config.model_config.mmf_transformer)
        self.model.build()
        self.model.init_losses()
        self.processor = build_processors(
            config.dataset_config["charades"].processors)
        state_dict = torch.load(serialized_file, map_location=self.device)
        self.model.load_state_dict(state_dict)
        self.model.to(self.device)
        self.model.eval()
        self.initialized = True
        print(
            "********* files in temp direcotry that .mar file got extracted *********",
            os.listdir(model_dir))

    def preprocess(self, requests):
        """ Preprocessing, based on processor defined for MMF model.
            """
        def create_sample(video_transfomred, audio_transfomred, text_tensor,
                          video_label):

            label = [self.class_to_idx[l] for l in video_label]

            one_hot_label = torch.zeros(len(self.class_to_idx))
            one_hot_label[label] = 1

            current_sample = Sample()
            current_sample.video = video_transfomred
            current_sample.audio = audio_transfomred
            current_sample.update(text_tensor)
            current_sample.targets = one_hot_label
            current_sample.dataset_type = 'test'
            current_sample.dataset_name = 'charades'
            return SampleList([current_sample]).to(self.device)

        for idx, data in enumerate(requests):
            raw_script = data.get('script')
            script = raw_script.decode('utf-8')
            raw_label = data.get('labels')
            video_label = raw_label.decode('utf-8')
            video_label = [video_label]

            video = io.BytesIO(data['data'])
            video_tensor, audio_tensor, info = torchvision.io.read_video(video)
            text_tensor = self.processor["text_processor"]({"text": script})
            video_transformed = self.processor["video_test_processor"](
                video_tensor)
            audio_transformed = self.processor["audio_processor"](audio_tensor)
            samples = create_sample(video_transformed, audio_transformed,
                                    text_tensor, video_label)

        return samples

    def inference(self, samples):
        """ Predict the class (or classes) of the received text using the serialized transformers checkpoint.
        """
        if torch.cuda.is_available():
            with torch.cuda.device(samples.get_device()):
                output = self.model(samples)
        else:
            output = self.model(samples)

        sigmoid_scores = torch.sigmoid(output["scores"])
        binary_scores = torch.round(sigmoid_scores)
        score = binary_scores[0]
        score = score.nonzero()

        predictions = []
        for item in score:
            predictions.append(self.idx_to_class[item.item()])
        print("************** predictions *********", predictions)
        return predictions

    def postprocess(self, inference_output):
        # TODO: Add any needed post-processing of the model predictions here
        return [inference_output]