def test_multiply_by_scalar_with_float_and_float(self): self.assertEqual( reference_executor.multiply_by_scalar( reference_executor.ComputedValue(10.0, tf.float32), 0.5).value, 5.0) self.assertAlmostEqual( reference_executor.multiply_by_scalar( reference_executor.ComputedValue(np.float32(1.0), tf.float32), 0.333333333333).value, 0.3333333, places=3)
def test_stamp_computed_value_into_graph_with_undefined_tensor_dims(self): v_type = computation_types.TensorType(tf.int32, [None]) v_value = np.array([1, 2, 3], dtype=np.int32) v = reference_executor.ComputedValue(v_value, v_type) with tf.Graph().as_default() as graph: stamped_v = reference_executor.stamp_computed_value_into_graph(v, graph) with tf.Session(graph=graph) as sess: v_result = graph_utils.fetch_value_in_session(sess, stamped_v) self.assertTrue(np.array_equal(v_result, np.array([1, 2, 3])))
def test_fit_argument(self): old_arg = reference_executor.ComputedValue( anonymous_tuple.AnonymousTuple([('A', 10)]), [('A', type_constructors.at_clients(tf.int32, all_equal=True))]) new_arg = reference_executor.fit_argument( old_arg, [('A', type_constructors.at_clients(tf.int32))], reference_executor.ComputationContext( cardinalities={placements.CLIENTS: 3})) self.assertEqual(str(new_arg.type_signature), '<A={int32}@CLIENTS>') self.assertEqual(new_arg.value.A, [10, 10, 10])
def test_get_cardinalities_failure(self): with self.assertRaises(ValueError): reference_executor.get_cardinalities( reference_executor.ComputedValue( anonymous_tuple.AnonymousTuple([('A', [1, 2, 3]), ('B', [1, 2])]), [('A', computation_types.FederatedType(tf.int32, placements.CLIENTS)), ('B', computation_types.FederatedType(tf.int32, placements.CLIENTS)) ]))
def test_multiply_by_scalar_with_tuple_and_float(self): self.assertEqual( str( reference_executor.multiply_by_scalar( reference_executor.ComputedValue( anonymous_tuple.AnonymousTuple([ ('A', 10.0), ('B', anonymous_tuple.AnonymousTuple([('C', 20.0)])), ]), [('A', tf.float32), ('B', [('C', tf.float32)])]), 0.5).value), '<A=5.0,B=<C=10.0>>')
def test_computation_context_resolve_reference(self): c1 = reference_executor.ComputationContext() c2 = reference_executor.ComputationContext( c1, {'foo': reference_executor.ComputedValue(10, tf.int32)}) c3 = reference_executor.ComputationContext( c2, {'bar': reference_executor.ComputedValue(11, tf.int32)}) c4 = reference_executor.ComputationContext(c3) c5 = reference_executor.ComputationContext( c4, {'foo': reference_executor.ComputedValue(12, tf.int32)}) self.assertRaises(ValueError, c1.resolve_reference, 'foo') self.assertEqual(c2.resolve_reference('foo').value, 10) self.assertEqual(c3.resolve_reference('foo').value, 10) self.assertEqual(c4.resolve_reference('foo').value, 10) self.assertEqual(c5.resolve_reference('foo').value, 12) self.assertRaises(ValueError, c1.resolve_reference, 'bar') self.assertRaises(ValueError, c2.resolve_reference, 'bar') self.assertEqual(c3.resolve_reference('bar').value, 11) self.assertEqual(c4.resolve_reference('bar').value, 11) self.assertEqual(c5.resolve_reference('bar').value, 11)
def test_stamp_computed_value_into_graph_with_tuples_of_tensors(self): v_val = anonymous_tuple.AnonymousTuple([('x', 10), ('y', anonymous_tuple.AnonymousTuple( [('z', 0.6)]))]) v_type = [('x', tf.int32), ('y', [('z', tf.float32)])] v = reference_executor.ComputedValue( reference_executor.to_representation_for_type(v_val, v_type), v_type) with tf.Graph().as_default() as graph: stamped_v = reference_executor.stamp_computed_value_into_graph(v, graph) with tf.Session(graph=graph) as sess: v_val = graph_utils.fetch_value_in_session(sess, stamped_v) self.assertEqual(str(v_val), '<x=10,y=<z=0.6>>')
def test_get_cardinalities_success(self): foo = reference_executor.get_cardinalities( reference_executor.ComputedValue( anonymous_tuple.AnonymousTuple( [('A', [1, 2, 3]), ('B', anonymous_tuple.AnonymousTuple( [('C', [[1, 2], [3, 4], [5, 6]]), ('D', [True, False, True])]))]), [('A', computation_types.FederatedType(tf.int32, placements.CLIENTS)), ('B', [('C', computation_types.FederatedType( computation_types.SequenceType(tf.int32), placements.CLIENTS)), ('D', computation_types.FederatedType(tf.bool, placements.CLIENTS))])])) self.assertDictEqual(foo, {placements.CLIENTS: 3})
def foo(x): self.assertIsInstance(x, reference_executor.ComputedValue) return reference_executor.ComputedValue(str(x.value), tf.string)
def test_computed_value(self): v = reference_executor.ComputedValue(10, tf.int32) self.assertEqual(str(v.type_signature), 'int32') self.assertEqual(v.value, 10)