def test_repeats(self):
        """Test the setters on repeats works as expected."""
        # Test that attributes are set to all variables of the same instance
        rbh_default_1_1 = RemoteBufferHandle()
        rbh_default_1_2 = RemoteBufferHandle(remote_buffer_id=1)
        assert rbh_default_1_1 == rbh_default_1_2
        rbh_default_1_1.repeats = 1
        assert rbh_default_1_1.repeats == rbh_default_1_2.repeats
        # Check that repeats changes across the variables
        rbh_default_1_1.repeats = 2
        assert rbh_default_1_1.repeats == rbh_default_1_2.repeats
        rbh_default_1_2.repeats = 3
        assert rbh_default_1_1.repeats == rbh_default_1_2.repeats

        # Check that it can be set from the constructor
        rbh_default_1_3 = RemoteBufferHandle(remote_buffer_id=1, repeats=4)
        assert rbh_default_1_1.repeats == rbh_default_1_2.repeats
        assert rbh_default_1_3.repeats == rbh_default_1_2.repeats

        # Check that non-positive values are not allowed
        with pytest.raises(ValueError) as e_info:
            _ = RemoteBufferHandle(remote_buffer_id=2, repeats=0)
        assert e_info.value.args[0].startswith(
            "Repeats must be a non-zero, positive integer")
        with pytest.raises(ValueError) as e_info:
            rbh_default_1_2.repeats = 0
        assert e_info.value.args[0].startswith(
            "Repeats must be a non-zero, positive integer")

        # Clean-up so that the RemoteBufferHandle gets reset
        RemoteBufferHandle._buffers = {}
Esempio n. 2
0
def prepare_remote_buffer(t: Tensor,
                          remote_buffer_handle: Optional[RemoteBufferHandle],
                          g: Graph) -> RemoteBufferHandle:
    """Prepare the remote buffer.

    Args:
        t (Tensor): Input tensor to the op.
        remote_buffer_handle (Optional[RemoteBufferHandle]): If set:
          The remote buffer handle to use in the preparation
        g (Graph): The graph to set the remote buffer info to

    Raises:
        ValueError: If there is a shape or type mismatch between `t` and
          `remote_buffer_handle`

    Returns:
        RemoteBufferHandle: The remote buffer handle used in the preparation.
    """
    if remote_buffer_handle is None:
        shape = t._pb_tensor.info.shape()
        d_type = dtype.as_dtype(t._pb_tensor.info.data_type_lcase())
        # Check for existing buffer handles
        existing_buffers = RemoteBufferHandle._buffers
        for _, rbh in existing_buffers.items():
            if rbh.tensor_shape == shape and rbh.tensor_dtype == d_type:
                remote_buffer_handle = rbh
                break

        if remote_buffer_handle is None:
            # Create handle if not found
            remote_buffer_handle = RemoteBufferHandle(remote_buffer_id=None,
                                                      tensor_shape=shape,
                                                      tensor_dtype=d_type,
                                                      repeats=1)

    # The remote buffer handle may be set, and may have empty shape and dtype
    if remote_buffer_handle.tensor_shape is None:
        remote_buffer_handle.tensor_shape = t._pb_tensor.info.shape()
    if remote_buffer_handle.tensor_dtype is None:
        remote_buffer_handle.tensor_dtype = dtype.as_dtype(
            t._pb_tensor.info.data_type_lcase())

    info = _ir.TensorInfo(remote_buffer_handle.tensor_dtype._pb_dtype,
                          remote_buffer_handle.tensor_shape)
    if t._pb_tensor.info.dataType() != info.dataType():
        raise ValueError(
            f"DataType of {t.id} ({t._pb_tensor.info.dataType()}) "
            f"does not match that of the RemoteBufferHandle ({info.dataType()})"
        )
    if t._pb_tensor.info.shape() != info.shape():
        raise ValueError(
            f"DataType of {t.id} ({t._pb_tensor.info.shape()}) "
            f"does not match that of the RemoteBufferHandle ({info.shape()})")

    g._ir._pb_ir.setRemoteBufferInfo(
        remote_buffer_handle.remote_buffer_id,
        _ir.RemoteBufferInfo(info, remote_buffer_handle.repeats))

    return remote_buffer_handle
    def test___new___(self):
        """Test that the remote buffer handles are correctly created."""
        rbh_default_1 = RemoteBufferHandle()
        rbh_default_2 = RemoteBufferHandle()

        # The second call to RemoteBufferHandle() should increase the remote_buffer_id
        assert rbh_default_1.remote_buffer_id == 1
        assert rbh_default_2.remote_buffer_id == 2
        assert rbh_default_1 != rbh_default_2

        # Test that no new instance is created
        rbh_3_1 = RemoteBufferHandle(remote_buffer_id=3)
        rbh_3_2 = RemoteBufferHandle(remote_buffer_id=3)
        assert rbh_3_1 == rbh_3_2

        # Clean-up so that the RemoteBufferHandle gets reset
        RemoteBufferHandle._buffers = {}
Esempio n. 4
0
 def test_raises(self):
     """Test that remote_buffer_id=-1 raises NotImplementedError."""
     with pytest.raises(NotImplementedError):
         _ = RemoteBufferHandle(remote_buffer_id=-1,
                                tensor_shape=None,
                                tensor_dtype=None,
                                repeats=1)
     # Clean-up so that the RemoteBufferHandle gets reset
     RemoteBufferHandle._buffers = {}
    def test___init__(self):
        """Test that the remote buffer handles are correctly initialized."""
        # Test that illegal values are captured
        with pytest.raises(NotImplementedError) as e_info:
            _ = RemoteBufferHandle(remote_buffer_id=-1)
            assert e_info.value.args[0] == (
                "remote_buffer_id = -1 (automatic RemoteSetup) "
                "not supported")

        # Clean-up so that the RemoteBufferHandle gets reset
        RemoteBufferHandle._buffers = {}
Esempio n. 6
0
    def test_remote_load_graph(self, use_offset: bool,
                               tensor_shape: Tuple[int, ...], repeats: int,
                               tensor_dtype: dtype, inplace: bool) -> None:
        """Test that the graph is correct when using the remote load op

        Args:
            use_offset (bool): Whether or not to use offset
            tensor_shape (Tuple[int, ...]): The shape of the tensor to be loaded
            repeats (int): The number of tensors potentially stored in the buffer
            tensor_dtype (dtype): The type of the tensors to be loaded
            inplace (bool): Whether or not to use the inplace version of the op
        """
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            t = pir.variable(
                np.random.rand(*tensor_shape).astype(tensor_dtype.as_numpy()))
            if use_offset:
                offset = pir.constant([1], name='offset')
                # With this option the graph should contain
                # 1. t
                # 2. offset
                # 3. out
                n_tensors = 3
            else:
                offset = None
                # With this option the graph should contain
                # 1. t
                # 2. out
                n_tensors = 2

            rbh = RemoteBufferHandle(remote_buffer_id=1,
                                     tensor_shape=tensor_shape,
                                     tensor_dtype=tensor_dtype,
                                     repeats=repeats)

            op = ops.remote_load if not inplace else ops.remote_load_
            op(t, offset, rbh)

        assert len(g.get_tensors()) == n_tensors
        # Only t is a variable
        assert len(g.get_variables()) == 1
        type_string = "RemoteLoad" if not inplace else "RemoteLoadInplace"
        pb_type = _ir.op.exchange.RemoteLoadOp if not inplace else _ir.op.exchange.RemoteLoadInplaceOp
        assert contains_op_of_type(type_string, pb_type, g)

        # Clean-up so that the RemoteBufferHandle gets reset
        RemoteBufferHandle._buffers = {}
Esempio n. 7
0
def build_model(
        data: Dict[str,
                   np.array]) -> Tuple[_ir.Ir, Dict[str, DeviceToHostStream]]:
    """Build a model for storing and loading tensors from the remote buffer.

    Args:
        data(Dict[str, np.array]) : Dict of the data to be stored and loaded from the remote buffer

    Returns:
    (tuple): tuple containing:

        ir._pb_ir (_ir.Ir): The underlying IR
        d2h_streams (Dict[str, DeviceToHostStream]): The output streams
    """
    ir = pir.Ir()
    main = ir.main_graph()

    with main:
        # Placeholder for tensor ids
        tensors = {}
        # Create variable tensors from the data
        for name in data.keys():
            tensors[name] = pir.variable(data[name], name=name)

        # Placeholder for device to host streams
        d2h_streams = {}

        # Store and load the first tensor without specifying the remote buffer handle or offset
        ops.remote_store(t=tensors["store_in_1"])
        tensors["load_out_1"] = ops.remote_load(t=tensors["load_in_1"])
        tensors["load_out_1_inplace"] = ops.remote_load_(
            t=tensors["load_in_1_inplace"])
        # Anchor the input tensors to the load operator
        d2h_streams = make_anchor(d2h_streams, tensors, "load_in_1")
        d2h_streams = make_anchor(d2h_streams, tensors, "load_in_1_inplace")
        # Anchor the output tensors of the load operator
        d2h_streams = make_anchor(d2h_streams, tensors, "load_out_1")
        d2h_streams = make_anchor(d2h_streams, tensors, "load_out_1_inplace")

        # Store and load the second and third tensor using the same buffer id
        # Buffer 1 should already be assigned implicitly, so we chose a different id
        rbh = RemoteBufferHandle(
            remote_buffer_id=42,
            tensor_shape=tensors["store_in_2"]._pb_tensor.info.shape(),
            tensor_dtype=dtype.as_dtype(
                tensors["store_in_2"]._pb_tensor.info.data_type_lcase()),
            repeats=2)
        # Index starts at 0
        offset_tensor_2 = pir.constant(0, name="offset_2")
        offset_tensor_3 = pir.constant(1, name="offset_3")
        ops.remote_store(t=tensors["store_in_2"],
                         offset=offset_tensor_2,
                         remote_buffer_handle=rbh)
        ops.remote_store(t=tensors["store_in_3"],
                         offset=offset_tensor_3,
                         remote_buffer_handle=rbh)
        tensors["load_out_2"] = ops.remote_load(t=tensors["load_in_2"],
                                                offset=offset_tensor_2,
                                                remote_buffer_handle=rbh)
        tensors["load_out_3_inplace"] = ops.remote_load_(
            t=tensors["load_in_3_inplace"],
            offset=offset_tensor_3,
            remote_buffer_handle=rbh)

        # Anchor the input tensors to the load operator
        d2h_streams = make_anchor(d2h_streams, tensors, "load_in_2")
        d2h_streams = make_anchor(d2h_streams, tensors, "load_in_3_inplace")
        # Anchor the output tensors of the load operator
        d2h_streams = make_anchor(d2h_streams, tensors, "load_out_2")
        d2h_streams = make_anchor(d2h_streams, tensors, "load_out_3_inplace")

    return ir._pb_ir, d2h_streams
Esempio n. 8
0
    def test_remote_store_graph(self, use_offset: bool,
                                use_remote_buffer_id: bool, use_rbh: bool,
                                tensor_shape: Tuple[int, ...], repeats: int,
                                tensor_dtype: dtype) -> None:
        """Test that the graph is correct when using the remote store op.

        Args:
            use_offset (bool): Whether or not to use offset
            use_remote_buffer_id (bool): Whether or not to set the remote buffer_id
            use_rbh (bool): Whether or not to specify the remote buffer handle
            tensor_shape (Tuple[int, ...]): The shape of the tensor to be stored
            repeats (int): The number of tensors to potentially store in the buffer
            tensor_dtype (dtype): The type of the tensors to be stored
        """
        ir = pir.Ir()
        g = ir.main_graph()

        with g:
            t = pir.variable(
                np.random.rand(*tensor_shape).astype(tensor_dtype.as_numpy()))
            if use_offset:
                offset = pir.constant([1], name='offset')
                # With this option the graph should contain
                # 1. t
                # 2. offset
                n_tensors = 2
            else:
                offset = None
                # With this option the graph should contain
                # 1. t
                n_tensors = 1

            remote_buffer_id = 1 if use_remote_buffer_id else -1

            if remote_buffer_id == -1:
                with pytest.raises(NotImplementedError):
                    _ = RemoteBufferHandle(remote_buffer_id=remote_buffer_id,
                                           tensor_shape=tensor_shape,
                                           tensor_dtype=tensor_dtype,
                                           repeats=repeats)
                # Clean-up so that the RemoteBufferHandle gets reset
                RemoteBufferHandle._buffers = {}
                return

            if use_rbh:
                rbh = RemoteBufferHandle(remote_buffer_id=remote_buffer_id,
                                         tensor_shape=tensor_shape,
                                         tensor_dtype=tensor_dtype,
                                         repeats=repeats)
            else:
                rbh = None

            ops.remote_store(t, offset, rbh)

        assert len(g.get_tensors()) == n_tensors
        # Only t is a variable
        assert len(g.get_variables()) == 1
        assert contains_op_of_type("RemoteStore",
                                   _ir.op.exchange.RemoteStoreOp, g)

        # Clean-up so that the RemoteBufferHandle gets reset
        RemoteBufferHandle._buffers = {}
    def test_tensor_dtype(self):
        """Test the setters on tensor_dtype works as expected."""
        # Test that attributes are set to all variables of the same instance
        rbh_default_1_1 = RemoteBufferHandle()
        rbh_default_1_2 = RemoteBufferHandle(remote_buffer_id=1)
        assert rbh_default_1_1 == rbh_default_1_2
        rbh_default_1_1.tensor_dtype = int8
        assert rbh_default_1_1.tensor_dtype == rbh_default_1_2.tensor_dtype
        # Check that dtype cannot be written to twice
        rbh_default_1_1.tensor_dtype = int8  # This is ok since it's the same value
        rbh_default_1_2.tensor_dtype = int8  # This is ok since it's the same value
        with pytest.raises(ValueError) as e_info:
            rbh_default_1_1.tensor_dtype = int16  # This is not ok since it's a new value
        assert e_info.value.args[0].startswith("Cannot reset buffer dtype")
        with pytest.raises(ValueError) as e_info:
            rbh_default_1_2.tensor_dtype = int16  # This is not ok since it's a new value
        assert e_info.value.args[0].startswith("Cannot reset buffer dtype")

        # Test the same when dtype is set in the constructor
        rbh_default_2_1 = RemoteBufferHandle(remote_buffer_id=2,
                                             tensor_dtype=int8)
        rbh_default_2_2 = RemoteBufferHandle(remote_buffer_id=2,
                                             tensor_dtype=int8)
        assert rbh_default_2_1 == rbh_default_2_2
        assert rbh_default_2_1.tensor_dtype == rbh_default_2_2.tensor_dtype
        # Check that dtype cannot be written to twice
        rbh_default_2_1.tensor_dtype = int8  # This is ok since it's the same value
        rbh_default_2_2.tensor_dtype = int8  # This is ok since it's the same value
        with pytest.raises(ValueError) as e_info:
            rbh_default_2_1.tensor_dtype = int16  # This is not ok since it's a new value
        assert e_info.value.args[0].startswith("Cannot reset buffer dtype")
        with pytest.raises(ValueError) as e_info:
            rbh_default_2_2.tensor_dtype = int16  # This is not ok since it's a new value
        assert e_info.value.args[0].startswith("Cannot reset buffer dtype")
        # It should not be possible to edit the dtype from the constructor
        with pytest.raises(ValueError) as e_info:
            _ = RemoteBufferHandle(remote_buffer_id=2, tensor_dtype=int16)
        assert e_info.value.args[0].startswith("Cannot reset buffer dtype")

        # Clean-up so that the RemoteBufferHandle gets reset
        RemoteBufferHandle._buffers = {}