def test_serialize_tensors(): # create a temporary file and immediately release it so DGL can open it. f = tempfile.NamedTemporaryFile(delete=False) path = f.name f.close() tensor_dict = { "a": F.tensor([1, 3, -1, 0], dtype=F.int64), "1@1": F.tensor([1.5, 2], dtype=F.float32) } save_tensors(path, tensor_dict) load_tensor_dict = load_tensors(path) for key in tensor_dict: assert key in load_tensor_dict assert np.array_equal(F.asnumpy(load_tensor_dict[key]), F.asnumpy(tensor_dict[key])) load_nd_dict = load_tensors(path, return_dgl_ndarray=True) for key in tensor_dict: assert key in load_nd_dict assert isinstance(load_nd_dict[key], nd.NDArray) assert np.array_equal(load_nd_dict[key].asnumpy(), F.asnumpy(tensor_dict[key])) os.unlink(path)
def test_serialize_empty_dict(): # create a temporary file and immediately release it so DGL can open it. f = tempfile.NamedTemporaryFile(delete=False) path = f.name f.close() tensor_dict = {} save_tensors(path, tensor_dict) load_tensor_dict = load_tensors(path) assert isinstance(load_tensor_dict, dict) assert len(load_tensor_dict) == 0 os.unlink(path)
elif args.dataset == 'pubmed': part_config = os.path.join(part_config, "Libra_result_pubmed", str(nc) + "Communities", "pubmed.json") else: print("Error: Dataset not found !!!") sys.exit(1) print("Dataset/partition location: ", part_config) with open(part_config) as conf_f: part_metadata = json.load(conf_f) part_files = part_metadata['part-{}'.format(args.rank)] assert 'node_feats' in part_files, "the partition does not contain node features." assert 'edge_feats' in part_files, "the partition does not contain edge feature." assert 'part_graph' in part_files, "the partition does not contain graph structure." node_feats = load_tensors(part_files['node_feats']) # edge_feats = load_tensors(part_files['edge_feats']) graph = load_graphs(part_files['part_graph'])[0][0] num_parts = part_metadata['num_parts'] node_map = part_metadata['node_map'] graph.ndata['feat'] = node_feats['feat'] graph.ndata['lf'] = node_feats['lf'] graph.ndata['label'] = node_feats['label'] graph.ndata['train_mask'] = node_feats['train_mask'] graph.ndata['test_mask'] = node_feats['test_mask'] graph.ndata['val_mask'] = node_feats['val_mask'] g_orig = load_data(args)[0]