Example #1
0
    def evaluate_model(model,
                       gen_data,
                       rnn_args,
                       sim_data_node=None,
                       num_smiles=1000):
        start = time.time()
        model.eval()

        # Samples SMILES
        samples = []
        step = 100
        count = 0
        for _ in range(int(num_smiles / step)):
            samples.extend(
                generate_smiles(generator=model,
                                gen_data=gen_data,
                                init_args=rnn_args,
                                num_samples=step,
                                is_train=False,
                                verbose=True,
                                max_len=smiles_max_len))
            count += step
        res = num_smiles - count
        if res > 0:
            samples.extend(
                generate_smiles(generator=model,
                                gen_data=gen_data,
                                init_args=rnn_args,
                                num_samples=res,
                                is_train=False,
                                verbose=True,
                                max_len=smiles_max_len))
        smiles, valid_vec = canonical_smiles(samples)
        valid_smiles = []
        invalid_smiles = []
        for idx, sm in enumerate(smiles):
            if len(sm) > 0:
                valid_smiles.append(sm)
            else:
                invalid_smiles.append(samples[idx])
        v = len(valid_smiles)
        valid_smiles = list(set(valid_smiles))
        print(
            f'Percentage of valid SMILES = {float(len(valid_smiles)) / float(len(samples)):.2f}, '
            f'Num. samples = {len(samples)}, Num. valid = {len(valid_smiles)}, '
            f'Num. requested = {num_smiles}, Num. dups = {v - len(valid_smiles)}'
        )

        # sub-nodes of sim data resource
        smiles_node = DataNode(label="valid_smiles", data=valid_smiles)
        invalid_smiles_node = DataNode(label='invalid_smiles',
                                       data=invalid_smiles)

        # add sim data nodes to parent node
        if sim_data_node:
            sim_data_node.data = [smiles_node, invalid_smiles_node]

        duration = time.time() - start
        print('\nModel evaluation duration: {:.0f}m {:.0f}s'.format(
            duration // 60, duration % 60))
Example #2
0
def smiles_to_tensor(smiles):
    smiles = list(smiles)
    _, valid_vec = canonical_smiles(smiles)
    valid_vec = torch.tensor(valid_vec).view(-1, 1).float().to(device)
    smiles, _ = pad_sequences(smiles)
    inp, _ = seq2tensor(smiles, tokens=get_default_tokens())
    inp = torch.from_numpy(inp).long().to(device)
    return inp, valid_vec
Example #3
0
    def fit(self, trajectories):
        """Train the reward function / model using the GRL algorithm."""
        """Train the reward function / model using the GRL algorithm."""
        if self.use_buffer:
            extra_trajs = self.replay_buffer.sample(self.batch_size)
            trajectories.extend(extra_trajs)
            self.replay_buffer.populate(trajectories)
        d_traj, d_traj_probs = [], []
        for traj in trajectories:
            d_traj.append(''.join(list(traj.terminal_state.state)) +
                          traj.terminal_state.action)
            d_traj_probs.append(traj.traj_prob)
        _, valid_vec_samp = canonical_smiles(d_traj)
        valid_vec_samp = torch.tensor(valid_vec_samp).view(-1, 1).float().to(
            self.device)
        d_traj, _ = pad_sequences(d_traj)
        d_samp, _ = seq2tensor(d_traj, tokens=get_default_tokens())
        d_samp = torch.from_numpy(d_samp).long().to(self.device)
        losses = []
        for i in trange(self.k, desc='IRL optimization...'):
            # D_demo processing
            demo_states, demo_actions = self.demo_gen_data.random_training_set(
            )
            d_demo = torch.cat(
                [demo_states, demo_actions[:, -1].reshape(-1, 1)],
                dim=1).to(self.device)
            valid_vec_demo = torch.ones(d_demo.shape[0]).view(
                -1, 1).float().to(self.device)
            d_demo_out = self.model([d_demo, valid_vec_demo])

            # D_samp processing
            d_samp_out = self.model([d_samp, valid_vec_samp])
            d_out_combined = torch.cat([d_samp_out, d_demo_out], dim=0)
            if d_samp_out.shape[0] < 1000:
                d_samp_out = torch.cat([d_samp_out, d_demo_out], dim=0)
            z = torch.ones(d_samp_out.shape[0]).float().to(
                self.device)  # dummy importance weights TODO: replace this
            d_samp_out = z.view(-1, 1) * torch.exp(d_samp_out)

            # objective
            loss = torch.mean(d_demo_out) - torch.log(torch.mean(d_samp_out))
            losses.append(loss.item())
            loss = -loss  # for maximization

            # update params
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            # self.lr_sch.step()
        return np.mean(losses)
Example #4
0
    def __call__(self, x, use_mc):
        """
        Calculates the reward function of a given state.

        :param x:
            The state to be used in calculating the reward.
        :param use_mc:
            Whether Monte Carlo Tree Search or the parameterized reward function should be used
        :return: float
            A scalar value representing the reward w.r.t. the given state x.
        """
        if use_mc:
            if self.mc_enabled:
                mc_node = MoleculeMonteCarloTreeSearchNode(
                    x,
                    self,
                    self.mc_policy,
                    self.actions,
                    self.max_len,
                    end_char=self.end_char)
                mcts = MonteCarloTreeSearch(mc_node)
                reward = mcts(simulations_number=self.mc_max_sims)
                return reward
            else:
                return self.no_mc_fill_val
        else:
            # Get reward of completed string using the reward net or a given reward function.
            state = ''.join(x.tolist())
            if self.use_true_reward:
                state = state[1:-1].replace('\n', '-')
                reward = self.true_reward_func(state, self.expert_func)
            else:
                smiles, valid_vec = canonical_smiles([state])
                valid_vec = torch.tensor(valid_vec).view(-1, 1).float().to(
                    self.device)
                inp, _ = seq2tensor([state], tokens=self.actions)
                inp = torch.from_numpy(inp).long().to(self.device)
                reward = self.model([inp, valid_vec]).squeeze().item()
            return self.reward_wrapper(reward)