def test_load_networkx_graphs() -> None: gl = GraphLog() dataset = gl.get_dataset_by_name("rule_0") nx_graphs, query_nodes = load_networkx_graphs(dataset.json_graphs["train"]) assert len(nx_graphs) == 5000 assert isinstance(nx_graphs[0], DiGraph) assert len(nx_graphs[0]) == 35 assert len(query_nodes) == 5000 for nodes in query_nodes: assert len(nodes) == 2
def test_dataloader(self): log.info("Test data loader called.") gl = GraphLog() rule_world = gl.get_dataset_by_name(self.hparams.train_world) # when using multi-node (ddp) we need to add the datasampler batch_size = self.hparams.batch_size loader = gl.get_dataloader_by_mode(rule_world, mode="test", batch_size=batch_size) return loader
def test_single_dataloader() -> None: gl = GraphLog() device = (torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")) cpu_device = torch.device("cpu") dataset = gl.get_dataset_by_name(name="rule_0") batch_size = 32 dataloader_size = {"train": 157, "valid": 32, "test": 32} for mode, size in dataloader_size.items(): dataloader = gl.get_dataloader_by_mode(dataset=dataset, batch_size=batch_size, mode=mode) for batch in dataloader: assert len(batch.targets) == batch_size assert len(batch.queries) == len(batch.targets) assert batch.targets.device == cpu_device batch.to(device) assert batch.targets.device == device break assert len(dataloader) == size
class CheckpointableTestTube(Checkpointable): """Checkpointable TestTube Class This class provides a mechanism to checkpoint the (otherwise stateless) TestTube """ def __init__(self, config_id, load_checkpoint=True, seed=-1): self.config = bootstrap_config(config_id, seed) self.logbook = LogBook(self.config) self.num_experiments = self.config.general.num_experiments torch.set_num_threads(self.num_experiments) self.device = self.config.general.device self.label2id = {} self.model = None self.gl = GraphLog() def bootstrap_model(self): """Method to instantiate the models that will be common to all the experiments.""" model = choose_model(self.config) model.to(self.device) return model def load_label2id(self): self.label2id = self.gl.get_label2id() print("Found : {} labels".format(len(self.label2id))) def initialize_data(self, mode="train", override_mode=None) -> List[Any]: """ Load and initialize data here :return: """ datasets = self.gl.get_dataset_names_by_split() graphworld_list = [] for rule_world in datasets[mode]: graphworld_list.append(self.gl.get_dataset_by_name(rule_world)) self.config.model.num_classes = len(self.gl.get_label2id()) self.load_label2id() return graphworld_list def run(self): """Method to run the task""" write_message_logs("Starting Experiment at {}".format( time.asctime(time.localtime(time.time())))) write_config_log(self.config) write_message_logs("torch version = {}".format(torch.__version__)) if not self.config.general.is_meta: self.train_data = self.initialize_data(mode="train") self.valid_data = self.initialize_data(mode="valid") self.test_data = self.initialize_data(mode="test") self.experiment = MultitaskExperiment( config=self.config, model=self.model, data=[self.train_data, self.valid_data, self.test_data], logbook=self.logbook, ) else: raise NotImplementedError("NA") self.experiment.load_model() self.experiment.run() def prepare_evaluator( self, epoch: Optional[int] = None, test_data=None, zero_init=False, override_mode=None, label2id=None, ): self.load_label2id() if test_data: assert label2id is not None self.test_data = test_data self.num_graphworlds = len(test_data) self.config.model.num_classes = len(label2id) else: self.test_data = self.initialize_data(mode="test", override_mode=override_mode) self.evaluator = InferenceExperiment(self.config, self.logbook, [self.test_data]) self.evaluator.reset(epoch=epoch, zero_init=zero_init) def evaluate(self, epoch: Optional[int] = None, test_data=None, ale_mode=False): self.prepare_evaluator(epoch, test_data) return self.evaluator.run(ale_mode=ale_mode)
def test_single_dataset() -> None: gl = GraphLog() dataset = gl.get_dataset_by_name("rule_0") assert isinstance(dataset, Dataset) assert len(gl.datasets) == 1