Ejemplo n.º 1
0
 def initialise_sampler(self):
     self.sampler = []
     if self.is_training:
         self.sampler.append([
             ResizeSampler(reader=reader,
                           data_param=self.data_param,
                           batch_size=self.net_param.batch_size,
                           windows_per_image=1,
                           shuffle_buffer=True,
                           queue_length=self.net_param.queue_length)
             for reader in self.readers
         ])
         return
     if self._infer_type in ('encode', 'encode-decode'):
         self.sampler.append([
             ResizeSampler(reader=reader,
                           data_param=self.data_param,
                           batch_size=self.net_param.batch_size,
                           windows_per_image=1,
                           shuffle_buffer=False,
                           queue_length=self.net_param.queue_length)
             for reader in self.readers
         ])
         return
     if self._infer_type == 'linear_interpolation':
         self.sampler.append([
             LinearInterpolateSampler(
                 reader=reader,
                 data_param=self.data_param,
                 batch_size=self.net_param.batch_size,
                 n_interpolations=self.autoencoder_param.n_interpolations,
                 queue_length=self.net_param.queue_length)
             for reader in self.readers
         ])
         return
 def test_init(self):
     sampler = LinearInterpolateSampler(reader=get_3d_reader(),
                                        data_param=MULTI_MOD_DATA,
                                        batch_size=1,
                                        n_interpolations=8,
                                        queue_length=1)
     with self.test_session() as sess:
         coordinator = tf.train.Coordinator()
         sampler.run_threads(sess, coordinator, num_threads=2)
         out = sess.run(sampler.pop_batch_op())
         self.assertAllClose(out['image'].shape, [1, 256, 168, 256, 2])
     sampler.close_all()
 def test_init(self):
     sampler = LinearInterpolateSampler(
         reader=get_3d_reader(),
         data_param=MULTI_MOD_DATA,
         batch_size=1,
         n_interpolations=8,
         queue_length=1)
     with self.test_session() as sess:
         coordinator = tf.train.Coordinator()
         sampler.run_threads(sess, coordinator, num_threads=2)
         out = sess.run(sampler.pop_batch_op())
         self.assertAllClose(out['image'].shape, [1, 256, 168, 256, 2])
     sampler.close_all()