Пример #1
0
    def init_distributed_data_parallel_model(self):
        """
        Initialize
        `torch.nn.parallel.distributed.DistributedDataParallel <https://pytorch.org/
        docs/stable/nn.html#distributeddataparallel>`_.

        Needed for distributed training. This is where a model should be wrapped by DDP.
        """
        if not is_distributed_training_run():
            return
        assert (
            self.distributed_model is None
        ), "init_ddp_non_elastic must only be called once"

        broadcast_buffers = (
            self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS
        )

        if self.use_sharded_ddp:
            if not isinstance(self.optimizer, ZeRO):
                raise ValueError(
                    "ShardedDataParallel engine should only be used in conjunction with ZeRO optimizer"
                )
            from fairscale.nn.data_parallel import ShardedDataParallel

            # Replace the original DDP wrap by the shard-aware ShardedDDP
            self.distributed_model = ShardedDataParallel(
                module=self.base_model,
                sharded_optimizer=self.optimizer.optimizer,
                broadcast_buffers=broadcast_buffers,
            )
        else:
            self.distributed_model = init_distributed_data_parallel_model(
                self.base_model,
                broadcast_buffers=broadcast_buffers,
                find_unused_parameters=self.find_unused_parameters,
                bucket_cap_mb=self.ddp_bucket_cap_mb,
            )
            if self.fp16_grad_compress:

                from torch.distributed.algorithms import ddp_comm_hooks

                # FP16 hook is stateless and only takes a process group as the state.
                # We use the default process group so we set the state to None.
                process_group = None
                self.distributed_model.register_comm_hook(
                    process_group,
                    ddp_comm_hooks.default_hooks.fp16_compress_hook,
                )
        if (
            isinstance(self.base_loss, ClassyLoss)
            and self.base_loss.has_learned_parameters()
        ):
            logging.info("Initializing distributed loss")
            self.distributed_loss = init_distributed_data_parallel_model(
                self.base_loss,
                broadcast_buffers=broadcast_buffers,
                find_unused_parameters=self.find_unused_parameters,
                bucket_cap_mb=self.ddp_bucket_cap_mb,
            )
Пример #2
0
    def init_distributed_data_parallel_model(self):
        """
        Initialize
        `torch.nn.parallel.distributed.DistributedDataParallel <https://pytorch.org/
        docs/stable/nn.html#distributeddataparallel>`_.

        Needed for distributed training. This is where a model should be wrapped by DDP.
        """
        if not is_distributed_training_run():
            return
        assert (self.distributed_model is
                None), "init_ddp_non_elastic must only be called once"

        broadcast_buffers = (
            self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS)
        self.distributed_model = init_distributed_data_parallel_model(
            self.base_model,
            broadcast_buffers=broadcast_buffers,
            find_unused_parameters=self.find_unused_parameters,
        )
        if (isinstance(self.base_loss, ClassyLoss)
                and self.base_loss.has_learned_parameters()):
            logging.info("Initializing distributed loss")
            self.distributed_loss = init_distributed_data_parallel_model(
                self.base_loss,
                broadcast_buffers=broadcast_buffers,
                find_unused_parameters=self.find_unused_parameters,
            )
Пример #3
0
    def _build_momentum_network(self, task: tasks.ClassyTask) -> None:
        """
        Create the teacher: it is an exponential moving average of the student.
        """
        logging.info("Building momentum encoder")

        # - same architecture but do not apply stochastic depth
        task.config["MODEL"]["TRUNK"]["VISION_TRANSFORMERS"][
            "DROP_PATH_RATE"] = 0
        task.loss.momentum_teacher = build_model(task.config["MODEL"],
                                                 task.config["OPTIMIZER"])
        task.loss.momentum_teacher = nn.SyncBatchNorm.convert_sync_batchnorm(
            task.loss.momentum_teacher)
        task.loss.momentum_teacher.to(task.device)

        if get_world_size() > 1:
            task.loss.momentum_teacher = init_distributed_data_parallel_model(
                task.loss.momentum_teacher)

        # Restore an hypothetical checkpoint
        if task.loss.checkpoint is not None:
            task.loss.load_state_dict(task.loss.checkpoint)
        # Initialize from the model
        else:
            task.loss.momentum_teacher.load_state_dict(task.model.state_dict())
Пример #4
0
    def _build_momentum_network(self, task: tasks.ClassyTask) -> None:
        """
        Create the model replica called the encoder. This will slowly track
        the main model.
        """
        logging.info("Building momentum encoder - rank %s %s",
                     *get_machine_local_and_dist_rank())

        # - same architecture
        task.loss.momentum_encoder = build_model(task.config["MODEL"],
                                                 task.config["OPTIMIZER"])
        task.loss.momentum_encoder = nn.SyncBatchNorm.convert_sync_batchnorm(
            task.loss.momentum_encoder)
        task.loss.momentum_encoder.to(
            torch.device("cuda" if task.use_gpu else "cpu"))

        # Initialize from the model
        if task.loss.checkpoint is None:
            for param_q, param_k in zip(
                    task.base_model.parameters(),
                    task.loss.momentum_encoder.parameters()):
                param_k.data.copy_(param_q.data)
            for buff_q, buff_k in zip(
                    task.base_model.named_buffers(),
                    task.loss.momentum_encoder.named_buffers(),
            ):
                if "running_" not in buff_k[0]:
                    continue
                buff_k[1].data.copy_(buff_q[1].data)
        task.loss.momentum_encoder = init_distributed_data_parallel_model(
            task.loss.momentum_encoder)

        # Restore an hypothetical checkpoint
        if task.loss.checkpoint is not None:
            task.loss.load_state_dict(task.loss.checkpoint)
Пример #5
0
    def init_distributed_data_parallel_model(self):
        """Sets up distributed dataparallel and wraps model in DDP
        """
        assert (self.distributed_model is
                None), "init_ddp_non_elastic must only be called once"

        self.distributed_model = init_distributed_data_parallel_model(
            self.base_model)
Пример #6
0
    def init_distributed_data_parallel_model(self):
        """Sets up distributed dataparallel and wraps model in DDP
        """
        assert (self.distributed_model is
                None), "init_ddp_non_elastic must only be called once"

        broadcast_buffers = (
            self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS)
        self.distributed_model = init_distributed_data_parallel_model(
            self.base_model, broadcast_buffers=broadcast_buffers)
Пример #7
0
    def init_distributed_data_parallel_model(self):
        """
        Initialize
        `torch.nn.parallel.distributed.DistributedDataParallel <https://pytorch.org/
        docs/stable/nn.html#distributeddataparallel>`_.

        Needed for distributed training. This is where a model should be wrapped by DDP.
        """
        assert (self.distributed_model is
                None), "init_ddp_non_elastic must only be called once"

        broadcast_buffers = (
            self.broadcast_buffers_mode == BroadcastBuffersMode.FORWARD_PASS)
        self.distributed_model = init_distributed_data_parallel_model(
            self.base_model, broadcast_buffers=broadcast_buffers)