예제 #1
0
def test_compute_similarity() -> None:
    gl = GraphLog()
    assert gl.compute_similarity("rule_0", "rule_0") == 1.0
    assert gl.compute_similarity("rule_0", "rule_2") == 0.9
    assert gl.compute_similarity("rule_0", "rule_10") == 0.5
    assert gl.compute_similarity("rule_0", "rule_20") == 0.0
    assert gl.compute_similarity("rule_0", "rule_50") == 0.0
예제 #2
0
 def __init__(self, config_id, load_checkpoint=True, seed=-1):
     self.config = bootstrap_config(config_id, seed)
     self.logbook = LogBook(self.config)
     self.num_experiments = self.config.general.num_experiments
     torch.set_num_threads(self.num_experiments)
     self.device = self.config.general.device
     self.label2id = {}
     self.model = None
     self.gl = GraphLog()
예제 #3
0
def test_load_networkx_graphs() -> None:
    gl = GraphLog()
    dataset = gl.get_dataset_by_name("rule_0")
    nx_graphs, query_nodes = load_networkx_graphs(dataset.json_graphs["train"])
    assert len(nx_graphs) == 5000
    assert isinstance(nx_graphs[0], DiGraph)
    assert len(nx_graphs[0]) == 35
    assert len(query_nodes) == 5000
    for nodes in query_nodes:
        assert len(nodes) == 2
예제 #4
0
    def test_dataloader(self):
        log.info("Test data loader called.")
        gl = GraphLog()
        rule_world = gl.get_dataset_by_name(self.hparams.train_world)
        # when using multi-node (ddp) we need to add the  datasampler
        batch_size = self.hparams.batch_size

        loader = gl.get_dataloader_by_mode(rule_world,
                                           mode="test",
                                           batch_size=batch_size)
        return loader
예제 #5
0
def test_paper_data_ids() -> None:
    gl = GraphLog()
    train_ids = [f"rule_{d}" for d in range(0, 51)]
    valid_ids = [f"rule_{d}" for d in range(51, 54)]
    test_ids = [f"rule_{d}" for d in range(54, 57)]
    data_by_split = gl.get_dataset_names_by_split()
    for world in train_ids:
        assert world in data_by_split["train"]
    for world in valid_ids:
        assert world in data_by_split["valid"]
    for world in test_ids:
        assert world in data_by_split["test"]
예제 #6
0
def test_get_most_similar_datasets() -> None:
    gl = GraphLog()
    sim = gl.get_most_similar_datasets("rule_0", 5)
    true_sim = [
        ("rule_0", 1.0),
        ("rule_1", 0.95),
        ("rule_2", 0.9),
        ("rule_3", 0.85),
        ("rule_4", 0.8),
    ]
    for si, s in enumerate(sim):
        assert s[0] == true_sim[si][0]
        assert s[1] == true_sim[si][1]
예제 #7
0
def test_stats() -> None:
    gl = GraphLog()
    stats = gl.compute_stats_by_dataset(name="rule_0")
    expected_stats = {
        "num_class": 17,
        "num_des": 286,
        "avg_resolution_length": 4.485714285714286,
        "num_nodes": 15.487,
        "num_edges": 19.295,
        "split": "train",
    }
    assert len(stats) == len(expected_stats)
    for key in expected_stats:
        assert expected_stats[key] == stats[key]
예제 #8
0
def test_label2id() -> None:
    gl = GraphLog()
    data_loc = os.path.join(gl.data_dir, gl.data_filename)
    label2id_loc = os.path.join(data_loc, "train", "label2id.json")
    assert os.path.exists(label2id_loc) & os.path.isfile(label2id_loc)
    len(gl.label2id.keys()) == 21
    assert gl.label2id["UNK_REL"] == 0
예제 #9
0
def test_download() -> None:
    gl = GraphLog()
    data_loc = os.path.join(gl.data_dir, gl.data_filename)
    assert os.path.exists(data_loc) & os.path.isdir(data_loc)
    train_data_loc = os.path.join(data_loc, "train")
    assert os.path.exists(train_data_loc) & os.path.isdir(train_data_loc)
    valid_data_loc = os.path.join(data_loc, "valid")
    assert os.path.exists(valid_data_loc) & os.path.isdir(valid_data_loc)
    test_data_loc = os.path.join(data_loc, "test")
    assert os.path.exists(test_data_loc) & os.path.isdir(test_data_loc)
예제 #10
0
def test_single_dataloader() -> None:
    gl = GraphLog()
    device = (torch.device("cuda:0")
              if torch.cuda.is_available() else torch.device("cpu"))
    cpu_device = torch.device("cpu")
    dataset = gl.get_dataset_by_name(name="rule_0")
    batch_size = 32
    dataloader_size = {"train": 157, "valid": 32, "test": 32}
    for mode, size in dataloader_size.items():
        dataloader = gl.get_dataloader_by_mode(dataset=dataset,
                                               batch_size=batch_size,
                                               mode=mode)
        for batch in dataloader:
            assert len(batch.targets) == batch_size
            assert len(batch.queries) == len(batch.targets)
            assert batch.targets.device == cpu_device
            batch.to(device)
            assert batch.targets.device == device
            break
        assert len(dataloader) == size
예제 #11
0
class CheckpointableTestTube(Checkpointable):
    """Checkpointable TestTube Class

    This class provides a mechanism to checkpoint the (otherwise stateless) TestTube
    """
    def __init__(self, config_id, load_checkpoint=True, seed=-1):
        self.config = bootstrap_config(config_id, seed)
        self.logbook = LogBook(self.config)
        self.num_experiments = self.config.general.num_experiments
        torch.set_num_threads(self.num_experiments)
        self.device = self.config.general.device
        self.label2id = {}
        self.model = None
        self.gl = GraphLog()

    def bootstrap_model(self):
        """Method to instantiate the models that will be common to all
        the experiments."""
        model = choose_model(self.config)
        model.to(self.device)
        return model

    def load_label2id(self):
        self.label2id = self.gl.get_label2id()
        print("Found : {} labels".format(len(self.label2id)))

    def initialize_data(self, mode="train", override_mode=None) -> List[Any]:
        """
        Load and initialize data here
        :return:
        """
        datasets = self.gl.get_dataset_names_by_split()
        graphworld_list = []
        for rule_world in datasets[mode]:
            graphworld_list.append(self.gl.get_dataset_by_name(rule_world))

        self.config.model.num_classes = len(self.gl.get_label2id())
        self.load_label2id()

        return graphworld_list

    def run(self):
        """Method to run the task"""

        write_message_logs("Starting Experiment at {}".format(
            time.asctime(time.localtime(time.time()))))
        write_config_log(self.config)
        write_message_logs("torch version = {}".format(torch.__version__))

        if not self.config.general.is_meta:
            self.train_data = self.initialize_data(mode="train")
            self.valid_data = self.initialize_data(mode="valid")
            self.test_data = self.initialize_data(mode="test")
            self.experiment = MultitaskExperiment(
                config=self.config,
                model=self.model,
                data=[self.train_data, self.valid_data, self.test_data],
                logbook=self.logbook,
            )
        else:
            raise NotImplementedError("NA")
        self.experiment.load_model()
        self.experiment.run()

    def prepare_evaluator(
        self,
        epoch: Optional[int] = None,
        test_data=None,
        zero_init=False,
        override_mode=None,
        label2id=None,
    ):
        self.load_label2id()
        if test_data:
            assert label2id is not None
            self.test_data = test_data
            self.num_graphworlds = len(test_data)
            self.config.model.num_classes = len(label2id)
        else:
            self.test_data = self.initialize_data(mode="test",
                                                  override_mode=override_mode)
        self.evaluator = InferenceExperiment(self.config, self.logbook,
                                             [self.test_data])
        self.evaluator.reset(epoch=epoch, zero_init=zero_init)

    def evaluate(self,
                 epoch: Optional[int] = None,
                 test_data=None,
                 ale_mode=False):
        self.prepare_evaluator(epoch, test_data)
        return self.evaluator.run(ale_mode=ale_mode)
예제 #12
0
#
"""
import os

import torch
from networkx import DiGraph
from torch.utils.data import Dataset

from graphlog import GraphLog
from graphlog.utils import load_networkx_graphs
import pytest


@pytest.mark.parametrize(  # type: ignore
    "gl",
    [GraphLog(), GraphLog(data_key="graphlog_v1.1")],
)
def test_download(gl) -> None:
    data_loc = os.path.join(gl.data_dir, gl.data_filename)
    assert os.path.exists(data_loc) & os.path.isdir(data_loc)
    train_data_loc = os.path.join(data_loc, "train")
    assert os.path.exists(train_data_loc) & os.path.isdir(train_data_loc)
    valid_data_loc = os.path.join(data_loc, "valid")
    assert os.path.exists(valid_data_loc) & os.path.isdir(valid_data_loc)
    test_data_loc = os.path.join(data_loc, "test")
    assert os.path.exists(test_data_loc) & os.path.isdir(test_data_loc)


@pytest.mark.parametrize(  # type: ignore
    "gl",
    [GraphLog(), GraphLog(data_key="graphlog_v1.1")],
예제 #13
0
def test_all_dataset_loading() -> None:
    gl = GraphLog()
    gl.load_datasets()
    assert len(gl.datasets) == 57
예제 #14
0
def test_single_dataset() -> None:
    gl = GraphLog()
    dataset = gl.get_dataset_by_name("rule_0")
    assert isinstance(dataset, Dataset)
    assert len(gl.datasets) == 1