def __init__( self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, use_fp16_stats=False, ): defaults = { "lr": lr, "bias_correction": bias_correction, "betas": betas, "eps": eps, "weight_decay": weight_decay, } super().__init__(params, defaults) self.use_fp16_stats = use_fp16_stats self.FLOAT16_MAX = 65504.0 if not has_deepspeed_cpu_adam: raise ImportError( "Please install DeepSpeed: pip install deepspeed") self.opt_id = CPUAdam.optimizer_id CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1 self.ds_opt_adam = CPUAdamBuilder().load() adamw_mode = True self.ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
def _get_cpu_adam(): try: from deepspeed.ops.op_builder import CPUAdamBuilder return CPUAdamBuilder().load() except ImportError: # fbcode from deepspeed.ops.adam import DeepSpeedCPUAdam as ds_opt_adam return ds_opt_adam
class CPUAdam(torch.optim.Optimizer): optimizer_id = 0 def __init__( self, params, lr=1e-3, bias_correction=True, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, use_fp16_stats=False, ): defaults = { "lr": lr, "bias_correction": bias_correction, "betas": betas, "eps": eps, "weight_decay": weight_decay, } super().__init__(params, defaults) self.use_fp16_stats = use_fp16_stats self.FLOAT16_MAX = 65504.0 if not has_deepspeed_cpu_adam: raise ImportError( "Please install DeepSpeed: pip install deepspeed") self.opt_id = CPUAdam.optimizer_id CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1 self.ds_opt_adam = CPUAdamBuilder().load() adamw_mode = True self.ds_opt_adam.create_adam(self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode) @property def supports_flat_params(self): return True @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group_id, group in enumerate(self.param_groups): for param_id, p in enumerate(group["params"]): if p.grad is None: continue state = self.state[p] if len(state) == 0: state["step"] = 0 dtype = torch.float16 if self.use_fp16_stats else p.data.dtype # gradient momentums state["exp_avg"] = torch.zeros_like(p.data, dtype=dtype, device="cpu") # gradient variances state["exp_avg_sq"] = torch.zeros_like(p.data, dtype=dtype, device="cpu") if self.use_fp16_stats: assert torch.is_floating_point(p.data) state["exp_avg_scale"] = 1.0 state["exp_avg_sq_scale"] = 1.0 exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] p_data_bak = p.data # backup of the original data pointer p.data = p.data.to(dtype=torch.float32, device="cpu") p.grad.data = p.grad.data.to(dtype=torch.float32, device="cpu") if self.use_fp16_stats: exp_avg = exp_avg.float() * state["exp_avg_scale"] exp_avg_sq = exp_avg_sq.float() * state["exp_avg_sq_scale"] state["step"] += 1 beta1, beta2 = group["betas"] self.ds_opt_adam.adam_update( self.opt_id, state["step"], group["lr"], beta1, beta2, group["eps"], group["weight_decay"], group["bias_correction"], p.data, p.grad.data, exp_avg, exp_avg_sq, ) if p_data_bak.data_ptr() != p.data.data_ptr(): p_data_bak.copy_(p.data) p.data = p_data_bak if self.use_fp16_stats: def inf_norm(t): return torch.norm(t, float("inf")) # from github.com/openai/jukebox/blob/master/jukebox/utils/fp16.py state["exp_avg_scale"], state["exp_avg_sq_scale"] = ( 1e-8 + inf_norm(exp_avg) / self.FLOAT16_MAX, 1e-8 + inf_norm(exp_avg_sq) / self.FLOAT16_MAX, ) state["exp_avg"], state["exp_avg_sq"] = ( (exp_avg / state["exp_avg_scale"]).half(), (exp_avg_sq / state["exp_avg_sq_scale"]).half(), ) return loss