コード例 #1
    def fit_transform(self, X, y=None):
        # Compute the kernel matrix by only computing the lower half
        kernel = get_local_symmetric_kernels(self.train_points)[0]

        # If specified, average it
        if self.average:
            _compute_average(kernel, X, X)
        return kernel
コード例 #2
ファイル: test_fchl.py プロジェクト: syyunn/qml
def test_krr_fchl_atomic():

    test_dir = os.path.dirname(os.path.realpath(__file__))

    # Parse file containing PBE0/def2-TZVP heats of formation and xyz filenames
    data = get_energies(test_dir + "/data/hof_qm7.txt")

    # Generate a list of qml.Compound() objects"
    mols = []

    for xyz_file in sorted(data.keys())[:10]:

        # Initialize the qml.Compound() objects
        mol = qml.Compound(xyz=test_dir + "/qm7/" + xyz_file)

        # Associate a property (heat of formation) with the object
        mol.properties = data[xyz_file]

        # This is a Molecular Coulomb matrix sorted by row norm
        mol.representation = generate_representation(mol.coordinates, \
                                mol.nuclear_charges, cut_distance=1e6)

    X = np.array([mol.representation for mol in mols])

    # Set hyper-parameters
    sigma = 2.5

    K = get_local_symmetric_kernels(X, [sigma])[0]

    K_test = np.zeros((len(mols), len(mols)))

    for i, Xi in enumerate(X):
        for j, Xj in enumerate(X):

            K_atomic = get_atomic_kernels(Xi[:mols[i].natoms],
                                          Xj[:mols[j].natoms], [sigma])[0]
            K_test[i, j] = np.sum(K_atomic)

            assert np.invert(np.all(
                np.isnan(K_atomic))), "FCHL atomic kernel contains NaN"

            if (i == j):
                K_atomic_symmetric = get_atomic_symmetric_kernels(
                    Xi[:mols[i].natoms], [sigma])[0]
                assert np.allclose(K_atomic, K_atomic_symmetric
                                   ), "Error in FCHL symmetric atomic kernels"
                assert np.invert(np.all(np.isnan(K_atomic_symmetric))
                                 ), "FCHL atomic symmetric kernel contains NaN"

    assert np.allclose(K, K_test), "Error in FCHL atomic kernels"
コード例 #3
ファイル: test_fchl.py プロジェクト: syyunn/qml
def test_krr_fchl_local():

    # Test that all kernel arguments work
    kernel_args = {
        "cut_distance": 1e6,
        "cut_start": 0.5,
        "two_body_width": 0.1,
        "two_body_scaling": 2.0,
        "two_body_power": 6.0,
        "three_body_width": 3.0,
        "three_body_scaling": 2.0,
        "three_body_power": 3.0,
        "alchemy": "periodic-table",
        "alchemy_period_width": 1.0,
        "alchemy_group_width": 1.0,
        "fourier_order": 2,

    test_dir = os.path.dirname(os.path.realpath(__file__))

    # Parse file containing PBE0/def2-TZVP heats of formation and xyz filenames
    data = get_energies(test_dir + "/data/hof_qm7.txt")

    # Generate a list of qml.Compound() objects"
    mols = []

    for xyz_file in sorted(data.keys())[:100]:

        # Initialize the qml.Compound() objects
        mol = qml.Compound(xyz=test_dir + "/qm7/" + xyz_file)

        # Associate a property (heat of formation) with the object
        mol.properties = data[xyz_file]

        # This is a Molecular Coulomb matrix sorted by row norm

    # Shuffle molecules

    # Make training and test sets
    n_test = len(mols) // 3
    n_train = len(mols) - n_test

    training = mols[:n_train]
    test = mols[-n_test:]

    X = np.array([mol.representation for mol in training])
    Xs = np.array([mol.representation for mol in test])

    # List of properties
    Y = np.array([mol.properties for mol in training])
    Ys = np.array([mol.properties for mol in test])

    # Set hyper-parameters
    sigma = 2.5
    llambda = 1e-8

    K_symmetric = get_local_symmetric_kernels(X, [sigma], **kernel_args)[0]
    K = get_local_kernels(X, X, [sigma], **kernel_args)[0]

    assert np.allclose(K, K_symmetric), "Error in FCHL symmetric local kernels"
    assert np.invert(np.all(
        np.isnan(K_symmetric))), "FCHL local symmetric kernel contains NaN"
    assert np.invert(np.all(np.isnan(K))), "FCHL local kernel contains NaN"

    # Solve alpha
    K[np.diag_indices_from(K)] += llambda
    alpha = cho_solve(K, Y)

    # Calculate prediction kernel
    Ks = get_local_kernels(Xs, X, [sigma], **kernel_args)[0]
    assert np.invert(np.all(
        np.isnan(Ks))), "FCHL local testkernel contains NaN"

    Yss = np.dot(Ks, alpha)

    mae = np.mean(np.abs(Ys - Yss))
    assert abs(2 - mae) < 1.0, "Error in FCHL local kernel-ridge regression"
コード例 #4
ファイル: test_fchl.py プロジェクト: syyunn/qml
def test_fchl_local_periodic():

    nuclear_charges = [
            13, 13, 58, 58, 58, 58, 58, 58, 16, 16, 16, 16, 16, 16, 16, 16, 16,
            16, 16, 16, 16, 16, 23, 23
            34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 73, 73, 73, 73, 81,
            81, 81, 81
        np.array([48, 8, 8, 8, 8, 8, 8, 51, 51]),
            16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
            16, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
            30, 30
        np.array([58, 58, 8, 8, 8, 8, 8, 8, 8, 8, 23, 23])

    cells = np.array([[[1.01113290e+01, -1.85000000e-04, 0.00000000e+00],
                       [-5.05582400e+00, 8.75745400e+00, 0.00000000e+00],
                       [0.00000000e+00, 0.00000000e+00, 6.15100100e+00]],
                      [[9.672168, 0., 0.], [0., 3.643786, 0.],
                       [0., 0., 14.961818]],
                      [[5.28208000e+00, -1.20000000e-05, 1.50000000e-05],
                       [-2.64105000e+00, 4.57443400e+00, -3.00000000e-05],
                       [1.40000000e-05, -2.40000000e-05, 4.77522000e+00]],
                      [[-1.917912, 3.321921, 0.], [3.835824, 0., 0.],
                       [1.917912, -1.107307, -56.423542]],
                      [[3.699168, 3.699168, -3.255938],
                       [3.699168, -3.699168, 3.255938],
                       [-3.699168, -3.699168, -3.255938]]])

    fractional_coordinates = [
        [[0.6666706, 0.33333356, 0.15253127],
         [0.33332896, 0.66666655, 0.65253119],
         [0.14802736, 0.375795, 0.23063888],
         [0.62422269, 0.77225019, 0.23063888],
         [0.22775607, 0.85196133, 0.23063888],
         [0.77224448, 0.14803879, 0.7306388],
         [0.37577687, 0.22774993,
          0.7306388], [0.8519722, 0.62420512, 0.7306388],
         [0.57731818, 0.47954083, 0.0102715],
         [0.52043884, 0.09777799, 0.01028288],
         [0.90211803, 0.4226259, 0.01032677],
         [0.33335216, 0.6666734, 0.01482035],
         [0.90959766, 0.77585201, 0.30637615],
         [0.86626106, 0.09040873, 0.3063794],
         [0.22415808, 0.13374794, 0.30638265],
         [0.42268138, 0.52045928, 0.51027142],
         [0.47956114, 0.90222098, 0.5102828],
         [0.09788153, 0.57737421, 0.51032669],
         [0.6666474, 0.33332671, 0.51482027],
         [0.09040133, 0.22414696, 0.80637769],
         [0.13373793, 0.90959025, 0.80637932],
         [0.77584247, 0.86625217, 0.80638257], [0., 0., 0.05471142],
         [0., 0., 0.55471134]],
        [[0.81615001, 0.75000014, 0.00116296],
         [0.52728096, 0.25000096, 0.0993275],
         [0.24582596, 0.75000014, 0.2198563],
         [0.74582658, 0.75000014, 0.28014376],
         [0.02728137, 0.25000096, 0.40067257],
         [0.31615042, 0.75000014, 0.49883711],
         [0.68384978, 0.25000096, 0.50116236],
         [0.97271884, 0.75000014, 0.59932757],
         [0.25417362, 0.25000096, 0.71985637],
         [0.75417321, 0.25000096, 0.78014383],
         [0.47271925, 0.75000014, 0.90067263],
         [0.1838502, 0.25000096, 0.99883717],
         [0.33804831, 0.75000014, 0.07120258],
         [0.83804789, 0.75000014, 0.42879749],
         [0.16195232, 0.25000096, 0.57120198],
         [0.6619519, 0.25000096, 0.92879756],
         [0.98245812, 0.25000096, 0.17113829],
         [0.48245853, 0.25000096, 0.32886177],
         [0.51754167, 0.75000014, 0.67113836],
         [0.01754209, 0.75000014, 0.82886184]],
        [[0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
         [3.66334233e-01, 1.96300000e-07, 2.70922493e-01],
         [6.33665197e-01, 6.33666177e-01, 2.70923540e-01],
         [3.62000000e-08, 3.66333081e-01, 2.70923851e-01],
         [6.70000000e-09, 6.33664733e-01, 7.29076149e-01],
         [6.33664135e-01, 9.99998055e-01, 7.29076460e-01],
         [3.66336157e-01, 3.66334260e-01, 7.29077507e-01],
         [3.33333635e-01, 6.66667395e-01, 4.99998953e-01],
         [6.66667720e-01, 3.33333042e-01, 5.00000000e-01]],
        [[0.3379644, 0.66203644, 0.01389048],
         [0.02316309, 0.97683587, 0.06948926],
         [0.70833843, 0.29165976, 0.12501892],
         [0.39352259, 0.60647824, 0.18056506],
         [0.74538243, 0.25461577, 0.2361509],
         [0.09722803, 0.90277092, 0.2916841],
         [0.44907919, 0.55092165,
          0.34723485], [0.8009281, 0.1990701, 0.4027879],
         [0.15278103, 0.84721793, 0.45834308],
         [0.83797345, 0.16202475, 0.51392396],
         [0.52315813, 0.4768427, 0.56947169],
         [0.20833916, 0.7916598, 0.62501748],
         [0.89352691, 0.10647128, 0.68058436],
         [0.57870427, 0.42129656, 0.73611012],
         [0.93056329, 0.06943491, 0.79169347],
         [0.28241704, 0.71758191, 0.84725114],
         [0.63426956, 0.36573128, 0.90280596],
         [0.98611817, 0.01388002, 0.95835813], [0., 0., 0.],
         [0.35185434, 0.64814649, 0.05556032],
         [0.03704151, 0.96295744, 0.11112454],
         [0.72221887, 0.27777932, 0.16666022],
         [0.40741437, 0.59258647, 0.22224039],
         [0.75926009, 0.24073811, 0.27778387],
         [0.11111195, 0.888887, 0.33333586],
         [0.46296234, 0.53703849, 0.38888431],
         [0.81480954, 0.18518866, 0.44443222],
         [0.16667233, 0.83332662, 0.500017],
         [0.85185117, 0.14814703, 0.55555711],
         [0.53704217, 0.46295866, 0.61112381],
         [0.22222196, 0.777777, 0.66666587],
         [0.90740847, 0.09258972, 0.72222903],
         [0.59259328, 0.40740756,
          0.77777712], [0.94444213, 0.05555607, 0.83333],
         [0.29630132, 0.70369764, 0.88890396],
         [0.64815247, 0.35184836, 0.94445471]],
        [[0., 0., 0.], [0.75000042, 0.50000027, 0.25000015],
         [0.15115386, 0.81961403, 0.33154037],
         [0.51192691, 0.18038651, 0.3315404],
         [0.08154025, 0.31961376, 0.40115401],
         [0.66846017, 0.81961403, 0.48807366],
         [0.08154025, 0.68038678, 0.76192703],
         [0.66846021, 0.18038651, 0.84884672],
         [0.23807355, 0.31961376, 0.91846033],
         [0.59884657, 0.68038678, 0.91846033], [0.50000031, 0., 0.50000031],
         [0.25000015, 0.50000027, 0.75000042]]

    n = 5

    X = np.array([
                                cut_distance=7.0) for i in range(5)

    sigmas = [2.5]
    K = get_local_symmetric_kernels(X, sigmas, cut_distance=7.0, cut_start=0.7)

    K_ref = np.array([
        [530.03184304, 435.65196293, 198.61245535, 782.49428327, 263.53562172],
        [435.65196293, 371.35281119, 163.83766549, 643.99777576, 215.04338938],
        [198.61245535, 163.83766549, 76.12134823, 295.02739281, 99.89595704],
            782.49428327, 643.99777576, 295.02739281, 1199.61736141,
        [263.53562172, 215.04338938, 99.89595704, 389.31169487, 133.36920188]

    assert np.allclose(K, K_ref), "Error in periodic FCHL"
コード例 #5
def test_krr_fchl_alchemy():

    test_dir = os.path.dirname(os.path.realpath(__file__))

    # Parse file containing PBE0/def2-TZVP heats of formation and xyz filenames
    data = get_energies(test_dir + "/data/hof_qm7.txt")

    # Generate a list of qml.Compound() objects"
    mols = []

    for xyz_file in sorted(data.keys())[:20]:

        # Initialize the qml.Compound() objects
        mol = qml.Compound(xyz=test_dir + "/qm7/" + xyz_file)

        # Associate a property (heat of formation) with the object
        mol.properties = data[xyz_file]

        # This is a Molecular Coulomb matrix sorted by row norm

    # Shuffle molecules

    X = np.array([mol.representation for mol in mols])

    sigma = 2.5

    np.set_printoptions(edgeitems=16, linewidth=6666)
    overlap = np.array(
            1., 0.00835282, 0.90696062, 0.82257756, 0.61368025, 0.37660345,
            0.19010927, 0.07894037, 0.02696323, 0.00757568, 0.67663385,
            0.61368025, 0.45783336, 0.28096329, 0.14183016, 0.05889311,
            0.02011579, 0.0056518, 0.41523683, 0.37660345
             0.00835282, 1., 0.00757568, 0.02696323, 0.07894037, 0.19010927,
             0.37660345, 0.61368025, 0.82257756, 0.90696062, 0.0056518,
             0.02011579, 0.05889311, 0.14183016, 0.28096329, 0.45783336,
             0.61368025, 0.67663385, 0.0034684, 0.01234467
             0.90696062, 0.00757568, 1., 0.90696062, 0.67663385, 0.41523683,
             0.20961139, 0.08703837, 0.02972922, 0.00835282, 0.90696062,
             0.82257756, 0.61368025, 0.37660345, 0.19010927, 0.07894037,
             0.02696323, 0.00757568, 0.67663385, 0.61368025
             0.82257756, 0.02696323, 0.90696062, 1., 0.90696062, 0.67663385,
             0.41523683, 0.20961139, 0.08703837, 0.02972922, 0.82257756,
             0.90696062, 0.82257756, 0.61368025, 0.37660345, 0.19010927,
             0.07894037, 0.02696323, 0.61368025, 0.67663385
             0.61368025, 0.07894037, 0.67663385, 0.90696062, 1., 0.90696062,
             0.67663385, 0.41523683, 0.20961139, 0.08703837, 0.61368025,
             0.82257756, 0.90696062, 0.82257756, 0.61368025, 0.37660345,
             0.19010927, 0.07894037, 0.45783336, 0.61368025
             0.37660345, 0.19010927, 0.41523683, 0.67663385, 0.90696062, 1.,
             0.90696062, 0.67663385, 0.41523683, 0.20961139, 0.37660345,
             0.61368025, 0.82257756, 0.90696062, 0.82257756, 0.61368025,
             0.37660345, 0.19010927, 0.28096329, 0.45783336
             0.19010927, 0.37660345, 0.20961139, 0.41523683, 0.67663385,
             0.90696062, 1., 0.90696062, 0.67663385, 0.41523683, 0.19010927,
             0.37660345, 0.61368025, 0.82257756, 0.90696062, 0.82257756,
             0.61368025, 0.37660345, 0.14183016, 0.28096329
             0.07894037, 0.61368025, 0.08703837, 0.20961139, 0.41523683,
             0.67663385, 0.90696062, 1., 0.90696062, 0.67663385, 0.07894037,
             0.19010927, 0.37660345, 0.61368025, 0.82257756, 0.90696062,
             0.82257756, 0.61368025, 0.05889311, 0.14183016
             0.02696323, 0.82257756, 0.02972922, 0.08703837, 0.20961139,
             0.41523683, 0.67663385, 0.90696062, 1., 0.90696062, 0.02696323,
             0.07894037, 0.19010927, 0.37660345, 0.61368025, 0.82257756,
             0.90696062, 0.82257756, 0.02011579, 0.05889311
             0.00757568, 0.90696062, 0.00835282, 0.02972922, 0.08703837,
             0.20961139, 0.41523683, 0.67663385, 0.90696062, 1., 0.00757568,
             0.02696323, 0.07894037, 0.19010927, 0.37660345, 0.61368025,
             0.82257756, 0.90696062, 0.0056518, 0.02011579
             0.67663385, 0.0056518, 0.90696062, 0.82257756, 0.61368025,
             0.37660345, 0.19010927, 0.07894037, 0.02696323, 0.00757568, 1.,
             0.90696062, 0.67663385, 0.41523683, 0.20961139, 0.08703837,
             0.02972922, 0.00835282, 0.90696062, 0.82257756
             0.61368025, 0.02011579, 0.82257756, 0.90696062, 0.82257756,
             0.61368025, 0.37660345, 0.19010927, 0.07894037, 0.02696323,
             0.90696062, 1., 0.90696062, 0.67663385, 0.41523683, 0.20961139,
             0.08703837, 0.02972922, 0.82257756, 0.90696062
             0.45783336, 0.05889311, 0.61368025, 0.82257756, 0.90696062,
             0.82257756, 0.61368025, 0.37660345, 0.19010927, 0.07894037,
             0.67663385, 0.90696062, 1., 0.90696062, 0.67663385, 0.41523683,
             0.20961139, 0.08703837, 0.61368025, 0.82257756
             0.28096329, 0.14183016, 0.37660345, 0.61368025, 0.82257756,
             0.90696062, 0.82257756, 0.61368025, 0.37660345, 0.19010927,
             0.41523683, 0.67663385, 0.90696062, 1., 0.90696062, 0.67663385,
             0.41523683, 0.20961139, 0.37660345, 0.61368025
             0.14183016, 0.28096329, 0.19010927, 0.37660345, 0.61368025,
             0.82257756, 0.90696062, 0.82257756, 0.61368025, 0.37660345,
             0.20961139, 0.41523683, 0.67663385, 0.90696062, 1., 0.90696062,
             0.67663385, 0.41523683, 0.19010927, 0.37660345
             0.05889311, 0.45783336, 0.07894037, 0.19010927, 0.37660345,
             0.61368025, 0.82257756, 0.90696062, 0.82257756, 0.61368025,
             0.08703837, 0.20961139, 0.41523683, 0.67663385, 0.90696062, 1.,
             0.90696062, 0.67663385, 0.07894037, 0.19010927
             0.02011579, 0.61368025, 0.02696323, 0.07894037, 0.19010927,
             0.37660345, 0.61368025, 0.82257756, 0.90696062, 0.82257756,
             0.02972922, 0.08703837, 0.20961139, 0.41523683, 0.67663385,
             0.90696062, 1., 0.90696062, 0.02696323, 0.07894037
             0.0056518, 0.67663385, 0.00757568, 0.02696323, 0.07894037,
             0.19010927, 0.37660345, 0.61368025, 0.82257756, 0.90696062,
             0.00835282, 0.02972922, 0.08703837, 0.20961139, 0.41523683,
             0.67663385, 0.90696062, 1., 0.00757568, 0.02696323
             0.41523683, 0.0034684, 0.67663385, 0.61368025, 0.45783336,
             0.28096329, 0.14183016, 0.05889311, 0.02011579, 0.0056518,
             0.90696062, 0.82257756, 0.61368025, 0.37660345, 0.19010927,
             0.07894037, 0.02696323, 0.00757568, 1., 0.90696062
             0.37660345, 0.01234467, 0.61368025, 0.67663385, 0.61368025,
             0.45783336, 0.28096329, 0.14183016, 0.05889311, 0.02011579,
             0.82257756, 0.90696062, 0.82257756, 0.61368025, 0.37660345,
             0.19010927, 0.07894037, 0.02696323, 0.90696062, 1.

    K_alchemy = get_local_symmetric_kernels(X, [sigma],
    K_custom = get_local_symmetric_kernels(X, [sigma], alchemy=overlap)[0]

    assert np.allclose(K_alchemy, K_custom), "Error in alchemy"

    nooverlap = np.eye(100)
    K_noalchemy = get_local_symmetric_kernels(X, [sigma], alchemy="off")[0]
    K_custom = get_local_symmetric_kernels(X, [sigma], alchemy=nooverlap)[0]

    assert np.allclose(K_noalchemy, K_custom), "Error in no-alchemy"
コード例 #6
 def fit_transform(self, X, y=None):
     return np.squeeze(get_local_symmetric_kernels(self.train_points))