コード例 #1
0
    def __init__(self,
                 module,
                 timers,
                 ds_config,
                 overlap_comm=True,
                 prefetch_bucket_size=50000000,
                 max_reuse_distance=1000000000,
                 max_live_parameters=1000000000,
                 param_persistence_threshold=100000,
                 model_persistence_threshold=sys.maxsize,
                 offload_param_config=None,
                 mpu=None):

        see_memory_usage("DeepSpeedZeRoOffload initialize [begin]", force=True)

        print_rank_0(f"initialized {__class__.__name__} with args: {locals()}",
                     force=False)

        self.module = module
        self.dtype = list(module.parameters())[0].dtype
        self.offload_device = None
        self.offload_param_pin_memory = False

        if offload_param_config is not None and offload_param_config.device != OffloadDeviceEnum.none:
            self.offload_device = offload_param_config.device
            self.offload_param_pin_memory = offload_param_config.pin_memory

        self._convert_to_zero_parameters(ds_config, module, mpu)

        for m in module.modules():
            _init_external_params(m)

        _inject_parameters(module, ZeROOrderedDict)

        self.param_numel_persistence_threshold = int(
            param_persistence_threshold)
        self.model_persistence_threshold = int(model_persistence_threshold)
        self.persistent_parameters = self.mark_persistent_parameters(
            self.param_numel_persistence_threshold,
            self.model_persistence_threshold)

        self.param_coordinators = {}
        self._prefetch_bucket_sz = int(prefetch_bucket_size)
        self._max_reuse_distance_in_numel = int(max_reuse_distance)
        self._max_available_parameters_in_numel = int(max_live_parameters)
        self.__allgather_stream = Stream(
        ) if overlap_comm else torch.cuda.default_stream()

        self.forward_hooks = []
        self.backward_hooks = []
        self.setup_zero_stage3_hooks()
        print_rank_0(
            f'Created module hooks: forward = {len(self.forward_hooks)}, backward = {len(self.backward_hooks)}',
            force=False)

        see_memory_usage("DeepSpeedZeRoOffload initialize [end]", force=True)
コード例 #2
0
    def __init__(self,
                 init_optimizer,
                 param_names,
                 mpu=None,
                 clip_grad=0.0,
                 norm_type=2,
                 allgather_bucket_size=5000000000,
                 dp_process_group=None,
                 timers=None):
        super().__init__()
        see_memory_usage('begin bf16_optimizer', force=True)
        self.timers = timers
        self.optimizer = init_optimizer
        self.param_names = param_names
        self.using_real_optimizer = not isinstance(self.optimizer, DummyOptim)

        self.clip_grad = clip_grad
        self.norm_type = norm_type
        self.mpu = mpu
        self.allgather_bucket_size = int(allgather_bucket_size)
        self.dp_process_group = dp_process_group
        self.dp_rank = dist.get_rank(group=self.dp_process_group)
        self.real_dp_process_group = [
            dp_process_group for i in range(len(self.optimizer.param_groups))
        ]

        # Load pre-built or JIT compile (un)flatten ops
        util_ops = UtilsBuilder().load()
        self.flatten = util_ops.flatten
        self.unflatten = util_ops.unflatten

        #align nccl all-gather send buffers to 4-bye boundary
        self.nccl_start_alignment_factor = 2  # 4-byte alignment/sizeof(fp16) = 2

        # Build BF16/FP32 groups
        self.bf16_groups = []
        self.bf16_groups_flat = []
        self.bf16_partitioned_groups = []

        self.fp32_groups_flat_partition = []

        # Maintain different fp32 gradients views for convenience
        self.fp32_groups_gradients = []
        self.fp32_groups_gradients_flat = []
        self.fp32_groups_actual_gradients_flat = []
        self.fp32_groups_gradient_flat_partition = []
        self.fp32_groups_has_gradients = []

        self.step_count = 0
        self.group_paddings = []

        if self.using_real_optimizer:
            self._setup_for_real_optimizer()

        see_memory_usage('end bf16_optimizer', force=True)
コード例 #3
0
    def post_sub_module_backward_function(self, sub_module):
        see_memory_usage(
            f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
            force=False)

        self.get_param_coordinator(
            training=sub_module.training).release_sub_module(sub_module)

        see_memory_usage(
            f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} after release",
            force=False)
コード例 #4
0
    def pre_sub_module_forward_function(self, sub_module):
        see_memory_usage(
            f"Before sub module function {sub_module.__class__.__name__}",
            force=False)

        global FWD_MODULE_STACK
        FWD_MODULE_STACK.append(sub_module)

        param_coordinator = self.get_param_coordinator(
            training=sub_module.training)
        param_coordinator.trace_prologue(sub_module)
        if param_coordinator.is_record_trace():
            param_coordinator.record_module(sub_module)
        param_coordinator.fetch_sub_module(sub_module)

        see_memory_usage(
            f"Before sub module function {sub_module.__class__.__name__} after fetch",
            force=False)
コード例 #5
0
    def _post_init_method(self, module):
        #see_memory_usage(f"Before converting parmas in {module.__class__.__name__}", force=False)
        print_rank_0(f'Converting Params in {module.__class__.__name__}', force=False)
        see_memory_usage(
            f"Before converting and partitioning parmas in {module.__class__.__name__}",
            force=False)

        global param_count
        for name, param in module.named_parameters(recurse=False):
            param_count += param.numel()
            if not is_zero_param(param):
                self._convert_to_deepspeed_param(param)
                print_rank_0(
                    f"Partitioning param with ds id {param.ds_id} and shape {param.data.shape}"
                )
                param.partition()
        see_memory_usage(
            f"Param count {param_count}. After converting and partitioning parmas in {module.__class__.__name__}",
            force=False)
コード例 #6
0
def model_provider():
    """Build the model."""

    print_rank_0('building GPT2 model ...')
    see_memory_usage(f"Before Building Model", force=True)
    with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(),
                             remote_device=get_args().remote_device,
                             deepspeed_config=get_args().deepspeed_config,
                             enabled=get_args().zero_stage == 3):
        model = GPT2Model(num_tokentypes=0, parallel_output=True)
    see_memory_usage(f"After Building Model", force=True)

    if mpu.get_data_parallel_rank() == 0:
        billion_params = get_parameters_in_billions(model)
        print(
            f' > number of parameters on model parallel rank {mpu.get_model_parallel_rank()}\
            {round(billion_params, 3)} Billion',
            flush=True)

    return model
コード例 #7
0
ファイル: training.py プロジェクト: Chen-Chang/Megatron-LM
def train_step(forward_step_func, data_iterator, model, optimizer,
               lr_scheduler):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    see_memory_usage(f'before forward {model.global_steps}', force=True)
    # Forward model for one step.
    timers('forward').start()
    loss, loss_reduced = forward_step_func(data_iterator, model)
    timers('forward').stop()

    see_memory_usage(f'before backward {model.global_steps}', force=True)
    # Calculate gradients, reduce across processes, and clip.
    timers('backward').start()
    backward_step(optimizer, model, loss)
    timers('backward').stop()

    see_memory_usage(f'before optimizer {model.global_steps}', force=True)
    # Update parameters.
    skipped_iter = 0
    timers('optimizer').start()
    if args.deepspeed:
        model.step()
    else:
        optimizer.step()
        # Update learning rate.
        if not (args.fp16 and optimizer.overflow):
            lr_scheduler.step()
        else:
            skipped_iter = 1
    timers('optimizer').stop()

    return loss_reduced, skipped_iter
コード例 #8
0
 def _go(hidden_dim):
     with deepspeed.zero.Init(enabled=zero_stage == 3,
                              config_dict_or_path=ds_config):
         model = SimpleModel(hidden_dim, nlayers=78)
     print('total number of parameters:',
           sum([p.numel() for p in model.parameters()]))
     see_memory_usage('pre-init', force=True)
     model, _, _, _ = deepspeed.initialize(model=model, config=ds_config)
     see_memory_usage('post-init', force=True)
     data_loader = random_dataloader(model=model,
                                     total_samples=50,
                                     hidden_dim=hidden_dim,
                                     device=model.device,
                                     dtype=torch.half)
     print(f"optimizer={model.optimizer}")
     for batch in data_loader:
         model(batch[0], batch[1])
     see_memory_usage('post-fwds', force=True)
コード例 #9
0
    def _partition_gradient(self,
                            param,
                            partition_buffer=None,
                            accumulate=False):
        #import pdb;pdb.set_trace()
        # param.grad=None
        # param.grad.test()
        print_rank_0(
            f"Partitioning param {id(param)} gradient of size {param.grad.numel()} type {param.grad.dtype} part_size {param.ds_tensor.numel()}"
        )
        see_memory_usage("Before partitioning gradients", force=False)
        partition_size = param.ds_tensor.numel()

        if partition_buffer is None:
            assert not accumulate, "No buffer to accumulate to"
            partition_buffer = torch.zeros(partition_size,
                                           dtype=param.dtype,
                                           device=param.device)
        else:
            assert partition_buffer.numel(
            ) >= partition_size, f"The partition buffer size {partition_buffer.numel()} should match the size of param.ds_tensor {partition_size}"

        rank = torch.distributed.get_rank(group=self.ds_process_group)
        start = partition_size * rank
        end = start + partition_size

        dest_tensor = partition_buffer.view(-1).narrow(0, 0, partition_size)

        #print("before partition gradients")
        if start < param.ds_numel:
            elements = min(param.ds_numel - start, partition_size)

            dest_tensor_full_buffer = partition_buffer.view(-1).narrow(
                0, 0, partition_size)

            dest_tensor = dest_tensor_full_buffer.narrow(0, 0, elements)
            src_tensor = param.grad.view(-1).narrow(0, start, elements)

            # just copy the grad partition to the buffer
            if not accumulate:
                dest_tensor.copy_(src_tensor)

            # if source and destinatoin are on same device,
            # add to the provided buffer
            elif src_tensor.device == dest_tensor.device:
                dest_tensor.add_(src_tensor)

            # if source and destination are on different device, copy first to src
            # then add and move back to the destination. This seems to run faster
            # when src is gpu and dest is cpu
            # adding directly to cpu is very slow
            else:
                acc_tensor = torch.empty(src_tensor.numel(),
                                         dtype=param.dtype,
                                         device=param.device)

                acc_tensor.copy_(dest_tensor)
                acc_tensor.add_(src_tensor)
                dest_tensor.copy_(acc_tensor)

            # partition_buffer.view(-1).narrow(
            #     0,
            #     0,
            #     elements).copy_(param.grad.view(-1).narrow(0,
            #                                             start,
            #                                             elements))

        #print("after partition gradients")
        param.grad.data = dest_tensor_full_buffer.data
        see_memory_usage("After partitioning gradients", force=False)
コード例 #10
0
    def backward(ctx, *grads):
        global timers
        see_memory_usage("In backward", force=False)
        # removing pointers to the contiguous buffer memory
        # so that they can be garbage collected once the checkpoints
        # have been used
        if SYNCHRONIZE:
            torch.cuda.synchronize()
        if PROFILE_TIME:
            timers('backward').start()

        if CONTIGUOUS_CHECKPOINTING:
            global data_offsets, size_offsets
            global contiguous_data_buffers, contiguous_size_buffers

            for buffers in contiguous_data_buffers:
                buffers = []

            # frees up all the pointers to the checkpoints except for the ones
            # stored by save for backward
            contiguous_data_buffers = []
            contiguous_size_buffers = []
            data_offsets = []
            size_offsets = []

        see_memory_usage("In backward checkpointing code", force=False)
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), "
                               "please use .backward() if possible")

        global cuda_device, transport_stream, PARTITION_ACTIVATIONS

        if PARTITION_ACTIVATIONS:
            # with torch.cuda.stream(transport_stream):
            inputs = gather_partitioned_activations(
                ctx.saved_tensors,
                device=cuda_device if CPU_CHECKPOINT else None)
            detached_inputs = detach_variable(inputs)
        elif CPU_CHECKPOINT:
            inputs = move_to_device(ctx.saved_tensors, cuda_device,
                                    is_activation_to_checkpoint)
            detached_inputs = detach_variable(inputs)
        else:
            inputs = ctx.saved_tensors
            detached_inputs = detach_variable(inputs)

        # Add non tensor input args
        detached_inputs = merge_tensors(tensor_objects=detached_inputs,
                                        non_tensor_objects=ctx.non_tensor_args,
                                        tensor_flags=ctx.tensor_flags)

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
        bwd_cuda_rng_state = torch.cuda.get_rng_state()
        bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        # Set the states to what it used to be before the forward pass.
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
        _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)

        # if PARTITION_ACTIVATIONS:
        #     current_stream=torch.cuda.current_stream()
        #     current_stream.wait_stream(transport_stream)

        see_memory_usage("In backward checkpointing code before forward",
                         force=False)

        with torch.enable_grad():
            outputs = ctx.run_function(*detached_inputs)

        see_memory_usage("In backward checkpointing code after forward",
                         force=False)
        # Set the states back to what it was at the start of this function.
        torch.set_rng_state(bwd_cpu_rng_state)
        _set_cuda_rng_state(bwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs, )

        # Filter out non tensor outputs
        outputs, _, _ = extract_tensors(all_objects=outputs)

        # Construct arguments to autograd.backward().
        # This is usually just outputs and grads, but forward() can return tensors that
        # are not differentiable.
        output_tensors = []
        grad_tensors = []
        for out, grad in zip(outputs, grads):
            if out.requires_grad:
                output_tensors.append(out)
                grad_tensors.append(grad)

        see_memory_usage("In backward checkpointing code before backward",
                         force=False)

        torch.autograd.backward(output_tensors, grad_tensors)

        see_memory_usage("After backward checkpointing code after backward",
                         force=False)

        if PROFILE_TIME:
            timers('backward').stop()
            timers.log(['backward'])
        if SYNCHRONIZE:
            torch.cuda.synchronize()
        ret_list = [None, None]  # first None for ctx
        for inp in detached_inputs:
            if torch.is_tensor(inp):
                ret_list.append(inp.grad)
            else:
                ret_list.append(None)

        return tuple(ret_list)
コード例 #11
0
    def forward(ctx, run_function, all_outputs, *args):
        global mpu, timers, SYNCHRONIZE, PROFILE_TIME

        def save_args_for_backward(*all_args):
            tensor_args, non_tensor_args, tensor_flags = extract_tensors(
                all_objects=all_args)
            ctx.save_for_backward(*tensor_args)
            ctx.non_tensor_args = non_tensor_args
            ctx.tensor_flags = tensor_flags

        if SYNCHRONIZE:
            torch.cuda.synchronize()

        if timers is None and PROFILE_TIME:
            timers = Timers()

        if PROFILE_TIME:
            timers('forward').start()

        ctx.run_function = run_function
        global num_layers
        global mp_rank, mp_size, mp_group
        global contiguous_data_buffers, contiguous_size_buffers
        global data_offsets, size_offsets
        if mp_rank is None:
            if mpu is not None:
                if hasattr(mpu, 'get_tensor_model_parallel_rank'):
                    mp_rank = mpu.get_tensor_model_parallel_rank()
                    mp_size = mpu.get_tensor_model_parallel_world_size()
                    mp_group = mpu.get_tensor_model_parallel_group()
                else:
                    mp_rank = mpu.get_model_parallel_rank()
                    mp_size = mpu.get_model_parallel_world_size()
                    mp_group = mpu.get_model_parallel_group()
            else:
                mp_rank = 0
                mp_size = 1
                mp_group = None

        global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset

        if cuda_device is None:
            see_memory_usage("First Forward Beginning", force=False)
            if dist.get_rank() == 0:
                logger.info(f"Activation Checkpointing Information")
                logger.info(
                    f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}"
                )
                logger.info(
                    f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers"
                )
                logger.info(f"----Synchronization {SYNCHRONIZE}")
                logger.info(
                    f"----Profiling time in checkpointing {PROFILE_TIME}")

            cuda_device = torch.cuda.current_device()
            transport_stream = torch.cuda.Stream(device=cuda_device)

        if PARTITION_ACTIVATIONS:
            inputs = partition_activations(args, CPU_CHECKPOINT,
                                           CONTIGUOUS_CHECKPOINTING)
        elif CPU_CHECKPOINT:
            inputs = copy_to_device(args,
                                    device=torch.device('cpu'),
                                    criterion_func=is_activation_to_checkpoint)

        # just in case something funky is happening such as reuse of inputs
        inputs_cuda = copy_to_device(
            args,
            device=cuda_device,
            criterion_func=is_activation_to_checkpoint)

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
        ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
        ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        see_memory_usage("Before running forward on the layer", force=False)
        # ctx.save_for_backward(*args)
        with torch.no_grad():
            outputs = run_function(*inputs_cuda)

        see_memory_usage("After running forward on the layer", force=False)
        del inputs_cuda

        if PARTITION_ACTIVATIONS:
            new_args = get_partitioned_activations_for_backward(
                args, inputs, CONTIGUOUS_CHECKPOINTING)
            assert len(
                new_args
            ) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
            save_args_for_backward(*new_args)
        elif CPU_CHECKPOINT:
            new_args = get_cpu_activations_for_backward(args, inputs)
            save_args_for_backward(*new_args)
        else:
            save_args_for_backward(*args)

        if PROFILE_TIME:
            timers('forward').stop()
            timers.log(['forward'])
        if SYNCHRONIZE:
            torch.cuda.synchronize()

        # Tensors returned from forward() may not be differentiable.
        if torch.is_tensor(outputs):
            non_grad_outputs = [outputs
                                ] if not outputs.is_floating_point() else []
        else:
            non_grad_outputs = [
                o for o in outputs
                if torch.is_tensor(o) and not o.is_floating_point()
            ]
        ctx.mark_non_differentiable(*non_grad_outputs)

        if torch.is_tensor(outputs):
            all_outputs += [outputs]
            return outputs
        else:
            all_outputs += outputs
            outputs, _, _ = extract_tensors(all_objects=outputs)
            return tuple(outputs)
コード例 #12
0
    def forward(ctx, run_function, all_outputs, *args):
        global mpu, timers, SYNCHRONIZE, PROFILE_TIME

        def save_args_for_backward(*all_args):
            tensor_args, non_tensor_args, tensor_flags = extract_tensors(
                all_objects=all_args)
            ctx.save_for_backward(*tensor_args)
            ctx.non_tensor_args = non_tensor_args
            ctx.tensor_flags = tensor_flags

        if SYNCHRONIZE:
            torch.cuda.synchronize()

        if timers is None and PROFILE_TIME:
            timers = Timers()

        if PROFILE_TIME:
            timers('forward').start()

        ctx.run_function = run_function
        global num_layers
        global mp_rank, mp_size, mp_group
        global contiguous_data_buffers, contiguous_size_buffers
        global data_offsets, size_offsets
        if mp_rank is None:
            if mpu is not None:
                mp_rank = mpu.get_model_parallel_rank()
                mp_size = mpu.get_model_parallel_world_size()
                mp_group = mpu.get_model_parallel_group()
            else:
                mp_rank = 0
                mp_size = 1
                mp_group = None

        global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset

        if cuda_device is None:
            see_memory_usage("First Forward Begining", force=False)
            if dist.get_rank() == 0:
                logger.info(f"Activation Checkpointing Information")
                logger.info(
                    f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}"
                )
                logger.info(
                    f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers"
                )
                logger.info(f"----Synchronization {SYNCHRONIZE}")
                logger.info(f"----Profiling {PROFILE_TIME}")

            cuda_device = torch.cuda.current_device()
            transport_stream = torch.cuda.Stream(device=cuda_device)

        if PARTITION_ACTIVATIONS:
            #inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]]
            # inputs.append(args[-1])

            inputs = []
            for i, item in enumerate(args[:-1]):
                if not torch.is_tensor(item):
                    inputs.append(item)
                    continue

                partition_size = get_partition_size(item)
                partition = item.detach().contiguous().view(-1).narrow(
                    0, get_partition_start(item), partition_size).clone()

                if CONTIGUOUS_CHECKPOINTING:
                    buffer_device = torch.device(
                        'cpu') if PA_TO_CPU else partition.device

                    if i >= len(contiguous_data_buffers):
                        tensor_list = [
                            torch.tensor(()).new_empty([partition_size],
                                                       dtype=partition.dtype,
                                                       device=buffer_device)
                            for i in range(num_layers)
                        ]
                        contiguous_data_buffers.append(tensor_list)
                        data_offsets.append(0)
                    elif contiguous_data_buffers[i] is None:
                        tensor_list = [
                            torch.tensor(()).new_empty([partition_size],
                                                       dtype=partition.dtype,
                                                       device=buffer_device)
                            for i in range(num_layers)
                        ]
                        contiguous_data_buffers[i] = tensor_list
                        data_offsets[i] = 0

                    # Because the 'new_empty' returns uninitialized pages,
                    # the pages need to be populated during the cudaMemcpy time
                    # which increases the data copy time. To avoid this, we
                    # pre-populate these pages by simply writing 0 ahead of
                    # the actual cudaMemcpy operation time. Due to the
                    # previously launched GPU kernels, there is a small
                    # window of time here for CPUs to populate pages asynchronously.
                    contiguous_data_buffers[i][data_offsets[i]].data[range(
                        0, contiguous_data_buffers[i][
                            data_offsets[i]].data.shape[0],
                        int(mmap.PAGESIZE / contiguous_data_buffers[i][
                            data_offsets[i]].data.element_size()))] = 0

                    contiguous_partition = contiguous_data_buffers[i][
                        data_offsets[i]].data.copy_(partition.data)
                    data_offsets[i] = data_offsets[i] + 1
                    inputs.append(contiguous_partition)
                else:
                    partition = partition.cpu() if PA_TO_CPU else partition
                    inputs.append(partition)

            inputs.append(args[-1])

        #just in case something funky is happening such as reuse of inputs
        inputs_cuda = move_to_device(args, cuda_device)

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
        ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
        ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        see_memory_usage("Before running forward on the layer", force=False)
        # ctx.save_for_backward(*args)
        with torch.no_grad():
            outputs = run_function(*inputs_cuda)

        see_memory_usage("After running forward on the layer", force=False)
        del inputs_cuda

        # with torch.cuda.stream(transport_stream):
        # if PARTITION_ACTIVATIONS:
        #    new_args = []
        #    for arg, inp in zip(args,inputs):
        #        size= torch.tensor(arg.size())
        #        arg.data = inp.data
        #        new_args.append(arg)
        #        new_args.append(size)
        #    ctx.save_for_backward(*new_args)

        if PARTITION_ACTIVATIONS:
            new_args = []
            for i, (arg, inp) in enumerate(zip(args, inputs)):
                if not torch.is_tensor(arg):
                    new_args.append(arg)
                    continue

                size = torch.tensor(arg.size())

                arg.data = inp.data
                new_args.append(arg)

                if CONTIGUOUS_CHECKPOINTING:
                    numel = size.numel()
                    if i >= len(contiguous_size_buffers):
                        tmp = torch.tensor(())
                        contiguous_size_buffers.append(
                            tmp.new_empty([numel * num_layers],
                                          dtype=size.dtype,
                                          device=size.device))
                        size_offsets.append(0)
                    elif contiguous_size_buffers[i] is None:
                        tmp = torch.tensor(())
                        contiguous_size_buffers[i] = tmp.new_empty(
                            [numel * num_layers],
                            dtype=size.dtype,
                            device=size.device)
                        size_offsets[i] = 0

                    contiguous_size = contiguous_size_buffers[i].narrow(
                        0, size_offsets[i], numel).data.copy_(size.data)
                    contiguous_size = contiguous_size.view_as(size)
                    size_offsets[i] = size_offsets[i] + numel
                    new_args.append(contiguous_size)
                else:
                    new_args.append(size)
                # if dist.get_rank() == 0:
                #    logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ")

            save_args_for_backward(*new_args)
        else:
            save_args_for_backward(*args)

        if PROFILE_TIME:
            timers('forward').stop()
            timers.log(['forward'])
        if SYNCHRONIZE:
            torch.cuda.synchronize()

        # Tensors returned from forward() may not be differentiable.
        if torch.is_tensor(outputs):
            non_grad_outputs = [outputs
                                ] if not outputs.is_floating_point() else []
        else:
            non_grad_outputs = [
                o for o in outputs
                if torch.is_tensor(o) and not o.is_floating_point()
            ]
        ctx.mark_non_differentiable(*non_grad_outputs)

        if torch.is_tensor(outputs):
            all_outputs += [outputs]
            return outputs
        else:
            all_outputs += outputs
            outputs, _, _ = extract_tensors(all_objects=outputs)
            return tuple(outputs)
コード例 #13
0
    def _setup_for_real_optimizer(self):
        dp_world_size = dist.get_world_size(group=self.dp_process_group)
        self.partition_count = [
            dp_world_size for i in range(len(self.optimizer.param_groups))
        ]

        for i, param_group in enumerate(self.optimizer.param_groups):
            see_memory_usage(f'before initializing group {i}', force=True)

            partition_id = dist.get_rank(group=self.real_dp_process_group[i])

            # grab the original list
            self.bf16_groups.append(param_group['params'])

            # create flat bf16 params
            self.bf16_groups_flat.append(
                self._flatten_dense_tensors_aligned(
                    self.bf16_groups[i],
                    self.nccl_start_alignment_factor * dp_world_size))

            # Make bf16 params point to flat tensor storage
            self._update_storage_to_flattened_tensor(
                tensor_list=self.bf16_groups[i],
                flat_tensor=self.bf16_groups_flat[i])

            # divide flat weights into equal sized partitions
            partition_size = self.bf16_groups_flat[i].numel() // dp_world_size
            bf16_dp_partitions = [
                self.bf16_groups_flat[i].narrow(0, dp_index * partition_size,
                                                partition_size)
                for dp_index in range(dp_world_size)
            ]
            self.bf16_partitioned_groups.append(bf16_dp_partitions)

            # create fp32 params partition
            self.fp32_groups_flat_partition.append(
                bf16_dp_partitions[partition_id].clone().float().detach())
            self.fp32_groups_flat_partition[i].requires_grad = True

            num_elem_list = [t.numel() for t in self.bf16_groups[i]]

            # create fp32 gradients
            self.fp32_groups_gradients_flat.append(
                torch.zeros_like(self.bf16_groups_flat[i],
                                 dtype=torch.float32))

            # track individual fp32 gradients for entire model
            fp32_gradients = self._split_flat_tensor(
                flat_tensor=self.fp32_groups_gradients_flat[i],
                num_elem_list=num_elem_list)
            self.fp32_groups_gradients.append(fp32_gradients)

            # flat tensor corresponding to actual fp32 gradients (i.e., minus alignment padding)
            length_without_padding = sum(num_elem_list)
            self.fp32_groups_actual_gradients_flat.append(
                torch.narrow(self.fp32_groups_gradients_flat[i], 0, 0,
                             length_without_padding))

            # flat tensor corresponding to gradient partition
            self.fp32_groups_gradient_flat_partition.append(
                torch.narrow(self.fp32_groups_gradients_flat[i], 0,
                             partition_id * partition_size, partition_size))

            # track fp32 gradient updates
            self.fp32_groups_has_gradients.append([False] *
                                                  len(self.bf16_groups[i]))

            # Record padding required for alignment
            if partition_id == dist.get_world_size(
                    group=self.real_dp_process_group[i]) - 1:
                padding = self.bf16_groups_flat[i].numel(
                ) - length_without_padding
            else:
                padding = 0

            self.group_paddings.append(padding)

            # update optimizer param groups to reference fp32 params partition
            param_group['params'] = [self.fp32_groups_flat_partition[i]]

            see_memory_usage(f'after initializing group {i}', force=True)

        see_memory_usage('before initialize_optimizer', force=True)
        self.initialize_optimizer_states()
        see_memory_usage('end initialize_optimizer', force=True)

        # Need optimizer states initialized before linking lp to optimizer state
        self._link_all_hp_params()
        self._param_slice_mappings = self._create_param_mapping()