def set_state(self, state):
        """
        Restore the experiment from the state returned by `get_state`
        :param state: dictionary with "model", "optimizer", "lr_scheduler", and "amp"
                      states
        """
        if "model" in state:
            with io.BytesIO(state["model"]) as buffer:
                state_dict = deserialize_state_dict(buffer, self.device)
            self.model.module.load_state_dict(state_dict)

        if "optimizer" in state:
            with io.BytesIO(state["optimizer"]) as buffer:
                state_dict = deserialize_state_dict(buffer, self.device)
            self.optimizer.load_state_dict(state_dict)

        if "lr_scheduler" in state:
            with io.BytesIO(state["lr_scheduler"]) as buffer:
                state_dict = deserialize_state_dict(buffer, self.device)
            self.lr_scheduler.load_state_dict(state_dict)

        if "amp" in state and amp is not None:
            with io.BytesIO(state["amp"]) as buffer:
                state_dict = deserialize_state_dict(buffer, self.device)
            amp.load_state_dict(state_dict)
def repair(in_path):
    with open(in_path, "rb") as f:
        checkpoint = pickle.load(f)

    fix_needed = False
    if "lr_scheduler" in checkpoint:
        print(
            "Loading LR scheduler state dict (this might take a few minutes)")
        with io.BytesIO(checkpoint["lr_scheduler"]) as buf:
            lr_sched_state_dict = deserialize_state_dict(buf)

        if "anneal_func" in lr_sched_state_dict:
            fix_needed = True
            del lr_sched_state_dict["anneal_func"]

            with io.BytesIO() as buf:
                serialize_state_dict(buf, lr_sched_state_dict)
                checkpoint["lr_scheduler"] = buf.getvalue()

            out_path = f"{in_path}.repaired"
            print(f"Saving {out_path}")
            with open(out_path, "wb") as f:
                pickle.dump(checkpoint, f)

    if not fix_needed:
        print("This checkpoint does not need repair")
    def set_state(self, state):
        """
        Restore the experiment from the state returned by `get_state`
        :param state: dictionary with "model", "optimizer", "lr_scheduler", and "amp"
                      states
        """
        if "model" in state:
            with io.BytesIO(state["model"]) as buffer:
                state_dict = deserialize_state_dict(buffer, self.device)
            state_dict = get_compatible_state_dict(self.model.module,
                                                   state_dict)
            self.model.module.load_state_dict(state_dict)

        if "optimizer" in state:
            with io.BytesIO(state["optimizer"]) as buffer:
                state_dict = deserialize_state_dict(buffer, self.device)
            self.optimizer.load_state_dict(state_dict)

        if "lr_scheduler" in state:
            with io.BytesIO(state["lr_scheduler"]) as buffer:
                state_dict = deserialize_state_dict(buffer, self.device)
            self.lr_scheduler.load_state_dict(state_dict)

        if "amp" in state and amp is not None:
            with io.BytesIO(state["amp"]) as buffer:
                state_dict = deserialize_state_dict(buffer, self.device)
            amp.load_state_dict(state_dict)

        if "current_epoch" in state:
            self.current_epoch = state["current_epoch"]
        else:
            # Try to recover current epoch from LR Scheduler state
            last_epoch = self.lr_scheduler.last_epoch + 1
            if isinstance(self.lr_scheduler, ComposedLRScheduler):
                self.current_epoch = last_epoch // self.lr_scheduler.steps_per_epoch
            elif isinstance(self.lr_scheduler, OneCycleLR):
                steps_per_epoch = self.lr_scheduler.total_steps // self.epochs
                self.current_epoch = last_epoch // steps_per_epoch
            else:
                self.current_epoch = last_epoch

        if "current_timestep" in state:
            self.current_timestep = state["current_timestep"]
        else:
            self.current_timestep = self.total_batches * self.current_epoch
def get_state_dict(checkpoint_path):

    checkpoint_path = os.path.expanduser(checkpoint_path)
    with open(checkpoint_path, "rb") as loaded_state:
        checkpoint_dict = pickle.load(loaded_state)

    if "model" in checkpoint_dict:
        with io.BytesIO(checkpoint_dict["model"]) as buffer:
            state_dict = deserialize_state_dict(buffer)
        return state_dict
    else:
        return None
 def set_state(self, state):
     """
     Restore the experiment from the state returned by `get_state`
     :param state: dictionary with "model", "optimizer", "lr_scheduler", and "amp"
                   states
     """
     if "algorithm" in state:
         with io.BytesIO(state["model"]) as buffer:
             state_dict = deserialize_state_dict(buffer, self.device)
         self.algorithm.load_state_dict(state_dict)
     if "current_epoch" in state:
         self.current_epoch = state["current_epoch"]
         self.total_steps = state["total_steps"]
Beispiel #6
0
 def setup_experiment(self, config):
     super(LoadBlockModelExperiment, self).setup_experiment(config)
     self.load_file = config.get("model_path", None)
     if self.load_file is not None:
         with open(self.load_file, mode="rb") as f:
             state = pickle.load(f)
         if "model" in state:
             with io.BytesIO(state["model"]) as buffer:
                 state_dict = deserialize_state_dict(buffer, self.device)
             model = self.model
             if hasattr(model, "module"):
                 # DistributedDataParallel
                 model = model.module
             state_dict = get_compatible_state_dict(state_dict, model)
             model.load_state_dict(state_dict)
Beispiel #7
0
def create_model(model_class,
                 model_args,
                 init_batch_norm,
                 device,
                 checkpoint_file=None,
                 resize_buffers_for_checkpoint=False):
    """
    Create imagenet experiment model with option to load state from checkpoint

    :param model_class:
            The model class. Must inherit from torch.nn.Module
    :param model_args:
        The model constructor arguments
    :param init_batch_norm:
        Whether or not to initialize batch norm modules
    :param device:
        Model device
    :param checkpoint_file:
        Optional checkpoint file to load model state
    :param resize_buffers_for_checkpoint:
        Optional param with `checkpoint_file`. If True, this resizes the models buffers
        to match those of the checkpoint before loading it.

    :return: Configured model
    """
    model = model_class(**model_args)
    if init_batch_norm:
        init_resnet50_batch_norm(model)
    model.to(device)

    # Load model parameters from checkpoint
    if checkpoint_file is not None:
        with open(checkpoint_file, "rb") as pickle_file:
            state = pickle.load(pickle_file)
        with io.BytesIO(state["model"]) as buffer:
            state_dict = deserialize_state_dict(buffer, device)

        state_dict = get_compatible_state_dict(model, state_dict)

        if resize_buffers_for_checkpoint:
            resize_model_buffers(model, state_dict)

        model.load_state_dict(state_dict)

    return model
Beispiel #8
0
    def test_serialization(self):
        model1 = simple_linear_net()
        model2 = simple_linear_net()

        def init(m):
            if hasattr(m, "weight") and m.weight is not None:
                m.weight.data.fill_(42.0)

        model2.apply(init)

        with io.BytesIO() as buffer:
            serialize_state_dict(buffer, model1.state_dict())

            buffer.seek(0)
            state_dict = deserialize_state_dict(buffer)
            model2.load_state_dict(state_dict)

        self.assertTrue(compare_models(model1, model2, (32, )))
def create_model(model_class,
                 model_args,
                 init_batch_norm,
                 device,
                 checkpoint_file=None,
                 init_hooks=None):
    """
    Create imagenet experiment model with option to load state from checkpoint

    :param model_class:
            The model class. Must inherit from torch.nn.Module
    :param model_args:
        The model constructor arguments
    :param init_batch_norm:
        Whether or not to initialize batch norm modules
    :param device:
        Model device
    :param checkpoint_file:
        Optional checkpoint file to load model state

    :return: Configured model
    """
    model = model_class(**model_args)
    if init_batch_norm:
        init_resnet50_batch_norm(model)
    model.to(device)

    # Load model parameters from checkpoint
    if checkpoint_file is not None:
        with open(checkpoint_file, "rb") as pickle_file:
            state = pickle.load(pickle_file)
        with io.BytesIO(state["model"]) as buffer:
            state_dict = deserialize_state_dict(buffer, device)
        model.load_state_dict(state_dict)

    # Modify init via hooks.
    elif init_hooks:
        for hook, kwargs in init_hooks:
            model = hook(model, **kwargs) or model

    return model
def create_model(model_class, model_args, init_batch_norm, device,
                 checkpoint_file=None, init_hooks=None):
    """
    Create imagenet experiment model with option to load state from checkpoint

    :param model_class:
            The model class. Must inherit from torch.nn.Module
    :param model_args:
        The model constructor arguments
    :param init_batch_norm:
        Whether or not to initialize batch norm modules
    :param device:
        Model device
    :param checkpoint_file:
        Optional checkpoint file to load model state

    :return: Configured model
    """
    model = model_class(**model_args)
    if init_batch_norm:
        init_resnet50_batch_norm(model)
    model.to(device)

    # Load model parameters from checkpoint
    if checkpoint_file is not None:
        with open(checkpoint_file, "rb") as pickle_file:
            state = pickle.load(pickle_file)
        with io.BytesIO(state["model"]) as buffer:
            state_dict = deserialize_state_dict(buffer, device)

        # Make sure checkpoint is compatible with model
        if model.state_dict().keys() != state_dict.keys():
            state_dict = OrderedDict(
                zip(model.state_dict().keys(), state_dict.values()))

        model.load_state_dict(state_dict)

    return model