def test_raises_type_error_with_bad_type(self): value = 10.0 type_signature = tf.int32 with self.assertRaises(TypeError): tensorflow_computation_factory.create_constant( value, type_signature)
def test_raises_type_error_with_non_scalar_value(self): value = np.zeros([1]) type_signature = tf.int32 with self.assertRaises(TypeError): tensorflow_computation_factory.create_constant( value, type_signature)
def create_dummy_computation_tensorflow_constant(): """Returns a tensorflow computation and type `( -> float32)`.""" value = 10.0 tensor_type = computation_types.TensorType(tf.float32) value, type_signature = tensorflow_computation_factory.create_constant( value, tensor_type) return value, type_signature
def create_dummy_computation_tensorflow_constant(): """Returns a tensorflow computation and type `( -> float32)`.""" value = 10.0 type_spec = tf.float32 value = tensorflow_computation_factory.create_constant(value, type_spec) type_signature = computation_types.FunctionType(None, type_spec) return value, type_signature
def test_to_value_for_computations(self): tensor_type = computation_types.TensorType(tf.int32) computation_proto, _ = tensorflow_computation_factory.create_constant( 10, tensor_type) computation = computation_impl.ConcreteComputation( computation_proto, context_stack_impl.context_stack) value = value_impl.to_value(computation, None) self.assertIsInstance(value, value_impl.Value) self.assertEqual(value.type_signature.compact_representation(), '( -> int32)')
def test_intrinsic_construction_raises_outside_symbol_binding_context( self): type_signature = computation_types.TensorType(tf.int32) computation_proto, _ = tensorflow_computation_factory.create_constant( 2, type_signature) return_2 = computation_impl.ConcreteComputation( computation_proto, context_stack_impl.context_stack) with context_stack_impl.context_stack.install( runtime_error_context.RuntimeErrorContext()): with self.assertRaises(context_base.ContextError): intrinsics.federated_eval(return_2, placements.SERVER)
def test_returns_computation(self, value, type_signature, expected_result): proto, _ = tensorflow_computation_factory.create_constant( value, type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, type_signature) expected_type.check_assignable_from(actual_type) actual_result = test_utils.run_tensorflow(proto) if isinstance(expected_result, list): self.assertCountEqual(actual_result, expected_result) else: self.assertEqual(actual_result, expected_result)
def test_returns_computation_with_tuple_unnamed(self): value = 10 type_signature = computation_types.NamedTupleType([tf.int32] * 3) proto = tensorflow_computation_factory.create_constant( value, type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, type_signature) self.assertEqual(actual_type, expected_type) expected_value = [value] * 3 actual_value = test_utils.run_tensorflow(proto, expected_value) self.assertCountEqual(actual_value, expected_value)
def test_returns_computation_with_tensor_float(self): value = 10.0 type_signature = computation_types.TensorType(tf.float32, [3]) proto = tensorflow_computation_factory.create_constant( value, type_signature) self.assertIsInstance(proto, pb.Computation) actual_type = type_serialization.deserialize_type(proto.type) expected_type = computation_types.FunctionType(None, type_signature) self.assertEqual(actual_type, expected_type) expected_value = [value] * 3 actual_value = test_utils.run_tensorflow(proto, expected_value) self.assertCountEqual(actual_value, expected_value)
def test_invoke_returns_value_with_correct_type(self): tensor_type = computation_types.TensorType(tf.int32) computation_proto, _ = tensorflow_computation_factory.create_constant( 10, tensor_type) computation = computation_impl.ConcreteComputation( computation_proto, context_stack_impl.context_stack) context = federated_computation_context.FederatedComputationContext( context_stack_impl.context_stack) with context_stack_impl.context_stack.install(context): result = context.invoke(computation, None) self.assertIsInstance(result, value_impl.Value) self.assertEqual(str(result.type_signature), 'int32')
def test_federated_select_server_val_must_be_server_placed( self, federated_select): client_keys, max_key, server_val, select_fn = ( self.basic_federated_select_args()) del server_val bad_server_val_proto, _ = tensorflow_computation_factory.create_constant( tf.constant(['first', 'second', 'third']), computation_types.TensorType(dtype=tf.string, shape=[3])) bad_server_val = computation_impl.ConcreteComputation( bad_server_val_proto, context_stack_impl.context_stack) bad_server_val = bad_server_val() with self.assertRaises(TypeError): federated_select(client_keys, max_key, bad_server_val, select_fn)
async def embed_tf_scalar_constant(executor, type_spec, value): """Embeds a constant `val` of TFF type `type_spec` in `executor`. Args: executor: An instance of `tff.framework.Executor`. type_spec: An instance of `tff.Type`. value: A scalar value. Returns: An instance of `tff.framework.ExecutorValue` containing an embedded value. """ py_typecheck.check_type(executor, executor_base.Executor) proto = tensorflow_computation_factory.create_constant(value, type_spec) type_signature = type_serialization.deserialize_type(proto.type) result = await executor.create_value(proto, type_signature) return await executor.create_call(result)
async def embed_tf_constant(executor, type_spec, value): """Embeds a constant `val` of TFF type `type_spec` in `executor`. Args: executor: An instance of `tff.framework.Executor`. type_spec: An instance of `tff.Type`. value: A value, must be a tensor or nested structure of tensors with the structure matching `type_spec`. Returns: An instance of `tff.framework.ExecutorValue` containing an embedded value. """ py_typecheck.check_type(executor, executor_base.Executor) proto, type_signature = tensorflow_computation_factory.create_constant( value, type_spec) result = await executor.create_value(proto, type_signature) return await executor.create_call(result)
def test_raises_type_error(self, value, type_signature): with self.assertRaises(TypeError): tensorflow_computation_factory.create_constant( value, type_signature)
def create_dummy_computation_tensorflow_constant(value=10.0): """Returns a tensorflow computation and type `( -> T)`.""" type_spec = type_utils.infer_type(value) value = tensorflow_computation_factory.create_constant(value, type_spec) type_signature = computation_types.FunctionType(None, type_spec) return value, type_signature