Ejemplo n.º 1
0
    def forward(ctx, run_function, all_outputs, *args):
        global mpu, timers, SYNCHRONIZE, PROFILE_TIME

        def save_args_for_backward(*all_args):
            tensor_args, non_tensor_args, tensor_flags = extract_tensors(
                all_objects=all_args)
            ctx.save_for_backward(*tensor_args)
            ctx.non_tensor_args = non_tensor_args
            ctx.tensor_flags = tensor_flags

        if SYNCHRONIZE:
            torch.cuda.synchronize()

        if timers is None and PROFILE_TIME:
            timers = Timers()

        if PROFILE_TIME:
            timers('forward').start()

        ctx.run_function = run_function
        global num_layers
        global mp_rank, mp_size, mp_group
        global contiguous_data_buffers, contiguous_size_buffers
        global data_offsets, size_offsets
        if mp_rank is None:
            if mpu is not None:
                if hasattr(mpu, 'get_tensor_model_parallel_rank'):
                    mp_rank = mpu.get_tensor_model_parallel_rank()
                    mp_size = mpu.get_tensor_model_parallel_world_size()
                    mp_group = mpu.get_tensor_model_parallel_group()
                else:
                    mp_rank = mpu.get_model_parallel_rank()
                    mp_size = mpu.get_model_parallel_world_size()
                    mp_group = mpu.get_model_parallel_group()
            else:
                mp_rank = 0
                mp_size = 1
                mp_group = None

        global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset

        if cuda_device is None:
            see_memory_usage("First Forward Beginning", force=False)
            if dist.get_rank() == 0:
                logger.info(f"Activation Checkpointing Information")
                logger.info(
                    f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}"
                )
                logger.info(
                    f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers"
                )
                logger.info(f"----Synchronization {SYNCHRONIZE}")
                logger.info(
                    f"----Profiling time in checkpointing {PROFILE_TIME}")

            cuda_device = torch.cuda.current_device()
            transport_stream = torch.cuda.Stream(device=cuda_device)

        if PARTITION_ACTIVATIONS:
            inputs = partition_activations(args, CPU_CHECKPOINT,
                                           CONTIGUOUS_CHECKPOINTING)
        elif CPU_CHECKPOINT:
            inputs = copy_to_device(args,
                                    device=torch.device('cpu'),
                                    criterion_func=is_activation_to_checkpoint)

        # just in case something funky is happening such as reuse of inputs
        inputs_cuda = copy_to_device(
            args,
            device=cuda_device,
            criterion_func=is_activation_to_checkpoint)

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
        ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
        ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        see_memory_usage("Before running forward on the layer", force=False)
        # ctx.save_for_backward(*args)
        with torch.no_grad():
            outputs = run_function(*inputs_cuda)

        see_memory_usage("After running forward on the layer", force=False)
        del inputs_cuda

        if PARTITION_ACTIVATIONS:
            new_args = get_partitioned_activations_for_backward(
                args, inputs, CONTIGUOUS_CHECKPOINTING)
            assert len(
                new_args
            ) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}'
            save_args_for_backward(*new_args)
        elif CPU_CHECKPOINT:
            new_args = get_cpu_activations_for_backward(args, inputs)
            save_args_for_backward(*new_args)
        else:
            save_args_for_backward(*args)

        if PROFILE_TIME:
            timers('forward').stop()
            timers.log(['forward'])
        if SYNCHRONIZE:
            torch.cuda.synchronize()

        # Tensors returned from forward() may not be differentiable.
        if torch.is_tensor(outputs):
            non_grad_outputs = [outputs
                                ] if not outputs.is_floating_point() else []
        else:
            non_grad_outputs = [
                o for o in outputs
                if torch.is_tensor(o) and not o.is_floating_point()
            ]
        ctx.mark_non_differentiable(*non_grad_outputs)

        if torch.is_tensor(outputs):
            all_outputs += [outputs]
            return outputs
        else:
            all_outputs += outputs
            outputs, _, _ = extract_tensors(all_objects=outputs)
            return tuple(outputs)
Ejemplo n.º 2
0
    def forward(ctx, run_function, *args):
        global mpu, timers, SYNCHRONIZE, PROFILE_TIME

        if SYNCHRONIZE:
            torch.cuda.synchronize()

        if timers is None and PROFILE_TIME:
            timers = Timers()

        if PROFILE_TIME:
            timers('forward').start()

        ctx.run_function = run_function
        global num_layers
        global mp_rank, mp_size, mp_group
        global contiguous_data_buffers, contiguous_size_buffers
        global data_offsets, size_offsets
        if mp_rank is None:
            if mpu is not None:
                mp_rank = mpu.get_model_parallel_rank()
                mp_size = mpu.get_model_parallel_world_size()
                mp_group = mpu.get_model_parallel_group()
            else:
                mp_rank = 0
                mp_size = 1
                mp_group = None

        global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset

        if cuda_device is None:
            see_memory_usage("First Forward Begining", force=True)
            if dist.get_rank() == 0:
                logger.info(f"Activation Checkpointing Information")
                logger.info(
                    f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}"
                )
                logger.info(
                    f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers"
                )
                logger.info(f"----Synchronization {SYNCHRONIZE}")
                logger.info(f"----Profiling {PROFILE_TIME}")

            cuda_device = torch.cuda.current_device()
            transport_stream = torch.cuda.Stream(device=cuda_device)

        if PARTITION_ACTIVATIONS:
            #inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]]
            #inputs.append(args[-1])

            inputs = []
            for i, item in enumerate(args[:-1]):
                partition_size = get_partition_size(item)
                partition = item.detach().contiguous().view(-1).narrow(
                    0, get_partition_start(item), partition_size).clone()

                if CONTIGUOUS_CHECKPOINTING:
                    buffer_device = torch.device(
                        'cpu') if PA_TO_CPU else partition.device

                    if i >= len(contiguous_data_buffers):
                        tensor_list = [
                            torch.tensor(()).new_empty([partition_size],
                                                       dtype=partition.dtype,
                                                       device=buffer_device)
                            for i in range(num_layers)
                        ]
                        contiguous_data_buffers.append(tensor_list)
                        data_offsets.append(0)
                    elif contiguous_data_buffers[i] is None:
                        tensor_list = [
                            torch.tensor(()).new_empty([partition_size],
                                                       dtype=partition.dtype,
                                                       device=buffer_device)
                            for i in range(num_layers)
                        ]
                        contiguous_data_buffers[i] = tensor_list
                        data_offsets[i] = 0

                    contiguous_partition = contiguous_data_buffers[i][
                        data_offsets[i]].data.copy_(partition.data)
                    data_offsets[i] = data_offsets[i] + 1
                    inputs.append(contiguous_partition)
                else:
                    partition = partition.cpu() if PA_TO_CPU else partition
                    inputs.append(partition)

            inputs.append(args[-1])

        #just in case something funky is happening such as reuse of inputs
        inputs_cuda = [item.to(cuda_device) for item in args]

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
        ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
        ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        #ctx.save_for_backward(*args)
        with torch.no_grad():
            outputs = run_function(*inputs_cuda)

        del inputs_cuda

        #with torch.cuda.stream(transport_stream):
        #if PARTITION_ACTIVATIONS:
        #    new_args = []
        #    for arg, inp in zip(args,inputs):
        #        size= torch.tensor(arg.size())
        #        arg.data = inp.data
        #        new_args.append(arg)
        #        new_args.append(size)
        #    ctx.save_for_backward(*new_args)

        if PARTITION_ACTIVATIONS:
            new_args = []
            for i, (arg, inp) in enumerate(zip(args, inputs)):
                size = torch.tensor(arg.size())

                arg.data = inp.data
                new_args.append(arg)

                if CONTIGUOUS_CHECKPOINTING:
                    numel = size.numel()
                    if i >= len(contiguous_size_buffers):
                        tmp = torch.tensor(())
                        contiguous_size_buffers.append(
                            tmp.new_empty([numel * num_layers],
                                          dtype=size.dtype,
                                          device=size.device))
                        size_offsets.append(0)
                    elif contiguous_size_buffers[i] is None:
                        tmp = torch.tensor(())
                        contiguous_size_buffers[i] = tmp.new_empty(
                            [numel * num_layers],
                            dtype=size.dtype,
                            device=size.device)
                        size_offsets[i] = 0

                    contiguous_size = contiguous_size_buffers[i].narrow(
                        0, size_offsets[i], numel).data.copy_(size.data)
                    contiguous_size = contiguous_size.view_as(size)
                    size_offsets[i] = size_offsets[i] + numel
                    new_args.append(contiguous_size)
                else:
                    new_args.append(size)
                #if dist.get_rank() == 0:
                #    logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ")

            ctx.save_for_backward(*new_args)
        else:
            ctx.save_for_backward(*args)
        if PROFILE_TIME:
            timers('forward').stop()
            timers.log(['forward'])
        if SYNCHRONIZE:
            torch.cuda.synchronize()

        # Tensors returned from forward() may not be differentiable.
        if torch.is_tensor(outputs):
            non_grad_outputs = [outputs
                                ] if not outputs.is_floating_point() else []
        else:
            non_grad_outputs = [
                o for o in outputs if not o.is_floating_point()
            ]
        ctx.mark_non_differentiable(*non_grad_outputs)
        return outputs
Ejemplo n.º 3
0
    def forward(ctx, run_function, all_outputs, *args):
        global mpu, timers, SYNCHRONIZE, PROFILE_TIME

        def save_args_for_backward(*all_args):
            tensor_args, non_tensor_args, tensor_flags = extract_tensors(
                all_objects=all_args)
            ctx.save_for_backward(*tensor_args)
            ctx.non_tensor_args = non_tensor_args
            ctx.tensor_flags = tensor_flags

        if SYNCHRONIZE:
            torch.cuda.synchronize()

        if timers is None and PROFILE_TIME:
            timers = Timers()

        if PROFILE_TIME:
            timers('forward').start()

        ctx.run_function = run_function
        global num_layers
        global mp_rank, mp_size, mp_group
        global contiguous_data_buffers, contiguous_size_buffers
        global data_offsets, size_offsets
        if mp_rank is None:
            if mpu is not None:
                mp_rank = mpu.get_model_parallel_rank()
                mp_size = mpu.get_model_parallel_world_size()
                mp_group = mpu.get_model_parallel_group()
            else:
                mp_rank = 0
                mp_size = 1
                mp_group = None

        global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset

        if cuda_device is None:
            see_memory_usage("First Forward Begining", force=False)
            if dist.get_rank() == 0:
                logger.info(f"Activation Checkpointing Information")
                logger.info(
                    f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}"
                )
                logger.info(
                    f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers"
                )
                logger.info(f"----Synchronization {SYNCHRONIZE}")
                logger.info(f"----Profiling {PROFILE_TIME}")

            cuda_device = torch.cuda.current_device()
            transport_stream = torch.cuda.Stream(device=cuda_device)

        if PARTITION_ACTIVATIONS:
            #inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]]
            # inputs.append(args[-1])

            inputs = []
            for i, item in enumerate(args[:-1]):
                if not torch.is_tensor(item):
                    inputs.append(item)
                    continue

                partition_size = get_partition_size(item)
                partition = item.detach().contiguous().view(-1).narrow(
                    0, get_partition_start(item), partition_size).clone()

                if CONTIGUOUS_CHECKPOINTING:
                    buffer_device = torch.device(
                        'cpu') if PA_TO_CPU else partition.device

                    if i >= len(contiguous_data_buffers):
                        tensor_list = [
                            torch.tensor(()).new_empty([partition_size],
                                                       dtype=partition.dtype,
                                                       device=buffer_device)
                            for i in range(num_layers)
                        ]
                        contiguous_data_buffers.append(tensor_list)
                        data_offsets.append(0)
                    elif contiguous_data_buffers[i] is None:
                        tensor_list = [
                            torch.tensor(()).new_empty([partition_size],
                                                       dtype=partition.dtype,
                                                       device=buffer_device)
                            for i in range(num_layers)
                        ]
                        contiguous_data_buffers[i] = tensor_list
                        data_offsets[i] = 0

                    # Because the 'new_empty' returns uninitialized pages,
                    # the pages need to be populated during the cudaMemcpy time
                    # which increases the data copy time. To avoid this, we
                    # pre-populate these pages by simply writing 0 ahead of
                    # the actual cudaMemcpy operation time. Due to the
                    # previously launched GPU kernels, there is a small
                    # window of time here for CPUs to populate pages asynchronously.
                    contiguous_data_buffers[i][data_offsets[i]].data[range(
                        0, contiguous_data_buffers[i][
                            data_offsets[i]].data.shape[0],
                        int(mmap.PAGESIZE / contiguous_data_buffers[i][
                            data_offsets[i]].data.element_size()))] = 0

                    contiguous_partition = contiguous_data_buffers[i][
                        data_offsets[i]].data.copy_(partition.data)
                    data_offsets[i] = data_offsets[i] + 1
                    inputs.append(contiguous_partition)
                else:
                    partition = partition.cpu() if PA_TO_CPU else partition
                    inputs.append(partition)

            inputs.append(args[-1])

        #just in case something funky is happening such as reuse of inputs
        inputs_cuda = move_to_device(args, cuda_device)

        # Copy the rng states.
        ctx.fwd_cpu_rng_state = torch.get_rng_state()
        ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
        ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

        see_memory_usage("Before running forward on the layer", force=False)
        # ctx.save_for_backward(*args)
        with torch.no_grad():
            outputs = run_function(*inputs_cuda)

        see_memory_usage("After running forward on the layer", force=False)
        del inputs_cuda

        # with torch.cuda.stream(transport_stream):
        # if PARTITION_ACTIVATIONS:
        #    new_args = []
        #    for arg, inp in zip(args,inputs):
        #        size= torch.tensor(arg.size())
        #        arg.data = inp.data
        #        new_args.append(arg)
        #        new_args.append(size)
        #    ctx.save_for_backward(*new_args)

        if PARTITION_ACTIVATIONS:
            new_args = []
            for i, (arg, inp) in enumerate(zip(args, inputs)):
                if not torch.is_tensor(arg):
                    new_args.append(arg)
                    continue

                size = torch.tensor(arg.size())

                arg.data = inp.data
                new_args.append(arg)

                if CONTIGUOUS_CHECKPOINTING:
                    numel = size.numel()
                    if i >= len(contiguous_size_buffers):
                        tmp = torch.tensor(())
                        contiguous_size_buffers.append(
                            tmp.new_empty([numel * num_layers],
                                          dtype=size.dtype,
                                          device=size.device))
                        size_offsets.append(0)
                    elif contiguous_size_buffers[i] is None:
                        tmp = torch.tensor(())
                        contiguous_size_buffers[i] = tmp.new_empty(
                            [numel * num_layers],
                            dtype=size.dtype,
                            device=size.device)
                        size_offsets[i] = 0

                    contiguous_size = contiguous_size_buffers[i].narrow(
                        0, size_offsets[i], numel).data.copy_(size.data)
                    contiguous_size = contiguous_size.view_as(size)
                    size_offsets[i] = size_offsets[i] + numel
                    new_args.append(contiguous_size)
                else:
                    new_args.append(size)
                # if dist.get_rank() == 0:
                #    logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ")

            save_args_for_backward(*new_args)
        else:
            save_args_for_backward(*args)

        if PROFILE_TIME:
            timers('forward').stop()
            timers.log(['forward'])
        if SYNCHRONIZE:
            torch.cuda.synchronize()

        # Tensors returned from forward() may not be differentiable.
        if torch.is_tensor(outputs):
            non_grad_outputs = [outputs
                                ] if not outputs.is_floating_point() else []
        else:
            non_grad_outputs = [
                o for o in outputs
                if torch.is_tensor(o) and not o.is_floating_point()
            ]
        ctx.mark_non_differentiable(*non_grad_outputs)

        if torch.is_tensor(outputs):
            all_outputs += [outputs]
            return outputs
        else:
            all_outputs += outputs
            outputs, _, _ = extract_tensors(all_objects=outputs)
            return tuple(outputs)