Ejemplo n.º 1
0
    def basic_federated_select_args(self):
        values = ['first', 'second', 'third']
        server_val = intrinsics.federated_value(values, placements.SERVER)

        max_key_py = len(values)
        max_key = intrinsics.federated_value(max_key_py, placements.SERVER)

        def get_three_random_keys_fn():
            return tf.random.uniform(shape=[3],
                                     minval=0,
                                     maxval=max_key_py,
                                     dtype=tf.int32)

        get_three_random_keys_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
            get_three_random_keys_fn, None)
        get_three_random_keys = computation_impl.ConcreteComputation(
            get_three_random_keys_proto, context_stack_impl.context_stack)
        client_keys = intrinsics.federated_eval(get_three_random_keys,
                                                placements.CLIENTS)

        state_type = server_val.type_signature.member

        def _select_fn(arg):
            state = type_conversions.type_to_py_container(arg[0], state_type)
            key = arg[1]
            return tf.gather(state, key)

        select_fn_type = computation_types.StructType([state_type, tf.int32])
        select_fn_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
            _select_fn, select_fn_type)
        select_fn = computation_impl.ConcreteComputation(
            select_fn_proto, context_stack_impl.context_stack)

        return (client_keys, max_key, server_val, select_fn)
Ejemplo n.º 2
0
    def test_federated_aggregate_with_unknown_dimension(self):
        Accumulator = collections.namedtuple('Accumulator', ['samples'])  # pylint: disable=invalid-name
        accumulator_type = computation_types.to_type(
            Accumulator(samples=computation_types.TensorType(dtype=tf.int32,
                                                             shape=[None])))

        x = _mock_data_of_type(computation_types.at_clients(tf.int32))

        def initialize_fn():
            return Accumulator(samples=tf.zeros(shape=[0], dtype=tf.int32))

        initialize_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
            initialize_fn, None)
        initialize = computation_impl.ConcreteComputation(
            initialize_proto, context_stack_impl.context_stack)
        zero = initialize()

        # The operator to use during the first stage simply adds an element to the
        # tensor, increasing its size.
        def _accumulate(arg):
            return Accumulator(samples=tf.concat(
                [arg[0].samples,
                 tf.expand_dims(arg[1], axis=0)], axis=0))

        accumulate_type = computation_types.StructType(
            [accumulator_type, tf.int32])
        accumulate_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
            _accumulate, accumulate_type)
        accumulate = computation_impl.ConcreteComputation(
            accumulate_proto, context_stack_impl.context_stack)

        # The operator to use during the second stage simply adds total and count.
        def _merge(arg):
            return Accumulator(
                samples=tf.concat([arg[0].samples, arg[1].samples], axis=0))

        merge_type = computation_types.StructType(
            [accumulator_type, accumulator_type])
        merge_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
            _merge, merge_type)
        merge = computation_impl.ConcreteComputation(
            merge_proto, context_stack_impl.context_stack)

        # The operator to use during the final stage simply computes the ratio.
        report_proto, _ = tensorflow_computation_factory.create_identity(
            accumulator_type)
        report = computation_impl.ConcreteComputation(
            report_proto, context_stack_impl.context_stack)

        value = intrinsics.federated_aggregate(x, zero, accumulate, merge,
                                               report)
        self.assert_value(value, '<samples=int32[?]>@SERVER')
Ejemplo n.º 3
0
    def test_federated_aggregate_with_client_int(self):
        # The representation used during the aggregation process will be a named
        # tuple with 2 elements - the integer 'total' that represents the sum of
        # elements encountered, and the integer element 'count'.
        Accumulator = collections.namedtuple('Accumulator', 'total count')  # pylint: disable=invalid-name
        accumulator_type = computation_types.to_type(
            Accumulator(total=computation_types.TensorType(dtype=tf.int32),
                        count=computation_types.TensorType(dtype=tf.int32)))

        x = _mock_data_of_type(computation_types.at_clients(tf.int32))
        zero = Accumulator(0, 0)

        # The operator to use during the first stage simply adds an element to the
        # total and updates the count.
        def _accumulate(arg):
            return Accumulator(arg[0].total + arg[1], arg[0].count + 1)

        accumulate_type = computation_types.StructType(
            [accumulator_type, tf.int32])
        accumulate_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
            _accumulate, accumulate_type)
        accumulate = computation_impl.ConcreteComputation(
            accumulate_proto, context_stack_impl.context_stack)

        # The operator to use during the second stage simply adds total and count.
        def _merge(arg):
            return Accumulator(arg[0].total + arg[1].total,
                               arg[0].count + arg[1].count)

        merge_type = computation_types.StructType(
            [accumulator_type, accumulator_type])
        merge_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
            _merge, merge_type)
        merge = computation_impl.ConcreteComputation(
            merge_proto, context_stack_impl.context_stack)

        # The operator to use during the final stage simply computes the ratio.
        def _report(arg):
            return tf.cast(arg.total, tf.float32) / tf.cast(
                arg.count, tf.float32)

        report_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
            _report, accumulator_type)
        report = computation_impl.ConcreteComputation(
            report_proto, context_stack_impl.context_stack)

        value = intrinsics.federated_aggregate(x, zero, accumulate, merge,
                                               report)
        self.assert_value(value, 'float32@SERVER')
Ejemplo n.º 4
0
    def test_federated_select_autozips_server_val(self, federated_select):
        client_keys, max_key, server_val, select_fn = (
            self.basic_federated_select_args())
        del server_val, select_fn

        values = ['first', 'second', 'third']
        server_val_element = intrinsics.federated_value(
            values, placements.SERVER)
        server_val_dict = collections.OrderedDict(e1=server_val_element,
                                                  e2=server_val_element)

        state_type = computation_types.StructType([
            ('e1', server_val_element.type_signature.member),
            ('e2', server_val_element.type_signature.member),
        ])

        def _select_fn_dict(arg):
            state = type_conversions.type_to_py_container(arg[0], state_type)
            key = arg[1]
            return (tf.gather(state['e1'], key), tf.gather(state['e2'], key))

        select_fn_dict_type = computation_types.StructType(
            [state_type, tf.int32])
        select_fn_dict_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
            _select_fn_dict, select_fn_dict_type)
        select_fn_dict = computation_impl.ConcreteComputation(
            select_fn_dict_proto, context_stack_impl.context_stack)

        result = federated_select(client_keys, max_key, server_val_dict,
                                  select_fn_dict)
        self.assert_value(result, '{<string,string>*}@CLIENTS')
Ejemplo n.º 5
0
 def call_concrete(*args):
     concrete = computation_impl.ConcreteComputation(
         comp.proto, context_stack_impl.context_stack)
     result = concrete(*args)
     if comp.type_signature.result.is_struct():
         return structure.from_container(result, recursive=True)
     return result
  def test_returns_value_with_computation_impl(self, proto, type_signature):
    executor = create_test_executor()
    value = computation_impl.ConcreteComputation(
        proto, context_stack_impl.context_stack)

    result = self.run_sync(executor.create_value(value, type_signature))

    self.assertIsInstance(result, executor_value_base.ExecutorValue)
    self.assertEqual(result.type_signature.compact_representation(),
                     type_signature.compact_representation())
Ejemplo n.º 7
0
def _create_computation_greater_than_10_with_unused_parameter(
) -> computation_base.Computation:
    parameter_type = computation_types.StructType([
        computation_types.TensorType(tf.int32),
        computation_types.TensorType(tf.int32),
    ])
    computation_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
        lambda arg: arg[0] > 10, parameter_type)
    return computation_impl.ConcreteComputation(
        computation_proto, context_stack_impl.context_stack)
 def test_serialize_deserialize_round_trip(self):
     operand_type = computation_types.TensorType(tf.int32)
     proto, _ = tensorflow_computation_factory.create_binary_operator(
         tf.add, operand_type, operand_type)
     comp = computation_impl.ConcreteComputation(
         proto, context_stack_impl.context_stack)
     serialized_comp = computation_serialization.serialize_computation(comp)
     deserialize_comp = computation_serialization.deserialize_computation(
         serialized_comp)
     self.assertIsInstance(deserialize_comp, computation_base.Computation)
     self.assertEqual(deserialize_comp, comp)
Ejemplo n.º 9
0
  def test_to_value_for_computations(self):
    tensor_type = computation_types.TensorType(tf.int32)
    computation_proto, _ = tensorflow_computation_factory.create_constant(
        10, tensor_type)
    computation = computation_impl.ConcreteComputation(
        computation_proto, context_stack_impl.context_stack)

    value = value_impl.to_value(computation, None)

    self.assertIsInstance(value, value_impl.Value)
    self.assertEqual(value.type_signature.compact_representation(),
                     '( -> int32)')
Ejemplo n.º 10
0
    def test_intrinsic_construction_raises_outside_symbol_binding_context(
            self):
        type_signature = computation_types.TensorType(tf.int32)
        computation_proto, _ = tensorflow_computation_factory.create_constant(
            2, type_signature)
        return_2 = computation_impl.ConcreteComputation(
            computation_proto, context_stack_impl.context_stack)

        with context_stack_impl.context_stack.install(
                runtime_error_context.RuntimeErrorContext()):
            with self.assertRaises(context_base.ContextError):
                intrinsics.federated_eval(return_2, placements.SERVER)
Ejemplo n.º 11
0
    def test_infers_accumulate_return_as_merge_arg_merge_return_as_report_arg(
            self):
        type_spec = computation_types.TensorType(dtype=tf.int64, shape=[None])
        x = _mock_data_of_type(computation_types.at_clients(tf.int64))

        def initialize_fn():
            return tf.constant([], dtype=tf.int64, shape=[0])

        initialize_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
            initialize_fn, None)
        initialize = computation_impl.ConcreteComputation(
            initialize_proto, context_stack_impl.context_stack)
        zero = initialize()

        def _accumulate(arg):
            return tf.concat([arg[0], [arg[1]]], 0)

        accumulate_type = computation_types.StructType([type_spec, tf.int64])
        accumulate_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
            _accumulate, accumulate_type)
        accumulate = computation_impl.ConcreteComputation(
            accumulate_proto, context_stack_impl.context_stack)

        def _merge(arg):
            return tf.concat([arg[0], arg[1]], 0)

        merge_type = computation_types.StructType([type_spec, type_spec])
        merge_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
            _merge, merge_type)
        merge = computation_impl.ConcreteComputation(
            merge_proto, context_stack_impl.context_stack)

        report_proto, _ = tensorflow_computation_factory.create_identity(
            type_spec)
        report = computation_impl.ConcreteComputation(
            report_proto, context_stack_impl.context_stack)

        value = intrinsics.federated_aggregate(x, zero, accumulate, merge,
                                               report)
        self.assert_value(value, 'int64[?]@SERVER')
    def test_invoke_returns_value_with_correct_type(self):
        tensor_type = computation_types.TensorType(tf.int32)
        computation_proto, _ = tensorflow_computation_factory.create_constant(
            10, tensor_type)
        computation = computation_impl.ConcreteComputation(
            computation_proto, context_stack_impl.context_stack)
        context = federated_computation_context.FederatedComputationContext(
            context_stack_impl.context_stack)

        with context_stack_impl.context_stack.install(context):
            result = context.invoke(computation, None)

        self.assertIsInstance(result, value_impl.Value)
        self.assertEqual(str(result.type_signature), 'int32')
Ejemplo n.º 13
0
    def test_federated_select_keys_must_be_fixed_length(
            self, federated_select):
        client_keys, max_key, server_val, select_fn = (
            self.basic_federated_select_args())

        unshape_type = computation_types.TensorType(tf.int32, [None])
        unshape_proto, _ = tensorflow_computation_factory.create_identity(
            unshape_type)
        unshape = computation_impl.ConcreteComputation(
            unshape_proto, context_stack_impl.context_stack)
        bad_client_keys = intrinsics.federated_map(unshape, client_keys)

        with self.assertRaises(TypeError):
            federated_select(bad_client_keys, max_key, server_val, select_fn)
Ejemplo n.º 14
0
 def test_set_local_python_execution_context_and_run_simple_xla_computation(
         self):
     builder = xla_client.XlaBuilder('comp')
     xla_client.ops.Parameter(builder, 0,
                              xla_client.shape_from_pyval(tuple()))
     xla_client.ops.Constant(builder, np.int32(10))
     xla_comp = builder.build()
     comp_type = computation_types.FunctionType(None, np.int32)
     comp_pb = xla_serialization.create_xla_tff_computation(
         xla_comp, [], comp_type)
     ctx_stack = context_stack_impl.context_stack
     comp = computation_impl.ConcreteComputation(comp_pb, ctx_stack)
     execution_contexts.set_local_python_execution_context()
     self.assertEqual(comp(), 10)
Ejemplo n.º 15
0
  def test_with_type_raises_non_assignable_type(self):
    int_return_type = computation_types.FunctionType(tf.int32, tf.int32)
    original_comp = computation_impl.ConcreteComputation(
        pb.Computation(
            **{
                'type': type_serialization.serialize_type(int_return_type),
                'intrinsic': pb.Intrinsic(uri='whatever')
            }), context_stack_impl.context_stack)

    list_return_type = computation_types.FunctionType(
        tf.int32,
        computation_types.StructWithPythonType([(None, tf.int32)], list))
    with self.assertRaises(computation_types.TypeNotAssignableError):
      computation_impl.ConcreteComputation.with_type(original_comp,
                                                     list_return_type)
Ejemplo n.º 16
0
    def test_federated_select_server_val_must_be_server_placed(
            self, federated_select):
        client_keys, max_key, server_val, select_fn = (
            self.basic_federated_select_args())
        del server_val

        bad_server_val_proto, _ = tensorflow_computation_factory.create_constant(
            tf.constant(['first', 'second', 'third']),
            computation_types.TensorType(dtype=tf.string, shape=[3]))
        bad_server_val = computation_impl.ConcreteComputation(
            bad_server_val_proto, context_stack_impl.context_stack)
        bad_server_val = bad_server_val()

        with self.assertRaises(TypeError):
            federated_select(client_keys, max_key, bad_server_val, select_fn)
  def test_invoke_raises_value_error_with_federated_computation(self):
    bogus_proto = pb.Computation(
        type=type_serialization.serialize_type(
            computation_types.to_type(
                computation_types.FunctionType(tf.int32, tf.int32))),
        reference=pb.Reference(name='boogledy'))
    non_tf_computation = computation_impl.ConcreteComputation(
        bogus_proto, context_stack_impl.context_stack)

    context = tensorflow_computation_context.TensorFlowComputationContext(
        tf.compat.v1.get_default_graph(), tf.constant('bogus_token'))

    with self.assertRaisesRegex(
        ValueError, 'Can only invoke TensorFlow in the body of '
        'a TensorFlow computation'):
      context.invoke(non_tf_computation, None)
Ejemplo n.º 18
0
def deserialize_computation(
    computation_proto: pb.Computation) -> computation_base.Computation:
  """Deserializes 'tff.Computation' as a pb.Computation.

  Args:
    computation_proto: An instance of `pb.Computation`.

  Returns:
    The corresponding instance of `tff.Computation`.

  Raises:
    TypeError: If the argument is of the wrong type.
  """
  py_typecheck.check_type(computation_proto, pb.Computation)
  return computation_impl.ConcreteComputation(computation_proto,
                                              context_stack_impl.context_stack)
Ejemplo n.º 19
0
  def test_with_type_preserves_python_container(self):
    struct_return_type = computation_types.FunctionType(
        tf.int32, computation_types.StructType([(None, tf.int32)]))
    original_comp = computation_impl.ConcreteComputation(
        pb.Computation(
            **{
                'type': type_serialization.serialize_type(struct_return_type),
                'intrinsic': pb.Intrinsic(uri='whatever')
            }), context_stack_impl.context_stack)

    list_return_type = computation_types.FunctionType(
        tf.int32,
        computation_types.StructWithPythonType([(None, tf.int32)], list))
    fn_with_annotated_type = computation_impl.ConcreteComputation.with_type(
        original_comp, list_return_type)
    type_test_utils.assert_types_identical(
        list_return_type, fn_with_annotated_type.type_signature)
Ejemplo n.º 20
0
    def test_federated_aggregate_with_federated_zero_fails(self):
        x = _mock_data_of_type(computation_types.at_clients(tf.int32))
        zero = intrinsics.federated_value(0, placements.SERVER)
        accumulate = _create_computation_add()

        # The operator to use during the second stage simply adds total and count.
        merge = _create_computation_add()

        # The operator to use during the final stage simply computes the ratio.
        report_type = computation_types.TensorType(tf.int32)
        report_proto, _ = tensorflow_computation_factory.create_identity(
            report_type)
        report = computation_impl.ConcreteComputation(
            report_proto, context_stack_impl.context_stack)

        with self.assertRaisesRegex(
                TypeError, 'Expected `zero` to be assignable to type int32, '
                'but was of incompatible type int32@SERVER'):
            intrinsics.federated_aggregate(x, zero, accumulate, merge, report)
Ejemplo n.º 21
0
    def test_federated_select_keys_must_be_int32(self, federated_select):
        client_keys, max_key, server_val, select_fn = (
            self.basic_federated_select_args())
        del client_keys

        def get_three_random_keys_fn():
            return tf.random.uniform(shape=[3],
                                     minval=0,
                                     maxval=3,
                                     dtype=tf.int64)

        get_three_random_keys_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
            get_three_random_keys_fn, None)
        get_three_random_keys = computation_impl.ConcreteComputation(
            get_three_random_keys_proto, context_stack_impl.context_stack)
        bad_client_keys = intrinsics.federated_eval(get_three_random_keys,
                                                    placements.CLIENTS)

        with self.assertRaises(TypeError):
            federated_select(bad_client_keys, max_key, server_val, select_fn)
Ejemplo n.º 22
0
    def test_federated_select_fn_must_take_int32_keys(self, federated_select):
        client_keys, max_key, server_val, select_fn = (
            self.basic_federated_select_args())
        del select_fn

        state_type = server_val.type_signature.member

        def _bad_select_fn(arg):
            state = type_conversions.type_to_py_container(arg[0], state_type)
            key = arg[1]
            return tf.gather(state, key)

        bad_select_fn_type = computation_types.StructType(
            [state_type, tf.int64])
        bad_select_fn_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
            _bad_select_fn, bad_select_fn_type)
        bad_select_fn = computation_impl.ConcreteComputation(
            bad_select_fn_proto, context_stack_impl.context_stack)

        with self.assertRaises(TypeError):
            federated_select(client_keys, max_key, server_val, bad_select_fn)
Ejemplo n.º 23
0
def _tf_wrapper_fn(parameter_type, name):
    """Wrapper function to plug Tensorflow logic into the TFF framework."""
    del name  # Unused.
    if not type_analysis.is_tensorflow_compatible_type(parameter_type):
        raise TypeError(
            '`tf_computation`s can accept only parameter types with '
            'constituents `SequenceType`, `StructType` '
            'and `TensorType`; you have attempted to create one '
            'with the type {}.'.format(parameter_type))
    ctx_stack = context_stack_impl.context_stack
    tf_serializer = tensorflow_serialization.tf_computation_serializer(
        parameter_type, ctx_stack)
    arg = next(tf_serializer)
    try:
        result = yield arg
    except Exception as e:  # pylint: disable=broad-except
        tf_serializer.throw(e)
    comp_pb, extra_type_spec = tf_serializer.send(result)
    tf_serializer.close()
    yield computation_impl.ConcreteComputation(comp_pb, ctx_stack,
                                               extra_type_spec)
Ejemplo n.º 24
0
  def test_something(self):
    # TODO(b/113112108): Revise these tests after a more complete implementation
    # is in place.

    # At the moment, this should succeed, as both the computation body and the
    # type are well-formed.
    computation_impl.ConcreteComputation(
        pb.Computation(
            **{
                'type':
                    type_serialization.serialize_type(
                        computation_types.FunctionType(tf.int32, tf.int32)),
                'intrinsic':
                    pb.Intrinsic(uri='whatever')
            }), context_stack_impl.context_stack)

    # This should fail, as the proto is not well-formed.
    self.assertRaises(TypeError, computation_impl.ConcreteComputation,
                      pb.Computation(), context_stack_impl.context_stack)

    # This should fail, as "10" is not an instance of pb.Computation.
    self.assertRaises(TypeError, computation_impl.ConcreteComputation, 10,
                      context_stack_impl.context_stack)
Ejemplo n.º 25
0
def _jax_strategy_fn(fn_to_wrap, fn_name, parameter_type, unpack):
  """Serializes a Python function containing JAX code as a TFF computation.

  Args:
    fn_to_wrap: The Python function containing JAX code to be serialized as a
      computation containing XLA.
    fn_name: The name for the constructed computation (currently ignored).
    parameter_type: An instance of `computation_types.Type` that represents the
      TFF type of the computation parameter, or `None` if there's none.
    unpack: See `unpack` in `function_utils.create_argument_unpacking_fn`.

  Returns:
    An instance of `computation_impl.ConcreteComputation` with the constructed
    computation.
  """
  del fn_name  # Unused.
  unpack_arguments_fn = function_utils.create_argument_unpacking_fn(
      fn_to_wrap, parameter_type, unpack=unpack)
  ctx_stack = context_stack_impl.context_stack
  comp_pb = jax_serialization.serialize_jax_computation(fn_to_wrap,
                                                        unpack_arguments_fn,
                                                        parameter_type,
                                                        ctx_stack)
  return computation_impl.ConcreteComputation(comp_pb, ctx_stack)
Ejemplo n.º 26
0
def _create_computation_add() -> computation_base.Computation:
    operand_type = computation_types.TensorType(tf.int32)
    computation_proto, _ = tensorflow_computation_factory.create_binary_operator(
        tf.add, operand_type, operand_type)
    return computation_impl.ConcreteComputation(
        computation_proto, context_stack_impl.context_stack)
Ejemplo n.º 27
0
def _create_computation_greater_than_10() -> computation_base.Computation:
    parameter_type = computation_types.TensorType(tf.int32)
    computation_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
        lambda x: x > 10, parameter_type)
    return computation_impl.ConcreteComputation(
        computation_proto, context_stack_impl.context_stack)
Ejemplo n.º 28
0
def _create_computation_random() -> computation_base.Computation:
    computation_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
        lambda: tf.random.normal([]), None)
    return computation_impl.ConcreteComputation(
        computation_proto, context_stack_impl.context_stack)
Ejemplo n.º 29
0
def _create_computation_reduce() -> computation_base.Computation:
    parameter_type = computation_types.SequenceType(tf.int32)
    computation_proto, _ = tensorflow_computation_factory.create_computation_for_py_fn(
        lambda ds: ds.reduce(np.int32(0), lambda x, y: x + y), parameter_type)
    return computation_impl.ConcreteComputation(
        computation_proto, context_stack_impl.context_stack)