Exemplo n.º 1
0
def test_load_networkx_graphs() -> None:
    gl = GraphLog()
    dataset = gl.get_dataset_by_name("rule_0")
    nx_graphs, query_nodes = load_networkx_graphs(dataset.json_graphs["train"])
    assert len(nx_graphs) == 5000
    assert isinstance(nx_graphs[0], DiGraph)
    assert len(nx_graphs[0]) == 35
    assert len(query_nodes) == 5000
    for nodes in query_nodes:
        assert len(nodes) == 2
Exemplo n.º 2
0
    def test_dataloader(self):
        log.info("Test data loader called.")
        gl = GraphLog()
        rule_world = gl.get_dataset_by_name(self.hparams.train_world)
        # when using multi-node (ddp) we need to add the  datasampler
        batch_size = self.hparams.batch_size

        loader = gl.get_dataloader_by_mode(rule_world,
                                           mode="test",
                                           batch_size=batch_size)
        return loader
Exemplo n.º 3
0
def test_single_dataloader() -> None:
    gl = GraphLog()
    device = (torch.device("cuda:0")
              if torch.cuda.is_available() else torch.device("cpu"))
    cpu_device = torch.device("cpu")
    dataset = gl.get_dataset_by_name(name="rule_0")
    batch_size = 32
    dataloader_size = {"train": 157, "valid": 32, "test": 32}
    for mode, size in dataloader_size.items():
        dataloader = gl.get_dataloader_by_mode(dataset=dataset,
                                               batch_size=batch_size,
                                               mode=mode)
        for batch in dataloader:
            assert len(batch.targets) == batch_size
            assert len(batch.queries) == len(batch.targets)
            assert batch.targets.device == cpu_device
            batch.to(device)
            assert batch.targets.device == device
            break
        assert len(dataloader) == size
Exemplo n.º 4
0
class CheckpointableTestTube(Checkpointable):
    """Checkpointable TestTube Class

    This class provides a mechanism to checkpoint the (otherwise stateless) TestTube
    """
    def __init__(self, config_id, load_checkpoint=True, seed=-1):
        self.config = bootstrap_config(config_id, seed)
        self.logbook = LogBook(self.config)
        self.num_experiments = self.config.general.num_experiments
        torch.set_num_threads(self.num_experiments)
        self.device = self.config.general.device
        self.label2id = {}
        self.model = None
        self.gl = GraphLog()

    def bootstrap_model(self):
        """Method to instantiate the models that will be common to all
        the experiments."""
        model = choose_model(self.config)
        model.to(self.device)
        return model

    def load_label2id(self):
        self.label2id = self.gl.get_label2id()
        print("Found : {} labels".format(len(self.label2id)))

    def initialize_data(self, mode="train", override_mode=None) -> List[Any]:
        """
        Load and initialize data here
        :return:
        """
        datasets = self.gl.get_dataset_names_by_split()
        graphworld_list = []
        for rule_world in datasets[mode]:
            graphworld_list.append(self.gl.get_dataset_by_name(rule_world))

        self.config.model.num_classes = len(self.gl.get_label2id())
        self.load_label2id()

        return graphworld_list

    def run(self):
        """Method to run the task"""

        write_message_logs("Starting Experiment at {}".format(
            time.asctime(time.localtime(time.time()))))
        write_config_log(self.config)
        write_message_logs("torch version = {}".format(torch.__version__))

        if not self.config.general.is_meta:
            self.train_data = self.initialize_data(mode="train")
            self.valid_data = self.initialize_data(mode="valid")
            self.test_data = self.initialize_data(mode="test")
            self.experiment = MultitaskExperiment(
                config=self.config,
                model=self.model,
                data=[self.train_data, self.valid_data, self.test_data],
                logbook=self.logbook,
            )
        else:
            raise NotImplementedError("NA")
        self.experiment.load_model()
        self.experiment.run()

    def prepare_evaluator(
        self,
        epoch: Optional[int] = None,
        test_data=None,
        zero_init=False,
        override_mode=None,
        label2id=None,
    ):
        self.load_label2id()
        if test_data:
            assert label2id is not None
            self.test_data = test_data
            self.num_graphworlds = len(test_data)
            self.config.model.num_classes = len(label2id)
        else:
            self.test_data = self.initialize_data(mode="test",
                                                  override_mode=override_mode)
        self.evaluator = InferenceExperiment(self.config, self.logbook,
                                             [self.test_data])
        self.evaluator.reset(epoch=epoch, zero_init=zero_init)

    def evaluate(self,
                 epoch: Optional[int] = None,
                 test_data=None,
                 ale_mode=False):
        self.prepare_evaluator(epoch, test_data)
        return self.evaluator.run(ale_mode=ale_mode)
Exemplo n.º 5
0
def test_single_dataset() -> None:
    gl = GraphLog()
    dataset = gl.get_dataset_by_name("rule_0")
    assert isinstance(dataset, Dataset)
    assert len(gl.datasets) == 1