def evaluate_model(model, gen_data, rnn_args, sim_data_node=None, num_smiles=1000): start = time.time() model.eval() # Samples SMILES samples = [] step = 100 count = 0 for _ in range(int(num_smiles / step)): samples.extend( generate_smiles(generator=model, gen_data=gen_data, init_args=rnn_args, num_samples=step, is_train=False, verbose=True, max_len=smiles_max_len)) count += step res = num_smiles - count if res > 0: samples.extend( generate_smiles(generator=model, gen_data=gen_data, init_args=rnn_args, num_samples=res, is_train=False, verbose=True, max_len=smiles_max_len)) smiles, valid_vec = canonical_smiles(samples) valid_smiles = [] invalid_smiles = [] for idx, sm in enumerate(smiles): if len(sm) > 0: valid_smiles.append(sm) else: invalid_smiles.append(samples[idx]) v = len(valid_smiles) valid_smiles = list(set(valid_smiles)) print( f'Percentage of valid SMILES = {float(len(valid_smiles)) / float(len(samples)):.2f}, ' f'Num. samples = {len(samples)}, Num. valid = {len(valid_smiles)}, ' f'Num. requested = {num_smiles}, Num. dups = {v - len(valid_smiles)}' ) # sub-nodes of sim data resource smiles_node = DataNode(label="valid_smiles", data=valid_smiles) invalid_smiles_node = DataNode(label='invalid_smiles', data=invalid_smiles) # add sim data nodes to parent node if sim_data_node: sim_data_node.data = [smiles_node, invalid_smiles_node] duration = time.time() - start print('\nModel evaluation duration: {:.0f}m {:.0f}s'.format( duration // 60, duration % 60))
def smiles_to_tensor(smiles): smiles = list(smiles) _, valid_vec = canonical_smiles(smiles) valid_vec = torch.tensor(valid_vec).view(-1, 1).float().to(device) smiles, _ = pad_sequences(smiles) inp, _ = seq2tensor(smiles, tokens=get_default_tokens()) inp = torch.from_numpy(inp).long().to(device) return inp, valid_vec
def fit(self, trajectories): """Train the reward function / model using the GRL algorithm.""" """Train the reward function / model using the GRL algorithm.""" if self.use_buffer: extra_trajs = self.replay_buffer.sample(self.batch_size) trajectories.extend(extra_trajs) self.replay_buffer.populate(trajectories) d_traj, d_traj_probs = [], [] for traj in trajectories: d_traj.append(''.join(list(traj.terminal_state.state)) + traj.terminal_state.action) d_traj_probs.append(traj.traj_prob) _, valid_vec_samp = canonical_smiles(d_traj) valid_vec_samp = torch.tensor(valid_vec_samp).view(-1, 1).float().to( self.device) d_traj, _ = pad_sequences(d_traj) d_samp, _ = seq2tensor(d_traj, tokens=get_default_tokens()) d_samp = torch.from_numpy(d_samp).long().to(self.device) losses = [] for i in trange(self.k, desc='IRL optimization...'): # D_demo processing demo_states, demo_actions = self.demo_gen_data.random_training_set( ) d_demo = torch.cat( [demo_states, demo_actions[:, -1].reshape(-1, 1)], dim=1).to(self.device) valid_vec_demo = torch.ones(d_demo.shape[0]).view( -1, 1).float().to(self.device) d_demo_out = self.model([d_demo, valid_vec_demo]) # D_samp processing d_samp_out = self.model([d_samp, valid_vec_samp]) d_out_combined = torch.cat([d_samp_out, d_demo_out], dim=0) if d_samp_out.shape[0] < 1000: d_samp_out = torch.cat([d_samp_out, d_demo_out], dim=0) z = torch.ones(d_samp_out.shape[0]).float().to( self.device) # dummy importance weights TODO: replace this d_samp_out = z.view(-1, 1) * torch.exp(d_samp_out) # objective loss = torch.mean(d_demo_out) - torch.log(torch.mean(d_samp_out)) losses.append(loss.item()) loss = -loss # for maximization # update params self.optimizer.zero_grad() loss.backward() self.optimizer.step() # self.lr_sch.step() return np.mean(losses)
def __call__(self, x, use_mc): """ Calculates the reward function of a given state. :param x: The state to be used in calculating the reward. :param use_mc: Whether Monte Carlo Tree Search or the parameterized reward function should be used :return: float A scalar value representing the reward w.r.t. the given state x. """ if use_mc: if self.mc_enabled: mc_node = MoleculeMonteCarloTreeSearchNode( x, self, self.mc_policy, self.actions, self.max_len, end_char=self.end_char) mcts = MonteCarloTreeSearch(mc_node) reward = mcts(simulations_number=self.mc_max_sims) return reward else: return self.no_mc_fill_val else: # Get reward of completed string using the reward net or a given reward function. state = ''.join(x.tolist()) if self.use_true_reward: state = state[1:-1].replace('\n', '-') reward = self.true_reward_func(state, self.expert_func) else: smiles, valid_vec = canonical_smiles([state]) valid_vec = torch.tensor(valid_vec).view(-1, 1).float().to( self.device) inp, _ = seq2tensor([state], tokens=self.actions) inp = torch.from_numpy(inp).long().to(self.device) reward = self.model([inp, valid_vec]).squeeze().item() return self.reward_wrapper(reward)