Exemplo n.º 1
0
def get_from_to_our_keys(model_name: str) -> Dict[str, str]:
    """
    Returns a dictionary that maps from original model's key -> our implementation's keys
    """

    # create our model (with small weights)
    our_config = RegNetConfig(depths=[2, 7, 17, 1],
                              hidden_sizes=[8, 8, 8, 8],
                              groups_width=8)
    if "in1k" in model_name:
        our_model = RegNetForImageClassification(our_config)
    else:
        our_model = RegNetModel(our_config)
    # create from model (with small weights)
    from_model = FakeRegNetVisslWrapper(
        RegNet(
            FakeRegNetParams(depth=27,
                             group_width=1010,
                             w_0=1744,
                             w_a=620.83,
                             w_m=2.52)))

    with torch.no_grad():
        from_model = from_model.eval()
        our_model = our_model.eval()

        x = torch.randn((1, 3, 32, 32))
        # trace both
        dest_tracker = Tracker(our_model)
        dest_traced = dest_tracker(x).parametrized

        pprint(dest_tracker.name2module)
        src_tracker = Tracker(from_model)
        src_traced = src_tracker(x).parametrized

    # convert the keys -> module dict to keys -> params
    def to_params_dict(dict_with_modules):
        params_dict = OrderedDict()
        for name, module in dict_with_modules.items():
            for param_name, param in module.state_dict().items():
                params_dict[f"{name}.{param_name}"] = param
        return params_dict

    from_to_ours_keys = {}

    src_state_dict = to_params_dict(src_traced)
    dst_state_dict = to_params_dict(dest_traced)

    for (src_key, src_param), (dest_key,
                               dest_param) in zip(src_state_dict.items(),
                                                  dst_state_dict.items()):
        from_to_ours_keys[src_key] = dest_key
        logger.info(f"{src_key} -> {dest_key}")
    # if "in1k" was in the model_name it means it must have a classification head (was finetuned)
    if "in1k" in model_name:
        from_to_ours_keys["0.clf.0.weight"] = "classifier.1.weight"
        from_to_ours_keys["0.clf.0.bias"] = "classifier.1.bias"

    return from_to_ours_keys
 def create_and_check_model(self, config, pixel_values, labels):
     model = RegNetModel(config=config)
     model.to(torch_device)
     model.eval()
     result = model(pixel_values)
     # expected last hidden states: B, C, H // 32, W // 32
     self.parent.assertEqual(
         result.last_hidden_state.shape,
         (self.batch_size, self.hidden_sizes[-1], self.image_size // 32, self.image_size // 32),
     )
 def test_model_from_pretrained(self):
     for model_name in REGNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
         model = RegNetModel.from_pretrained(model_name)
         self.assertIsNotNone(model)