def __init__(self, module, timers, ds_config, overlap_comm=True, prefetch_bucket_size=50000000, max_reuse_distance=1000000000, max_live_parameters=1000000000, param_persistence_threshold=100000, model_persistence_threshold=sys.maxsize, offload_param_config=None, mpu=None): see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True) print_rank_0(f"initialized {__class__.__name__} with args: {locals()}", force=False) self.module = module self.dtype = list(module.parameters())[0].dtype self.offload_device = None self.offload_param_pin_memory = False if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none: self.offload_device = offload_param_config.device self.offload_param_pin_memory = offload_param_config.pin_memory self._convert_to_zero_parameters(ds_config, module, mpu) for m in module.modules(): _init_external_params(m) _inject_parameters(module, ZeROOrderedDict) self.param_numel_persistence_threshold = int( param_persistence_threshold) self.model_persistence_threshold = int(model_persistence_threshold) self.persistent_parameters = self.mark_persistent_parameters( self.param_numel_persistence_threshold, self.model_persistence_threshold) self.param_coordinators = {} self._prefetch_bucket_sz = int(prefetch_bucket_size) self._max_reuse_distance_in_numel = int(max_reuse_distance) self._max_available_parameters_in_numel = int(max_live_parameters) self.__allgather_stream = Stream( ) if overlap_comm else torch.cuda.default_stream() self.forward_hooks = [] self.backward_hooks = [] self.setup_zero_stage3_hooks() print_rank_0( f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}', force=False) see_memory_usage("DeepSpeedZeRoOffload initialize [end]", force=True)
def __init__(self, init_optimizer, param_names, mpu=None, clip_grad=0.0, norm_type=2, allgather_bucket_size=5000000000, dp_process_group=None, timers=None): super().__init__() see_memory_usage('begin bf16_optimizer', force=True) self.timers = timers self.optimizer = init_optimizer self.param_names = param_names self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim) self.clip_grad = clip_grad self.norm_type = norm_type self.mpu = mpu self.allgather_bucket_size = int(allgather_bucket_size) self.dp_process_group = dp_process_group self.dp_rank = dist.get_rank(group=self.dp_process_group) self.real_dp_process_group = [ dp_process_group for i in range(len(self.optimizer.param_groups)) ] # Load pre-built or JIT compile (un)flatten ops util_ops = UtilsBuilder().load() self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten #align nccl all-gather send buffers to 4-bye boundary self.nccl_start_alignment_factor = 2 # 4-byte alignment/sizeof(fp16) = 2 # Build BF16/FP32 groups self.bf16_groups = [] self.bf16_groups_flat = [] self.bf16_partitioned_groups = [] self.fp32_groups_flat_partition = [] # Maintain different fp32 gradients views for convenience self.fp32_groups_gradients = [] self.fp32_groups_gradients_flat = [] self.fp32_groups_actual_gradients_flat = [] self.fp32_groups_gradient_flat_partition = [] self.fp32_groups_has_gradients = [] self.step_count = 0 self.group_paddings = [] if self.using_real_optimizer: self._setup_for_real_optimizer() see_memory_usage('end bf16_optimizer', force=True)
def post_sub_module_backward_function(self, sub_module): see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", force=False) self.get_param_coordinator( training=sub_module.training).release_sub_module(sub_module) see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release", force=False)
def pre_sub_module_forward_function(self, sub_module): see_memory_usage( f"Before sub module function {sub_module.__class__.__name__}", force=False) global FWD_MODULE_STACK FWD_MODULE_STACK.append(sub_module) param_coordinator = self.get_param_coordinator( training=sub_module.training) param_coordinator.trace_prologue(sub_module) if param_coordinator.is_record_trace(): param_coordinator.record_module(sub_module) param_coordinator.fetch_sub_module(sub_module) see_memory_usage( f"Before sub module function {sub_module.__class__.__name__} after fetch", force=False)
def _post_init_method(self, module): #see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False) print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False) see_memory_usage( f"Before converting and partitioning parmas in {module.__class__.__name__}", force=False) global param_count for name, param in module.named_parameters(recurse=False): param_count += param.numel() if not is_zero_param(param): self._convert_to_deepspeed_param(param) print_rank_0( f"Partitioning param with ds id {param.ds_id} and shape {param.data.shape}" ) param.partition() see_memory_usage( f"Param count {param_count}. After converting and partitioning parmas in {module.__class__.__name__}", force=False)
def model_provider(): """Build the model.""" print_rank_0('building GPT2 model ...') see_memory_usage(f"Before Building Model", force=True) with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), remote_device=get_args().remote_device, deepspeed_config=get_args().deepspeed_config, enabled=get_args().zero_stage == 3): model = GPT2Model(num_tokentypes=0, parallel_output=True) see_memory_usage(f"After Building Model", force=True) if mpu.get_data_parallel_rank() == 0: billion_params = get_parameters_in_billions(model) print( f' > number of parameters on model parallel rank {mpu.get_model_parallel_rank()}\ {round(billion_params, 3)} Billion', flush=True) return model
def train_step(forward_step_func, data_iterator, model, optimizer, lr_scheduler): """Single training step.""" args = get_args() timers = get_timers() see_memory_usage(f'before forward {model.global_steps}', force=True) # Forward model for one step. timers('forward').start() loss, loss_reduced = forward_step_func(data_iterator, model) timers('forward').stop() see_memory_usage(f'before backward {model.global_steps}', force=True) # Calculate gradients, reduce across processes, and clip. timers('backward').start() backward_step(optimizer, model, loss) timers('backward').stop() see_memory_usage(f'before optimizer {model.global_steps}', force=True) # Update parameters. skipped_iter = 0 timers('optimizer').start() if args.deepspeed: model.step() else: optimizer.step() # Update learning rate. if not (args.fp16 and optimizer.overflow): lr_scheduler.step() else: skipped_iter = 1 timers('optimizer').stop() return loss_reduced, skipped_iter
def _go(hidden_dim): with deepspeed.zero.Init(enabled=zero_stage == 3, config_dict_or_path=ds_config): model = SimpleModel(hidden_dim, nlayers=78) print('total number of parameters:', sum([p.numel() for p in model.parameters()])) see_memory_usage('pre-init', force=True) model, _, _, _ = deepspeed.initialize(model=model, config=ds_config) see_memory_usage('post-init', force=True) data_loader = random_dataloader(model=model, total_samples=50, hidden_dim=hidden_dim, device=model.device, dtype=torch.half) print(f"optimizer={model.optimizer}") for batch in data_loader: model(batch[0], batch[1]) see_memory_usage('post-fwds', force=True)
def _partition_gradient(self, param, partition_buffer=None, accumulate=False): #import pdb;pdb.set_trace() # param.grad=None # param.grad.test() print_rank_0( f"Partitioning param {id(param)} gradient of size {param.grad.numel()} type {param.grad.dtype} part_size {param.ds_tensor.numel()}" ) see_memory_usage("Before partitioning gradients", force=False) partition_size = param.ds_tensor.numel() if partition_buffer is None: assert not accumulate, "No buffer to accumulate to" partition_buffer = torch.zeros(partition_size, dtype=param.dtype, device=param.device) else: assert partition_buffer.numel( ) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}" rank = torch.distributed.get_rank(group=self.ds_process_group) start = partition_size * rank end = start + partition_size dest_tensor = partition_buffer.view(-1).narrow(0, 0, partition_size) #print("before partition gradients") if start < param.ds_numel: elements = min(param.ds_numel - start, partition_size) dest_tensor_full_buffer = partition_buffer.view(-1).narrow( 0, 0, partition_size) dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements) src_tensor = param.grad.view(-1).narrow(0, start, elements) # just copy the grad partition to the buffer if not accumulate: dest_tensor.copy_(src_tensor) # if source and destinatoin are on same device, # add to the provided buffer elif src_tensor.device == dest_tensor.device: dest_tensor.add_(src_tensor) # if source and destination are on different device, copy first to src # then add and move back to the destination. This seems to run faster # when src is gpu and dest is cpu # adding directly to cpu is very slow else: acc_tensor = torch.empty(src_tensor.numel(), dtype=param.dtype, device=param.device) acc_tensor.copy_(dest_tensor) acc_tensor.add_(src_tensor) dest_tensor.copy_(acc_tensor) # partition_buffer.view(-1).narrow( # 0, # 0, # elements).copy_(param.grad.view(-1).narrow(0, # start, # elements)) #print("after partition gradients") param.grad.data = dest_tensor_full_buffer.data see_memory_usage("After partitioning gradients", force=False)
def backward(ctx, *grads): global timers see_memory_usage("In backward", force=False) # removing pointers to the contiguous buffer memory # so that they can be garbage collected once the checkpoints # have been used if SYNCHRONIZE: torch.cuda.synchronize() if PROFILE_TIME: timers('backward').start() if CONTIGUOUS_CHECKPOINTING: global data_offsets, size_offsets global contiguous_data_buffers, contiguous_size_buffers for buffers in contiguous_data_buffers: buffers = [] # frees up all the pointers to the checkpoints except for the ones # stored by save for backward contiguous_data_buffers = [] contiguous_size_buffers = [] data_offsets = [] size_offsets = [] see_memory_usage("In backward checkpointing code", force=False) if not torch.autograd._is_checkpoint_valid(): raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible") global cuda_device, transport_stream, PARTITION_ACTIVATIONS if PARTITION_ACTIVATIONS: # with torch.cuda.stream(transport_stream): inputs = gather_partitioned_activations( ctx.saved_tensors, device=cuda_device if CPU_CHECKPOINT else None) detached_inputs = detach_variable(inputs) elif CPU_CHECKPOINT: inputs = move_to_device(ctx.saved_tensors, cuda_device, is_activation_to_checkpoint) detached_inputs = detach_variable(inputs) else: inputs = ctx.saved_tensors detached_inputs = detach_variable(inputs) # Add non tensor input args detached_inputs = merge_tensors(tensor_objects=detached_inputs, non_tensor_objects=ctx.non_tensor_args, tensor_flags=ctx.tensor_flags) # Store the current states. bwd_cpu_rng_state = torch.get_rng_state() bwd_cuda_rng_state = torch.cuda.get_rng_state() bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() # Set the states to what it used to be before the forward pass. torch.set_rng_state(ctx.fwd_cpu_rng_state) _set_cuda_rng_state(ctx.fwd_cuda_rng_state) get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) # if PARTITION_ACTIVATIONS: # current_stream=torch.cuda.current_stream() # current_stream.wait_stream(transport_stream) see_memory_usage("In backward checkpointing code before forward", force=False) with torch.enable_grad(): outputs = ctx.run_function(*detached_inputs) see_memory_usage("In backward checkpointing code after forward", force=False) # Set the states back to what it was at the start of this function. torch.set_rng_state(bwd_cpu_rng_state) _set_cuda_rng_state(bwd_cuda_rng_state) get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) if isinstance(outputs, torch.Tensor): outputs = (outputs, ) # Filter out non tensor outputs outputs, _, _ = extract_tensors(all_objects=outputs) # Construct arguments to autograd.backward(). # This is usually just outputs and grads, but forward() can return tensors that # are not differentiable. output_tensors = [] grad_tensors = [] for out, grad in zip(outputs, grads): if out.requires_grad: output_tensors.append(out) grad_tensors.append(grad) see_memory_usage("In backward checkpointing code before backward", force=False) torch.autograd.backward(output_tensors, grad_tensors) see_memory_usage("After backward checkpointing code after backward", force=False) if PROFILE_TIME: timers('backward').stop() timers.log(['backward']) if SYNCHRONIZE: torch.cuda.synchronize() ret_list = [None, None] # first None for ctx for inp in detached_inputs: if torch.is_tensor(inp): ret_list.append(inp.grad) else: ret_list.append(None) return tuple(ret_list)
def forward(ctx, run_function, all_outputs, *args): global mpu, timers, SYNCHRONIZE, PROFILE_TIME def save_args_for_backward(*all_args): tensor_args, non_tensor_args, tensor_flags = extract_tensors( all_objects=all_args) ctx.save_for_backward(*tensor_args) ctx.non_tensor_args = non_tensor_args ctx.tensor_flags = tensor_flags if SYNCHRONIZE: torch.cuda.synchronize() if timers is None and PROFILE_TIME: timers = Timers() if PROFILE_TIME: timers('forward').start() ctx.run_function = run_function global num_layers global mp_rank, mp_size, mp_group global contiguous_data_buffers, contiguous_size_buffers global data_offsets, size_offsets if mp_rank is None: if mpu is not None: if hasattr(mpu, 'get_tensor_model_parallel_rank'): mp_rank = mpu.get_tensor_model_parallel_rank() mp_size = mpu.get_tensor_model_parallel_world_size() mp_group = mpu.get_tensor_model_parallel_group() else: mp_rank = mpu.get_model_parallel_rank() mp_size = mpu.get_model_parallel_world_size() mp_group = mpu.get_model_parallel_group() else: mp_rank = 0 mp_size = 1 mp_group = None global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset if cuda_device is None: see_memory_usage("First Forward Beginning", force=False) if dist.get_rank() == 0: logger.info(f"Activation Checkpointing Information") logger.info( f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}" ) logger.info( f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers" ) logger.info(f"----Synchronization {SYNCHRONIZE}") logger.info( f"----Profiling time in checkpointing {PROFILE_TIME}") cuda_device = torch.cuda.current_device() transport_stream = torch.cuda.Stream(device=cuda_device) if PARTITION_ACTIVATIONS: inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING) elif CPU_CHECKPOINT: inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint) # just in case something funky is happening such as reuse of inputs inputs_cuda = copy_to_device( args, device=cuda_device, criterion_func=is_activation_to_checkpoint) # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() see_memory_usage("Before running forward on the layer", force=False) # ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*inputs_cuda) see_memory_usage("After running forward on the layer", force=False) del inputs_cuda if PARTITION_ACTIVATIONS: new_args = get_partitioned_activations_for_backward( args, inputs, CONTIGUOUS_CHECKPOINTING) assert len( new_args ) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}' save_args_for_backward(*new_args) elif CPU_CHECKPOINT: new_args = get_cpu_activations_for_backward(args, inputs) save_args_for_backward(*new_args) else: save_args_for_backward(*args) if PROFILE_TIME: timers('forward').stop() timers.log(['forward']) if SYNCHRONIZE: torch.cuda.synchronize() # Tensors returned from forward() may not be differentiable. if torch.is_tensor(outputs): non_grad_outputs = [outputs ] if not outputs.is_floating_point() else [] else: non_grad_outputs = [ o for o in outputs if torch.is_tensor(o) and not o.is_floating_point() ] ctx.mark_non_differentiable(*non_grad_outputs) if torch.is_tensor(outputs): all_outputs += [outputs] return outputs else: all_outputs += outputs outputs, _, _ = extract_tensors(all_objects=outputs) return tuple(outputs)
def forward(ctx, run_function, all_outputs, *args): global mpu, timers, SYNCHRONIZE, PROFILE_TIME def save_args_for_backward(*all_args): tensor_args, non_tensor_args, tensor_flags = extract_tensors( all_objects=all_args) ctx.save_for_backward(*tensor_args) ctx.non_tensor_args = non_tensor_args ctx.tensor_flags = tensor_flags if SYNCHRONIZE: torch.cuda.synchronize() if timers is None and PROFILE_TIME: timers = Timers() if PROFILE_TIME: timers('forward').start() ctx.run_function = run_function global num_layers global mp_rank, mp_size, mp_group global contiguous_data_buffers, contiguous_size_buffers global data_offsets, size_offsets if mp_rank is None: if mpu is not None: mp_rank = mpu.get_model_parallel_rank() mp_size = mpu.get_model_parallel_world_size() mp_group = mpu.get_model_parallel_group() else: mp_rank = 0 mp_size = 1 mp_group = None global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset if cuda_device is None: see_memory_usage("First Forward Begining", force=False) if dist.get_rank() == 0: logger.info(f"Activation Checkpointing Information") logger.info( f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}" ) logger.info( f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers" ) logger.info(f"----Synchronization {SYNCHRONIZE}") logger.info(f"----Profiling {PROFILE_TIME}") cuda_device = torch.cuda.current_device() transport_stream = torch.cuda.Stream(device=cuda_device) if PARTITION_ACTIVATIONS: #inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]] # inputs.append(args[-1]) inputs = [] for i, item in enumerate(args[:-1]): if not torch.is_tensor(item): inputs.append(item) continue partition_size = get_partition_size(item) partition = item.detach().contiguous().view(-1).narrow( 0, get_partition_start(item), partition_size).clone() if CONTIGUOUS_CHECKPOINTING: buffer_device = torch.device( 'cpu') if PA_TO_CPU else partition.device if i >= len(contiguous_data_buffers): tensor_list = [ torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device) for i in range(num_layers) ] contiguous_data_buffers.append(tensor_list) data_offsets.append(0) elif contiguous_data_buffers[i] is None: tensor_list = [ torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device) for i in range(num_layers) ] contiguous_data_buffers[i] = tensor_list data_offsets[i] = 0 # Because the 'new_empty' returns uninitialized pages, # the pages need to be populated during the cudaMemcpy time # which increases the data copy time. To avoid this, we # pre-populate these pages by simply writing 0 ahead of # the actual cudaMemcpy operation time. Due to the # previously launched GPU kernels, there is a small # window of time here for CPUs to populate pages asynchronously. contiguous_data_buffers[i][data_offsets[i]].data[range( 0, contiguous_data_buffers[i][ data_offsets[i]].data.shape[0], int(mmap.PAGESIZE / contiguous_data_buffers[i][ data_offsets[i]].data.element_size()))] = 0 contiguous_partition = contiguous_data_buffers[i][ data_offsets[i]].data.copy_(partition.data) data_offsets[i] = data_offsets[i] + 1 inputs.append(contiguous_partition) else: partition = partition.cpu() if PA_TO_CPU else partition inputs.append(partition) inputs.append(args[-1]) #just in case something funky is happening such as reuse of inputs inputs_cuda = move_to_device(args, cuda_device) # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() see_memory_usage("Before running forward on the layer", force=False) # ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*inputs_cuda) see_memory_usage("After running forward on the layer", force=False) del inputs_cuda # with torch.cuda.stream(transport_stream): # if PARTITION_ACTIVATIONS: # new_args = [] # for arg, inp in zip(args,inputs): # size= torch.tensor(arg.size()) # arg.data = inp.data # new_args.append(arg) # new_args.append(size) # ctx.save_for_backward(*new_args) if PARTITION_ACTIVATIONS: new_args = [] for i, (arg, inp) in enumerate(zip(args, inputs)): if not torch.is_tensor(arg): new_args.append(arg) continue size = torch.tensor(arg.size()) arg.data = inp.data new_args.append(arg) if CONTIGUOUS_CHECKPOINTING: numel = size.numel() if i >= len(contiguous_size_buffers): tmp = torch.tensor(()) contiguous_size_buffers.append( tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device)) size_offsets.append(0) elif contiguous_size_buffers[i] is None: tmp = torch.tensor(()) contiguous_size_buffers[i] = tmp.new_empty( [numel * num_layers], dtype=size.dtype, device=size.device) size_offsets[i] = 0 contiguous_size = contiguous_size_buffers[i].narrow( 0, size_offsets[i], numel).data.copy_(size.data) contiguous_size = contiguous_size.view_as(size) size_offsets[i] = size_offsets[i] + numel new_args.append(contiguous_size) else: new_args.append(size) # if dist.get_rank() == 0: # logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ") save_args_for_backward(*new_args) else: save_args_for_backward(*args) if PROFILE_TIME: timers('forward').stop() timers.log(['forward']) if SYNCHRONIZE: torch.cuda.synchronize() # Tensors returned from forward() may not be differentiable. if torch.is_tensor(outputs): non_grad_outputs = [outputs ] if not outputs.is_floating_point() else [] else: non_grad_outputs = [ o for o in outputs if torch.is_tensor(o) and not o.is_floating_point() ] ctx.mark_non_differentiable(*non_grad_outputs) if torch.is_tensor(outputs): all_outputs += [outputs] return outputs else: all_outputs += outputs outputs, _, _ = extract_tensors(all_objects=outputs) return tuple(outputs)
def _setup_for_real_optimizer(self): dp_world_size = dist.get_world_size(group=self.dp_process_group) self.partition_count = [ dp_world_size for i in range(len(self.optimizer.param_groups)) ] for i, param_group in enumerate(self.optimizer.param_groups): see_memory_usage(f'before initializing group {i}', force=True) partition_id = dist.get_rank(group=self.real_dp_process_group[i]) # grab the original list self.bf16_groups.append(param_group['params']) # create flat bf16 params self.bf16_groups_flat.append( self._flatten_dense_tensors_aligned( self.bf16_groups[i], self.nccl_start_alignment_factor * dp_world_size)) # Make bf16 params point to flat tensor storage self._update_storage_to_flattened_tensor( tensor_list=self.bf16_groups[i], flat_tensor=self.bf16_groups_flat[i]) # divide flat weights into equal sized partitions partition_size = self.bf16_groups_flat[i].numel() // dp_world_size bf16_dp_partitions = [ self.bf16_groups_flat[i].narrow(0, dp_index * partition_size, partition_size) for dp_index in range(dp_world_size) ] self.bf16_partitioned_groups.append(bf16_dp_partitions) # create fp32 params partition self.fp32_groups_flat_partition.append( bf16_dp_partitions[partition_id].clone().float().detach()) self.fp32_groups_flat_partition[i].requires_grad = True num_elem_list = [t.numel() for t in self.bf16_groups[i]] # create fp32 gradients self.fp32_groups_gradients_flat.append( torch.zeros_like(self.bf16_groups_flat[i], dtype=torch.float32)) # track individual fp32 gradients for entire model fp32_gradients = self._split_flat_tensor( flat_tensor=self.fp32_groups_gradients_flat[i], num_elem_list=num_elem_list) self.fp32_groups_gradients.append(fp32_gradients) # flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding) length_without_padding = sum(num_elem_list) self.fp32_groups_actual_gradients_flat.append( torch.narrow(self.fp32_groups_gradients_flat[i], 0, 0, length_without_padding)) # flat tensor corresponding to gradient partition self.fp32_groups_gradient_flat_partition.append( torch.narrow(self.fp32_groups_gradients_flat[i], 0, partition_id * partition_size, partition_size)) # track fp32 gradient updates self.fp32_groups_has_gradients.append([False] * len(self.bf16_groups[i])) # Record padding required for alignment if partition_id == dist.get_world_size( group=self.real_dp_process_group[i]) - 1: padding = self.bf16_groups_flat[i].numel( ) - length_without_padding else: padding = 0 self.group_paddings.append(padding) # update optimizer param groups to reference fp32 params partition param_group['params'] = [self.fp32_groups_flat_partition[i]] see_memory_usage(f'after initializing group {i}', force=True) see_memory_usage('before initialize_optimizer', force=True) self.initialize_optimizer_states() see_memory_usage('end initialize_optimizer', force=True) # Need optimizer states initialized before linking lp to optimizer state self._link_all_hp_params() self._param_slice_mappings = self._create_param_mapping()