def train_cifar(model, config, num_steps=400, average_dp_losses=True, fp16=True, seed=123): with torch.random.fork_rng(devices=[torch.cuda.current_device()]): ds_utils.set_random_seed(seed) # disable dropout model.eval() trainset = cifar_trainset(fp16=fp16) config['local_rank'] = dist.get_rank() engine, _, _, _ = deepspeed.initialize( config=config, model=model, model_parameters=[p for p in model.parameters()], training_data=trainset) losses = [] for step in range(num_steps): loss = engine.train_batch() losses.append(loss.item()) if step % 50 == 0 and dist.get_rank() == 0: print(f'STEP={step} LOSS={loss.item()}') if average_dp_losses: loss_tensor = torch.tensor(losses).cuda() dist.all_reduce(loss_tensor) loss_tensor /= dist.get_world_size() losses = loss_tensor.tolist() return losses
def _create_ep_parallel_group(self, moe_experts): # Call the init process self.ep_group = {} self.expert_mp_group = {} moe_experts = moe_experts if type(moe_experts) is list else [moe_experts] for e in moe_experts: self.ep_group.update({e: None}) self.expert_mp_group.update({e: None}) for moe_ep_size in self.ep_group.keys(): num_ep_groups = dist.get_world_size() // moe_ep_size for i in range(num_ep_groups): ep_cnt = i * moe_ep_size size = dist.get_world_size( ) if moe_ep_size > dist.get_world_size() else moe_ep_size ranks = list(range(ep_cnt, ep_cnt + size)) _ep_group = dist.new_group(ranks) if dist.get_rank() in ranks: self.ep_group.update({moe_ep_size: _ep_group}) if dist.get_world_size() > moe_ep_size: num_expert_mp_groups = dist.get_world_size() // num_ep_groups expert_mp_size = dist.get_world_size() // moe_ep_size for i in range(num_expert_mp_groups): expert_mp_comm_ranks = [ i + nr * moe_ep_size for nr in range(expert_mp_size) ] _expert_mp_group = dist.new_group(expert_mp_comm_ranks) if dist.get_rank() in expert_mp_comm_ranks: self.expert_mp_group.update({moe_ep_size: _expert_mp_group})
def load_state_dict(self, state_dict): """ Overrides load_state_dict() to add special handling when loading checkpoints """ # Because at different stage exp_avg_mask may change (e.g., # BERT pre-training seqlen 128 and 512 ), we don't use the exp_avg_mask # in checkpoints but always use the one user provided in training script. # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.) # Thus here we keep the exp_avg_mask unchanged when loading checkpoint for i, group in enumerate(self.param_groups): if 'exp_avg_mask' in group: state_dict['param_groups'][i]['exp_avg_mask'] = group[ 'exp_avg_mask'] elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict[ 'param_groups'][i]: state_dict['param_groups'][i].pop('exp_avg_mask') super().load_state_dict(state_dict) if self.state[self.param_groups[0]['params'] [0]]['step'] < self.freeze_step: if dist.get_rank() == 0: print( "Checkpoint loaded and OnebitAdam warmup stage starts/continues." ) if self.adam_freeze_key is True: self.adam_freeze_key = False if self.using_pipeline: self.deepspeed.pipeline_enable_backward_allreduce = True else: self.deepspeed.enable_backward_allreduce = True else: if dist.get_rank() == 0: print( "Checkpoint loaded and OnebitAdam compression stage starts/continues." ) if self.adam_freeze_key is False: self.adam_freeze_key = True if self.using_pipeline: self.deepspeed.pipeline_enable_backward_allreduce = False else: self.deepspeed.enable_backward_allreduce = False # We reset the compression errors when loading checkpoints for 3 reasons: # 1) The worker and server error at each GPU are distinct, so in current implementation # only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors. # If we want to save them correctly we need O(num_gpu*model_size) memory in order to # gather all the error, which is a very large memory requirement. It's possible to save # them in a distributed way, but it will make the checkpoint saving/loading much more complicated. # 2) Even if we are able to save the compression errors correctly, you need to have the # exact same number of GPUs in order to load them correctly. # 3) We verified on BERT pre-training that occasionally resetting the compression error # at checkpoint loading does not affect the convergence. # However, please avoid frequent checkpoint loading which could break the error # compensation mechanism thus affect the convergence. for group in self.param_groups: for p in group['params']: if 'worker_error' in self.state[p]: self.state[p].pop('worker_error') if 'server_error' in self.state[p]: self.state[p].pop('server_error')
def test_reduce_scatter_coalesced_tensor_smaller_than_world_sz(): input = torch.zeros((1, ), dtype=torch.half, device=torch.cuda.current_device()) (output, ) = reduce_scatter_coalesced([input], dist.get_world_group()) if dist.get_rank() == 0: assert output.shape == (1, ) assert torch.allclose(output, torch.zeros_like(output)) elif dist.get_rank() == 1: assert output.shape == (0, )
def _initialize_parameters(self, parameters, src_tensors, aio_handle): assert len(parameters) == len(src_tensors) swap_paths = self._get_swap_paths( parameters=parameters, num_elems=[src.numel() for src in src_tensors]) SWAP_INIT_TIMER = "swap_init_write" self._start_timer(SWAP_INIT_TIMER) pinned_buffers = self.swap_buffer_manager.allocate_all( num_elems=self.largest_numel, dtype=self.dtype) assert pinned_buffers is not None self._swap_out_unpinned_tensors(aio_handle=aio_handle, unpinned_tensors=src_tensors, dest_paths=swap_paths, pinned_buffers=pinned_buffers) if dist.get_rank() == 0 and SWAPPER_DEBUG_MODE: for i, tensor in enumerate(src_tensors): logger.info( f'copy_in_fp16_param: fp32_id = {id(parameters[i])} index = {i}, swap_num_elem = {src_tensors[i].numel()}' ) self.swap_buffer_manager.free(pinned_buffers) self._stop_timer(SWAP_INIT_TIMER) self._log_timers([SWAP_INIT_TIMER])
def write_events(self, event_list, flush=True): if self.enabled and self.summary_writer is not None and dist.get_rank( ) == 0: for event in event_list: self.summary_writer.add_scalar(*event) if flush: self.summary_writer.flush()
def __init__(self, swap_config, aio_config, base_folder, optimizer, largest_numel, device, dtype, timers): super(PartitionedOptimizerSwapper, self).__init__(swap_config, aio_config, base_folder, optimizer, largest_numel, device, dtype, timers) aio_op = AsyncIOBuilder().load() self.aio_handle = aio_op.aio_handle(aio_config[AIO_BLOCK_SIZE], aio_config[AIO_QUEUE_DEPTH], aio_config[AIO_SINGLE_SUBMIT], aio_config[AIO_OVERLAP_EVENTS], aio_config[AIO_THREAD_COUNT]) # Overlap swapping out self.gradient_swapper = AsyncTensorSwapper( aio_handle=self.aio_handle, numel_alignment=self.numel_alignment, timers=self.timers) self.print_exclude_list += [ 'aio_handle', 'gradient_swapper', 'print_exclude_list' ] if dist.get_rank() == 0: print_object(obj=self, name='PartitionedOptimizerSwapper', exclude_list=self.print_exclude_list)
def _swap_in_parameter(self, aio_handle, parameter, dest_buffers): swap_info = self._get_param_swap_info(parameter) if swap_info is None: return assert len(swap_info.tensors) <= len(dest_buffers) swap_lengths = [self._io_aligned_numel(swap_info.numel())] * len( swap_info.tensors) swap_buffers = get_sized_buffers(dest_buffers, swap_lengths) READ_TIMER = 'swap_submit_read_param' WAIT_TIMER = 'swap_wait_read_param' self._start_timer(READ_TIMER) swap_in_tensors(aio_handle, swap_buffers, swap_info.swap_paths) self._stop_timer(READ_TIMER) swap_bytes = sum([ buffer.numel() * buffer.element_size() for buffer in swap_buffers ]) self._start_timer(WAIT_TIMER) aio_handle.wait() self._stop_timer(WAIT_TIMER) compute_lengths = [swap_info.numel()] * len(swap_info.tensors) compute_buffers = get_sized_buffers(dest_buffers, compute_lengths) for t, buffer in zip(swap_info.tensors, compute_buffers): t.data = buffer.data self._log_timers([READ_TIMER, WAIT_TIMER]) if DEBUG_MODE and dist.get_rank() == 0: logger.info( f'optimizer_param_swap_in: {(swap_bytes/(1024**3)):5.2f} GB')
def _report_statistics(self, message): if dist.get_rank() == 0: element_size = torch.tensor([], dtype=self.dtype).element_size() swapped_GB = (self.num_elements_swapped * element_size) / (1024**3) logger.debug( f'{message} num_elems = {self.num_elements_swapped}, {swapped_GB:5.2f} GB' )
def test(self, check_using_norm): groups._create_expert_and_data_parallel(2) param1 = torch.nn.Parameter(torch.Tensor([0])) param1.grad = torch.Tensor([1]) param2 = torch.nn.Parameter(torch.Tensor([0])) if dist.get_rank() == 0: param2.grad = torch.Tensor([1]) else: param2.grad = torch.Tensor([float("inf")]) param2.allreduce = False # param2 is now MoE parameter parameters = [param1, param2] if check_using_norm: grads_group_flat = [ _flatten_dense_tensors([p.grad for p in parameters]) ] norm = ds_utils.get_weight_norm(grads_group_flat) overflow_checker = ds_utils.CheckOverflow([parameters]) overflow = overflow_checker.check_using_norm([norm], reduce_overflow=False) else: overflow_checker = ds_utils.CheckOverflow([parameters]) overflow = overflow_checker.check() assert overflow
def partition_activations_in_checkpoint(partition_activation): global PARTITION_ACTIVATIONS PARTITION_ACTIVATIONS = partition_activation if dist.get_rank() == 0: logger.info( f"**************Partition Activations {PARTITION_ACTIVATIONS}************" )
def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None): is_pipe_parallel = isinstance(self.module, PipelineModule) if is_pipe_parallel: raise RuntimeError( 'pipeline parallelism is currently not supported in inference.') if os.path.isdir(load_dir): if tag is None: latest_path = os.path.join(load_dir, "latest") if os.path.isfile(latest_path): with open(latest_path, "r") as fd: tag = fd.read().strip() ckpt_list = self._get_all_ckpt_names(load_dir, tag) sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine) else: sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir) if type(sd_loader) is list: self.sd = torch.load(sd_loader[0], map_location='cpu') self.key_list = list(self.sd.keys()) self.load_model_with_checkpoint(self.module) for i in range(1, len(sd_loader)): if not dist.is_initialized() or dist.get_rank() == 0: print(f"loading checkpoint ({i})") self.sd = torch.load(sd_loader[i], map_location='cuda') self.key_list = list(self.sd.keys()) self.load_model_with_checkpoint(self.module) else: mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank() load_path, checkpoint, quantize_config = sd_loader.load(self.mp_world_size, mp_rank, is_pipe_parallel=is_pipe_parallel, quantize=(self.dtype is torch.int8), quantize_groups=self.quantize_groups, mlp_extra_grouping=self.mlp_extra_grouping) self.quantization_scales, self.quantize_merge_count = quantize_config moe, _ = has_moe_layers(self.module) if moe: from deepspeed.runtime.engine import DeepSpeedEngine old_moe_load = False if not isinstance(checkpoint['num_experts'], list): old_moe_load = True DeepSpeedEngine.load_moe_state_dict( load_dir, tag, state_dict=checkpoint[self._choose_module_key(checkpoint)], old_moe_load=old_moe_load, model=self.module, mpu=self.mpu, checkpoint_engine=self.checkpoint_engine) self.module.load_state_dict( state_dict=checkpoint[self._choose_module_key(checkpoint)], checkpoint_engine=self.checkpoint_engine, strict=load_module_strict)
def _apply_to_tensors_only(module, functional, backward_function, outputs): if isinstance(outputs, (tuple, list)): touched_outputs = [] for output in outputs: touched_output = _apply_to_tensors_only(module, functional, backward_function, output) touched_outputs.append(touched_output) return outputs.__class__(touched_outputs) elif isinstance(outputs, dict): # apply inplace to avoid recreating dict inherited objects for key in outputs.keys(): outputs[key] = _apply_to_tensors_only(module, functional, backward_function, outputs[key]) return outputs elif type(outputs) is torch.Tensor: return functional.apply(module, backward_function, outputs) else: if not is_builtin_type(outputs): global warned if not warned and dist.get_rank() == 0: logger.warning( f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. " "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and " "output tensors and therefore may not get triggered properly." ) warned = True return outputs
def all_gather_dp_groups(partitioned_param_groups, dp_process_group, start_alignment_factor, allgather_bucket_size): for group_id, partitioned_params in enumerate(partitioned_param_groups): # Sequential AllGather Best of both worlds partition_id = dist.get_rank(group=dp_process_group[group_id]) dp_world_size = dist.get_world_size(group=dp_process_group[group_id]) num_shards = max( 1, partitioned_params[partition_id].numel() * dp_world_size // allgather_bucket_size) shard_size = partitioned_params[partition_id].numel() // num_shards # Enforce nccl/rccl alignment of start location of each shard shard_size = shard_size - (shard_size % start_alignment_factor) num_elements = shard_size assert shard_size * num_shards <= partitioned_params[ partition_id].numel() for shard_id in range(num_shards): if shard_id == (num_shards - 1): num_elements = partitioned_params[partition_id].numel( ) - shard_id * shard_size shard_list = [] for dp_id in range(dp_world_size): curr_shard = partitioned_params[dp_id].narrow( 0, shard_id * shard_size, num_elements).detach() shard_list.append(curr_shard) dist.all_gather(shard_list, shard_list[partition_id], dp_process_group[group_id])
def test(self): x = torch.ones(1, 3).cuda() * (dist.get_rank() + 1) sum_of_ranks = (dist.get_world_size() * (dist.get_world_size() + 1)) // 2 result = torch.ones(1, 3).cuda() * sum_of_ranks dist.all_reduce(x) assert torch.all(x == result)
def see_memory_usage(message, force=False): if not force: return if dist.is_initialized() and not dist.get_rank() == 0: return # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports gc.collect() # Print message except when distributed but not rank 0 logger.info(message) logger.info( f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \ Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \ CA {round(torch_memory_reserved() / (1024 * 1024 * 1024),2)} GB \ Max_CA {round(torch_max_memory_reserved() / (1024 * 1024 * 1024))} GB " ) vm_stats = psutil.virtual_memory() used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2) logger.info( f'CPU Virtual Memory: used = {used_GB} GB, percent = {vm_stats.percent}%' ) # get the peak memory to report correct data, so reset the counter for the next call if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ torch.cuda.reset_peak_memory_stats()
def write_events(self, event_list): if self.enabled and dist.get_rank() == 0: import csv # We assume each event_list element is a tensorboard-style tuple in the format: (log_name: String, value, step: Int) for event in event_list: log_name = event[0] value = event[1] step = event[2] # Set the header to the log_name # Need this check because the deepspeed engine currently formats log strings to separate with '/' if '/' in log_name: record_splits = log_name.split('/') header = record_splits[len(record_splits) - 1] else: header = log_name # sanitize common naming conventions into filename filename = log_name.replace('/', '_').replace(' ', '_') fname = self.log_dir + '/' + filename + '.csv' # Open file and record event. Insert header if this is the first time writing with open(fname, 'a+') as csv_monitor_file: csv_monitor_writer = csv.writer(csv_monitor_file) if filename not in self.filenames: self.filenames.append(filename) csv_monitor_writer.writerow(['step', header]) csv_monitor_writer.writerow([step, value])
def print_json_dist(message, ranks=None, path=None): from deepspeed import comm as dist """Print message when one of following condition meets + not dist.is_initialized() + dist.get_rank() in ranks if ranks is not None or ranks = [-1] Args: message (str) ranks (list) path (str) """ should_log = not dist.is_initialized() ranks = ranks or [] my_rank = dist.get_rank() if dist.is_initialized() else -1 if ranks and not should_log: should_log = ranks[0] == -1 should_log = should_log or (my_rank in set(ranks)) if should_log: message['rank'] = my_rank import json with open(path, 'w') as outfile: json.dump(message, outfile) os.fsync(outfile)
def _restore_from_bit16_weights(self): for i, group in enumerate(self.bf16_groups): partition_id = dist.get_rank(group=self.real_dp_process_group[i]) for bf16_partitions, fp32_partition in zip( self.bf16_partitioned_groups, self.fp32_groups_flat_partition): fp32_partition.data.copy_(bf16_partitions[partition_id].data)
def _load_legacy_checkpoint(self, state_dict_list, load_optimizer_states=True, load_from_fp32_weights=False): dp_rank = dist.get_rank(group=self.dp_process_group) current_rank_sd = state_dict_list[dp_rank] ckpt_version = current_rank_sd.get(DS_VERSION, False) assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed" ckpt_version = pkg_version.parse(ckpt_version) self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad) if load_optimizer_states: self.optimizer.load_state_dict( current_rank_sd[BASE_OPTIMIZER_STATE]) if load_from_fp32_weights: for current, saved in zip( self.fp32_groups_flat_partition, current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]): src_tensor = _get_padded_tensor(saved, current.numel()) current.data.copy_(src_tensor.data) if load_optimizer_states: self._link_all_hp_params()
def write_events(self, event_list): if self.enabled and dist.get_rank() == 0: for event in event_list: label = event[0] value = event[1] step = event[2] self.log({label: value}, step=step)
def load(module, state_dict, prefix): args = (state_dict, prefix, {}, True, [], [], error_msgs) if len(list(module.parameters())) > 0 and list( module.parameters())[0].numel() == 0: with GatheredParameters(list(module.parameters(recurse=False)), modifier_rank=0): if dist.get_rank() == 0: module._load_from_state_dict(*args) else: if hasattr(module, 'weight'): if 'query_key_value' in prefix: module.weight = self.mp_replace.qkv_copy( module.weight.data, state_dict[prefix + 'weight']) else: module.weight = self.mp_replace.copy( module.weight.data, state_dict[prefix + 'weight']) else: module.norm.weight = self.mp_replace.copy( module.norm.weight.data, state_dict[prefix + 'weight']) if prefix + 'bias' in self.key_list: if hasattr(module, 'norm'): module.norm.bias = self.mp_replace.copy( module.norm.bias, state_dict[prefix + 'bias']) else: data = state_dict[prefix + 'bias'] data = data.to(torch.cuda.current_device()) module.bias = self.mp_replace.copy(module.bias, data)
def _distributed_test(): ds_cfg = { "train_micro_batch_size_per_gpu": 1, "zero_optimization": { "stage": 3, "stage3_max_reuse_distance": 0, "contiguous_gradients": True, "overlap_comm": True, }, "optimizer": { "type": "Adam", "params": { "lr": 1. } }, "fp16": { "enabled": True, "loss_scale": 1., } } with deepspeed.zero.Init(config=ds_cfg, mem_efficient_linear=False, enabled=init_context_manager): model = ManyParamModel() ds_engine = _ds_initialize_for_param_partitioning_testing( model, ds_cfg) for _ in range(3): # test multiple iterations to cover prefetching activations: List[Tensor] = ds_engine( torch.ones((param_sz, ), dtype=torch.float16, device=ds_engine.device)) assert len(activations) == n_layers partition_sz = math.ceil(param_sz / world_sz) expected_activations = torch.empty(param_sz, dtype=torch.float16, device=ds_engine.device) for start_idx in range(0, param_sz, partition_sz): expected_activations[start_idx:start_idx + partition_sz] = dist.get_rank() for layer_num, activation in enumerate(activations): expected_activations *= 2 * layer_num assert torch.allclose(activation, expected_activations) # TODO. finish writing this test ds_engine.backward(activations[-1].sum()) avgd_gradients = ds_engine.optimizer.averaged_gradients assert set(avgd_gradients.keys()) == { 0 }, "should only have one parameter group" weight_gradients: List[Tensor] = avgd_gradients[0] for layer_num, activation in enumerate(weight_gradients): pass
def write_events(self, event_list): if dist.get_rank() == 0: if self.tb_monitor is not None: self.tb_monitor.write_events(event_list) if self.wandb_monitor is not None: self.wandb_monitor.write_events(event_list) if self.csv_monitor is not None: self.csv_monitor.write_events(event_list)
def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0): if mp_group is not None: self.gpu_index = dist.get_rank(group=mp_group) else: self.gpu_index = 0 self.out_dim = out_dim self.in_dim = in_dim self.mp_size = mp_size
def create_deepspeed_args(): parser = argparse.ArgumentParser() args = parser.parse_args(args='') args.deepspeed = True if dist.is_initialized(): # We assume up to one full node executing unit tests assert dist.get_world_size() <= torch.cuda.device_count() args.local_rank = dist.get_rank() return args
def __init__(self, tensor, group, partition_meta=None): super().__init__() self.group = group self.num_parts = dist.get_world_size(group=self.group) self.rank = dist.get_rank(group=self.group) self.orig_size = list(tensor.size()) self.orig_device = tensor.device self.local_data, self.partition = self._partition_tensor(tensor)
def test_reduce_scatter_coalesced_single_input(): input = torch.full((6, ), dist.get_rank(), dtype=torch.half, device=torch.cuda.current_device()) (output, ) = reduce_scatter_coalesced([input], dist.get_world_group()) assert output.shape == (3, ) assert torch.allclose(output, torch.full_like(output, 0.5))
def _link_all_hp_params(self): dp_world_size = dist.get_world_size(group=self.dp_process_group) for i, param_group in enumerate(self.optimizer.param_groups): # Link bf16 and fp32 params in partition partition_id = dist.get_rank(group=self.real_dp_process_group[i]) partition_size = self.bf16_groups_flat[i].numel() // dp_world_size self._link_hp_params(self.bf16_groups[i], self.fp32_groups_flat_partition[i], partition_id * partition_size, partition_size, self.real_dp_process_group[i])
def __init__(self, init_optimizer, param_names, mpu=None, clip_grad=0.0, norm_type=2, allgather_bucket_size=5000000000, dp_process_group=None, timers=None): super().__init__() see_memory_usage('begin bf16_optimizer', force=True) self.timers = timers self.optimizer = init_optimizer self.param_names = param_names self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim) self.clip_grad = clip_grad self.norm_type = norm_type self.mpu = mpu self.allgather_bucket_size = int(allgather_bucket_size) self.dp_process_group = dp_process_group self.dp_rank = dist.get_rank(group=self.dp_process_group) self.real_dp_process_group = [ dp_process_group for i in range(len(self.optimizer.param_groups)) ] # Load pre-built or JIT compile (un)flatten ops util_ops = UtilsBuilder().load() self.flatten = util_ops.flatten self.unflatten = util_ops.unflatten #align nccl all-gather send buffers to 4-bye boundary self.nccl_start_alignment_factor = 2 # 4-byte alignment/sizeof(fp16) = 2 # Build BF16/FP32 groups self.bf16_groups = [] self.bf16_groups_flat = [] self.bf16_partitioned_groups = [] self.fp32_groups_flat_partition = [] # Maintain different fp32 gradients views for convenience self.fp32_groups_gradients = [] self.fp32_groups_gradients_flat = [] self.fp32_groups_actual_gradients_flat = [] self.fp32_groups_gradient_flat_partition = [] self.fp32_groups_has_gradients = [] self.step_count = 0 self.group_paddings = [] if self.using_real_optimizer: self._setup_for_real_optimizer() see_memory_usage('end bf16_optimizer', force=True)