示例#1
0
def train_cifar(model,
                config,
                num_steps=400,
                average_dp_losses=True,
                fp16=True,
                seed=123):
    with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
        ds_utils.set_random_seed(seed)

        # disable dropout
        model.eval()

        trainset = cifar_trainset(fp16=fp16)
        config['local_rank'] = dist.get_rank()

        engine, _, _, _ = deepspeed.initialize(
            config=config,
            model=model,
            model_parameters=[p for p in model.parameters()],
            training_data=trainset)

        losses = []
        for step in range(num_steps):
            loss = engine.train_batch()
            losses.append(loss.item())
            if step % 50 == 0 and dist.get_rank() == 0:
                print(f'STEP={step} LOSS={loss.item()}')

        if average_dp_losses:
            loss_tensor = torch.tensor(losses).cuda()
            dist.all_reduce(loss_tensor)
            loss_tensor /= dist.get_world_size()
            losses = loss_tensor.tolist()

    return losses
示例#2
0
    def _create_ep_parallel_group(self, moe_experts):
        # Call the init process
        self.ep_group = {}
        self.expert_mp_group = {}
        moe_experts = moe_experts if type(moe_experts) is list else [moe_experts]
        for e in moe_experts:
            self.ep_group.update({e: None})
            self.expert_mp_group.update({e: None})
        for moe_ep_size in self.ep_group.keys():
            num_ep_groups = dist.get_world_size() // moe_ep_size
            for i in range(num_ep_groups):
                ep_cnt = i * moe_ep_size
                size = dist.get_world_size(
                ) if moe_ep_size > dist.get_world_size() else moe_ep_size
                ranks = list(range(ep_cnt, ep_cnt + size))
                _ep_group = dist.new_group(ranks)
                if dist.get_rank() in ranks:
                    self.ep_group.update({moe_ep_size: _ep_group})

            if dist.get_world_size() > moe_ep_size:
                num_expert_mp_groups = dist.get_world_size() // num_ep_groups
                expert_mp_size = dist.get_world_size() // moe_ep_size
                for i in range(num_expert_mp_groups):
                    expert_mp_comm_ranks = [
                        i + nr * moe_ep_size for nr in range(expert_mp_size)
                    ]
                    _expert_mp_group = dist.new_group(expert_mp_comm_ranks)
                    if dist.get_rank() in expert_mp_comm_ranks:
                        self.expert_mp_group.update({moe_ep_size: _expert_mp_group})
示例#3
0
 def load_state_dict(self, state_dict):
     """
     Overrides load_state_dict() to add special handling when loading checkpoints
     """
     # Because at different stage exp_avg_mask may change (e.g.,
     # BERT pre-training seqlen 128 and 512 ), we don't use the exp_avg_mask
     # in checkpoints but always use the one user provided in training script.
     # (See example in DeepSpeedExamples/bing_bert/deepspeed_train.py.)
     # Thus here we keep the exp_avg_mask unchanged when loading checkpoint
     for i, group in enumerate(self.param_groups):
         if 'exp_avg_mask' in group:
             state_dict['param_groups'][i]['exp_avg_mask'] = group[
                 'exp_avg_mask']
         elif 'exp_avg_mask' not in group and 'exp_avg_mask' in state_dict[
                 'param_groups'][i]:
             state_dict['param_groups'][i].pop('exp_avg_mask')
     super().load_state_dict(state_dict)
     if self.state[self.param_groups[0]['params']
                   [0]]['step'] < self.freeze_step:
         if dist.get_rank() == 0:
             print(
                 "Checkpoint loaded and OnebitAdam warmup stage starts/continues."
             )
         if self.adam_freeze_key is True:
             self.adam_freeze_key = False
             if self.using_pipeline:
                 self.deepspeed.pipeline_enable_backward_allreduce = True
             else:
                 self.deepspeed.enable_backward_allreduce = True
     else:
         if dist.get_rank() == 0:
             print(
                 "Checkpoint loaded and OnebitAdam compression stage starts/continues."
             )
         if self.adam_freeze_key is False:
             self.adam_freeze_key = True
             if self.using_pipeline:
                 self.deepspeed.pipeline_enable_backward_allreduce = False
             else:
                 self.deepspeed.enable_backward_allreduce = False
     # We reset the compression errors when loading checkpoints for 3 reasons:
     # 1) The worker and server error at each GPU are distinct, so in current implementation
     # only rank 0's errors are saved in the checkpoint. Thus we have to reset the errors.
     # If we want to save them correctly we need O(num_gpu*model_size) memory in order to
     # gather all the error, which is a very large memory requirement. It's possible to save
     # them in a distributed way, but it will make the checkpoint saving/loading much more complicated.
     # 2) Even if we are able to save the compression errors correctly, you need to have the
     # exact same number of GPUs in order to load them correctly.
     # 3) We verified on BERT pre-training that occasionally resetting the compression error
     # at checkpoint loading does not affect the convergence.
     # However, please avoid frequent checkpoint loading which could break the error
     # compensation mechanism thus affect the convergence.
     for group in self.param_groups:
         for p in group['params']:
             if 'worker_error' in self.state[p]:
                 self.state[p].pop('worker_error')
             if 'server_error' in self.state[p]:
                 self.state[p].pop('server_error')
示例#4
0
def test_reduce_scatter_coalesced_tensor_smaller_than_world_sz():
    input = torch.zeros((1, ),
                        dtype=torch.half,
                        device=torch.cuda.current_device())

    (output, ) = reduce_scatter_coalesced([input], dist.get_world_group())

    if dist.get_rank() == 0:
        assert output.shape == (1, )
        assert torch.allclose(output, torch.zeros_like(output))
    elif dist.get_rank() == 1:
        assert output.shape == (0, )
示例#5
0
    def _initialize_parameters(self, parameters, src_tensors, aio_handle):
        assert len(parameters) == len(src_tensors)

        swap_paths = self._get_swap_paths(
            parameters=parameters,
            num_elems=[src.numel() for src in src_tensors])

        SWAP_INIT_TIMER = "swap_init_write"
        self._start_timer(SWAP_INIT_TIMER)

        pinned_buffers = self.swap_buffer_manager.allocate_all(
            num_elems=self.largest_numel, dtype=self.dtype)
        assert pinned_buffers is not None

        self._swap_out_unpinned_tensors(aio_handle=aio_handle,
                                        unpinned_tensors=src_tensors,
                                        dest_paths=swap_paths,
                                        pinned_buffers=pinned_buffers)

        if dist.get_rank() == 0 and SWAPPER_DEBUG_MODE:
            for i, tensor in enumerate(src_tensors):
                logger.info(
                    f'copy_in_fp16_param: fp32_id = {id(parameters[i])} index = {i}, swap_num_elem = {src_tensors[i].numel()}'
                )

        self.swap_buffer_manager.free(pinned_buffers)

        self._stop_timer(SWAP_INIT_TIMER)
        self._log_timers([SWAP_INIT_TIMER])
示例#6
0
 def write_events(self, event_list, flush=True):
     if self.enabled and self.summary_writer is not None and dist.get_rank(
     ) == 0:
         for event in event_list:
             self.summary_writer.add_scalar(*event)
         if flush:
             self.summary_writer.flush()
    def __init__(self, swap_config, aio_config, base_folder, optimizer,
                 largest_numel, device, dtype, timers):
        super(PartitionedOptimizerSwapper,
              self).__init__(swap_config, aio_config, base_folder, optimizer,
                             largest_numel, device, dtype, timers)

        aio_op = AsyncIOBuilder().load()
        self.aio_handle = aio_op.aio_handle(aio_config[AIO_BLOCK_SIZE],
                                            aio_config[AIO_QUEUE_DEPTH],
                                            aio_config[AIO_SINGLE_SUBMIT],
                                            aio_config[AIO_OVERLAP_EVENTS],
                                            aio_config[AIO_THREAD_COUNT])

        # Overlap swapping out
        self.gradient_swapper = AsyncTensorSwapper(
            aio_handle=self.aio_handle,
            numel_alignment=self.numel_alignment,
            timers=self.timers)

        self.print_exclude_list += [
            'aio_handle', 'gradient_swapper', 'print_exclude_list'
        ]

        if dist.get_rank() == 0:
            print_object(obj=self,
                         name='PartitionedOptimizerSwapper',
                         exclude_list=self.print_exclude_list)
    def _swap_in_parameter(self, aio_handle, parameter, dest_buffers):
        swap_info = self._get_param_swap_info(parameter)
        if swap_info is None:
            return

        assert len(swap_info.tensors) <= len(dest_buffers)

        swap_lengths = [self._io_aligned_numel(swap_info.numel())] * len(
            swap_info.tensors)
        swap_buffers = get_sized_buffers(dest_buffers, swap_lengths)

        READ_TIMER = 'swap_submit_read_param'
        WAIT_TIMER = 'swap_wait_read_param'

        self._start_timer(READ_TIMER)
        swap_in_tensors(aio_handle, swap_buffers, swap_info.swap_paths)
        self._stop_timer(READ_TIMER)

        swap_bytes = sum([
            buffer.numel() * buffer.element_size() for buffer in swap_buffers
        ])

        self._start_timer(WAIT_TIMER)
        aio_handle.wait()
        self._stop_timer(WAIT_TIMER)

        compute_lengths = [swap_info.numel()] * len(swap_info.tensors)
        compute_buffers = get_sized_buffers(dest_buffers, compute_lengths)
        for t, buffer in zip(swap_info.tensors, compute_buffers):
            t.data = buffer.data

        self._log_timers([READ_TIMER, WAIT_TIMER])
        if DEBUG_MODE and dist.get_rank() == 0:
            logger.info(
                f'optimizer_param_swap_in: {(swap_bytes/(1024**3)):5.2f} GB')
示例#9
0
 def _report_statistics(self, message):
     if dist.get_rank() == 0:
         element_size = torch.tensor([], dtype=self.dtype).element_size()
         swapped_GB = (self.num_elements_swapped * element_size) / (1024**3)
         logger.debug(
             f'{message} num_elems = {self.num_elements_swapped}, {swapped_GB:5.2f} GB'
         )
示例#10
0
    def test(self, check_using_norm):
        groups._create_expert_and_data_parallel(2)

        param1 = torch.nn.Parameter(torch.Tensor([0]))
        param1.grad = torch.Tensor([1])
        param2 = torch.nn.Parameter(torch.Tensor([0]))
        if dist.get_rank() == 0:
            param2.grad = torch.Tensor([1])
        else:
            param2.grad = torch.Tensor([float("inf")])
        param2.allreduce = False
        # param2 is now MoE parameter
        parameters = [param1, param2]
        if check_using_norm:
            grads_group_flat = [
                _flatten_dense_tensors([p.grad for p in parameters])
            ]
            norm = ds_utils.get_weight_norm(grads_group_flat)
            overflow_checker = ds_utils.CheckOverflow([parameters])
            overflow = overflow_checker.check_using_norm([norm],
                                                         reduce_overflow=False)
        else:
            overflow_checker = ds_utils.CheckOverflow([parameters])
            overflow = overflow_checker.check()
        assert overflow
示例#11
0
def partition_activations_in_checkpoint(partition_activation):
    global PARTITION_ACTIVATIONS
    PARTITION_ACTIVATIONS = partition_activation
    if dist.get_rank() == 0:
        logger.info(
            f"**************Partition Activations {PARTITION_ACTIVATIONS}************"
        )
示例#12
0
    def _load_checkpoint(self, load_dir, load_module_strict=True, tag=None):
        is_pipe_parallel = isinstance(self.module, PipelineModule)
        if is_pipe_parallel:
            raise RuntimeError(
                'pipeline parallelism is currently not supported in inference.')
        if os.path.isdir(load_dir):
            if tag is None:
                latest_path = os.path.join(load_dir, "latest")
                if os.path.isfile(latest_path):
                    with open(latest_path, "r") as fd:
                        tag = fd.read().strip()

            ckpt_list = self._get_all_ckpt_names(load_dir, tag)
            sd_loader = SDLoaderFactory.get_sd_loader(ckpt_list, self.checkpoint_engine)
        else:
            sd_loader = SDLoaderFactory.get_sd_loader_json(load_dir)

        if type(sd_loader) is list:
            self.sd = torch.load(sd_loader[0], map_location='cpu')
            self.key_list = list(self.sd.keys())

            self.load_model_with_checkpoint(self.module)

            for i in range(1, len(sd_loader)):
                if not dist.is_initialized() or dist.get_rank() == 0:
                    print(f"loading checkpoint ({i})")
                self.sd = torch.load(sd_loader[i], map_location='cuda')
                self.key_list = list(self.sd.keys())
                self.load_model_with_checkpoint(self.module)
        else:
            mp_rank = 0 if self.mpu is None else self.mpu.get_model_parallel_rank()

            load_path, checkpoint, quantize_config = sd_loader.load(self.mp_world_size,
                                                    mp_rank,
                                                    is_pipe_parallel=is_pipe_parallel,
                                                    quantize=(self.dtype is torch.int8),
                                                    quantize_groups=self.quantize_groups,
                                                    mlp_extra_grouping=self.mlp_extra_grouping)

            self.quantization_scales, self.quantize_merge_count = quantize_config

            moe, _ = has_moe_layers(self.module)
            if moe:
                from deepspeed.runtime.engine import DeepSpeedEngine
                old_moe_load = False
                if not isinstance(checkpoint['num_experts'], list):
                    old_moe_load = True
                DeepSpeedEngine.load_moe_state_dict(
                    load_dir,
                    tag,
                    state_dict=checkpoint[self._choose_module_key(checkpoint)],
                    old_moe_load=old_moe_load,
                    model=self.module,
                    mpu=self.mpu,
                    checkpoint_engine=self.checkpoint_engine)

            self.module.load_state_dict(
                state_dict=checkpoint[self._choose_module_key(checkpoint)],
                checkpoint_engine=self.checkpoint_engine,
                strict=load_module_strict)
示例#13
0
def _apply_to_tensors_only(module, functional, backward_function, outputs):
    if isinstance(outputs, (tuple, list)):
        touched_outputs = []
        for output in outputs:
            touched_output = _apply_to_tensors_only(module, functional,
                                                    backward_function, output)
            touched_outputs.append(touched_output)
        return outputs.__class__(touched_outputs)
    elif isinstance(outputs, dict):
        # apply inplace to avoid recreating dict inherited objects
        for key in outputs.keys():
            outputs[key] = _apply_to_tensors_only(module, functional,
                                                  backward_function,
                                                  outputs[key])
        return outputs

    elif type(outputs) is torch.Tensor:
        return functional.apply(module, backward_function, outputs)
    else:
        if not is_builtin_type(outputs):
            global warned
            if not warned and dist.get_rank() == 0:
                logger.warning(
                    f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. "
                    "The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and "
                    "output tensors and therefore may not get triggered properly."
                )
                warned = True
        return outputs
示例#14
0
def all_gather_dp_groups(partitioned_param_groups, dp_process_group,
                         start_alignment_factor, allgather_bucket_size):
    for group_id, partitioned_params in enumerate(partitioned_param_groups):
        # Sequential AllGather Best of both worlds
        partition_id = dist.get_rank(group=dp_process_group[group_id])
        dp_world_size = dist.get_world_size(group=dp_process_group[group_id])

        num_shards = max(
            1, partitioned_params[partition_id].numel() * dp_world_size //
            allgather_bucket_size)

        shard_size = partitioned_params[partition_id].numel() // num_shards

        # Enforce nccl/rccl alignment of start location of each shard
        shard_size = shard_size - (shard_size % start_alignment_factor)

        num_elements = shard_size

        assert shard_size * num_shards <= partitioned_params[
            partition_id].numel()

        for shard_id in range(num_shards):

            if shard_id == (num_shards - 1):
                num_elements = partitioned_params[partition_id].numel(
                ) - shard_id * shard_size

            shard_list = []
            for dp_id in range(dp_world_size):
                curr_shard = partitioned_params[dp_id].narrow(
                    0, shard_id * shard_size, num_elements).detach()
                shard_list.append(curr_shard)

            dist.all_gather(shard_list, shard_list[partition_id],
                            dp_process_group[group_id])
示例#15
0
 def test(self):
     x = torch.ones(1, 3).cuda() * (dist.get_rank() + 1)
     sum_of_ranks = (dist.get_world_size() *
                     (dist.get_world_size() + 1)) // 2
     result = torch.ones(1, 3).cuda() * sum_of_ranks
     dist.all_reduce(x)
     assert torch.all(x == result)
示例#16
0
def see_memory_usage(message, force=False):
    if not force:
        return
    if dist.is_initialized() and not dist.get_rank() == 0:
        return

    # python doesn't do real-time garbage collection so do it explicitly to get the correct RAM reports
    gc.collect()

    # Print message except when distributed but not rank 0
    logger.info(message)
    logger.info(
        f"MA {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB \
        Max_MA {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB \
        CA {round(torch_memory_reserved() / (1024 * 1024 * 1024),2)} GB \
        Max_CA {round(torch_max_memory_reserved() / (1024 * 1024 * 1024))} GB "
    )

    vm_stats = psutil.virtual_memory()
    used_GB = round(((vm_stats.total - vm_stats.available) / (1024**3)), 2)
    logger.info(
        f'CPU Virtual Memory:  used = {used_GB} GB, percent = {vm_stats.percent}%'
    )

    # get the peak memory to report correct data, so reset the counter for the next call
    if hasattr(torch.cuda, "reset_peak_memory_stats"):  # pytorch 1.4+
        torch.cuda.reset_peak_memory_stats()
示例#17
0
    def write_events(self, event_list):
        if self.enabled and dist.get_rank() == 0:
            import csv
            # We assume each event_list element is a tensorboard-style tuple in the format: (log_name: String, value, step: Int)
            for event in event_list:
                log_name = event[0]
                value = event[1]
                step = event[2]

                # Set the header to the log_name
                # Need this check because the deepspeed engine currently formats log strings to separate with '/'
                if '/' in log_name:
                    record_splits = log_name.split('/')
                    header = record_splits[len(record_splits) - 1]
                else:
                    header = log_name

                # sanitize common naming conventions into filename
                filename = log_name.replace('/', '_').replace(' ', '_')
                fname = self.log_dir + '/' + filename + '.csv'

                # Open file and record event. Insert header if this is the first time writing
                with open(fname, 'a+') as csv_monitor_file:
                    csv_monitor_writer = csv.writer(csv_monitor_file)
                    if filename not in self.filenames:
                        self.filenames.append(filename)
                        csv_monitor_writer.writerow(['step', header])
                    csv_monitor_writer.writerow([step, value])
示例#18
0
def print_json_dist(message, ranks=None, path=None):
    from deepspeed import comm as dist
    """Print message when one of following condition meets

    + not dist.is_initialized()
    + dist.get_rank() in ranks if ranks is not None or ranks = [-1]

    Args:
        message (str)
        ranks (list)
        path (str)

    """
    should_log = not dist.is_initialized()
    ranks = ranks or []
    my_rank = dist.get_rank() if dist.is_initialized() else -1
    if ranks and not should_log:
        should_log = ranks[0] == -1
        should_log = should_log or (my_rank in set(ranks))
    if should_log:
        message['rank'] = my_rank
        import json
        with open(path, 'w') as outfile:
            json.dump(message, outfile)
            os.fsync(outfile)
示例#19
0
 def _restore_from_bit16_weights(self):
     for i, group in enumerate(self.bf16_groups):
         partition_id = dist.get_rank(group=self.real_dp_process_group[i])
         for bf16_partitions, fp32_partition in zip(
                 self.bf16_partitioned_groups,
                 self.fp32_groups_flat_partition):
             fp32_partition.data.copy_(bf16_partitions[partition_id].data)
示例#20
0
    def _load_legacy_checkpoint(self,
                                state_dict_list,
                                load_optimizer_states=True,
                                load_from_fp32_weights=False):

        dp_rank = dist.get_rank(group=self.dp_process_group)
        current_rank_sd = state_dict_list[dp_rank]

        ckpt_version = current_rank_sd.get(DS_VERSION, False)
        assert ckpt_version, f"Empty ds_version in checkpoint, not clear how to proceed"
        ckpt_version = pkg_version.parse(ckpt_version)

        self.clip_grad = current_rank_sd.get(CLIP_GRAD, self.clip_grad)

        if load_optimizer_states:
            self.optimizer.load_state_dict(
                current_rank_sd[BASE_OPTIMIZER_STATE])

        if load_from_fp32_weights:
            for current, saved in zip(
                    self.fp32_groups_flat_partition,
                    current_rank_sd[SINGLE_PARTITION_OF_FP32_GROUPS]):
                src_tensor = _get_padded_tensor(saved, current.numel())
                current.data.copy_(src_tensor.data)

        if load_optimizer_states:
            self._link_all_hp_params()
示例#21
0
 def write_events(self, event_list):
     if self.enabled and dist.get_rank() == 0:
         for event in event_list:
             label = event[0]
             value = event[1]
             step = event[2]
             self.log({label: value}, step=step)
示例#22
0
 def load(module, state_dict, prefix):
     args = (state_dict, prefix, {}, True, [], [], error_msgs)
     if len(list(module.parameters())) > 0 and list(
             module.parameters())[0].numel() == 0:
         with GatheredParameters(list(module.parameters(recurse=False)),
                                 modifier_rank=0):
             if dist.get_rank() == 0:
                 module._load_from_state_dict(*args)
     else:
         if hasattr(module, 'weight'):
             if 'query_key_value' in prefix:
                 module.weight = self.mp_replace.qkv_copy(
                     module.weight.data,
                     state_dict[prefix + 'weight'])
             else:
                 module.weight = self.mp_replace.copy(
                     module.weight.data,
                     state_dict[prefix + 'weight'])
         else:
             module.norm.weight = self.mp_replace.copy(
                 module.norm.weight.data,
                 state_dict[prefix + 'weight'])
         if prefix + 'bias' in self.key_list:
             if hasattr(module, 'norm'):
                 module.norm.bias = self.mp_replace.copy(
                     module.norm.bias,
                     state_dict[prefix + 'bias'])
             else:
                 data = state_dict[prefix + 'bias']
                 data = data.to(torch.cuda.current_device())
                 module.bias = self.mp_replace.copy(module.bias, data)
示例#23
0
    def _distributed_test():
        ds_cfg = {
            "train_micro_batch_size_per_gpu": 1,
            "zero_optimization": {
                "stage": 3,
                "stage3_max_reuse_distance": 0,
                "contiguous_gradients": True,
                "overlap_comm": True,
            },
            "optimizer": {
                "type": "Adam",
                "params": {
                    "lr": 1.
                }
            },
            "fp16": {
                "enabled": True,
                "loss_scale": 1.,
            }
        }

        with deepspeed.zero.Init(config=ds_cfg,
                                 mem_efficient_linear=False,
                                 enabled=init_context_manager):
            model = ManyParamModel()

        ds_engine = _ds_initialize_for_param_partitioning_testing(
            model, ds_cfg)

        for _ in range(3):  # test multiple iterations to cover prefetching
            activations: List[Tensor] = ds_engine(
                torch.ones((param_sz, ),
                           dtype=torch.float16,
                           device=ds_engine.device))
            assert len(activations) == n_layers

            partition_sz = math.ceil(param_sz / world_sz)
            expected_activations = torch.empty(param_sz,
                                               dtype=torch.float16,
                                               device=ds_engine.device)
            for start_idx in range(0, param_sz, partition_sz):
                expected_activations[start_idx:start_idx +
                                     partition_sz] = dist.get_rank()

            for layer_num, activation in enumerate(activations):
                expected_activations *= 2 * layer_num
                assert torch.allclose(activation, expected_activations)

            # TODO. finish writing this test
            ds_engine.backward(activations[-1].sum())

            avgd_gradients = ds_engine.optimizer.averaged_gradients
            assert set(avgd_gradients.keys()) == {
                0
            }, "should only have one parameter group"
            weight_gradients: List[Tensor] = avgd_gradients[0]

            for layer_num, activation in enumerate(weight_gradients):
                pass
示例#24
0
 def write_events(self, event_list):
     if dist.get_rank() == 0:
         if self.tb_monitor is not None:
             self.tb_monitor.write_events(event_list)
         if self.wandb_monitor is not None:
             self.wandb_monitor.write_events(event_list)
         if self.csv_monitor is not None:
             self.csv_monitor.write_events(event_list)
示例#25
0
 def __init__(self, mp_group=None, mp_size=1, out_dim=1, in_dim=0):
     if mp_group is not None:
         self.gpu_index = dist.get_rank(group=mp_group)
     else:
         self.gpu_index = 0
     self.out_dim = out_dim
     self.in_dim = in_dim
     self.mp_size = mp_size
示例#26
0
def create_deepspeed_args():
    parser = argparse.ArgumentParser()
    args = parser.parse_args(args='')
    args.deepspeed = True
    if dist.is_initialized():
        # We assume up to one full node executing unit tests
        assert dist.get_world_size() <= torch.cuda.device_count()
        args.local_rank = dist.get_rank()
    return args
示例#27
0
    def __init__(self, tensor, group, partition_meta=None):
        super().__init__()

        self.group = group
        self.num_parts = dist.get_world_size(group=self.group)
        self.rank = dist.get_rank(group=self.group)

        self.orig_size = list(tensor.size())
        self.orig_device = tensor.device
        self.local_data, self.partition = self._partition_tensor(tensor)
示例#28
0
def test_reduce_scatter_coalesced_single_input():
    input = torch.full((6, ),
                       dist.get_rank(),
                       dtype=torch.half,
                       device=torch.cuda.current_device())

    (output, ) = reduce_scatter_coalesced([input], dist.get_world_group())

    assert output.shape == (3, )
    assert torch.allclose(output, torch.full_like(output, 0.5))
示例#29
0
 def _link_all_hp_params(self):
     dp_world_size = dist.get_world_size(group=self.dp_process_group)
     for i, param_group in enumerate(self.optimizer.param_groups):
         # Link bf16 and fp32 params in partition
         partition_id = dist.get_rank(group=self.real_dp_process_group[i])
         partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
         self._link_hp_params(self.bf16_groups[i],
                              self.fp32_groups_flat_partition[i],
                              partition_id * partition_size, partition_size,
                              self.real_dp_process_group[i])
示例#30
0
    def __init__(self,
                 init_optimizer,
                 param_names,
                 mpu=None,
                 clip_grad=0.0,
                 norm_type=2,
                 allgather_bucket_size=5000000000,
                 dp_process_group=None,
                 timers=None):
        super().__init__()
        see_memory_usage('begin bf16_optimizer', force=True)
        self.timers = timers
        self.optimizer = init_optimizer
        self.param_names = param_names
        self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)

        self.clip_grad = clip_grad
        self.norm_type = norm_type
        self.mpu = mpu
        self.allgather_bucket_size = int(allgather_bucket_size)
        self.dp_process_group = dp_process_group
        self.dp_rank = dist.get_rank(group=self.dp_process_group)
        self.real_dp_process_group = [
            dp_process_group for i in range(len(self.optimizer.param_groups))
        ]

        # Load pre-built or JIT compile (un)flatten ops
        util_ops = UtilsBuilder().load()
        self.flatten = util_ops.flatten
        self.unflatten = util_ops.unflatten

        #align nccl all-gather send buffers to 4-bye boundary
        self.nccl_start_alignment_factor = 2  # 4-byte alignment/sizeof(fp16) = 2

        # Build BF16/FP32 groups
        self.bf16_groups = []
        self.bf16_groups_flat = []
        self.bf16_partitioned_groups = []

        self.fp32_groups_flat_partition = []

        # Maintain different fp32 gradients views for convenience
        self.fp32_groups_gradients = []
        self.fp32_groups_gradients_flat = []
        self.fp32_groups_actual_gradients_flat = []
        self.fp32_groups_gradient_flat_partition = []
        self.fp32_groups_has_gradients = []

        self.step_count = 0
        self.group_paddings = []

        if self.using_real_optimizer:
            self._setup_for_real_optimizer()

        see_memory_usage('end bf16_optimizer', force=True)