Exemplo n.º 1
0
 def test_update_classy_state(self):
     """
     Tests that the update_classy_state successfully updates from a
     checkpoint
     """
     config = get_fast_test_task_config()
     task = build_task(config)
     task_2 = build_task(config)
     task_2.prepare()
     trainer = LocalTrainer()
     trainer.train(task)
     update_classy_state(task_2, task.get_classy_state(deep_copy=True))
     self._compare_states(task.get_classy_state(), task_2.get_classy_state())
Exemplo n.º 2
0
    def prepare(self):
        """Prepares task for training, populates all derived attributes """

        pin_memory = self.use_gpu and torch.cuda.device_count() > 1

        self.phases = self._build_phases()
        self.train = False if self.test_only else self.train
        self.dataloaders = self.build_dataloaders(
            pin_memory=pin_memory,
            multiprocessing_context=mp.get_context(self.dataloader_mp_context),
        )

        if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH:
            self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(
                self.base_model)
        elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX:
            self.base_model = apex.parallel.convert_syncbn_model(
                self.base_model)

        # move the model and loss to the right device
        if self.use_gpu:
            self.base_model, self.loss = copy_model_to_gpu(
                self.base_model, self.loss)
        else:
            self.loss.cpu()
            self.base_model.cpu()

        if self.optimizer is not None:
            # initialize the pytorch optimizer now since the model has been moved to
            # the appropriate device
            self.optimizer.init_pytorch_optimizer(self.base_model,
                                                  loss=self.loss)

        if self.amp_args is not None:
            # Initialize apex.amp. This updates the model and the PyTorch optimizer (
            # if training, which is wrapped by the ClassyOptimizer in self.optimizer).
            # Please note this must happen before loading the checkpoint, cause
            # there's amp state to be restored.

            if self.optimizer is None:
                self.base_model = apex.amp.initialize(self.base_model,
                                                      optimizers=None,
                                                      **self.amp_args)
            else:
                self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                    self.base_model, self.optimizer.optimizer, **self.amp_args)

        if self.checkpoint_path:
            self.checkpoint_dict = load_and_broadcast_checkpoint(
                self.checkpoint_path)

        classy_state_dict = (None if self.checkpoint_dict is None else
                             self.checkpoint_dict["classy_state_dict"])

        if classy_state_dict is not None:
            state_load_success = update_classy_state(self, classy_state_dict)
            assert (state_load_success
                    ), "Update classy state from checkpoint was unsuccessful."

        self.init_distributed_data_parallel_model()
Exemplo n.º 3
0
    def prepare(self, num_dataloader_workers=0, dataloader_mp_context=None):
        """Prepares task for training, populates all derived attributes

        Args:
            num_dataloader_workers: Number of dataloading processes. If 0,
                dataloading is done on main process
            dataloader_mp_context: Determines how processes are spawned.
                Value must be one of None, "spawn", "fork", "forkserver".
                If None, then context is inherited from parent process
        """

        pin_memory = self.use_gpu and torch.cuda.device_count() > 1

        self.phases = self._build_phases()
        self.dataloaders = self.build_dataloaders(
            num_workers=num_dataloader_workers,
            pin_memory=pin_memory,
            multiprocessing_context=dataloader_mp_context,
        )

        if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH:
            self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(
                self.base_model)
        elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX:
            self.base_model = apex.parallel.convert_syncbn_model(
                self.base_model)

        # move the model and loss to the right device
        if self.use_gpu:
            self.base_model, self.loss = copy_model_to_gpu(
                self.base_model, self.loss)
        else:
            self.loss.cpu()
            self.base_model.cpu()

        # initialize the pytorch optimizer now since the model has been moved to
        # the appropriate device
        self.optimizer.init_pytorch_optimizer(self.base_model, loss=self.loss)

        if self.amp_args is not None:
            # Initialize apex.amp. This updates the model and the PyTorch optimizer (
            # which is wrapped by the ClassyOptimizer in self.optimizer).
            # Please note this must happen before loading the checkpoint, cause
            # there's amp state to be restored.
            self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                self.base_model, self.optimizer.optimizer, **self.amp_args)

        classy_state_dict = (None if self.checkpoint is None else
                             self.checkpoint.get("classy_state_dict"))

        if classy_state_dict is not None:
            state_load_success = update_classy_state(self, classy_state_dict)
            assert (state_load_success
                    ), "Update classy state from checkpoint was unsuccessful."

        self.init_distributed_data_parallel_model()
Exemplo n.º 4
0
    def prepare(self):
        """Prepares task for training, populates all derived attributes """

        self.phases = self._build_phases()
        self.train = False if self.test_only else self.train

        if self.batch_norm_sync_mode == BatchNormSyncMode.PYTORCH:
            self.base_model = nn.SyncBatchNorm.convert_sync_batchnorm(
                self.base_model)
        elif self.batch_norm_sync_mode == BatchNormSyncMode.APEX:
            sync_bn_process_group = apex.parallel.create_syncbn_process_group(
                self.batch_norm_sync_group_size)
            self.base_model = apex.parallel.convert_syncbn_model(
                self.base_model, process_group=sync_bn_process_group)

        # move the model and loss to the right device
        if self.use_gpu:
            self.base_model, self.base_loss = copy_model_to_gpu(
                self.base_model, self.base_loss)
        else:
            self.base_loss.cpu()
            self.base_model.cpu()

        if self.optimizer is not None:
            self.prepare_optimizer(optimizer=self.optimizer,
                                   model=self.base_model,
                                   loss=self.base_loss)

        if self.amp_args is not None:
            # Initialize apex.amp. This updates the model and the PyTorch optimizer (
            # if training, which is wrapped by the ClassyOptimizer in self.optimizer).
            # Please note this must happen before loading the checkpoint, cause
            # there's amp state to be restored.

            if self.optimizer is None:
                self.base_model = apex.amp.initialize(self.base_model,
                                                      optimizers=None,
                                                      **self.amp_args)
            else:
                self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                    self.base_model, self.optimizer.optimizer, **self.amp_args)

        if self.checkpoint_path:
            self.checkpoint_dict = load_and_broadcast_checkpoint(
                self.checkpoint_path)

        classy_state_dict = (None if self.checkpoint_dict is None else
                             self.checkpoint_dict["classy_state_dict"])

        if classy_state_dict is not None:
            state_load_success = update_classy_state(self, classy_state_dict)
            assert (state_load_success
                    ), "Update classy state from checkpoint was unsuccessful."

        self.init_distributed_data_parallel_model()
    def prepare(
        self,
        num_dataloader_workers=0,
        pin_memory=False,
        use_gpu=False,
        dataloader_mp_context=None,
    ):
        """Prepares task for training, populates all derived attributes

        Args:
            num_dataloader_workers: Number of dataloading processes. If 0,
                dataloading is done on main process
            pin_memory: if true pin memory on GPU
            use_gpu: if true, load model, optimizer, loss, etc on GPU
            dataloader_mp_context: Determines how processes are spawned.
                Value must be one of None, "spawn", "fork", "forkserver".
                If None, then context is inherited from parent process
        """
        self.phases = self._build_phases()
        self.dataloaders = self.build_dataloaders(
            num_workers=num_dataloader_workers,
            pin_memory=pin_memory,
            multiprocessing_context=dataloader_mp_context,
        )

        # move the model and loss to the right device
        if use_gpu:
            self.base_model, self.loss = copy_model_to_gpu(self.base_model, self.loss)
        else:
            self.loss.cpu()
            self.base_model.cpu()

        # initialize the pytorch optimizer now since the model has been moved to
        # the appropriate device
        self.optimizer.init_pytorch_optimizer(self.base_model, loss=self.loss)

        classy_state_dict = (
            None
            if self.checkpoint is None
            else self.checkpoint.get("classy_state_dict")
        )

        if classy_state_dict is not None:
            state_load_success = update_classy_state(self, classy_state_dict)
            assert (
                state_load_success
            ), "Update classy state from checkpoint was unsuccessful."

        if self.amp_opt_level is not None:
            # Initialize apex.amp. This updates the model and the PyTorch optimizer (
            # which is wrapped by the ClassyOptimizer in self.optimizer)
            self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                self.base_model, self.optimizer.optimizer, opt_level=self.amp_opt_level
            )
        self.init_distributed_data_parallel_model()