Example #1
0
    def start(self):

        target = LocalServer.get_target()
        logger.info("Initializing tf session, connecting to %s...", target)

        self.graph = tf.Graph()
        self.session = tf.Session(target=target, graph=self.graph)

        with self.graph.as_default():
            self.__read_meta_graph()

        if self.summary is not None:
            self.summary_saver = tf.summary.FileWriter(self.log_dir,
                                                       self.graph)

        if self.optimizer_func is None:

            # get actual operations/tensors from names
            self.optimizer = self.graph.get_operation_by_name(
                self.optimizer_loss_names[0])
            self.loss = self.graph.get_tensor_by_name(
                self.optimizer_loss_names[1])

        # add symbolic gradients
        for tensor_name in self.gradients:
            tensor = self.graph.get_tensor_by_name(tensor_name)
            self.tf_gradient[tensor_name] = tf.gradients(self.loss,
                                                         [tensor])[0]
Example #2
0
    def __init__(self, weight_graph_basename, inference_graph_basename,
                 input_keys, output_keys):
        assert os.path.exists(weight_graph_basename +
                              '.index'), weight_graph_basename
        # NOTE this seems a bit dubious, don't know if this is persistent
        # for different tf models
        # assert os.path.exists(weight_graph_basename + '.data-00000-of-00001')
        self.weight_graph_basename = weight_graph_basename

        assert os.path.exists(inference_graph_basename +
                              '.meta'), inference_graph_basename
        self.inference_graph_basename = inference_graph_basename
        if not (isinstance(input_keys, tuple) or isinstance(input_keys, list)):
            input_keys = [
                input_keys,
            ]
        self.input_keys = input_keys
        self.output_keys = output_keys

        self.graph = tf.Graph()
        self.session = tf.Session(graph=self.graph)

        with self.graph.as_default():
            self._read_meta_graph()

        self.lock = threading.Lock()
Example #3
0
    def start(self):

        target = LocalServer.get_target()
        logger.info("Initializing tf session, connecting to %s...", target)

        self.graph = tf.Graph()
        self.session = tf.Session(target=target, graph=self.graph)

        with self.graph.as_default():
            self.__read_checkpoint()
Example #4
0
    def __predict(self):
        '''The background predict process.'''

        try:
            # TODO: is the server still needed?
            target = LocalServer.get_target()
            logger.info("Initializing tf session, connecting to %s...", target)

            self.graph = tf.Graph()
            self.session = tf.Session(target=target, graph=self.graph)

            with self.graph.as_default():
                self.__read_checkpoint()

            if not self.shared_output_arrays:
                if not self.shared_output_array_config:
                    self.__create_shared_output_array_config()
                self.__init_shared_output_arrays()

            # from now on it is save to access the shared array configuration
            self.predict_process_initialized.set()

            # loop predict
            while True:

                # wait for inputs
                self.worker_sent_inputs.wait()
                self.worker_sent_inputs.clear()

                if not self.shared_input_arrays:
                    self.__init_shared_input_arrays()

                # read inputs
                input_data = self.__read_inputs_from_shared()
                self.predict_received_inputs.set()

                # compute outputs
                output_data = self.session.run(
                    {t: t
                     for t in self.outputs.keys()}, feed_dict=input_data)

                # write outputs
                self.__write_outputs_to_shared(output_data)
                self.predict_sent_outputs.set()

        except Exception as e:

            self.predict_process_crashed.value = True

            # release locks and events
            self.predict_process_initialized.set()
            self.worker_sent_inputs.clear()
            self.predict_received_inputs.set()
            self.predict_sent_outputs.set()
            raise e
Example #5
0
 def __init__(self, graph_name):
     self.graph_name = graph_name
     self.graph = tf.Graph()
     self.session = tf.Session(graph=self.graph)
     with self.graph.as_default():
         self.read_meta_graph()