def _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists):
    """Check the compatibility between tensor input and tensor list input."""
    if not tensors or not isinstance(tensors, list):
        raise RuntimeError("The first argument 'tensors' expects a list of tensors.")

    if len(tensors) != 1:
        raise RuntimeError(
            "Gloo only accept one tensor in the first argument 'tensors'."
            " Got {} != 1.".format(len(tensors))
        )

    if not tensor_lists or not isinstance(tensor_lists, list):
        raise RuntimeError(
            "The second argument 'tensor_lists' expects a list of tensor list."
        )

    if len(tensor_lists) != 1:
        raise RuntimeError(
            "Gloo only accept one tensor list "
            "in the second argument 'tensor_lists'."
            " Got {} != 1.".format(len(tensor_lists))
        )

    dtype = gloo_util.get_gloo_tensor_dtype(tensors[0])
    shape = gloo_util.get_tensor_shape(tensors[0])

    # check all tensors in `tensor_lists` match.
    for t in tensor_lists[0]:
        # check dtype
        dt = gloo_util.get_gloo_tensor_dtype(t)
        if dt != dtype:
            raise RuntimeError(
                "All tensor operands to scatter/gather must "
                "have the same dtype. Got '{}' and '{}'.".format(dt, dtype)
            )
        s = gloo_util.get_tensor_shape(t)
        if s != shape:
            raise RuntimeError(
                "All tensor operands to scatter/gather must "
                "have the same shape. Got '{}' and '{}'.".format(s, shape)
            )
Exemple #2
0
def _flatten_for_scatter_gather(tensor_list, copy=False):
    """Flatten the tensor for gather/scatter operations.

    Args:
        tensor_list: the list of tensors to be scattered/gathered.
        copy: whether the copy the tensors in tensor_list into the buffer.

    Returns:
        The flattened tensor buffer.
    """
    if not tensor_list:
        raise RuntimeError("Received an empty list.")

    t = tensor_list[0]
    # note we need a numpy dtype here.
    dtype = gloo_util.get_numpy_tensor_dtype(t)
    buffer_shape = [len(tensor_list)] + gloo_util.get_tensor_shape(t)

    buffer = numpy.empty(buffer_shape, dtype=dtype)
    if copy:
        for i, tensor in enumerate(tensor_list):
            gloo_util.copy_tensor(buffer[i], tensor)
    return buffer