Example #1
0
    def prepare(self, pin_memory: bool = False):
        """
        Prepares the task:
        - dataloaders
        - model
        - copy model to correct device
        - meters
        - loss
        - optimizer
        - LR schedulers
        - AMP state
        - resume from a checkpoint if available
        """
        self.dataloaders = self.build_dataloaders(pin_memory=pin_memory)
        self.phases = self._build_phases()
        train_phases = [phase for phase in self.phases if phase["train"]]
        num_train_phases = len(train_phases)
        self.base_model = self._build_model()
        self._set_ddp_options()
        self.base_loss = self._build_loss()
        self.meters = self._build_meters()
        self.optimizer = self._build_optimizer()
        self.optimizer_schedulers = self._build_optimizer_schedulers()
        self.num_train_phases = num_train_phases

        self.base_loss = self.base_loss.to(self.device)
        if self.device.type == "cuda":
            self.base_model = copy_model_to_gpu(self.base_model)

        # initialize the pytorch optimizer now since the model has been moved to
        # the appropriate device.
        self.prepare_optimizer()

        # Enable mixed precision grad scalers
        if self.amp_type == AmpType.APEX:
            # Allow Apex Amp to perform casts as specified by the amp_args.
            # This updates the model and the PyTorch optimizer (which is wrapped
            # by the ClassyOptimizer in self.optimizer).
            # NOTE: this must happen before loading the checkpoint. See
            # https://nvidia.github.io/apex/amp.html#checkpointing for more details.
            self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                self.base_model, self.optimizer.optimizer, **self.amp_args)

        # Restore an hypothetical checkpoint
        vissl_state_dict = None
        if self.checkpoint_path is not None:
            self.checkpoint = CheckpointLoader.load_and_broadcast_checkpoint(
                checkpoint_folder=self.checkpoint_folder,
                checkpoint_path=self.checkpoint_path,
                device=torch.device("cpu"),
            )

            self.iteration = self.checkpoint["iteration"]
            self.local_iteration_num = self.checkpoint["iteration_num"]
            vissl_state_dict = self.checkpoint.get("classy_state_dict")
            if "loss" in self.checkpoint:
                self.base_loss.load_state_dict(self.checkpoint["loss"])
                logging.info("======Loaded loss state from checkpoint======")

        return self._update_classy_state(vissl_state_dict)
Example #2
0
    def prepare(self, pin_memory: bool = False):
        """
        Prepares the task:
        - dataloaders
        - model
        - copy model to correct device
        - meters
        - loss
        - optimizer
        - LR schedulers
        - AMP state
        - resume from a checkpoint if available
        """
        self.phases = self._build_phases()
        self.num_phases = len(self.phases)
        self.base_model = self._build_model()
        self._set_ddp_options()
        self.meters = self._build_meters()
        self.optimizer = self._build_optimizer()
        self.optimizer_schedulers = self._build_optimizer_schedulers()

        if self.device.type == "cuda":
            self.base_model = copy_model_to_gpu(self.base_model)

        # initialize the pytorch optimizer now since the model has been moved to
        # the appropriate device.
        self.prepare_optimizer()

        # Enable mixed precision grad scalers
        if self.amp_type == AmpType.APEX:
            # Allow Apex Amp to perform casts as specified by the amp_args.
            # This updates the model and the PyTorch optimizer (which is wrapped
            # by the ClassyOptimizer in self.optimizer).
            # NOTE: this must happen before loading the checkpoint. See
            # https://nvidia.github.io/apex/amp.html#checkpointing for more details.
            self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                self.base_model, self.optimizer.optimizer, **self.amp_args
            )

        # Create EMA average of the model if hook is specified.
        ema_config = self.config["HOOKS"]["EMA_MODEL"]
        if ema_config["ENABLE_EMA_METERS"] or ema_config["SAVE_EMA_MODEL"]:
            self._create_ema_model()

        # Restore an hypothetical checkpoint
        vissl_state_dict = None
        if self.checkpoint_path is not None:
            self.checkpoint = CheckpointLoader.load_and_broadcast_checkpoint(
                checkpoint_folder=self.checkpoint_folder,
                checkpoint_path=self.checkpoint_path,
                device=torch.device("cpu"),
            )
            if self.checkpoint is not None:
                self.iteration = self.checkpoint["iteration"]
                self.local_iteration_num = self.checkpoint["iteration_num"]
                vissl_state_dict = self.checkpoint.get("classy_state_dict")
            else:
                raise ValueError(f"Could not load checkpoint: {self.checkpoint_path}")

        current_train_phase_idx = (
            vissl_state_dict["train_phase_idx"] + 1 if vissl_state_dict else 0
        )

        self.datasets, self.data_and_label_keys = self.build_datasets(
            current_train_phase_idx
        )

        # set dataset state before building dataloader, in order to capture checkpoint info.
        if vissl_state_dict and "train" in self.datasets:
            self.datasets["train"].set_classy_state(
                vissl_state_dict.get("train_dataset_iterator")
            )

        self.dataloaders = self.build_dataloaders(
            pin_memory=pin_memory, current_train_phase_idx=current_train_phase_idx
        )

        # Build base loss, move to device, and load from checkpoint if applicable
        self.base_loss = self._build_loss()
        self.base_loss = self.base_loss.to(self.device)
        if self.checkpoint and "loss" in self.checkpoint:
            self.base_loss.load_state_dict(self.checkpoint["loss"])
            logging.info("======Loaded loss state from checkpoint======")

        return self._update_classy_state(vissl_state_dict)