def generate_iou_data(queries,
                      path,
                      if_data,
                      use_geometric=False,
                      iqd_params=None,
                      pred_weight=None):
    for query_id, query in tqdm(enumerate(queries), total=len(queries)):
        object_id_list = []
        iou_list = []
        tp_simple = data_utils.get_partial_query_matches(
            if_data.vg_data, queries)
        for image_id in tqdm(tp_simple[query_id]):
            if iqd_params is None:
                iqd_params = {}
            iqd = ImageQueryData(query,
                                 query_id,
                                 image_id,
                                 if_data,
                                 use_geometric=use_geometric,
                                 pred_weight=pred_weight,
                                 **iqd_params)
            iqd.compute_potential_data()
            object_id_list.extend(
                [iqd.model_sub_bbox_id, iqd.model_obj_bbox_id])
            iou_list.extend([iqd.sub_model_iou, iqd.obj_model_iou])
        iou_name = 'q{:03d}_iou_values.csv'.format(query_id)
        iou_path = os.path.join(path, iou_name)
        with open(iou_path, 'wb') as f:
            csv_writer = csv.writer(f)
            csv_writer.writerow(('image_ix', 'obj_ix', 'iou'))
            for image_id, (object_id,
                           iou) in enumerate(zip(object_id_list, iou_list)):
                csv_writer.writerow((image_id, object_id, iou))
Beispiel #2
0
def save_query_texts(queries, if_data, path):
    """Save the texts of the given queries to a file."""
    tp_simple = data_utils.get_partial_query_matches(if_data.vg_data, queries)

    with open(path, 'w') as f:
        for index, (query, image_index) in enumerate(zip(queries, tp_simple)):
            iqd = query_viz.ImageQueryData(query, index, 0, if_data)
            f.write('[{}] (count: {}): {}\n'.format(index, len(image_index),
                                                    iqd.get_query_text()))
Beispiel #3
0
def iou_check(queries, if_data):
    tp_simple = data_utils.get_partial_query_matches(if_data.vg_data, queries)
    energy_path = os.path.join(out_path, 'query_ious/')
    geom_energy_path = os.path.join(out_path, 'query_ious_geom/')
    data_simple = [(energy_path, 'factor graph'),
                   (geom_energy_path, 'geometric mean')]
    data_utils.get_iou_recall_values(data_simple,
                                     tp_simple,
                                     len(if_data.vg_data),
                                     show_plot=True)
Beispiel #4
0
def export_images(queries, if_data, false_negs, path, json_data):
    """Save the counts of the given queries to a file."""
    tp_simple = data_utils.get_partial_query_matches(if_data.vg_data, queries)
    for query_index, image_index in false_negs:
        tp_simple[query_index].append(image_index)

    pairs = enumerate(zip(queries, tp_simple))
    for index, (query, image_indexes) in tqdm(pairs,
                                              desc='queries',
                                              total=len(queries)):
        iqd = ImageQueryData(query, index, 0, if_data)
        query_dir = iqd.get_query_text().replace(' ', '_')
        query_path = os.path.join(path, query_dir)
        if not os.path.exists(query_path):
            os.makedirs(query_path)
        for image_index in image_indexes:
            if_data.configure(image_index, None)
            shutil.copy(if_data.image_filename, query_path)
            image_name = os.path.basename(if_data.image_filename)
            json_name = os.path.splitext(image_name)[0] + '.json'
            with open(os.path.join(query_path, json_name), 'w') as f:
                json.dump(json_data[image_index], f)

    neg_path = os.path.join(path, 'negative_images')
    if not os.path.exists(neg_path):
        os.makedirs(neg_path)
    full_negs = [
        index for index in range(len(if_data.vg_data))
        if all([index not in image_index for image_index in tp_simple])
    ]

    for image_index in full_negs:
        if_data.configure(image_index, None)
        shutil.copy(if_data.image_filename, neg_path)
        image_name = os.path.basename(if_data.image_filename)
        json_name = os.path.splitext(image_name)[0] + '.json'
        with open(os.path.join(neg_path, json_name), 'w') as f:
            json.dump(json_data[image_index], f)
Beispiel #5
0
def recall_check(queries, if_data, false_negs, situate_recalls):
    """Compute recalls for a given set of queries and plot."""
    tp_simple = data_utils.get_partial_query_matches(if_data.vg_data, queries)
    for query_index, image_index in false_negs:
        tp_simple[query_index].append(image_index)
    energy_path = os.path.join(out_path, 'query_energies_psu/')
    geom_energy_path = os.path.join(out_path, 'query_energies_geom_psu/')
    true_geom_energy_path = os.path.join(out_path,
                                         'query_energies_true_geom_psu/')
    weighted_energy_path = os.path.join(out_path,
                                        'query_energies_weighted_psu/')
    rcnn_energy_path = os.path.join(out_path, 'query_energies_rcnn_psu/')
    data_simple = [(energy_path, 'vanilla IRSG'),
                   (geom_energy_path, 'geometric mean on potentials'),
                   (true_geom_energy_path, 'true geometric mean'),
                   (weighted_energy_path, 'weighted IRSG'),
                   (rcnn_energy_path, 'RCNN-weighted IRSG')]

    data_utils.get_single_image_recall_values(data_simple,
                                              tp_simple,
                                              len(if_data.vg_data),
                                              show_plot=True,
                                              situate_recalls=situate_recalls)
Beispiel #6
0
def save_query_counts(queries, if_data, false_negs, path):
    """Save the counts of the given queries to a file."""
    tp_simple = data_utils.get_partial_query_matches(if_data.vg_data, queries)
    for query_index, image_index in false_negs:
        tp_simple[query_index].append(image_index)

    with open(path, 'w') as f:
        for index, (query, image_index) in enumerate(zip(queries, tp_simple)):
            iqd = ImageQueryData(query, index, 0, if_data)
            f.write('{} ({} pos, {} neg)\n'.format(
                iqd.get_query_text(), len(image_index),
                len(if_data.vg_data) - len(image_index)))
            f.write('positives: {}\n'.format(image_index))
            duplicates = list(
                set([
                    index for index in image_index
                    if image_index.count(index) > 1
                ]))
            f.write('duplicates: {}\n\n'.format(duplicates))
            full_negs = [
                index for index in range(len(if_data.vg_data))
                if all([index not in image_index for image_index in tp_simple])
            ]
        f.write('negatives for ALL queries: {}'.format(len(full_negs)))
Beispiel #7
0
            parts = [part.replace('_', ' ') for part in line.split()]
            text_parts, gen_func = parts[:-1], gen_dict[parts[-1]]
            query_struct.annotations = gen_func(
                *text_parts, use_attrs=use_attrs)
            queries.append(query_struct)
    return queries


if __name__ == '__main__':
    _, _, _, _, _, if_data = dp.get_all_data(
        'stanford', split='test', use_csv=True)
    path = os.path.join(data_path, 'queries.txt')
    queries = generate_queries_from_file(path)
    false_neg_path = os.path.join(data_path, 'false_negs.csv')
    false_negs = data_utils.get_false_negs(false_neg_path)
    tp_data_pos = data_utils.get_partial_query_matches(
        if_data.vg_data, queries)
    for query_index, image_index in false_negs:
        tp_data_pos[query_index].append(image_index)
    tp_data_neg = generate_tp_neg(tp_data_pos, len(queries),
                                  len(if_data.vg_data), NEGS_PER_QUERY)
    generate_all_query_plots(queries, tp_data_pos, tp_data_neg,
                             if_data, condition_gmm=True,
                             visualize_gmm=False, false_negs=false_negs)
    generate_all_query_plots(queries, tp_data_pos, tp_data_neg,
                             if_data, condition_gmm=True,
                             visualize_gmm=False, use_geometric=True,
                             suffix='true_geom', false_negs=false_negs)
    generate_all_query_plots(queries, tp_data_pos, tp_data_neg,
                             if_data, condition_gmm=True,
                             visualize_gmm=False, use_alt_geom=True,
                             suffix='geom', false_negs=false_negs)