def _preprocess_states_actions(actions, states, device): # Process states and actions states = [''.join(list(state)) for state in states] states, states_len = pad_sequences(states) states, _ = seq2tensor(states, get_default_tokens()) states = torch.from_numpy(states).long().to(device) states_len = torch.tensor(states_len).long().to(device) actions, _ = seq2tensor(actions, get_default_tokens()) actions = torch.from_numpy(actions.reshape(-1)).long().to(device) return (states, states_len), actions
def forward(self, x): """ Performs forward propagation to get predictions. Arguments: ---------- :param x: list A list of SMILES strings as input to the model. :return: tensor predictions corresponding to x. """ batch_size = len(x) x, states_len = pad_sequences(x) x, _ = seq2tensor(x, self.tokens) x = torch.from_numpy(x).long().to(self.device) x = self.encoder(x) x = x.permute(1, 0, 2) h0 = torch.zeros(self.num_layers * self.num_directions, batch_size, self.d_model).to(self.device) if self.unit_type == 'lstm': c0 = torch.zeros(self.num_layers * self.num_directions, batch_size, self.d_model).to(self.device) h0 = (h0, c0) x, hidden = self.rnn(x, h0) x = x[-1, :, :].reshape(batch_size, -1) x = self.read_out(x) return x
def calc_adv_ref(self, trajectory): states, actions, _ = unpack_batch([trajectory], self.gamma) last_state = ''.join(list(states[-1])) inp, _ = seq2tensor([last_state], tokens=get_default_tokens()) inp = torch.from_numpy(inp).long().to(self.device) values_v = self.critic(inp) values = values_v.view(-1, ).data.cpu().numpy() last_gae = 0.0 result_adv = [] result_ref = [] for val, next_val, exp in zip(reversed(values[:-1]), reversed(values[1:]), reversed(trajectory[:-1])): if exp.last_state is None: # for terminal state delta = exp.reward - val last_gae = delta else: delta = exp.reward + self.gamma * next_val - val last_gae = delta + self.gamma * self.gae_lambda * last_gae result_adv.append(last_gae) result_ref.append(last_gae + val) adv_v = torch.FloatTensor(list(reversed(result_adv))).to(self.device) ref_v = torch.FloatTensor(list(reversed(result_ref))).to(self.device) return states[:-1], actions[:-1], adv_v, ref_v
def smiles_to_tensor(smiles): smiles = list(smiles) _, valid_vec = canonical_smiles(smiles) valid_vec = torch.tensor(valid_vec).view(-1, 1).float().to(device) smiles, _ = pad_sequences(smiles) inp, _ = seq2tensor(smiles, tokens=get_default_tokens()) inp = torch.from_numpy(inp).long().to(device) return inp, valid_vec
def random_training_set(self, batch_size=None, return_seq_len=False): if batch_size is None: batch_size = self.batch_size assert (batch_size > 0) inp, target = self.random_chunk(batch_size) inp_padded, inp_seq_len = pad_sequences(inp) inp_tensor, self.all_characters = seq2tensor( inp_padded, tokens=self.all_characters, flip=False) target_padded, target_seq_len = pad_sequences(target) target_tensor, self.all_characters = seq2tensor( target_padded, tokens=self.all_characters, flip=False) self.n_characters = len(self.all_characters) inp_tensor = torch.tensor(inp_tensor).long() target_tensor = torch.tensor(target_tensor).long() if self.use_cuda: inp_tensor = inp_tensor.cuda() target_tensor = target_tensor.cuda() if return_seq_len: return inp_tensor, target_tensor, (inp_seq_len, target_seq_len) return inp_tensor, target_tensor
def fit(self, trajectories): """Train the reward function / model using the GRL algorithm.""" """Train the reward function / model using the GRL algorithm.""" if self.use_buffer: extra_trajs = self.replay_buffer.sample(self.batch_size) trajectories.extend(extra_trajs) self.replay_buffer.populate(trajectories) d_traj, d_traj_probs = [], [] for traj in trajectories: d_traj.append(''.join(list(traj.terminal_state.state)) + traj.terminal_state.action) d_traj_probs.append(traj.traj_prob) _, valid_vec_samp = canonical_smiles(d_traj) valid_vec_samp = torch.tensor(valid_vec_samp).view(-1, 1).float().to( self.device) d_traj, _ = pad_sequences(d_traj) d_samp, _ = seq2tensor(d_traj, tokens=get_default_tokens()) d_samp = torch.from_numpy(d_samp).long().to(self.device) losses = [] for i in trange(self.k, desc='IRL optimization...'): # D_demo processing demo_states, demo_actions = self.demo_gen_data.random_training_set( ) d_demo = torch.cat( [demo_states, demo_actions[:, -1].reshape(-1, 1)], dim=1).to(self.device) valid_vec_demo = torch.ones(d_demo.shape[0]).view( -1, 1).float().to(self.device) d_demo_out = self.model([d_demo, valid_vec_demo]) # D_samp processing d_samp_out = self.model([d_samp, valid_vec_samp]) d_out_combined = torch.cat([d_samp_out, d_demo_out], dim=0) if d_samp_out.shape[0] < 1000: d_samp_out = torch.cat([d_samp_out, d_demo_out], dim=0) z = torch.ones(d_samp_out.shape[0]).float().to( self.device) # dummy importance weights TODO: replace this d_samp_out = z.view(-1, 1) * torch.exp(d_samp_out) # objective loss = torch.mean(d_demo_out) - torch.log(torch.mean(d_samp_out)) losses.append(loss.item()) loss = -loss # for maximization # update params self.optimizer.zero_grad() loss.backward() self.optimizer.step() # self.lr_sch.step() return np.mean(losses)
def __call__(self, x, use_mc): """ Calculates the reward function of a given state. :param x: The state to be used in calculating the reward. :param use_mc: Whether Monte Carlo Tree Search or the parameterized reward function should be used :return: float A scalar value representing the reward w.r.t. the given state x. """ if use_mc: if self.mc_enabled: mc_node = MoleculeMonteCarloTreeSearchNode( x, self, self.mc_policy, self.actions, self.max_len, end_char=self.end_char) mcts = MonteCarloTreeSearch(mc_node) reward = mcts(simulations_number=self.mc_max_sims) return reward else: return self.no_mc_fill_val else: # Get reward of completed string using the reward net or a given reward function. state = ''.join(x.tolist()) if self.use_true_reward: state = state[1:-1].replace('\n', '-') reward = self.true_reward_func(state, self.expert_func) else: smiles, valid_vec = canonical_smiles([state]) valid_vec = torch.tensor(valid_vec).view(-1, 1).float().to( self.device) inp, _ = seq2tensor([state], tokens=self.actions) inp = torch.from_numpy(inp).long().to(self.device) reward = self.model([inp, valid_vec]).squeeze().item() return self.reward_wrapper(reward)
def fit(self, trajectories): sq2ten = lambda x: torch.from_numpy( seq2tensor(x, get_default_tokens())[0]).long().to(self.device) t_states, t_actions, t_adv, t_ref = [], [], [], [] t_old_probs = [] for traj in trajectories: states, actions, adv_v, ref_v = self.calc_adv_ref(traj) if len(states) == 0: continue t_states.append(states) t_actions.append(actions) t_adv.append(adv_v) t_ref.append(ref_v) with torch.set_grad_enabled(False): hidden_states = self.initial_states_func( batch_size=1, **self.initial_states_args) trajectory_input = sq2ten(states[-1]) actions = sq2ten(actions) old_probs = [] for p in range(len(trajectory_input)): outputs = self.model([trajectory_input[p].reshape(1, 1)] + hidden_states) output, hidden_states = outputs[0], outputs[1:] log_prob = torch.log_softmax(output.view(1, -1), dim=1) old_probs.append(log_prob[0, actions[p]].item()) t_old_probs.append(old_probs) if len(t_states) == 0: return 0., 0. for epoch in trange(self.ppo_epochs, desc='PPO optimization...'): cr_loss = 0. ac_loss = 0. for i in range(len(t_states)): traj_last_state = t_states[i][-1] traj_actions = t_actions[i] traj_adv = t_adv[i] traj_ref = t_ref[i] traj_old_probs = t_old_probs[i] hidden_states = self.initial_states_func( 1, **self.initial_states_args) for p in range(len(traj_last_state)): state, action, adv = traj_last_state[p], traj_actions[ p], traj_adv[p] old_log_prob = traj_old_probs[p] state, action = sq2ten(state), sq2ten(action) # Critic pred = self.critic(state) cr_loss = cr_loss + F.mse_loss(pred.reshape(-1, 1), traj_ref[p].reshape(-1, 1)) # Actor outputs = self.actor([state] + hidden_states) output, hidden_states = outputs[0], outputs[1:] logprob_pi_v = torch.log_softmax(output.view(1, -1), dim=-1) logprob_pi_v = logprob_pi_v[0, action] ratio_v = torch.exp(logprob_pi_v - old_log_prob) surr_obj_v = adv * ratio_v clipped_surr_v = adv * torch.clamp( ratio_v, 1.0 - self.ppo_eps, 1.0 + self.ppo_eps) loss_policy_v = torch.min(surr_obj_v, clipped_surr_v) # Maximize entropy prob = torch.softmax(output.view(1, -1), dim=1) prob = prob[0, action] entropy = prob * logprob_pi_v entropy_loss = self.entropy_beta * entropy ac_loss = ac_loss - (loss_policy_v + entropy_loss) # Update weights self.critic_opt.zero_grad() self.actor_opt.zero_grad() cr_loss = cr_loss / len(trajectories) ac_loss = ac_loss / len(trajectories) cr_loss.backward() ac_loss.backward() self.critic_opt.step() self.actor_opt.step() return cr_loss.item(), -ac_loss.item()