Пример #1
0
class TrainHandler(object):

    train_sets = Config.get('train').get('train_sets', [])
    test_sets = Config.get('train').get('test_sets', [])

    @classmethod
    def handle(cls):
        cls._download_data()
        cls._convert_data()
        cls._split_data()
        cls._train()

    @classmethod
    def _download_data(cls):
        logger.debug('Fetching data sets: ' + str(cls.train_sets))
        for name in cls.train_sets:
            Processor.download(name)
        for name in cls.test_sets:
            Processor.download(name)

    @classmethod
    def _convert_data(cls):
        pass

    @classmethod
    def _split_data(cls):
        pass

    @classmethod
    def _train(cls):
        pass
Пример #2
0
    def __init__(self, model_name):
        self.asset_dir = os.path.join(Config.get('models_dir'), model_name)
        os.system('mkdir -p {}'.format(self.asset_dir))
        self.asset_url_map = {}

        model_configs = Config.get('models')
        for conf in model_configs:
            if conf.get('name') == model_name:
                asset_urls = conf.get('asset_urls')
                for asset in asset_urls:
                    self.asset_url_map[asset['name']] = asset['url']
Пример #3
0
    def handle(cls):

        if Config.get('model') == 'ssd':
            cls.model = SSDModel()

        logger.debug('Start serving ...')
        full_video_path = os.path.join(Config.get('videos_dir'),
                                       Config.get('serve').get('video'))

        url = None
        precomputed_labels = None
        full_annotated_path = None
        confs = Config.get('videos')
        for conf in confs:
            if conf.get('name') == Config.get('serve').get('video'):
                url = conf.get('url')
                precomputed_labels = conf.get('precomputed_labels')
                full_annotated_path = os.path.join(Config.get('videos_dir'),
                                                   conf.get('annotated_name'))
                break

        # download video if necessary
        if os.path.exists(full_video_path):
            logger.debug('video already exists, skip downloading')
        else:
            os.system(
                'mkdir -p {} && wget {} -O {} --force-directories'.format(
                    Config.get('videos_dir'), url, full_video_path))

        logger.debug('Processing video at {}'.format(full_video_path))
        logger.debug(
            'Producing annotated video to {}'.format(full_annotated_path))

        # load precomputed labels if possible
        precomputed_labels_path = os.path.join(Config.get('videos_dir'),
                                               precomputed_labels)
        if os.path.exists(precomputed_labels_path):
            cls.use_precomputed = True
            with open(precomputed_labels_path, 'r') as f:
                lines = f.readlines()
                for l in lines:
                    cls.scores.append(ujson.loads(l))
            logger.debug(
                'precomputed labels file exists, skip real time prediction')

        score_fn = cls.process_precomputed if cls.use_precomputed == True else cls.process
        fps = 50 if cls.use_precomputed == True else 1000

        video_processor = VideoProcessor(full_video_path, score_fn,
                                         full_annotated_path)
        video_processor.start(
            max_frame_num=Config.get('serve').get('max_frame_num'), fps=fps)

        if cls.use_precomputed == False and len(cls.scores) > 0:
            with open(precomputed_labels_path, 'w+') as f:
                for score in cls.scores:
                    f.write(str(score) + '\n')
Пример #4
0
class TrainHandler(object):

    data_sets = Config.get('train').get('data_sets', [])
    holdout_percentage = Config.get('holdout_percentage', 0.1)

    @classmethod
    def handle(cls):
        cls._download()
        train_set, val_set = cls._process()
        cls._train(train_set, val_set)

    @classmethod
    def _download(cls):
        logger.debug('Fetching data sets: {}'.format(cls.data_sets))
        for name in cls.data_sets:
            RawProcessor.download(name)

    @classmethod
    def _process(cls):
        '''
        Load raw data and labels, split them into training sets and validation sets.
        :return: None
        '''
        raw_data_map = RawProcessor.load_raw_data(cls.data_sets)
        raw_label_map = RawProcessor.load_raw_labels(cls.data_sets)

        seed(0)
        shuffled_keys = [k for k in raw_data_map]
        shuffle(shuffled_keys)

        split_index = int(
            round(len(shuffled_keys) * (1 - cls.holdout_percentage)))
        train_keys = shuffled_keys[:split_index]
        val_keys = shuffled_keys[split_index:]

        train_set = [(k, raw_data_map[k], raw_label_map[k])
                     for k in train_keys]
        val_set = [(k, raw_data_map[k], raw_label_map[k]) for k in val_keys]

        return train_set, val_set

    @classmethod
    def _train(cls, train_set, val_set):

        model = None
        if Config.get('model') == 'ssd':
            model = SSDModel()

        model.train(train_set, val_set)
Пример #5
0
class ModelConstants(object):
    MODEL_NAME = 'ssd'
    CHECKPOINT_PRETRAINED_FILE = 'ssd_300_vgg.ckpt.zip'
    CHECKPOINT_PRETRAINED = 'ssd_300_vgg.ckpt'
    CHECKPOINT_TRAINED = 'ssd_trained.ckpt'

    FULL_ASSET_PATH = os.path.join(Config.get('models_dir'), MODEL_NAME)
Пример #6
0
    def _train(cls, train_set, val_set):

        model = None
        if Config.get('model') == 'ssd':
            model = SSDModel()

        model.train(train_set, val_set)
Пример #7
0
    def _test(cls, test_set):

        model = None
        if Config.get('model') == 'ssd':
            model = SSDModel()

        output_dir = Config.get('test').get('output_path')
        slide_show = Config.get('test').get('slide_show')
        json_lines = []

        results = model.test(test_set, show=slide_show)

        for instance, result in zip(test_set, results):
            json_lines.append(cls._serialize(instance[0], result))

        with open(output_dir, 'w+') as f:
            f.writelines(json_lines)
Пример #8
0
class TestHandler(object):
    data_sets = Config.get('test').get('data_sets', [])

    @classmethod
    def handle(cls):
        cls._download()
        test_set = cls._process()
        cls._test(test_set)

    @classmethod
    def _download(cls):
        logger.debug('Fetching data sets: {}'.format(cls.data_sets))
        for name in cls.data_sets:
            RawProcessor.download(name)

    @classmethod
    def _process(cls):
        '''
        Load raw data as list of tuples.
        :return: None
        '''
        raw_data_map = RawProcessor.load_raw_data(cls.data_sets)
        return [(k, raw_data_map[k], None) for k in raw_data_map]

    @classmethod
    def _test(cls, test_set):

        model = None
        if Config.get('model') == 'ssd':
            model = SSDModel()

        output_dir = Config.get('test').get('output_path')
        slide_show = Config.get('test').get('slide_show')
        json_lines = []

        results = model.test(test_set, show=slide_show)

        for instance, result in zip(test_set, results):
            json_lines.append(cls._serialize(instance[0], result))

        with open(output_dir, 'w+') as f:
            f.writelines(json_lines)

    @classmethod
    def _serialize(self, key, result):
        """
        Neither json / ujson works. Implementing my own serializer.
        :return:
        """
        return '{{"{}": {}}}\n'.format(key, str(result))
Пример #9
0
class Processor(object):

    dataset_conf = Config.get('datasets')

    @classmethod
    def download(cls, name):
        for conf in cls.dataset_conf:
            if conf.get('name') == name:
                cls._download(name, conf.get('url'))
                return

        raise Exception('Data set {} not found in base.yaml'.format(name))

    @classmethod
    def get_raw_dataset_dir(cls, name):
        return '{}/raw/{}'.format(Config.get('data_raw_dir'), name)

    @classmethod
    def _download(cls, name, url):
        dataset_dir = cls.get_raw_dataset_dir(name)
        os.system('mkdir -p {}'.format(dataset_dir))
        os.system('wget {} -P {}'.format(url, dataset_dir))
Пример #10
0
 def get_raw_dataset_dir(cls, name):
     return '{}/raw/{}'.format(Config.get('data_raw_dir'), name)
Пример #11
0
class TestConfig(unittest.TestCase):
    def setUp(self):
        self.current_path = os.path.abspath(os.path.dirname(__file__))
        fullpath = os.path.join(self.current_path, "test_config")

        data = []
        data.append("[Ignore]")
        data.append("list = user1,user2,user3,user4,user5,user6")
        data.append("list3 = ")
        data.append("list4 = user10")
        data.append("")
        data.append("[Channels]")
        data.append("general = 000000000000000000001")
        data.append("test = 000000000000000000002")
        data.append("awesome = 000000000000000000003")
        data.append("asdf = 000000000000000000004")
        data.append("cool = 000000000000000000005")
        data.append("voice = 000000000000000000006")

        with open(fullpath, 'w') as f:
            f.write("\n".join(data))

        self.channel_config = Config(fullpath, "Channels")
        self.ignore_config = Config(fullpath, "Ignore")

    def tearDown(self):
        new_file = os.path.join(self.current_path, "new_file")
        if os.path.isfile(new_file):
            os.remove(new_file)

    def test_create_file(self):
        new_file = os.path.join(self.current_path, "new_file")
        self.assertFalse(os.path.isfile(new_file))
        Config(new_file, "MySection")
        self.assertTrue(os.path.isfile(new_file))
        os.remove(new_file)

    def test_get(self):
        channel_id = self.channel_config.get("test")
        self.assertEqual(channel_id, "000000000000000000002")

    def test_get_as_list(self):
        users = self.ignore_config.get_as_list("list")
        user_list = ["user1", "user2", "user3", "user4", "user5", "user6"]
        self.assertListEqual(users, user_list)

    def test_get_all(self):
        channels = self.channel_config.get_all()
        channel_list = ["general", "test", "awesome", "asdf", "cool", "voice"]
        self.assertListEqual(channels, channel_list)

    def test_save(self):
        self.channel_config.save("newchannel", "111111111111111111")
        channel_id = self.channel_config.get("newchannel")
        self.assertEqual(channel_id, "111111111111111111")

    def test_append(self):
        self.ignore_config.append("list", "user9001")
        users = self.ignore_config.get("list")
        self.assertIn(",user9001", users)
        # Try adding again for code coverage
        self.ignore_config.append("list", "user9001")
        users = self.ignore_config.get("list")
        self.assertIn(",user9001", users)
        # Try adding it to a new option to assert option is created
        self.ignore_config.append("list2", "user9001")
        users = self.ignore_config.get("list2")
        self.assertIn("user9001", users)
        # Try adding it to an empty option
        self.ignore_config.append("list3", "user9001")
        users = self.ignore_config.get("list3")
        self.assertIn("user9001", users)

    def test_truncate(self):
        # Try a option that doesn't exist
        self.ignore_config.truncate("list2", "user3")
        does_have = self.ignore_config.has("list2")
        self.assertFalse(does_have)
        for x in range(0, 6):
            user = "******".format(x)
            self.ignore_config.truncate("list", user)
            users = self.ignore_config.get("list")
            self.assertNotIn(user, users, "Running " + user)

        # Test deleting the last one deletes the entire option
        self.ignore_config.truncate("list", "user6")
        with self.assertRaises(configparser.NoOptionError):
            # get raises and exception if it can't find the option. This is what we are testing.
            self.ignore_config.get("list")

    def test_delete(self):
        self.channel_config.delete("asdf")
        channels = self.channel_config.get_all()
        self.assertNotIn("asdf", channels)

    def test_has(self):
        does_contain = self.channel_config.has("awesome")
        self.assertTrue(does_contain)

    def test_contains(self):
        # Test list contains
        does_contain = self.ignore_config.contains("list", "user5")
        self.assertTrue(does_contain)
        # Test missing option doesn't contain
        does_contain = self.ignore_config.contains("list2", "user5")
        self.assertFalse(does_contain)
        # Test empty option doesn't contain
        does_contain = self.ignore_config.contains("list3", "user5")
        self.assertFalse(does_contain)
        # Test non-list contains
        does_contain = self.ignore_config.contains("list4", "user10")
        self.assertTrue(does_contain)
        # Test non-list doesn't contain
        does_contain = self.ignore_config.contains("list4", "user5")
        self.assertFalse(does_contain)
Пример #12
0
 def _get_raw_data_set_dir(cls, name):
     return os.path.join(Config.get('data_raw_dir'), name)
Пример #13
0
class RawProcessor(object):

    data_set_conf = Config.get('data_sets')

    @classmethod
    def download(cls, name):
        for conf in cls.data_set_conf:
            if conf.get('name') == name:

                data_set_dir = cls._get_raw_data_set_dir(name)
                url, compression_format = conf.get('url'), conf.get(
                    'compression_format')

                logger.debug('Downloading data set: {}'.format(name))
                # skip download if data is present
                if os.path.exists(data_set_dir) and len(
                        os.listdir(data_set_dir)) > 0:
                    logger.debug('Skip downloading, use cached files instead.')
                    return

                os.system('mkdir -p {}'.format(data_set_dir))
                os.system('wget {} -P {}'.format(url, data_set_dir))

                if (compression_format == 'zip'):
                    os.system(
                        'unzip {d}/*.zip -d {d} && rm -rf {d}/*.zip'.format(
                            d=data_set_dir))
                return
        raise Exception('Data set {} not found in base.yaml'.format(name))

    @classmethod
    def load_raw_data(cls, names):
        '''
        load raw data into numpy ndarray
        :param names: names of data sets
        :return: map of {fname, ndarray}
        '''
        data_map = {}
        for name in names:
            for conf in cls.data_set_conf:
                if conf.get('name') == name:
                    data_set_dir = cls._get_raw_data_set_dir(name)
                    for file_name, full_file_name in cls._get_files_generator(
                            os.path.join(data_set_dir,
                                         conf.get('folder_name')),
                            conf.get('data_format')):
                        im = cv2.imread(full_file_name)
                        data_map[file_name] = im

        return data_map

    @classmethod
    def load_raw_labels(cls, names):
        '''
        load raw labels into lists of [x, y, w, h, category]
        :param names: names of data sets
        :return: map of {fname, label_list}
        '''
        label_map = {}
        for name in names:
            for conf in cls.data_set_conf:
                if conf.get('name') == name:
                    data_set_dir = cls._get_raw_data_set_dir(name)
                    if conf.get('label_format') == 'idl':
                        # format {"60094.jpg": [[171.33312, 188.49996000000002, 243.8336, 240.66647999999998, 1]]}
                        for _, full_file_name in cls._get_files_generator(
                                os.path.join(data_set_dir,
                                             conf.get('folder_name')), 'idl'):
                            with open(full_file_name) as f:
                                for line in f:
                                    d = ujson.loads(line)
                                    for k in d:
                                        label_map[k] = d.get(k)
        return label_map

    @classmethod
    def _get_raw_data_set_dir(cls, name):
        return os.path.join(Config.get('data_raw_dir'), name)

    @classmethod
    def _get_files_generator(cls, directory, extension):
        """
        :param directory:
        :param extension:
        :return: a generator of tuples (file_name, full_file_name)
        """
        for dir_path, sub_dir_paths, files in os.walk(directory):
            for f in files:
                if f.endswith(extension):
                    yield f, os.path.join(dir_path, f)