Ejemplo n.º 1
0
def test_model():
    dummy_graph = nx.MultiDiGraph()
    dummy_graph.add_node("n1", attr="a")
    dummy_graph.add_node("n2", attr="b")
    dummy_graph.add_node("n3", attr="c")
    dummy_graph.add_edge("n1", "n2", attr="dummy")
    dummy_graph.add_edge("n2", "n3", attr="dummy")

    config = {
        "num_timesteps": 2,
        "hidden_size_orig": len(dummy_graph),
        "gnn_h_size": 4,
        "gnn_m_size": 2,
        "num_edge_types": 1,
        "learning_rate": 0.001,
        "batch_size": 4,
        "num_epochs": 1,
    }
    model = GnnPytorchDGLModel(config=config)

    data = [{
        "x": {
            "code_rep": Graph(dummy_graph, ["a", "b", "c"], ["dummy"]),
            "aux_in": [0, 0],
        },
        "y": 0,
    }]
    model.train(data, data)
Ejemplo n.º 2
0
def test_train_model():
    dummy_graph = nx.MultiDiGraph()
    dummy_graph.add_node("n1", attr="a", seq_order=0)
    dummy_graph.add_node("n2", attr="b")
    dummy_graph.add_node("n3", attr="c")
    dummy_graph.add_edge("n1", "n2", attr="dummy")
    data = [{
        "x": {
            "code_rep":
            SequenceGraph.from_graph(
                Graph(dummy_graph, ["a", "b", "c"], ["dummy"])),
            "aux_in": [],
        },
        "y": 0,
    }]

    model = Tf2SandwichModel(
        {
            'num_epochs': 4,
            'layers': ['rnn', 'ggnn', 'rnn'],
            'base': {
                'hidden_dim': 4,
            }
        },
        num_types=3)
    model.train(data, data)
Ejemplo n.º 3
0
def test_train_model():
    dummy_graph = nx.MultiDiGraph()
    dummy_graph.add_node("n1", attr="a")
    dummy_graph.add_node("n2", attr="b")
    dummy_graph.add_node("n3", attr="c")

    config = CONFIG
    config["hidden_size_orig"] = len(dummy_graph)

    state = GnnTfModelState(config)
    model = GnnTfModel(config, state)

    data = [{
        "x": {
            "code_rep": Graph(dummy_graph, ["a", "b", "c"], ["dummy"]),
            "aux_in": [0, 0],
        },
        "y": 0,
    }]
    model.train(data, data)

    state.backup_best_weights()
    state.restore_best_weights()

    num_params = state.count_number_trainable_params()
    assert num_params
Ejemplo n.º 4
0
    def to_graph(self, vocab: Vocabulary) -> Graph:
        g = nx.MultiDiGraph()

        for idx, node in enumerate(self.nodes):
            if idx < self.seq_len:
                g.add_node(idx, attr=vocab.node_kinds[node], seq_order=idx)
            else:
                g.add_node(idx, attr=vocab.node_kinds[node])

        for k, s, t in zip(*self.edges):
            g.add_edge(s, t, attr=vocab.edge_kinds[k])

        return Graph(g, list(vocab.node_kinds), list(vocab.edge_kinds))
Ejemplo n.º 5
0
    def from_graph(graph: Graph) -> 'SequenceGraph':
        node_to_int = {n: i for i, n in enumerate(graph.get_node_types())}
        edge_to_int = {e: i for i, e in enumerate(graph.get_edge_types())}

        nodes = []
        node_mapping = {}
        seq_len = 0
        for node, data in sorted(
                graph.G.nodes(data=True),
                key=lambda item: item[1].get('seq_order', float('inf'))):
            node_mapping[node] = len(nodes)
            nodes.append(node_to_int[data['attr']])

            if 'seq_order' in data:
                seq_len = len(nodes)

        edges = []
        for u, v, data in graph.G.edges(data=True):
            edges.append(
                (edge_to_int[data['attr']], node_mapping[u], node_mapping[v]))

        return SequenceGraph(np.array(nodes, dtype=np.int32),
                             np.array(list(zip(*edges)), dtype=np.int32),
                             seq_len)
Ejemplo n.º 6
0
 def from_graph(graph: Graph) -> 'Vocabulary':
     return Vocabulary(node_kinds=np.array(graph.get_node_types()),
                       edge_kinds=np.array(graph.get_edge_types()))