def test_mmft_from_params(self): modalities_config = [ MMFTransformerModalityConfig( type="image", key="image", embedding_dim=256, position_dim=1, segment_id=0, encoder=IdentityEncoder.Config(), ), MMFTransformerModalityConfig( type="text", key="text", embedding_dim=768, position_dim=512, segment_id=1, encoder=IdentityEncoder.Config(), ), ] mmft = MMFTransformer.from_params(modalities=modalities_config, num_labels=2) mmft.build() config = OmegaConf.structured( MMFTransformer.Config(modalities=modalities_config, num_labels=2)) self.assertIsNotNone(mmft) self.assertEqual(mmft.config, config)
def test_modality_key_preprocessing(self): self._text_modality_config.key = "body" second_text_modality_config = MMFTransformerModalityConfig( type="text", key="ocr", embedding_dim=756, position_dim=128, segment_id=2, encoder=TextEncoderFactory.Config(type=TextEncoderTypes.identity), ) modalities_config = [ self._image_modality_config, self._text_modality_config, second_text_modality_config, ] config = MMFTransformer.Config(modalities=modalities_config, num_labels=2) mmft = build_model(config) sample_list = SampleList() sample_list.image = torch.rand(2, 256) sample_list.body = torch.randint(0, 512, (2, 128)) sample_list.ocr = torch.randint(0, 512, (2, 128)) sample_list.lm_label_ids = torch.randint(-1, 30522, (2, 128)) lm_labels_sum = sample_list.lm_label_ids.sum().item() * 2 transformer_input = mmft.preprocess_sample(sample_list) self._compare_processed_for_multimodality(transformer_input, lm_labels_sum)
def test_mmft_from_build_model(self): modalities_config = [ MMFTransformerModalityConfig( type="image", key="image", embedding_dim=256, position_dim=1, segment_id=0, encoder=ImageEncoderFactory.Config( type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config(pretrained=False), ), ), MMFTransformerModalityConfig( type="text", key="text", embedding_dim=756, position_dim=512, segment_id=1, encoder=TextEncoderFactory.Config( type=TextEncoderTypes.identity), ), ] config = MMFTransformer.Config(modalities=modalities_config, num_labels=2) mmft = build_model(config) self.assertIsNotNone(mmft)
def test_tie_mlm_head_weight_to_encoder(self): self._text_modality_config = MMFTransformerModalityConfig( type="text", key="text", embedding_dim=768, position_dim=128, segment_id=0, encoder=TextEncoderFactory.Config( type=TextEncoderTypes.transformer), ) heads = [MLM.Config()] modalities_config = [ self._image_modality_config, self._text_modality_config ] config = MMFTransformer.Config( heads=heads, modalities=modalities_config, num_labels=2, tie_weight_to_encoder="text", ) mmft = build_model(config) test_utils.compare_tensors( mmft.heads[0].cls.predictions.decoder.weight, mmft.encoders["text"].embeddings.word_embeddings.weight, )
def initialize(self, ctx): self.manifest = ctx.manifest properties = ctx.system_properties model_dir = properties.get("model_dir") serialized_file = self.manifest['model']['serializedFile'] model_pt_path = os.path.join(model_dir, serialized_file) self.map_location = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device(self.map_location + ":" + str(properties.get("gpu_id")) if torch.cuda. is_available() else self.map_location) # reading the csv file which include all the labels in the dataset to make the class/index mapping # and matching the output of the model with num labels from dataset df = pd.read_csv('./charades_action_lables.csv') label_set = set() df['action_labels'] = df['action_labels'].str.replace('"', '') labels_initial = df['action_labels'].tolist() labels = [] for sublist in labels_initial: new_sublist = ast.literal_eval(sublist) labels.append(new_sublist) for item in new_sublist: label_set.add(item) classes = sorted(list(label_set)) self.class_to_idx = {classes[i]: i for i in range(len(classes))} self.classes = classes self.labels = labels self.idx_to_class = classes config = OmegaConf.load('config.yaml') print("*********** config keyssss **********", config.keys()) setup_very_basic_config() setup_imports() self.model = MMFTransformer(config.model_config.mmf_transformer) self.model.build() self.model.init_losses() self.processor = build_processors( config.dataset_config["charades"].processors) state_dict = torch.load(serialized_file, map_location=self.device) self.model.load_state_dict(state_dict) self.model.to(self.device) self.model.eval() self.initialized = True print( "********* files in temp direcotry that .mar file got extracted *********", os.listdir(model_dir))
def test_preprocessing_with_resnet_encoder(self): self._image_modality_config = MMFTransformerModalityConfig( type="image", key="image", embedding_dim=2048, position_dim=1, segment_id=0, encoder=ImageEncoderFactory.Config( type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config(pretrained=False), ), ) modalities_config = [ self._image_modality_config, self._text_modality_config ] config = MMFTransformer.Config(modalities=modalities_config, num_labels=2) mmft = build_model(config) sample_list = SampleList() sample_list.image = torch.rand(2, 3, 224, 224) sample_list.text = torch.randint(0, 512, (2, 128)) transformer_input = mmft.preprocess_sample(sample_list) input_ids = transformer_input["input_ids"] self.assertEqual(input_ids["image"].dim(), 3) self.assertEqual(list(input_ids["image"].size()), [2, 1, 2048]) self.assertEqual(input_ids["text"].dim(), 2) self.assertEqual(list(input_ids["text"].size()), [2, 128]) position_ids = transformer_input["position_ids"] test_utils.compare_tensors(position_ids["image"], torch.tensor([[0], [0]])) test_utils.compare_tensors( position_ids["text"], torch.arange(0, 128).unsqueeze(0).expand((2, 128))) masks = transformer_input["masks"] test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]])) test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long()) segment_ids = transformer_input["segment_ids"] test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]])) test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long())
def test_one_dim_feature_preprocessing(self): modalities_config = [ self._image_modality_config, self._text_modality_config ] config = MMFTransformer.Config(modalities=modalities_config, num_labels=2) mmft = build_model(config) sample_list = SampleList() sample_list.image = torch.rand(2, 256) sample_list.text = torch.randint(0, 512, (2, 128)) transformer_input = mmft.preprocess_sample(sample_list) input_ids = transformer_input["input_ids"] self.assertEqual(input_ids["image"].dim(), 3) self.assertEqual(list(input_ids["image"].size()), [2, 1, 256]) self.assertEqual(input_ids["text"].dim(), 2) self.assertEqual(list(input_ids["text"].size()), [2, 128]) position_ids = transformer_input["position_ids"] test_utils.compare_tensors(position_ids["image"], torch.tensor([[0], [0]])) test_utils.compare_tensors( position_ids["text"], torch.arange(0, 128).unsqueeze(0).expand((2, 128))) masks = transformer_input["masks"] masks = mmft._infer_masks(sample_list, input_ids) test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]])) test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long()) segment_ids = transformer_input["segment_ids"] test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]])) test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long()) mlm_labels = transformer_input["mlm_labels"] test_utils.compare_tensors( mlm_labels["combined_labels"], torch.full((2, 129), dtype=torch.long, fill_value=-1), )
def test_custom_feature_and_mask_preprocessing(self): extra_modality = MMFTransformerModalityConfig( type="my_random_feature", key="my_random_feature", embedding_dim=128, position_dim=4, segment_id=3, encoder=EncoderFactory.Config(type="identity"), ) modalities_config = [ self._image_modality_config, self._text_modality_config, extra_modality, ] config = MMFTransformer.Config(modalities=modalities_config, num_labels=2) mmft = build_model(config) sample_list = SampleList() sample_list.image = torch.rand(2, 256) sample_list.text = torch.randint(0, 512, (2, 128)) sample_list.text_mask = torch.ones(2, 128) sample_list.text_mask[:, 70:] = 0 sample_list.my_random_feature = torch.rand(2, 4, 128) sample_list.my_random_feature_mask = torch.ones(2, 4) sample_list.my_random_feature_mask[:, 3:] = 0 transformer_input = mmft.preprocess_sample(sample_list) input_ids = transformer_input["input_ids"] self.assertEqual(input_ids["image"].dim(), 3) self.assertEqual(list(input_ids["image"].size()), [2, 1, 256]) self.assertEqual(input_ids["text"].dim(), 2) self.assertEqual(list(input_ids["text"].size()), [2, 128]) self.assertEqual(input_ids["my_random_feature"].dim(), 3) self.assertEqual(list(input_ids["my_random_feature"].size()), [2, 4, 128]) position_ids = transformer_input["position_ids"] test_utils.compare_tensors(position_ids["image"], torch.tensor([[0], [0]])) test_utils.compare_tensors( position_ids["text"], torch.arange(0, 128).unsqueeze(0).expand((2, 128))) test_utils.compare_tensors( position_ids["my_random_feature"], torch.arange(0, 4).unsqueeze(0).expand((2, 4)), ) masks = transformer_input["masks"] test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]])) self.assertEqual(masks["text"].sum().item(), 140) self.assertEqual(masks["my_random_feature"].sum().item(), 6) segment_ids = transformer_input["segment_ids"] test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]])) test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long()) test_utils.compare_tensors( segment_ids["my_random_feature"], torch.full((2, 4), dtype=torch.long, fill_value=3).long(), )
def test_mmft_pretrained(self): mmft = MMFTransformer.from_params(num_labels=2) self.assertIsNotNone(mmft)
def test_preprocessing_with_mvit_encoder(self): encoder_config = OmegaConf.create({ "name": "pytorchvideo", "model_name": "mvit_base_32x3", "random_init": True, "drop_last_n_layers": 0, "pooler_name": "cls", "spatial_size": 224, "temporal_size": 8, "head": None, "embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]], "atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]], "pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]], "pool_kv_stride_adaptive": [1, 8, 8], "pool_kvq_kernel": [3, 3, 3], }) self._image_modality_config = MMFTransformerModalityConfig( type="image", key="image", embedding_dim=768, position_dim=1, segment_id=0, encoder=encoder_config, ) modalities_config = [ self._image_modality_config, self._text_modality_config ] config = MMFTransformer.Config(modalities=modalities_config, num_labels=2) mmft = build_model(config) sample_list = SampleList() sample_list.image = torch.rand((2, 3, 8, 224, 224)) sample_list.text = torch.randint(0, 512, (2, 128)) transformer_input = mmft.preprocess_sample(sample_list) input_ids = transformer_input["input_ids"] self.assertEqual(input_ids["image"].dim(), 3) self.assertEqual(list(input_ids["image"].size()), [2, 1, 768]) self.assertEqual(input_ids["text"].dim(), 2) self.assertEqual(list(input_ids["text"].size()), [2, 128]) position_ids = transformer_input["position_ids"] test_utils.compare_tensors(position_ids["image"], torch.tensor([[0], [0]])) test_utils.compare_tensors( position_ids["text"], torch.arange(0, 128).unsqueeze(0).expand((2, 128))) masks = transformer_input["masks"] test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]])) test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long()) segment_ids = transformer_input["segment_ids"] test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]])) test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long())
class MMFHandler(BaseHandler): """ Transformers handler class for MMFTransformerWithVideoAudio model. """ def __init__(self): super(MMFHandler, self).__init__() self.initialized = False def initialize(self, ctx): self.manifest = ctx.manifest properties = ctx.system_properties model_dir = properties.get("model_dir") serialized_file = self.manifest['model']['serializedFile'] model_pt_path = os.path.join(model_dir, serialized_file) self.map_location = "cuda" if torch.cuda.is_available() else "cpu" self.device = torch.device(self.map_location + ":" + str(properties.get("gpu_id")) if torch.cuda. is_available() else self.map_location) # reading the csv file which include all the labels in the dataset to make the class/index mapping # and matching the output of the model with num labels from dataset df = pd.read_csv('./charades_action_lables.csv') label_set = set() df['action_labels'] = df['action_labels'].str.replace('"', '') labels_initial = df['action_labels'].tolist() labels = [] for sublist in labels_initial: new_sublist = ast.literal_eval(sublist) labels.append(new_sublist) for item in new_sublist: label_set.add(item) classes = sorted(list(label_set)) self.class_to_idx = {classes[i]: i for i in range(len(classes))} self.classes = classes self.labels = labels self.idx_to_class = classes config = OmegaConf.load('config.yaml') print("*********** config keyssss **********", config.keys()) setup_very_basic_config() setup_imports() self.model = MMFTransformer(config.model_config.mmf_transformer) self.model.build() self.model.init_losses() self.processor = build_processors( config.dataset_config["charades"].processors) state_dict = torch.load(serialized_file, map_location=self.device) self.model.load_state_dict(state_dict) self.model.to(self.device) self.model.eval() self.initialized = True print( "********* files in temp direcotry that .mar file got extracted *********", os.listdir(model_dir)) def preprocess(self, requests): """ Preprocessing, based on processor defined for MMF model. """ def create_sample(video_transfomred, audio_transfomred, text_tensor, video_label): label = [self.class_to_idx[l] for l in video_label] one_hot_label = torch.zeros(len(self.class_to_idx)) one_hot_label[label] = 1 current_sample = Sample() current_sample.video = video_transfomred current_sample.audio = audio_transfomred current_sample.update(text_tensor) current_sample.targets = one_hot_label current_sample.dataset_type = 'test' current_sample.dataset_name = 'charades' return SampleList([current_sample]).to(self.device) for idx, data in enumerate(requests): raw_script = data.get('script') script = raw_script.decode('utf-8') raw_label = data.get('labels') video_label = raw_label.decode('utf-8') video_label = [video_label] video = io.BytesIO(data['data']) video_tensor, audio_tensor, info = torchvision.io.read_video(video) text_tensor = self.processor["text_processor"]({"text": script}) video_transformed = self.processor["video_test_processor"]( video_tensor) audio_transformed = self.processor["audio_processor"](audio_tensor) samples = create_sample(video_transformed, audio_transformed, text_tensor, video_label) return samples def inference(self, samples): """ Predict the class (or classes) of the received text using the serialized transformers checkpoint. """ if torch.cuda.is_available(): with torch.cuda.device(samples.get_device()): output = self.model(samples) else: output = self.model(samples) sigmoid_scores = torch.sigmoid(output["scores"]) binary_scores = torch.round(sigmoid_scores) score = binary_scores[0] score = score.nonzero() predictions = [] for item in score: predictions.append(self.idx_to_class[item.item()]) print("************** predictions *********", predictions) return predictions def postprocess(self, inference_output): # TODO: Add any needed post-processing of the model predictions here return [inference_output]