Esempio n. 1
0
 def test_ill_add(self):
     collector = OutputsCollector(n_devices=2)
     foo = tf.zeros([2, 2])
     bar = tf.zeros([42])
     with self.assertRaisesRegexp(AssertionError, ""):
         collector.add_to_collection(name=None, var=None)
     with self.assertRaisesRegexp(AssertionError, ""):
         collector.add_to_collection(name=None, var=bar)
     with self.assertRaisesRegexp(ValueError, ""):
         collector.add_to_collection(name=foo, var=bar,
                                     average_over_devices=True)
         collector.add_to_collection(name=foo, var=bar,
                                     average_over_devices=True)
         collector.add_to_collection(name=foo, var=bar,
                                     average_over_devices=True)
Esempio n. 2
0
 def test_netout_single_device(self):
     n_device = 1
     collector = OutputsCollector(n_devices=n_device)
     for idx in range(n_device):
         with tf.name_scope('worker_%d' % idx):
             image = tf.ones([2, 32, 32, 32, 1])
             foo = tf.zeros([2, 2])
             collector.add_to_collection(name='image',
                                         var=image,
                                         collection=NETORK_OUTPUT,
                                         average_over_devices=False)
             collector.add_to_collection(name='foo',
                                         var=foo,
                                         collection=NETORK_OUTPUT,
                                         average_over_devices=False)
     self.assertDictEqual(collector.output_vars,
                          {'image': image, 'foo': foo})
Esempio n. 3
0
 def test_netout_mutiple_device(self):
     n_device = 4
     collector = OutputsCollector(n_devices=n_device)
     for idx in range(n_device):
         with tf.name_scope('worker_%d' % idx):
             image = tf.ones([2, 32, 32, 32, 1])
             foo = tf.zeros([2, 2])
             bar = tf.zeros([42])
             collector.add_to_collection(name='image',
                                         var=image,
                                         collection=NETORK_OUTPUT,
                                         average_over_devices=False)
             collector.add_to_collection(name='foo',
                                         var=foo,
                                         collection=NETORK_OUTPUT,
                                         average_over_devices=False)
             collector.add_to_collection(name='bar',
                                         var=bar,
                                         collection=NETORK_OUTPUT,
                                         average_over_devices=True)
     self.assertEqual(
         set(collector.output_vars),
         {'image_1', 'image_3', 'image_2',
          'image', 'foo_1', 'foo_2', 'foo_3', 'foo', 'bar'})
     self.assertEqual(len(collector.output_vars['bar']), n_device)
     collector.finalise_output_op()
     self.assertIsInstance(collector.output_vars['bar'], tf.Tensor)
Esempio n. 4
0
 def test_add_to_single_device(self):
     n_device = 1
     collector = OutputsCollector(n_devices=n_device)
     for idx in range(n_device):
         with tf.name_scope('worker_%d' % idx):
             image = tf.ones([2, 32, 32, 32, 1])
             foo = tf.zeros([2, 2])
             bar = tf.zeros([42])
             collector.add_to_collection(name='image',
                                         var=image,
                                         average_over_devices=False)
             collector.add_to_collection(name='foo',
                                         var=foo,
                                         average_over_devices=False)
             collector.add_to_collection(name='bar',
                                         var=bar,
                                         collection=NETWORK_OUTPUT,
                                         average_over_devices=False)
     self.assertDictEqual(collector.variables(collection=CONSOLE),
                          {'image': image, 'foo': foo})
     self.assertDictEqual(collector.variables(collection=NETWORK_OUTPUT),
                          {'bar': bar})