def forward(ctx, run_function, all_outputs, *args): global mpu, timers, SYNCHRONIZE, PROFILE_TIME def save_args_for_backward(*all_args): tensor_args, non_tensor_args, tensor_flags = extract_tensors( all_objects=all_args) ctx.save_for_backward(*tensor_args) ctx.non_tensor_args = non_tensor_args ctx.tensor_flags = tensor_flags if SYNCHRONIZE: torch.cuda.synchronize() if timers is None and PROFILE_TIME: timers = Timers() if PROFILE_TIME: timers('forward').start() ctx.run_function = run_function global num_layers global mp_rank, mp_size, mp_group global contiguous_data_buffers, contiguous_size_buffers global data_offsets, size_offsets if mp_rank is None: if mpu is not None: if hasattr(mpu, 'get_tensor_model_parallel_rank'): mp_rank = mpu.get_tensor_model_parallel_rank() mp_size = mpu.get_tensor_model_parallel_world_size() mp_group = mpu.get_tensor_model_parallel_group() else: mp_rank = mpu.get_model_parallel_rank() mp_size = mpu.get_model_parallel_world_size() mp_group = mpu.get_model_parallel_group() else: mp_rank = 0 mp_size = 1 mp_group = None global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset if cuda_device is None: see_memory_usage("First Forward Beginning", force=False) if dist.get_rank() == 0: logger.info(f"Activation Checkpointing Information") logger.info( f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}" ) logger.info( f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers" ) logger.info(f"----Synchronization {SYNCHRONIZE}") logger.info( f"----Profiling time in checkpointing {PROFILE_TIME}") cuda_device = torch.cuda.current_device() transport_stream = torch.cuda.Stream(device=cuda_device) if PARTITION_ACTIVATIONS: inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING) elif CPU_CHECKPOINT: inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint) # just in case something funky is happening such as reuse of inputs inputs_cuda = copy_to_device( args, device=cuda_device, criterion_func=is_activation_to_checkpoint) # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() see_memory_usage("Before running forward on the layer", force=False) # ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*inputs_cuda) see_memory_usage("After running forward on the layer", force=False) del inputs_cuda if PARTITION_ACTIVATIONS: new_args = get_partitioned_activations_for_backward( args, inputs, CONTIGUOUS_CHECKPOINTING) assert len( new_args ) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}' save_args_for_backward(*new_args) elif CPU_CHECKPOINT: new_args = get_cpu_activations_for_backward(args, inputs) save_args_for_backward(*new_args) else: save_args_for_backward(*args) if PROFILE_TIME: timers('forward').stop() timers.log(['forward']) if SYNCHRONIZE: torch.cuda.synchronize() # Tensors returned from forward() may not be differentiable. if torch.is_tensor(outputs): non_grad_outputs = [outputs ] if not outputs.is_floating_point() else [] else: non_grad_outputs = [ o for o in outputs if torch.is_tensor(o) and not o.is_floating_point() ] ctx.mark_non_differentiable(*non_grad_outputs) if torch.is_tensor(outputs): all_outputs += [outputs] return outputs else: all_outputs += outputs outputs, _, _ = extract_tensors(all_objects=outputs) return tuple(outputs)
def forward(ctx, run_function, *args): global mpu, timers, SYNCHRONIZE, PROFILE_TIME if SYNCHRONIZE: torch.cuda.synchronize() if timers is None and PROFILE_TIME: timers = Timers() if PROFILE_TIME: timers('forward').start() ctx.run_function = run_function global num_layers global mp_rank, mp_size, mp_group global contiguous_data_buffers, contiguous_size_buffers global data_offsets, size_offsets if mp_rank is None: if mpu is not None: mp_rank = mpu.get_model_parallel_rank() mp_size = mpu.get_model_parallel_world_size() mp_group = mpu.get_model_parallel_group() else: mp_rank = 0 mp_size = 1 mp_group = None global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset if cuda_device is None: see_memory_usage("First Forward Begining", force=True) if dist.get_rank() == 0: logger.info(f"Activation Checkpointing Information") logger.info( f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}" ) logger.info( f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers" ) logger.info(f"----Synchronization {SYNCHRONIZE}") logger.info(f"----Profiling {PROFILE_TIME}") cuda_device = torch.cuda.current_device() transport_stream = torch.cuda.Stream(device=cuda_device) if PARTITION_ACTIVATIONS: #inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]] #inputs.append(args[-1]) inputs = [] for i, item in enumerate(args[:-1]): partition_size = get_partition_size(item) partition = item.detach().contiguous().view(-1).narrow( 0, get_partition_start(item), partition_size).clone() if CONTIGUOUS_CHECKPOINTING: buffer_device = torch.device( 'cpu') if PA_TO_CPU else partition.device if i >= len(contiguous_data_buffers): tensor_list = [ torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device) for i in range(num_layers) ] contiguous_data_buffers.append(tensor_list) data_offsets.append(0) elif contiguous_data_buffers[i] is None: tensor_list = [ torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device) for i in range(num_layers) ] contiguous_data_buffers[i] = tensor_list data_offsets[i] = 0 contiguous_partition = contiguous_data_buffers[i][ data_offsets[i]].data.copy_(partition.data) data_offsets[i] = data_offsets[i] + 1 inputs.append(contiguous_partition) else: partition = partition.cpu() if PA_TO_CPU else partition inputs.append(partition) inputs.append(args[-1]) #just in case something funky is happening such as reuse of inputs inputs_cuda = [item.to(cuda_device) for item in args] # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() #ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*inputs_cuda) del inputs_cuda #with torch.cuda.stream(transport_stream): #if PARTITION_ACTIVATIONS: # new_args = [] # for arg, inp in zip(args,inputs): # size= torch.tensor(arg.size()) # arg.data = inp.data # new_args.append(arg) # new_args.append(size) # ctx.save_for_backward(*new_args) if PARTITION_ACTIVATIONS: new_args = [] for i, (arg, inp) in enumerate(zip(args, inputs)): size = torch.tensor(arg.size()) arg.data = inp.data new_args.append(arg) if CONTIGUOUS_CHECKPOINTING: numel = size.numel() if i >= len(contiguous_size_buffers): tmp = torch.tensor(()) contiguous_size_buffers.append( tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device)) size_offsets.append(0) elif contiguous_size_buffers[i] is None: tmp = torch.tensor(()) contiguous_size_buffers[i] = tmp.new_empty( [numel * num_layers], dtype=size.dtype, device=size.device) size_offsets[i] = 0 contiguous_size = contiguous_size_buffers[i].narrow( 0, size_offsets[i], numel).data.copy_(size.data) contiguous_size = contiguous_size.view_as(size) size_offsets[i] = size_offsets[i] + numel new_args.append(contiguous_size) else: new_args.append(size) #if dist.get_rank() == 0: # logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ") ctx.save_for_backward(*new_args) else: ctx.save_for_backward(*args) if PROFILE_TIME: timers('forward').stop() timers.log(['forward']) if SYNCHRONIZE: torch.cuda.synchronize() # Tensors returned from forward() may not be differentiable. if torch.is_tensor(outputs): non_grad_outputs = [outputs ] if not outputs.is_floating_point() else [] else: non_grad_outputs = [ o for o in outputs if not o.is_floating_point() ] ctx.mark_non_differentiable(*non_grad_outputs) return outputs
def forward(ctx, run_function, all_outputs, *args): global mpu, timers, SYNCHRONIZE, PROFILE_TIME def save_args_for_backward(*all_args): tensor_args, non_tensor_args, tensor_flags = extract_tensors( all_objects=all_args) ctx.save_for_backward(*tensor_args) ctx.non_tensor_args = non_tensor_args ctx.tensor_flags = tensor_flags if SYNCHRONIZE: torch.cuda.synchronize() if timers is None and PROFILE_TIME: timers = Timers() if PROFILE_TIME: timers('forward').start() ctx.run_function = run_function global num_layers global mp_rank, mp_size, mp_group global contiguous_data_buffers, contiguous_size_buffers global data_offsets, size_offsets if mp_rank is None: if mpu is not None: mp_rank = mpu.get_model_parallel_rank() mp_size = mpu.get_model_parallel_world_size() mp_group = mpu.get_model_parallel_group() else: mp_rank = 0 mp_size = 1 mp_group = None global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset if cuda_device is None: see_memory_usage("First Forward Begining", force=False) if dist.get_rank() == 0: logger.info(f"Activation Checkpointing Information") logger.info( f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}" ) logger.info( f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers" ) logger.info(f"----Synchronization {SYNCHRONIZE}") logger.info(f"----Profiling {PROFILE_TIME}") cuda_device = torch.cuda.current_device() transport_stream = torch.cuda.Stream(device=cuda_device) if PARTITION_ACTIVATIONS: #inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]] # inputs.append(args[-1]) inputs = [] for i, item in enumerate(args[:-1]): if not torch.is_tensor(item): inputs.append(item) continue partition_size = get_partition_size(item) partition = item.detach().contiguous().view(-1).narrow( 0, get_partition_start(item), partition_size).clone() if CONTIGUOUS_CHECKPOINTING: buffer_device = torch.device( 'cpu') if PA_TO_CPU else partition.device if i >= len(contiguous_data_buffers): tensor_list = [ torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device) for i in range(num_layers) ] contiguous_data_buffers.append(tensor_list) data_offsets.append(0) elif contiguous_data_buffers[i] is None: tensor_list = [ torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device) for i in range(num_layers) ] contiguous_data_buffers[i] = tensor_list data_offsets[i] = 0 # Because the 'new_empty' returns uninitialized pages, # the pages need to be populated during the cudaMemcpy time # which increases the data copy time. To avoid this, we # pre-populate these pages by simply writing 0 ahead of # the actual cudaMemcpy operation time. Due to the # previously launched GPU kernels, there is a small # window of time here for CPUs to populate pages asynchronously. contiguous_data_buffers[i][data_offsets[i]].data[range( 0, contiguous_data_buffers[i][ data_offsets[i]].data.shape[0], int(mmap.PAGESIZE / contiguous_data_buffers[i][ data_offsets[i]].data.element_size()))] = 0 contiguous_partition = contiguous_data_buffers[i][ data_offsets[i]].data.copy_(partition.data) data_offsets[i] = data_offsets[i] + 1 inputs.append(contiguous_partition) else: partition = partition.cpu() if PA_TO_CPU else partition inputs.append(partition) inputs.append(args[-1]) #just in case something funky is happening such as reuse of inputs inputs_cuda = move_to_device(args, cuda_device) # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() see_memory_usage("Before running forward on the layer", force=False) # ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*inputs_cuda) see_memory_usage("After running forward on the layer", force=False) del inputs_cuda # with torch.cuda.stream(transport_stream): # if PARTITION_ACTIVATIONS: # new_args = [] # for arg, inp in zip(args,inputs): # size= torch.tensor(arg.size()) # arg.data = inp.data # new_args.append(arg) # new_args.append(size) # ctx.save_for_backward(*new_args) if PARTITION_ACTIVATIONS: new_args = [] for i, (arg, inp) in enumerate(zip(args, inputs)): if not torch.is_tensor(arg): new_args.append(arg) continue size = torch.tensor(arg.size()) arg.data = inp.data new_args.append(arg) if CONTIGUOUS_CHECKPOINTING: numel = size.numel() if i >= len(contiguous_size_buffers): tmp = torch.tensor(()) contiguous_size_buffers.append( tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device)) size_offsets.append(0) elif contiguous_size_buffers[i] is None: tmp = torch.tensor(()) contiguous_size_buffers[i] = tmp.new_empty( [numel * num_layers], dtype=size.dtype, device=size.device) size_offsets[i] = 0 contiguous_size = contiguous_size_buffers[i].narrow( 0, size_offsets[i], numel).data.copy_(size.data) contiguous_size = contiguous_size.view_as(size) size_offsets[i] = size_offsets[i] + numel new_args.append(contiguous_size) else: new_args.append(size) # if dist.get_rank() == 0: # logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ") save_args_for_backward(*new_args) else: save_args_for_backward(*args) if PROFILE_TIME: timers('forward').stop() timers.log(['forward']) if SYNCHRONIZE: torch.cuda.synchronize() # Tensors returned from forward() may not be differentiable. if torch.is_tensor(outputs): non_grad_outputs = [outputs ] if not outputs.is_floating_point() else [] else: non_grad_outputs = [ o for o in outputs if torch.is_tensor(o) and not o.is_floating_point() ] ctx.mark_non_differentiable(*non_grad_outputs) if torch.is_tensor(outputs): all_outputs += [outputs] return outputs else: all_outputs += outputs outputs, _, _ = extract_tensors(all_objects=outputs) return tuple(outputs)