Exemple #1
0
    def set_classy_state(self, state):
        """
        Initialize the model trunk and head from the state dictionary.

        We call this on the state.base_model which is not wrapped with DDP.
        load the model from checkpoint.
        """
        from vissl.utils.checkpoint import print_loaded_dict_info

        # Loading the trunk
        logging.info(f"Rank {self.local_rank}: Loading Trunk state dict....")
        if isinstance(self.trunk, FSDP):
            self.trunk.load_local_state_dict(state["model"]["trunk"])
            fsdp_recursive_reset_lazy_init(self.trunk)
        else:
            self.trunk.load_state_dict(state["model"]["trunk"])

        # Loading the head
        logging.info(f"Rank {self.local_rank}: Loading Heads state dict....")
        if any(isinstance(head, FSDP) for head in self.heads):
            for i, head in enumerate(self.heads):
                if isinstance(head, FSDP):
                    head.load_local_state_dict(state["model"]["heads"][i])
                    fsdp_recursive_reset_lazy_init(head)
                else:
                    self.head.load_state_dict(state["model"]["heads"][i])
        else:
            # sometimes, we want to load the partial head only, so strict=False
            self.heads.load_state_dict(state["model"]["heads"], strict=False)

        # Print debug information about layers loaded.
        #
        # Get the model state dict original (if FSDP, calling it on all ranks.)
        logging.info(f"Rank {self.local_rank}: Model state dict loaded!")
        if isinstance(self.trunk, FSDP) or any(
            isinstance(head, FSDP) for head in self.heads
        ):
            return  # TODO (Quentin) - log the weights of the loaded shard

        if self.local_rank == 0:
            trunk_state_dict, heads_state_dict = (
                self.trunk.state_dict(),
                self.heads.state_dict(),
            )
            model_state_dict = {}
            model_state_dict.update(trunk_state_dict)
            model_state_dict.update(heads_state_dict)

            # get the checkpoint state dict
            checkpoint_state_dict = {}
            checkpoint_state_dict.update(state["model"]["trunk"])
            checkpoint_state_dict.update(state["model"]["heads"])
            params_from_file = self.model_config["WEIGHTS_INIT"]
            skip_layers = params_from_file.get("SKIP_LAYERS", [])
            print_loaded_dict_info(
                model_state_dict,
                checkpoint_state_dict,
                skip_layers=skip_layers,
                model_config=self.model_config,
            )
Exemple #2
0
    def set_classy_state(self, state):
        """
        Initialize the model trunk and head from the state dictionary.

        We call this on the state.base_model which is not wrapped with DDP.
        load the model from checkpoint.
        """
        from vissl.utils.checkpoint import print_loaded_dict_info

        logging.info(f"Rank {self.local_rank}: Loading Trunk state dict....")
        self.trunk.load_state_dict(state["model"]["trunk"])
        logging.info(f"Rank {self.local_rank}: Loading Heads state dict....")

        # sometimes, we want to load the partial head only, so strict=False
        self.heads.load_state_dict(state["model"]["heads"], strict=False)
        logging.info(f"Rank {self.local_rank}: Model state dict loaded!")

        # Print debug information about layers loaded.
        #
        # Get the model state dict original (if FSDP, calling it on all ranks.)
        if self.local_rank == 0 or isinstance(self.trunk, FSDP):
            trunk_state_dict, heads_state_dict = (
                self.trunk.state_dict(),
                self.heads.state_dict(),
            )
        # Now print.
        if self.local_rank == 0:
            model_state_dict = {}
            model_state_dict.update(trunk_state_dict)
            model_state_dict.update(heads_state_dict)

            # get the checkpoint state dict
            checkpoint_state_dict = {}
            checkpoint_state_dict.update(state["model"]["trunk"])
            checkpoint_state_dict.update(state["model"]["heads"])
            params_from_file = self.model_config["WEIGHTS_INIT"]
            skip_layers = params_from_file.get("SKIP_LAYERS", [])
            print_loaded_dict_info(
                model_state_dict,
                checkpoint_state_dict,
                skip_layers=skip_layers,
                model_config=self.model_config,
            )