def setUp(self) -> None:
     self.batch_size = 8
     self.input_size = 100
     self.n_node_properties = 20
     self.model = NSMCell(self.input_size, self.n_node_properties)
     self.graph_batch = collate_graphs(
         list(
             islice(
                 infinite_graphs(
                     self.input_size,
                     self.n_node_properties,
                     node_distribution=(16.4, 8.2),
                     density_distribution=(0.2, 0.4),
                 ),
                 self.batch_size,
             )))
     self.instruction_batch = torch.rand(self.batch_size, self.input_size)
     self.distribution = (
         1.0 /
         self.graph_batch.nodes_per_graph)[self.graph_batch.node_indices]
     all_prop_similarities = F.softmax(
         self.instruction_batch @ torch.rand(self.input_size,
                                             self.n_node_properties + 1),
         dim=1,
     )
     self.node_prop_similarities = all_prop_similarities[:, :-1]
     self.relation_similarity = all_prop_similarities[:, -1]
Example #2
0
 def setUp(self) -> None:
     graph_gen = infinite_graphs(
         300,
         77,
         node_distribution=(16.4, 8.2),
         density_distribution=(0.2, 0.4),
     )
     self.batch = collate_graphs(list(islice(graph_gen, 8)))
Example #3
0
 def setUp(self) -> None:
     self.hidden_size = 300
     self.n_properties = 77
     self.graph_gen = infinite_graphs(
         self.hidden_size,
         self.n_properties,
         node_distribution=(16.4, 8.2),
         density_distribution=(0.2, 0.4),
     )
Example #4
0
    def test_collate_graphs_inverse(self) -> None:
        batch_lens = random.sample(range(1, 21), 7)
        gen = infinite_graphs()
        for b_len in batch_lens:
            original_graphs = list(islice(gen, b_len))
            reconstructed_graphs = split_batch(collate_graphs(original_graphs))
            self.assertEqual(len(original_graphs), len(reconstructed_graphs))

            for og, ng in zip(original_graphs, reconstructed_graphs):
                for key in og._fields:
                    with self.subTest(batch_len=b_len, graph_attr=key):
                        self.assertTrue(
                            (getattr(og, key) == getattr(ng, key)).all())
Example #5
0
    def __init__(
        self,
        size: int,
        hidden_size: int,
        n_token_distribution: Tuple[float, float],
        n_properties: int,
        node_distribution: Tuple[float, float],
        density_distribution: Tuple[float, float],
        answer_vocab_size: int,
    ) -> None:

        self.graphs = list(
            islice(
                infinite_graphs(hidden_size, n_properties, node_distribution,
                                density_distribution),
                size,
            ))
        self.questions = [
            torch.rand(max(1, int(random.gauss(*n_token_distribution))),
                       hidden_size) for _ in range(size)
        ]
        self.answers = random.choices(range(answer_vocab_size), k=size)
    def setUp(self) -> None:
        # These tests should be using expected dimensions
        self.batch_size = 64
        self.output_size = 2000
        n_node_properties = 78
        input_size = 300
        computation_steps = 8

        self.vocab = torch.rand(1335, input_size)
        self.prop_embeds = torch.rand(n_node_properties + 1, input_size)
        self.model = NSM(
            input_size,
            n_node_properties,
            computation_steps,
            self.output_size,
        )
        self.graph_batch = collate_graphs(
            list(
                islice(
                    infinite_graphs(
                        input_size,
                        n_node_properties,
                        node_distribution=(16.4, 8.2),
                        density_distribution=(0.2, 0.4),
                    ),
                    self.batch_size,
                )))
        self.question_batch = pack_sequence(
            sorted(
                [
                    torch.rand(random.randint(10, 20), input_size)
                    for _ in range(self.batch_size)
                ],
                key=len,
                reverse=True,
            ))
from tqdm import tqdm
from itertools import islice
import h5py
from nsm.utils import infinite_graphs

gen = infinite_graphs()

exponent = 4
n_graphs = 10 ** exponent
with h5py.File("data/deleteme.h5", "w") as f:
    for i, graph in tqdm(enumerate(islice(gen, n_graphs)), total=n_graphs):
        # breakpoint()
        name = f"graph{i:0{exponent}}"
        grp = f.create_group(name)
        for key, value in graph._asdict().items():
            grp[key] = value