def start(self): target = LocalServer.get_target() logger.info("Initializing tf session, connecting to %s...", target) self.graph = tf.Graph() self.session = tf.Session(target=target, graph=self.graph) with self.graph.as_default(): self.__read_meta_graph() if self.summary is not None: self.summary_saver = tf.summary.FileWriter(self.log_dir, self.graph) if self.optimizer_func is None: # get actual operations/tensors from names self.optimizer = self.graph.get_operation_by_name( self.optimizer_loss_names[0]) self.loss = self.graph.get_tensor_by_name( self.optimizer_loss_names[1]) # add symbolic gradients for tensor_name in self.gradients: tensor = self.graph.get_tensor_by_name(tensor_name) self.tf_gradient[tensor_name] = tf.gradients(self.loss, [tensor])[0]
def start(self): target = LocalServer.get_target() logger.info("Initializing tf session, connecting to %s...", target) self.graph = tf.Graph() self.session = tf.Session(target=target, graph=self.graph) with self.graph.as_default(): self.__read_checkpoint()
def __predict(self): '''The background predict process.''' try: # TODO: is the server still needed? target = LocalServer.get_target() logger.info("Initializing tf session, connecting to %s...", target) self.graph = tf.Graph() self.session = tf.Session(target=target, graph=self.graph) with self.graph.as_default(): self.__read_checkpoint() if not self.shared_output_arrays: if not self.shared_output_array_config: self.__create_shared_output_array_config() self.__init_shared_output_arrays() # from now on it is save to access the shared array configuration self.predict_process_initialized.set() # loop predict while True: # wait for inputs self.worker_sent_inputs.wait() self.worker_sent_inputs.clear() if not self.shared_input_arrays: self.__init_shared_input_arrays() # read inputs input_data = self.__read_inputs_from_shared() self.predict_received_inputs.set() # compute outputs output_data = self.session.run( {t: t for t in self.outputs.keys()}, feed_dict=input_data) # write outputs self.__write_outputs_to_shared(output_data) self.predict_sent_outputs.set() except Exception as e: self.predict_process_crashed.value = True # release locks and events self.predict_process_initialized.set() self.worker_sent_inputs.clear() self.predict_received_inputs.set() self.predict_sent_outputs.set() raise e