def get_config(self):
     return MaskFormerConfig.from_backbone_and_decoder_configs(
         backbone_config=SwinConfig(depths=[1, 1, 1, 1], ),
         decoder_config=DetrConfig(
             decoder_ffn_dim=128,
             num_queries=self.num_queries,
             decoder_attention_heads=2,
             d_model=self.mask_feature_size,
         ),
         mask_feature_size=self.mask_feature_size,
         fpn_feature_size=self.mask_feature_size,
         num_channels=self.num_channels,
         num_labels=self.num_labels,
     )
 def get_config(self):
     return DetrConfig(
         d_model=self.hidden_size,
         encoder_layers=self.num_hidden_layers,
         decoder_layers=self.num_hidden_layers,
         encoder_attention_heads=self.num_attention_heads,
         decoder_attention_heads=self.num_attention_heads,
         encoder_ffn_dim=self.intermediate_size,
         decoder_ffn_dim=self.intermediate_size,
         dropout=self.hidden_dropout_prob,
         attention_dropout=self.attention_probs_dropout_prob,
         num_queries=self.num_queries,
         num_labels=self.num_labels,
     )
    def prepare_config_and_inputs(self):
        pixel_values = floats_tensor(
            [self.batch_size, self.num_channels, self.min_size, self.max_size])

        pixel_mask = torch.ones(
            [self.batch_size, self.min_size, self.max_size],
            device=torch_device)

        labels = None
        if self.use_labels:
            # labels is a list of Dict (each Dict being the labels for a given example in the batch)
            labels = []
            for i in range(self.batch_size):
                target = {}
                target["class_labels"] = torch.randint(high=self.num_labels,
                                                       size=(self.n_targets, ),
                                                       device=torch_device)
                target["boxes"] = torch.rand(self.n_targets,
                                             4,
                                             device=torch_device)
                target["masks"] = torch.rand(self.n_targets,
                                             self.min_size,
                                             self.max_size,
                                             device=torch_device)
                labels.append(target)

        config = DetrConfig(
            d_model=self.hidden_size,
            encoder_layers=self.num_hidden_layers,
            decoder_layers=self.num_hidden_layers,
            encoder_attention_heads=self.num_attention_heads,
            decoder_attention_heads=self.num_attention_heads,
            encoder_ffn_dim=self.intermediate_size,
            decoder_ffn_dim=self.intermediate_size,
            dropout=self.hidden_dropout_prob,
            attention_dropout=self.attention_probs_dropout_prob,
            num_queries=self.num_queries,
            num_labels=self.num_labels,
        )
        return config, pixel_values, pixel_mask, labels
def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
    """
    Copy/paste/tweak model's weights to our DETR structure.
    """

    # load default config
    config = DetrConfig()
    # set backbone and dilation attributes
    if "resnet101" in model_name:
        config.backbone = "resnet101"
    if "dc5" in model_name:
        config.dilation = True
    is_panoptic = "panoptic" in model_name
    if is_panoptic:
        config.num_labels = 250
    else:
        config.num_labels = 91
        config.id2label = id2label
        config.label2id = {v: k for k, v in id2label.items()}

    # load feature extractor
    format = "coco_panoptic" if is_panoptic else "coco_detection"
    feature_extractor = DetrFeatureExtractor(format=format)

    # prepare image
    img = prepare_img()
    encoding = feature_extractor(images=img, return_tensors="pt")
    pixel_values = encoding["pixel_values"]

    logger.info(f"Converting model {model_name}...")

    # load original model from torch hub
    detr = torch.hub.load("facebookresearch/detr", model_name, pretrained=True).eval()
    state_dict = detr.state_dict()
    # rename keys
    for src, dest in rename_keys:
        if is_panoptic:
            src = "detr." + src
        rename_key(state_dict, src, dest)
    state_dict = rename_backbone_keys(state_dict)
    # query, key and value matrices need special treatment
    read_in_q_k_v(state_dict, is_panoptic=is_panoptic)
    # important: we need to prepend a prefix to each of the base model keys as the head models use different attributes for them
    prefix = "detr.model." if is_panoptic else "model."
    for key in state_dict.copy().keys():
        if is_panoptic:
            if (
                key.startswith("detr")
                and not key.startswith("class_labels_classifier")
                and not key.startswith("bbox_predictor")
            ):
                val = state_dict.pop(key)
                state_dict["detr.model" + key[4:]] = val
            elif "class_labels_classifier" in key or "bbox_predictor" in key:
                val = state_dict.pop(key)
                state_dict["detr." + key] = val
            elif key.startswith("bbox_attention") or key.startswith("mask_head"):
                continue
            else:
                val = state_dict.pop(key)
                state_dict[prefix + key] = val
        else:
            if not key.startswith("class_labels_classifier") and not key.startswith("bbox_predictor"):
                val = state_dict.pop(key)
                state_dict[prefix + key] = val
    # finally, create HuggingFace model and load state dict
    model = DetrForSegmentation(config) if is_panoptic else DetrForObjectDetection(config)
    model.load_state_dict(state_dict)
    model.eval()
    # verify our conversion
    original_outputs = detr(pixel_values)
    outputs = model(pixel_values)
    assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-4)
    assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-4)
    if is_panoptic:
        assert torch.allclose(outputs.pred_masks, original_outputs["pred_masks"], atol=1e-4)

    # Save model and feature extractor
    logger.info(f"Saving PyTorch model and feature extractor to {pytorch_dump_folder_path}...")
    Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
    model.save_pretrained(pytorch_dump_folder_path)
    feature_extractor.save_pretrained(pytorch_dump_folder_path)
예제 #5
0
'''.split())

import pytorch_lightning as pl
from transformers import DetrConfig, DetrFeatureExtractor
from lib.DETR import DetrForObjectDetection, CocoDetection
import torch

feature_extractor = DetrFeatureExtractor.from_pretrained(
    "facebook/detr-resnet-50")
DATA_BASE = 'data/custom/'
val_dataset = CocoDetection(img_folder=f'{DATA_BASE}/val',
                            feature_extractor=feature_extractor,
                            train=False)
cats = val_dataset.coco.cats
id2label = {k: v['name'] for k, v in cats.items()}
config = DetrConfig.from_pretrained("facebook/detr-resnet-50",
                                    num_labels=len(id2label))
model = DetrForObjectDetection(config)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_ = model.to(device)


def load(model):
    state_dict = torch.load(f'saved_models/{args.run_name}',
                            map_location=torch.device(device))
    model.load_state_dict(state_dict)


load(model)
print('Model loaded')