def get_config(self): return MaskFormerConfig.from_backbone_and_decoder_configs( backbone_config=SwinConfig(depths=[1, 1, 1, 1], ), decoder_config=DetrConfig( decoder_ffn_dim=128, num_queries=self.num_queries, decoder_attention_heads=2, d_model=self.mask_feature_size, ), mask_feature_size=self.mask_feature_size, fpn_feature_size=self.mask_feature_size, num_channels=self.num_channels, num_labels=self.num_labels, )
def get_config(self): return DetrConfig( d_model=self.hidden_size, encoder_layers=self.num_hidden_layers, decoder_layers=self.num_hidden_layers, encoder_attention_heads=self.num_attention_heads, decoder_attention_heads=self.num_attention_heads, encoder_ffn_dim=self.intermediate_size, decoder_ffn_dim=self.intermediate_size, dropout=self.hidden_dropout_prob, attention_dropout=self.attention_probs_dropout_prob, num_queries=self.num_queries, num_labels=self.num_labels, )
def prepare_config_and_inputs(self): pixel_values = floats_tensor( [self.batch_size, self.num_channels, self.min_size, self.max_size]) pixel_mask = torch.ones( [self.batch_size, self.min_size, self.max_size], device=torch_device) labels = None if self.use_labels: # labels is a list of Dict (each Dict being the labels for a given example in the batch) labels = [] for i in range(self.batch_size): target = {} target["class_labels"] = torch.randint(high=self.num_labels, size=(self.n_targets, ), device=torch_device) target["boxes"] = torch.rand(self.n_targets, 4, device=torch_device) target["masks"] = torch.rand(self.n_targets, self.min_size, self.max_size, device=torch_device) labels.append(target) config = DetrConfig( d_model=self.hidden_size, encoder_layers=self.num_hidden_layers, decoder_layers=self.num_hidden_layers, encoder_attention_heads=self.num_attention_heads, decoder_attention_heads=self.num_attention_heads, encoder_ffn_dim=self.intermediate_size, decoder_ffn_dim=self.intermediate_size, dropout=self.hidden_dropout_prob, attention_dropout=self.attention_probs_dropout_prob, num_queries=self.num_queries, num_labels=self.num_labels, ) return config, pixel_values, pixel_mask, labels
def convert_detr_checkpoint(model_name, pytorch_dump_folder_path): """ Copy/paste/tweak model's weights to our DETR structure. """ # load default config config = DetrConfig() # set backbone and dilation attributes if "resnet101" in model_name: config.backbone = "resnet101" if "dc5" in model_name: config.dilation = True is_panoptic = "panoptic" in model_name if is_panoptic: config.num_labels = 250 else: config.num_labels = 91 config.id2label = id2label config.label2id = {v: k for k, v in id2label.items()} # load feature extractor format = "coco_panoptic" if is_panoptic else "coco_detection" feature_extractor = DetrFeatureExtractor(format=format) # prepare image img = prepare_img() encoding = feature_extractor(images=img, return_tensors="pt") pixel_values = encoding["pixel_values"] logger.info(f"Converting model {model_name}...") # load original model from torch hub detr = torch.hub.load("facebookresearch/detr", model_name, pretrained=True).eval() state_dict = detr.state_dict() # rename keys for src, dest in rename_keys: if is_panoptic: src = "detr." + src rename_key(state_dict, src, dest) state_dict = rename_backbone_keys(state_dict) # query, key and value matrices need special treatment read_in_q_k_v(state_dict, is_panoptic=is_panoptic) # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them prefix = "detr.model." if is_panoptic else "model." for key in state_dict.copy().keys(): if is_panoptic: if ( key.startswith("detr") and not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor") ): val = state_dict.pop(key) state_dict["detr.model" + key[4:]] = val elif "class_labels_classifier" in key or "bbox_predictor" in key: val = state_dict.pop(key) state_dict["detr." + key] = val elif key.startswith("bbox_attention") or key.startswith("mask_head"): continue else: val = state_dict.pop(key) state_dict[prefix + key] = val else: if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"): val = state_dict.pop(key) state_dict[prefix + key] = val # finally, create HuggingFace model and load state dict model = DetrForSegmentation(config) if is_panoptic else DetrForObjectDetection(config) model.load_state_dict(state_dict) model.eval() # verify our conversion original_outputs = detr(pixel_values) outputs = model(pixel_values) assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-4) assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-4) if is_panoptic: assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4) # Save model and feature extractor logger.info(f"Saving PyTorch model and feature extractor to {pytorch_dump_folder_path}...") Path(pytorch_dump_folder_path).mkdir(exist_ok=True) model.save_pretrained(pytorch_dump_folder_path) feature_extractor.save_pretrained(pytorch_dump_folder_path)