示例#1
0
文件: qm9.py 项目: yaniv256/spektral
def load_data(nf_keys=None,
              ef_keys=None,
              auto_pad=True,
              self_loops=False,
              amount=None,
              return_type='numpy'):
    """
    Loads the QM9 chemical data set of small molecules.

    Nodes represent heavy atoms (hydrogens are discarded), edges represent
    chemical bonds.

    The node features represent the chemical properties of each atom, and are
    loaded according to the `nf_keys` argument.
    See `spektral.datasets.qm9.NODE_FEATURES` for possible node features, and
    see [this link](http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx)
    for the meaning of each property. Usually, it is sufficient to load the
    atomic number.

    The edge features represent the type and stereoscopy of each chemical bond
    between two atoms.
    See `spektral.datasets.qm9.EDGE_FEATURES` for possible edge features, and
    see [this link](http://www.nonlinear.com/progenesis/sdf-studio/v0.9/faq/sdf-file-format-guidance.aspx)
    for the meaning of each property. Usually, it is sufficient to load the
    type of bond.

    :param nf_keys: list or str, node features to return (see `qm9.NODE_FEATURES`
    for available features);
    :param ef_keys: list or str, edge features to return (see `qm9.EDGE_FEATURES`
    for available features);
    :param auto_pad: if `return_type='numpy'`, zero pad graph matrices to have 
    the same number of nodes;
    :param self_loops: if `return_type='numpy'`, add self loops to adjacency 
    matrices;
    :param amount: the amount of molecules to return (in ascending order by
    number of atoms).
    :param return_type: `'numpy'`, `'networkx'`, or `'sdf'`, data format to return;
    :return:
    - if `return_type='numpy'`, the adjacency matrix, node features,
    edge features, and a Pandas dataframe containing labels;
    - if `return_type='networkx'`, a list of graphs in Networkx format,
    and a dataframe containing labels;   
    - if `return_type='sdf'`, a list of molecules in the internal SDF format and
    a dataframe containing labels.
    """
    if return_type not in RETURN_TYPES:
        raise ValueError('Possible return_type: {}'.format(RETURN_TYPES))

    if not os.path.exists(DATA_PATH):
        _download_data()  # Try to download dataset

    print('Loading QM9 dataset.')
    sdf_file = os.path.join(DATA_PATH, 'qm9.sdf')
    data = load_sdf(sdf_file, amount=amount)  # Internal SDF format

    # Load labels
    labels_file = os.path.join(DATA_PATH, 'qm9.sdf.csv')
    labels = load_csv(labels_file)
    if amount is not None:
        labels = labels[:amount]
    if return_type is 'sdf':
        return data, labels
    else:
        # Convert to Networkx
        data = [sdf_to_nx(_) for _ in data]

    if return_type is 'numpy':
        if nf_keys is not None:
            if isinstance(nf_keys, str):
                nf_keys = [nf_keys]
        else:
            nf_keys = NODE_FEATURES
        if ef_keys is not None:
            if isinstance(ef_keys, str):
                ef_keys = [ef_keys]
        else:
            ef_keys = EDGE_FEATURES

        adj, nf, ef = nx_to_numpy(data,
                                  auto_pad=auto_pad,
                                  self_loops=self_loops,
                                  nf_keys=nf_keys,
                                  ef_keys=ef_keys)
        return adj, nf, ef, labels
    elif return_type is 'networkx':
        return data, labels
    else:
        # Should not get here
        raise RuntimeError()
示例#2
0
def load_data(return_type='numpy',
              nf_keys=None,
              ef_keys=None,
              auto_pad=True,
              self_loops=False,
              amount=None):
    """
    Loads the QM9 molecules dataset.
    :param return_type: 'networkx', 'numpy', or 'sdf', data format to return;
    :param nf_keys: list or str, node features to return (see `qm9.NODE_FEATURES`
    for available features);
    :param ef_keys: list or str, edge features to return (see `qm9.EDGE_FEATURES`
    for available features);
    :param auto_pad: if `return_type='numpy'`, zero pad graph matrices to have 
    the same number of nodes;
    :param self_loops: if `return_type='numpy'`, add self loops to adjacency 
    matrices;
    :param amount: the amount of molecules to return (in order).
    :return: if `return_type='numpy'`, the adjacency matrix, node features,
    edge features, and a Pandas dataframe containing labels;
    if `return_type='networkx'`, a list of graphs in Networkx format,
    and a dataframe containing labels;   
    if `return_type='sdf'`, a list of molecules in the internal SDF format and
    a dataframe containing labels.
    """
    if return_type not in RETURN_TYPES:
        raise ValueError('Possible return_type: {}'.format(RETURN_TYPES))

    if not os.path.exists(DATA_PATH):
        _ = dataset_downloader()  # Try to download dataset

    print('Loading QM9 dataset.')
    sdf_file = os.path.join(DATA_PATH, 'qm9.sdf')
    data = load_sdf(sdf_file, amount=amount)  # Internal SDF format

    # Load labels
    labels_file = os.path.join(DATA_PATH, 'qm9.sdf.csv')
    labels = load_csv(labels_file)
    if amount is not None:
        labels = labels[:amount]
    if return_type is 'sdf':
        return data, labels
    else:
        # Convert to Networkx
        data = [sdf_to_nx(_, keep_hydrogen=True) for _ in data]

    if return_type is 'numpy':
        if nf_keys is not None:
            if isinstance(nf_keys, str):
                nf_keys = [nf_keys]
        else:
            nf_keys = NODE_FEATURES
        if ef_keys is not None:
            if isinstance(ef_keys, str):
                ef_keys = [ef_keys]
        else:
            ef_keys = EDGE_FEATURES

        adj, nf, ef = nx_to_numpy(data,
                                  auto_pad=auto_pad,
                                  self_loops=self_loops,
                                  nf_keys=nf_keys,
                                  ef_keys=ef_keys)
        return adj, nf, ef, labels
    elif return_type is 'networkx':
        return data, labels
    else:
        # Should not get here
        raise RuntimeError()