Exemple #1
0
def visualize(model, graphs, res_dir, data_name, class_values, num=5):
    model.eval()
    model.to(device)
    preds = []
    graph_loader = DataLoader(graphs, 50, shuffle=False)
    for data in tqdm(graph_loader):
        data = data.to(device)
        pred = model(data)
        preds.extend(pred.view(-1).tolist())
    order = np.argsort(preds).tolist()
    highest = [PyGGraph_to_nx(graphs[i]) for i in order[-num:][::-1]]
    lowest = [PyGGraph_to_nx(graphs[i]) for i in order[:num]]
    highest_scores = [preds[i] for i in order[-num:][::-1]]
    lowest_scores = [preds[i] for i in order[:num]]
    scores = highest_scores + lowest_scores
    type_to_label = {0: 'u0', 1: 'v0', 2: 'u1', 3: 'v1', 4: 'u2', 5: 'v2'}
    #type_to_color = {0: 'r', 1: 'r', 2: 'k', 3: 'k', 4: 'k', 5: 'k'}
    #type_to_color = {0: 'xkcd:orangered', 1: 'xkcd:azure', 2: 'xkcd:orange', 3: 'xkcd:lightblue', 4: 'y', 5: 'g'}
    type_to_color = {0: 'xkcd:red', 1: 'xkcd:blue', 2: 'xkcd:orange', 3: 'xkcd:lightblue', 4: 'y', 5: 'g'}
    plt.axis('off')
    f = plt.figure(figsize=(20, 10))
    axs = f.subplots(2, num)
    #cmap = plt.cm.coolwarm
    cmap = plt.cm.get_cmap('rainbow')
    vmin, vmax = min(class_values), max(class_values)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax))
    sm.set_array([])
    for i, g in enumerate(highest + lowest):
        u_nodes = [x for x, y in g.nodes(data=True) if y['type'] % 2 == 0]
        u0, v0 = 0, len(u_nodes)
        pos = nx.drawing.layout.bipartite_layout(g, u_nodes)
        bottom_u_node = min(pos, key=lambda x: (pos[x][0], pos[x][1]))
        bottom_v_node = min(pos, key=lambda x: (-pos[x][0], pos[x][1]))
        # swap u0 and v0 with bottom nodes if they are not already
        if u0 != bottom_u_node:
            pos[u0], pos[bottom_u_node] = pos[bottom_u_node], pos[u0]
        if v0 != bottom_v_node:
            pos[v0], pos[bottom_v_node] = pos[bottom_v_node], pos[v0]
        labels = {x: type_to_label[y] for x, y in nx.get_node_attributes(g, 'type').items()}
        node_colors = [type_to_color[y] for x, y in nx.get_node_attributes(g, 'type').items()]
        edge_types = nx.get_edge_attributes(g, 'type')
        edge_types = [edge_types[x] for x in g.edges()]
        #f.add_subplot(2, num, i+1)
        axs[i//num, i%num].axis('off')
        nx.draw_networkx(g, pos, 
                #labels=labels, 
                with_labels=False, 
                node_size=150, 
                node_color=node_colors, edge_color=edge_types, 
                ax=axs[i//num, i%num], edge_cmap=cmap, edge_vmin=vmin, edge_vmax=vmax, 
                )
        # make u0 v0 on top of other nodes
        nx.draw_networkx_nodes(g, {u0: pos[u0]}, nodelist=[u0], node_size=150,
                node_color='xkcd:red', ax=axs[i//num, i%num])
        nx.draw_networkx_nodes(g, {v0: pos[v0]}, nodelist=[v0], node_size=150,
                node_color='xkcd:blue', ax=axs[i//num, i%num])
        axs[i//num, i%num].set_title('{:.4f}'.format(scores[i]), x=0.5, y=-0.05, fontsize=20)
    f.subplots_adjust(right=0.85)
    cbar_ax = f.add_axes([0.88, 0.15, 0.02, 0.7])
    if len(class_values) > 20:
        class_values = np.linspace(min(class_values), max(class_values), 20, dtype=int).tolist()
    cbar = plt.colorbar(sm, cax=cbar_ax, ticks=class_values)
    cbar.ax.tick_params(labelsize=22)
    f.savefig(os.path.join(res_dir, "visualization_{}.pdf".format(data_name)), 
            interpolation='nearest', bbox_inches='tight')
Exemple #2
0
def visualize(
    model, graphs, res_dir, data_name, class_values, num=5, sort_by="prediction"
):
    model.eval()
    model.to(device)
    R = []
    Y = []
    graph_loader = DataLoader(graphs, 50, shuffle=False)
    for data in tqdm(graph_loader):
        data = data.to(device)
        r = model(data).detach()
        y = data.y
        R.extend(r.view(-1).tolist())
        Y.extend(y.view(-1).tolist())
    if sort_by == "true":  # sort graphs by their true ratings
        order = np.argsort(Y).tolist()
    elif sort_by == "prediction":
        order = np.argsort(R).tolist()
    elif sort_by == "random":  # randomly select graphs to visualize
        order = np.random.permutation(range(len(R))).tolist()
    highest = [PyGGraph_to_nx(graphs[i]) for i in order[-num:][::-1]]
    lowest = [PyGGraph_to_nx(graphs[i]) for i in order[:num]]
    highest_scores = [R[i] for i in order[-num:][::-1]]
    lowest_scores = [R[i] for i in order[:num]]
    highest_ys = [Y[i] for i in order[-num:][::-1]]
    lowest_ys = [Y[i] for i in order[:num]]
    scores = highest_scores + lowest_scores
    ys = highest_ys + lowest_ys
    type_to_label = {0: "u0", 1: "v0", 2: "u1", 3: "v1", 4: "u2", 5: "v2"}
    type_to_color = {
        0: "xkcd:red",
        1: "xkcd:blue",
        2: "xkcd:orange",
        3: "xkcd:lightblue",
        4: "y",
        5: "g",
    }
    plt.axis("off")
    f = plt.figure(figsize=(20, 10))
    axs = f.subplots(2, num)
    cmap = plt.cm.get_cmap("rainbow")
    vmin, vmax = min(class_values), max(class_values)
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax))
    sm.set_array([])
    for i, g in enumerate(highest + lowest):
        u_nodes = [x for x, y in g.nodes(data=True) if y["type"] % 2 == 0]
        u0, v0 = 0, len(u_nodes)
        pos = nx.drawing.layout.bipartite_layout(g, u_nodes)
        bottom_u_node = min(pos, key=lambda x: (pos[x][0], pos[x][1]))
        bottom_v_node = min(pos, key=lambda x: (-pos[x][0], pos[x][1]))
        # swap u0 and v0 with bottom nodes if they are not already
        if u0 != bottom_u_node:
            pos[u0], pos[bottom_u_node] = pos[bottom_u_node], pos[u0]
        if v0 != bottom_v_node:
            pos[v0], pos[bottom_v_node] = pos[bottom_v_node], pos[v0]
        labels = {
            x: type_to_label[y] for x, y in nx.get_node_attributes(g, "type").items()
        }
        node_colors = [
            type_to_color[y] for x, y in nx.get_node_attributes(g, "type").items()
        ]
        edge_types = nx.get_edge_attributes(g, "type")
        edge_types = [class_values[edge_types[x]] for x in g.edges()]
        axs[i // num, i % num].axis("off")
        nx.draw_networkx(
            g,
            pos,
            # labels=labels,
            with_labels=False,
            node_size=150,
            node_color=node_colors,
            edge_color=edge_types,
            ax=axs[i // num, i % num],
            edge_cmap=cmap,
            edge_vmin=vmin,
            edge_vmax=vmax,
        )
        # make u0 v0 on top of other nodes
        nx.draw_networkx_nodes(
            g,
            {u0: pos[u0]},
            nodelist=[u0],
            node_size=150,
            node_color="xkcd:red",
            ax=axs[i // num, i % num],
        )
        nx.draw_networkx_nodes(
            g,
            {v0: pos[v0]},
            nodelist=[v0],
            node_size=150,
            node_color="xkcd:blue",
            ax=axs[i // num, i % num],
        )
        axs[i // num, i % num].set_title(
            "{:.4f} ({:})".format(scores[i], ys[i]), x=0.5, y=-0.05, fontsize=20
        )
    f.subplots_adjust(right=0.85)
    cbar_ax = f.add_axes([0.88, 0.15, 0.02, 0.7])
    if len(class_values) > 20:
        class_values = np.linspace(
            min(class_values), max(class_values), 20, dtype=int
        ).tolist()
    cbar = plt.colorbar(sm, cax=cbar_ax, ticks=class_values)
    cbar.ax.tick_params(labelsize=22)
    f.savefig(
        os.path.join(res_dir, "visualization_{}_{}.pdf".format(data_name, sort_by)),
        interpolation="nearest",
        bbox_inches="tight",
    )