class FP16_Optimizer(object): """ FP16 Optimizer for training fp16 models. Handles loss scaling. For usage example please see, TODO: DeepSpeed V2 Tutorial """ def __init__(self, init_optimizer, deepspeed=None, static_loss_scale=1.0, dynamic_loss_scale=False, initial_dynamic_scale=2**32, dynamic_loss_args=None, verbose=True, mpu=None, clip_grad=0.0, fused_adam_legacy=False, timers=None): self.fused_adam_legacy = fused_adam_legacy self.timers = timers if not torch.cuda.is_available: raise SystemError("Cannot use fp16 without CUDA.") self.optimizer = init_optimizer # param flattened by groups self.fp16_groups = [] self.fp16_groups_flat = [] self.fp32_groups_flat = [] # loop to deal with groups for i, param_group in enumerate(self.optimizer.param_groups): # push this group to list before modify self.fp16_groups.append(param_group['params']) # init fp16 weight buffer, flattened self.fp16_groups_flat.append( _flatten_dense_tensors([p.clone().detach() for p in self.fp16_groups[i]])) # set model fp16 weight to slices of flattened buffer updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data # init master weight, flattened self.fp32_groups_flat.append( self.fp16_groups_flat[i].clone().float().detach()) # modify optimizer of have flat master weight self.fp32_groups_flat[ i].requires_grad = True # keep this in case internal optimizer uses it param_group['params'] = [self.fp32_groups_flat[i]] # we may have a way of fusing dynamic scale. Do not support for now if dynamic_loss_scale: self.dynamic_loss_scale = True self.cur_iter = 0 self.last_overflow_iter = -1 self.scale_factor = 2 if dynamic_loss_args is None: self.cur_scale = initial_dynamic_scale self.scale_window = 1000 self.min_loss_scale = 1 else: self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE] self.scale_window = dynamic_loss_args[SCALE_WINDOW] self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE] else: self.dynamic_loss_scale = False self.cur_iter = 0 self.cur_scale = static_loss_scale self.verbose = verbose self.clip_grad = clip_grad self.norm_type = 2 TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) if TORCH_MAJOR == 0 and TORCH_MINOR <= 4: self.clip_grad_norm = torch.nn.utils.clip_grad_norm else: self.clip_grad_norm = torch.nn.utils.clip_grad_norm_ #model parallel object self.mpu = mpu self.overflow = False self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu, deepspeed=deepspeed) self.initialize_optimizer_states() def initialize_optimizer_states(self): for i, group in enumerate(self.fp16_groups): self.fp32_groups_flat[i].grad = torch.zeros( self.fp32_groups_flat[i].size(), device=self.fp32_groups_flat[i].device) self.optimizer.step() for i, group in enumerate(self.fp16_groups): self.fp32_groups_flat[i].grad = None return def zero_grad(self, set_grads_to_None=True): """ Zero FP16 parameter grads. """ # For speed, set model fp16 grad to None by default for group in self.fp16_groups: for p in group: if set_grads_to_None: p.grad = None else: if p.grad is not None: p.grad.detach_() p.grad.zero_() def step_fused_adam(self, closure=None): """ Not supporting closure. """ # First compute norm for all group so we know if there is overflow grads_groups_flat = [] norm_groups = [] for i, group in enumerate(self.fp16_groups): grads_groups_flat.append( _flatten_dense_tensors([ torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group ])) norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu)) self.overflow = self.overflow_checker.check_using_norm(norm_groups) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow combined_scale = self.unscale_and_clip_grads(grads_groups_flat, norm_groups, apply_scale=False) # norm is in fact norm*cur_scale self.optimizer.step(grads=[[g] for g in grads_groups_flat], output_params=[[p] for p in self.fp16_groups_flat], scale=combined_scale, grad_norms=norm_groups) # TODO: we probably don't need this? just to be safe for i in range(len(norm_groups)): updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data = q.data return self.overflow def start_timers(self, name_list): if self.timers is not None: for name in name_list: self.timers(name).start() def stop_timers(self, name_list): if self.timers is not None: for name in name_list: self.timers(name).stop() def log_timers(self, name_list): if self.timers is not None: self.timers.log(name_list) def step(self, closure=None): """ Not supporting closure. """ if self.fused_adam_legacy: return self.step_fused_adam() COMPUTE_NORM = "compute_norm" OVERFLOW_CHECK = 'overflow_check' OVERFLOW_TIMERS = [COMPUTE_NORM, OVERFLOW_CHECK] UNSCALE_AND_CLIP = 'unscale_and_clip' BASIC_STEP = 'basic_step' UPDATE_FP16 = 'update_fp16' STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16] # First determine if there is overflow. self.start_timers([OVERFLOW_CHECK]) fp16_params = [] for i, group in enumerate(self.fp16_groups): fp16_params.extend([p for p in group if p.grad is not None]) self.overflow = self.overflow_checker.has_overflow(fp16_params) self.stop_timers([OVERFLOW_CHECK]) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: log_dist( "Overflow detected. Skipping step. Attempted loss " f"scale: {prev_scale}, reducing to {self.cur_scale}", ranks=[0]) # Clear gradients for i, group in enumerate(self.fp16_groups): for p in group: p.grad = None self.log_timers(OVERFLOW_TIMERS) return self.overflow grads_groups_flat = [] for i, group in enumerate(self.fp16_groups): data_type = self.fp32_groups_flat[i].dtype grads_groups_flat.append( _flatten_dense_tensors([ torch.zeros(p.size(), dtype=data_type, device=p.device) if p.grad is None else p.grad.to(data_type) for p in group ])) for p in group: p.grad = None self.fp32_groups_flat[i].grad = grads_groups_flat[i] self.start_timers([COMPUTE_NORM]) all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu) self.stop_timers([COMPUTE_NORM]) self.start_timers([UNSCALE_AND_CLIP]) self.unscale_and_clip_grads(grads_groups_flat, [all_groups_norm]) self.stop_timers([UNSCALE_AND_CLIP]) self.start_timers([BASIC_STEP]) self.optimizer.step() self.stop_timers([BASIC_STEP]) #get rid of the fp32 gradients. Not needed anymore for group in self.fp32_groups_flat: group.grad = None self.start_timers([UPDATE_FP16]) for i in range(len(self.fp16_groups)): updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i], self.fp16_groups[i]) for p, q in zip(self.fp16_groups[i], updated_params): p.data.copy_(q.data) self.stop_timers([UPDATE_FP16]) self.log_timers(STEP_TIMERS) return self.overflow def unscale_and_clip_grads(self, grad_groups_flat, norm_groups, apply_scale=True): total_norm = 0.0 for norm in norm_groups: total_norm += norm**2.0 total_norm = math.sqrt(total_norm) # compute combined scale factor for this group combined_scale = self.cur_scale if self.clip_grad > 0.: # norm is in fact norm*scale clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad if clip > 1: combined_scale = clip * self.cur_scale if apply_scale: for grad in grad_groups_flat: grad.data.mul_(1. / combined_scale) return combined_scale def backward(self, loss): """ :attr:`backward` performs the following steps: 1. fp32_loss = loss.float() 2. scaled_loss = fp32_loss*loss_scale 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves """ scaled_loss = (loss.float()) * self.cur_scale scaled_loss.backward() def _update_scale(self, skip): if self.dynamic_loss_scale: prev_scale = self.cur_scale if skip: self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_loss_scale) self.last_overflow_iter = self.cur_iter if self.verbose: logger.info(f"\nGrad overflow on iteration {self.cur_iter}") logger.info( f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}" ) else: # Ensure self.scale_window updates since last overflow stable_interval = (self.cur_iter - self.last_overflow_iter) - 1 if (stable_interval > 0) and (stable_interval % self.scale_window == 0): self.cur_scale *= self.scale_factor if self.verbose: logger.info( f"No Grad overflow for {self.scale_window} iterations") logger.info( f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}" ) else: if skip: logger.info("Grad overflow on iteration: %s", self.cur_iter) logger.info("Using static loss scale of: %s", self.cur_scale) self.cur_iter += 1 return # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" def _get_state(self): return self.optimizer.state def _set_state(self, value): self.optimizer.state = value state = property(_get_state, _set_state) # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" # (for example, to adjust the learning rate) def _get_param_groups(self): return self.optimizer.param_groups def _set_param_groups(self, value): self.optimizer.param_groups = value param_groups = property(_get_param_groups, _set_param_groups) def state_dict(self): """ Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict of the contained Pytorch optimizer. Example:: checkpoint = {} checkpoint['model'] = model.state_dict() checkpoint['optimizer'] = optimizer.state_dict() torch.save(checkpoint, "saved.pth") """ state_dict = {} state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['cur_scale'] = self.cur_scale state_dict['cur_iter'] = self.cur_iter if state_dict['dynamic_loss_scale']: state_dict['last_overflow_iter'] = self.last_overflow_iter state_dict['scale_factor'] = self.scale_factor state_dict['scale_window'] = self.scale_window state_dict['optimizer_state_dict'] = self.optimizer.state_dict() state_dict['fp32_groups_flat'] = self.fp32_groups_flat state_dict['clip_grad'] = self.clip_grad return state_dict # Refresh fp32 master params from fp16 copies def refresh_fp32_params(self): for current, saved in zip(self.fp32_groups_flat, self.fp16_groups_flat): current.data.copy_(saved.data) def load_state_dict(self, state_dict, load_optimizer_states=True): """ Loads a state_dict created by an earlier call to state_dict(). If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, whose parameters in turn came from ``model``, it is expected that the user will call ``model.load_state_dict()`` before ``fp16_optimizer_instance.load_state_dict()`` is called. Example:: model = torch.nn.Linear(D_in, D_out).cuda().half() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) ... checkpoint = torch.load("saved.pth") model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) """ # I think it should actually be ok to reload the optimizer before the model. self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] self.cur_scale = state_dict['cur_scale'] self.cur_iter = state_dict['cur_iter'] if state_dict['dynamic_loss_scale']: self.last_overflow_iter = state_dict['last_overflow_iter'] self.scale_factor = state_dict['scale_factor'] self.scale_window = state_dict['scale_window'] if load_optimizer_states: self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) self.clip_grad = state_dict['clip_grad'] # At this point, the optimizer's references to the model's fp32 parameters are up to date. # The optimizer's hyperparameters and internal buffers are also up to date. # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still # out of date. There are two options. # 1: Refresh the master params from the model's fp16 params. # This requires less storage but incurs precision loss. # 2: Save and restore the fp32 master copies separately. # We choose option 2. # # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device # of their associated parameters, because it's possible those buffers might not exist yet in # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been # constructed in the same way as the one whose state_dict we are loading, the same master params # are guaranteed to exist, so we can just copy_() from the saved master params. for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']): current.data.copy_(saved.data) def __repr__(self): return repr(self.optimizer)
class FP16_UnfusedOptimizer(object): """ FP16 Optimizer without weight fusion to support LAMB optimizer For usage example please see, TODO: DeepSpeed V2 Tutorial """ def __init__(self, init_optimizer, static_loss_scale=1.0, dynamic_loss_scale=False, dynamic_loss_args=None, verbose=True, mpu=None, clip_grad=0.0, fused_lamb_legacy=False): self.fused_lamb_legacy = fused_lamb_legacy if torch.distributed.get_rank() == 0: logger.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ') if not torch.cuda.is_available: raise SystemError("Cannot use fp16 without CUDA.") self.optimizer = init_optimizer # param groups self.fp16_groups = [] self.fp32_groups = [] # loop to deal with groups for i, param_group in enumerate(self.optimizer.param_groups): #fp16 weights that represents the actual model weights self.fp16_groups.append(param_group['params']) #creating a fp32 copy of the weights that will be updated first then #copied to fp16 weights fp32_group = [p.clone().float().detach() for p in param_group['params']] #incase the internal optimizer needs it for p in fp32_group: p.requires_grad = True #setting the param groups in the optimizer to point to fp32 #note these are not the weights used by the model #the model uses the fp16 version that we added to fp16_group self.fp32_groups.append(fp32_group) param_group['params'] = self.fp32_groups[i] # we may have a way of fusing dynamic scale. Do not support for now if dynamic_loss_scale: self.dynamic_loss_scale = True self.cur_iter = 0 self.last_overflow_iter = -1 self.scale_factor = 2.0 if dynamic_loss_args is None: self.cur_scale = 1.0 * 2**16 self.scale_window = 1000 self.min_loss_scale = 0.25 else: self.cur_scale = dynamic_loss_args[INITIAL_LOSS_SCALE] self.scale_window = dynamic_loss_args[SCALE_WINDOW] self.min_loss_scale = dynamic_loss_args[MIN_LOSS_SCALE] else: self.dynamic_loss_scale = False self.cur_iter = 0 self.cur_scale = static_loss_scale self.verbose = verbose self.clip_grad = clip_grad self.norm_type = 2 TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) if TORCH_MAJOR == 0 and TORCH_MINOR <= 4: self.clip_grad_norm = torch.nn.utils.clip_grad_norm else: self.clip_grad_norm = torch.nn.utils.clip_grad_norm_ self.mpu = mpu self.overflow = False self.overflow_checker = CheckOverflow(self.fp16_groups, mpu=self.mpu) self.initialize_optimizer_states() def zero_grad(self, set_grads_to_None=True): """ Zero FP16 parameter grads. """ # FP32 grad should never exist outside of the step function # For speed, set model fp16 grad to None by default for group in self.fp16_groups: for p in group: if set_grads_to_None: p.grad = None else: if p.grad is not None: p.grad.detach_() p.grad.zero_() def step_fused_lamb(self, closure=None): """ Not supporting closure. """ # First compute norm for all group so we know if there is overflow grads_groups_flat = [] grads_groups = [] norm_groups = [] for i, group in enumerate(self.fp16_groups): grads = [ torch.zeros(p.size(), dtype=p.dtype, device=p.device) if p.grad is None else p.grad for p in group ] grads_groups.append(grads) grads_groups_flat.append(_flatten_dense_tensors(grads)) norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu)) self.overflow = self.overflow_checker.check_using_norm(norm_groups) prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow combined_scale = self.unscale_and_clip_grads(norm_groups, apply_scale=False) self.optimizer.step(grads=grads_groups, output_params=self.fp16_groups, scale=combined_scale) return self.overflow def step(self, closure=None): """ Not supporting closure. """ if self.fused_lamb_legacy: return self.step_fused_lamb() self.overflow = self.overflow_checker.check() prev_scale = self.cur_scale self._update_scale(self.overflow) if self.overflow: if self.verbose: logger.info( "[deepspeed] fp16 dynamic loss scale overflow! Skipping step. Attempted loss " "scale: {}, reducing to {}".format(prev_scale, self.cur_scale)) return self.overflow norm_groups = [] for i, group in enumerate(self.fp16_groups): norm_groups.append(get_grad_norm(group, mpu=self.mpu)) # copying gradients to fp32 to work with fp32 parameters for fp32_param, fp16_param in zip(self.fp32_groups[i], self.fp16_groups[i]): if fp16_param.grad is None: fp32_param.grad = torch.zeros(fp16_param.size(), dtype=fp32_param.dtype, device=fp32_param.device) else: fp32_param.grad = fp16_param.grad.to(fp32_param.dtype) self.unscale_and_clip_grads(norm_groups) self.optimizer.step() for fp32_group, fp16_group in zip(self.fp32_groups, self.fp16_groups): for fp32_param, fp16_param in zip(fp32_group, fp16_group): #remove the fp32 grad fp32_param.grad = None #copy data from fp32 to fp16 fp16_param.data.copy_(fp32_param.data) return self.overflow def unscale_and_clip_grads(self, norm_groups, apply_scale=True): total_norm = 0.0 for norm in norm_groups: total_norm += norm**2.0 total_norm = math.sqrt(total_norm) # compute combined scale factor for this group combined_scale = self.cur_scale if self.clip_grad > 0.: # norm is in fact norm*scale clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad if clip > 1: combined_scale = clip * self.cur_scale if apply_scale: for group in self.fp32_groups: for param in group: if param.grad is not None: param.grad.data.mul_(1. / combined_scale) return combined_scale def backward(self, loss): """ :attr:`backward` performs the following steps: 1. fp32_loss = loss.float() 2. scaled_loss = fp32_loss*loss_scale 3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's fp16 leaves """ scaled_loss = (loss.float()) * self.cur_scale scaled_loss.backward() def _update_scale(self, skip): if self.dynamic_loss_scale: prev_scale = self.cur_scale if skip: self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_loss_scale) self.last_overflow_iter = self.cur_iter if self.verbose: logger.info("Grad overflow on iteration: %s", self.cur_iter) logger.info( f"Reducing dynamic loss scale from {prev_scale} to {self.cur_scale}" ) else: # Ensure self.scale_window updates since last overflow stable_interval = (self.cur_iter - self.last_overflow_iter) - 1 if (stable_interval > 0) and (stable_interval % self.scale_window == 0): self.cur_scale *= self.scale_factor if self.verbose: logger.info( f"No Grad overflow for {self.scale_window} iterations") logger.info( f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}" ) else: if skip: logger.info("Grad overflow on iteration %s", self.cur_iter) logger.info("Using static loss scale of %s", self.cur_scale) self.cur_iter += 1 return # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state" def _get_state(self): return self.optimizer.state def _set_state(self, value): self.optimizer.state = value state = property(_get_state, _set_state) # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups" # (for example, to adjust the learning rate) def _get_param_groups(self): return self.optimizer.param_groups def _set_param_groups(self, value): self.optimizer.param_groups = value param_groups = property(_get_param_groups, _set_param_groups) def state_dict(self): """ Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict of the contained Pytorch optimizer. Example:: checkpoint = {} checkpoint['model'] = model.state_dict() checkpoint['optimizer'] = optimizer.state_dict() torch.save(checkpoint, "saved.pth") """ state_dict = {} state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale state_dict['cur_scale'] = self.cur_scale state_dict['cur_iter'] = self.cur_iter if state_dict['dynamic_loss_scale']: state_dict['last_overflow_iter'] = self.last_overflow_iter state_dict['scale_factor'] = self.scale_factor state_dict['scale_window'] = self.scale_window state_dict['optimizer_state_dict'] = self.optimizer.state_dict() state_dict['fp32_groups'] = self.fp32_groups return state_dict def load_state_dict(self, state_dict, load_optimizer_states=True): """ Loads a state_dict created by an earlier call to state_dict(). If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, whose parameters in turn came from ``model``, it is expected that the user will call ``model.load_state_dict()`` before ``fp16_optimizer_instance.load_state_dict()`` is called. Example:: model = torch.nn.Linear(D_in, D_out).cuda().half() optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) ... checkpoint = torch.load("saved.pth") model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) """ # I think it should actually be ok to reload the optimizer before the model. self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] self.cur_scale = state_dict['cur_scale'] self.cur_iter = state_dict['cur_iter'] if state_dict['dynamic_loss_scale']: self.last_overflow_iter = state_dict['last_overflow_iter'] self.scale_factor = state_dict['scale_factor'] self.scale_window = state_dict['scale_window'] if load_optimizer_states: self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) # At this point, the optimizer's references to the model's fp32 parameters are up to date. # The optimizer's hyperparameters and internal buffers are also up to date. # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still # out of date. There are two options. # 1: Refresh the master params from the model's fp16 params. # This requires less storage but incurs precision loss. # 2: Save and restore the fp32 master copies separately. # We choose option 2. # # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device # of their associated parameters, because it's possible those buffers might not exist yet in # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been # constructed in the same way as the one whose state_dict we are loading, the same master params # are guaranteed to exist, so we can just copy_() from the saved master params. for current_group, saved_group in zip(self.fp32_groups, state_dict['fp32_groups']): for current, saved in zip(current_group, saved_group): current.data.copy_(saved.data) def __repr__(self): return repr(self.optimizer) def initialize_optimizer_states(self): for i, group in enumerate(self.fp16_groups): for param in group: param.grad = torch.zeros(param.size(), dtype=param.dtype, device=torch.cuda.current_device()) for i, group in enumerate(self.fp32_groups): for param in group: param.grad = torch.zeros(param.size(), dtype=param.dtype, device=torch.cuda.current_device()) self.optimizer.step() for i, group in enumerate(self.fp16_groups): for param in group: param.grad = None for i, group in enumerate(self.fp32_groups): for param in group: param.grad = None