Пример #1
0
    def __init__(
            self,
            params,
            lr=1e-3,
            bias_correction=True,
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=0,
            use_fp16_stats=False,
    ):
        defaults = {
            "lr": lr,
            "bias_correction": bias_correction,
            "betas": betas,
            "eps": eps,
            "weight_decay": weight_decay,
        }
        super().__init__(params, defaults)

        self.use_fp16_stats = use_fp16_stats
        self.FLOAT16_MAX = 65504.0

        if not has_deepspeed_cpu_adam:
            raise ImportError(
                "Please install DeepSpeed: pip install deepspeed")

        self.opt_id = CPUAdam.optimizer_id
        CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1

        self.ds_opt_adam = CPUAdamBuilder().load()
        adamw_mode = True
        self.ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps,
                                     weight_decay, adamw_mode)
Пример #2
0
def _get_cpu_adam():
    try:
        from deepspeed.ops.op_builder import CPUAdamBuilder
        return CPUAdamBuilder().load()
    except ImportError:
        # fbcode
        from deepspeed.ops.adam import DeepSpeedCPUAdam as ds_opt_adam
        return ds_opt_adam
Пример #3
0
class CPUAdam(torch.optim.Optimizer):

    optimizer_id = 0

    def __init__(
            self,
            params,
            lr=1e-3,
            bias_correction=True,
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=0,
            use_fp16_stats=False,
    ):
        defaults = {
            "lr": lr,
            "bias_correction": bias_correction,
            "betas": betas,
            "eps": eps,
            "weight_decay": weight_decay,
        }
        super().__init__(params, defaults)

        self.use_fp16_stats = use_fp16_stats
        self.FLOAT16_MAX = 65504.0

        if not has_deepspeed_cpu_adam:
            raise ImportError(
                "Please install DeepSpeed: pip install deepspeed")

        self.opt_id = CPUAdam.optimizer_id
        CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1

        self.ds_opt_adam = CPUAdamBuilder().load()
        adamw_mode = True
        self.ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps,
                                     weight_decay, adamw_mode)

    @property
    def supports_flat_params(self):
        return True

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group_id, group in enumerate(self.param_groups):
            for param_id, p in enumerate(group["params"]):
                if p.grad is None:
                    continue

                state = self.state[p]
                if len(state) == 0:
                    state["step"] = 0
                    dtype = torch.float16 if self.use_fp16_stats else p.data.dtype
                    # gradient momentums
                    state["exp_avg"] = torch.zeros_like(p.data,
                                                        dtype=dtype,
                                                        device="cpu")
                    # gradient variances
                    state["exp_avg_sq"] = torch.zeros_like(p.data,
                                                           dtype=dtype,
                                                           device="cpu")
                    if self.use_fp16_stats:
                        assert torch.is_floating_point(p.data)
                        state["exp_avg_scale"] = 1.0
                        state["exp_avg_sq_scale"] = 1.0

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]

                p_data_bak = p.data  # backup of the original data pointer

                p.data = p.data.to(dtype=torch.float32, device="cpu")
                p.grad.data = p.grad.data.to(dtype=torch.float32, device="cpu")

                if self.use_fp16_stats:
                    exp_avg = exp_avg.float() * state["exp_avg_scale"]
                    exp_avg_sq = exp_avg_sq.float() * state["exp_avg_sq_scale"]

                state["step"] += 1
                beta1, beta2 = group["betas"]

                self.ds_opt_adam.adam_update(
                    self.opt_id,
                    state["step"],
                    group["lr"],
                    beta1,
                    beta2,
                    group["eps"],
                    group["weight_decay"],
                    group["bias_correction"],
                    p.data,
                    p.grad.data,
                    exp_avg,
                    exp_avg_sq,
                )

                if p_data_bak.data_ptr() != p.data.data_ptr():
                    p_data_bak.copy_(p.data)
                    p.data = p_data_bak

                if self.use_fp16_stats:

                    def inf_norm(t):
                        return torch.norm(t, float("inf"))

                    # from github.com/openai/jukebox/blob/master/jukebox/utils/fp16.py
                    state["exp_avg_scale"], state["exp_avg_sq_scale"] = (
                        1e-8 + inf_norm(exp_avg) / self.FLOAT16_MAX,
                        1e-8 + inf_norm(exp_avg_sq) / self.FLOAT16_MAX,
                    )
                    state["exp_avg"], state["exp_avg_sq"] = (
                        (exp_avg / state["exp_avg_scale"]).half(),
                        (exp_avg_sq / state["exp_avg_sq_scale"]).half(),
                    )

        return loss