def load_graph_data(dataframe,
                    embeddings,
                    name="default",
                    testing=False,
                    num_test=100,
                    using_start=False):

    actor_indeces = []
    actor_features = []
    utterance_indeces = []
    utterance_features = []
    source_edges = []
    target_edges = []

    if testing:
        num_dialogues = num_test
    else:
        num_dialogues = len(dataframe['Dialogue ID'].unique())

    print("Building graph, 1 dialogue at a time...")
    for dialogueID in tqdm(dataframe['Dialogue ID'].unique()[0:num_dialogues]):
        dialogue = dataframe[dataframe["Dialogue ID"] == dialogueID]

        # Loop through all utterances of the dialogue
        for rowidx in range(len(dialogue)):
            row = dialogue.iloc[rowidx]

            # 0. Add actor index-feature if it does not already exist
            actor_idx = f"{row.Actor}_{dialogueID}"
            if actor_idx not in actor_indeces:
                actor_indeces.append(actor_idx)
                if len(actor_features) == 0:
                    # Create new numpy array of actor features
                    actor_features = np.random.normal(0.0, 1.0, [1, 1024])
                else:
                    # Concatenate features to already existing array
                    actor_features = np.concatenate(
                        (actor_features, np.random.normal(0.0, 1.0,
                                                          [1, 1024])),
                        axis=0)
            # 1. Add utterance index-feature (ELMo embeddings)
            utt_idx = f"u_dID{dialogueID}_#{rowidx}"
            utterance_indeces.append(utt_idx)
            # To iterate over the ELMo embeddings we use the index list of the
            # dataset, indexed by the row of the dialogue we are currently parsing
            if len(utterance_features) == 0:
                utterance_features = np.array(
                    [embeddings[dialogue.index[rowidx]]])
            else:
                utterance_features = np.concatenate(
                    (utterance_features,
                     np.array([embeddings[dialogue.index[rowidx]]])),
                    axis=0)

            # 2. Build edges. If this is the first row of a dialogue,
            # begin by drawing an edge from the "START-Node" (source)
            # to the current utterance index (target)
            if using_start and rowidx == 0:
                source_edges.append("START-Node")
                target_edges.append(utt_idx)

            # 3. Construct remaining edges.
            # 3.1 Actor to the utterance
            source_edges.append(actor_idx)
            target_edges.append(utt_idx)
            # 3.2 Utterance to the next utterance
            if (rowidx + 1) != len(dialogue):
                source_edges.append(utt_idx)
                target_edges.append(f"u_dID{dialogueID}_#{rowidx + 1}")
            # 3.3 Utterance to all actors
            for actor in dialogue['Actor'].unique():
                all_actor_idx = f"{actor}_{dialogueID}"
                source_edges.append(utt_idx)
                target_edges.append(all_actor_idx)

    # GraphSAGE (Does not support modelling nodes of different kind) ..less bad

    if using_start:
        start_features = np.random.normal(0.0, 1.0, [1, 1024])
        start_index = "START-Node"
        node_features = np.concatenate(
            (actor_features, utterance_features, start_features), axis=0)
        node_indeces = actor_indeces + utterance_indeces + [start_index]
    else:
        node_features = np.concatenate((actor_features, utterance_features),
                                       axis=0)
        node_indeces = actor_indeces + utterance_indeces

    nodes = IndexedArray(node_features, node_indeces)

    edges = pd.DataFrame({"source": source_edges, "target": target_edges})

    # GraphSAGE:
    full_graph = StellarDiGraph(nodes, edges)

    targets = pd.Series(
        dataframe['Dialogue Act'].tolist()[0:len(utterance_indeces)],
        index=utterance_indeces)

    print("Check if graph has all properties required for ML/Inference...")
    full_graph.check_graph_for_ml(expensive_check=True)
    print("Check successful.")
    print(full_graph.info())
    print("---- Graph Creation Finished ----")

    netx_graph = full_graph.to_networkx(feature_attr='utterance_embedding')
    # Save graphs for later use.
    if testing:
        pickle.dump((netx_graph, targets),
                    open(f"visualizeGraph/test_{name}_netx.pickle", "wb"))
        pickle.dump((full_graph, targets),
                    open(f"createdGraphs/test_{name}_graph.pickle", "wb"))
    else:
        pickle.dump((netx_graph, targets),
                    open(f"visualizeGraph/{name}_netx.pickle", "wb"))
        pickle.dump((full_graph, targets),
                    open(f"createdGraphs/{name}_graph.pickle", "wb"))

    return full_graph, targets
Exemple #2
0
# add features for renamed genes
feature_rename = feature_df.loc[gene_list]
feature_rename_gm = feature_rename.copy(deep=True)
feature_rename_k = feature_rename.copy(deep=True)

feature_rename_gm.index = feature_rename_gm.index.map(lambda name: name + '_gm')
feature_rename_k.index = feature_rename_k.index.map(lambda name: name + '_k')
feature_df = pd.concat([feature_df, feature_rename_gm, feature_rename_k], axis=0)

# %% [markdown]
# ## Read graph

# %%
G = StellarDiGraph(edges=df[['source', 'target']], nodes=feature_df)
print(G.info())

# %% [markdown]
# ## Data Generators
# 
# Now we create the data generators using `CorruptedGenerator`. `CorruptedGenerator` returns shuffled node features along with the regular node features and we train our model to discriminate between the two. 
# 
# Note that:
# 
# - We typically pass all nodes to `corrupted_generator.flow` because this is an unsupervised task
# - We don't pass `targets` to `corrupted_generator.flow` because these are binary labels (true nodes, false nodes) that are created by `CorruptedGenerator`

# %%
# HinSAGE model 
graphsage_generator = DirectedGraphSAGENodeGenerator(
    G, batch_size=50, in_samples=[30, 5], out_samples=[30, 5], seed=0