def _load_student_model(self, checkpoint: Any) -> _IncompatibleKeys: # pyre-ignore checkpoint_state_dict = checkpoint.pop("model") self._convert_ndarray_to_tensor(checkpoint_state_dict) # if the state_dict comes from a model that was wrapped in a # DataParallel or DistributedDataParallel during serialization, # remove the "module" prefix before performing the matching. _strip_prefix_if_present(checkpoint_state_dict, "module.") # work around https://github.com/pytorch/pytorch/issues/24139 model_state_dict = self.model.modelStudent.state_dict() incorrect_shapes = [] for k in list(checkpoint_state_dict.keys()): if k in model_state_dict: shape_model = tuple(model_state_dict[k].shape) shape_checkpoint = tuple(checkpoint_state_dict[k].shape) if shape_model != shape_checkpoint: incorrect_shapes.append((k, shape_checkpoint, shape_model)) checkpoint_state_dict.pop(k) # pyre-ignore incompatible = self.model.modelStudent.load_state_dict( checkpoint_state_dict, strict=False ) return _IncompatibleKeys( missing_keys=incompatible.missing_keys, unexpected_keys=incompatible.unexpected_keys, incorrect_shapes=incorrect_shapes, )
def _load_model_slimmable(self, checkpoint): checkpoint_state_dict = checkpoint.pop("model") _strip_prefix_if_present(checkpoint_state_dict, "module.") # to new checkpoint new_checkpoint_state_dict = {} for k in list(checkpoint_state_dict.keys()): new_checkpoint_state_dict["T_backbone.bottom_up.{}".format( k)] = checkpoint_state_dict[k] # import pdb.set_trace() model_state_dict = self.model.state_dict() incorrect_shapes = [] # pdb.set_trace() for k in list(new_checkpoint_state_dict.keys()): if k in model_state_dict: shape_model = tuple(model_state_dict[k].shape) shape_checkpoint = tuple(new_checkpoint_state_dict[k].shape) # import pdb # pdb.set_trace() if shape_model != shape_checkpoint: incorrect_shapes.append((k, shape_checkpoint, shape_model)) new_checkpoint_state_dict.pop(k) # pyre-ignore incompatible = self.model.load_state_dict(new_checkpoint_state_dict, strict=False) return _IncompatibleKeys( missing_keys=incompatible.missing_keys, unexpected_keys=incompatible.unexpected_keys, incorrect_shapes=incorrect_shapes, )