示例#1
0
import os
import ujson
from src.model import SSDModel
from src.utils import Config, Logger, VideoProcessor

logger = Logger.get_logger('ServeHandler')


class ServeHandler(object):
    model = None
    scores = []
    frame_cnt = 0
    use_precomputed = False

    @classmethod
    def handle(cls):

        if Config.get('model') == 'ssd':
            cls.model = SSDModel()

        logger.debug('Start serving ...')
        full_video_path = os.path.join(Config.get('videos_dir'),
                                       Config.get('serve').get('video'))

        url = None
        precomputed_labels = None
        full_annotated_path = None
        confs = Config.get('videos')
        for conf in confs:
            if conf.get('name') == Config.get('serve').get('video'):
                url = conf.get('url')
示例#2
0
from src.data import Processor
from src.utils import Config, Logger
import urllib

logger = Logger.get_logger('TrainHandler')


class TrainHandler(object):

    train_sets = Config.get('train').get('train_sets', [])
    test_sets = Config.get('train').get('test_sets', [])

    @classmethod
    def handle(cls):
        cls._download_data()
        cls._convert_data()
        cls._split_data()
        cls._train()

    @classmethod
    def _download_data(cls):
        logger.debug('Fetching data sets: ' + str(cls.train_sets))
        for name in cls.train_sets:
            Processor.download(name)
        for name in cls.test_sets:
            Processor.download(name)

    @classmethod
    def _convert_data(cls):
        pass
示例#3
0
SSD_TO_RAW_CLASS_MAPPING = {
    7: 1,  # vehicle
    15: 2,  # pedestrian
    2: 3,  # cyclist
    # 21: 20, # traffic lights
}

RAW_TO_SSD_CLASS_MAPPING = {
    1: 7,  # vehicle
    2: 15,  # pedestrian
    3: 2,  # cyclist
    # 20: 21, # traffic lights
}

logger = Logger.get_logger('SSD')


class SSDModel(BaseModel):
    """ SSD Model """
    def __init__(self):
        BaseModel.__init__(self, ModelConstants.MODEL_NAME)

        self.session = None
        self.image_4d = None
        self.predictions = None
        self.localisations = None
        self.img_input = None  # tf placeholder
        self.bbox_img = None
        self.net_shape = (300, 300)
        self.ssd_anchors = None
示例#4
0
import os

from src.utils import Config, Logger

logger = Logger.get_logger('BaseModel')


class BaseModel(object):
    def __init__(self, model_name):
        self.asset_dir = os.path.join(Config.get('models_dir'), model_name)
        os.system('mkdir -p {}'.format(self.asset_dir))
        self.asset_url_map = {}

        model_configs = Config.get('models')
        for conf in model_configs:
            if conf.get('name') == model_name:
                asset_urls = conf.get('asset_urls')
                for asset in asset_urls:
                    self.asset_url_map[asset['name']] = asset['url']

    def _download_asset(self, asset_name):

        logger.debug('Downloading asset: {}'.format(asset_name))
        full_asset_name = os.path.join(self.asset_dir, asset_name)

        if os.path.exists(full_asset_name):
            logger.debug('Skip downloading, use cached files instead.')
            return

        os.system('wget {} -O {}'.format(self.asset_url_map.get(asset_name),
                                         full_asset_name))
示例#5
0
import os

import cv2

from src.utils import Logger, Visualizer

logger = Logger.get_logger('VideoProcessor')


class VideoProcessor(object):
    def __init__(self, path, score_fn, annotated_path):

        self.score_fn = score_fn
        self.annotated_path = annotated_path
        self.visualizer = Visualizer()

        if not os.path.exists(path):
            raise IOError('file %s does not exist'.format(path))
        self.capture = cv2.VideoCapture(path)
        if os.path.exists(annotated_path):
            os.remove(annotated_path)
        self.writer = cv2.VideoWriter(annotated_path,
                                      cv2.VideoWriter_fourcc(*'XVID'), 50.0,
                                      (640, 360))

        while not self.capture.isOpened():
            cv2.waitKey(1000)
            logger.debug('Wait for header')

    def start(self, max_frame_num=2 << 32, fps=1000):
        num_frames = min(int(self.capture.get(cv2.CAP_PROP_FRAME_COUNT)),
示例#6
0
from src.utils import Config, Logger
import cv2
import os
import ujson

logger = Logger.get_logger('Processor')


class RawProcessor(object):

    data_set_conf = Config.get('data_sets')

    @classmethod
    def download(cls, name):
        for conf in cls.data_set_conf:
            if conf.get('name') == name:

                data_set_dir = cls._get_raw_data_set_dir(name)
                url, compression_format = conf.get('url'), conf.get(
                    'compression_format')

                logger.debug('Downloading data set: {}'.format(name))
                # skip download if data is present
                if os.path.exists(data_set_dir) and len(
                        os.listdir(data_set_dir)) > 0:
                    logger.debug('Skip downloading, use cached files instead.')
                    return

                os.system('mkdir -p {}'.format(data_set_dir))
                os.system('wget {} -P {}'.format(url, data_set_dir))
示例#7
0
from src.data import RawProcessor
from src.model import SSDModel
from src.utils import Config, Logger

logger = Logger.get_logger('TestHandler')


class TestHandler(object):
    data_sets = Config.get('test').get('data_sets', [])

    @classmethod
    def handle(cls):
        cls._download()
        test_set = cls._process()
        cls._test(test_set)

    @classmethod
    def _download(cls):
        logger.debug('Fetching data sets: {}'.format(cls.data_sets))
        for name in cls.data_sets:
            RawProcessor.download(name)

    @classmethod
    def _process(cls):
        '''
        Load raw data as list of tuples.
        :return: None
        '''
        raw_data_map = RawProcessor.load_raw_data(cls.data_sets)
        return [(k, raw_data_map[k], None) for k in raw_data_map]