示例#1
0
文件: exp.py 项目: yunshengb/GraphSim
def exp5():
    """ Query visualization. """
    dataset = 'imdbmulti'
    model = 'astar'
    concise = True
    norms = [True, False]
    dir = get_result_path() + '/{}/query_vis/{}'.format(dataset, model)
    create_dir_if_not_exists(dir)
    info_dict = {
        # draw node config
        'draw_node_size': 150 if dataset != 'linux' else 10,
        'draw_node_label_enable': True,
        'node_label_name': None if dataset == 'linux' else 'type',
        'draw_node_label_font_size': 6,
        'draw_node_color_map': TYPE_COLOR_MAP,
        # draw edge config
        'draw_edge_label_enable': False,
        'edge_label_name': 'valence',
        'draw_edge_label_font_size': 6,
        # graph text info config
        'each_graph_text_list': [],
        'each_graph_text_font_size': 8,
        'each_graph_text_pos': [0.5, 1.05],
        # graph padding: value range: [0, 1]
        'top_space': 0.20 if concise else 0.26,  # out of whole graph
        'bottom_space': 0.05,
        'hbetween_space': 0.6 if concise else 1,  # out of the subgraph
        'wbetween_space': 0,
        # plot config
        'plot_dpi': 200,
        'plot_save_path_eps': '',
        'plot_save_path_png': ''
    }
    train_data = load_data(dataset, train=True)
    test_data = load_data(dataset, train=False)
    row_graphs = test_data.graphs
    col_graphs = train_data.graphs
    r = load_result(dataset, model, row_graphs=row_graphs, col_graphs=col_graphs)
    tr = load_result(dataset, TRUE_MODEL, row_graphs=row_graphs, col_graphs=col_graphs)
    for norm in norms:
        ids = r.get_sort_id_mat(norm)
        m, n = r.m_n()
        num_vis = 10
        for i in range(num_vis):
            q = test_data.graphs[i]
            gids = np.concatenate([ids[i][:3], [ids[i][int(n / 2)]], ids[i][-3:]])
            gs = [train_data.graphs[j] for j in gids]
            info_dict['each_graph_text_list'] = \
                [get_text_label(dataset, r, tr, i, i, q, model, norm, True, concise)] + \
                [get_text_label(dataset, r, tr, i, j,
                                train_data.graphs[j], model, norm, False, concise) \
                 for j in gids]
            # print(info_dict['each_graph_text_list'])
            info_dict['plot_save_path_png'] = '{}/query_vis_{}_{}_{}{}.{}'.format(
                dir, dataset, model, i, get_norm_str(norm), 'png')
            info_dict['plot_save_path_eps'] = '{}/query_vis_{}_{}_{}{}.{}'.format(
                dir, dataset, model, i, get_norm_str(norm), 'eps')
            vis(q, gs, info_dict)
示例#2
0
def get_gs_ds_mat(gs1, gs2, dist_sim_calculator, tvt1, tvt2,
                  dataset, dist_metric, dist_algo, norm, dec_gsize, return_neg1=False):
    mat_str = '{}({})_{}({})'.format(tvt1, len(gs1), tvt2, len(gs2))
    dir = '{}/ds_mat'.format(get_save_path())
    create_dir_if_not_exists(dir)
    sfn = '{}/{}_{}_ds_mat_{}{}_{}'.format(
        dir, dataset, mat_str, dist_metric,
        get_norm_str(norm), dist_algo)
    l = load(sfn)
    if l is not None:
        print('Loaded from {}'.format(sfn))
        return l
    m = len(gs1)
    n = len(gs2)
    dist_mat = np.zeros((m, n))
    for i in range(m):
        for j in range(n):
            g1 = gs1[i]
            g2 = gs2[j]
            d, normed_d = dist_sim_calculator.calculate_dist_sim(
                g1, g2, dec_gsize=dec_gsize, return_neg1=return_neg1)
            if norm:
                dist_mat[i][j] = normed_d
            else:
                dist_mat[i][j] = d
    save(sfn, dist_mat)
    print('Saved to {}'.format(sfn))
    return dist_mat
示例#3
0
文件: exp.py 项目: yunshengb/GraphSim
def plot_preck_helper(dataset, dsmetric, models, rs, true_result, metric, norm, ks,
                      logscale, plot_results, extra_dir):
    print_ids = []
    numbers = {}
    assert (metric[0:6] == 'prec@k')
    if len(metric) > 6:
        rm = float(metric.split('_')[1])
    else:
        rm = 0
    for model in models:
        precs = prec_at_ks(true_result, rs[model], norm, ks, rm, print_ids)
        numbers[model] = {'ks': ks, 'precs': precs}
    rtn = {'preck{}_{}'.format(get_norm_str(norm), rm): numbers}
    if not plot_results:
        return rtn
    plt.figure(figsize=(16, 10))
    for model in models:
        ks = numbers[model]['ks']
        inters = numbers[model]['precs']
        if logscale:
            pltfunc = plt.semilogx
        else:
            pltfunc = plt.plot
        pltfunc(ks, inters, **get_plotting_arg(args1, model))
        plt.scatter(ks, inters, s=200, label=shorten_name(model),
                    **get_plotting_arg(args2, model))
    plt.xlabel('k')
    # ax = plt.gca()
    # ax.set_xticks(ks)
    plt.ylabel(metric)
    plt.ylim([-0.06, 1.06])
    plt.legend(loc='best', ncol=2)
    plt.grid(linestyle='dashed')
    plt.tight_layout()
    # plt.show()
    kss = 'k_{}_{}'.format(min(ks), max(ks))
    bfn = '{}_{}_{}_{}_{}{}_{}'.format(
        dsmetric, metric, dataset, '_'.join(models), kss, get_norm_str(norm), rm)
    dir = '{}/{}/{}'.format(get_result_path(), dataset, metric)
    save_fig(plt, dir, bfn)
    if extra_dir:
        save_fig(plt, extra_dir, bfn)
    print(metric, 'plotted')
    return rtn
def get_gs_ds_mat(gs1,
                  gs2,
                  dist_sim_calculator,
                  tvt1,
                  tvt2,
                  dataset,
                  dist_metric,
                  dist_algo,
                  norm,
                  dec_gsize,
                  return_neg1=False):

    if logging_enabled == True:
        print("- Entered dist_sim_calculator::get_gs_ds_mat Global Method")

    mat_str = '{}({})_{}({})'.format(tvt1, len(gs1), tvt2, len(gs2))
    dir = '{}\\ds_mat'.format(get_save_path())
    create_dir_if_not_exists(dir)
    sfn = '{}\\{}_{}_ds_mat_{}{}_{}'.format(dir, dataset, mat_str, dist_metric,
                                            get_norm_str(norm), dist_algo)

    l = load(sfn)

    if l is not None:
        print('Loaded from {}'.format(sfn))
        return l

    if not dist_sim_calculator.gidpair_ds_map:
        # dist_sim_calculator.initial_calculate_dist_sim(gs1, gs2)
        dist_sim_calculator.initial_dist_sim_pairs_with_netcomp(gs1, gs2)

    m = len(gs1)
    n = len(gs2)

    dist_mat = np.zeros((m, n))
    for i in range(m):
        for j in range(n):
            g1 = gs1[i]
            g2 = gs2[j]
            d, normed_d = dist_sim_calculator.calculate_dist_sim(
                g1, g2, dec_gsize=dec_gsize, return_neg1=return_neg1)
            if norm:
                dist_mat[i][j] = normed_d
                print("i: ", i, ", j: ", j, ", d: ", d, ", normed_d: ",
                      normed_d)

            else:
                dist_mat[i][j] = d

    save(sfn, dist_mat)
    print('Saved to {}'.format(sfn))

    return dist_mat
示例#5
0
文件: exp.py 项目: yunshengb/GraphSim
def plot_heatmap(gs1_str, gs2_str, dist_mat, thresh_pos, thresh_neg,
                 dataset, dist_metric, norm):
    m, n = dist_mat.shape
    label_mat, num_poses, num_negs, _, _ = \
        get_classification_labels_from_dist_mat(
            dist_mat, thresh_pos, thresh_neg)
    title = '{} pos pairs ({:.2%})\n{} neg pairs ({:.2%})'.format(
        num_poses, num_poses / (m * n), num_negs, num_negs / (m * n))
    sorted_label_mat = np.sort(label_mat, axis=1)[:, ::-1]
    mat_str = '{}({})_{}({})_{}_{}'.format(
        gs1_str, m, gs2_str, n, thresh_pos, thresh_neg)
    fn = '{}_acc_{}_labels_heatmap_{}{}'.format(dist_metric, mat_str,
                                                dataset, get_norm_str(norm))
    dir = '{}/{}/classif_labels'.format(get_result_path(), dataset)
    create_dir_if_not_exists(dir)
    plot_heatmap_helper(sorted_label_mat, title, dir, fn,
                        cmap='bwr')
    sorted_dist_mat = np.sort(dist_mat, axis=1)
    mat_str = '{}({})_{}({})'.format(
        gs1_str, m, gs2_str, n)
    fn = '{}_acc_{}_dist_heatmap_{}{}'.format(dist_metric, mat_str,
                                              dataset, get_norm_str(norm))
    plot_heatmap_helper(sorted_dist_mat, '', dir, fn,
                        cmap='tab20')
示例#6
0
    col_graphs = train_data.graphs
    pred_r = load_result(dataset,
                         'siamese',
                         sim_mat=emb_data['sim_mat'],
                         time_mat=emb_data['time_li'])
    # r = load_result(dataset, model, row_graphs=row_graphs, col_graphs=col_graphs)
    tr = load_result(dataset,
                     TRUE_MODEL,
                     row_graphs=row_graphs,
                     col_graphs=col_graphs)
    for norm in norms:
        ids = pred_r.sort_id_mat_
        num_vis = 10
        for i in range(len(row_graphs)):
            q = test_data.graphs[i]
            # gids = ids[i][:7]
            gids = np.concatenate(
                [ids[i][:5], [ids[i][int(len(col_graphs) / 2)]], ids[i][-1:]])
            gs = [train_data.graphs[j] for j in gids]
            weight_query = []
            weight_query.append(weight[len(col_graphs) + i])
            text = ['\n\n']
            for j in gids:
                weight_query.append(weight[j])
                rtn = '\n {}: {:.2f}'.format('sim', pred_r.sim_mat_[i][j])
                text.append(rtn)
            info_dict['each_graph_text_list'] = text
            info_dict['plot_save_path'] = '{}/query_vis_{}_{}_{}{}.{}'.format(
                dir, dataset, model, i, get_norm_str(norm), ext)
            vis(q, gs, info_dict, weight_query, weight_max, weight_min)
示例#7
0
文件: exp.py 项目: yunshengb/GraphSim
def draw_ranking(dataset, ds_metric, true_r, pred_r, model_name, node_feat_name,
                 plot_node_ids=False, plot_gids=False, ds_norm=True,
                 existing_mappings=None,
                 extra_dir=None, plot_max_num=np.inf):
    plot_what = 'query_demo'
    concise = True
    dir = get_result_path() + '/{}/{}/{}'.format(dataset, plot_what,
                                                 true_r.get_model())
    info_dict = {
        # draw node config
        'draw_node_size': 20,
        'draw_node_label_enable': True,
        'show_labels': plot_node_ids,
        'node_label_type': 'label' if plot_node_ids else 'type',
        'node_label_name': 'type',
        'draw_node_label_font_size': 6,
        'draw_node_color_map': get_color_map(true_r.get_all_gs()),
        # draw edge config
        'draw_edge_label_enable': False,
        'draw_edge_label_font_size': 6,
        # graph text info config
        'each_graph_text_list': [],
        'each_graph_text_font_size': 10,
        'each_graph_text_pos': [0.5, 1.05],
        # graph padding: value range: [0, 1]
        'top_space': 0.20 if concise else 0.26,  # out of whole graph
        'bottom_space': 0.05,
        'hbetween_space': 0.6 if concise else 1,  # out of the subgraph
        'wbetween_space': 0,
        # plot config
        'plot_dpi': 200,
        'plot_save_path_eps': '',
        'plot_save_path_png': ''
    }
    test_gs = true_r.get_row_gs()
    train_gs = None
    if true_r.has_single_col_gs():
        train_gs = true_r.get_single_col_gs()
        if plot_node_ids and existing_mappings:
            # existing_orderings: [train + val ... test]
            test_gs = reorder_gs_based_on_exsiting_mappings(
                test_gs, existing_mappings[len(train_gs):], node_feat_name)
            train_gs = reorder_gs_based_on_exsiting_mappings(
                train_gs, existing_mappings[0:len(train_gs)], node_feat_name)
    plt_cnt = 0
    ids_groundtruth = true_r.get_sort_id_mat(ds_norm)
    ids_rank = pred_r.get_sort_id_mat(ds_norm)
    for i in range(len(test_gs)):
        q = test_gs[i]
        if not true_r.has_single_col_gs():
            train_gs = true_r.get_col_gs(i)
        middle_idx = len(train_gs) // 2
        # Choose the top 6 matches, the overall middle match, and the worst match.
        selected_ids = list(range(6))
        selected_ids.extend([middle_idx, -1])
        # Get the selected graphs from the groundtruth and the model.
        gids_groundtruth = np.array(ids_groundtruth[i][selected_ids])
        gids_rank = np.array(ids_rank[i][selected_ids])
        # Top row graphs are only the groundtruth outputs.
        gs_groundtruth = [train_gs[j] for j in gids_groundtruth]
        # Bottom row graphs are the query graph + model ranking.
        gs_rank = [test_gs[i]]
        gs_rank = gs_rank + [train_gs[j] for j in gids_rank]
        gs = gs_groundtruth + gs_rank

        # Create the plot labels.
        text = []
        # First label is the name of the groundtruth algorithm, rest are scores for the graphs.
        text += [get_text_label_for_ranking(
            ds_metric, true_r, i, i, ds_norm, True, dataset, gids_groundtruth, plot_gids)]
        text += [get_text_label_for_ranking(
            ds_metric, true_r, i, j, ds_norm, False, dataset, gids_groundtruth, plot_gids)
            for j in gids_groundtruth]
        # Start bottom row labels, just ranked from 1 to N with some fancy formatting.
        text.append("Rank by\n{}".format(model_name))
        for j in range(len(gids_rank)):
            ds = format_ds(pred_r.pred_ds(i, gids_rank[j], ds_norm))
            if j == len(gids_rank) - 2:
                rtn = '\n ...   {}   ...\n{}'.format(int(len(train_gs) / 2), ds)
            elif j == len(gids_rank) - 1:
                rtn = '\n {}\n{}'.format(int(len(train_gs)), ds)
            else:
                rtn = '\n {}\n{}'.format(str(j + 1), ds)
            # rtn = '\n {}: {:.2f}'.format('sim', pred_r.sim_mat_[i][j])
            text.append(rtn)

        # Perform the visualization.
        info_dict['each_graph_text_list'] = text
        fn = '{}_{}_{}_{}{}'.format(
            plot_what, dataset, true_r.get_model(), i, get_norm_str(ds_norm))
        info_dict, plt_cnt = set_save_paths_for_vis(
            info_dict, dir, extra_dir, fn, plt_cnt)
        vis_small(q, gs, info_dict)
        if plt_cnt > plot_max_num:
            print('Saved {} query demo plots'.format(plt_cnt))
            return
    print('Saved {} query demo plots'.format(plt_cnt))
示例#8
0
文件: exp.py 项目: yunshengb/GraphSim
def plot_single_number_metric_helper(dataset, dsmetric, models, rs, true_result,
                                     metric, norm,
                                     ds_kernel, thresh_pos, thresh_neg,
                                     thresh_pos_sim, thresh_neg_sim,
                                     plot_results, extra_dir):
    # dsmetric: distance/similarity metric, e.g. ged, mcs, etc.
    # metric: eval metric.
    print_ids = []
    rtn = {}
    val_list = []
    for model in models:
        if metric == 'mrr':
            val = mean_reciprocal_rank(
                true_result, rs[model], norm, print_ids)
        elif metric == 'mse':
            val = mean_squared_error(
                true_result, rs[model], ds_kernel, norm)
        elif metric == 'dev':
            val = mean_deviation(
                true_result, rs[model], ds_kernel, norm)
        elif metric == 'time':
            val = average_time(rs[model])
        elif 'acc' in metric:
            val = accuracy(
                true_result, rs[model], thresh_pos, thresh_neg,
                thresh_pos_sim, thresh_neg_sim, norm)
            pos_acc, neg_acc, acc = val
            if metric == 'pos_acc':
                val = pos_acc
            elif metric == 'neg_acc':
                val = neg_acc
            elif metric == 'acc':
                val = acc  # only the overall acc
            else:
                assert (metric == 'accall')
        elif metric == 'kendalls_tau':
            val = kendalls_tau(true_result, rs[model], norm)
        elif metric == 'spearmans_rho':
            val = spearmans_rho(true_result, rs[model], norm)
        else:
            raise RuntimeError('Unknown {}'.format(metric))
        # print('{} {}: {}'.format(metric, model, mrr_mse_time))
        rtn[model] = val
        val_list.append(val)
    rtn = {'{}{}'.format(metric, get_norm_str(norm)): rtn}
    if not plot_results:
        return rtn
    plt = plot_multiple_bars(val_list, models, metric)
    if metric == 'time':
        ylabel = 'time (msec)'
        norm = None
    elif metric == 'pos_acc':
        ylabel = 'pos_recall'
    elif metric == 'neg_acc':
        ylabel = 'neg_recall'
    elif metric == 'kendalls_tau':
        ylabel = 'Kendall\'s $\\tau$'
    elif metric == 'spearmans_rho':
        ylabel = 'Spearman\'s $\\rho$'
    else:
        ylabel = metric
    plt.ylabel(ylabel)
    if metric == 'time':
        plt.yscale('log')
    metric_addi_info = ''
    bfn = '{}_{}{}_{}_{}{}'.format(
        dsmetric, metric, metric_addi_info,
        dataset, '_'.join(models),
        get_norm_str(norm))
    sp = get_result_path() + '/{}/{}/'.format(dataset, metric)
    save_fig(plt, sp, bfn)
    if extra_dir:
        save_fig(plt, extra_dir, bfn)
    print(metric, 'plotted')
    return rtn
    # print("min:", weight_min)
    # weight_max = 0.85
    # weight_min = 0.70
    train_data = load_data(dataset, train=True)
    test_data = load_data(dataset, train=False)
    row_graphs = test_data.graphs
    col_graphs = train_data.graphs
    r = load_result(dataset, model, row_graphs=row_graphs, col_graphs=col_graphs)
    tr = load_result(dataset, TRUE_MODEL, row_graphs=row_graphs, col_graphs=col_graphs)
    ids = r.sort_id_mat(True)
    q = train_data.graphs[ids[28][0]]
    gs = [test_data.graphs[47],
          train_data.graphs[ids[56][3]], train_data.graphs[ids[61][4]],
          test_data.graphs[13], test_data.graphs[67], train_data.graphs[ids[67][0]],
          test_data.graphs[99]]
    weight_query = [weight[ids[28][0]], weight[len(col_graphs) + 47],
                    weight[ids[56][3]], weight[ids[61][4]], weight[len(col_graphs) + 13],
                    weight[len(col_graphs) + 67], weight[ids[67][0]],
                    weight[len(col_graphs) + 99]]
    info_dict['each_graph_text_list'] = ['(a)', '(b)', '(c)', '(d)',
                                         '(e)', '(f)', '(g)', '(h)']
    info_dict['plot_save_path_png'] = '{}/{}_{}_{}_{}{}.{}'.format(
        dir, plot_what, dataset, model, 'hard', get_norm_str(True), 'png')
    info_dict['plot_save_path_eps'] = '{}/{}_{}_{}_{}{}.{}'.format(
        dir, plot_what, dataset, model, 'hard', get_norm_str(True), 'eps')

    vis_attention(q, gs, info_dict, weight_query, weight_max, weight_min)