Ejemplo n.º 1
0
    def _init_pytorch_grad_scaler(self):
        assert is_fairscale_sharded_available(), (
            "To use FSDP with PyTorch AMP, ShardedGradScaler() "
            "from fairscale is needed. Please upgrade fairscale")
        from fairscale.optim.grad_scaler import ShardedGradScaler

        self.amp_grad_scaler = ShardedGradScaler()
        logging.info("Setting AMP: using ShardedGradScaler")
Ejemplo n.º 2
0
    def set_amp_args(self):
        """
        Two automatic mixed precision implementations are available: Apex's and PyTorch's.

        - If Apex's AMP is enabled, amp_args is a dictionary containing arguments
        to be passed to amp.initialize. Set to None to disable amp.
        To enable mixed precision training, pass amp_args={"opt_level": "O1"} here.
        See https://nvidia.github.io/apex/amp.html for more info.

        - If Pytorch's AMP is enabled, no arguments are needed.
        """

        if self.config.MODEL.AMP_PARAMS.USE_AMP:
            assert (
                self.device.type == "cuda"
            ), "Mixed precision is only available on CUDA devices for now"

            # This will rightly fail if the setting is not correct
            self.amp_type = AmpType[self.config.MODEL.AMP_PARAMS.AMP_TYPE.upper()]

            # Check Apex availability
            if self.amp_type == AmpType.APEX:
                if not is_apex_available():
                    raise RuntimeError(
                        "Apex is not available. Can't use mixed precision"
                    )

                # "amp_args" are actually Apex Amp args
                self.amp_args = self.config.MODEL.AMP_PARAMS.AMP_ARGS
                logging.info(f"Setting AMP: using apex, args {self.amp_args}")

            elif self.amp_type == AmpType.PYTORCH:
                # if the optimizer is sharded or FSDP data parallel is used, then the GradScaler
                # needs to be shard-aware.
                if (
                    self.config["TRAINER"]["TASK_NAME"] == "self_supervision_fsdp_task"
                    or self.config["OPTIMIZER"]["name"] == "zero"
                ):
                    assert is_fairscale_sharded_available(), (
                        "To use ZeRO with PyTorch AMP, ShardedGradScaler() "
                        "from fairscale is needed. Please upgrade fairscale"
                    )
                    from fairscale.optim.grad_scaler import ShardedGradScaler

                    self.amp_grad_scaler = ShardedGradScaler()
                    logging.info("Setting AMP: using sharded grad scaler")
                else:
                    self.amp_grad_scaler = TorchGradScaler()
                    logging.info("Setting AMP: using pytorch grad scaler")
            logging.info(f"Setting AMP: {self.amp_type} - args: {self.amp_args}")

        else:
            self.amp_args, self.amp_type = None, None
            logging.info("Not using Automatic Mixed Precision")
Ejemplo n.º 3
0
    def _init_pytorch_grad_scaler(self):
        if self.config["OPTIMIZER"]["name"] == "zero":
            assert is_fairscale_sharded_available(), (
                "To use ZeRO with PyTorch AMP, ShardedGradScaler() "
                "from fairscale is needed. Please upgrade fairscale")
            from fairscale.optim.grad_scaler import ShardedGradScaler

            self.amp_grad_scaler = ShardedGradScaler()
            logging.info("Setting AMP: using sharded grad scaler")
        else:
            self.amp_grad_scaler = TorchGradScaler()
            logging.info("Setting AMP: using pytorch grad scaler")