Пример #1
0
    def init_model_from_weights_params_file(self, config: AttrDict,
                                            checkpoint: Dict[str, Any]):
        """
        We initialize the weights from this checkpoint. However, we don't care
        about the other metadata like iteration number etc.
        So the method only reads the state_dict
        """

        # TODO (Quentin) - support: different number of nodes + different checkpoint
        # formats + fine tuning
        # Special cases in which we want to evaluate a model trained with FSDP:
        # - we need to benchmark it in FSDP mode as well and with the same number of
        #   workers
        # - we need to have it trained with VISSL (no support for other checkpoint
        #   types for now)
        if isinstance(self.trunk, FeatureExtractorModel) and isinstance(
                self.trunk.base_model, FSDP):
            CheckpointLoader.init_fsdp_model_from_weights(
                self.trunk.base_model,
                checkpoint,
                weights_path=[
                    "classy_state_dict", "base_model", "model", "trunk"
                ],
            )
            fsdp_recursive_reset_lazy_init(self.trunk.base_model)
        elif isinstance(self.trunk, FSDP):
            CheckpointLoader.init_fsdp_model_from_weights(
                self.trunk,
                checkpoint,
                weights_path=[
                    "classy_state_dict", "base_model", "model", "trunk"
                ],
            )
            fsdp_recursive_reset_lazy_init(self.trunk)

        # General case: support for multiple format of checkpoint
        else:
            params_from_file = config["MODEL"]["WEIGHTS_INIT"]
            skip_layers = params_from_file.get("SKIP_LAYERS", [])
            replace_prefix = params_from_file.get("REMOVE_PREFIX", None)
            append_prefix = params_from_file.get("APPEND_PREFIX", None)
            state_dict_key_name = params_from_file.get("STATE_DICT_KEY_NAME",
                                                       None)
            init_model_from_consolidated_weights(
                config,
                self,
                checkpoint,
                state_dict_key_name=state_dict_key_name,
                skip_layers=skip_layers,
                replace_prefix=replace_prefix,
                append_prefix=append_prefix,
            )
Пример #2
0
 def _init_fsdp_model_heads_from_weights_params_file(
         self, checkpoint: Dict[str, Any]):
     for i, head in enumerate(self.heads):
         logging.info(f"Loading FSDP head {i}")
         if isinstance(head, FSDP):
             CheckpointLoader.init_fsdp_model_from_weights(
                 head,
                 checkpoint,
                 weights_path=[
                     "classy_state_dict", "base_model", "model", "heads"
                 ],
                 strict=False,
                 head_index=i,
             )
             fsdp_recursive_reset_lazy_init(head)
Пример #3
0
    def prepare(self, pin_memory: bool = False):
        """
        Prepares the task:
        - dataloaders
        - model
        - copy model to correct device
        - meters
        - loss
        - optimizer
        - LR schedulers
        - AMP state
        - resume from a checkpoint if available
        """
        self.dataloaders = self.build_dataloaders(pin_memory=pin_memory)
        self.phases = self._build_phases()
        train_phases = [phase for phase in self.phases if phase["train"]]
        num_train_phases = len(train_phases)
        self.base_model = self._build_model()
        self._set_ddp_options()
        self.base_loss = self._build_loss()
        self.meters = self._build_meters()
        self.optimizer = self._build_optimizer()
        self.optimizer_schedulers = self._build_optimizer_schedulers()
        self.num_train_phases = num_train_phases

        self.base_loss = self.base_loss.to(self.device)
        if self.device.type == "cuda":
            self.base_model = copy_model_to_gpu(self.base_model)

        # initialize the pytorch optimizer now since the model has been moved to
        # the appropriate device.
        self.prepare_optimizer()

        # Enable mixed precision grad scalers
        if self.amp_type == AmpType.APEX:
            # Allow Apex Amp to perform casts as specified by the amp_args.
            # This updates the model and the PyTorch optimizer (which is wrapped
            # by the ClassyOptimizer in self.optimizer).
            # NOTE: this must happen before loading the checkpoint. See
            # https://nvidia.github.io/apex/amp.html#checkpointing for more details.
            self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                self.base_model, self.optimizer.optimizer, **self.amp_args)

        # Restore an hypothetical checkpoint
        vissl_state_dict = None
        if self.checkpoint_path is not None:
            self.checkpoint = CheckpointLoader.load_and_broadcast_checkpoint(
                checkpoint_folder=self.checkpoint_folder,
                checkpoint_path=self.checkpoint_path,
                device=torch.device("cpu"),
            )

            self.iteration = self.checkpoint["iteration"]
            self.local_iteration_num = self.checkpoint["iteration_num"]
            vissl_state_dict = self.checkpoint.get("classy_state_dict")
            if "loss" in self.checkpoint:
                self.base_loss.load_state_dict(self.checkpoint["loss"])
                logging.info("======Loaded loss state from checkpoint======")

        return self._update_classy_state(vissl_state_dict)
Пример #4
0
    def _restore_model_weights(self, model):
        """
        If using a weights file to initialize the model, we load the weights
        and initialize the model. Since the weights file specified
        by user might not be VISSL trained weights, we expose several config
        options like APPEND_PREFIX, etc to allow successful loading of the weights.
        See MODEL.WEIGHTS_INIT description in vissl/config/defaults.yaml for details.
        """
        params_from_file = self.config["MODEL"]["WEIGHTS_INIT"]
        init_weights_path = params_from_file["PARAMS_FILE"]
        assert init_weights_path, "Shouldn't call this when init_weight_path is empty"
        logging.info(f"Initializing model from: {init_weights_path}")

        if PathManager.exists(init_weights_path):
            checkpoint = CheckpointLoader.load_and_broadcast_init_weights(
                checkpoint_path=init_weights_path, device=torch.device("cpu"))
            model.init_model_from_weights_params_file(self.config, checkpoint)
        return model
Пример #5
0
    def prepare(self, pin_memory: bool = False):
        """
        Prepares the task:
        - dataloaders
        - model
        - copy model to correct device
        - meters
        - loss
        - optimizer
        - LR schedulers
        - AMP state
        - resume from a checkpoint if available
        """
        self.phases = self._build_phases()
        self.num_phases = len(self.phases)
        self.base_model = self._build_model()
        self._set_ddp_options()
        self.meters = self._build_meters()
        self.optimizer = self._build_optimizer()
        self.optimizer_schedulers = self._build_optimizer_schedulers()

        if self.device.type == "cuda":
            self.base_model = copy_model_to_gpu(self.base_model)

        # initialize the pytorch optimizer now since the model has been moved to
        # the appropriate device.
        self.prepare_optimizer()

        # Enable mixed precision grad scalers
        if self.amp_type == AmpType.APEX:
            # Allow Apex Amp to perform casts as specified by the amp_args.
            # This updates the model and the PyTorch optimizer (which is wrapped
            # by the ClassyOptimizer in self.optimizer).
            # NOTE: this must happen before loading the checkpoint. See
            # https://nvidia.github.io/apex/amp.html#checkpointing for more details.
            self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                self.base_model, self.optimizer.optimizer, **self.amp_args
            )

        # Create EMA average of the model if hook is specified.
        ema_config = self.config["HOOKS"]["EMA_MODEL"]
        if ema_config["ENABLE_EMA_METERS"] or ema_config["SAVE_EMA_MODEL"]:
            self._create_ema_model()

        # Restore an hypothetical checkpoint
        vissl_state_dict = None
        if self.checkpoint_path is not None:
            self.checkpoint = CheckpointLoader.load_and_broadcast_checkpoint(
                checkpoint_folder=self.checkpoint_folder,
                checkpoint_path=self.checkpoint_path,
                device=torch.device("cpu"),
            )
            if self.checkpoint is not None:
                self.iteration = self.checkpoint["iteration"]
                self.local_iteration_num = self.checkpoint["iteration_num"]
                vissl_state_dict = self.checkpoint.get("classy_state_dict")
            else:
                raise ValueError(f"Could not load checkpoint: {self.checkpoint_path}")

        current_train_phase_idx = (
            vissl_state_dict["train_phase_idx"] + 1 if vissl_state_dict else 0
        )

        self.datasets, self.data_and_label_keys = self.build_datasets(
            current_train_phase_idx
        )

        # set dataset state before building dataloader, in order to capture checkpoint info.
        if vissl_state_dict and "train" in self.datasets:
            self.datasets["train"].set_classy_state(
                vissl_state_dict.get("train_dataset_iterator")
            )

        self.dataloaders = self.build_dataloaders(
            pin_memory=pin_memory, current_train_phase_idx=current_train_phase_idx
        )

        # Build base loss, move to device, and load from checkpoint if applicable
        self.base_loss = self._build_loss()
        self.base_loss = self.base_loss.to(self.device)
        if self.checkpoint and "loss" in self.checkpoint:
            self.base_loss.load_state_dict(self.checkpoint["loss"])
            logging.info("======Loaded loss state from checkpoint======")

        return self._update_classy_state(vissl_state_dict)
Пример #6
0
    def init_model_from_weights_params_file(self,
                                            config: AttrDict,
                                            checkpoint: Dict[str, Any],
                                            strict: bool = False):
        """
        We initialize the weights from this checkpoint. However, we don't care
        about the other metadata like iteration number etc.
        So the method only reads the state_dict
        """

        # Specific case for FSDP trunks:
        # - models have to be created with VISSL
        # - checkpoints have to be created with VISSL
        if isinstance(self.trunk, FeatureExtractorModel) and isinstance(
                self.trunk.base_model, FSDP):
            # Linear evaluation / extraction from FSDP models:
            # - load the trunk (complete strict load)
            # - load the head (optional and partial load supported)
            logging.info("Loading FSDP trunk in extraction mode")
            CheckpointLoader.init_fsdp_model_from_weights(
                self.trunk.base_model,
                checkpoint,
                weights_path=[
                    "classy_state_dict", "base_model", "model", "trunk"
                ],
            )
            fsdp_recursive_reset_lazy_init(self.trunk.base_model)
            if should_init_head_weights(config.MODEL):
                self._init_fsdp_model_heads_from_weights_params_file(
                    checkpoint)

        elif isinstance(self.trunk, FSDP):
            # Fine-tuning of FSDP models:
            # - load the trunk (complete strict load)
            # - load the head (optional and partial load supported)
            logging.info("Loading FSDP trunk")
            CheckpointLoader.init_fsdp_model_from_weights(
                self.trunk,
                checkpoint,
                weights_path=[
                    "classy_state_dict", "base_model", "model", "trunk"
                ],
            )
            fsdp_recursive_reset_lazy_init(self.trunk)
            if should_init_head_weights(config.MODEL):
                self._init_fsdp_model_heads_from_weights_params_file(
                    checkpoint)

        # General case: support for multiple format of checkpoint
        else:
            params_from_file = config["MODEL"]["WEIGHTS_INIT"]
            skip_layers = params_from_file.get("SKIP_LAYERS", [])
            replace_prefix = params_from_file.get("REMOVE_PREFIX", None)
            append_prefix = params_from_file.get("APPEND_PREFIX", None)
            state_dict_key_name = params_from_file.get("STATE_DICT_KEY_NAME",
                                                       None)
            init_model_from_consolidated_weights(
                config,
                self,
                checkpoint,
                state_dict_key_name=state_dict_key_name,
                skip_layers=skip_layers,
                replace_prefix=replace_prefix,
                append_prefix=append_prefix,
                strict=strict,
            )
    def _worker(gpu_id: int, sync_file: str, world_size: int):
        torch.manual_seed(0)
        os.environ["RANK"] = str(gpu_id)
        init_distributed_on_file(world_size=world_size,
                                 gpu_id=gpu_id,
                                 sync_file=sync_file)
        torch.backends.cudnn.deterministic = True

        config = TestCheckpointConversion._create_fsdp_model_config(
            with_fsdp=True)
        model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        model = fsdp_wrapper(model, **config.MODEL.FSDP_CONFIG)
        optimizer = optim.SGD(model.parameters(), lr=1e-4)

        # Fake inputs
        num_iterations = 5
        batch_size = 3
        torch.manual_seed(gpu_id)
        fake_inputs = torch.randn(size=(num_iterations, batch_size, 3, 96, 96))
        fake_targets = torch.randn(size=(num_iterations, batch_size))

        # Fake training loop
        criterion = nn.MSELoss()
        for iteration in range(num_iterations):
            fake_input = fake_inputs[iteration].cuda(gpu_id)
            fake_target = fake_targets[iteration].cuda(gpu_id)
            output1, output2 = model(fake_input)[0]
            loss = criterion(output1.sum(axis=-1), fake_target) + criterion(
                output2.sum(axis=-1), fake_target)
            if gpu_id == 0:
                print(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Save a bunch of checkpoint, one by shard
        checkpoint_writer = CheckpointWriter(
            checkpoint_folder=".",
            is_final_train_phase=True,
            mode="iteration",
            mode_num=0,
            backend="disk",
        )
        content = {
            "classy_state_dict": {
                "base_model": {
                    "model": {
                        "trunk": model.trunk.local_state_dict()
                    },
                    "meta": {
                        "trunk": model.trunk.local_metadata_dict()
                    },
                }
            }
        }
        checkpoint_writer.save_sharded_checkpoint(content,
                                                  shard_rank=gpu_id,
                                                  world_size=world_size)
        dist.barrier()
        print(os.listdir("."))

        # Convert the checkpoint to consolidated and sliced checkpoints
        if gpu_id == 0:
            CheckpointFormatConverter.sharded_to_consolidated_checkpoint(
                "checkpoint.torch", "checkpoint_conso.torch")
            CheckpointFormatConverter.sharded_to_sliced_checkpoint(
                "checkpoint.torch", "checkpoint_sliced.torch")
        dist.barrier()
        print(os.listdir("."))

        # Now create models initialized from the previous checkpoint and compare them
        fake_test_input = torch.randn(size=(1, 3, 96, 96)).cuda(gpu_id)

        shard_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint.torch", device=torch.device("cpu"))
        shard_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        shard_model = fsdp_wrapper(shard_model, **config.MODEL.FSDP_CONFIG)
        shard_model.init_model_from_weights_params_file(config, shard_cp)

        conso_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint_conso.torch", device=torch.device("cpu"))
        conso_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        conso_model = fsdp_wrapper(conso_model, **config.MODEL.FSDP_CONFIG)
        conso_model.init_model_from_weights_params_file(config, conso_cp)

        slice_cp = CheckpointLoader.load_and_broadcast_init_weights(
            "checkpoint_sliced.torch", device=torch.device("cpu"))
        slice_model = build_model(config.MODEL, config.OPTIMIZER).cuda(gpu_id)
        slice_model = fsdp_wrapper(slice_model, **config.MODEL.FSDP_CONFIG)
        slice_model.init_model_from_weights_params_file(config, slice_cp)

        # Verifying that the models are equivalent
        if gpu_id == 0:
            slice_state_dict = slice_model.local_state_dict()
            conso_state_dict = conso_model.local_state_dict()
            assert set(slice_state_dict.keys()) == set(conso_state_dict.keys())
            for k in slice_state_dict.keys():
                slice_val = slice_state_dict[k]
                conso_val = conso_state_dict[k]
                assert torch.allclose(
                    slice_val, conso_val
                ), f"Difference for key {k}: {slice_val} VS {conso_val}"
        dist.barrier()

        with torch.no_grad():
            ref_out = model.trunk(fake_test_input)[0]
            shard_out = shard_model.trunk(fake_test_input)[0]
            conso_out = conso_model.trunk(fake_test_input)[0]
            slice_out = slice_model.trunk(fake_test_input)[0]
            assert torch.allclose(
                ref_out, shard_out), f"{ref_out.sum()} vs {shard_out.sum()}"
            assert torch.allclose(
                ref_out, conso_out), f"{ref_out.sum()} vs {conso_out.sum()}"
            assert torch.allclose(
                ref_out, slice_out), f"{ref_out.sum()} vs {slice_out.sum()}"