コード例 #1
0
ファイル: supplementary_clsex.py プロジェクト: KarmaGa/CTW
def main(models, n):
    assert six.PY3

    if not common_tools.exists_and_newer(settings.TEST_CLS_CROPPED, settings.TEST_CLASSIFICATION):
        print('creating', settings.TEST_CLS_CROPPED)
        predictions2html.create_pkl()

    with open(settings.TEST_CLS_CROPPED, 'rb') as f:
        gts = cPickle.load(f)
    with open(settings.STAT_FREQUENCY) as f:
        stat_freq = json.load(f)
    preds = []
    for model in models:
        all = []
        with open(model['predictions_file_path']) as f:
            for line in f:
                obj = json.loads(line.strip())
                all += list(zip(obj['predictions'], obj['probabilities']))
        assert len(gts) == len(all)
        preds.append(all)

    np.random.seed(n + 2018)
    sampled = np.array(range(len(gts)))
    np.random.shuffle(sampled)
    sampled = sampled[:n]

    dir_name = 'cls_examples'
    root = os.path.join(settings.PRODUCTS_ROOT, dir_name)
    if not os.path.isdir(root):
        os.makedirs(root)

    def text2minipage(text):
        i = [o['text'] for o in stat_freq].index(text)
        return r'\begin{minipage}{3.5mm} \includegraphics[width=\linewidth]{figure/texts/' + '0_{}.png'.format(i) + r'} \end{minipage}'

    for no, i in enumerate(sampled):
        file_name = '{}.png'.format(i)
        image, gt = gts[i]
        image = misc.toimage(image)
        image.save(os.path.join(root, file_name), format='png')

        s = '{} & '.format(no + 1) + r'\begin{minipage}{6.0mm} \includegraphics[width=\linewidth]{figure/cls_examples/' + '{}.png'.format(i) + r'} \end{minipage} &' + '\n'
        s += '{} &\n'.format(text2minipage(gt['text']))

        for j, preds_model in enumerate(preds):
            texts, probs = preds_model[i]
            prob_text = '{:5.1f}'.format(round(probs[0] * 1000) / 10.)
            prob_text = prob_text.replace(' ', r'\,\,\,')
            col = '{} {}'.format(text2minipage(texts[0]), prob_text)
            if texts[0] == gt['text']:
                col = r'\multicolumn{1}{>{\columncolor{cls_correct}}r}{' + col + '}'
            if j == len(preds) - 1:
                col += r' \\'
            else:
                col += ' &'
            s += col + '\n'
        print(s.replace('\n', ' ').strip())
コード例 #2
0
def main():
    write_darknet_data()
    write_darknet_cfg()
    write_darknet_names()
    assert os.path.isfile(settings.DARKNET_PRETRAIN) and 79327120 == os.path.getsize(settings.DARKNET_PRETRAIN), \
            'please download {} to {}'.format('https://pjreddie.com/media/files/darknet19_448.conv.23', settings.DARKNET_PRETRAIN)
    if not common_tools.exists_and_newer(settings.DARKNET_TRAIN_LIST,
                                         settings.CATES):
        crop_train_images()
コード例 #3
0
def main():
    write_darknet_test_cfg()
    if not common_tools.exists_and_newer(settings.DARKNET_VALID_LIST,
                                         settings.CATES):
        crop_test_images(settings.DARKNET_VALID_LIST)
コード例 #4
0
def main():
    if not common_tools.exists_and_newer(settings.TEST_LIST, settings.CATES):
        dn_prepare = imp.load_source('dn_prepare',
                                     '../detection/prepare_test_data.py')
        dn_prepare.crop_test_images(settings.TEST_LIST)
コード例 #5
0
def main(models, n):
    assert six.PY3

    if not common_tools.exists_and_newer(settings.TEST_CLS_CROPPED,
                                         settings.TEST_CLASSIFICATION):
        print('creating', settings.TEST_CLS_CROPPED)
        create_pkl()

    with open(settings.TEST_CLS_CROPPED, 'rb') as f:
        gts = cPickle.load(f)
    preds = []
    for model in models:
        all = []
        with open(model['predictions_file_path']) as f:
            for line in f:
                obj = json.loads(line.strip())
                all += list(zip(obj['predictions'], obj['probabilities']))
        assert len(gts) == len(all)
        preds.append(all)

    np.random.seed(0)
    sampled = np.array(range(len(gts)))
    np.random.shuffle(sampled)
    sampled = sampled[105:105 + n]

    with open('predictions_compare.template.html') as f:
        template = Template(f.read())
    rows = []
    for i in sampled:
        row = dict()
        image, gt = gts[i]
        image = misc.toimage(image)
        bytesio = six.BytesIO()
        image.save(bytesio, format='png')
        png_base64 = base64.b64encode(bytesio.getvalue())

        row['png_base64'] = png_base64.decode('ascii')
        row['ground_truth'] = gt['text']
        row_models = []
        for preds_model in preds:
            texts, probs = preds_model[i]
            prob_text = '{:5.1f}%'.format(round(probs[0] * 1000) / 10.)
            prob_text = prob_text.replace('  ', ' ')
            row_models.append({
                'text': texts[0],
                'prob': prob_text,
            })
        row['models'] = row_models
        rows.append(row)
    with open(settings.PREDICTIONS_HTML, 'w') as f:
        f.write(
            template.render({
                'models':
                list(map(operator.itemgetter('display_name'), models)),
                'rows':
                rows,
            }))

    with open(settings.STAT_FREQUENCY) as f:
        freq = json.load(f)
    text2idx = {o['text']: i for i, o in enumerate(freq)}

    dir_name = 'cls_examples'
    root = os.path.join(settings.PRODUCTS_ROOT, dir_name)
    if not os.path.isdir(root):
        os.makedirs(root)

    def text2minipage(text):
        return r'\begin{minipage}{3mm} \includegraphics[width=\linewidth]{figure/texts/' \
            + '0_{}.png'.format(text2idx[text]) + r'} \end{minipage}'

    for no, i in enumerate(sampled):
        file_name = '{}.png'.format(i)
        image, gt = gts[i]
        image = misc.toimage(image)
        image.save(os.path.join(root, file_name), format='png')

        s = r'\begin{minipage}{6.0mm} \includegraphics[width=\linewidth]{figure/cls_examples/' + '{}.png'.format(
            i) + r'} \end{minipage} &' + '\n'
        s += '{} &\n'.format(text2minipage(gt['text']))

        for j, preds_model in enumerate(preds):
            texts, probs = preds_model[i]
            prob_text = '{:5.1f}'.format(round(probs[0] * 1000) / 10.)
            prob_text = prob_text.replace(' ', r'\,\,\,')
            col = '{} {}'.format(text2minipage(texts[0]), prob_text)
            if texts[0] == gt['text']:
                col = r'\multicolumn{1}{>{\columncolor{cls_correct}}r}{' + col + '}'
            if j == len(preds) - 1:
                col += r' \\'
            else:
                col += ' &'
            s += col + '\n'