Beispiel #1
0
    def init_distributed_data_parallel_model(self):
        """
        This method overloads the ClassificationTask class's method from ClassyVision.
        """
        if not is_distributed_training_run():
            return

        assert get_cuda_device_index(
        ) > -1, "Distributed training not setup correctly"

        # TODO (Min): We can load checkpoint, but it ends up setting the trunk's _is_root
        # flag to true. We need to set it back to None here.
        # Also, right now, the head's weight is only partially loaded from the checkpoint
        # because we dump the checkpoint after the head if wrapped, but loading it before
        # it is wrapped.
        # For very big models, we need re-work the checkpoint logic because we don't have
        # enough memory to load the entire model on one node. We need to use local_state_dict()
        # API to load checkpoint shards.
        for module in self.base_model.trunk.modules():
            if isinstance(module, FSDP):
                module._is_root = None

        # Then, wrap the whole model. We replace the base_model since it is used
        # when checkpoint is taken.
        fsdp_config = self.config["MODEL"]["FSDP_CONFIG"]
        self.base_model = FSDP(module=self.base_model, **fsdp_config)
        self.distributed_model = self.base_model
Beispiel #2
0
    def on_forward(self, task: tasks.ClassyTask) -> None:
        """
        - Update the momentum encoder.
        - Compute the key reusing the updated moco-encoder. If we use the
          batch shuffling, the perform global shuffling of the batch
          and then run the moco encoder to compute the features.
          We unshuffle the computer features and use the features
          as "key" in computing the moco loss.
        """

        # Update the momentum encoder
        if task.loss.moco_encoder is None:
            self._build_moco_encoder(task)
            self.is_distributed = is_distributed_training_run()
            logging.info("MoCo: Distributed setup, shuffling batches")
        else:
            self._update_momentum_encoder(task)

        # Compute key features. We do not backpropagate in this codepath
        im_k = task.last_batch.sample["data_momentum"][0]

        if self.is_distributed and self.shuffle_batch:
            # shuffle for making use of BN
            im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k, task)

        k = task.loss.moco_encoder(im_k)[0]
        k = torch.nn.functional.normalize(k, dim=1)

        if self.is_distributed and self.shuffle_batch:
            # undo shuffle
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        task.loss.key = k
Beispiel #3
0
    def init_distributed_data_parallel_model(self):
        """
        Initialize
        `torch.nn.parallel.distributed.DistributedDataParallel <https://pytorch.org/
        docs/stable/nn.html#distributeddataparallel>`_.

        Needed for distributed training. This is where a model should be wrapped by DDP.
        """
        if not is_distributed_training_run():
            return
        assert (self.distributed_model is
                None), "init_ddp_non_elastic must only be called once"

        broadcast_buffers = (
            self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS)
        self.distributed_model = init_distributed_data_parallel_model(
            self.base_model,
            broadcast_buffers=broadcast_buffers,
            find_unused_parameters=self.find_unused_parameters,
        )
        if (isinstance(self.base_loss, ClassyLoss)
                and self.base_loss.has_learned_parameters()):
            logging.info("Initializing distributed loss")
            self.distributed_loss = init_distributed_data_parallel_model(
                self.base_loss,
                broadcast_buffers=broadcast_buffers,
                find_unused_parameters=self.find_unused_parameters,
            )
Beispiel #4
0
    def init_distributed_data_parallel_model(self):
        """
        Initialize
        `torch.nn.parallel.distributed.DistributedDataParallel <https://pytorch.org/
        docs/stable/nn.html#distributeddataparallel>`_.

        Needed for distributed training. This is where a model should be wrapped by DDP.
        """
        if not is_distributed_training_run():
            return
        assert (
            self.distributed_model is None
        ), "init_ddp_non_elastic must only be called once"

        broadcast_buffers = (
            self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS
        )

        if self.use_sharded_ddp:
            if not isinstance(self.optimizer, ZeRO):
                raise ValueError(
                    "ShardedDataParallel engine should only be used in conjunction with ZeRO optimizer"
                )
            from fairscale.nn.data_parallel import ShardedDataParallel

            # Replace the original DDP wrap by the shard-aware ShardedDDP
            self.distributed_model = ShardedDataParallel(
                module=self.base_model,
                sharded_optimizer=self.optimizer.optimizer,
                broadcast_buffers=broadcast_buffers,
            )
        else:
            self.distributed_model = init_distributed_data_parallel_model(
                self.base_model,
                broadcast_buffers=broadcast_buffers,
                find_unused_parameters=self.find_unused_parameters,
                bucket_cap_mb=self.ddp_bucket_cap_mb,
            )
            if self.fp16_grad_compress:

                from torch.distributed.algorithms import ddp_comm_hooks

                # FP16 hook is stateless and only takes a process group as the state.
                # We use the default process group so we set the state to None.
                process_group = None
                self.distributed_model.register_comm_hook(
                    process_group,
                    ddp_comm_hooks.default_hooks.fp16_compress_hook,
                )
        if (
            isinstance(self.base_loss, ClassyLoss)
            and self.base_loss.has_learned_parameters()
        ):
            logging.info("Initializing distributed loss")
            self.distributed_loss = init_distributed_data_parallel_model(
                self.base_loss,
                broadcast_buffers=broadcast_buffers,
                find_unused_parameters=self.find_unused_parameters,
                bucket_cap_mb=self.ddp_bucket_cap_mb,
            )
    def train(self, task: ClassyTask):
        """Runs training phases, phases are generated from the config.

        Args:
            task: Task to be used in training. It should contain
                everything that is needed for training
        """

        pin_memory = self.use_gpu and torch.cuda.device_count() > 1
        task.prepare(
            num_dataloader_workers=self.num_dataloader_workers,
            pin_memory=pin_memory,
            use_gpu=self.use_gpu,
            dataloader_mp_context=self.dataloader_mp_context,
        )
        assert isinstance(task, ClassyTask)

        if is_distributed_training_run():
            task.init_distributed_data_parallel_model()

        local_variables = {}
        task.run_hooks(local_variables, ClassyHookFunctions.on_start.name)
        best_acc = {
            'top1_acc': 0,
            'top1_epoch': 0,
            'top5_acc': 0,
            'top5_epoch': 0
        }
        epoch = 0
        while not task.done_training():
            task.advance_phase()

            # Start phase hooks
            task.run_hooks(local_variables,
                           ClassyHookFunctions.on_phase_start.name)
            while True:
                # Process next sample
                try:
                    task.train_step(self.use_gpu, local_variables)
                except StopIteration:
                    break

            logging.info("Syncing meters on phase end...")
            for meter in task.meters:
                meter.sync_state()
            logging.info("...meters synced")
            barrier()
            meter = task.run_hooks(local_variables,
                                   ClassyHookFunctions.on_phase_end.name)
            if meter is not None:
                if meter[0].value['top_1'] > best_acc['top1_acc']:
                    best_acc['top1_acc'] = meter[0].value['top_1']
                    best_acc['top5_acc'] = meter[0].value['top_5']
                    best_acc['top1_epoch'] = epoch
                    best_acc['top5_epoch'] = epoch
            epoch += 1

        task.run_hooks(local_variables, ClassyHookFunctions.on_end.name)
        return best_acc
Beispiel #6
0
 def __init__(self, loss_config: AttrDict):
     super().__init__()
     self.loss_config = loss_config
     self.momentum_teacher = None
     self.checkpoint = None
     self.teacher_output = None
     self.teacher_temp = None
     self.is_distributed = is_distributed_training_run()
     self.use_gpu = get_cuda_device_index() > -1
     self.center = None
    def prepare(
        self,
        num_dataloader_workers=0,
        pin_memory=False,
        use_gpu=False,
        dataloader_mp_context=None,
    ):
        """Prepares task for training, populates all derived attributes

        Args:
            num_dataloader_workers: Number of dataloading processes. If 0,
                dataloading is done on main process
            pin_memory: if true pin memory on GPU
            use_gpu: if true, load model, optimizer, loss, etc on GPU
            dataloader_mp_context: Determines how processes are spawned.
                Value must be one of None, "spawn", "fork", "forkserver".
                If None, then context is inherited from parent process
        """
        self.phases = self._build_phases()
        self.dataloaders = self.build_dataloaders(
            num_workers=num_dataloader_workers,
            pin_memory=pin_memory,
            multiprocessing_context=dataloader_mp_context,
        )

        # move the model and loss to the right device
        if use_gpu:
            self.loss.cuda()
            self.base_model = copy_model_to_gpu(self.base_model)
        else:
            self.loss.cpu()
            self.base_model.cpu()

        # initialize the pytorch optimizer now since the model has been moved to
        # the appropriate device
        self.optimizer.init_pytorch_optimizer(self.base_model)

        classy_state_dict = (None if self.checkpoint is None else
                             self.checkpoint.get("classy_state_dict"))

        if classy_state_dict is not None:
            state_load_success = update_classy_state(self, classy_state_dict)
            assert (state_load_success
                    ), "Update classy state from checkpoint was unsuccessful."

        if self.amp_opt_level is not None:
            # Initialize apex.amp. This updates the model and the PyTorch optimizer (
            # which is wrapped by the ClassyOptimizer in self.optimizer)
            self.base_model, self.optimizer.optimizer = apex.amp.initialize(
                self.base_model,
                self.optimizer.optimizer,
                opt_level=self.amp_opt_level)

        if is_distributed_training_run():
            self.init_distributed_data_parallel_model()
Beispiel #8
0
    def init_distributed_data_parallel_model(self):
        """
        This method overloads the ClassificationTask class's method from ClassyVision.
        """
        if not is_distributed_training_run():
            return

        for module in self.base_model.modules():
            if isinstance(module, FullyShardedDataParallel):
                raise ValueError(
                    "DistributedDataParallel should not be used"
                    "with a FullyShardedDataParallel model.\n"
                    "Please set config.TRAINER.TASK_NAME='self_supervision_fsdp_task'"
                )

        super().init_distributed_data_parallel_model()
Beispiel #9
0
    def init_distributed_data_parallel_model(self):
        """
        Initialize FSDP if needed.

        This method overloads the ClassificationTask class's method from ClassyVision.
        """
        if not is_distributed_training_run():
            return

        # Make sure default cuda device is set. TODO (Min): we should enable FSDP can
        # be enabled for 1-GPU as well, but the use case there is likely different.
        # I.e. perhaps we use it for cpu_offloading.
        assert get_cuda_device_index(
        ) > -1, "Distributed training not setup correctly"

        # The model might be already wrapped by FSDP internally. Check regnet_fsdp.py.
        # Here, we wrap it at the outer most level.
        fsdp_config = self.config["MODEL"]["FSDP_CONFIG"]
        if is_primary():
            logging.info(f"Using FSDP, config: {fsdp_config}")

        # First, wrap the head's prototype_i layers if it is SWAV.
        # TODO (Min): make this more general for different models, which may have multiple
        #             heads.
        head0 = self.base_model.heads[0]
        if isinstance(head0, SwAVPrototypesHead):
            for j in range(head0.nmb_heads):
                module = getattr(head0, "prototypes" + str(j))
                module = FSDP(module=module, **fsdp_config)
                setattr(head0, "prototypes" + str(j), module)

        # TODO (Min): We can load checkpoint, but it ends up setting the trunk's _is_root
        # flag to true. We need to set it back to None here.
        # Also, right now, the head's weight is only partially loaded from the checkpoint
        # because we dump the checkpoint after the head if wrapped, but loading it before
        # it is wrapped.
        # For very big models, we need re-work the checkpoint logic because we don't have
        # enough memory to load the entire model on one node. We need to use local_state_dict()
        # API to load checkpoint shards.
        for module in self.base_model.trunk.modules():
            if isinstance(module, FSDP):
                module._is_root = None

        # Then, wrap the whole model. We replace the base_model since it is used
        # when checkpoint is taken.
        self.base_model = FSDP(module=self.base_model, **fsdp_config)
        self.distributed_model = self.base_model
Beispiel #10
0
    def __init__(self, loss_config: AttrDict):
        super().__init__()
        self.loss_config = loss_config

        self.momentum_encoder = None
        self.checkpoint = None
        self.momentum_scores = None
        self.momentum_embeddings = None
        self.is_distributed = is_distributed_training_run()
        self.use_gpu = get_cuda_device_index() > -1
        self.softmax = nn.Softmax(dim=1)

        # keep track of number of iterations
        self.register_buffer("num_iteration", torch.zeros(1, dtype=int))

        # for queue
        self.use_queue = False
        if self.loss_config.queue.local_queue_length > 0:
            self.initialize_queue()
Beispiel #11
0
def gather_from_all(tensor: torch.Tensor) -> torch.Tensor:
    """
    Similar to classy_vision.generic.distributed_util.gather_from_all
    except that it does not cut the gradients
    """
    if tensor.ndim == 0:
        # 0 dim tensors cannot be gathered. so unsqueeze
        tensor = tensor.unsqueeze(0)

    if is_distributed_training_run():
        tensor, orig_device = convert_to_distributed_tensor(tensor)
        gathered_tensors = GatherLayer.apply(tensor)
        gathered_tensors = [
            convert_to_normal_tensor(_tensor, orig_device)
            for _tensor in gathered_tensors
        ]
    else:
        gathered_tensors = [tensor]
    gathered_tensor = torch.cat(gathered_tensors, 0)
    return gathered_tensor
Beispiel #12
0
    def init_distributed_data_parallel_model(self):
        """
        Initialize FSDP if needed.

        This method overloads the ClassificationTask class's method from ClassyVision.
        """
        if not is_distributed_training_run():
            return

        # Make sure default cuda device is set. TODO (Min): we should ensure FSDP can
        # be enabled for 1-GPU as well, but the use case there is likely different.
        # I.e. perhaps we use it for cpu_offloading.
        assert get_cuda_device_index(
        ) > -1, "Distributed training not setup correctly"

        # The model might be already wrapped by FSDP internally. Check regnet_fsdp.py.
        # Here, we wrap it at the outer most level.
        fsdp_config = self.config["MODEL"]["FSDP_CONFIG"]
        if is_primary():
            logging.info(f"Using FSDP, config: {fsdp_config}")

        # First, wrap the head's prototype_i layers if it is SWAV.
        # TODO (Min): make this more general for different models, which may have multiple
        #             heads.
        if len(self.base_model.heads) != 1:
            raise ValueError(
                f"FSDP only support 1 head, not {len(self.base_model.heads)} heads"
            )
        head0 = self.base_model.heads[0]
        if isinstance(head0, SwAVPrototypesHead):
            # This is important for convergence!
            #
            # Since we "normalize" this layer in the update hook, we need to keep its
            # weights in full precision. It is output is going into the loss and used
            # for clustering, so we need to have that in full precision as well.
            fp_fsdp_config = fsdp_config.copy()
            fp_fsdp_config["flatten_parameters"] = False
            fp_fsdp_config["mixed_precision"] = False
            fp_fsdp_config["fp32_reduce_scatter"] = False
            for j in range(head0.nmb_heads):
                module = getattr(head0, "prototypes" + str(j))
                module = FSDP(module=module, **fp_fsdp_config)
                setattr(head0, "prototypes" + str(j), module)
        head0 = FSDP(module=head0, **fsdp_config)
        self.base_model.heads[0] = head0

        # Init the head properly since the weights are potentially initialized on different
        # ranks with different seeds. We first summon the full params from all workers.
        # Then, within that context, we set a fixed random seed so that all workers init the
        # weights the same way. Finally, we reset the layer's weights using reset_parameters().
        #
        # TODO (Min): This will go away once we have a way to sync from rank 0.
        with head0.summon_full_params():
            with set_torch_seed(self.config["SEED_VALUE"]):
                for m in head0.modules():
                    if isinstance(m, Linear):
                        m.reset_parameters()
        head0._reset_lazy_init()
        head0.prototypes0._reset_lazy_init()

        # TODO (Min): We can load checkpoint, but it ends up setting the trunk's _is_root
        # flag to true. We need to set it back to None here.
        # Also, right now, the head's weight is only partially loaded from the checkpoint
        # because we dump the checkpoint after the head if wrapped, but loading it before
        # it is wrapped.
        # For very big models, we need re-work the checkpoint logic because we don't have
        # enough memory to load the entire model on one node. We need to use local_state_dict()
        # API to load checkpoint shards.
        for module in self.base_model.trunk.modules():
            if isinstance(module, FSDP):
                module._is_root = None

        # Then, wrap the whole model. We replace the base_model since it is used
        # when checkpoint is taken.
        self.base_model = FSDP(module=self.base_model, **fsdp_config)
        self.distributed_model = self.base_model
Beispiel #13
0
 def model(self):
     """Returns model used in training (can be wrapped with DDP)"""
     return (self.distributed_model
             if is_distributed_training_run() else self.base_model)
Beispiel #14
0
class TestClassificationTask(unittest.TestCase):
    def _compare_model_state(self,
                             model_state_1,
                             model_state_2,
                             check_heads=True):
        compare_model_state(self, model_state_1, model_state_2, check_heads)

    def _compare_samples(self, sample_1, sample_2):
        compare_samples(self, sample_1, sample_2)

    def _compare_states(self, state_1, state_2, check_heads=True):
        compare_states(self, state_1, state_2)

    def setUp(self):
        # create a base directory to write checkpoints to
        self.base_dir = tempfile.mkdtemp()

    def tearDown(self):
        # delete all the temporary data created
        shutil.rmtree(self.base_dir)

    def test_build_task(self):
        config = get_test_task_config()
        task = build_task(config)
        self.assertTrue(isinstance(task, ClassificationTask))

    def test_hooks_config_builds_correctly(self):
        config = get_test_task_config()
        config["hooks"] = [{"name": "loss_lr_meter_logging"}]
        task = build_task(config)
        self.assertTrue(len(task.hooks) == 1)
        self.assertTrue(isinstance(task.hooks[0], LossLrMeterLoggingHook))

    def test_get_state(self):
        config = get_test_task_config()
        loss = build_loss(config["loss"])
        task = (
            ClassificationTask().set_num_epochs(1).set_loss(loss).set_model(
                build_model(config["model"])).set_optimizer(
                    build_optimizer(config["optimizer"])))
        for phase_type in ["train", "test"]:
            dataset = build_dataset(config["dataset"][phase_type])
            task.set_dataset(dataset, phase_type)

        task.prepare()

        task = build_task(config)
        task.prepare()

    def test_synchronize_losses_non_distributed(self):
        """
        Tests that synchronize losses has no side effects in a non-distributed setting.
        """
        test_config = get_fast_test_task_config()
        task = build_task(test_config)
        task.prepare()

        old_losses = copy.deepcopy(task.losses)
        task.synchronize_losses()
        self.assertEqual(old_losses, task.losses)

    def test_synchronize_losses_when_losses_empty(self):
        config = get_fast_test_task_config()
        task = build_task(config)
        task.prepare()

        task.set_use_gpu(torch.cuda.is_available())

        # Losses should be empty when creating task
        self.assertEqual(len(task.losses), 0)

        task.synchronize_losses()

    def test_checkpointing(self):
        """
        Tests checkpointing by running train_steps to make sure the train_steps
        run the same way after loading from a checkpoint.
        """
        config = get_fast_test_task_config()
        task = build_task(config).set_hooks([LossLrMeterLoggingHook()])
        task_2 = build_task(config).set_hooks([LossLrMeterLoggingHook()])

        task.set_use_gpu(torch.cuda.is_available())

        # only train 1 phase at a time
        trainer = LimitedPhaseTrainer(num_phases=1)

        while not task.done_training():
            # set task's state as task_2's checkpoint
            task_2._set_checkpoint_dict(
                get_checkpoint_dict(task, {}, deep_copy=True))

            # task 2 should have the same state before training
            self._compare_states(task.get_classy_state(),
                                 task_2.get_classy_state())

            # train for one phase
            trainer.train(task)
            trainer.train(task_2)

            # task 2 should have the same state after training
            self._compare_states(task.get_classy_state(),
                                 task_2.get_classy_state())

    def test_final_train_checkpoint(self):
        """Test that a train phase checkpoint with a where of 1.0 can be loaded"""

        config = get_fast_test_task_config()
        task = build_task(config).set_hooks(
            [CheckpointHook(self.base_dir, {}, phase_types=["train"])])
        task_2 = build_task(config)

        task.set_use_gpu(torch.cuda.is_available())

        trainer = LocalTrainer()
        trainer.train(task)

        self.assertAlmostEqual(task.where, 1.0, delta=1e-3)

        # set task_2's state as task's final train checkpoint
        task_2.set_checkpoint(self.base_dir)
        task_2.prepare()

        # we should be able to train the task
        trainer.train(task_2)

    def test_test_only_checkpointing(self):
        """
        Tests checkpointing by running train_steps to make sure the
        train_steps run the same way after loading from a training
        task checkpoint on a test_only task.
        """
        train_config = get_fast_test_task_config()
        train_config["num_epochs"] = 10
        test_config = get_fast_test_task_config()
        test_config["test_only"] = True
        train_task = build_task(train_config).set_hooks(
            [LossLrMeterLoggingHook()])
        test_only_task = build_task(test_config).set_hooks(
            [LossLrMeterLoggingHook()])

        # prepare the tasks for the right device
        train_task.prepare()

        # test in both train and test mode
        trainer = LocalTrainer()
        trainer.train(train_task)

        # set task's state as task_2's checkpoint
        test_only_task._set_checkpoint_dict(
            get_checkpoint_dict(train_task, {}, deep_copy=True))
        test_only_task.prepare()
        test_state = test_only_task.get_classy_state()

        # We expect the phase idx to be different for a test only task
        self.assertEqual(test_state["phase_idx"], -1)

        # We expect that test only state is test, no matter what train state is
        self.assertFalse(test_state["train"])

        # Num updates should be 0
        self.assertEqual(test_state["num_updates"], 0)

        # train_phase_idx should -1
        self.assertEqual(test_state["train_phase_idx"], -1)

        # Verify task will run
        trainer = LocalTrainer()
        trainer.train(test_only_task)

    def test_test_only_task(self):
        """
        Tests the task in test mode by running train_steps
        to make sure the train_steps run as expected on a
        test_only task
        """
        test_config = get_fast_test_task_config()
        test_config["test_only"] = True

        # delete train dataset
        del test_config["dataset"]["train"]

        test_only_task = build_task(test_config).set_hooks(
            [LossLrMeterLoggingHook()])

        test_only_task.prepare()
        test_state = test_only_task.get_classy_state()

        # We expect that test only state is test, no matter what train state is
        self.assertFalse(test_state["train"])

        # Num updates should be 0
        self.assertEqual(test_state["num_updates"], 0)

        # Verify task will run
        trainer = LocalTrainer()
        trainer.train(test_only_task)

    def test_train_only_task(self):
        """
        Tests that the task runs when only a train dataset is specified.
        """
        test_config = get_fast_test_task_config()

        # delete the test dataset from the config
        del test_config["dataset"]["test"]

        task = build_task(test_config).set_hooks([LossLrMeterLoggingHook()])
        task.prepare()

        # verify the the task can still be trained
        trainer = LocalTrainer()
        trainer.train(task)

    @unittest.skipUnless(torch.cuda.is_available(),
                         "This test needs a gpu to run")
    def test_checkpointing_different_device(self):
        config = get_fast_test_task_config()
        task = build_task(config)
        task_2 = build_task(config)

        for use_gpu in [True, False]:
            task.set_use_gpu(use_gpu)
            task.prepare()

            # set task's state as task_2's checkpoint
            task_2._set_checkpoint_dict(
                get_checkpoint_dict(task, {}, deep_copy=True))

            # we should be able to run the trainer using state from a different device
            trainer = LocalTrainer()
            task_2.set_use_gpu(not use_gpu)
            trainer.train(task_2)

    @unittest.skipUnless(is_distributed_training_run(),
                         "This test needs a distributed run")
    def test_get_classy_state_on_loss(self):
        config = get_fast_test_task_config()
        config["loss"] = {"name": "test_stateful_loss", "in_plane": 256}
        task = build_task(config)
        task.prepare()
        self.assertIn("alpha", task.get_classy_state()["loss"])
Beispiel #15
0
class TestClassificationTask(unittest.TestCase):
    def _compare_model_state(self,
                             model_state_1,
                             model_state_2,
                             check_heads=True):
        compare_model_state(self, model_state_1, model_state_2, check_heads)

    def _compare_samples(self, sample_1, sample_2):
        compare_samples(self, sample_1, sample_2)

    def _compare_states(self, state_1, state_2, check_heads=True):
        compare_states(self, state_1, state_2)

    def setUp(self):
        # create a base directory to write checkpoints to
        self.base_dir = tempfile.mkdtemp()

    def tearDown(self):
        # delete all the temporary data created
        shutil.rmtree(self.base_dir)

    def test_build_task(self):
        config = get_test_task_config()
        task = build_task(config)
        self.assertTrue(isinstance(task, ClassificationTask))

    def test_hooks_config_builds_correctly(self):
        config = get_test_task_config()
        config["hooks"] = [{"name": "loss_lr_meter_logging"}]
        task = build_task(config)
        self.assertTrue(len(task.hooks) == 1)
        self.assertTrue(isinstance(task.hooks[0], LossLrMeterLoggingHook))

    def test_get_state(self):
        config = get_test_task_config()
        loss = build_loss(config["loss"])
        task = (
            ClassificationTask().set_num_epochs(1).set_loss(loss).set_model(
                build_model(config["model"])).set_optimizer(
                    build_optimizer(config["optimizer"])))
        for phase_type in ["train", "test"]:
            dataset = build_dataset(config["dataset"][phase_type])
            task.set_dataset(dataset, phase_type)

        task.prepare()

        task = build_task(config)
        task.prepare()

    def test_synchronize_losses_non_distributed(self):
        """
        Tests that synchronize losses has no side effects in a non-distributed setting.
        """
        test_config = get_fast_test_task_config()
        task = build_task(test_config)
        task.prepare()

        old_losses = copy.deepcopy(task.losses)
        task.synchronize_losses()
        self.assertEqual(old_losses, task.losses)

    def test_synchronize_losses_when_losses_empty(self):
        config = get_fast_test_task_config()
        task = build_task(config)
        task.prepare()

        task.set_use_gpu(torch.cuda.is_available())

        # Losses should be empty when creating task
        self.assertEqual(len(task.losses), 0)

        task.synchronize_losses()

    def test_checkpointing(self):
        """
        Tests checkpointing by running train_steps to make sure the train_steps
        run the same way after loading from a checkpoint.
        """
        config = get_fast_test_task_config()
        task = build_task(config).set_hooks([LossLrMeterLoggingHook()])
        task_2 = build_task(config).set_hooks([LossLrMeterLoggingHook()])

        task.set_use_gpu(torch.cuda.is_available())

        # only train 1 phase at a time
        trainer = LimitedPhaseTrainer(num_phases=1)

        while not task.done_training():
            # set task's state as task_2's checkpoint
            task_2._set_checkpoint_dict(
                get_checkpoint_dict(task, {}, deep_copy=True))

            # task 2 should have the same state before training
            self._compare_states(task.get_classy_state(),
                                 task_2.get_classy_state())

            # train for one phase
            trainer.train(task)
            trainer.train(task_2)

            # task 2 should have the same state after training
            self._compare_states(task.get_classy_state(),
                                 task_2.get_classy_state())

    def test_final_train_checkpoint(self):
        """Test that a train phase checkpoint with a where of 1.0 can be loaded"""

        config = get_fast_test_task_config()
        task = build_task(config).set_hooks(
            [CheckpointHook(self.base_dir, {}, phase_types=["train"])])
        task_2 = build_task(config)

        task.set_use_gpu(torch.cuda.is_available())

        trainer = LocalTrainer()
        trainer.train(task)

        self.assertAlmostEqual(task.where, 1.0, delta=1e-3)

        # set task_2's state as task's final train checkpoint
        task_2.set_checkpoint(self.base_dir)
        task_2.prepare()

        # we should be able to train the task
        trainer.train(task_2)

    def test_test_only_checkpointing(self):
        """
        Tests checkpointing by running train_steps to make sure the
        train_steps run the same way after loading from a training
        task checkpoint on a test_only task.
        """
        train_config = get_fast_test_task_config()
        train_config["num_epochs"] = 10
        test_config = get_fast_test_task_config()
        test_config["test_only"] = True
        train_task = build_task(train_config).set_hooks(
            [LossLrMeterLoggingHook()])
        test_only_task = build_task(test_config).set_hooks(
            [LossLrMeterLoggingHook()])

        # prepare the tasks for the right device
        train_task.prepare()

        # test in both train and test mode
        trainer = LocalTrainer()
        trainer.train(train_task)

        # set task's state as task_2's checkpoint
        test_only_task._set_checkpoint_dict(
            get_checkpoint_dict(train_task, {}, deep_copy=True))
        test_only_task.prepare()
        test_state = test_only_task.get_classy_state()

        # We expect the phase idx to be different for a test only task
        self.assertEqual(test_state["phase_idx"], -1)

        # We expect that test only state is test, no matter what train state is
        self.assertFalse(test_state["train"])

        # Num updates should be 0
        self.assertEqual(test_state["num_updates"], 0)

        # train_phase_idx should -1
        self.assertEqual(test_state["train_phase_idx"], -1)

        # Verify task will run
        trainer = LocalTrainer()
        trainer.train(test_only_task)

    def test_test_only_task(self):
        """
        Tests the task in test mode by running train_steps
        to make sure the train_steps run as expected on a
        test_only task
        """
        test_config = get_fast_test_task_config()
        test_config["test_only"] = True

        # delete train dataset
        del test_config["dataset"]["train"]

        test_only_task = build_task(test_config).set_hooks(
            [LossLrMeterLoggingHook()])

        test_only_task.prepare()
        test_state = test_only_task.get_classy_state()

        # We expect that test only state is test, no matter what train state is
        self.assertFalse(test_state["train"])

        # Num updates should be 0
        self.assertEqual(test_state["num_updates"], 0)

        # Verify task will run
        trainer = LocalTrainer()
        trainer.train(test_only_task)

    def test_train_only_task(self):
        """
        Tests that the task runs when only a train dataset is specified.
        """
        test_config = get_fast_test_task_config()

        # delete the test dataset from the config
        del test_config["dataset"]["test"]

        task = build_task(test_config).set_hooks([LossLrMeterLoggingHook()])
        task.prepare()

        # verify the the task can still be trained
        trainer = LocalTrainer()
        trainer.train(task)

    @unittest.skipUnless(torch.cuda.is_available(),
                         "This test needs a gpu to run")
    def test_checkpointing_different_device(self):
        config = get_fast_test_task_config()
        task = build_task(config)
        task_2 = build_task(config)

        for use_gpu in [True, False]:
            task.set_use_gpu(use_gpu)
            task.prepare()

            # set task's state as task_2's checkpoint
            task_2._set_checkpoint_dict(
                get_checkpoint_dict(task, {}, deep_copy=True))

            # we should be able to run the trainer using state from a different device
            trainer = LocalTrainer()
            task_2.set_use_gpu(not use_gpu)
            trainer.train(task_2)

    @unittest.skipUnless(is_distributed_training_run(),
                         "This test needs a distributed run")
    def test_get_classy_state_on_loss(self):
        config = get_fast_test_task_config()
        config["loss"] = {"name": "test_stateful_loss", "in_plane": 256}
        task = build_task(config)
        task.prepare()
        self.assertIn("alpha", task.get_classy_state()["loss"])

    def test_gradient_clipping(self):
        apex_available = True
        try:
            import apex  # noqa F401
        except ImportError:
            apex_available = False

        def train_with_clipped_gradients(amp_args=None):
            task = build_task(get_fast_test_task_config())
            task.set_num_epochs(1)
            task.set_model(SimpleModel())
            task.set_loss(SimpleLoss())
            task.set_meters([])
            task.set_use_gpu(torch.cuda.is_available())
            task.set_clip_grad_norm(0.5)
            task.set_amp_args(amp_args)

            task.set_optimizer(SGD(lr=1))

            trainer = LocalTrainer()
            trainer.train(task)

            return task.model.param.grad.norm()

        grad_norm = train_with_clipped_gradients(None)
        self.assertAlmostEqual(grad_norm, 0.5, delta=1e-2)

        if apex_available and torch.cuda.is_available():
            grad_norm = train_with_clipped_gradients({"opt_level": "O2"})
            self.assertAlmostEqual(grad_norm, 0.5, delta=1e-2)

    def test_clip_stateful_loss(self):
        config = get_fast_test_task_config()
        config["loss"] = {"name": "test_stateful_loss", "in_plane": 256}
        config["grad_norm_clip"] = grad_norm_clip = 1
        task = build_task(config)
        task.set_use_gpu(False)
        task.prepare()

        # set fake gradients with norm > grad_norm_clip
        for param in itertools.chain(task.base_model.parameters(),
                                     task.base_loss.parameters()):
            param.grad = 1.1 + torch.rand(param.shape)
            self.assertGreater(param.grad.norm(), grad_norm_clip)

        task._clip_gradients(grad_norm_clip)

        for param in itertools.chain(task.base_model.parameters(),
                                     task.base_loss.parameters()):
            self.assertLessEqual(param.grad.norm(), grad_norm_clip)

    # helper used by gradient accumulation tests
    def train_with_batch(self, simulated_bs, actual_bs, clip_grad_norm=None):
        config = copy.deepcopy(get_fast_test_task_config())
        config["dataset"]["train"]["num_samples"] = 12
        config["dataset"]["train"]["batchsize_per_replica"] = actual_bs
        del config["dataset"]["test"]

        task = build_task(config)
        task.set_num_epochs(1)
        task.set_model(SimpleModel())
        task.set_loss(SimpleLoss())
        task.set_meters([])
        task.set_use_gpu(torch.cuda.is_available())
        if simulated_bs is not None:
            task.set_simulated_global_batchsize(simulated_bs)
        if clip_grad_norm is not None:
            task.set_clip_grad_norm(clip_grad_norm)

        task.set_optimizer(SGD(lr=1))

        trainer = LocalTrainer()
        trainer.train(task)

        return task.model.param

    def test_gradient_accumulation(self):
        param_with_accumulation = self.train_with_batch(simulated_bs=4,
                                                        actual_bs=2)
        param = self.train_with_batch(simulated_bs=4, actual_bs=4)

        self.assertAlmostEqual(param_with_accumulation, param, delta=1e-5)

    def test_gradient_accumulation_and_clipping(self):
        param = self.train_with_batch(simulated_bs=6,
                                      actual_bs=2,
                                      clip_grad_norm=0.1)

        # param starts at 5, it has to decrease, LR = 1
        # clipping the grad to 0.1 means we drop 0.1 per update. num_samples =
        # 12 and the simulated batch size is 6, so we should do 2 updates: 5 ->
        # 4.9 -> 4.8
        self.assertAlmostEqual(param, 4.8, delta=1e-5)

    @unittest.skipIf(
        get_torch_version() < [1, 8],
        "FP16 Grad compression is only available from PyTorch 1.8",
    )
    def test_fp16_grad_compression(self):
        # there is no API defined to check that a DDP hook has been enabled, so we just
        # test that we set the right variables
        config = copy.deepcopy(get_fast_test_task_config())
        task = build_task(config)
        self.assertFalse(task.fp16_grad_compress)

        config.setdefault("distributed", {})
        config["distributed"]["fp16_grad_compress"] = True

        task = build_task(config)
        self.assertTrue(task.fp16_grad_compress)