コード例 #1
0
    def fit_transform(self, X, y=None):
        self.fit(X)
        # Compute the kernel matrix by only computing the lower half
        kernel = get_local_symmetric_kernels(self.train_points)[0]

        # If specified, average it
        if self.average:
            _compute_average(kernel, X, X)
        return kernel
コード例 #2
0
ファイル: test_fchl.py プロジェクト: syyunn/qml
def test_krr_fchl_atomic():

    test_dir = os.path.dirname(os.path.realpath(__file__))

    # Parse file containing PBE0/def2-TZVP heats of formation and xyz filenames
    data = get_energies(test_dir + "/data/hof_qm7.txt")

    # Generate a list of qml.Compound() objects"
    mols = []

    for xyz_file in sorted(data.keys())[:10]:

        # Initialize the qml.Compound() objects
        mol = qml.Compound(xyz=test_dir + "/qm7/" + xyz_file)

        # Associate a property (heat of formation) with the object
        mol.properties = data[xyz_file]

        # This is a Molecular Coulomb matrix sorted by row norm
        mol.representation = generate_representation(mol.coordinates, \
                                mol.nuclear_charges, cut_distance=1e6)
        mols.append(mol)

    X = np.array([mol.representation for mol in mols])

    # Set hyper-parameters
    sigma = 2.5

    K = get_local_symmetric_kernels(X, [sigma])[0]

    K_test = np.zeros((len(mols), len(mols)))

    for i, Xi in enumerate(X):
        for j, Xj in enumerate(X):

            K_atomic = get_atomic_kernels(Xi[:mols[i].natoms],
                                          Xj[:mols[j].natoms], [sigma])[0]
            K_test[i, j] = np.sum(K_atomic)

            assert np.invert(np.all(
                np.isnan(K_atomic))), "FCHL atomic kernel contains NaN"

            if (i == j):
                K_atomic_symmetric = get_atomic_symmetric_kernels(
                    Xi[:mols[i].natoms], [sigma])[0]
                assert np.allclose(K_atomic, K_atomic_symmetric
                                   ), "Error in FCHL symmetric atomic kernels"
                assert np.invert(np.all(np.isnan(K_atomic_symmetric))
                                 ), "FCHL atomic symmetric kernel contains NaN"

    assert np.allclose(K, K_test), "Error in FCHL atomic kernels"
コード例 #3
0
ファイル: test_fchl.py プロジェクト: syyunn/qml
def test_krr_fchl_local():

    # Test that all kernel arguments work
    kernel_args = {
        "cut_distance": 1e6,
        "cut_start": 0.5,
        "two_body_width": 0.1,
        "two_body_scaling": 2.0,
        "two_body_power": 6.0,
        "three_body_width": 3.0,
        "three_body_scaling": 2.0,
        "three_body_power": 3.0,
        "alchemy": "periodic-table",
        "alchemy_period_width": 1.0,
        "alchemy_group_width": 1.0,
        "fourier_order": 2,
    }

    test_dir = os.path.dirname(os.path.realpath(__file__))

    # Parse file containing PBE0/def2-TZVP heats of formation and xyz filenames
    data = get_energies(test_dir + "/data/hof_qm7.txt")

    # Generate a list of qml.Compound() objects"
    mols = []

    for xyz_file in sorted(data.keys())[:100]:

        # Initialize the qml.Compound() objects
        mol = qml.Compound(xyz=test_dir + "/qm7/" + xyz_file)

        # Associate a property (heat of formation) with the object
        mol.properties = data[xyz_file]

        # This is a Molecular Coulomb matrix sorted by row norm
        mol.generate_fchl_representation(cut_distance=1e6)
        mols.append(mol)

    # Shuffle molecules
    np.random.seed(666)
    np.random.shuffle(mols)

    # Make training and test sets
    n_test = len(mols) // 3
    n_train = len(mols) - n_test

    training = mols[:n_train]
    test = mols[-n_test:]

    X = np.array([mol.representation for mol in training])
    Xs = np.array([mol.representation for mol in test])

    # List of properties
    Y = np.array([mol.properties for mol in training])
    Ys = np.array([mol.properties for mol in test])

    # Set hyper-parameters
    sigma = 2.5
    llambda = 1e-8

    K_symmetric = get_local_symmetric_kernels(X, [sigma], **kernel_args)[0]
    K = get_local_kernels(X, X, [sigma], **kernel_args)[0]

    assert np.allclose(K, K_symmetric), "Error in FCHL symmetric local kernels"
    assert np.invert(np.all(
        np.isnan(K_symmetric))), "FCHL local symmetric kernel contains NaN"
    assert np.invert(np.all(np.isnan(K))), "FCHL local kernel contains NaN"

    # Solve alpha
    K[np.diag_indices_from(K)] += llambda
    alpha = cho_solve(K, Y)

    # Calculate prediction kernel
    Ks = get_local_kernels(Xs, X, [sigma], **kernel_args)[0]
    assert np.invert(np.all(
        np.isnan(Ks))), "FCHL local testkernel contains NaN"

    Yss = np.dot(Ks, alpha)

    mae = np.mean(np.abs(Ys - Yss))
    assert abs(2 - mae) < 1.0, "Error in FCHL local kernel-ridge regression"
コード例 #4
0
ファイル: test_fchl.py プロジェクト: syyunn/qml
def test_fchl_local_periodic():

    nuclear_charges = [
        np.array([
            13, 13, 58, 58, 58, 58, 58, 58, 16, 16, 16, 16, 16, 16, 16, 16, 16,
            16, 16, 16, 16, 16, 23, 23
        ]),
        np.array([
            34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 73, 73, 73, 73, 81,
            81, 81, 81
        ]),
        np.array([48, 8, 8, 8, 8, 8, 8, 51, 51]),
        np.array([
            16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16,
            16, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30, 30,
            30, 30
        ]),
        np.array([58, 58, 8, 8, 8, 8, 8, 8, 8, 8, 23, 23])
    ]

    cells = np.array([[[1.01113290e+01, -1.85000000e-04, 0.00000000e+00],
                       [-5.05582400e+00, 8.75745400e+00, 0.00000000e+00],
                       [0.00000000e+00, 0.00000000e+00, 6.15100100e+00]],
                      [[9.672168, 0., 0.], [0., 3.643786, 0.],
                       [0., 0., 14.961818]],
                      [[5.28208000e+00, -1.20000000e-05, 1.50000000e-05],
                       [-2.64105000e+00, 4.57443400e+00, -3.00000000e-05],
                       [1.40000000e-05, -2.40000000e-05, 4.77522000e+00]],
                      [[-1.917912, 3.321921, 0.], [3.835824, 0., 0.],
                       [1.917912, -1.107307, -56.423542]],
                      [[3.699168, 3.699168, -3.255938],
                       [3.699168, -3.699168, 3.255938],
                       [-3.699168, -3.699168, -3.255938]]])

    fractional_coordinates = [
        [[0.6666706, 0.33333356, 0.15253127],
         [0.33332896, 0.66666655, 0.65253119],
         [0.14802736, 0.375795, 0.23063888],
         [0.62422269, 0.77225019, 0.23063888],
         [0.22775607, 0.85196133, 0.23063888],
         [0.77224448, 0.14803879, 0.7306388],
         [0.37577687, 0.22774993,
          0.7306388], [0.8519722, 0.62420512, 0.7306388],
         [0.57731818, 0.47954083, 0.0102715],
         [0.52043884, 0.09777799, 0.01028288],
         [0.90211803, 0.4226259, 0.01032677],
         [0.33335216, 0.6666734, 0.01482035],
         [0.90959766, 0.77585201, 0.30637615],
         [0.86626106, 0.09040873, 0.3063794],
         [0.22415808, 0.13374794, 0.30638265],
         [0.42268138, 0.52045928, 0.51027142],
         [0.47956114, 0.90222098, 0.5102828],
         [0.09788153, 0.57737421, 0.51032669],
         [0.6666474, 0.33332671, 0.51482027],
         [0.09040133, 0.22414696, 0.80637769],
         [0.13373793, 0.90959025, 0.80637932],
         [0.77584247, 0.86625217, 0.80638257], [0., 0., 0.05471142],
         [0., 0., 0.55471134]],
        [[0.81615001, 0.75000014, 0.00116296],
         [0.52728096, 0.25000096, 0.0993275],
         [0.24582596, 0.75000014, 0.2198563],
         [0.74582658, 0.75000014, 0.28014376],
         [0.02728137, 0.25000096, 0.40067257],
         [0.31615042, 0.75000014, 0.49883711],
         [0.68384978, 0.25000096, 0.50116236],
         [0.97271884, 0.75000014, 0.59932757],
         [0.25417362, 0.25000096, 0.71985637],
         [0.75417321, 0.25000096, 0.78014383],
         [0.47271925, 0.75000014, 0.90067263],
         [0.1838502, 0.25000096, 0.99883717],
         [0.33804831, 0.75000014, 0.07120258],
         [0.83804789, 0.75000014, 0.42879749],
         [0.16195232, 0.25000096, 0.57120198],
         [0.6619519, 0.25000096, 0.92879756],
         [0.98245812, 0.25000096, 0.17113829],
         [0.48245853, 0.25000096, 0.32886177],
         [0.51754167, 0.75000014, 0.67113836],
         [0.01754209, 0.75000014, 0.82886184]],
        [[0.00000000e+00, 0.00000000e+00, 0.00000000e+00],
         [3.66334233e-01, 1.96300000e-07, 2.70922493e-01],
         [6.33665197e-01, 6.33666177e-01, 2.70923540e-01],
         [3.62000000e-08, 3.66333081e-01, 2.70923851e-01],
         [6.70000000e-09, 6.33664733e-01, 7.29076149e-01],
         [6.33664135e-01, 9.99998055e-01, 7.29076460e-01],
         [3.66336157e-01, 3.66334260e-01, 7.29077507e-01],
         [3.33333635e-01, 6.66667395e-01, 4.99998953e-01],
         [6.66667720e-01, 3.33333042e-01, 5.00000000e-01]],
        [[0.3379644, 0.66203644, 0.01389048],
         [0.02316309, 0.97683587, 0.06948926],
         [0.70833843, 0.29165976, 0.12501892],
         [0.39352259, 0.60647824, 0.18056506],
         [0.74538243, 0.25461577, 0.2361509],
         [0.09722803, 0.90277092, 0.2916841],
         [0.44907919, 0.55092165,
          0.34723485], [0.8009281, 0.1990701, 0.4027879],
         [0.15278103, 0.84721793, 0.45834308],
         [0.83797345, 0.16202475, 0.51392396],
         [0.52315813, 0.4768427, 0.56947169],
         [0.20833916, 0.7916598, 0.62501748],
         [0.89352691, 0.10647128, 0.68058436],
         [0.57870427, 0.42129656, 0.73611012],
         [0.93056329, 0.06943491, 0.79169347],
         [0.28241704, 0.71758191, 0.84725114],
         [0.63426956, 0.36573128, 0.90280596],
         [0.98611817, 0.01388002, 0.95835813], [0., 0., 0.],
         [0.35185434, 0.64814649, 0.05556032],
         [0.03704151, 0.96295744, 0.11112454],
         [0.72221887, 0.27777932, 0.16666022],
         [0.40741437, 0.59258647, 0.22224039],
         [0.75926009, 0.24073811, 0.27778387],
         [0.11111195, 0.888887, 0.33333586],
         [0.46296234, 0.53703849, 0.38888431],
         [0.81480954, 0.18518866, 0.44443222],
         [0.16667233, 0.83332662, 0.500017],
         [0.85185117, 0.14814703, 0.55555711],
         [0.53704217, 0.46295866, 0.61112381],
         [0.22222196, 0.777777, 0.66666587],
         [0.90740847, 0.09258972, 0.72222903],
         [0.59259328, 0.40740756,
          0.77777712], [0.94444213, 0.05555607, 0.83333],
         [0.29630132, 0.70369764, 0.88890396],
         [0.64815247, 0.35184836, 0.94445471]],
        [[0., 0., 0.], [0.75000042, 0.50000027, 0.25000015],
         [0.15115386, 0.81961403, 0.33154037],
         [0.51192691, 0.18038651, 0.3315404],
         [0.08154025, 0.31961376, 0.40115401],
         [0.66846017, 0.81961403, 0.48807366],
         [0.08154025, 0.68038678, 0.76192703],
         [0.66846021, 0.18038651, 0.84884672],
         [0.23807355, 0.31961376, 0.91846033],
         [0.59884657, 0.68038678, 0.91846033], [0.50000031, 0., 0.50000031],
         [0.25000015, 0.50000027, 0.75000042]]
    ]

    n = 5

    X = np.array([
        generate_representation(fractional_coordinates[i],
                                nuclear_charges[i],
                                cell=cells[i],
                                max_size=36,
                                neighbors=200,
                                cut_distance=7.0) for i in range(5)
    ])

    sigmas = [2.5]
    K = get_local_symmetric_kernels(X, sigmas, cut_distance=7.0, cut_start=0.7)

    K_ref = np.array([
        [530.03184304, 435.65196293, 198.61245535, 782.49428327, 263.53562172],
        [435.65196293, 371.35281119, 163.83766549, 643.99777576, 215.04338938],
        [198.61245535, 163.83766549, 76.12134823, 295.02739281, 99.89595704],
        [
            782.49428327, 643.99777576, 295.02739281, 1199.61736141,
            389.31169487
        ],
        [263.53562172, 215.04338938, 99.89595704, 389.31169487, 133.36920188]
    ])

    assert np.allclose(K, K_ref), "Error in periodic FCHL"
コード例 #5
0
def test_krr_fchl_alchemy():

    test_dir = os.path.dirname(os.path.realpath(__file__))

    # Parse file containing PBE0/def2-TZVP heats of formation and xyz filenames
    data = get_energies(test_dir + "/data/hof_qm7.txt")

    # Generate a list of qml.Compound() objects"
    mols = []

    for xyz_file in sorted(data.keys())[:20]:

        # Initialize the qml.Compound() objects
        mol = qml.Compound(xyz=test_dir + "/qm7/" + xyz_file)

        # Associate a property (heat of formation) with the object
        mol.properties = data[xyz_file]

        # This is a Molecular Coulomb matrix sorted by row norm
        mol.generate_fchl_representation(cut_distance=1e6)
        mols.append(mol)

    # Shuffle molecules
    np.random.seed(666)
    np.random.shuffle(mols)

    X = np.array([mol.representation for mol in mols])

    sigma = 2.5

    np.set_printoptions(edgeitems=16, linewidth=6666)
    overlap = np.array(
        [[
            1., 0.00835282, 0.90696062, 0.82257756, 0.61368025, 0.37660345,
            0.19010927, 0.07894037, 0.02696323, 0.00757568, 0.67663385,
            0.61368025, 0.45783336, 0.28096329, 0.14183016, 0.05889311,
            0.02011579, 0.0056518, 0.41523683, 0.37660345
        ],
         [
             0.00835282, 1., 0.00757568, 0.02696323, 0.07894037, 0.19010927,
             0.37660345, 0.61368025, 0.82257756, 0.90696062, 0.0056518,
             0.02011579, 0.05889311, 0.14183016, 0.28096329, 0.45783336,
             0.61368025, 0.67663385, 0.0034684, 0.01234467
         ],
         [
             0.90696062, 0.00757568, 1., 0.90696062, 0.67663385, 0.41523683,
             0.20961139, 0.08703837, 0.02972922, 0.00835282, 0.90696062,
             0.82257756, 0.61368025, 0.37660345, 0.19010927, 0.07894037,
             0.02696323, 0.00757568, 0.67663385, 0.61368025
         ],
         [
             0.82257756, 0.02696323, 0.90696062, 1., 0.90696062, 0.67663385,
             0.41523683, 0.20961139, 0.08703837, 0.02972922, 0.82257756,
             0.90696062, 0.82257756, 0.61368025, 0.37660345, 0.19010927,
             0.07894037, 0.02696323, 0.61368025, 0.67663385
         ],
         [
             0.61368025, 0.07894037, 0.67663385, 0.90696062, 1., 0.90696062,
             0.67663385, 0.41523683, 0.20961139, 0.08703837, 0.61368025,
             0.82257756, 0.90696062, 0.82257756, 0.61368025, 0.37660345,
             0.19010927, 0.07894037, 0.45783336, 0.61368025
         ],
         [
             0.37660345, 0.19010927, 0.41523683, 0.67663385, 0.90696062, 1.,
             0.90696062, 0.67663385, 0.41523683, 0.20961139, 0.37660345,
             0.61368025, 0.82257756, 0.90696062, 0.82257756, 0.61368025,
             0.37660345, 0.19010927, 0.28096329, 0.45783336
         ],
         [
             0.19010927, 0.37660345, 0.20961139, 0.41523683, 0.67663385,
             0.90696062, 1., 0.90696062, 0.67663385, 0.41523683, 0.19010927,
             0.37660345, 0.61368025, 0.82257756, 0.90696062, 0.82257756,
             0.61368025, 0.37660345, 0.14183016, 0.28096329
         ],
         [
             0.07894037, 0.61368025, 0.08703837, 0.20961139, 0.41523683,
             0.67663385, 0.90696062, 1., 0.90696062, 0.67663385, 0.07894037,
             0.19010927, 0.37660345, 0.61368025, 0.82257756, 0.90696062,
             0.82257756, 0.61368025, 0.05889311, 0.14183016
         ],
         [
             0.02696323, 0.82257756, 0.02972922, 0.08703837, 0.20961139,
             0.41523683, 0.67663385, 0.90696062, 1., 0.90696062, 0.02696323,
             0.07894037, 0.19010927, 0.37660345, 0.61368025, 0.82257756,
             0.90696062, 0.82257756, 0.02011579, 0.05889311
         ],
         [
             0.00757568, 0.90696062, 0.00835282, 0.02972922, 0.08703837,
             0.20961139, 0.41523683, 0.67663385, 0.90696062, 1., 0.00757568,
             0.02696323, 0.07894037, 0.19010927, 0.37660345, 0.61368025,
             0.82257756, 0.90696062, 0.0056518, 0.02011579
         ],
         [
             0.67663385, 0.0056518, 0.90696062, 0.82257756, 0.61368025,
             0.37660345, 0.19010927, 0.07894037, 0.02696323, 0.00757568, 1.,
             0.90696062, 0.67663385, 0.41523683, 0.20961139, 0.08703837,
             0.02972922, 0.00835282, 0.90696062, 0.82257756
         ],
         [
             0.61368025, 0.02011579, 0.82257756, 0.90696062, 0.82257756,
             0.61368025, 0.37660345, 0.19010927, 0.07894037, 0.02696323,
             0.90696062, 1., 0.90696062, 0.67663385, 0.41523683, 0.20961139,
             0.08703837, 0.02972922, 0.82257756, 0.90696062
         ],
         [
             0.45783336, 0.05889311, 0.61368025, 0.82257756, 0.90696062,
             0.82257756, 0.61368025, 0.37660345, 0.19010927, 0.07894037,
             0.67663385, 0.90696062, 1., 0.90696062, 0.67663385, 0.41523683,
             0.20961139, 0.08703837, 0.61368025, 0.82257756
         ],
         [
             0.28096329, 0.14183016, 0.37660345, 0.61368025, 0.82257756,
             0.90696062, 0.82257756, 0.61368025, 0.37660345, 0.19010927,
             0.41523683, 0.67663385, 0.90696062, 1., 0.90696062, 0.67663385,
             0.41523683, 0.20961139, 0.37660345, 0.61368025
         ],
         [
             0.14183016, 0.28096329, 0.19010927, 0.37660345, 0.61368025,
             0.82257756, 0.90696062, 0.82257756, 0.61368025, 0.37660345,
             0.20961139, 0.41523683, 0.67663385, 0.90696062, 1., 0.90696062,
             0.67663385, 0.41523683, 0.19010927, 0.37660345
         ],
         [
             0.05889311, 0.45783336, 0.07894037, 0.19010927, 0.37660345,
             0.61368025, 0.82257756, 0.90696062, 0.82257756, 0.61368025,
             0.08703837, 0.20961139, 0.41523683, 0.67663385, 0.90696062, 1.,
             0.90696062, 0.67663385, 0.07894037, 0.19010927
         ],
         [
             0.02011579, 0.61368025, 0.02696323, 0.07894037, 0.19010927,
             0.37660345, 0.61368025, 0.82257756, 0.90696062, 0.82257756,
             0.02972922, 0.08703837, 0.20961139, 0.41523683, 0.67663385,
             0.90696062, 1., 0.90696062, 0.02696323, 0.07894037
         ],
         [
             0.0056518, 0.67663385, 0.00757568, 0.02696323, 0.07894037,
             0.19010927, 0.37660345, 0.61368025, 0.82257756, 0.90696062,
             0.00835282, 0.02972922, 0.08703837, 0.20961139, 0.41523683,
             0.67663385, 0.90696062, 1., 0.00757568, 0.02696323
         ],
         [
             0.41523683, 0.0034684, 0.67663385, 0.61368025, 0.45783336,
             0.28096329, 0.14183016, 0.05889311, 0.02011579, 0.0056518,
             0.90696062, 0.82257756, 0.61368025, 0.37660345, 0.19010927,
             0.07894037, 0.02696323, 0.00757568, 1., 0.90696062
         ],
         [
             0.37660345, 0.01234467, 0.61368025, 0.67663385, 0.61368025,
             0.45783336, 0.28096329, 0.14183016, 0.05889311, 0.02011579,
             0.82257756, 0.90696062, 0.82257756, 0.61368025, 0.37660345,
             0.19010927, 0.07894037, 0.02696323, 0.90696062, 1.
         ]])

    K_alchemy = get_local_symmetric_kernels(X, [sigma],
                                            alchemy="periodic-table")[0]
    K_custom = get_local_symmetric_kernels(X, [sigma], alchemy=overlap)[0]

    assert np.allclose(K_alchemy, K_custom), "Error in alchemy"

    nooverlap = np.eye(100)
    K_noalchemy = get_local_symmetric_kernels(X, [sigma], alchemy="off")[0]
    K_custom = get_local_symmetric_kernels(X, [sigma], alchemy=nooverlap)[0]

    assert np.allclose(K_noalchemy, K_custom), "Error in no-alchemy"
コード例 #6
0
 def fit_transform(self, X, y=None):
     self.fit(X)
     return np.squeeze(get_local_symmetric_kernels(self.train_points))