コード例 #1
0
 def test_build_dataset_reduce_fn(self, simulation, reduce_fn):
   dataset_reduce_fn = dp_fedavg._build_dataset_reduce_fn(simulation)
   self.assertIs(dataset_reduce_fn, reduce_fn)
   ds = tf.data.Dataset.range(10, output_type=tf.int32)
   total_sum = dataset_reduce_fn(
       reduce_fn=lambda x, y: x + y, dataset=ds, initial_state_fn=lambda: 0)
   self.assertEqual(total_sum, np.int32(45))
コード例 #2
0
 def test_build_dataset_reduce_fn_tuple(self, simulation, reduce_fn):
     dataset_reduce_fn = dp_fedavg._build_dataset_reduce_fn(simulation)
     self.assertIs(dataset_reduce_fn, reduce_fn)
     ds = tf.data.Dataset.range(
         10, output_type=tf.float32).map(lambda x: 0.1 * x)
     total_cnt, total_sum = dataset_reduce_fn(
         reduce_fn=lambda x, y: (x[0] + 1, x[1] + y),
         dataset=ds,
         initial_state_fn=lambda: (tf.constant(0), tf.constant(0.1)))
     self.assertEqual(total_cnt, np.float32(10))
     self.assertEqual(total_sum, np.float32(4.6))