Beispiel #1
0
  def test_complex_param_tf_computation(self):
    # The eager executor (inside the local executor stack) has issue with
    # tf2_computation and the v1 version of SavedModel, hence the test above is
    # replicated here.

    MyType = collections.namedtuple('MyType', ['x', 'd'])  # pylint: disable=invalid-name

    @tf.function
    def foo(t, odict, unnamed_tuple):
      self.assertIsInstance(t, MyType)
      self.assertIsInstance(t.d, dict)
      self.assertIsInstance(odict, collections.OrderedDict)
      self.assertIsInstance(unnamed_tuple, tuple)
      return t.x + t.d['y'] + t.d['z'] + odict['o'] + unnamed_tuple[0]

    args = [
        MyType(1, dict(y=2, z=3)),
        collections.OrderedDict([('o', 0)]), (0,)
    ]
    arg_type = [
        MyType(tf.int32, collections.OrderedDict(y=tf.int32, z=tf.int32)),
        collections.OrderedDict([('o', tf.int32)]), (tf.int32,)
    ]

    # Explicit type
    tf_comp = tff.tf_computation(foo, arg_type)
    self.assertEqual(tf_comp(*args), 6)

    # Polymorphic
    tf_comp = tff.tf_computation(foo)
    self.assertEqual(tf_comp(*args), 6)
Beispiel #2
0
 def foo(temperatures, threshold):
   return tff.federated_sum(
       tff.federated_map(
           tff.tf_computation(
               lambda x, y: tf.cast(tf.greater(x, y), tf.int32),
               [tf.float32, tf.float32]),
           [temperatures, tff.federated_broadcast(threshold)]))
  def test_tf_comp_first_mode_of_usage_as_non_polymorphic_wrapper(self):
    # Wrapping a lambda with a parameter.
    foo = tff.tf_computation(lambda x: x > 10, tf.int32)
    self.assertEqual(str(foo.type_signature), '(int32 -> bool)')
    self.assertEqual(foo(9), False)
    self.assertEqual(foo(11), True)

    # Wrapping an existing Python function with a parameter.
    bar = tff.tf_computation(tf.add, (tf.int32, tf.int32))
    self.assertEqual(str(bar.type_signature), '(<int32,int32> -> int32)')

    # Wrapping a no-parameter lambda.
    baz = tff.tf_computation(lambda: tf.constant(10))
    self.assertEqual(str(baz.type_signature), '( -> int32)')
    self.assertEqual(baz(), 10)

    # Wrapping a no-parameter Python function.
    def bak_fn():
      return tf.constant(10)

    bak = tff.tf_computation(bak_fn)
    self.assertEqual(str(bak.type_signature), '( -> int32)')
    self.assertEqual(bak(), 10)
  def test_tf_comp_third_mode_of_usage_as_polymorphic_callable(self):
    # Wrapping a lambda.
    foo = tff.tf_computation(lambda x: x > 0)

    self.assertEqual(foo(-1), False)
    self.assertEqual(foo(0), False)
    self.assertEqual(foo(1), True)

    # Decorating a Python function.
    @tff.tf_computation
    def bar(x, y):
      return x > y

    self.assertEqual(bar(0, 1), False)
    self.assertEqual(bar(1, 0), True)
    self.assertEqual(bar(0, 0), False)
Beispiel #5
0
  def test_sequence_reduce(self):
    add_numbers = tff.tf_computation(tf.add, [tf.int32, tf.int32])

    @tff.federated_computation(tff.SequenceType(tf.int32))
    def foo1(x):
      return tff.sequence_reduce(x, 0, add_numbers)

    self.assertEqual(str(foo1.type_signature), '(int32* -> int32)')

    @tff.federated_computation(
        tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER))
    def foo2(x):
      return tff.sequence_reduce(x, 0, add_numbers)

    self.assertEqual(
        str(foo2.type_signature), '(int32*@SERVER -> int32@SERVER)')

    @tff.federated_computation(
        tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS))
    def foo3(x):
      return tff.sequence_reduce(x, 0, add_numbers)

    self.assertEqual(
        str(foo3.type_signature), '({int32*}@CLIENTS -> {int32}@CLIENTS)')
Beispiel #6
0
 def _(x):
   return tff.federated_map(
       tff.tf_computation(lambda x: x > 10, tf.int32), x)
Beispiel #7
0
 def foo(x, y):
   return tff.federated_apply(
       tff.tf_computation(lambda x, y: x > 10, [tf.int32, tf.int32]), [x, y])
Beispiel #8
0
 def foo(x):
   return tff.federated_apply(
       tff.tf_computation(lambda x: x > 10, tf.int32), x)
Beispiel #9
0
 def foo(x):
   plus = tff.tf_computation(tf.add, [tf.int32, tf.int32])
   return tff.federated_reduce(x, 0, plus)