def test_model(): dummy_graph = nx.MultiDiGraph() dummy_graph.add_node("n1", attr="a") dummy_graph.add_node("n2", attr="b") dummy_graph.add_node("n3", attr="c") dummy_graph.add_edge("n1", "n2", attr="dummy") dummy_graph.add_edge("n2", "n3", attr="dummy") config = { "num_timesteps": 2, "hidden_size_orig": len(dummy_graph), "gnn_h_size": 4, "gnn_m_size": 2, "num_edge_types": 1, "learning_rate": 0.001, "batch_size": 4, "num_epochs": 1, } model = GnnPytorchDGLModel(config=config) data = [{ "x": { "code_rep": Graph(dummy_graph, ["a", "b", "c"], ["dummy"]), "aux_in": [0, 0], }, "y": 0, }] model.train(data, data)
def test_train_model(): dummy_graph = nx.MultiDiGraph() dummy_graph.add_node("n1", attr="a", seq_order=0) dummy_graph.add_node("n2", attr="b") dummy_graph.add_node("n3", attr="c") dummy_graph.add_edge("n1", "n2", attr="dummy") data = [{ "x": { "code_rep": SequenceGraph.from_graph( Graph(dummy_graph, ["a", "b", "c"], ["dummy"])), "aux_in": [], }, "y": 0, }] model = Tf2SandwichModel( { 'num_epochs': 4, 'layers': ['rnn', 'ggnn', 'rnn'], 'base': { 'hidden_dim': 4, } }, num_types=3) model.train(data, data)
def test_train_model(): dummy_graph = nx.MultiDiGraph() dummy_graph.add_node("n1", attr="a") dummy_graph.add_node("n2", attr="b") dummy_graph.add_node("n3", attr="c") config = CONFIG config["hidden_size_orig"] = len(dummy_graph) state = GnnTfModelState(config) model = GnnTfModel(config, state) data = [{ "x": { "code_rep": Graph(dummy_graph, ["a", "b", "c"], ["dummy"]), "aux_in": [0, 0], }, "y": 0, }] model.train(data, data) state.backup_best_weights() state.restore_best_weights() num_params = state.count_number_trainable_params() assert num_params
def to_graph(self, vocab: Vocabulary) -> Graph: g = nx.MultiDiGraph() for idx, node in enumerate(self.nodes): if idx < self.seq_len: g.add_node(idx, attr=vocab.node_kinds[node], seq_order=idx) else: g.add_node(idx, attr=vocab.node_kinds[node]) for k, s, t in zip(*self.edges): g.add_edge(s, t, attr=vocab.edge_kinds[k]) return Graph(g, list(vocab.node_kinds), list(vocab.edge_kinds))
def from_graph(graph: Graph) -> 'SequenceGraph': node_to_int = {n: i for i, n in enumerate(graph.get_node_types())} edge_to_int = {e: i for i, e in enumerate(graph.get_edge_types())} nodes = [] node_mapping = {} seq_len = 0 for node, data in sorted( graph.G.nodes(data=True), key=lambda item: item[1].get('seq_order', float('inf'))): node_mapping[node] = len(nodes) nodes.append(node_to_int[data['attr']]) if 'seq_order' in data: seq_len = len(nodes) edges = [] for u, v, data in graph.G.edges(data=True): edges.append( (edge_to_int[data['attr']], node_mapping[u], node_mapping[v])) return SequenceGraph(np.array(nodes, dtype=np.int32), np.array(list(zip(*edges)), dtype=np.int32), seq_len)
def from_graph(graph: Graph) -> 'Vocabulary': return Vocabulary(node_kinds=np.array(graph.get_node_types()), edge_kinds=np.array(graph.get_edge_types()))