Beispiel #1
0
def allgather_fn(output_tensor,
                 input_tensor,
                 group=None,
                 async_op=False,
                 debug=get_caller_func()):
    global cdb
    global has_warned_all_gather
    assert cdb is not None and cdb.is_initialized(
    ), 'DeepSpeed backend not set, please initialize it using init_process_group()'
    if cdb.has_allgather_base:
        return all_gather_base(output_tensor,
                               input_tensor,
                               group=group,
                               async_op=async_op,
                               debug=debug)
    else:
        if not has_warned_all_gather and get_rank() == 0:
            utils.logger.warning(
                "unable to find torch.distributed._all_gather_base. will fall back to "
                "torch.distributed.all_gather which will result in suboptimal performance. "
                "please consider upgrading your pytorch installation.")
            has_warned_all_gather = True
        output_tensors = list(
            torch.chunk(output_tensor, cdb.get_world_size(group)))
        return all_gather(output_tensors,
                          input_tensor,
                          group=group,
                          async_op=async_op,
                          debug=debug)
Beispiel #2
0
def irecv(tensor,
          src=None,
          group=None,
          tag=0,
          prof=False,
          log_name='irecv',
          debug=get_caller_func()):
    global cdb
    return cdb.recv(tensor=tensor, src=src, group=group, tag=tag)
Beispiel #3
0
def isend(tensor,
          dst,
          group=None,
          tag=0,
          prof=False,
          log_name='isend',
          debug=get_caller_func()):
    global cdb
    return cdb.send(tensor=tensor, dst=dst, group=group, tag=tag)
Beispiel #4
0
def all_gather_base(output_tensor,
                    tensor,
                    group=None,
                    async_op=False,
                    prof=False,
                    log_name='all_gather_base',
                    debug=get_caller_func()):
    global cdb
    return cdb.all_gather_base(output_tensor=output_tensor,
                               input_tensor=tensor,
                               group=group,
                               async_op=async_op)
Beispiel #5
0
def all_gather(tensor_list,
               tensor,
               group=None,
               async_op=False,
               prof=False,
               log_name='all_gather',
               debug=get_caller_func()):
    global cdb
    return cdb.all_gather(tensor_list=tensor_list,
                          tensor=tensor,
                          group=group,
                          async_op=async_op)
Beispiel #6
0
def broadcast(tensor,
              src,
              group=None,
              async_op=False,
              prof=False,
              log_name='broadcast',
              debug=get_caller_func()):
    global cdb
    return cdb.broadcast(tensor=tensor,
                         src=src,
                         group=group,
                         async_op=async_op)
Beispiel #7
0
def all_reduce(tensor,
               op=ReduceOp.SUM,
               group=None,
               async_op=False,
               prof=False,
               log_name='all_reduce',
               debug=get_caller_func()):
    #if profile_comm:
    # context of the timers?
    # timers.start()
    # TensorBoard logging for comm calls.?
    global cdb
    #print(f'op = {op}, cdb= {cdb.name}')
    return cdb.all_reduce(tensor, op, group, async_op)
Beispiel #8
0
def reduce_scatter(output,
                   input_list,
                   op=ReduceOp.SUM,
                   group=None,
                   async_op=False,
                   prof=False,
                   log_name='reduce_scatter',
                   debug=get_caller_func()):
    global cdb
    return cdb.reduce_scatter(output=output,
                              input_list=input_list,
                              op=op,
                              group=group,
                              async_op=async_op)
Beispiel #9
0
def reduce(tensor,
           dst,
           op=ReduceOp.SUM,
           group=None,
           async_op=False,
           prof=False,
           log_name='reduce',
           debug=get_caller_func()):
    global cdb
    return cdb.reduce(tensor=tensor,
                      dst=dst,
                      op=op,
                      group=group,
                      async_op=async_op)
Beispiel #10
0
def scatter(tensor,
            scatter_list=None,
            src=0,
            group=None,
            async_op=False,
            prof=False,
            log_name='scatter',
            debug=get_caller_func()):
    global cdb
    return cdb.scatter(tensor=tensor,
                       scatter_list=scatter_list,
                       src=src,
                       group=group,
                       async_op=async_op)
Beispiel #11
0
def gather(tensor,
           gather_list=None,
           dst=0,
           group=None,
           async_op=False,
           prof=False,
           log_name='gather',
           debug=get_caller_func()):
    global cdb
    return cdb.gather(tensor=tensor,
                      gather_list=gather_list,
                      dst=dst,
                      group=group,
                      async_op=async_op)
Beispiel #12
0
def all_to_all_single(output,
                      tensor,
                      output_split_sizes=None,
                      input_split_sizes=None,
                      group=None,
                      async_op=False,
                      prof=False,
                      log_name='all_to_all_single',
                      debug=get_caller_func()):
    global cdb
    return cdb.all_to_all_single(output=output,
                                 input=tensor,
                                 output_split_sizes=output_split_sizes,
                                 input_split_sizes=input_split_sizes,
                                 group=group,
                                 async_op=async_op)
Beispiel #13
0
def barrier(group=None,
            prof=False,
            log_name='barrier',
            debug=get_caller_func()):
    global cdb
    return cdb.barrier()