def _visit_type(type_signature, function): def inner_function(inner_type): function(inner_type) return inner_type, False type_transformations.transform_type_postorder(type_signature, inner_function)
def test_transforms_tensor(self): orig_type = computation_types.to_type(tf.int32) expected_type = computation_types.to_type(tf.float32) result_type, mutated = type_transformations.transform_type_postorder( orig_type, _convert_tensor_to_float) noop_type, not_mutated = type_transformations.transform_type_postorder( orig_type, _convert_abstract_type_to_tensor) self.assertEqual(result_type, expected_type) self.assertEqual(noop_type, orig_type) self.assertTrue(mutated) self.assertFalse(not_mutated)
def test_recurses_under_sequence(self): orig_type = computation_types.SequenceType([tf.int32]) expected_type = computation_types.SequenceType([tf.float32]) result_type, mutated = type_transformations.transform_type_postorder( orig_type, _convert_tensor_to_float) noop_type, not_mutated = type_transformations.transform_type_postorder( orig_type, _convert_abstract_type_to_tensor) self.assertEqual(result_type, expected_type) self.assertEqual(noop_type, orig_type) self.assertTrue(mutated) self.assertFalse(not_mutated)
def test_transforms_federated_type(self): orig_type = computation_types.FederatedType(tf.int32, placement_literals.CLIENTS) expected_type = computation_types.FederatedType( tf.float32, placement_literals.CLIENTS) result_type, mutated = type_transformations.transform_type_postorder( orig_type, _convert_tensor_to_float) noop_type, not_mutated = type_transformations.transform_type_postorder( orig_type, _convert_abstract_type_to_tensor) self.assertEqual(result_type, expected_type) self.assertEqual(noop_type, orig_type) self.assertTrue(mutated) self.assertFalse(not_mutated)
def test_transforms_named_tuple_type_preserving_tuple_container(self): orig_type = computation_types.NamedTupleTypeWithPyContainerType( [('a', tf.int32), ('b', tf.float64)], dict) expected_type = computation_types.NamedTupleTypeWithPyContainerType( [('a', tf.float32), ('b', tf.float32)], dict) result_type, mutated = type_transformations.transform_type_postorder( orig_type, _convert_tensor_to_float) noop_type, not_mutated = type_transformations.transform_type_postorder( orig_type, _convert_abstract_type_to_tensor) self.assertEqual(result_type, expected_type) self.assertEqual(noop_type, orig_type) self.assertTrue(mutated) self.assertFalse(not_mutated)
def test_recurses_under_named_tuple_type(self): orig_type = computation_types.to_type([[('a', tf.int32), ('b', tf.float64)]]) expected_type = computation_types.to_type([[('a', tf.float32), ('b', tf.float32)]]) result_type, mutated = type_transformations.transform_type_postorder( orig_type, _convert_tensor_to_float) noop_type, not_mutated = type_transformations.transform_type_postorder( orig_type, _convert_abstract_type_to_tensor) self.assertEqual(result_type, expected_type) self.assertEqual(noop_type, orig_type) self.assertTrue(mutated) self.assertFalse(not_mutated)
def test_raises_on_non_type_first_arg(self): with self.assertRaises(TypeError): type_transformations.transform_type_postorder(tf.int32, None)
def test_raises_on_none_function(self): with self.assertRaises(TypeError): type_transformations.transform_type_postorder( computation_types.to_type(tf.int32), None)
def test_raises_on_none_type(self): with self.assertRaises(TypeError): type_transformations.transform_type_postorder(None, lambda x: x)
def test_updates_mutated_bit_at_tuple(self): orig_type = computation_types.to_type([tf.int32, tf.float64]) _, mutated = type_transformations.transform_type_postorder( orig_type, _convert_tuple_to_tensor) self.assertTrue(mutated)
def test_updates_mutated_bit_at_function(self): orig_type = computation_types.FunctionType(tf.int32, tf.int64) _, mutated = type_transformations.transform_type_postorder( orig_type, _convert_function_to_tensor) self.assertTrue(mutated)
def test_updates_mutated_bit_at_sequence(self): orig_type = computation_types.SequenceType(tf.int32) _, mutated = type_transformations.transform_type_postorder( orig_type, _convert_sequence_to_tensor) self.assertTrue(mutated)
def test_updates_mutated_bit_at_federated(self): orig_type = computation_types.FederatedType(tf.int32, placement_literals.CLIENTS) _, mutated = type_transformations.transform_type_postorder( orig_type, _convert_federated_to_tensor) self.assertTrue(mutated)
def create_constant(scalar_value, type_spec) -> pb.Computation: """Returns a tensorflow computation returning a constant `scalar_value`. The returned computation has the type signature `( -> T)`, where `T` is `type_spec`. `scalar_value` must be a scalar, and cannot be a float if any of the tensor leaves of `type_spec` contain an integer data type. `type_spec` must contain only named tuples and tensor types, but these can be arbitrarily nested. Args: scalar_value: A scalar value to place in all the tensor leaves of `type_spec`. type_spec: A type convertible to instance of `computation_types.Type` via `computation_types.to_type` and whose resulting type tree can only contain named tuples and tensors. Raises: TypeError: If the constraints of `type_spec` are violated. """ type_spec = computation_types.to_type(type_spec) if not type_utils.is_generic_op_compatible_type(type_spec): raise TypeError( 'Type spec {} cannot be constructed as a TensorFlow constant in TFF; ' ' only nested tuples and tensors are permitted.'.format(type_spec)) inferred_scalar_value_type = type_utils.infer_type(scalar_value) if (not isinstance(inferred_scalar_value_type, computation_types.TensorType) or inferred_scalar_value_type.shape != tf.TensorShape(())): raise TypeError( 'Must pass a scalar value to `create_tensorflow_constant`; encountered ' 'a value {}'.format(scalar_value)) tensor_dtypes_in_type_spec = [] def _pack_dtypes(type_signature): """Appends dtype of `type_signature` to nonlocal variable.""" if isinstance(type_signature, computation_types.TensorType): tensor_dtypes_in_type_spec.append(type_signature.dtype) return type_signature, False type_transformations.transform_type_postorder(type_spec, _pack_dtypes) if (any(x.is_integer for x in tensor_dtypes_in_type_spec) and not inferred_scalar_value_type.dtype.is_integer): raise TypeError( 'Only integers can be used as scalar values if our desired constant ' 'type spec contains any integer tensors; passed scalar {} of dtype {} ' 'for type spec {}.'.format(scalar_value, inferred_scalar_value_type.dtype, type_spec)) def _create_result_tensor(type_spec, scalar_value): """Packs `scalar_value` into `type_spec` recursively.""" if isinstance(type_spec, computation_types.TensorType): type_spec.shape.assert_is_fully_defined() result = tf.constant(scalar_value, dtype=type_spec.dtype, shape=type_spec.shape) else: elements = [] for _, type_element in anonymous_tuple.iter_elements(type_spec): elements.append( _create_result_tensor(type_element, scalar_value)) result = elements return result with tf.Graph().as_default() as graph: result = _create_result_tensor(type_spec, scalar_value) _, result_binding = tensorflow_utils.capture_result_from_graph( result, graph) type_signature = computation_types.FunctionType(None, type_spec) tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def( graph.as_graph_def()), parameter=None, result=result_binding) return pb.Computation( type=type_serialization.serialize_type(type_signature), tensorflow=tensorflow)