Esempio n. 1
0
def _preprocess_states_actions(actions, states, device):
    # Process states and actions
    states = [''.join(list(state)) for state in states]
    states, states_len = pad_sequences(states)
    states, _ = seq2tensor(states, get_default_tokens())
    states = torch.from_numpy(states).long().to(device)
    states_len = torch.tensor(states_len).long().to(device)
    actions, _ = seq2tensor(actions, get_default_tokens())
    actions = torch.from_numpy(actions.reshape(-1)).long().to(device)
    return (states, states_len), actions
Esempio n. 2
0
    def forward(self, x):
        """
        Performs forward propagation to get predictions.

        Arguments:
        ----------
        :param x: list
            A list of SMILES strings as input to the model.
        :return: tensor
            predictions corresponding to x.
        """
        batch_size = len(x)
        x, states_len = pad_sequences(x)
        x, _ = seq2tensor(x, self.tokens)
        x = torch.from_numpy(x).long().to(self.device)
        x = self.encoder(x)
        x = x.permute(1, 0, 2)
        h0 = torch.zeros(self.num_layers * self.num_directions, batch_size,
                         self.d_model).to(self.device)
        if self.unit_type == 'lstm':
            c0 = torch.zeros(self.num_layers * self.num_directions, batch_size,
                             self.d_model).to(self.device)
            h0 = (h0, c0)
        x, hidden = self.rnn(x, h0)
        x = x[-1, :, :].reshape(batch_size, -1)
        x = self.read_out(x)
        return x
Esempio n. 3
0
    def calc_adv_ref(self, trajectory):
        states, actions, _ = unpack_batch([trajectory], self.gamma)
        last_state = ''.join(list(states[-1]))
        inp, _ = seq2tensor([last_state], tokens=get_default_tokens())
        inp = torch.from_numpy(inp).long().to(self.device)
        values_v = self.critic(inp)
        values = values_v.view(-1, ).data.cpu().numpy()
        last_gae = 0.0
        result_adv = []
        result_ref = []
        for val, next_val, exp in zip(reversed(values[:-1]),
                                      reversed(values[1:]),
                                      reversed(trajectory[:-1])):
            if exp.last_state is None:  # for terminal state
                delta = exp.reward - val
                last_gae = delta
            else:
                delta = exp.reward + self.gamma * next_val - val
                last_gae = delta + self.gamma * self.gae_lambda * last_gae
            result_adv.append(last_gae)
            result_ref.append(last_gae + val)

        adv_v = torch.FloatTensor(list(reversed(result_adv))).to(self.device)
        ref_v = torch.FloatTensor(list(reversed(result_ref))).to(self.device)
        return states[:-1], actions[:-1], adv_v, ref_v
Esempio n. 4
0
def smiles_to_tensor(smiles):
    smiles = list(smiles)
    _, valid_vec = canonical_smiles(smiles)
    valid_vec = torch.tensor(valid_vec).view(-1, 1).float().to(device)
    smiles, _ = pad_sequences(smiles)
    inp, _ = seq2tensor(smiles, tokens=get_default_tokens())
    inp = torch.from_numpy(inp).long().to(device)
    return inp, valid_vec
Esempio n. 5
0
 def random_training_set(self, batch_size=None, return_seq_len=False):
     if batch_size is None:
         batch_size = self.batch_size
     assert (batch_size > 0)
     inp, target = self.random_chunk(batch_size)
     inp_padded, inp_seq_len = pad_sequences(inp)
     inp_tensor, self.all_characters = seq2tensor(
         inp_padded, tokens=self.all_characters, flip=False)
     target_padded, target_seq_len = pad_sequences(target)
     target_tensor, self.all_characters = seq2tensor(
         target_padded, tokens=self.all_characters, flip=False)
     self.n_characters = len(self.all_characters)
     inp_tensor = torch.tensor(inp_tensor).long()
     target_tensor = torch.tensor(target_tensor).long()
     if self.use_cuda:
         inp_tensor = inp_tensor.cuda()
         target_tensor = target_tensor.cuda()
     if return_seq_len:
         return inp_tensor, target_tensor, (inp_seq_len, target_seq_len)
     return inp_tensor, target_tensor
Esempio n. 6
0
    def fit(self, trajectories):
        """Train the reward function / model using the GRL algorithm."""
        """Train the reward function / model using the GRL algorithm."""
        if self.use_buffer:
            extra_trajs = self.replay_buffer.sample(self.batch_size)
            trajectories.extend(extra_trajs)
            self.replay_buffer.populate(trajectories)
        d_traj, d_traj_probs = [], []
        for traj in trajectories:
            d_traj.append(''.join(list(traj.terminal_state.state)) +
                          traj.terminal_state.action)
            d_traj_probs.append(traj.traj_prob)
        _, valid_vec_samp = canonical_smiles(d_traj)
        valid_vec_samp = torch.tensor(valid_vec_samp).view(-1, 1).float().to(
            self.device)
        d_traj, _ = pad_sequences(d_traj)
        d_samp, _ = seq2tensor(d_traj, tokens=get_default_tokens())
        d_samp = torch.from_numpy(d_samp).long().to(self.device)
        losses = []
        for i in trange(self.k, desc='IRL optimization...'):
            # D_demo processing
            demo_states, demo_actions = self.demo_gen_data.random_training_set(
            )
            d_demo = torch.cat(
                [demo_states, demo_actions[:, -1].reshape(-1, 1)],
                dim=1).to(self.device)
            valid_vec_demo = torch.ones(d_demo.shape[0]).view(
                -1, 1).float().to(self.device)
            d_demo_out = self.model([d_demo, valid_vec_demo])

            # D_samp processing
            d_samp_out = self.model([d_samp, valid_vec_samp])
            d_out_combined = torch.cat([d_samp_out, d_demo_out], dim=0)
            if d_samp_out.shape[0] < 1000:
                d_samp_out = torch.cat([d_samp_out, d_demo_out], dim=0)
            z = torch.ones(d_samp_out.shape[0]).float().to(
                self.device)  # dummy importance weights TODO: replace this
            d_samp_out = z.view(-1, 1) * torch.exp(d_samp_out)

            # objective
            loss = torch.mean(d_demo_out) - torch.log(torch.mean(d_samp_out))
            losses.append(loss.item())
            loss = -loss  # for maximization

            # update params
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            # self.lr_sch.step()
        return np.mean(losses)
Esempio n. 7
0
    def __call__(self, x, use_mc):
        """
        Calculates the reward function of a given state.

        :param x:
            The state to be used in calculating the reward.
        :param use_mc:
            Whether Monte Carlo Tree Search or the parameterized reward function should be used
        :return: float
            A scalar value representing the reward w.r.t. the given state x.
        """
        if use_mc:
            if self.mc_enabled:
                mc_node = MoleculeMonteCarloTreeSearchNode(
                    x,
                    self,
                    self.mc_policy,
                    self.actions,
                    self.max_len,
                    end_char=self.end_char)
                mcts = MonteCarloTreeSearch(mc_node)
                reward = mcts(simulations_number=self.mc_max_sims)
                return reward
            else:
                return self.no_mc_fill_val
        else:
            # Get reward of completed string using the reward net or a given reward function.
            state = ''.join(x.tolist())
            if self.use_true_reward:
                state = state[1:-1].replace('\n', '-')
                reward = self.true_reward_func(state, self.expert_func)
            else:
                smiles, valid_vec = canonical_smiles([state])
                valid_vec = torch.tensor(valid_vec).view(-1, 1).float().to(
                    self.device)
                inp, _ = seq2tensor([state], tokens=self.actions)
                inp = torch.from_numpy(inp).long().to(self.device)
                reward = self.model([inp, valid_vec]).squeeze().item()
            return self.reward_wrapper(reward)
Esempio n. 8
0
    def fit(self, trajectories):
        sq2ten = lambda x: torch.from_numpy(
            seq2tensor(x, get_default_tokens())[0]).long().to(self.device)
        t_states, t_actions, t_adv, t_ref = [], [], [], []
        t_old_probs = []
        for traj in trajectories:
            states, actions, adv_v, ref_v = self.calc_adv_ref(traj)
            if len(states) == 0:
                continue
            t_states.append(states)
            t_actions.append(actions)
            t_adv.append(adv_v)
            t_ref.append(ref_v)

            with torch.set_grad_enabled(False):
                hidden_states = self.initial_states_func(
                    batch_size=1, **self.initial_states_args)
                trajectory_input = sq2ten(states[-1])
                actions = sq2ten(actions)
                old_probs = []
                for p in range(len(trajectory_input)):
                    outputs = self.model([trajectory_input[p].reshape(1, 1)] +
                                         hidden_states)
                    output, hidden_states = outputs[0], outputs[1:]
                    log_prob = torch.log_softmax(output.view(1, -1), dim=1)
                    old_probs.append(log_prob[0, actions[p]].item())
                t_old_probs.append(old_probs)

        if len(t_states) == 0:
            return 0., 0.

        for epoch in trange(self.ppo_epochs, desc='PPO optimization...'):
            cr_loss = 0.
            ac_loss = 0.
            for i in range(len(t_states)):
                traj_last_state = t_states[i][-1]
                traj_actions = t_actions[i]
                traj_adv = t_adv[i]
                traj_ref = t_ref[i]
                traj_old_probs = t_old_probs[i]
                hidden_states = self.initial_states_func(
                    1, **self.initial_states_args)
                for p in range(len(traj_last_state)):
                    state, action, adv = traj_last_state[p], traj_actions[
                        p], traj_adv[p]
                    old_log_prob = traj_old_probs[p]
                    state, action = sq2ten(state), sq2ten(action)

                    # Critic
                    pred = self.critic(state)
                    cr_loss = cr_loss + F.mse_loss(pred.reshape(-1, 1),
                                                   traj_ref[p].reshape(-1, 1))

                    # Actor
                    outputs = self.actor([state] + hidden_states)
                    output, hidden_states = outputs[0], outputs[1:]
                    logprob_pi_v = torch.log_softmax(output.view(1, -1),
                                                     dim=-1)
                    logprob_pi_v = logprob_pi_v[0, action]
                    ratio_v = torch.exp(logprob_pi_v - old_log_prob)
                    surr_obj_v = adv * ratio_v
                    clipped_surr_v = adv * torch.clamp(
                        ratio_v, 1.0 - self.ppo_eps, 1.0 + self.ppo_eps)
                    loss_policy_v = torch.min(surr_obj_v, clipped_surr_v)

                    # Maximize entropy
                    prob = torch.softmax(output.view(1, -1), dim=1)
                    prob = prob[0, action]
                    entropy = prob * logprob_pi_v
                    entropy_loss = self.entropy_beta * entropy
                    ac_loss = ac_loss - (loss_policy_v + entropy_loss)
            # Update weights
            self.critic_opt.zero_grad()
            self.actor_opt.zero_grad()
            cr_loss = cr_loss / len(trajectories)
            ac_loss = ac_loss / len(trajectories)
            cr_loss.backward()
            ac_loss.backward()
            self.critic_opt.step()
            self.actor_opt.step()
        return cr_loss.item(), -ac_loss.item()