def setUp(self) -> None: self.batch_size = 8 self.input_size = 100 self.n_node_properties = 20 self.model = NSMCell(self.input_size, self.n_node_properties) self.graph_batch = collate_graphs( list( islice( infinite_graphs( self.input_size, self.n_node_properties, node_distribution=(16.4, 8.2), density_distribution=(0.2, 0.4), ), self.batch_size, ))) self.instruction_batch = torch.rand(self.batch_size, self.input_size) self.distribution = ( 1.0 / self.graph_batch.nodes_per_graph)[self.graph_batch.node_indices] all_prop_similarities = F.softmax( self.instruction_batch @ torch.rand(self.input_size, self.n_node_properties + 1), dim=1, ) self.node_prop_similarities = all_prop_similarities[:, :-1] self.relation_similarity = all_prop_similarities[:, -1]
def setUp(self) -> None: graph_gen = infinite_graphs( 300, 77, node_distribution=(16.4, 8.2), density_distribution=(0.2, 0.4), ) self.batch = collate_graphs(list(islice(graph_gen, 8)))
def collate_fn(batch): graphs, questions, targets, instructionss = zip(*batch) return ( collate_graphs(graphs), torch.nn.utils.rnn.pack_sequence(questions, enforce_sorted=False), vocab.concept_embeddings, vocab.property_embeddings, torch.tensor(targets), torch.stack(instructionss), )
def test_collate_graphs_inverse(self) -> None: batch_lens = random.sample(range(1, 21), 7) gen = infinite_graphs() for b_len in batch_lens: original_graphs = list(islice(gen, b_len)) reconstructed_graphs = split_batch(collate_graphs(original_graphs)) self.assertEqual(len(original_graphs), len(reconstructed_graphs)) for og, ng in zip(original_graphs, reconstructed_graphs): for key in og._fields: with self.subTest(batch_len=b_len, graph_attr=key): self.assertTrue( (getattr(og, key) == getattr(ng, key)).all())
def collate_fn(batch): graphs, questions, targets = zip(*batch) # return a 6-tuple, thats what the last iteration of # NSMLightningModule expects return ( collate_graphs(graphs), torch.nn.utils.rnn.pack_sequence(questions, enforce_sorted=False), vocab.concept_embeddings, vocab.property_embeddings, torch.tensor(targets), # dummy tensor, gold_instructions not used torch.empty(1), )
def test_edge_indices_shift_edge_attrs(self) -> None: graphs = list(islice(self.graph_gen, 8)) batch = collate_graphs(graphs) for _ in range(5): # Get 5 random nodes and test that their adjacency list match g_idx = random.randrange(len(graphs)) graph = graphs[g_idx] node_idx = random.randrange(graph.node_attrs.size(0)) # breakpoint() correct_edge_attrs = graph.edge_attrs[graph.edge_indices[0] == node_idx] new_node_idx = sum( batch.nodes_per_graph.tolist()[:g_idx]) + node_idx new_edge_attrs = batch.edge_attrs[batch.edge_indices[0] == new_node_idx] # breakpoint() self.assertEqual(new_edge_attrs.size(), correct_edge_attrs.size()) self.assertTrue(new_edge_attrs.eq(correct_edge_attrs).all())
def test_output_shapes(self) -> None: batch_size = 8 graphs = list(islice(self.graph_gen, batch_size)) batch = collate_graphs(graphs) with self.subTest("node_attrs"): total_nodes = sum(g.node_attrs.size(0) for g in graphs) self.assertEqual( batch.node_attrs.size(), (total_nodes, self.n_properties, self.hidden_size), ) with self.subTest("edges"): total_edges = sum(g.edge_attrs.size(0) for g in graphs) with self.subTest("edge_attrs"): self.assertEqual(batch.edge_attrs.size(), (total_edges, self.hidden_size)) with self.subTest("edge_indices"): self.assertEqual(batch.edge_indices.size(), (2, total_edges)) with self.subTest("nodes_per_graph"): self.assertEqual(batch.nodes_per_graph.size(0), batch_size)
def setUp(self) -> None: # These tests should be using expected dimensions self.batch_size = 64 self.output_size = 2000 n_node_properties = 78 input_size = 300 computation_steps = 8 self.vocab = torch.rand(1335, input_size) self.prop_embeds = torch.rand(n_node_properties + 1, input_size) self.model = NSM( input_size, n_node_properties, computation_steps, self.output_size, ) self.graph_batch = collate_graphs( list( islice( infinite_graphs( input_size, n_node_properties, node_distribution=(16.4, 8.2), density_distribution=(0.2, 0.4), ), self.batch_size, ))) self.question_batch = pack_sequence( sorted( [ torch.rand(random.randint(10, 20), input_size) for _ in range(self.batch_size) ], key=len, reverse=True, ))
def test_nodes_per_graph(self) -> None: graphs = list(islice(self.graph_gen, 8)) nodes_per_graph = [g.node_attrs.size(0) for g in graphs] batch = collate_graphs(graphs) self.assertEqual(batch.nodes_per_graph.tolist(), nodes_per_graph) self.assertEqual(batch.node_attrs.size(0), sum(nodes_per_graph))
def test_output_type(self) -> None: batch = collate_graphs(list(islice(self.graph_gen, 8))) self.assertIsInstance(batch, Batch)