Пример #1
0
 def setUp(self) -> None:
     self.batch_size = 8
     self.input_size = 100
     self.n_node_properties = 20
     self.model = NSMCell(self.input_size, self.n_node_properties)
     self.graph_batch = collate_graphs(
         list(
             islice(
                 infinite_graphs(
                     self.input_size,
                     self.n_node_properties,
                     node_distribution=(16.4, 8.2),
                     density_distribution=(0.2, 0.4),
                 ),
                 self.batch_size,
             )))
     self.instruction_batch = torch.rand(self.batch_size, self.input_size)
     self.distribution = (
         1.0 /
         self.graph_batch.nodes_per_graph)[self.graph_batch.node_indices]
     all_prop_similarities = F.softmax(
         self.instruction_batch @ torch.rand(self.input_size,
                                             self.n_node_properties + 1),
         dim=1,
     )
     self.node_prop_similarities = all_prop_similarities[:, :-1]
     self.relation_similarity = all_prop_similarities[:, -1]
Пример #2
0
 def setUp(self) -> None:
     graph_gen = infinite_graphs(
         300,
         77,
         node_distribution=(16.4, 8.2),
         density_distribution=(0.2, 0.4),
     )
     self.batch = collate_graphs(list(islice(graph_gen, 8)))
Пример #3
0
 def collate_fn(batch):
     graphs, questions, targets, instructionss = zip(*batch)
     return (
         collate_graphs(graphs),
         torch.nn.utils.rnn.pack_sequence(questions,
                                          enforce_sorted=False),
         vocab.concept_embeddings,
         vocab.property_embeddings,
         torch.tensor(targets),
         torch.stack(instructionss),
     )
Пример #4
0
    def test_collate_graphs_inverse(self) -> None:
        batch_lens = random.sample(range(1, 21), 7)
        gen = infinite_graphs()
        for b_len in batch_lens:
            original_graphs = list(islice(gen, b_len))
            reconstructed_graphs = split_batch(collate_graphs(original_graphs))
            self.assertEqual(len(original_graphs), len(reconstructed_graphs))

            for og, ng in zip(original_graphs, reconstructed_graphs):
                for key in og._fields:
                    with self.subTest(batch_len=b_len, graph_attr=key):
                        self.assertTrue(
                            (getattr(og, key) == getattr(ng, key)).all())
Пример #5
0
        def collate_fn(batch):
            graphs, questions, targets = zip(*batch)

            # return a 6-tuple, thats what the last iteration of
            # NSMLightningModule expects
            return (
                collate_graphs(graphs),
                torch.nn.utils.rnn.pack_sequence(questions,
                                                 enforce_sorted=False),
                vocab.concept_embeddings,
                vocab.property_embeddings,
                torch.tensor(targets),
                # dummy tensor, gold_instructions not used
                torch.empty(1),
            )
Пример #6
0
 def test_edge_indices_shift_edge_attrs(self) -> None:
     graphs = list(islice(self.graph_gen, 8))
     batch = collate_graphs(graphs)
     for _ in range(5):
         # Get 5 random nodes and test that their adjacency list match
         g_idx = random.randrange(len(graphs))
         graph = graphs[g_idx]
         node_idx = random.randrange(graph.node_attrs.size(0))
         # breakpoint()
         correct_edge_attrs = graph.edge_attrs[graph.edge_indices[0] ==
                                               node_idx]
         new_node_idx = sum(
             batch.nodes_per_graph.tolist()[:g_idx]) + node_idx
         new_edge_attrs = batch.edge_attrs[batch.edge_indices[0] ==
                                           new_node_idx]
         # breakpoint()
         self.assertEqual(new_edge_attrs.size(), correct_edge_attrs.size())
         self.assertTrue(new_edge_attrs.eq(correct_edge_attrs).all())
Пример #7
0
 def test_output_shapes(self) -> None:
     batch_size = 8
     graphs = list(islice(self.graph_gen, batch_size))
     batch = collate_graphs(graphs)
     with self.subTest("node_attrs"):
         total_nodes = sum(g.node_attrs.size(0) for g in graphs)
         self.assertEqual(
             batch.node_attrs.size(),
             (total_nodes, self.n_properties, self.hidden_size),
         )
     with self.subTest("edges"):
         total_edges = sum(g.edge_attrs.size(0) for g in graphs)
         with self.subTest("edge_attrs"):
             self.assertEqual(batch.edge_attrs.size(),
                              (total_edges, self.hidden_size))
         with self.subTest("edge_indices"):
             self.assertEqual(batch.edge_indices.size(), (2, total_edges))
     with self.subTest("nodes_per_graph"):
         self.assertEqual(batch.nodes_per_graph.size(0), batch_size)
Пример #8
0
    def setUp(self) -> None:
        # These tests should be using expected dimensions
        self.batch_size = 64
        self.output_size = 2000
        n_node_properties = 78
        input_size = 300
        computation_steps = 8

        self.vocab = torch.rand(1335, input_size)
        self.prop_embeds = torch.rand(n_node_properties + 1, input_size)
        self.model = NSM(
            input_size,
            n_node_properties,
            computation_steps,
            self.output_size,
        )
        self.graph_batch = collate_graphs(
            list(
                islice(
                    infinite_graphs(
                        input_size,
                        n_node_properties,
                        node_distribution=(16.4, 8.2),
                        density_distribution=(0.2, 0.4),
                    ),
                    self.batch_size,
                )))
        self.question_batch = pack_sequence(
            sorted(
                [
                    torch.rand(random.randint(10, 20), input_size)
                    for _ in range(self.batch_size)
                ],
                key=len,
                reverse=True,
            ))
Пример #9
0
 def test_nodes_per_graph(self) -> None:
     graphs = list(islice(self.graph_gen, 8))
     nodes_per_graph = [g.node_attrs.size(0) for g in graphs]
     batch = collate_graphs(graphs)
     self.assertEqual(batch.nodes_per_graph.tolist(), nodes_per_graph)
     self.assertEqual(batch.node_attrs.size(0), sum(nodes_per_graph))
Пример #10
0
 def test_output_type(self) -> None:
     batch = collate_graphs(list(islice(self.graph_gen, 8)))
     self.assertIsInstance(batch, Batch)