def __init__(self, gate: Module, experts: Module, num_local_experts: int, group: Optional[Any] = None, use_tutel: bool = False) -> None: super().__init__() self.gate = gate self.experts = experts self.group = group self.world_size = dist.get_world_size(group) self.num_local_experts = num_local_experts self.time_falltoall = 0.0 self.time_salltoall = 0.0 self.time_moe = 0.0 self.timers = SynchronizedWallClockTimer() self.wall_clock_breakdown = False self.use_tutel = use_tutel and TUTEL_INSTALLED if self.use_tutel: logger.info('Using Tutel optimizations.') elif use_tutel and not TUTEL_INSTALLED: logger.warning("Tutel optimization requested but not installed. " "Proceeding without Tutel.")
def __init__(self, gate: Module, experts: Module, ep_group_name, ep_size, num_local_experts: int, use_tutel: bool = False) -> None: super().__init__() self.gate = gate self.experts = experts self.ep_group = None self.ep_size = ep_size self.ep_group_name = ep_group_name self.num_local_experts = num_local_experts self.time_falltoall = 0.0 self.time_salltoall = 0.0 self.time_moe = 0.0 self.timers = SynchronizedWallClockTimer() self.wall_clock_breakdown = False self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1 if self.use_tutel: logger.info('Using Tutel optimizations.') elif use_tutel and not TUTEL_INSTALLED: logger.warning("Tutel optimization requested but not installed. " "Proceeding without Tutel.") elif use_tutel and TUTEL_INSTALLED and gate.k != 1: logger.warning( "To enable Tutel optimization, use top-1 instead of top-2 gate. " "Proceeding without Tutel.")
def __init__(self, model_dim: int, num_experts: int, k: int = 1, capacity_factor: float = 1.0, eval_capacity_factor: float = 1.0, min_capacity: int = 4, noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, use_rts: bool = True) -> None: super().__init__() # Only top-1 and top-2 are supported at the moment. if k != 1 and k != 2: raise ValueError('Only top-1 and top-2 gatings are supported.') self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float() self.k = k self.capacity_factor = capacity_factor self.eval_capacity_factor = eval_capacity_factor self.min_capacity = min_capacity self.noisy_gate_policy = noisy_gate_policy self.timers = SynchronizedWallClockTimer() self.wall_clock_breakdown = False self.gate_time = 0.0 self.drop_tokens = drop_tokens self.use_rts = use_rts
def forward(ctx, run_function, *args): global mpu, timers, SYNCHRONIZE, PROFILE_TIME if SYNCHRONIZE: torch.cuda.synchronize() if timers is None and PROFILE_TIME: timers = Timers() if PROFILE_TIME: timers('forward').start() ctx.run_function = run_function global num_layers global mp_rank, mp_size, mp_group global contiguous_data_buffers, contiguous_size_buffers global data_offsets, size_offsets if mp_rank is None: if mpu is not None: mp_rank = mpu.get_model_parallel_rank() mp_size = mpu.get_model_parallel_world_size() mp_group = mpu.get_model_parallel_group() else: mp_rank = 0 mp_size = 1 mp_group = None global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset if cuda_device is None: see_memory_usage("First Forward Begining", force=True) if dist.get_rank() == 0: logger.info(f"Activation Checkpointing Information") logger.info( f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}" ) logger.info( f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers" ) logger.info(f"----Synchronization {SYNCHRONIZE}") logger.info(f"----Profiling {PROFILE_TIME}") cuda_device = torch.cuda.current_device() transport_stream = torch.cuda.Stream(device=cuda_device) if PARTITION_ACTIVATIONS: #inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]] #inputs.append(args[-1]) inputs = [] for i, item in enumerate(args[:-1]): partition_size = get_partition_size(item) partition = item.detach().contiguous().view(-1).narrow( 0, get_partition_start(item), partition_size).clone() if CONTIGUOUS_CHECKPOINTING: buffer_device = torch.device( 'cpu') if PA_TO_CPU else partition.device if i >= len(contiguous_data_buffers): tensor_list = [ torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device) for i in range(num_layers) ] contiguous_data_buffers.append(tensor_list) data_offsets.append(0) elif contiguous_data_buffers[i] is None: tensor_list = [ torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device) for i in range(num_layers) ] contiguous_data_buffers[i] = tensor_list data_offsets[i] = 0 contiguous_partition = contiguous_data_buffers[i][ data_offsets[i]].data.copy_(partition.data) data_offsets[i] = data_offsets[i] + 1 inputs.append(contiguous_partition) else: partition = partition.cpu() if PA_TO_CPU else partition inputs.append(partition) inputs.append(args[-1]) #just in case something funky is happening such as reuse of inputs inputs_cuda = [item.to(cuda_device) for item in args] # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() #ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*inputs_cuda) del inputs_cuda #with torch.cuda.stream(transport_stream): #if PARTITION_ACTIVATIONS: # new_args = [] # for arg, inp in zip(args,inputs): # size= torch.tensor(arg.size()) # arg.data = inp.data # new_args.append(arg) # new_args.append(size) # ctx.save_for_backward(*new_args) if PARTITION_ACTIVATIONS: new_args = [] for i, (arg, inp) in enumerate(zip(args, inputs)): size = torch.tensor(arg.size()) arg.data = inp.data new_args.append(arg) if CONTIGUOUS_CHECKPOINTING: numel = size.numel() if i >= len(contiguous_size_buffers): tmp = torch.tensor(()) contiguous_size_buffers.append( tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device)) size_offsets.append(0) elif contiguous_size_buffers[i] is None: tmp = torch.tensor(()) contiguous_size_buffers[i] = tmp.new_empty( [numel * num_layers], dtype=size.dtype, device=size.device) size_offsets[i] = 0 contiguous_size = contiguous_size_buffers[i].narrow( 0, size_offsets[i], numel).data.copy_(size.data) contiguous_size = contiguous_size.view_as(size) size_offsets[i] = size_offsets[i] + numel new_args.append(contiguous_size) else: new_args.append(size) #if dist.get_rank() == 0: # logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ") ctx.save_for_backward(*new_args) else: ctx.save_for_backward(*args) if PROFILE_TIME: timers('forward').stop() timers.log(['forward']) if SYNCHRONIZE: torch.cuda.synchronize() # Tensors returned from forward() may not be differentiable. if torch.is_tensor(outputs): non_grad_outputs = [outputs ] if not outputs.is_floating_point() else [] else: non_grad_outputs = [ o for o in outputs if not o.is_floating_point() ] ctx.mark_non_differentiable(*non_grad_outputs) return outputs
def forward(ctx, run_function, all_outputs, *args): global mpu, timers, SYNCHRONIZE, PROFILE_TIME def save_args_for_backward(*all_args): tensor_args, non_tensor_args, tensor_flags = extract_tensors( all_objects=all_args) ctx.save_for_backward(*tensor_args) ctx.non_tensor_args = non_tensor_args ctx.tensor_flags = tensor_flags if SYNCHRONIZE: torch.cuda.synchronize() if timers is None and PROFILE_TIME: timers = Timers() if PROFILE_TIME: timers('forward').start() ctx.run_function = run_function global num_layers global mp_rank, mp_size, mp_group global contiguous_data_buffers, contiguous_size_buffers global data_offsets, size_offsets if mp_rank is None: if mpu is not None: if hasattr(mpu, 'get_tensor_model_parallel_rank'): mp_rank = mpu.get_tensor_model_parallel_rank() mp_size = mpu.get_tensor_model_parallel_world_size() mp_group = mpu.get_tensor_model_parallel_group() else: mp_rank = mpu.get_model_parallel_rank() mp_size = mpu.get_model_parallel_world_size() mp_group = mpu.get_model_parallel_group() else: mp_rank = 0 mp_size = 1 mp_group = None global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset if cuda_device is None: see_memory_usage("First Forward Beginning", force=False) if dist.get_rank() == 0: logger.info(f"Activation Checkpointing Information") logger.info( f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {CPU_CHECKPOINT}" ) logger.info( f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers" ) logger.info(f"----Synchronization {SYNCHRONIZE}") logger.info( f"----Profiling time in checkpointing {PROFILE_TIME}") cuda_device = torch.cuda.current_device() transport_stream = torch.cuda.Stream(device=cuda_device) if PARTITION_ACTIVATIONS: inputs = partition_activations(args, CPU_CHECKPOINT, CONTIGUOUS_CHECKPOINTING) elif CPU_CHECKPOINT: inputs = copy_to_device(args, device=torch.device('cpu'), criterion_func=is_activation_to_checkpoint) # just in case something funky is happening such as reuse of inputs inputs_cuda = copy_to_device( args, device=cuda_device, criterion_func=is_activation_to_checkpoint) # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() see_memory_usage("Before running forward on the layer", force=False) # ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*inputs_cuda) see_memory_usage("After running forward on the layer", force=False) del inputs_cuda if PARTITION_ACTIVATIONS: new_args = get_partitioned_activations_for_backward( args, inputs, CONTIGUOUS_CHECKPOINTING) assert len( new_args ) % 2 == 0, f'save_for_backward called with odd number of args, {len(new_args)}' save_args_for_backward(*new_args) elif CPU_CHECKPOINT: new_args = get_cpu_activations_for_backward(args, inputs) save_args_for_backward(*new_args) else: save_args_for_backward(*args) if PROFILE_TIME: timers('forward').stop() timers.log(['forward']) if SYNCHRONIZE: torch.cuda.synchronize() # Tensors returned from forward() may not be differentiable. if torch.is_tensor(outputs): non_grad_outputs = [outputs ] if not outputs.is_floating_point() else [] else: non_grad_outputs = [ o for o in outputs if torch.is_tensor(o) and not o.is_floating_point() ] ctx.mark_non_differentiable(*non_grad_outputs) if torch.is_tensor(outputs): all_outputs += [outputs] return outputs else: all_outputs += outputs outputs, _, _ = extract_tensors(all_objects=outputs) return tuple(outputs)
import torch import deepspeed.comm as dist import numpy as np import argparse import deepspeed import os from deepspeed.runtime.comm.nccl import NcclBackend from deepspeed.utils.timer import SynchronizedWallClockTimer from statistics import mean timers = SynchronizedWallClockTimer() parser = argparse.ArgumentParser() parser.add_argument('--local_rank', type=int, default=-1) args = parser.parse_args() deepspeed.init_distributed(dist_backend='nccl') args.local_rank = int(os.environ['LOCAL_RANK']) torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) size = dist.get_world_size() rank = dist.get_rank() backend = NcclBackend() local_rank = args.local_rank # Setting tensor_size (BERT-Large) tensor_size = 300 * 2**20
def forward(ctx, run_function, all_outputs, *args): global mpu, timers, SYNCHRONIZE, PROFILE_TIME def save_args_for_backward(*all_args): tensor_args, non_tensor_args, tensor_flags = extract_tensors( all_objects=all_args) ctx.save_for_backward(*tensor_args) ctx.non_tensor_args = non_tensor_args ctx.tensor_flags = tensor_flags if SYNCHRONIZE: torch.cuda.synchronize() if timers is None and PROFILE_TIME: timers = Timers() if PROFILE_TIME: timers('forward').start() ctx.run_function = run_function global num_layers global mp_rank, mp_size, mp_group global contiguous_data_buffers, contiguous_size_buffers global data_offsets, size_offsets if mp_rank is None: if mpu is not None: mp_rank = mpu.get_model_parallel_rank() mp_size = mpu.get_model_parallel_world_size() mp_group = mpu.get_model_parallel_group() else: mp_rank = 0 mp_size = 1 mp_group = None global cuda_device, transport_stream, PARTITION_ACTIVATIONS, buffer_0, buffer_1, buffer_0_offset, buffer_1_offset if cuda_device is None: see_memory_usage("First Forward Begining", force=False) if dist.get_rank() == 0: logger.info(f"Activation Checkpointing Information") logger.info( f"----Partition Activations {PARTITION_ACTIVATIONS}, CPU CHECKPOINTING {PA_TO_CPU}" ) logger.info( f"----contiguous Memory Checkpointing {CONTIGUOUS_CHECKPOINTING} with {num_layers} total layers" ) logger.info(f"----Synchronization {SYNCHRONIZE}") logger.info(f"----Profiling {PROFILE_TIME}") cuda_device = torch.cuda.current_device() transport_stream = torch.cuda.Stream(device=cuda_device) if PARTITION_ACTIVATIONS: #inputs = [item.detach().contiguous().view(-1).narrow(0, get_partition_start(item), get_partition_size(item)).clone() for item in args[:-1]] # inputs.append(args[-1]) inputs = [] for i, item in enumerate(args[:-1]): if not torch.is_tensor(item): inputs.append(item) continue partition_size = get_partition_size(item) partition = item.detach().contiguous().view(-1).narrow( 0, get_partition_start(item), partition_size).clone() if CONTIGUOUS_CHECKPOINTING: buffer_device = torch.device( 'cpu') if PA_TO_CPU else partition.device if i >= len(contiguous_data_buffers): tensor_list = [ torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device) for i in range(num_layers) ] contiguous_data_buffers.append(tensor_list) data_offsets.append(0) elif contiguous_data_buffers[i] is None: tensor_list = [ torch.tensor(()).new_empty([partition_size], dtype=partition.dtype, device=buffer_device) for i in range(num_layers) ] contiguous_data_buffers[i] = tensor_list data_offsets[i] = 0 # Because the 'new_empty' returns uninitialized pages, # the pages need to be populated during the cudaMemcpy time # which increases the data copy time. To avoid this, we # pre-populate these pages by simply writing 0 ahead of # the actual cudaMemcpy operation time. Due to the # previously launched GPU kernels, there is a small # window of time here for CPUs to populate pages asynchronously. contiguous_data_buffers[i][data_offsets[i]].data[range( 0, contiguous_data_buffers[i][ data_offsets[i]].data.shape[0], int(mmap.PAGESIZE / contiguous_data_buffers[i][ data_offsets[i]].data.element_size()))] = 0 contiguous_partition = contiguous_data_buffers[i][ data_offsets[i]].data.copy_(partition.data) data_offsets[i] = data_offsets[i] + 1 inputs.append(contiguous_partition) else: partition = partition.cpu() if PA_TO_CPU else partition inputs.append(partition) inputs.append(args[-1]) #just in case something funky is happening such as reuse of inputs inputs_cuda = move_to_device(args, cuda_device) # Copy the rng states. ctx.fwd_cpu_rng_state = torch.get_rng_state() ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() see_memory_usage("Before running forward on the layer", force=False) # ctx.save_for_backward(*args) with torch.no_grad(): outputs = run_function(*inputs_cuda) see_memory_usage("After running forward on the layer", force=False) del inputs_cuda # with torch.cuda.stream(transport_stream): # if PARTITION_ACTIVATIONS: # new_args = [] # for arg, inp in zip(args,inputs): # size= torch.tensor(arg.size()) # arg.data = inp.data # new_args.append(arg) # new_args.append(size) # ctx.save_for_backward(*new_args) if PARTITION_ACTIVATIONS: new_args = [] for i, (arg, inp) in enumerate(zip(args, inputs)): if not torch.is_tensor(arg): new_args.append(arg) continue size = torch.tensor(arg.size()) arg.data = inp.data new_args.append(arg) if CONTIGUOUS_CHECKPOINTING: numel = size.numel() if i >= len(contiguous_size_buffers): tmp = torch.tensor(()) contiguous_size_buffers.append( tmp.new_empty([numel * num_layers], dtype=size.dtype, device=size.device)) size_offsets.append(0) elif contiguous_size_buffers[i] is None: tmp = torch.tensor(()) contiguous_size_buffers[i] = tmp.new_empty( [numel * num_layers], dtype=size.dtype, device=size.device) size_offsets[i] = 0 contiguous_size = contiguous_size_buffers[i].narrow( 0, size_offsets[i], numel).data.copy_(size.data) contiguous_size = contiguous_size.view_as(size) size_offsets[i] = size_offsets[i] + numel new_args.append(contiguous_size) else: new_args.append(size) # if dist.get_rank() == 0: # logger.info(f"The stored tensor is {contiguous_size} and orginal one is {size} ") save_args_for_backward(*new_args) else: save_args_for_backward(*args) if PROFILE_TIME: timers('forward').stop() timers.log(['forward']) if SYNCHRONIZE: torch.cuda.synchronize() # Tensors returned from forward() may not be differentiable. if torch.is_tensor(outputs): non_grad_outputs = [outputs ] if not outputs.is_floating_point() else [] else: non_grad_outputs = [ o for o in outputs if torch.is_tensor(o) and not o.is_floating_point() ] ctx.mark_non_differentiable(*non_grad_outputs) if torch.is_tensor(outputs): all_outputs += [outputs] return outputs else: all_outputs += outputs outputs, _, _ = extract_tensors(all_objects=outputs) return tuple(outputs)
from mpi4py import MPI import time import torch import torch.distributed as dist import numpy as np import deepspeed from deepspeed.runtime.comm.mpi import MpiBackend # Configure wall clock timer from deepspeed.utils.timer import SynchronizedWallClockTimer from statistics import mean timers = SynchronizedWallClockTimer() comm = MPI.COMM_WORLD size = comm.Get_size() rank = comm.Get_rank() deepspeed.init_distributed(dist_backend='nccl') # Change cuda_aware to True to test out CUDA-Aware MPI communication backend = MpiBackend(cuda_aware=False) device = torch.device('cuda', rank % torch.cuda.device_count()) tensor_size = 300 * 2**20 server_size = int(tensor_size / size) if tensor_size % (8 * size) != 0: right_tensor_size = tensor_size + (8 * size - (tensor_size % (8 * size))) else: