def generate_iou_data(queries, path, if_data, use_geometric=False, iqd_params=None, pred_weight=None): for query_id, query in tqdm(enumerate(queries), total=len(queries)): object_id_list = [] iou_list = [] tp_simple = data_utils.get_partial_query_matches( if_data.vg_data, queries) for image_id in tqdm(tp_simple[query_id]): if iqd_params is None: iqd_params = {} iqd = ImageQueryData(query, query_id, image_id, if_data, use_geometric=use_geometric, pred_weight=pred_weight, **iqd_params) iqd.compute_potential_data() object_id_list.extend( [iqd.model_sub_bbox_id, iqd.model_obj_bbox_id]) iou_list.extend([iqd.sub_model_iou, iqd.obj_model_iou]) iou_name = 'q{:03d}_iou_values.csv'.format(query_id) iou_path = os.path.join(path, iou_name) with open(iou_path, 'wb') as f: csv_writer = csv.writer(f) csv_writer.writerow(('image_ix', 'obj_ix', 'iou')) for image_id, (object_id, iou) in enumerate(zip(object_id_list, iou_list)): csv_writer.writerow((image_id, object_id, iou))
def save_query_texts(queries, if_data, path): """Save the texts of the given queries to a file.""" tp_simple = data_utils.get_partial_query_matches(if_data.vg_data, queries) with open(path, 'w') as f: for index, (query, image_index) in enumerate(zip(queries, tp_simple)): iqd = query_viz.ImageQueryData(query, index, 0, if_data) f.write('[{}] (count: {}): {}\n'.format(index, len(image_index), iqd.get_query_text()))
def iou_check(queries, if_data): tp_simple = data_utils.get_partial_query_matches(if_data.vg_data, queries) energy_path = os.path.join(out_path, 'query_ious/') geom_energy_path = os.path.join(out_path, 'query_ious_geom/') data_simple = [(energy_path, 'factor graph'), (geom_energy_path, 'geometric mean')] data_utils.get_iou_recall_values(data_simple, tp_simple, len(if_data.vg_data), show_plot=True)
def export_images(queries, if_data, false_negs, path, json_data): """Save the counts of the given queries to a file.""" tp_simple = data_utils.get_partial_query_matches(if_data.vg_data, queries) for query_index, image_index in false_negs: tp_simple[query_index].append(image_index) pairs = enumerate(zip(queries, tp_simple)) for index, (query, image_indexes) in tqdm(pairs, desc='queries', total=len(queries)): iqd = ImageQueryData(query, index, 0, if_data) query_dir = iqd.get_query_text().replace(' ', '_') query_path = os.path.join(path, query_dir) if not os.path.exists(query_path): os.makedirs(query_path) for image_index in image_indexes: if_data.configure(image_index, None) shutil.copy(if_data.image_filename, query_path) image_name = os.path.basename(if_data.image_filename) json_name = os.path.splitext(image_name)[0] + '.json' with open(os.path.join(query_path, json_name), 'w') as f: json.dump(json_data[image_index], f) neg_path = os.path.join(path, 'negative_images') if not os.path.exists(neg_path): os.makedirs(neg_path) full_negs = [ index for index in range(len(if_data.vg_data)) if all([index not in image_index for image_index in tp_simple]) ] for image_index in full_negs: if_data.configure(image_index, None) shutil.copy(if_data.image_filename, neg_path) image_name = os.path.basename(if_data.image_filename) json_name = os.path.splitext(image_name)[0] + '.json' with open(os.path.join(neg_path, json_name), 'w') as f: json.dump(json_data[image_index], f)
def recall_check(queries, if_data, false_negs, situate_recalls): """Compute recalls for a given set of queries and plot.""" tp_simple = data_utils.get_partial_query_matches(if_data.vg_data, queries) for query_index, image_index in false_negs: tp_simple[query_index].append(image_index) energy_path = os.path.join(out_path, 'query_energies_psu/') geom_energy_path = os.path.join(out_path, 'query_energies_geom_psu/') true_geom_energy_path = os.path.join(out_path, 'query_energies_true_geom_psu/') weighted_energy_path = os.path.join(out_path, 'query_energies_weighted_psu/') rcnn_energy_path = os.path.join(out_path, 'query_energies_rcnn_psu/') data_simple = [(energy_path, 'vanilla IRSG'), (geom_energy_path, 'geometric mean on potentials'), (true_geom_energy_path, 'true geometric mean'), (weighted_energy_path, 'weighted IRSG'), (rcnn_energy_path, 'RCNN-weighted IRSG')] data_utils.get_single_image_recall_values(data_simple, tp_simple, len(if_data.vg_data), show_plot=True, situate_recalls=situate_recalls)
def save_query_counts(queries, if_data, false_negs, path): """Save the counts of the given queries to a file.""" tp_simple = data_utils.get_partial_query_matches(if_data.vg_data, queries) for query_index, image_index in false_negs: tp_simple[query_index].append(image_index) with open(path, 'w') as f: for index, (query, image_index) in enumerate(zip(queries, tp_simple)): iqd = ImageQueryData(query, index, 0, if_data) f.write('{} ({} pos, {} neg)\n'.format( iqd.get_query_text(), len(image_index), len(if_data.vg_data) - len(image_index))) f.write('positives: {}\n'.format(image_index)) duplicates = list( set([ index for index in image_index if image_index.count(index) > 1 ])) f.write('duplicates: {}\n\n'.format(duplicates)) full_negs = [ index for index in range(len(if_data.vg_data)) if all([index not in image_index for image_index in tp_simple]) ] f.write('negatives for ALL queries: {}'.format(len(full_negs)))
parts = [part.replace('_', ' ') for part in line.split()] text_parts, gen_func = parts[:-1], gen_dict[parts[-1]] query_struct.annotations = gen_func( *text_parts, use_attrs=use_attrs) queries.append(query_struct) return queries if __name__ == '__main__': _, _, _, _, _, if_data = dp.get_all_data( 'stanford', split='test', use_csv=True) path = os.path.join(data_path, 'queries.txt') queries = generate_queries_from_file(path) false_neg_path = os.path.join(data_path, 'false_negs.csv') false_negs = data_utils.get_false_negs(false_neg_path) tp_data_pos = data_utils.get_partial_query_matches( if_data.vg_data, queries) for query_index, image_index in false_negs: tp_data_pos[query_index].append(image_index) tp_data_neg = generate_tp_neg(tp_data_pos, len(queries), len(if_data.vg_data), NEGS_PER_QUERY) generate_all_query_plots(queries, tp_data_pos, tp_data_neg, if_data, condition_gmm=True, visualize_gmm=False, false_negs=false_negs) generate_all_query_plots(queries, tp_data_pos, tp_data_neg, if_data, condition_gmm=True, visualize_gmm=False, use_geometric=True, suffix='true_geom', false_negs=false_negs) generate_all_query_plots(queries, tp_data_pos, tp_data_neg, if_data, condition_gmm=True, visualize_gmm=False, use_alt_geom=True, suffix='geom', false_negs=false_negs)