def allgather_fn(output_tensor, input_tensor, group=None, async_op=False, debug=get_caller_func()): global cdb global has_warned_all_gather assert cdb is not None and cdb.is_initialized( ), 'DeepSpeed backend not set, please initialize it using init_process_group()' if cdb.has_allgather_base: return all_gather_base(output_tensor, input_tensor, group=group, async_op=async_op, debug=debug) else: if not has_warned_all_gather and get_rank() == 0: utils.logger.warning( "unable to find torch.distributed._all_gather_base. will fall back to " "torch.distributed.all_gather which will result in suboptimal performance. " "please consider upgrading your pytorch installation.") has_warned_all_gather = True output_tensors = list( torch.chunk(output_tensor, cdb.get_world_size(group))) return all_gather(output_tensors, input_tensor, group=group, async_op=async_op, debug=debug)
def irecv(tensor, src=None, group=None, tag=0, prof=False, log_name='irecv', debug=get_caller_func()): global cdb return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
def isend(tensor, dst, group=None, tag=0, prof=False, log_name='isend', debug=get_caller_func()): global cdb return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
def all_gather_base(output_tensor, tensor, group=None, async_op=False, prof=False, log_name='all_gather_base', debug=get_caller_func()): global cdb return cdb.all_gather_base(output_tensor=output_tensor, input_tensor=tensor, group=group, async_op=async_op)
def all_gather(tensor_list, tensor, group=None, async_op=False, prof=False, log_name='all_gather', debug=get_caller_func()): global cdb return cdb.all_gather(tensor_list=tensor_list, tensor=tensor, group=group, async_op=async_op)
def broadcast(tensor, src, group=None, async_op=False, prof=False, log_name='broadcast', debug=get_caller_func()): global cdb return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)
def all_reduce(tensor, op=ReduceOp.SUM, group=None, async_op=False, prof=False, log_name='all_reduce', debug=get_caller_func()): #if profile_comm: # context of the timers? # timers.start() # TensorBoard logging for comm calls.? global cdb #print(f'op = {op}, cdb= {cdb.name}') return cdb.all_reduce(tensor, op, group, async_op)
def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=None, async_op=False, prof=False, log_name='reduce_scatter', debug=get_caller_func()): global cdb return cdb.reduce_scatter(output=output, input_list=input_list, op=op, group=group, async_op=async_op)
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, async_op=False, prof=False, log_name='reduce', debug=get_caller_func()): global cdb return cdb.reduce(tensor=tensor, dst=dst, op=op, group=group, async_op=async_op)
def scatter(tensor, scatter_list=None, src=0, group=None, async_op=False, prof=False, log_name='scatter', debug=get_caller_func()): global cdb return cdb.scatter(tensor=tensor, scatter_list=scatter_list, src=src, group=group, async_op=async_op)
def gather(tensor, gather_list=None, dst=0, group=None, async_op=False, prof=False, log_name='gather', debug=get_caller_func()): global cdb return cdb.gather(tensor=tensor, gather_list=gather_list, dst=dst, group=group, async_op=async_op)
def all_to_all_single(output, tensor, output_split_sizes=None, input_split_sizes=None, group=None, async_op=False, prof=False, log_name='all_to_all_single', debug=get_caller_func()): global cdb return cdb.all_to_all_single(output=output, input=tensor, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=group, async_op=async_op)
def barrier(group=None, prof=False, log_name='barrier', debug=get_caller_func()): global cdb return cdb.barrier()