def main(): """ Test if Ridge regression is working. Parameters ---------- fxyz: string giving location of xyz file prefix: string giving the filename prefix """ fxyz = os.path.join(os.path.split(__file__)[0], 'small_molecules-SOAP.xyz') fmat = ['SOAP-n4-l3-c1.9-g0.23'] fy = 'dft_formation_energy_per_atom_in_eV' prefix = "test-skrr" test_ratio = 0.05 lc_points = 8 lc_repeats = 8 # try to read the xyz file asapxyz = ASAPXYZ(fxyz) desc, _ = asapxyz.get_descriptors(fmat, False) y_all = asapxyz.get_property(fy) # print(desc) dm = Design_Matrix(X=desc, y=y_all, whiten=True, test_ratio=test_ratio) # kernel, jitter, delta, sigma, sparse_mode="fps", n_sparse=None k_spec = { 'k0': { "type": "linear" } } # { 'k1': {"type": "polynomial", "d": power}} # if sigma is not set... sigma = 0.001 * np.std(y_all) krr = KRRSparse(0., None, sigma) skrr = SPARSE_KRR_Wrapper(k_spec, krr, sparse_mode="fps", n_sparse=-1) # fit the model dm.compute_fit(skrr, 'skrr', store_results=True, plot=True) # learning curve if lc_points > 1: dm.compute_learning_curve(skrr, 'ridge_regression', lc_points=lc_points, lc_repeats=lc_repeats, randomseed=42, verbose=False) dm.save_state(prefix) plt.show()
def main(): """ Test if Ridge regression is working. Parameters ---------- fxyz: string giving location of xyz file prefix: string giving the filename prefix """ fxyz = os.path.join(os.path.split(__file__)[0], 'small_molecules-SOAP.xyz') fmat = ['SOAP-n4-l3-c1.9-g0.23'] fy = 'dft_formation_energy_per_atom_in_eV' prefix = "test-rr" test_ratio = 0.05 lc_points = 8 lc_repeats = 8 # try to read the xyz file asapxyz = ASAPXYZ(fxyz) desc, _ = asapxyz.get_descriptors(fmat, False) y_all = asapxyz.get_property(fy) # print(desc) dm = Design_Matrix(X=desc, y=y_all, whiten=True, test_ratio=test_ratio) # if sigma is not set... sigma = 0.001 * np.std(y_all) rr = RidgeRegression(sigma) # fit the model dm.compute_fit(rr, 'ridge_regression', store_results=True, plot=True) # learning curve if lc_points > 1: dm.compute_learning_curve(rr, 'ridge_regression', lc_points=lc_points, lc_repeats=lc_repeats, randomseed=42, verbose=False) dm.save_state(prefix) plt.show()
def main(fmat, fxyz, fy, prefix, scale, test_ratio, sigma, lc_points, lc_repeats): """ Parameters ---------- fmat: Location of descriptor matrix file or name of the tags in ase xyz file. You can use gen_descriptors.py to compute it. fxyz: Location of xyz file for reading the properties. fy: Location of property list (1D-array of floats) prefix: filename prefix for learning curve figure scale: Scale the coordinates (True/False). Scaling highly recommanded. test_ratio: train/test ratio sigma: noise level in kernel ridge regression, default is 0.1% of the standard deviation of the data. lc_points : number of points on the learning curve lc_repeats : number of sub-sampling when compute the learning curve Returns ------- Learning curve. """ scale = bool(scale) # try to read the xyz file if fxyz != 'none': asapxyz = ASAPXYZ(fxyz) desc, _ = asapxyz.get_descriptors(fmat) # we can also load the descriptor matrix from a standalone file if os.path.isfile(fmat[0]): try: desc = np.genfromtxt(fmat[0], dtype=float) print("loaded the descriptor matrix from file: ", fmat) except: raise ValueError('Cannot load the descriptor matrix from file') if len(desc) == 0: raise ValueError( 'Please supply descriptor in a xyz file or a standlone descriptor matrix' ) print("shape of the descriptor matrix: ", np.shape(desc), "number of descriptors: ", np.shape(desc[0])) # read in the properties to be predicted y_all = [] try: y_all = np.genfromtxt(fy, dtype=float) except: y_all = asapxyz.get_property(fy) dm = Design_Matrix(X=desc, y=y_all, whiten=True, test_ratio=test_ratio) # if sigma is not set... if sigma < 0: sigma = 0.001 * np.std(y_all) rr = RidgeRegression(sigma) # fit the model dm.compute_fit(rr, 'ridge_regression', store_results=True, plot=True) # learning curve if lc_points > 1: lc_scores = dm.compute_learning_curve(rr, 'ridge_regression', lc_points=lc_points, lc_repeats=lc_repeats, randomseed=42, verbose=False) # make plot lc_scores.plot_learning_curve() plt.show()