コード例 #1
0
ファイル: test_serialize.py プロジェクト: simco19/dgl
def test_serialize_tensors():
    # create a temporary file and immediately release it so DGL can open it.
    f = tempfile.NamedTemporaryFile(delete=False)
    path = f.name
    f.close()

    tensor_dict = {
        "a": F.tensor([1, 3, -1, 0], dtype=F.int64),
        "1@1": F.tensor([1.5, 2], dtype=F.float32)
    }

    save_tensors(path, tensor_dict)

    load_tensor_dict = load_tensors(path)

    for key in tensor_dict:
        assert key in load_tensor_dict
        assert np.array_equal(F.asnumpy(load_tensor_dict[key]),
                              F.asnumpy(tensor_dict[key]))

    load_nd_dict = load_tensors(path, return_dgl_ndarray=True)

    for key in tensor_dict:
        assert key in load_nd_dict
        assert isinstance(load_nd_dict[key], nd.NDArray)
        assert np.array_equal(load_nd_dict[key].asnumpy(),
                              F.asnumpy(tensor_dict[key]))

    os.unlink(path)
コード例 #2
0
ファイル: test_serialize.py プロジェクト: simco19/dgl
def test_serialize_empty_dict():
    # create a temporary file and immediately release it so DGL can open it.
    f = tempfile.NamedTemporaryFile(delete=False)
    path = f.name
    f.close()

    tensor_dict = {}

    save_tensors(path, tensor_dict)

    load_tensor_dict = load_tensors(path)
    assert isinstance(load_tensor_dict, dict)
    assert len(load_tensor_dict) == 0

    os.unlink(path)
コード例 #3
0
ファイル: train_dist_sym.py プロジェクト: yuk12/dgl
    elif args.dataset == 'pubmed':
        part_config = os.path.join(part_config, "Libra_result_pubmed",
                                   str(nc) + "Communities", "pubmed.json")
    else:
        print("Error: Dataset not found !!!")
        sys.exit(1)

    print("Dataset/partition location: ", part_config)
    with open(part_config) as conf_f:
        part_metadata = json.load(conf_f)

    part_files = part_metadata['part-{}'.format(args.rank)]
    assert 'node_feats' in part_files, "the partition does not contain node features."
    assert 'edge_feats' in part_files, "the partition does not contain edge feature."
    assert 'part_graph' in part_files, "the partition does not contain graph structure."
    node_feats = load_tensors(part_files['node_feats'])
    # edge_feats = load_tensors(part_files['edge_feats'])
    graph = load_graphs(part_files['part_graph'])[0][0]

    num_parts = part_metadata['num_parts']
    node_map = part_metadata['node_map']

    graph.ndata['feat'] = node_feats['feat']
    graph.ndata['lf'] = node_feats['lf']
    graph.ndata['label'] = node_feats['label']
    graph.ndata['train_mask'] = node_feats['train_mask']
    graph.ndata['test_mask'] = node_feats['test_mask']
    graph.ndata['val_mask'] = node_feats['val_mask']

    g_orig = load_data(args)[0]