コード例 #1
0
    def prepare_model(self):
        """
        first step prepare model
        needs to be called by subclass in re-write process

        Necessary: subclass needs to init
        self._input_stream
        """
        if self.config.MODEL_TYPE is 'od':
            self.download_model()
            self.load_frozen_graph()
            self.load_category_index()
        elif self.config.MODEL_TYPE is 'dl':
            self.download_model()
            self.load_frozen_graph()
        self.fps = FPS(self.config.FPS_INTERVAL).start()
        self._visualizer = Visualizer(self.config).start()
        return self
コード例 #2
0
class Model(object):
    """
    Base Tensorflow Inference Model Class
    """
    def __init__(self,config):
        self.config = config
        self.detection_graph = tf.Graph()
        self.category_index = None
        self.masks = None
        #self._tf_config = tf.compat.v1.ConfigProto(allow_soft_placement=True)
        self._tf_config = tf.ConfigProto(allow_soft_placement=True)
        self._tf_config.gpu_options.allow_growth=True
        #self._tf_config.gpu_options.force_gpu_compatible=True
        #self._tf_config.gpu_options.per_process_gpu_memory_fraction = 0.01
        self._run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
        self._run_metadata = False
        self._wait_thread = False
        self._is_imageD = False
        self._is_videoD = False
        self._is_rosD = False
        print ('> Model: {}'.format(self.config.MODEL_PATH))

    def download_model(self):
        """
        downlaods model from model_zoo
        """
        if self.config.MODEL_TYPE == 'dl':
            download_base = 'http://download.tensorflow.org/models/'
        elif self.config.MODEL_TYPE == 'od':
            download_base = 'http://download.tensorflow.org/models/object_detection/'
        model_file = self.config.MODEL_NAME + '.tar.gz'
        if not os.path.isfile(self.config.MODEL_PATH) and self.config.DOWNLOAD_MODEL:
            print('> Model not found. Downloading it now.')
            opener = urllib.request.URLopener()
            opener.retrieve(download_base + model_file, model_file)
            tar_file = tarfile.open(model_file)
            for file in tar_file.getmembers():
              file_name = os.path.basename(file.name)
              if 'frozen_inference_graph.pb' in file_name:
                tar_file.extract(file, os.getcwd() + '/models/')
            os.remove(os.getcwd() + '/' + model_file)
        else:
            print('> Model found. Proceed.')

    def node_name(self,n):
        if n.startswith("^"):
            return n[1:]
        else:
            return n.split(":")[0]

    def load_frozen_graph(self):
        """
        loads graph from frozen model file
        """
        print('> Loading frozen model into memory')
        if (self.config.MODEL_TYPE == 'od' and self.config.SPLIT_MODEL):
            # load a frozen Model and split it into GPU and CPU graphs
            # Hardcoded split points for ssd_mobilenet
            tf.reset_default_graph()
            if self.config.SSD_SHAPE == 600:
                shape = 7326
            else:
                shape = 1917
            self.score = tf.placeholder(tf.float32, shape=(None, shape, self.config.NUM_CLASSES), name=self.config.SPLIT_NODES[0])
            self.expand = tf.placeholder(tf.float32, shape=(None, shape, 1, 4), name=self.config.SPLIT_NODES[1])
            #self.tofloat = tf.placeholder(tf.float32, shape=(None), name=self.config.SPLIT_NODES[2])
            for node in tf.get_default_graph().as_graph_def().node:
                if node.name == self.config.SPLIT_NODES[0]:
                    score_def = node
                if node.name == self.config.SPLIT_NODES[1]:
                    expand_def = node
                #if node.name == self.config.SPLIT_NODES[2]:
                #    tofloat_def = node

            with self.detection_graph.as_default():
                graph_def = tf.GraphDef()
                with tf.gfile.GFile(self.config.MODEL_PATH, 'rb') as fid:
                    serialized_graph = fid.read()
                    graph_def.ParseFromString(serialized_graph)

                    edges = {}
                    name_to_node_map = {}
                    node_seq = {}
                    seq = 0
                    for node in graph_def.node:
                        n = self.node_name(node.name)
                        name_to_node_map[n] = node
                        edges[n] = [self.node_name(x) for x in node.input]
                        node_seq[n] = seq
                        seq += 1
                    for d in self.config.SPLIT_NODES:
                        assert d in name_to_node_map, "%s is not in graph" % d

                    nodes_to_keep = set()
                    next_to_visit = self.config.SPLIT_NODES[:]

                    while next_to_visit:
                        n = next_to_visit[0]
                        del next_to_visit[0]
                        if n in nodes_to_keep: continue
                        nodes_to_keep.add(n)
                        next_to_visit += edges[n]

                    nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n])
                    nodes_to_remove = set()

                    for n in node_seq:
                        if n in nodes_to_keep_list: continue
                        nodes_to_remove.add(n)
                    nodes_to_remove_list = sorted(list(nodes_to_remove), key=lambda n: node_seq[n])

                    keep = graph_pb2.GraphDef()
                    for n in nodes_to_keep_list:
                        keep.node.extend([copy.deepcopy(name_to_node_map[n])])

                    remove = graph_pb2.GraphDef()
                    remove.node.extend([score_def])
                    remove.node.extend([expand_def])
                    for n in nodes_to_remove_list:
                        remove.node.extend([copy.deepcopy(name_to_node_map[n])])

                    with tf.device('/gpu:0'):
                        tf.import_graph_def(keep, name='')
                    with tf.device('/cpu:0'):
                        tf.import_graph_def(remove, name='')
        else:
            # default model loading procedure
            with self.detection_graph.as_default():
              graph_def = tf.GraphDef()
              with tf.gfile.GFile(self.config.MODEL_PATH, 'rb') as fid:
                serialized_graph = fid.read()
                graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(graph_def, name='')

    def load_category_index(self):
        """
        creates categorie_index from label_map
        """
        print('> Loading label map')
        label_map = tf_utils.load_labelmap(self.config.LABEL_PATH)
        categories = tf_utils.convert_label_map_to_categories(label_map, max_num_classes=self.config.NUM_CLASSES, use_display_name=True)
        self.category_index = tf_utils.create_category_index(categories)

    def get_tensor_dict(self, outputs):
        """
        returns tensordict for given tensornames list
        """
        ops = self.detection_graph.get_operations()
        all_tensor_names = {output.name for op in ops for output in op.outputs}
        self.tensor_dict = {}
        for key in outputs:
            tensor_name = key + ':0'
            if tensor_name in all_tensor_names:
                self.tensor_dict[key] = self.detection_graph.get_tensor_by_name(tensor_name)
        return self.tensor_dict

    def prepare_model(self):
        """
        first step prepare model
        needs to be called by subclass in re-write process

        Necessary: subclass needs to init
        self._input_stream
        """
        if self.config.MODEL_TYPE is 'od':
            self.download_model()
            self.load_frozen_graph()
            self.load_category_index()
        elif self.config.MODEL_TYPE is 'dl':
            self.download_model()
            self.load_frozen_graph()
        self.fps = FPS(self.config.FPS_INTERVAL).start()
        self._visualizer = Visualizer(self.config).start()
        return self

    def isActive(self):
        """
        checks if stream and visualizer are active
        """
        return self._input_stream.isActive() and self._visualizer.isActive()

    def stop(self):
        """
        stops all Model sub classes
        """
        self._input_stream.stop()
        self._visualizer.stop()
        self.fps.stop()
        if self.config.SPLIT_MODEL and self.config.MODEL_TYPE is 'od':
            self._gpu_worker.stop()
            self._cpu_worker.stop()

    def detect(self):
        """
        needs to be written by subclass
        """
        self.detection = None

    def run(self):
        """
        runs detection loop on video or image
        listens on isActive()
        """
        print("> starting detection")
        self.start()
        while self.isActive():
            # detection
            self.detect()
            # Visualization
            if not self._wait_thread:
                self.visualize_detection()
                self.fps.update()
        self.stop()

    def start(self):
        """
        starts fps and visualizer class
        """
        self.fps.start()
        self._visualizer = Visualizer(self.config).start()

    def visualize_detection(self):
        self.detection = self._visualizer.visualize_detection(self.frame,self.boxes,
                                                            self.classes,self.scores,
                                                            self.masks,self.fps.fps_local(),
                                                            self.category_index,self._is_imageD)

    def prepare_ros(self,node):
        """
        prepares ros Node and ROSInputstream
        only in ros branch usable due to ROS realted package stuff
        """
        assert node in ['detection_node','deeplab_node'], "only 'detection_node' and 'deeplab_node' supported"
        import rospy
        from ros import ROSStream, DetectionPublisher, SegmentationPublisher
        self._is_rosD = True
        rospy.init_node(node)
        self._input_stream = ROSStream(self.config.ROS_INPUT)
        if node is 'detection_node':
            self._ros_publisher = DetectionPublisher()
        if node is 'deeplab_node':
            self._ros_publisher = SegmentationPublisher()
        # check for frame
        while True:
            self.frame = self._input_stream.read()
            time.sleep(1)
            print("...waiting for ROS image")
            if self.frame is not None:
                self.stream_height,self.stream_width = self.frame.shape[0:2]
                break

    def prepare_timeliner(self):
        """
        prepares timeliner and sets tf Run options
        """
        self._run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        self._run_metadata = tf.RunMetadata()
        self.timeliner = TimeLiner()

    def prepare_tracker(self):
        """
        prepares KCF tracker
        """
        sys.path.append(os.getcwd()+'/rod/kcf')
        import KCF
        self._tracker = KCF.kcftracker(False, True, False, False)
        self._tracker_counter = 0
        self._track = False

    def run_tracker(self):
        """
        runs KCF tracker on videoStream frame
        !does not work on images, obviously!
        """
        self.frame = self._input_stream.read()
        if self._first_track:
            self._trackers = []
            self._tracker_boxes = self.boxes
            num_tracked = 0
            for box in self.boxes[~np.all(self.boxes == 0, axis=1)]:
                    self._tracker.init(conv_detect2track(box,self._input_stream.real_width,
                                        self._input_stream.real_height),self.tracker_frame)
                    self._trackers.append(self._tracker)
                    num_tracked += 1
                    if num_tracked <= self.config.NUM_TRACKERS:
                        break
            self._first_track = False

        for idx,self._tracker in enumerate(self._trackers):
            tracker_box = self._tracker.update(self.frame)
            self._tracker_boxes[idx,:] = conv_track2detect(tracker_box,
                                                    self._input_stream.real_width,
                                                    self._input_stream.real_height)
        self._tracker_counter += 1
        self.boxes = self._tracker_boxes
        # Deactivate Tracker
        if self._tracker_counter >= self.config.TRACKER_FRAMES:
            self._track = False
            self._tracker_counter = 0

    def activate_tracker(self):
        """
        activates KCF tracker
        deactivates mask detection
        """
        #self.masks = None
        self.tracker_frame = self.frame
        self._track = True
        self._first_track = True
コード例 #3
0
 def start(self):
     """
     starts fps and visualizer class
     """
     self.fps.start()
     self._visualizer = Visualizer(self.config).start()
コード例 #4
0
class Model(object):
    """
    Base Tensorflow Inference Model Class
    """
    def __init__(self, config):
        self.config = config
        self.detection_graph = tf.Graph()
        self.category_index = None
        self.masks = None
        self._tf_config = tf.ConfigProto(allow_soft_placement=True)
        self._tf_config.gpu_options.allow_growth = True
        self._run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE)
        self._run_metadata = False
        self._wait_thread = False
        print('> Model: {}'.format(self.config.MODEL_PATH))

    def download_model(self):
        """
        downlaods model from model_zoo
        """
        if self.config.MODEL_TYPE == 'dl':
            download_base = 'http://download.tensorflow.org/models/'
        elif self.config.MODEL_TYPE == 'od':
            download_base = 'http://download.tensorflow.org/models/object_detection/'
        model_file = self.config.MODEL_NAME + '.tar.gz'
        if not os.path.isfile(
                self.config.MODEL_PATH) and self.config.DOWNLOAD_MODEL:
            print('> Model not found. Downloading it now.')
            opener = urllib.request.URLopener()
            opener.retrieve(download_base + model_file, model_file)
            tar_file = tarfile.open(model_file)
            for file in tar_file.getmembers():
                file_name = os.path.basename(file.name)
                if 'frozen_inference_graph.pb' in file_name:
                    tar_file.extract(file, os.getcwd() + '/models/')
            os.remove(os.getcwd() + '/' + model_file)
        else:
            print('> Model found. Proceed.')

    def _node_name(self, n):
        if n.startswith("^"):
            return n[1:]
        else:
            return n.split(":")[0]

    def load_frozenmodel(self):
        """
        loads graph from frozen model file
        """
        print('> Loading frozen model into memory')
        if (self.config.MODEL_TYPE == 'od' and self.config.SPLIT_MODEL):
            # load a frozen Model and split it into GPU and CPU graphs
            # Hardcoded split points for ssd_mobilenet
            input_graph = tf.Graph()
            with tf.Session(graph=input_graph, config=self._tf_config):
                if self.config.SSD_SHAPE == 600:
                    shape = 7326
                else:
                    shape = 1917
                self.score = tf.placeholder(tf.float32,
                                            shape=(None, shape,
                                                   self.config.NUM_CLASSES),
                                            name=self.config.SPLIT_NODES[0])
                self.expand = tf.placeholder(tf.float32,
                                             shape=(None, shape, 1, 4),
                                             name=self.config.SPLIT_NODES[1])
                for node in input_graph.as_graph_def().node:
                    if node.name == self.config.SPLIT_NODES[0]:
                        score_def = node
                    if node.name == self.config.SPLIT_NODES[1]:
                        expand_def = node

            with self.detection_graph.as_default():
                od_graph_def = tf.GraphDef()
                with tf.gfile.GFile(self.config.MODEL_PATH, 'rb') as fid:
                    serialized_graph = fid.read()
                    od_graph_def.ParseFromString(serialized_graph)

                    edges = {}
                    name_to_node_map = {}
                    node_seq = {}
                    seq = 0
                    for node in od_graph_def.node:
                        n = self._node_name(node.name)
                        name_to_node_map[n] = node
                        edges[n] = [self._node_name(x) for x in node.input]
                        node_seq[n] = seq
                        seq += 1
                    for d in self.config.SPLIT_NODES:
                        assert d in name_to_node_map, "%s is not in graph" % d

                    nodes_to_keep = set()
                    next_to_visit = self.config.SPLIT_NODES[:]

                    while next_to_visit:
                        n = next_to_visit[0]
                        del next_to_visit[0]
                        if n in nodes_to_keep: continue
                        nodes_to_keep.add(n)
                        next_to_visit += edges[n]

                    nodes_to_keep_list = sorted(list(nodes_to_keep),
                                                key=lambda n: node_seq[n])
                    nodes_to_remove = set()

                    for n in node_seq:
                        if n in nodes_to_keep_list: continue
                        nodes_to_remove.add(n)
                    nodes_to_remove_list = sorted(list(nodes_to_remove),
                                                  key=lambda n: node_seq[n])

                    keep = graph_pb2.GraphDef()
                    for n in nodes_to_keep_list:
                        keep.node.extend([copy.deepcopy(name_to_node_map[n])])

                    remove = graph_pb2.GraphDef()
                    remove.node.extend([score_def])
                    remove.node.extend([expand_def])
                    for n in nodes_to_remove_list:
                        remove.node.extend(
                            [copy.deepcopy(name_to_node_map[n])])

                    with tf.device('/gpu:0'):
                        tf.import_graph_def(keep, name='')
                    with tf.device('/cpu:0'):
                        tf.import_graph_def(remove, name='')
        else:
            # default model loading procedure
            with self.detection_graph.as_default():
                od_graph_def = tf.GraphDef()
                with tf.gfile.GFile(self.config.MODEL_PATH, 'rb') as fid:
                    serialized_graph = fid.read()
                    od_graph_def.ParseFromString(serialized_graph)
                    tf.import_graph_def(od_graph_def, name='')

    def load_labelmap(self):
        """
        creates categorie_index from label_map
        """
        print('> Loading label map')
        label_map = tf_utils.load_labelmap(self.config.LABEL_PATH)
        categories = tf_utils.convert_label_map_to_categories(
            label_map,
            max_num_classes=self.config.NUM_CLASSES,
            use_display_name=True)
        self.category_index = tf_utils.create_category_index(categories)

    def get_tensordict(self, outputs):
        """
        returns tensordict for given tensornames list
        """
        ops = self.detection_graph.get_operations()
        all_tensor_names = {output.name for op in ops for output in op.outputs}
        self.tensor_dict = {}
        for key in outputs:
            tensor_name = key + ':0'
            if tensor_name in all_tensor_names:
                self.tensor_dict[
                    key] = self.detection_graph.get_tensor_by_name(tensor_name)
        return self.tensor_dict

    def prepare_model(self):
        """
        first step prepare model
        needs to be called by subclass in re-write process

        Necessary: subclass needs to init
        self._input_stream
        """
        if self.config.MODEL_TYPE is 'od':
            self.download_model()
            self.load_frozenmodel()
            self.load_labelmap()
        elif self.config.MODEL_TYPE is 'dl':
            self.download_model()
            self.load_frozenmodel()
        self.fps = FPS(self.config.FPS_INTERVAL).start()
        self._visualizer = Visualizer(self.config).start()
        return self

    def isActive(self):
        """
        checks if stream and visualizer are active
        """
        return self._input_stream.isActive() and self._visualizer.isActive()

    def stop(self):
        """
        stops all sub classes
        """
        self._input_stream.stop()
        self._visualizer.stop()
        self.fps.stop()
        if self.config.SPLIT_MODEL and self.config.MODEL_TYPE is 'od':
            self._gpu_worker.stop()
            self._cpu_worker.stop()

    def detect(self):
        """
        needs to be written by subclass
        """
        self.detection = None

    def run(self):
        """
        runs detection loop on video or image
        listens on isActive()
        """
        print("> starting detection")
        self.start()
        while self.isActive():
            # detection
            self.detect()
            # Visualization
            if not self._wait_thread:
                self.visualize_detection()
                self.fps.update()
        self.stop()

    def start(self):
        """
        starts fps and visualizer class
        """
        self.fps.start()
        self._visualizer = Visualizer(self.config).start()

    def visualize_detection(self):
        self.detection = self._visualizer.visualize_detection(
            self.frame, self.boxes, self.classes, self.scores, self.masks,
            self.fps.fps_local(), self.category_index)