Example #1
0
def _visit_type(type_signature, function):
    def inner_function(inner_type):
        function(inner_type)
        return inner_type, False

    type_transformations.transform_type_postorder(type_signature,
                                                  inner_function)
Example #2
0
 def test_transforms_tensor(self):
     orig_type = computation_types.to_type(tf.int32)
     expected_type = computation_types.to_type(tf.float32)
     result_type, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_tensor_to_float)
     noop_type, not_mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_abstract_type_to_tensor)
     self.assertEqual(result_type, expected_type)
     self.assertEqual(noop_type, orig_type)
     self.assertTrue(mutated)
     self.assertFalse(not_mutated)
Example #3
0
 def test_recurses_under_sequence(self):
     orig_type = computation_types.SequenceType([tf.int32])
     expected_type = computation_types.SequenceType([tf.float32])
     result_type, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_tensor_to_float)
     noop_type, not_mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_abstract_type_to_tensor)
     self.assertEqual(result_type, expected_type)
     self.assertEqual(noop_type, orig_type)
     self.assertTrue(mutated)
     self.assertFalse(not_mutated)
Example #4
0
 def test_transforms_federated_type(self):
     orig_type = computation_types.FederatedType(tf.int32,
                                                 placement_literals.CLIENTS)
     expected_type = computation_types.FederatedType(
         tf.float32, placement_literals.CLIENTS)
     result_type, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_tensor_to_float)
     noop_type, not_mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_abstract_type_to_tensor)
     self.assertEqual(result_type, expected_type)
     self.assertEqual(noop_type, orig_type)
     self.assertTrue(mutated)
     self.assertFalse(not_mutated)
Example #5
0
 def test_transforms_named_tuple_type_preserving_tuple_container(self):
     orig_type = computation_types.NamedTupleTypeWithPyContainerType(
         [('a', tf.int32), ('b', tf.float64)], dict)
     expected_type = computation_types.NamedTupleTypeWithPyContainerType(
         [('a', tf.float32), ('b', tf.float32)], dict)
     result_type, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_tensor_to_float)
     noop_type, not_mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_abstract_type_to_tensor)
     self.assertEqual(result_type, expected_type)
     self.assertEqual(noop_type, orig_type)
     self.assertTrue(mutated)
     self.assertFalse(not_mutated)
Example #6
0
 def test_recurses_under_named_tuple_type(self):
     orig_type = computation_types.to_type([[('a', tf.int32),
                                             ('b', tf.float64)]])
     expected_type = computation_types.to_type([[('a', tf.float32),
                                                 ('b', tf.float32)]])
     result_type, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_tensor_to_float)
     noop_type, not_mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_abstract_type_to_tensor)
     self.assertEqual(result_type, expected_type)
     self.assertEqual(noop_type, orig_type)
     self.assertTrue(mutated)
     self.assertFalse(not_mutated)
Example #7
0
 def test_raises_on_non_type_first_arg(self):
     with self.assertRaises(TypeError):
         type_transformations.transform_type_postorder(tf.int32, None)
Example #8
0
 def test_raises_on_none_function(self):
     with self.assertRaises(TypeError):
         type_transformations.transform_type_postorder(
             computation_types.to_type(tf.int32), None)
Example #9
0
 def test_raises_on_none_type(self):
     with self.assertRaises(TypeError):
         type_transformations.transform_type_postorder(None, lambda x: x)
Example #10
0
 def test_updates_mutated_bit_at_tuple(self):
     orig_type = computation_types.to_type([tf.int32, tf.float64])
     _, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_tuple_to_tensor)
     self.assertTrue(mutated)
Example #11
0
 def test_updates_mutated_bit_at_function(self):
     orig_type = computation_types.FunctionType(tf.int32, tf.int64)
     _, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_function_to_tensor)
     self.assertTrue(mutated)
Example #12
0
 def test_updates_mutated_bit_at_sequence(self):
     orig_type = computation_types.SequenceType(tf.int32)
     _, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_sequence_to_tensor)
     self.assertTrue(mutated)
Example #13
0
 def test_updates_mutated_bit_at_federated(self):
     orig_type = computation_types.FederatedType(tf.int32,
                                                 placement_literals.CLIENTS)
     _, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_federated_to_tensor)
     self.assertTrue(mutated)
def create_constant(scalar_value, type_spec) -> pb.Computation:
    """Returns a tensorflow computation returning a constant `scalar_value`.

  The returned computation has the type signature `( -> T)`, where `T` is
  `type_spec`.

  `scalar_value` must be a scalar, and cannot be a float if any of the tensor
  leaves of `type_spec` contain an integer data type. `type_spec` must contain
  only named tuples and tensor types, but these can be arbitrarily nested.

  Args:
    scalar_value: A scalar value to place in all the tensor leaves of
      `type_spec`.
    type_spec: A type convertible to instance of `computation_types.Type` via
      `computation_types.to_type` and whose resulting type tree can only contain
      named tuples and tensors.

  Raises:
    TypeError: If the constraints of `type_spec` are violated.
  """
    type_spec = computation_types.to_type(type_spec)

    if not type_utils.is_generic_op_compatible_type(type_spec):
        raise TypeError(
            'Type spec {} cannot be constructed as a TensorFlow constant in TFF; '
            ' only nested tuples and tensors are permitted.'.format(type_spec))
    inferred_scalar_value_type = type_utils.infer_type(scalar_value)
    if (not isinstance(inferred_scalar_value_type,
                       computation_types.TensorType)
            or inferred_scalar_value_type.shape != tf.TensorShape(())):
        raise TypeError(
            'Must pass a scalar value to `create_tensorflow_constant`; encountered '
            'a value {}'.format(scalar_value))
    tensor_dtypes_in_type_spec = []

    def _pack_dtypes(type_signature):
        """Appends dtype of `type_signature` to nonlocal variable."""
        if isinstance(type_signature, computation_types.TensorType):
            tensor_dtypes_in_type_spec.append(type_signature.dtype)
        return type_signature, False

    type_transformations.transform_type_postorder(type_spec, _pack_dtypes)

    if (any(x.is_integer for x in tensor_dtypes_in_type_spec)
            and not inferred_scalar_value_type.dtype.is_integer):
        raise TypeError(
            'Only integers can be used as scalar values if our desired constant '
            'type spec contains any integer tensors; passed scalar {} of dtype {} '
            'for type spec {}.'.format(scalar_value,
                                       inferred_scalar_value_type.dtype,
                                       type_spec))

    def _create_result_tensor(type_spec, scalar_value):
        """Packs `scalar_value` into `type_spec` recursively."""
        if isinstance(type_spec, computation_types.TensorType):
            type_spec.shape.assert_is_fully_defined()
            result = tf.constant(scalar_value,
                                 dtype=type_spec.dtype,
                                 shape=type_spec.shape)
        else:
            elements = []
            for _, type_element in anonymous_tuple.iter_elements(type_spec):
                elements.append(
                    _create_result_tensor(type_element, scalar_value))
            result = elements
        return result

    with tf.Graph().as_default() as graph:
        result = _create_result_tensor(type_spec, scalar_value)
        _, result_binding = tensorflow_utils.capture_result_from_graph(
            result, graph)

    type_signature = computation_types.FunctionType(None, type_spec)
    tensorflow = pb.TensorFlow(graph_def=serialization_utils.pack_graph_def(
        graph.as_graph_def()),
                               parameter=None,
                               result=result_binding)
    return pb.Computation(
        type=type_serialization.serialize_type(type_signature),
        tensorflow=tensorflow)