Esempio n. 1
0
class GENTRL(nn.Module):
    '''
    GENTRL model
    '''
    def __init__(self,
                 enc,
                 dec,
                 latent_descr,
                 feature_descr,
                 tt_int=40,
                 tt_type='usual',
                 beta=0.01,
                 gamma=0.1):
        super(GENTRL, self).__init__()

        self.enc = enc
        self.dec = dec

        self.num_latent = len(latent_descr)
        self.num_features = len(feature_descr)

        self.latent_descr = latent_descr
        self.feature_descr = feature_descr

        self.tt_int = tt_int
        self.tt_type = tt_type

        self.lp = LP(distr_descr=self.latent_descr + self.feature_descr,
                     tt_int=self.tt_int,
                     tt_type=self.tt_type)

        self.beta = beta
        self.gamma = gamma

    def get_elbo(self, x, y):
        means, log_stds = torch.split(self.enc.encode(x),
                                      len(self.latent_descr),
                                      dim=1)
        latvar_samples = (
            means + torch.randn_like(log_stds) * torch.exp(0.5 * log_stds))

        rec_part = self.dec.weighted_forward(x, latvar_samples).mean()

        normal_distr_hentropies = (log(2 * pi) + 1 + log_stds).sum(dim=1)

        latent_dim = len(self.latent_descr)
        condition_dim = len(self.feature_descr)

        zy = torch.cat([latvar_samples, y], dim=1)
        log_p_zy = self.lp.log_prob(zy)

        y_to_marg = latent_dim * [True] + condition_dim * [False]
        log_p_y = self.lp.log_prob(zy, marg=y_to_marg)

        z_to_marg = latent_dim * [False] + condition_dim * [True]
        log_p_z = self.lp.log_prob(zy, marg=z_to_marg)
        log_p_z_by_y = log_p_zy - log_p_y
        log_p_y_by_z = log_p_zy - log_p_z

        kldiv_part = (-normal_distr_hentropies - log_p_zy).mean()

        elbo = rec_part - self.beta * kldiv_part
        elbo = elbo + self.gamma * log_p_y_by_z.mean()

        return elbo, {
            'loss': -elbo.detach().cpu().numpy(),
            'rec': rec_part.detach().cpu().numpy(),
            'kl': kldiv_part.detach().cpu().numpy(),
            'log_p_y_by_z': log_p_y_by_z.mean().detach().cpu().numpy(),
            'log_p_z_by_y': log_p_z_by_y.mean().detach().cpu().numpy()
        }

    def save(self, folder_to_save='./', version=""):
        if folder_to_save[-1] != '/':
            folder_to_save = folder_to_save + '/'
        torch.save(self.enc.state_dict(),
                   folder_to_save + 'enc.model' + version)
        torch.save(self.dec.state_dict(),
                   folder_to_save + 'dec.model' + version)
        torch.save(self.lp.state_dict(), folder_to_save + 'lp.model' + version)

        pickle.dump(self.lp.order,
                    open(folder_to_save + 'order.pkl' + version, 'wb'))

    def load(self, folder_to_load='./', version=""):
        if folder_to_load[-1] != '/':
            folder_to_load = folder_to_load + '/'

        order = pickle.load(open(folder_to_load + 'order.pkl' + version, 'rb'))
        self.lp = LP(distr_descr=self.latent_descr + self.feature_descr,
                     tt_int=self.tt_int,
                     tt_type=self.tt_type,
                     order=order)

        self.enc.load_state_dict(
            torch.load(folder_to_load + 'enc.model' + version))
        self.dec.load_state_dict(
            torch.load(folder_to_load + 'dec.model' + version))
        self.lp.load_state_dict(
            torch.load(folder_to_load + 'lp.model' + version))

    def train_as_vaelp(self,
                       train_loader,
                       num_epochs=10,
                       verbose_step=50,
                       lr=1e-3,
                       save_path=None):
        optimizer = optim.Adam(self.parameters(), lr=lr)

        global_stats = TrainStats()
        local_stats = TrainStats()

        epoch_i = 0
        to_reinit = False
        buf = None
        while epoch_i < num_epochs:
            i = 0
            if verbose_step:
                print("Epoch", epoch_i + 1, ":", flush=True)

            if epoch_i in [0, 1, 5]:
                to_reinit = True

            for x_batch, y_batch in train_loader:

                i += 1

                y_batch = y_batch.float().to(self.lp.tt_cores[0].device)
                if len(y_batch.shape) == 1:
                    y_batch = y_batch.view(-1, 1).contiguous()

                if to_reinit:
                    if (buf is None) or (buf.shape[0] < 5000):
                        enc_out = self.enc.encode(x_batch)
                        means, log_stds = torch.split(enc_out,
                                                      len(self.latent_descr),
                                                      dim=1)
                        z_batch = (means + torch.randn_like(log_stds) *
                                   torch.exp(0.5 * log_stds))
                        cur_batch = torch.cat([z_batch, y_batch], dim=1)
                        if buf is None:
                            buf = cur_batch
                        else:
                            buf = torch.cat([buf, cur_batch])
                    else:
                        descr = len(self.latent_descr) * [0]
                        descr += len(self.feature_descr) * [1]
                        self.lp.reinit_from_data(buf, descr)
                        self.lp.cuda()
                        buf = None
                        to_reinit = False

                    continue

                elbo, cur_stats = self.get_elbo(x_batch, y_batch)
                local_stats.update(cur_stats)
                global_stats.update(cur_stats)

                optimizer.zero_grad()
                loss = -elbo
                loss.backward()
                optimizer.step()

                if verbose_step and i % verbose_step == 0:
                    local_stats.print()
                    local_stats.reset()
                    i = 0

            epoch_i += 1
            if i >= 0:
                local_stats.print()
                local_stats.reset()
                if save_path != None:
                    self.save(save_path, version="_checkpoint_%d" % epoch_i)

        return global_stats

    def resemblence_filter(self, score, sm, memory):
        if len(memory) < 1:
            return score

        nb_occurences = memory.count(sm)

        if nb_occurences > 5:
            return 0
        else:
            return score

    def train_as_rl(self,
                    reward_fn,
                    num_iterations=100000,
                    verbose_step=50,
                    batch_size=200,
                    cond_lb=-2,
                    cond_rb=0,
                    lr_lp=1e-5,
                    lr_dec=1e-6):
        optimizer_lp = optim.Adam(self.lp.parameters(), lr=lr_lp)
        optimizer_dec = optim.Adam(self.dec.latent_fc.parameters(), lr=lr_dec)

        global_stats = TrainStats()
        local_stats = TrainStats()

        cur_iteration = 0
        while cur_iteration < num_iterations:
            print("!", end='')

            exploit_size = int(batch_size * (1 - 0.3))
            exploit_z = self.lp.sample(exploit_size, 50 * ['s'] + ['m'])

            z_means = exploit_z.mean(dim=0)
            z_stds = exploit_z.std(dim=0)

            expl_size = int(batch_size * 0.3)
            expl_z = torch.randn(expl_size, exploit_z.shape[1])
            expl_z = 2 * expl_z.to(exploit_z.device) * z_stds[None, :]
            expl_z += z_means[None, :]

            z = torch.cat([exploit_z, expl_z])
            smiles = self.dec.sample(50, z, argmax=False)
            zc = torch.zeros(z.shape[0], 1).to(z.device)
            conc_zy = torch.cat([z, zc], dim=1)
            log_probs = self.lp.log_prob(conc_zy, marg=50 * [False] + [True])
            log_probs += self.dec.weighted_forward(smiles, z)
            r_list = [reward_fn(s) for s in smiles]

            rewards = torch.tensor(r_list).float().to(exploit_z.device)
            rewards_bl = rewards - rewards.mean()

            optimizer_dec.zero_grad()
            optimizer_lp.zero_grad()
            loss = -(log_probs * rewards_bl).mean()
            loss.backward()
            optimizer_dec.step()
            optimizer_lp.step()

            valid_sm = [s for s in smiles if get_mol(s) is not None]
            cur_stats = {
                'mean_reward': sum(r_list) / len(smiles),
                'valid_perc': len(valid_sm) / len(smiles),
                'max_reward': max(r_list)
            }

            local_stats.update(cur_stats)
            global_stats.update(cur_stats)

            cur_iteration += 1

            if verbose_step and (cur_iteration + 1) % verbose_step == 0:
                local_stats.print()
                local_stats.reset()

        return global_stats

    def sample(self, num_samples):
        z = self.lp.sample(num_samples, 50 * ['s'] + ['m'])
        smiles = self.dec.sample(50, z, argmax=False)

        return smiles
class DIS_GENTRL(nn.Module):
    '''
    GENTRL model
    '''
    def __init__(self,
                 enc,
                 dec,
                 latent_descr,
                 feature_descr,
                 tt_int=40,
                 tt_type='usual',
                 beta=0.01,
                 gamma=0.1):
        super(DIS_GENTRL, self).__init__()

        self.enc = enc
        self.dec = dec

        self.num_latent = len(latent_descr)
        self.num_features = len(feature_descr)

        self.latent_descr = latent_descr
        self.feature_descr = feature_descr

        self.tt_int = tt_int
        self.tt_type = tt_type

        self.lp = LP(distr_descr=self.latent_descr + self.feature_descr,
                     tt_int=self.tt_int,
                     tt_type=self.tt_type)

        self.beta = beta
        self.gamma = gamma

    def get_elbo(self, x, y, host_rank):
        means, log_stds = torch.split(
            self.enc.encode(x).cuda(non_blocking=True),
            len(self.latent_descr),
            dim=1)
        latvar_samples = (
            means + torch.randn_like(log_stds) * torch.exp(0.5 * log_stds))

        rec_part = self.dec.weighted_forward(x, latvar_samples).mean()

        normal_distr_hentropies = (log(2 * pi) + 1 + log_stds).sum(dim=1)

        latent_dim = len(self.latent_descr)
        condition_dim = len(self.feature_descr)

        zy = torch.cat([latvar_samples, y], dim=1)

        # GPU measure point
        gpu_perform = self.nvidia_measure(host_rank)

        log_p_zy = self.lp.log_prob(zy)

        y_to_marg = latent_dim * [True] + condition_dim * [False]
        log_p_y = self.lp.log_prob(zy, marg=y_to_marg)

        z_to_marg = latent_dim * [False] + condition_dim * [True]
        log_p_z = self.lp.log_prob(zy, marg=z_to_marg)
        log_p_z_by_y = log_p_zy - log_p_y
        log_p_y_by_z = log_p_zy - log_p_z

        kldiv_part = (-normal_distr_hentropies - log_p_zy).mean()

        elbo = rec_part - self.beta * kldiv_part
        elbo = elbo + self.gamma * log_p_y_by_z.mean()

        return elbo, {
            'loss': -elbo.detach().cpu().numpy(),
            'rec': rec_part.detach().cpu().numpy(),
            'kl': kldiv_part.detach().cpu().numpy(),
            'log_p_y_by_z': log_p_y_by_z.mean().detach().cpu().numpy(),
            'log_p_z_by_y': log_p_z_by_y.mean().detach().cpu().numpy()
        }, gpu_perform

    def nvidia_measure(self, host_rank):
        GPUs = GPU.getGPUs()
        if len(GPUs) > 1:
            gpu_host = int(host_rank)
            gpu = GPUs[gpu_host]
        else:
            gpu_host = int(os.environ['SM_CURRENT_HOST'].split('-')[1]) - 1
            gpu = GPUs[0]

        gpu_perform = [
            gpu_host, gpu.memoryFree, gpu.memoryUsed, gpu.memoryUtil * 100,
            gpu.memoryTotal
        ]
        return gpu_perform

    def save(self, folder_to_save='./'):
        if folder_to_save[-1] != '/':
            folder_to_save = folder_to_save + '/'
        torch.save(self.enc.state_dict(), folder_to_save + 'enc.model')
        torch.save(self.dec.state_dict(), folder_to_save + 'dec.model')
        torch.save(self.lp.state_dict(), folder_to_save + 'lp.model')

        pickle.dump(self.lp.order, open(folder_to_save + 'order.pkl', 'wb'))

    def load(self, folder_to_load='./'):
        if folder_to_load[-1] != '/':
            folder_to_load = folder_to_load + '/'

        order = pickle.load(open(folder_to_load + 'order.pkl', 'rb'))
        self.lp = LP(distr_descr=self.latent_descr + self.feature_descr,
                     tt_int=self.tt_int,
                     tt_type=self.tt_type,
                     order=order)

        self.enc.load_state_dict(torch.load(folder_to_load + 'enc.model'))
        self.dec.load_state_dict(torch.load(folder_to_load + 'dec.model'))
        self.lp.load_state_dict(torch.load(folder_to_load + 'lp.model'))

    def sample(self, num_samples):
        z = self.lp.sample(num_samples, 50 * ['s'] + ['m'])
        smiles = self.dec.sample(50, z, argmax=False)

        return smiles