def test_computation_context_get_cardinality(self): c1 = reference_executor.ComputationContext(None, None, {placements.CLIENTS: 10}) self.assertEqual(c1.get_cardinality(placements.CLIENTS), 10) with self.assertRaises(ValueError): c1.get_cardinality(placements.SERVER) c2 = reference_executor.ComputationContext(c1) self.assertEqual(c2.get_cardinality(placements.CLIENTS), 10)
def test_fit_argument(self): old_arg = reference_executor.ComputedValue( anonymous_tuple.AnonymousTuple([('A', 10)]), [('A', type_constructors.at_clients(tf.int32, all_equal=True))]) new_arg = reference_executor.fit_argument( old_arg, [('A', type_constructors.at_clients(tf.int32))], reference_executor.ComputationContext( cardinalities={placements.CLIENTS: 3})) self.assertEqual(str(new_arg.type_signature), '<A={int32}@CLIENTS>') self.assertEqual(new_arg.value.A, [10, 10, 10])
def test_computation_context_resolve_reference(self): c1 = reference_executor.ComputationContext() c2 = reference_executor.ComputationContext( c1, {'foo': reference_executor.ComputedValue(10, tf.int32)}) c3 = reference_executor.ComputationContext( c2, {'bar': reference_executor.ComputedValue(11, tf.int32)}) c4 = reference_executor.ComputationContext(c3) c5 = reference_executor.ComputationContext( c4, {'foo': reference_executor.ComputedValue(12, tf.int32)}) self.assertRaises(ValueError, c1.resolve_reference, 'foo') self.assertEqual(c2.resolve_reference('foo').value, 10) self.assertEqual(c3.resolve_reference('foo').value, 10) self.assertEqual(c4.resolve_reference('foo').value, 10) self.assertEqual(c5.resolve_reference('foo').value, 12) self.assertRaises(ValueError, c1.resolve_reference, 'bar') self.assertRaises(ValueError, c2.resolve_reference, 'bar') self.assertEqual(c3.resolve_reference('bar').value, 11) self.assertEqual(c4.resolve_reference('bar').value, 11) self.assertEqual(c5.resolve_reference('bar').value, 11)