Example #1
0
    def backward(ctx, *grads):
        global timers
        see_memory_usage("In backward", force=False)
        # removing pointers to the contiguous buffer memory
        # so that they can be garbage collected once the checkpoints
        # have been used
        if SYNCHRONIZE:
            torch.cuda.synchronize()
        if PROFILE_TIME:
            timers('backward').start()

        if CONTIGUOUS_CHECKPOINTING:
            global data_offsets, size_offsets
            global contiguous_data_buffers, contiguous_size_buffers

            for buffers in contiguous_data_buffers:
                buffers = []

            # frees up all the pointers to the checkpoints except for the ones
            # stored by save for backward
            contiguous_data_buffers = []
            contiguous_size_buffers = []
            data_offsets = []
            size_offsets = []

        see_memory_usage("In backward checkpointing code", force=False)
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError("Checkpointing is not compatible with .grad(), "
                               "please use .backward() if possible")

        global cuda_device, transport_stream, PARTITION_ACTIVATIONS

        if PARTITION_ACTIVATIONS:
            # with torch.cuda.stream(transport_stream):
            inputs = gather_partitioned_activations(
                ctx.saved_tensors,
                device=cuda_device if CPU_CHECKPOINT else None)
            detached_inputs = detach_variable(inputs)
        elif CPU_CHECKPOINT:
            inputs = move_to_device(ctx.saved_tensors, cuda_device,
                                    is_activation_to_checkpoint)
            detached_inputs = detach_variable(inputs)
        else:
            inputs = ctx.saved_tensors
            detached_inputs = detach_variable(inputs)

        # Add non tensor input args
        detached_inputs = merge_tensors(tensor_objects=detached_inputs,
                                        non_tensor_objects=ctx.non_tensor_args,
                                        tensor_flags=ctx.tensor_flags)

        # Store the current states.
        bwd_cpu_rng_state = torch.get_rng_state()
        bwd_cuda_rng_state = torch.cuda.get_rng_state()
        bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        # Set the states to what it used to be before the forward pass.
        torch.set_rng_state(ctx.fwd_cpu_rng_state)
        _set_cuda_rng_state(ctx.fwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)

        # if PARTITION_ACTIVATIONS:
        #     current_stream=torch.cuda.current_stream()
        #     current_stream.wait_stream(transport_stream)

        see_memory_usage("In backward checkpointing code before forward",
                         force=False)

        with torch.enable_grad():
            outputs = ctx.run_function(*detached_inputs)

        see_memory_usage("In backward checkpointing code after forward",
                         force=False)
        # Set the states back to what it was at the start of this function.
        torch.set_rng_state(bwd_cpu_rng_state)
        _set_cuda_rng_state(bwd_cuda_rng_state)
        get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs, )

        # Filter out non tensor outputs
        outputs, _, _ = extract_tensors(all_objects=outputs)

        # Construct arguments to autograd.backward().
        # This is usually just outputs and grads, but forward() can return tensors that
        # are not differentiable.
        output_tensors = []
        grad_tensors = []
        for out, grad in zip(outputs, grads):
            if out.requires_grad:
                output_tensors.append(out)
                grad_tensors.append(grad)

        see_memory_usage("In backward checkpointing code before backward",
                         force=False)

        torch.autograd.backward(output_tensors, grad_tensors)

        see_memory_usage("After backward checkpointing code after backward",
                         force=False)

        if PROFILE_TIME:
            timers('backward').stop()
            timers.log(['backward'])
        if SYNCHRONIZE:
            torch.cuda.synchronize()
        ret_list = [None, None]  # first None for ctx
        for inp in detached_inputs:
            if torch.is_tensor(inp):
                ret_list.append(inp.grad)
            else:
                ret_list.append(None)

        return tuple(ret_list)
Example #2
0
    def forward(ctx, run_function, all_outputs, *args):
        global mpu, timers, SYNCHRONIZE, PROFILE_TIME

        def save_args_for_backward(*all_args):
            tensor_args, non_tensor_args, tensor_flags = extract_tensors(
                all_objects=all_args)
            ctx.save_for_backward(*tensor_args)
            ctx.non_tensor_args = non_tensor_args
            ctx.tensor_flags = tensor_flags

        if SYNCHRONIZE:
            torch.cuda.synchronize()

        if timers is None and PROFILE_TIME:
            timers = Timers()

        if PROFILE_TIME:
            timers('forward').start()

        ctx.run_function = run_function
        global num_layers
        global mp_rank, mp_size, mp_group
        global contiguous_data_buffers, contiguous_size_buffers
        global data_offsets, size_offsets
        if mp_rank is None:
            if mpu is not None:
                mp_rank = mpu.get_model_parallel_rank()
                mp_size = mpu.get_model_parallel_world_size()
                mp_group = mpu.get_model_parallel_group()
            else:
                mp_rank = 0
                mp_size = 1
                mp_group = None

        global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset

        if cuda_device is None:
            see_memory_usage("First Forward Begining", force=False)
            if dist.get_rank() == 0:
                logger.info(f"Activation Checkpointing Information")
                logger.info(
                    f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}"
                )
                logger.info(
                    f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers"
                )
                logger.info(f"----Synchronization {SYNCHRONIZE}")
                logger.info(f"----Profiling {PROFILE_TIME}")

            cuda_device = torch.cuda.current_device()
            transport_stream = torch.cuda.Stream(device=cuda_device)

        if PARTITION_ACTIVATIONS:
            #inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]]
            # inputs.append(args[-1])

            inputs = []
            for i, item in enumerate(args[:-1]):
                if not torch.is_tensor(item):
                    inputs.append(item)
                    continue

                partition_size = get_partition_size(item)
                partition = item.detach().contiguous().view(-1).narrow(
                    0, get_partition_start(item), partition_size).clone()

                if CONTIGUOUS_CHECKPOINTING:
                    buffer_device = torch.device(
                        'cpu') if PA_TO_CPU else partition.device

                    if i >= len(contiguous_data_buffers):
                        tensor_list = [
                            torch.tensor(()).new_empty([partition_size],
                                                       dtype=partition.dtype,
                                                       device=buffer_device)
                            for i in range(num_layers)
                        ]
                        contiguous_data_buffers.append(tensor_list)
                        data_offsets.append(0)
                    elif contiguous_data_buffers[i] is None:
                        tensor_list = [
                            torch.tensor(()).new_empty([partition_size],
                                                       dtype=partition.dtype,
                                                       device=buffer_device)
                            for i in range(num_layers)
                        ]
                        contiguous_data_buffers[i] = tensor_list
                        data_offsets[i] = 0

                    # Because the 'new_empty' returns uninitialized pages,
                    # the pages need to be populated during the cudaMemcpy time
                    # which increases the data copy time. To avoid this, we
                    # pre-populate these pages by simply writing 0 ahead of
                    # the actual cudaMemcpy operation time. Due to the
                    # previously launched GPU kernels, there is a small
                    # window of time here for CPUs to populate pages asynchronously.
                    contiguous_data_buffers[i][data_offsets[i]].data[range(
                        0, contiguous_data_buffers[i][
                            data_offsets[i]].data.shape[0],
                        int(mmap.PAGESIZE / contiguous_data_buffers[i][
                            data_offsets[i]].data.element_size()))] = 0

                    contiguous_partition = contiguous_data_buffers[i][
                        data_offsets[i]].data.copy_(partition.data)
                    data_offsets[i] = data_offsets[i] + 1
                    inputs.append(contiguous_partition)
                else:
                    partition = partition.cpu() if PA_TO_CPU else partition
                    inputs.append(partition)

            inputs.append(args[-1])

        #just in case something funky is happening such as reuse of inputs
        inputs_cuda = move_to_device(args, cuda_device)

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
        ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
        ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        see_memory_usage("Before running forward on the layer", force=False)
        # ctx.save_for_backward(*args)
        with torch.no_grad():
            outputs = run_function(*inputs_cuda)

        see_memory_usage("After running forward on the layer", force=False)
        del inputs_cuda

        # with torch.cuda.stream(transport_stream):
        # if PARTITION_ACTIVATIONS:
        #    new_args = []
        #    for arg, inp in zip(args,inputs):
        #        size= torch.tensor(arg.size())
        #        arg.data = inp.data
        #        new_args.append(arg)
        #        new_args.append(size)
        #    ctx.save_for_backward(*new_args)

        if PARTITION_ACTIVATIONS:
            new_args = []
            for i, (arg, inp) in enumerate(zip(args, inputs)):
                if not torch.is_tensor(arg):
                    new_args.append(arg)
                    continue

                size = torch.tensor(arg.size())

                arg.data = inp.data
                new_args.append(arg)

                if CONTIGUOUS_CHECKPOINTING:
                    numel = size.numel()
                    if i >= len(contiguous_size_buffers):
                        tmp = torch.tensor(())
                        contiguous_size_buffers.append(
                            tmp.new_empty([numel * num_layers],
                                          dtype=size.dtype,
                                          device=size.device))
                        size_offsets.append(0)
                    elif contiguous_size_buffers[i] is None:
                        tmp = torch.tensor(())
                        contiguous_size_buffers[i] = tmp.new_empty(
                            [numel * num_layers],
                            dtype=size.dtype,
                            device=size.device)
                        size_offsets[i] = 0

                    contiguous_size = contiguous_size_buffers[i].narrow(
                        0, size_offsets[i], numel).data.copy_(size.data)
                    contiguous_size = contiguous_size.view_as(size)
                    size_offsets[i] = size_offsets[i] + numel
                    new_args.append(contiguous_size)
                else:
                    new_args.append(size)
                # if dist.get_rank() == 0:
                #    logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ")

            save_args_for_backward(*new_args)
        else:
            save_args_for_backward(*args)

        if PROFILE_TIME:
            timers('forward').stop()
            timers.log(['forward'])
        if SYNCHRONIZE:
            torch.cuda.synchronize()

        # Tensors returned from forward() may not be differentiable.
        if torch.is_tensor(outputs):
            non_grad_outputs = [outputs
                                ] if not outputs.is_floating_point() else []
        else:
            non_grad_outputs = [
                o for o in outputs
                if torch.is_tensor(o) and not o.is_floating_point()
            ]
        ctx.mark_non_differentiable(*non_grad_outputs)

        if torch.is_tensor(outputs):
            all_outputs += [outputs]
            return outputs
        else:
            all_outputs += outputs
            outputs, _, _ = extract_tensors(all_objects=outputs)
            return tuple(outputs)