コード例 #1
0
 def test_multiply_by_scalar_with_float_and_float(self):
   self.assertEqual(
       reference_executor.multiply_by_scalar(
           reference_executor.ComputedValue(10.0, tf.float32), 0.5).value, 5.0)
   self.assertAlmostEqual(
       reference_executor.multiply_by_scalar(
           reference_executor.ComputedValue(np.float32(1.0), tf.float32),
           0.333333333333).value,
       0.3333333,
       places=3)
コード例 #2
0
 def test_stamp_computed_value_into_graph_with_undefined_tensor_dims(self):
   v_type = computation_types.TensorType(tf.int32, [None])
   v_value = np.array([1, 2, 3], dtype=np.int32)
   v = reference_executor.ComputedValue(v_value, v_type)
   with tf.Graph().as_default() as graph:
     stamped_v = reference_executor.stamp_computed_value_into_graph(v, graph)
     with tf.Session(graph=graph) as sess:
       v_result = graph_utils.fetch_value_in_session(sess, stamped_v)
   self.assertTrue(np.array_equal(v_result, np.array([1, 2, 3])))
コード例 #3
0
 def test_fit_argument(self):
   old_arg = reference_executor.ComputedValue(
       anonymous_tuple.AnonymousTuple([('A', 10)]),
       [('A', type_constructors.at_clients(tf.int32, all_equal=True))])
   new_arg = reference_executor.fit_argument(
       old_arg, [('A', type_constructors.at_clients(tf.int32))],
       reference_executor.ComputationContext(
           cardinalities={placements.CLIENTS: 3}))
   self.assertEqual(str(new_arg.type_signature), '<A={int32}@CLIENTS>')
   self.assertEqual(new_arg.value.A, [10, 10, 10])
コード例 #4
0
 def test_get_cardinalities_failure(self):
   with self.assertRaises(ValueError):
     reference_executor.get_cardinalities(
         reference_executor.ComputedValue(
             anonymous_tuple.AnonymousTuple([('A', [1, 2, 3]), ('B', [1, 2])]),
             [('A',
               computation_types.FederatedType(tf.int32, placements.CLIENTS)),
              ('B',
               computation_types.FederatedType(tf.int32, placements.CLIENTS))
             ]))
コード例 #5
0
 def test_multiply_by_scalar_with_tuple_and_float(self):
   self.assertEqual(
       str(
           reference_executor.multiply_by_scalar(
               reference_executor.ComputedValue(
                   anonymous_tuple.AnonymousTuple([
                       ('A', 10.0),
                       ('B', anonymous_tuple.AnonymousTuple([('C', 20.0)])),
                   ]), [('A', tf.float32), ('B', [('C', tf.float32)])]),
               0.5).value), '<A=5.0,B=<C=10.0>>')
コード例 #6
0
 def test_computation_context_resolve_reference(self):
   c1 = reference_executor.ComputationContext()
   c2 = reference_executor.ComputationContext(
       c1, {'foo': reference_executor.ComputedValue(10, tf.int32)})
   c3 = reference_executor.ComputationContext(
       c2, {'bar': reference_executor.ComputedValue(11, tf.int32)})
   c4 = reference_executor.ComputationContext(c3)
   c5 = reference_executor.ComputationContext(
       c4, {'foo': reference_executor.ComputedValue(12, tf.int32)})
   self.assertRaises(ValueError, c1.resolve_reference, 'foo')
   self.assertEqual(c2.resolve_reference('foo').value, 10)
   self.assertEqual(c3.resolve_reference('foo').value, 10)
   self.assertEqual(c4.resolve_reference('foo').value, 10)
   self.assertEqual(c5.resolve_reference('foo').value, 12)
   self.assertRaises(ValueError, c1.resolve_reference, 'bar')
   self.assertRaises(ValueError, c2.resolve_reference, 'bar')
   self.assertEqual(c3.resolve_reference('bar').value, 11)
   self.assertEqual(c4.resolve_reference('bar').value, 11)
   self.assertEqual(c5.resolve_reference('bar').value, 11)
コード例 #7
0
 def test_stamp_computed_value_into_graph_with_tuples_of_tensors(self):
   v_val = anonymous_tuple.AnonymousTuple([('x', 10),
                                           ('y',
                                            anonymous_tuple.AnonymousTuple(
                                                [('z', 0.6)]))])
   v_type = [('x', tf.int32), ('y', [('z', tf.float32)])]
   v = reference_executor.ComputedValue(
       reference_executor.to_representation_for_type(v_val, v_type), v_type)
   with tf.Graph().as_default() as graph:
     stamped_v = reference_executor.stamp_computed_value_into_graph(v, graph)
     with tf.Session(graph=graph) as sess:
       v_val = graph_utils.fetch_value_in_session(sess, stamped_v)
   self.assertEqual(str(v_val), '<x=10,y=<z=0.6>>')
コード例 #8
0
 def test_get_cardinalities_success(self):
   foo = reference_executor.get_cardinalities(
       reference_executor.ComputedValue(
           anonymous_tuple.AnonymousTuple(
               [('A', [1, 2, 3]),
                ('B',
                 anonymous_tuple.AnonymousTuple(
                     [('C', [[1, 2], [3, 4], [5, 6]]),
                      ('D', [True, False, True])]))]),
           [('A', computation_types.FederatedType(tf.int32,
                                                  placements.CLIENTS)),
            ('B', [('C',
                    computation_types.FederatedType(
                        computation_types.SequenceType(tf.int32),
                        placements.CLIENTS)),
                   ('D',
                    computation_types.FederatedType(tf.bool,
                                                    placements.CLIENTS))])]))
   self.assertDictEqual(foo, {placements.CLIENTS: 3})
コード例 #9
0
 def foo(x):
   self.assertIsInstance(x, reference_executor.ComputedValue)
   return reference_executor.ComputedValue(str(x.value), tf.string)
コード例 #10
0
 def test_computed_value(self):
   v = reference_executor.ComputedValue(10, tf.int32)
   self.assertEqual(str(v.type_signature), 'int32')
   self.assertEqual(v.value, 10)