Пример #1
0
    def _load_student_model(self, checkpoint: Any) -> _IncompatibleKeys:  # pyre-ignore
        checkpoint_state_dict = checkpoint.pop("model")
        self._convert_ndarray_to_tensor(checkpoint_state_dict)

        # if the state_dict comes from a model that was wrapped in a
        # DataParallel or DistributedDataParallel during serialization,
        # remove the "module" prefix before performing the matching.
        _strip_prefix_if_present(checkpoint_state_dict, "module.")

        # work around https://github.com/pytorch/pytorch/issues/24139
        model_state_dict = self.model.modelStudent.state_dict()
        incorrect_shapes = []
        for k in list(checkpoint_state_dict.keys()):
            if k in model_state_dict:
                shape_model = tuple(model_state_dict[k].shape)
                shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
                if shape_model != shape_checkpoint:
                    incorrect_shapes.append((k, shape_checkpoint, shape_model))
                    checkpoint_state_dict.pop(k)
        # pyre-ignore
        incompatible = self.model.modelStudent.load_state_dict(
            checkpoint_state_dict, strict=False
        )
        return _IncompatibleKeys(
            missing_keys=incompatible.missing_keys,
            unexpected_keys=incompatible.unexpected_keys,
            incorrect_shapes=incorrect_shapes,
        )
Пример #2
0
    def _load_model_slimmable(self, checkpoint):
        checkpoint_state_dict = checkpoint.pop("model")
        _strip_prefix_if_present(checkpoint_state_dict, "module.")
        # to new checkpoint
        new_checkpoint_state_dict = {}
        for k in list(checkpoint_state_dict.keys()):
            new_checkpoint_state_dict["T_backbone.bottom_up.{}".format(
                k)] = checkpoint_state_dict[k]
        # import pdb.set_trace()

        model_state_dict = self.model.state_dict()
        incorrect_shapes = []
        # pdb.set_trace()
        for k in list(new_checkpoint_state_dict.keys()):
            if k in model_state_dict:
                shape_model = tuple(model_state_dict[k].shape)
                shape_checkpoint = tuple(new_checkpoint_state_dict[k].shape)
                # import pdb
                # pdb.set_trace()
                if shape_model != shape_checkpoint:
                    incorrect_shapes.append((k, shape_checkpoint, shape_model))
                    new_checkpoint_state_dict.pop(k)
        # pyre-ignore
        incompatible = self.model.load_state_dict(new_checkpoint_state_dict,
                                                  strict=False)
        return _IncompatibleKeys(
            missing_keys=incompatible.missing_keys,
            unexpected_keys=incompatible.unexpected_keys,
            incorrect_shapes=incorrect_shapes,
        )