def set_state(self, state): """ Restore the experiment from the state returned by `get_state` :param state: dictionary with "model", "optimizer", "lr_scheduler", and "amp" states """ if "model" in state: with io.BytesIO(state["model"]) as buffer: state_dict = deserialize_state_dict(buffer, self.device) self.model.module.load_state_dict(state_dict) if "optimizer" in state: with io.BytesIO(state["optimizer"]) as buffer: state_dict = deserialize_state_dict(buffer, self.device) self.optimizer.load_state_dict(state_dict) if "lr_scheduler" in state: with io.BytesIO(state["lr_scheduler"]) as buffer: state_dict = deserialize_state_dict(buffer, self.device) self.lr_scheduler.load_state_dict(state_dict) if "amp" in state and amp is not None: with io.BytesIO(state["amp"]) as buffer: state_dict = deserialize_state_dict(buffer, self.device) amp.load_state_dict(state_dict)
def repair(in_path): with open(in_path, "rb") as f: checkpoint = pickle.load(f) fix_needed = False if "lr_scheduler" in checkpoint: print( "Loading LR scheduler state dict (this might take a few minutes)") with io.BytesIO(checkpoint["lr_scheduler"]) as buf: lr_sched_state_dict = deserialize_state_dict(buf) if "anneal_func" in lr_sched_state_dict: fix_needed = True del lr_sched_state_dict["anneal_func"] with io.BytesIO() as buf: serialize_state_dict(buf, lr_sched_state_dict) checkpoint["lr_scheduler"] = buf.getvalue() out_path = f"{in_path}.repaired" print(f"Saving {out_path}") with open(out_path, "wb") as f: pickle.dump(checkpoint, f) if not fix_needed: print("This checkpoint does not need repair")
def set_state(self, state): """ Restore the experiment from the state returned by `get_state` :param state: dictionary with "model", "optimizer", "lr_scheduler", and "amp" states """ if "model" in state: with io.BytesIO(state["model"]) as buffer: state_dict = deserialize_state_dict(buffer, self.device) state_dict = get_compatible_state_dict(self.model.module, state_dict) self.model.module.load_state_dict(state_dict) if "optimizer" in state: with io.BytesIO(state["optimizer"]) as buffer: state_dict = deserialize_state_dict(buffer, self.device) self.optimizer.load_state_dict(state_dict) if "lr_scheduler" in state: with io.BytesIO(state["lr_scheduler"]) as buffer: state_dict = deserialize_state_dict(buffer, self.device) self.lr_scheduler.load_state_dict(state_dict) if "amp" in state and amp is not None: with io.BytesIO(state["amp"]) as buffer: state_dict = deserialize_state_dict(buffer, self.device) amp.load_state_dict(state_dict) if "current_epoch" in state: self.current_epoch = state["current_epoch"] else: # Try to recover current epoch from LR Scheduler state last_epoch = self.lr_scheduler.last_epoch + 1 if isinstance(self.lr_scheduler, ComposedLRScheduler): self.current_epoch = last_epoch // self.lr_scheduler.steps_per_epoch elif isinstance(self.lr_scheduler, OneCycleLR): steps_per_epoch = self.lr_scheduler.total_steps // self.epochs self.current_epoch = last_epoch // steps_per_epoch else: self.current_epoch = last_epoch if "current_timestep" in state: self.current_timestep = state["current_timestep"] else: self.current_timestep = self.total_batches * self.current_epoch
def get_state_dict(checkpoint_path): checkpoint_path = os.path.expanduser(checkpoint_path) with open(checkpoint_path, "rb") as loaded_state: checkpoint_dict = pickle.load(loaded_state) if "model" in checkpoint_dict: with io.BytesIO(checkpoint_dict["model"]) as buffer: state_dict = deserialize_state_dict(buffer) return state_dict else: return None
def set_state(self, state): """ Restore the experiment from the state returned by `get_state` :param state: dictionary with "model", "optimizer", "lr_scheduler", and "amp" states """ if "algorithm" in state: with io.BytesIO(state["model"]) as buffer: state_dict = deserialize_state_dict(buffer, self.device) self.algorithm.load_state_dict(state_dict) if "current_epoch" in state: self.current_epoch = state["current_epoch"] self.total_steps = state["total_steps"]
def setup_experiment(self, config): super(LoadBlockModelExperiment, self).setup_experiment(config) self.load_file = config.get("model_path", None) if self.load_file is not None: with open(self.load_file, mode="rb") as f: state = pickle.load(f) if "model" in state: with io.BytesIO(state["model"]) as buffer: state_dict = deserialize_state_dict(buffer, self.device) model = self.model if hasattr(model, "module"): # DistributedDataParallel model = model.module state_dict = get_compatible_state_dict(state_dict, model) model.load_state_dict(state_dict)
def create_model(model_class, model_args, init_batch_norm, device, checkpoint_file=None, resize_buffers_for_checkpoint=False): """ Create imagenet experiment model with option to load state from checkpoint :param model_class: The model class. Must inherit from torch.nn.Module :param model_args: The model constructor arguments :param init_batch_norm: Whether or not to initialize batch norm modules :param device: Model device :param checkpoint_file: Optional checkpoint file to load model state :param resize_buffers_for_checkpoint: Optional param with `checkpoint_file`. If True, this resizes the models buffers to match those of the checkpoint before loading it. :return: Configured model """ model = model_class(**model_args) if init_batch_norm: init_resnet50_batch_norm(model) model.to(device) # Load model parameters from checkpoint if checkpoint_file is not None: with open(checkpoint_file, "rb") as pickle_file: state = pickle.load(pickle_file) with io.BytesIO(state["model"]) as buffer: state_dict = deserialize_state_dict(buffer, device) state_dict = get_compatible_state_dict(model, state_dict) if resize_buffers_for_checkpoint: resize_model_buffers(model, state_dict) model.load_state_dict(state_dict) return model
def test_serialization(self): model1 = simple_linear_net() model2 = simple_linear_net() def init(m): if hasattr(m, "weight") and m.weight is not None: m.weight.data.fill_(42.0) model2.apply(init) with io.BytesIO() as buffer: serialize_state_dict(buffer, model1.state_dict()) buffer.seek(0) state_dict = deserialize_state_dict(buffer) model2.load_state_dict(state_dict) self.assertTrue(compare_models(model1, model2, (32, )))
def create_model(model_class, model_args, init_batch_norm, device, checkpoint_file=None, init_hooks=None): """ Create imagenet experiment model with option to load state from checkpoint :param model_class: The model class. Must inherit from torch.nn.Module :param model_args: The model constructor arguments :param init_batch_norm: Whether or not to initialize batch norm modules :param device: Model device :param checkpoint_file: Optional checkpoint file to load model state :return: Configured model """ model = model_class(**model_args) if init_batch_norm: init_resnet50_batch_norm(model) model.to(device) # Load model parameters from checkpoint if checkpoint_file is not None: with open(checkpoint_file, "rb") as pickle_file: state = pickle.load(pickle_file) with io.BytesIO(state["model"]) as buffer: state_dict = deserialize_state_dict(buffer, device) model.load_state_dict(state_dict) # Modify init via hooks. elif init_hooks: for hook, kwargs in init_hooks: model = hook(model, **kwargs) or model return model
def create_model(model_class, model_args, init_batch_norm, device, checkpoint_file=None, init_hooks=None): """ Create imagenet experiment model with option to load state from checkpoint :param model_class: The model class. Must inherit from torch.nn.Module :param model_args: The model constructor arguments :param init_batch_norm: Whether or not to initialize batch norm modules :param device: Model device :param checkpoint_file: Optional checkpoint file to load model state :return: Configured model """ model = model_class(**model_args) if init_batch_norm: init_resnet50_batch_norm(model) model.to(device) # Load model parameters from checkpoint if checkpoint_file is not None: with open(checkpoint_file, "rb") as pickle_file: state = pickle.load(pickle_file) with io.BytesIO(state["model"]) as buffer: state_dict = deserialize_state_dict(buffer, device) # Make sure checkpoint is compatible with model if model.state_dict().keys() != state_dict.keys(): state_dict = OrderedDict( zip(model.state_dict().keys(), state_dict.values())) model.load_state_dict(state_dict) return model