def _run_sampler_threads(self, session=None): """ Get samplers from application and try to run sampler threads. Note: Overriding app.get_sampler() method by returning None to bypass this step. :param session: TF session used for fill tf.placeholders with sampled data :return: """ if session is None: return if self._coord is None: return if self.num_threads <= 0: return try: samplers = self.app.get_sampler() for sampler in traverse_nested(samplers): if sampler is None: continue sampler.run_threads(session, self._coord, self.num_threads) tf.logging.info('Filling queues (this can take a few minutes)') except (NameError, TypeError, AttributeError, IndexError): tf.logging.fatal( "samplers not running, pop_batch_op operations " "are blocked.") raise
def _run_sampler_threads(self, session=None): """ Get samplers from application and try to run sampler threads. Note: Overriding app.get_sampler() method by returning None to bypass this step. :param session: TF session used for fill tf.placeholders with sampled data :return: """ if session is None: return if self._coord is None: return if self.num_threads <= 0: return try: samplers = self.app.get_sampler() for sampler in traverse_nested(samplers): if sampler is None: continue sampler.run_threads(session, self._coord, self.num_threads) tf.logging.info('Filling queues (this can take a few minutes)') except (NameError, TypeError, AttributeError, IndexError): tf.logging.fatal("samplers not running, pop_batch_op operations " "are blocked.") raise
def stop(self): """ stop the sampling threads if there's any. :return: """ for sampler in util_common.traverse_nested(self.get_sampler()): if sampler is None: continue sampler.close_all()
def stop_sampler_threads(self, sender, **_unused_msg): """ Stop the sampler's threads :param sender: an instance of niftynet.application :param _unused_msg: :return: """ try: tf.logging.info('stopping sampling threads') for sampler in traverse_nested(sender.get_sampler()): if sampler is None: continue sampler.close_all() except (AttributeError, TypeError): pass
def create_graph(application, num_gpus=1, num_threads=1, is_training_action=False): """ Create a TF graph based on self.app properties and engine parameters. :return: """ graph = tf.Graph() main_device = device_string(num_gpus, 0, False, is_training_action) outputs_collector = OutputsCollector(n_devices=max(num_gpus, 1)) gradients_collector = GradientsCollector(n_devices=max(num_gpus, 1)) # start constructing the graph, handling training and inference cases with graph.as_default(), tf.device(main_device): # initialise sampler with tf.name_scope('Sampler'): application.initialise_sampler() for sampler in traverse_nested(application.get_sampler()): sampler.set_num_threads(num_threads) # initialise network, these are connected in # the context of multiple gpus application.initialise_network() application.add_validation_flag() # for data parallelism -- # defining and collecting variables from multiple devices for gpu_id in range(0, max(num_gpus, 1)): worker_device = device_string(num_gpus, gpu_id, True, is_training_action) scope_string = 'worker_{}'.format(gpu_id) with tf.name_scope(scope_string), tf.device(worker_device): # setup network for each of the multiple devices application.connect_data_and_network( outputs_collector, gradients_collector) with tf.name_scope('MergeOutputs'): outputs_collector.finalise_output_op() application.outputs_collector = outputs_collector application.gradients_collector = gradients_collector GRAPH_CREATED.send(application, iter_msg=None) return graph