def backward(ctx, *grads): global timers see_memory_usage("In backward", force=False) # removing pointers to the contiguous buffer memory # so that they can be garbage collected once the checkpoints # have been used if SYNCHRONIZE: torch.cuda.synchronize() if PROFILE_TIME: timers('backward').start() if CONTIGUOUS_CHECKPOINTING: global data_offsets, size_offsets global contiguous_data_buffers, contiguous_size_buffers for buffers in contiguous_data_buffers: buffers = [] # frees up all the pointers to the checkpoints except for the ones # stored by save for backward contiguous_data_buffers = [] contiguous_size_buffers = [] data_offsets = [] size_offsets = [] see_memory_usage("In backward checkpointing code", force=False) if not torch.autograd._is_checkpoint_valid(): raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible") global cuda_device, transport_stream, PARTITION_ACTIVATIONS if PARTITION_ACTIVATIONS: # with torch.cuda.stream(transport_stream): inputs = gather_partitioned_activations( ctx.saved_tensors, device=cuda_device if CPU_CHECKPOINT else None) detached_inputs = detach_variable(inputs) elif CPU_CHECKPOINT: inputs = move_to_device(ctx.saved_tensors, cuda_device, is_activation_to_checkpoint) detached_inputs = detach_variable(inputs) else: inputs = ctx.saved_tensors detached_inputs = detach_variable(inputs) # Add non tensor input args detached_inputs = merge_tensors(tensor_objects=detached_inputs, non_tensor_objects=ctx.non_tensor_args, tensor_flags=ctx.tensor_flags) # Store the current states. bwd_cpu_rng_state = torch.get_rng_state() bwd_cuda_rng_state = torch.cuda.get_rng_state() bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() # Set the states to what it used to be before the forward pass. torch.set_rng_state(ctx.fwd_cpu_rng_state) _set_cuda_rng_state(ctx.fwd_cuda_rng_state) get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) # if PARTITION_ACTIVATIONS: # current_stream=torch.cuda.current_stream() # current_stream.wait_stream(transport_stream) see_memory_usage("In backward checkpointing code before forward", force=False) with torch.enable_grad(): outputs = ctx.run_function(*detached_inputs) see_memory_usage("In backward checkpointing code after forward", force=False) # Set the states back to what it was at the start of this function. torch.set_rng_state(bwd_cpu_rng_state) _set_cuda_rng_state(bwd_cuda_rng_state) get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) if isinstance(outputs, torch.Tensor): outputs = (outputs, ) # Filter out non tensor outputs outputs, _, _ = extract_tensors(all_objects=outputs) # Construct arguments to autograd.backward(). # This is usually just outputs and grads, but forward() can return tensors that # are not differentiable. output_tensors = [] grad_tensors = [] for out, grad in zip(outputs, grads): if out.requires_grad: output_tensors.append(out) grad_tensors.append(grad) see_memory_usage("In backward checkpointing code before backward", force=False) torch.autograd.backward(output_tensors, grad_tensors) see_memory_usage("After backward checkpointing code after backward", force=False) if PROFILE_TIME: timers('backward').stop() timers.log(['backward']) if SYNCHRONIZE: torch.cuda.synchronize() ret_list = [None, None] # first None for ctx for inp in detached_inputs: if torch.is_tensor(inp): ret_list.append(inp.grad) else: ret_list.append(None) return tuple(ret_list)
def forward(ctx, run_function, all_outputs, *args): global mpu, timers, SYNCHRONIZE, PROFILE_TIME def save_args_for_backward(*all_args): tensor_args, non_tensor_args, tensor_flags = extract_tensors( all_objects=all_args) ctx.save_for_backward(*tensor_args) ctx.non_tensor_args = non_tensor_args ctx.tensor_flags = tensor_flags if SYNCHRONIZE: torch.cuda.synchronize() if timers is None and PROFILE_TIME: timers = Timers() if PROFILE_TIME: timers('forward').start() ctx.run_function = run_function global num_layers global mp_rank, mp_size, mp_group global contiguous_data_buffers, contiguous_size_buffers global data_offsets, size_offsets if mp_rank is None: if mpu is not None: mp_rank = mpu.get_model_parallel_rank() mp_size = mpu.get_model_parallel_world_size() mp_group = mpu.get_model_parallel_group() else: mp_rank = 0 mp_size = 1 mp_group = None global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset if cuda_device is None: see_memory_usage("First Forward Begining", force=False) if dist.get_rank() == 0: logger.info(f"Activation Checkpointing Information") logger.info( f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}" ) logger.info( f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers" ) logger.info(f"----Synchronization {SYNCHRONIZE}") logger.info(f"----Profiling {PROFILE_TIME}") cuda_device = torch.cuda.current_device() transport_stream = torch.cuda.Stream(device=cuda_device) if PARTITION_ACTIVATIONS: #inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]] # inputs.append(args[-1]) inputs = [] for i, item in enumerate(args[:-1]): if not torch.is_tensor(item): inputs.append(item) continue partition_size = get_partition_size(item) partition = item.detach().contiguous().view(-1).narrow( 0, get_partition_start(item), partition_size).clone() if CONTIGUOUS_CHECKPOINTING: buffer_device = torch.device( 'cpu') if PA_TO_CPU else partition.device if i >= len(contiguous_data_buffers): tensor_list = [ torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device) for i in range(num_layers) ] contiguous_data_buffers.append(tensor_list) data_offsets.append(0) elif contiguous_data_buffers[i] is None: tensor_list = [ torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device) for i in range(num_layers) ] contiguous_data_buffers[i] = tensor_list data_offsets[i] = 0 # Because the 'new_empty' returns uninitialized pages, # the pages need to be populated during the cudaMemcpy time # which increases the data copy time. To avoid this, we # pre-populate these pages by simply writing 0 ahead of # the actual cudaMemcpy operation time. Due to the # previously launched GPU kernels, there is a small # window of time here for CPUs to populate pages asynchronously. contiguous_data_buffers[i][data_offsets[i]].data[range( 0, contiguous_data_buffers[i][ data_offsets[i]].data.shape[0], int(mmap.PAGESIZE / contiguous_data_buffers[i][ data_offsets[i]].data.element_size()))] = 0 contiguous_partition = contiguous_data_buffers[i][ data_offsets[i]].data.copy_(partition.data) data_offsets[i] = data_offsets[i] + 1 inputs.append(contiguous_partition) else: partition = partition.cpu() if PA_TO_CPU else partition inputs.append(partition) inputs.append(args[-1]) #just in case something funky is happening such as reuse of inputs inputs_cuda = move_to_device(args, cuda_device) # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() see_memory_usage("Before running forward on the layer", force=False) # ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*inputs_cuda) see_memory_usage("After running forward on the layer", force=False) del inputs_cuda # with torch.cuda.stream(transport_stream): # if PARTITION_ACTIVATIONS: # new_args = [] # for arg, inp in zip(args,inputs): # size= torch.tensor(arg.size()) # arg.data = inp.data # new_args.append(arg) # new_args.append(size) # ctx.save_for_backward(*new_args) if PARTITION_ACTIVATIONS: new_args = [] for i, (arg, inp) in enumerate(zip(args, inputs)): if not torch.is_tensor(arg): new_args.append(arg) continue size = torch.tensor(arg.size()) arg.data = inp.data new_args.append(arg) if CONTIGUOUS_CHECKPOINTING: numel = size.numel() if i >= len(contiguous_size_buffers): tmp = torch.tensor(()) contiguous_size_buffers.append( tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device)) size_offsets.append(0) elif contiguous_size_buffers[i] is None: tmp = torch.tensor(()) contiguous_size_buffers[i] = tmp.new_empty( [numel * num_layers], dtype=size.dtype, device=size.device) size_offsets[i] = 0 contiguous_size = contiguous_size_buffers[i].narrow( 0, size_offsets[i], numel).data.copy_(size.data) contiguous_size = contiguous_size.view_as(size) size_offsets[i] = size_offsets[i] + numel new_args.append(contiguous_size) else: new_args.append(size) # if dist.get_rank() == 0: # logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ") save_args_for_backward(*new_args) else: save_args_for_backward(*args) if PROFILE_TIME: timers('forward').stop() timers.log(['forward']) if SYNCHRONIZE: torch.cuda.synchronize() # Tensors returned from forward() may not be differentiable. if torch.is_tensor(outputs): non_grad_outputs = [outputs ] if not outputs.is_floating_point() else [] else: non_grad_outputs = [ o for o in outputs if torch.is_tensor(o) and not o.is_floating_point() ] ctx.mark_non_differentiable(*non_grad_outputs) if torch.is_tensor(outputs): all_outputs += [outputs] return outputs else: all_outputs += outputs outputs, _, _ = extract_tensors(all_objects=outputs) return tuple(outputs)