def all_benchmark():
    # Benchmark all multiplication (local node)
    if is_main_process():
        xlarge = torch.rand(1,
                            75000 // args.scale,
                            75000 // args.scale,
                            device=device)
        y = torch.rand(1, 75000 // args.scale, 768, device=device)
        input_memory = torch.cuda.memory_allocated()
        print(f'Memory allocated by xlarge/y: '
              f'{humanize.naturalsize(input_memory)}')
        result, op_time, peak_memory = measure(torch.matmul, xlarge, y)
        del xlarge
        del y
        torch.cuda.empty_cache()
        output_memory = torch.cuda.memory_allocated()
        print(f'matmul_nt - Output memory consumption: '
              f'{humanize.naturalsize(output_memory)}')
        del result
        torch.cuda.empty_cache()

    # Benchmark all multiplication (distributed)
    xsmall = torch.rand(1,
                        75000 // (3 * args.scale),
                        75000 // args.scale,
                        device=device)
    ysmall = torch.rand(1, 75000 // (3 * args.scale), 768, device=device)
    dist_input_size = torch.cuda.memory_allocated()
    print(f'Memory allocated by xsmall/ysmall: '
          f'{humanize.naturalsize(dist_input_size)}')
    synchronize()
    result, dop_time, dpeak_memory = measure(distributed_matmul_all,
                                             xsmall,
                                             ysmall,
                                             offset=args.offset)
    del xsmall
    del ysmall

    torch.cuda.empty_cache()
    doutput_memory = torch.cuda.memory_allocated()
    print(f'distributed_matmul_all - Output memory consumption: '
          f'{humanize.naturalsize(doutput_memory)}')
    del result
    torch.cuda.empty_cache()

    all_input_size = comm.gather(dist_input_size, root=0)
    all_op_time = comm.gather(dop_time, root=0)
    all_peak_memory = comm.gather(dpeak_memory, root=0)
    all_output_memory = comm.gather(doutput_memory, root=0)

    if is_main_process():
        avg_input_size = sum(all_input_size) / len(all_input_size)
        avg_op_time = sum(all_op_time) / len(all_op_time)
        avg_peak_memory = sum(all_peak_memory) / len(all_peak_memory)
        avg_output_memory = sum(all_output_memory) / len(all_output_memory)

        return (input_memory, output_memory, op_time, peak_memory,
                avg_input_size, avg_op_time, avg_peak_memory,
                avg_output_memory)
def distributed_matmul_nt(left: Tensor, right: Tensor, offset=32) -> Tensor:
    """
    Multiply two sequence tensors to obtain the result of :math:`AB^T`.

    Left and right inputs can be N-dimensional tensors of size
    :math:`* \times \frac{T}{N} \times D`, where :math:`T` is the total length,
    :math:`N` is the total number of processes available and :math:`D`, the
    dimension of the sequence. The result of this function is a tensor of size
    :math:`* \times \frac{T}{N} \times T`, that contain the result chunk for
    each process of the resulting operation.

    Inputs
    ------
    left: Tensor
        :math:`A` in :math:`AB^T`, must be of size
        :math:`* \times \frac{T}{N} \times D`.
    right: Tensor
        :math:`B` in :math:`AB^T`, must be of size
        :math:`* \times \frac{T}{N} \times D`.
    offset: int
        Number of chunks to communicate during each distributed step, it must
        be a factor of :math:`\frac{T}{N}`. This factor should be modified in
        order to reduce the total computing time at the expense of the memory
        used.

    Returns
    -------
    result: Tensor
        For each process, this function computes the corresponding segment
        of the operation :math:`A^T B`, of size
        :math:`* \times \frac{T}{N} \times T`.
    """
    synchronize()
    rows = left.size(-2)
    world_size = get_world_size()
    total_rows = rows * world_size

    prefix_size = tuple(left.size())[:-2]
    size = (left.size(-2), right.size(-2))
    size = (world_size,) + prefix_size + size
    # (world_size, ...dims, T/N, T/N)
    result = torch.empty(size, device=left.device)
    final_size = prefix_size + (left.size(-2), total_rows)

    for row in range(0, rows, offset):
        end_bound = row + offset
        current_row = right[..., row:end_bound, :].contiguous()
        # [r0[row:end_bound], r1[row:end_bound], ..., rworld[row:end_bound]]
        # all_rows: world_size x ... x offset x dim
        current_row = current_row.unsqueeze(0)
        all_rows = hvd.allgather(current_row, name=f'scatter_rows_{row}')
        partial_results = left.matmul(all_rows.transpose(-1, -2))
        result[..., row:end_bound] = partial_results
    result = result.unsqueeze(-2).transpose(0, -2).reshape(*final_size)
    return result
def distributed_matmul_all(left: Tensor, right: Tensor, offset=32) -> Tensor:
    """
    Multiply two sequence tensors to obtain the result of :math:`AB`.

    Left and right inputs can be N-dimensional tensors, where the first one
    must be of size :math:`* \times \frac{T}{N} \times T` and the second one of
    size , where :math:`* \times \frac{T}{N} \times D`, where :math:`T` is the
    total length,  :math:`N` is the total number of processes available and
    :math:`D`, the dimension of the sequence. The result of this function is a
    tensor of size :math:`* \times \frac{T}{N} \times D`, that contain the
    result chunk for each process of the resulting operation.

    Inputs
    ------
    left: Tensor
        :math:`A` in :math:`AB`, must be of size
        :math:`* \times \frac{T}{N} \times T`
    right: Tensor
        :math:`B` in :math:`AB`, must be of size
        :math:`* \times \frac{T}{N} \times D`

    Returns
    -------
    result: Tensor
        For each process, this function computes the corresponding segment
        of the operation :math:`AB`, of size
        :math:`1 \times \frac{T}{N} \times D`
    """
    dims = left.dim()
    cols = left.size(dims - 1)
    world_size = get_world_size()

    total_cols = right.size(-1)
    split_size = cols // world_size
    splits = torch.stack(left.split(split_size, -1), dim=0)
    left_sizes = tuple(left.size())
    size = (world_size,) + left_sizes[:-2] + (left.size(-2), total_cols)
    rank_block = torch.empty(*size, device=left.device)

    total_cols = right.size(-1)
    synchronize()
    for current_col in range(0, total_cols, offset):
        end_bound = current_col + offset
        col = right[..., current_col:end_bound]
        col = col.contiguous()
        all_cols = hvd.allgather(col.unsqueeze(0),
                                 name=f'matmul_all_{current_col}')
        # all_cols: torch.size([world_size, right.size(1), offset])
        block_result = torch.matmul(splits, all_cols)
        rank_block[..., current_col:end_bound] = block_result
    result = rank_block.sum(dim=0)
    return result
def distributed_matmul_tn(left: Tensor, right: Tensor) -> Tensor:
    """
    Multiply two sequence tensors to obtain the result of :math:`A^{T} B`.

    Left and right inputs can be N-dimensional tensors, where the first one
    must be of size :math:`* \times \frac{T}{N} \times T` and the second one of
    size , where :math:`* \times \frac{T}{N} \times D`, where :math:`T` is the
    total length,  :math:`N` is the total number of processes available and
    :math:`D`, the dimension of the sequence. The result of this function is a
    tensor of size :math:`* \times \frac{T}{N} \times D`, that contain the
    result chunk for each process of the resulting operation.

    Inputs
    ------
    left: Tensor
        :math:`A` in :math:`A^T B`, must be of size
        :math:`* \times \frac{T}{N} \times T`
    right: Tensor
        :math:`B` in :math:`A^T B`, must be of size
        :math:`* \times \frac{T}{N} \times D`

    Returns
    -------
    result: Tensor
        For each process, this function computes the corresponding segment
        of the operation :math:`A^T B`, of size
        :math:`* \times \frac{T}{N} \times D`
    """
    cols = left.size(-1)
    world_size = get_world_size()
    rank = get_rank()

    split_size = cols // world_size
    splits = left.split(split_size, -1)
    rank_block = None

    synchronize()
    for r in range(world_size):
        rank_split = splits[r]
        rank_multiplication = torch.matmul(rank_split.transpose(-1, -2), right)
        handle = hvd.allreduce_async(rank_multiplication,
                                     name=f'matmul_tn_{r}',
                                     op=hvd.Sum)
        if r == rank:
            rank_block = hvd.synchronize(handle)
    return rank_block.contiguous()
def main():
    test_funcs = {'nt': nt_benchmark, 'all': all_benchmark, 'tn': tn_benchmark}
    output = test_funcs[args.mode]()
    if is_main_process():
        (input_memory, output_memory, op_time, peak_memory, avg_input_size,
         avg_op_time, avg_peak_memory, avg_output_memory) = output

        output = {
            'input_memory': input_memory,
            'total_time': op_time,
            'peak_memory': peak_memory,
            'output_memory': output_memory,
            'distributed_input_memory': avg_input_size,
            'distributed_time': avg_op_time,
            'distributed_peak_memory': avg_peak_memory,
            'distributed_output_memory': avg_output_memory
        }

        values.append(output)
        json.dump(values, open(args.file, 'w'))
    synchronize()