def test_tangible(self): eq = ['sin(x*2)', 'exp(x)+x', 'x/3', '3*exp(2/x)'] grammar_weights = "../save_model/grammar_ae_model.pt" print(grammar_weights) grammar_model = equation_vae.EquationGrammarModel(grammar_weights, latent_rep_size=25) # grammar_model = equation_vae.EquationGrammarModel(grammar_weights, latent_rep_size=LATENT_REP_SIZE) z, _none = grammar_model.encode(eq) for i, s in enumerate(grammar_model.decode(z)): print(eq[i] + " --> " + s)
from __future__ import division import sys import equation_vae import numpy as np from numpy import sin, exp, cos from matplotlib import pyplot as plt import pdb # 1. load grammar VAE grammar_weights = "pretrained/eq_vae_grammar_h100_c234_L25_E50_batchB.hdf5" grammar_weights = "eq_vae_grammar_h100_c234_L25_E50_batchB.hdf5" print(grammar_weights) grammar_model = equation_vae.EquationGrammarModel(grammar_weights, latent_rep_size=25) # 2. let's encode and decode some example equations eq = ['sin(x*2)', 'exp(x)+x', 'x/3', '3*exp(2/x)'] # z: encoded latent points # NOTE: this operation returns the mean of the encoding distribution # if you would like it to sample from that distribution instead # replace line 62 in equation_vae.py with: return self.vae.encoder.predict(one_hot) z = grammar_model.encode(eq) # mol: decoded equations # NOTE: decoding is stochastic so calling this function many # times for the same latent point will return different answers # let's plot how well the true functions match the decoded functions domain = np.linspace(-10, 10)
import sys sys.path.insert(0, '../../../') import equation_vae from numpy import * # need this for evaluating equations from sparse_gp import SparseGP import scipy.stats as sps import numpy as np import os.path import os import copy import time # here we load the grammar VAE in order to see if the equations parse correctly # diff_model = equation_vae.EquationGrammarModel("../eq_vae_grammar_h100_c234_L25_E50_batchB.hdf5",25) diff_model = equation_vae.EquationGrammarModel( "../../../pretrained/eq_vae_grammar_h100_c234_L25_E50_batchB.hdf5") def decode_from_latent_space(latent_points, grammar_model): decode_attempts = 25 decoded_molecules = [] for i in range(decode_attempts): current_decoded_molecules = grammar_model.decode(latent_points) decoded_molecules.append(current_decoded_molecules) # We see which ones are decoded by rdkit x = 0 # make x a dummy variable in order to see if equations parse correctly rdkit_molecules = [] for i in range(decode_attempts): rdkit_molecules.append([])