def test_basic(self): config = ConfigSchema( entities={"e": EntitySchema(num_partitions=1)}, relations=[RelationSchema(name="r", lhs="e", rhs="e")], dimension=1, entity_path="foo", edge_paths=["bar"], checkpoint_path="baz") metadata = ConfigMetadataProvider(config).get_checkpoint_metadata() self.assertIsInstance(metadata, dict) self.assertCountEqual(metadata.keys(), ["config/json"]) self.assertEqual( config, ConfigSchema.from_dict(json.loads(metadata["config/json"])))
def read_config(self) -> ConfigSchema: config_json = self.storage.load_config() return ConfigSchema.from_dict(json.loads(config_json))
lr=0.1, num_uniform_negs=50, eval_fraction= 0, # to reproduce results, we need to use all training data workers=1, distributed_init_method="tpc://localhost:30050", ) for num_part in args.num_parts: datadir = "{}_big_{}".format(args.dataset, num_part) config_dict['entity_path'] = os.path.join(args.root_output, datadir) config_dict['entities']['all']['num_partitions'] = num_part config_dict['edge_paths'] = [ os.path.join(args.root_output, datadir, datadir) ] config = ConfigSchema.from_dict(config_dict) convert_input_data( config.entities, config.relations, config.entity_path, config.edge_paths, [ Path( os.path.join( args.root_output, "{}_text/edgelist_pybig.txt".format(args.dataset))) ], lhs_col=0, rhs_col=2, rel_col=1,
def read_config(self) -> ConfigSchema: with open(os.path.join(self.path, CONFIG_FILE), "rt") as tf: return ConfigSchema.from_dict(json.load(tf))