def start(self): target = LocalServer.get_target() logger.info("Initializing tf session, connecting to %s...", target) self.graph = tf.Graph() self.session = tf.Session(target=target, graph=self.graph) with self.graph.as_default(): self.__read_meta_graph() if self.summary is not None: self.summary_saver = tf.summary.FileWriter(self.log_dir, self.graph) if self.optimizer_func is None: # get actual operations/tensors from names self.optimizer = self.graph.get_operation_by_name( self.optimizer_loss_names[0]) self.loss = self.graph.get_tensor_by_name( self.optimizer_loss_names[1]) # add symbolic gradients for tensor_name in self.gradients: tensor = self.graph.get_tensor_by_name(tensor_name) self.tf_gradient[tensor_name] = tf.gradients(self.loss, [tensor])[0]
def __init__(self, weight_graph_basename, inference_graph_basename, input_keys, output_keys): assert os.path.exists(weight_graph_basename + '.index'), weight_graph_basename # NOTE this seems a bit dubious, don't know if this is persistent # for different tf models # assert os.path.exists(weight_graph_basename + '.data-00000-of-00001') self.weight_graph_basename = weight_graph_basename assert os.path.exists(inference_graph_basename + '.meta'), inference_graph_basename self.inference_graph_basename = inference_graph_basename if not (isinstance(input_keys, tuple) or isinstance(input_keys, list)): input_keys = [ input_keys, ] self.input_keys = input_keys self.output_keys = output_keys self.graph = tf.Graph() self.session = tf.Session(graph=self.graph) with self.graph.as_default(): self._read_meta_graph() self.lock = threading.Lock()
def start(self): target = LocalServer.get_target() logger.info("Initializing tf session, connecting to %s...", target) self.graph = tf.Graph() self.session = tf.Session(target=target, graph=self.graph) with self.graph.as_default(): self.__read_checkpoint()
def __predict(self): '''The background predict process.''' try: # TODO: is the server still needed? target = LocalServer.get_target() logger.info("Initializing tf session, connecting to %s...", target) self.graph = tf.Graph() self.session = tf.Session(target=target, graph=self.graph) with self.graph.as_default(): self.__read_checkpoint() if not self.shared_output_arrays: if not self.shared_output_array_config: self.__create_shared_output_array_config() self.__init_shared_output_arrays() # from now on it is save to access the shared array configuration self.predict_process_initialized.set() # loop predict while True: # wait for inputs self.worker_sent_inputs.wait() self.worker_sent_inputs.clear() if not self.shared_input_arrays: self.__init_shared_input_arrays() # read inputs input_data = self.__read_inputs_from_shared() self.predict_received_inputs.set() # compute outputs output_data = self.session.run( {t: t for t in self.outputs.keys()}, feed_dict=input_data) # write outputs self.__write_outputs_to_shared(output_data) self.predict_sent_outputs.set() except Exception as e: self.predict_process_crashed.value = True # release locks and events self.predict_process_initialized.set() self.worker_sent_inputs.clear() self.predict_received_inputs.set() self.predict_sent_outputs.set() raise e
def __init__(self, graph_name): self.graph_name = graph_name self.graph = tf.Graph() self.session = tf.Session(graph=self.graph) with self.graph.as_default(): self.read_meta_graph()