Example #1
0
def test_iqspr_1(data):
    np.random.seed(0)
    ecfp = data['ecfp']
    bre = GaussianLogLikelihood(descriptor=ecfp)
    ngram = NGram()
    iqspr = IQSPR(estimator=bre, modifier=ngram)
    X, y = data['pg']
    bre.fit(X, y)
    bre.update_targets(reset=True, bandgap=(0.1, 0.2), density=(0.9, 1.2))
    ngram.fit(data['pg'][0][0:20], train_order=10)
    beta = np.linspace(0.05, 1, 10)
    for s, ll, p, f in iqspr(data['pg'][0][:5], beta, yield_lpf=True):
        assert np.abs(np.sum(p) - 1.0) < 1e-5
        assert np.sum(f) == 5
Example #2
0
def data():
    # ignore numpy warning
    import warnings
    print('ignore NumPy RuntimeWarning\n')
    warnings.filterwarnings("ignore", message="numpy.dtype size changed")
    warnings.filterwarnings("ignore", message="numpy.ndarray size changed")

    pwd = Path(__file__).parent
    pg_data = pd.read_csv(str(pwd / 'polymer_test_data.csv'))

    X = pg_data['smiles']
    y = pg_data.drop(['smiles', 'Unnamed: 0'], axis=1)
    ecfp = ECFP(n_jobs=1, input_type='smiles', target_col=0)
    rdkitfp = RDKitFP(n_jobs=1, input_type='smiles', target_col=0)
    bre = GaussianLogLikelihood(descriptor=ecfp)
    bre2 = GaussianLogLikelihood(descriptor=rdkitfp)
    bre.fit(X, y[['bandgap', 'glass_transition_temperature']])
    bre2.fit(X, y[['density', 'refractive_index']])
    bre.update_targets(bandgap=(1, 2), glass_transition_temperature=(200, 300))
    bre2.update_targets(refractive_index=(2, 3), density=(0.9, 1.2))

    class MyLogLikelihood(BaseLogLikelihoodSet):
        def __init__(self):
            super().__init__()

            self.loglike = bre
            self.loglike = bre2

    like_mdl = MyLogLikelihood()
    ngram = NGram()
    ngram.fit(X[0:20], train_order=5)
    iqspr = IQSPR(estimator=bre, modifier=ngram)
    # prepare test data
    yield dict(ecfp=ecfp,
               rdkitfp=rdkitfp,
               bre=bre,
               bre2=bre2,
               like_mdl=like_mdl,
               ngram=ngram,
               iqspr=iqspr,
               pg=(X, y))

    print('test over')