Exemple #1
0
    def test_create_struct_reraises_type_error(self, mock_stub):
        instance = mock_stub.return_value
        instance.CreateStruct = mock.Mock(side_effect=TypeError)
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.TensorType(tf.int32)
        value_1 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)
        value_2 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)

        with self.assertRaises(TypeError):
            loop.run_until_complete(executor.create_struct([value_1, value_2]))
    def test_create_struct_reraises_grpc_error(self, mock_stub):
        mock_stub.create_struct = mock.Mock(
            side_effect=_raise_non_retryable_grpc_error)
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.TensorType(tf.int32)
        value_1 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)
        value_2 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)

        with self.assertRaises(grpc.RpcError) as context:
            asyncio.run(executor.create_struct([value_1, value_2]))

        self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
    def test_create_struct_returns_remote_value(self, mock_stub):
        mock_stub.create_struct.return_value = executor_pb2.CreateStructResponse(
        )
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.TensorType(tf.int32)
        value_1 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)
        value_2 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)

        result = asyncio.run(executor.create_struct([value_1, value_2]))

        mock_stub.create_struct.assert_called_once()
        self.assertIsInstance(result, remote_executor.RemoteValue)
    def test_create_tuple_raises_retryable_error_on_grpc_error_unavailable(
            self, mock_stub):
        instance = mock_stub.return_value
        instance.CreateTuple = mock.Mock(
            side_effect=_raise_grpc_error_unavailable)
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.TensorType(tf.int32)
        value_1 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)
        value_2 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)

        with self.assertRaises(execution_context.RetryableError):
            loop.run_until_complete(executor.create_tuple([value_1, value_2]))
Exemple #5
0
  def test_create_struct_returns_remote_value(self, mock_stub):
    response = executor_pb2.CreateStructResponse()
    instance = mock_stub.return_value
    instance.CreateStruct = mock.Mock(side_effect=[response])
    loop = asyncio.get_event_loop()
    executor = create_remote_executor()
    type_signature = computation_types.TensorType(tf.int32)
    value_1 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                          type_signature, executor)
    value_2 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                          type_signature, executor)

    result = loop.run_until_complete(executor.create_struct([value_1, value_2]))

    instance.CreateStruct.assert_called_once()
    self.assertIsInstance(result, remote_executor.RemoteValue)
Exemple #6
0
    def test_create_struct_reraises_grpc_error(self, mock_stub):
        instance = mock_stub.return_value
        instance.CreateStruct = mock.Mock(
            side_effect=_raise_non_retryable_grpc_error)
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.TensorType(tf.int32)
        value_1 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)
        value_2 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)

        with self.assertRaises(grpc.RpcError) as context:
            loop.run_until_complete(executor.create_struct([value_1, value_2]))

        self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
    def test_create_struct_returns_remote_value(self, mock_stub):

        response = executor_pb2.ExecuteResponse(
            create_struct=executor_pb2.CreateStructResponse())
        executor = _setup_mock_streaming_executor(mock_stub, response)
        loop = asyncio.get_event_loop()

        type_signature = computation_types.TensorType(tf.int32)
        value_1 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)
        value_2 = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                              type_signature, executor)

        result = loop.run_until_complete(
            executor.create_struct([value_1, value_2]))

        self.assertIsInstance(result, remote_executor.RemoteValue)
    def test_create_selection_reraises_type_error(self, mock_stub):
        mock_stub.create_selection = mock.Mock(side_effect=TypeError)
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.StructType([tf.int32, tf.int32])
        source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                             type_signature, executor)

        with self.assertRaises(TypeError):
            asyncio.run(executor.create_selection(source, 0))
    def test_create_call_reraises_type_error(self, mock_stub):
        mock_stub.create_call = mock.Mock(side_effect=TypeError)
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.FunctionType(None, tf.int32)
        comp = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                           type_signature, executor)

        with self.assertRaises(TypeError):
            asyncio.run(executor.create_call(comp))
Exemple #10
0
    def test_create_call_reraises_type_error(self, mock_stub):
        instance = mock_stub.return_value
        instance.CreateCall = mock.Mock(side_effect=TypeError)
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.FunctionType(None, tf.int32)
        comp = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                           type_signature, executor)

        with self.assertRaises(TypeError):
            loop.run_until_complete(executor.create_call(comp))
Exemple #11
0
    def test_create_selection_reraises_type_error(self, mock_stub):
        instance = mock_stub.return_value
        instance.CreateSelection = mock.Mock(side_effect=TypeError)
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.StructType([tf.int32, tf.int32])
        source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                             type_signature, executor)

        with self.assertRaises(TypeError):
            loop.run_until_complete(executor.create_selection(source, index=0))
    def test_create_call_returns_remote_value(self, mock_stub):
        response = executor_pb2.ExecuteResponse(
            create_call=executor_pb2.CreateCallResponse())
        executor = _setup_mock_streaming_executor(mock_stub, response)
        loop = asyncio.get_event_loop()
        type_signature = computation_types.FunctionType(None, tf.int32)
        fn = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                         type_signature, executor)

        result = loop.run_until_complete(executor.create_call(fn, None))

        self.assertIsInstance(result, remote_executor.RemoteValue)
    def test_create_call_returns_remote_value(self, mock_stub):
        mock_stub.create_call.return_value = executor_pb2.CreateCallResponse()
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.FunctionType(None, tf.int32)
        fn = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                         type_signature, executor)

        result = asyncio.run(executor.create_call(fn, None))

        mock_stub.create_call.assert_called_once()
        self.assertIsInstance(result, remote_executor.RemoteValue)
Exemple #14
0
    def test_compute_raises_retryable_error_on_grpc_error_unavailable(
            self, mock_stub):
        instance = mock_stub.return_value
        instance.Compute = mock.Mock(side_effect=_raise_grpc_error_unavailable)
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.FunctionType(None, tf.int32)
        comp = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                           type_signature, executor)

        with self.assertRaises(execution_context.RetryableError):
            loop.run_until_complete(comp.compute())
    def test_create_selection_raises_retryable_error_on_grpc_error_unavailable(
            self, mock_stub):
        instance = mock_stub.return_value
        instance.CreateSelection = mock.Mock(
            side_effect=_raise_grpc_error_unavailable)
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.NamedTupleType([tf.int32, tf.int32])
        source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                             type_signature, executor)

        with self.assertRaises(execution_context.RetryableError):
            loop.run_until_complete(executor.create_selection(source, index=0))
    def test_compute_reraises_grpc_error(self, mock_stub):
        mock_stub.compute = mock.Mock(
            side_effect=_raise_non_retryable_grpc_error)
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.FunctionType(None, tf.int32)
        comp = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                           type_signature, executor)

        with self.assertRaises(grpc.RpcError) as context:
            asyncio.run(comp.compute())

        self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
    def test_create_selection_returns_remote_value(self, mock_stub):
        response = executor_pb2.ExecuteResponse(
            create_selection=executor_pb2.CreateSelectionResponse())
        executor = _setup_mock_streaming_executor(mock_stub, response)
        loop = asyncio.get_event_loop()
        type_signature = computation_types.StructType([tf.int32, tf.int32])
        source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                             type_signature, executor)

        result = loop.run_until_complete(
            executor.create_selection(source, index=0))

        self.assertIsInstance(result, remote_executor.RemoteValue)
    def test_create_selection_returns_remote_value(self, mock_stub):
        mock_stub.create_selection.return_value = executor_pb2.CreateSelectionResponse(
        )
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        type_signature = computation_types.StructType([tf.int32, tf.int32])
        source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                             type_signature, executor)

        result = asyncio.run(executor.create_selection(source, 0))

        mock_stub.create_selection.assert_called_once()
        self.assertIsInstance(result, remote_executor.RemoteValue)
Exemple #19
0
    def test_create_selection_returns_remote_value(self, mock_stub):
        response = executor_pb2.CreateSelectionResponse()
        instance = mock_stub.return_value
        instance.CreateSelection = mock.Mock(side_effect=[response])
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.StructType([tf.int32, tf.int32])
        source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                             type_signature, executor)

        result = loop.run_until_complete(executor.create_selection(source, 0))

        instance.CreateSelection.assert_called_once()
        self.assertIsInstance(result, remote_executor.RemoteValue)
Exemple #20
0
    def test_create_call_returns_remote_value(self, mock_stub):
        response = executor_pb2.CreateCallResponse()
        instance = mock_stub.return_value
        instance.CreateCall = mock.Mock(side_effect=[response])
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.FunctionType(None, tf.int32)
        fn = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                         type_signature, executor)

        result = loop.run_until_complete(executor.create_call(fn, None))

        instance.CreateCall.assert_called_once()
        self.assertIsInstance(result, remote_executor.RemoteValue)
    def test_create_selection_reraises_non_retryable_grpc_error(
            self, mock_stub):
        instance = mock_stub.return_value
        instance.CreateSelection = mock.Mock(
            side_effect=_raise_non_retryable_grpc_error)
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.NamedTupleType([tf.int32, tf.int32])
        source = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                             type_signature, executor)

        with self.assertRaises(grpc.RpcError) as context:
            loop.run_until_complete(executor.create_selection(source, index=0))

        self.assertEqual(context.exception.code(), grpc.StatusCode.ABORTED)
    def test_compute_returns_result(self, mock_stub):
        tensor_proto = tf.make_tensor_proto(1)
        any_pb = any_pb2.Any()
        any_pb.Pack(tensor_proto)
        value = executor_pb2.Value(tensor=any_pb)
        mock_stub.compute.return_value = executor_pb2.ComputeResponse(
            value=value)
        executor = remote_executor.RemoteExecutor(mock_stub)
        _set_cardinalities_with_mock(executor, mock_stub)
        executor.set_cardinalities({placements.CLIENTS: 3})
        type_signature = computation_types.FunctionType(None, tf.int32)
        comp = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                           type_signature, executor)

        result = asyncio.run(comp.compute())

        mock_stub.compute.assert_called_once()
        self.assertEqual(result, 1)
Exemple #23
0
    def test_compute_returns_result(self, mock_stub):
        tensor_proto = tf.make_tensor_proto(1)
        any_pb = any_pb2.Any()
        any_pb.Pack(tensor_proto)
        value = executor_pb2.Value(tensor=any_pb)
        response = executor_pb2.ComputeResponse(value=value)
        instance = mock_stub.return_value
        instance.Compute = mock.Mock(side_effect=[response])
        loop = asyncio.get_event_loop()
        executor = create_remote_executor()
        type_signature = computation_types.FunctionType(None, tf.int32)
        comp = remote_executor.RemoteValue(executor_pb2.ValueRef(),
                                           type_signature, executor)

        result = loop.run_until_complete(comp.compute())

        instance.Compute.assert_called_once()
        self.assertEqual(result, 1)