コード例 #1
0
 def test_computation_context_get_cardinality(self):
   c1 = reference_executor.ComputationContext(None, None,
                                              {placements.CLIENTS: 10})
   self.assertEqual(c1.get_cardinality(placements.CLIENTS), 10)
   with self.assertRaises(ValueError):
     c1.get_cardinality(placements.SERVER)
   c2 = reference_executor.ComputationContext(c1)
   self.assertEqual(c2.get_cardinality(placements.CLIENTS), 10)
コード例 #2
0
 def test_fit_argument(self):
   old_arg = reference_executor.ComputedValue(
       anonymous_tuple.AnonymousTuple([('A', 10)]),
       [('A', type_constructors.at_clients(tf.int32, all_equal=True))])
   new_arg = reference_executor.fit_argument(
       old_arg, [('A', type_constructors.at_clients(tf.int32))],
       reference_executor.ComputationContext(
           cardinalities={placements.CLIENTS: 3}))
   self.assertEqual(str(new_arg.type_signature), '<A={int32}@CLIENTS>')
   self.assertEqual(new_arg.value.A, [10, 10, 10])
コード例 #3
0
 def test_computation_context_resolve_reference(self):
   c1 = reference_executor.ComputationContext()
   c2 = reference_executor.ComputationContext(
       c1, {'foo': reference_executor.ComputedValue(10, tf.int32)})
   c3 = reference_executor.ComputationContext(
       c2, {'bar': reference_executor.ComputedValue(11, tf.int32)})
   c4 = reference_executor.ComputationContext(c3)
   c5 = reference_executor.ComputationContext(
       c4, {'foo': reference_executor.ComputedValue(12, tf.int32)})
   self.assertRaises(ValueError, c1.resolve_reference, 'foo')
   self.assertEqual(c2.resolve_reference('foo').value, 10)
   self.assertEqual(c3.resolve_reference('foo').value, 10)
   self.assertEqual(c4.resolve_reference('foo').value, 10)
   self.assertEqual(c5.resolve_reference('foo').value, 12)
   self.assertRaises(ValueError, c1.resolve_reference, 'bar')
   self.assertRaises(ValueError, c2.resolve_reference, 'bar')
   self.assertEqual(c3.resolve_reference('bar').value, 11)
   self.assertEqual(c4.resolve_reference('bar').value, 11)
   self.assertEqual(c5.resolve_reference('bar').value, 11)