Beispiel #1
0
    def _restore_model_weights(self, model):
        """
        If using a weights file to initialize the model, we load the weights
        and initialize the model. Since the weights file specified
        by user might not be VISSL trained weights, we expose several config
        options like APPEND_PREFIX, etc to allow successful loading of the weights.
        See MODEL.WEIGHTS_INIT description in vissl/config/defaults.yaml for details.
        """
        params_from_file = self.config["MODEL"]["WEIGHTS_INIT"]
        init_weights_path = params_from_file["PARAMS_FILE"]
        assert init_weights_path, "Shouldn't call this when init_weight_path is empty"
        logging.info(f"Initializing model from: {init_weights_path}")

        if PathManager.exists(init_weights_path):
            weights = load_and_broadcast_checkpoint(init_weights_path,
                                                    device=torch.device("cpu"))
            skip_layers = params_from_file.get("SKIP_LAYERS", [])
            replace_prefix = params_from_file.get("REMOVE_PREFIX", None)
            append_prefix = params_from_file.get("APPEND_PREFIX", None)
            state_dict_key_name = params_from_file.get("STATE_DICT_KEY_NAME",
                                                       None)

            # we initialize the weights from this checkpoint. However, we
            # don't care about the other metadata like iteration number etc.
            # So the method only reads the state_dict
            init_model_from_weights(
                self.config,
                model,
                weights,
                state_dict_key_name=state_dict_key_name,
                skip_layers=skip_layers,
                replace_prefix=replace_prefix,
                append_prefix=append_prefix,
            )
        return model
Beispiel #2
0
    def prepare(self):
        """Prepares task for training, populates all derived attributes """

        pin_memory = self.use_gpu and torch.cuda.device_count() > 1

        self.phases = self._build_phases()
        self.train = False if self.test_only else self.train
        self.dataloaders = self.build_dataloaders(
            pin_memory=pin_memory,
            multiprocessing_context=mp.get_context(self.dataloader_mp_context),
        )

        if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH:
            self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(
                self.base_model)
        elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX:
            self.base_model = apex.parallel.convert_syncbn_model(
                self.base_model)

        # move the model and loss to the right device
        if self.use_gpu:
            self.base_model, self.loss = copy_model_to_gpu(
                self.base_model, self.loss)
        else:
            self.loss.cpu()
            self.base_model.cpu()

        if self.optimizer is not None:
            # initialize the pytorch optimizer now since the model has been moved to
            # the appropriate device
            self.optimizer.init_pytorch_optimizer(self.base_model,
                                                  loss=self.loss)

        if self.amp_args is not None:
            # Initialize apex.amp. This updates the model and the PyTorch optimizer (
            # if training, which is wrapped by the ClassyOptimizer in self.optimizer).
            # Please note this must happen before loading the checkpoint, cause
            # there's amp state to be restored.

            if self.optimizer is None:
                self.base_model = apex.amp.initialize(self.base_model,
                                                      optimizers=None,
                                                      **self.amp_args)
            else:
                self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                    self.base_model, self.optimizer.optimizer, **self.amp_args)

        if self.checkpoint_path:
            self.checkpoint_dict = load_and_broadcast_checkpoint(
                self.checkpoint_path)

        classy_state_dict = (None if self.checkpoint_dict is None else
                             self.checkpoint_dict["classy_state_dict"])

        if classy_state_dict is not None:
            state_load_success = update_classy_state(self, classy_state_dict)
            assert (state_load_success
                    ), "Update classy state from checkpoint was unsuccessful."

        self.init_distributed_data_parallel_model()
Beispiel #3
0
    def prepare(self, pin_memory: bool = False):
        """
        Prepares the task:
        - dataloaders
        - model
        - copy model to correct device
        - meters
        - loss
        - optimizer
        - LR schedulers
        - AMP state
        - resume from a checkpoint if available
        """
        self.dataloaders = self.build_dataloaders(pin_memory=pin_memory)
        self.phases = self._build_phases()
        train_phases = [phase for phase in self.phases if phase["train"]]
        num_train_phases = len(train_phases)
        self.base_model = self._build_model()
        self._set_ddp_options()
        self.base_loss = self._build_loss()
        self.meters = self._build_meters()
        self.optimizer = self._build_optimizer()
        self.optimizer_schedulers = self._build_optimizer_schedulers()
        self.num_train_phases = num_train_phases

        self.base_loss = self.base_loss.to(self.device)
        if self.device.type == "cuda":
            self.base_model = copy_model_to_gpu(self.base_model)

        # initialize the pytorch optimizer now since the model has been moved to
        # the appropriate device.
        self.prepare_optimizer()

        # Enable mixed precision grad scalers
        if self.amp_type == AmpType.APEX:
            # Allow Apex Amp to perform casts as specified by the amp_args.
            # This updates the model and the PyTorch optimizer (which is wrapped
            # by the ClassyOptimizer in self.optimizer).
            # NOTE: this must happen before loading the checkpoint. See
            # https://nvidia.github.io/apex/amp.html#checkpointing for more details.
            self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                self.base_model, self.optimizer.optimizer, **self.amp_args
            )

        # Restore an hypothetical checkpoint
        vissl_state_dict = None
        if self.checkpoint_path is not None:
            self.checkpoint = load_and_broadcast_checkpoint(
                checkpoint_path=self.checkpoint_path, device=torch.device("cpu")
            )
            self.iteration = self.checkpoint["iteration"]
            self.local_iteration_num = self.checkpoint["iteration_num"]
            vissl_state_dict = self.checkpoint.get("classy_state_dict")
            if "loss" in self.checkpoint:
                self.base_loss.load_state_dict(self.checkpoint["loss"])
                logging.info("======Loaded loss state from checkpoint======")

        return self._update_classy_state(vissl_state_dict)
Beispiel #4
0
    def prepare(self):
        """Prepares task for training, populates all derived attributes """

        self.phases = self._build_phases()
        self.train = False if self.test_only else self.train

        if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH:
            self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(
                self.base_model)
        elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX:
            sync_bn_process_group = apex.parallel.create_syncbn_process_group(
                self.batch_norm_sync_group_size)
            self.base_model = apex.parallel.convert_syncbn_model(
                self.base_model, process_group=sync_bn_process_group)

        # move the model and loss to the right device
        if self.use_gpu:
            self.base_model, self.base_loss = copy_model_to_gpu(
                self.base_model, self.base_loss)
        else:
            self.base_loss.cpu()
            self.base_model.cpu()

        if self.optimizer is not None:
            self.prepare_optimizer(optimizer=self.optimizer,
                                   model=self.base_model,
                                   loss=self.base_loss)

        if self.amp_args is not None:
            # Initialize apex.amp. This updates the model and the PyTorch optimizer (
            # if training, which is wrapped by the ClassyOptimizer in self.optimizer).
            # Please note this must happen before loading the checkpoint, cause
            # there's amp state to be restored.

            if self.optimizer is None:
                self.base_model = apex.amp.initialize(self.base_model,
                                                      optimizers=None,
                                                      **self.amp_args)
            else:
                self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                    self.base_model, self.optimizer.optimizer, **self.amp_args)

        if self.checkpoint_path:
            self.checkpoint_dict = load_and_broadcast_checkpoint(
                self.checkpoint_path)

        classy_state_dict = (None if self.checkpoint_dict is None else
                             self.checkpoint_dict["classy_state_dict"])

        if classy_state_dict is not None:
            state_load_success = update_classy_state(self, classy_state_dict)
            assert (state_load_success
                    ), "Update classy state from checkpoint was unsuccessful."

        self.init_distributed_data_parallel_model()
Beispiel #5
0
 def load_and_broadcast_checkpoint(
     cls, checkpoint_folder: str, checkpoint_path: str, device
 ):
     """
     Load the checkpoint at the provided path, dealing with the
     potential indirection due to the notion of sharded checkpoint
     """
     checkpoint = load_and_broadcast_checkpoint(checkpoint_path, device)
     if cls._is_shard_aggregator_checkpoint(checkpoint):
         _, global_rank = get_machine_local_and_dist_rank()
         shard_name = checkpoint["shards"][global_rank]
         shard_path = os.path.join(checkpoint_folder, shard_name)
         checkpoint = load_checkpoint(shard_path, device)
     return checkpoint
Beispiel #6
0
    def prepare(self) -> None:
        super().prepare()
        if self.checkpoint_dict is None:
            # no checkpoint exists, load the model's state from the pretrained
            # checkpoint

            if self.pretrained_checkpoint_path:
                self.pretrained_checkpoint_dict = load_and_broadcast_checkpoint(
                    self.pretrained_checkpoint_path)

            assert (
                self.pretrained_checkpoint_dict
                is not None), "Need a pretrained checkpoint for fine tuning"

            state = self.pretrained_checkpoint_dict["classy_state_dict"]

            state_load_success = update_classy_model(
                self.base_model,
                state["base_model"],
                self.reset_heads,
                self.pretrained_checkpoint_load_strict,
            )
            assert (
                state_load_success
            ), "Update classy state from pretrained checkpoint was unsuccessful."

            self._load_hooks_from_pretrained_checkpoint(state)

        if self.freeze_trunk:
            # do not track gradients for all the parameters in the model except
            # for the parameters in the heads
            for param in self.base_model.parameters():
                param.requires_grad = False
            for heads in self.base_model.get_heads().values():
                for h in heads:
                    for param in h.parameters():
                        param.requires_grad = True
            # re-create ddp model
            self.distributed_model = None
            self.init_distributed_data_parallel_model()