示例#1
0
    def start(self):

        target = LocalServer.get_target()
        logger.info("Initializing tf session, connecting to %s...", target)

        self.graph = tf.Graph()
        self.session = tf.Session(target=target, graph=self.graph)

        with self.graph.as_default():
            self.__read_meta_graph()

        if self.summary is not None:
            self.summary_saver = tf.summary.FileWriter(self.log_dir,
                                                       self.graph)

        if self.optimizer_func is None:

            # get actual operations/tensors from names
            self.optimizer = self.graph.get_operation_by_name(
                self.optimizer_loss_names[0])
            self.loss = self.graph.get_tensor_by_name(
                self.optimizer_loss_names[1])

        # add symbolic gradients
        for tensor_name in self.gradients:
            tensor = self.graph.get_tensor_by_name(tensor_name)
            self.tf_gradient[tensor_name] = tf.gradients(self.loss,
                                                         [tensor])[0]
示例#2
0
    def start(self):

        target = LocalServer.get_target()
        logger.info("Initializing tf session, connecting to %s...", target)

        self.graph = tf.Graph()
        self.session = tf.Session(target=target, graph=self.graph)

        with self.graph.as_default():
            self.__read_checkpoint()
示例#3
0
    def __predict(self):
        '''The background predict process.'''

        try:
            # TODO: is the server still needed?
            target = LocalServer.get_target()
            logger.info("Initializing tf session, connecting to %s...", target)

            self.graph = tf.Graph()
            self.session = tf.Session(target=target, graph=self.graph)

            with self.graph.as_default():
                self.__read_checkpoint()

            if not self.shared_output_arrays:
                if not self.shared_output_array_config:
                    self.__create_shared_output_array_config()
                self.__init_shared_output_arrays()

            # from now on it is save to access the shared array configuration
            self.predict_process_initialized.set()

            # loop predict
            while True:

                # wait for inputs
                self.worker_sent_inputs.wait()
                self.worker_sent_inputs.clear()

                if not self.shared_input_arrays:
                    self.__init_shared_input_arrays()

                # read inputs
                input_data = self.__read_inputs_from_shared()
                self.predict_received_inputs.set()

                # compute outputs
                output_data = self.session.run(
                    {t: t
                     for t in self.outputs.keys()}, feed_dict=input_data)

                # write outputs
                self.__write_outputs_to_shared(output_data)
                self.predict_sent_outputs.set()

        except Exception as e:

            self.predict_process_crashed.value = True

            # release locks and events
            self.predict_process_initialized.set()
            self.worker_sent_inputs.clear()
            self.predict_received_inputs.set()
            self.predict_sent_outputs.set()
            raise e