def test_sequence_map(self): @tff.tf_computation(tf.int32) def over_threshold(x): return x > 10 @tff.federated_computation(tff.SequenceType(tf.int32)) def foo1(x): return tff.sequence_map(over_threshold, x) self.assertEqual(str(foo1.type_signature), '(int32* -> bool*)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER)) def foo2(x): return tff.sequence_map(over_threshold, x) self.assertEqual( str(foo2.type_signature), '(int32*@SERVER -> bool*@SERVER)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS)) def foo3(x): return tff.sequence_map(over_threshold, x) self.assertEqual( str(foo3.type_signature), '({int32*}@CLIENTS -> {bool*}@CLIENTS)')
def test_consume_infinite_tf_dataset(self): @tff.tf_computation(tff.SequenceType(tf.int64)) def consume(ds): # Consume the first 10 elements of the dataset. return ds.take(10).reduce(np.int64(0), lambda x, y: x + y) self.assertEqual(consume(tf.data.Dataset.range(10).repeat()), 45)
def test_consume_infinite_tf_dataset(self): # TODO(b/131363314): The reference executor should support generating # and returning infinite datasets self.skipTest('b/131363314') @tff.tf_computation(tff.SequenceType(tf.int64)) def consume(ds): # Consume the first 10 elements of the dataset. return ds.take(10).reduce(np.int64(0), lambda x, y: x + y) self.assertEqual(consume(tf.data.Dataset.range(10).repeat()), 45)
def test_with_four_element_dataset_pipeline(self): @tff.tf_computation def comp1(): return tf.data.Dataset.range(5) @tff.tf_computation(tff.SequenceType(tf.int64)) def comp2(ds): return ds.map(lambda x: tf.cast(x + 1, tf.float32)) @tff.tf_computation(tff.SequenceType(tf.float32)) def comp3(ds): return ds.repeat(5) @tff.tf_computation(tff.SequenceType(tf.float32)) def comp4(ds): return ds.reduce(0.0, lambda x, y: x + y) @tff.tf_computation def comp5(): return comp4(comp3(comp2(comp1()))) self.assertEqual(comp5(), 75.0)
def test_sequence_sum(self): @tff.federated_computation(tff.SequenceType(tf.int32)) def foo1(x): return tff.sequence_sum(x) self.assertEqual(str(foo1.type_signature), '(int32* -> int32)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER)) def foo2(x): return tff.sequence_sum(x) self.assertEqual(str(foo2.type_signature), '(int32*@SERVER -> int32@SERVER)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS)) def foo3(x): return tff.sequence_sum(x) self.assertEqual(str(foo3.type_signature), '({int32*}@CLIENTS -> {int32}@CLIENTS)')
def test_with_tf_datasets(self): @tff.tf_computation(tff.SequenceType(tf.int64)) def consume(ds): return ds.reduce(np.int64(0), lambda x, y: x + y) self.assertEqual(str(consume.type_signature), '(int64* -> int64)') @tff.tf_computation def produce(): return tf.data.Dataset.range(10) self.assertEqual(str(produce.type_signature), '( -> int64*)') self.assertEqual(consume(produce()), 45)
def next_fn(server_state, client_data): broadcast_state = tff.federated_broadcast(server_state) @tff.tf_computation(tf.int32, tff.SequenceType(tf.float32)) @tf.function def some_transform(x, y): del y # Unused return x + 1 client_update = tff.federated_map(some_transform, (broadcast_state, client_data)) aggregate_update = tff.federated_sum(client_update) server_output = tff.federated_value(1234, tff.SERVER) return aggregate_update, server_output
def test_sequence_reduce(self): add_numbers = tff.tf_computation(tf.add, [tf.int32, tf.int32]) @tff.federated_computation(tff.SequenceType(tf.int32)) def foo1(x): return tff.sequence_reduce(x, 0, add_numbers) self.assertEqual(str(foo1.type_signature), '(int32* -> int32)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.SERVER)) def foo2(x): return tff.sequence_reduce(x, 0, add_numbers) self.assertEqual( str(foo2.type_signature), '(int32*@SERVER -> int32@SERVER)') @tff.federated_computation( tff.FederatedType(tff.SequenceType(tf.int32), tff.CLIENTS)) def foo3(x): return tff.sequence_reduce(x, 0, add_numbers) self.assertEqual( str(foo3.type_signature), '({int32*}@CLIENTS -> {int32}@CLIENTS)')
def test_call_returned_directly_creates_canonical_form(self): @tff.federated_computation def init_fn(): return tff.federated_value(42, tff.SERVER) @tff.federated_computation(tff.FederatedType(tf.int32, tff.SERVER), tff.FederatedType( tff.SequenceType(tf.float32), tff.CLIENTS)) def next_fn(server_state, client_data): broadcast_state = tff.federated_broadcast(server_state) @tff.tf_computation(tf.int32, tff.SequenceType(tf.float32)) @tf.function def some_transform(x, y): del y # Unused return x + 1 client_update = tff.federated_map(some_transform, (broadcast_state, client_data)) aggregate_update = tff.federated_sum(client_update) server_output = tff.federated_value(1234, tff.SERVER) return aggregate_update, server_output @tff.federated_computation( tff.FederatedType(tf.int32, tff.SERVER), tff.FederatedType(computation_types.SequenceType(tf.float32), tff.CLIENTS)) def nested_next_fn(server_state, client_data): return next_fn(server_state, client_data) iterative_process = computation_utils.IterativeProcess( init_fn, nested_next_fn) cf = canonical_form_utils.get_canonical_form_for_iterative_process( iterative_process) self.assertIsInstance(cf, canonical_form.CanonicalForm)
def test_tf_comp_with_sequence_inputs_and_outputs_does_not_fail(self): @tff.tf_computation(tff.SequenceType(tf.int32)) def _(x): return x