コード例 #1
0
    def reconstruct(self, data, method='greedy', log=True, return_mems=True, return_str=True):
        """
        Method for encoding input smiles into memory and decoding back
        into smiles

        Arguments:
            data (np.array, required): Input array consisting of smiles and property
            method (str): Method for decoding. Greedy decoding is currently the only
                          method implemented. May implement beam search, top_p or top_k
                          in future versions.
            log (bool): If true, tracks reconstruction progress in separate log file
            return_mems (bool): If true, returns memory vectors in addition to decoded SMILES
            return_str (bool): If true, translates decoded vectors into SMILES strings. If false
                               returns tensor of token ids
        Returns:
            decoded_smiles (list): Decoded smiles data - either decoded SMILES strings or tensor of
                                   token ids
            mems (np.array): Array of model memory vectors
        """
        data = vae_data_gen(data, props=None, char_dict=self.params['CHAR_DICT'])

        data_iter = torch.utils.data.DataLoader(data,
                                                batch_size=self.params['BATCH_SIZE'],
                                                shuffle=False, num_workers=0,
                                                pin_memory=False, drop_last=True)
        self.batch_size = self.params['BATCH_SIZE']
        self.chunk_size = self.batch_size // self.params['BATCH_CHUNKS']

        self.model.eval()
        decoded_smiles = []
        mems = torch.empty((data.shape[0], self.params['d_latent'])).cpu()
        for j, data in enumerate(data_iter):
            if log:
                log_file = open('calcs/{}_progress.txt'.format(self.name), 'a')
                log_file.write('{}\n'.format(j))
                log_file.close()
            for i in range(self.params['BATCH_CHUNKS']):
                batch_data = data[i*self.chunk_size:(i+1)*self.chunk_size,:]
                mols_data = batch_data[:,:-1]
                props_data = batch_data[:,-1]
                if self.use_gpu:
                    mols_data = mols_data.cuda()
                    props_data = props_data.cuda()

                src = Variable(mols_data).long()
                src_mask = (src != self.pad_idx).unsqueeze(-2)

                ### Run through encoder to get memory
                if self.model_type == 'transformer':
                    _, mem, _, _ = self.model.encode(src, src_mask)
                else:
                    _, mem, _ = self.model.encode(src)
                start = j*self.batch_size+i*self.chunk_size
                stop = j*self.batch_size+(i+1)*self.chunk_size
                mems[start:stop, :] = mem.detach().cpu()

                ### Decode logic
                if method == 'greedy':
                    decoded = self.greedy_decode(mem, src_mask=src_mask)
                else:
                    decoded = None

                if return_str:
                    decoded = decode_mols(decoded, self.params['ORG_DICT'])
                    decoded_smiles += decoded
                else:
                    decoded_smiles.append(decoded)

        if return_mems:
            return decoded_smiles, mems.detach().numpy()
        else:
            return decoded_smiles
コード例 #2
0
    def calc_mems(self, data, log=True, save_dir='memory', save_fn='model_name', save=True):
        """
        Method for calculating and saving the memory of each neural net

        Arguments:
            data (np.array, req): Input array containing SMILES strings
            log (bool): If true, tracks calculation progress in separate log file
            save_dir (str): Directory to store output memory array
            save_fn (str): File name to store output memory array
            save (bool): If true, saves memory to disk. If false, returns memory
        Returns:
            mems(np.array): Reparameterized memory array
            mus(np.array): Mean memory array (prior to reparameterization)
            logvars(np.array): Log variance array (prior to reparameterization)
        """
        data = vae_data_gen(data, props=None, char_dict=self.params['CHAR_DICT'])

        data_iter = torch.utils.data.DataLoader(data,
                                                batch_size=self.params['BATCH_SIZE'],
                                                shuffle=False, num_workers=0,
                                                pin_memory=False, drop_last=True)
        save_shape = len(data_iter)*self.params['BATCH_SIZE']
        self.batch_size = self.params['BATCH_SIZE']
        self.chunk_size = self.batch_size // self.params['BATCH_CHUNKS']
        mems = torch.empty((save_shape, self.params['d_latent'])).cpu()
        mus = torch.empty((save_shape, self.params['d_latent'])).cpu()
        logvars = torch.empty((save_shape, self.params['d_latent'])).cpu()

        self.model.eval()
        for j, data in enumerate(data_iter):
            if log:
                log_file = open('memory/{}_progress.txt'.format(self.name), 'a')
                log_file.write('{}\n'.format(j))
                log_file.close()
            for i in range(self.params['BATCH_CHUNKS']):
                batch_data = data[i*self.chunk_size:(i+1)*self.chunk_size,:]
                mols_data = batch_data[:,:-1]
                props_data = batch_data[:,-1]
                if self.use_gpu:
                    mols_data = mols_data.cuda()
                    props_data = props_data.cuda()

                src = Variable(mols_data).long()
                src_mask = (src != self.pad_idx).unsqueeze(-2)

                ### Run through encoder to get memory
                if self.model_type == 'transformer':
                    mem, mu, logvar, _ = self.model.encode(src, src_mask)
                else:
                    mem, mu, logvar = self.model.encode(src)
                start = j*self.batch_size+i*self.chunk_size
                stop = j*self.batch_size+(i+1)*self.chunk_size
                mems[start:stop, :] = mem.detach().cpu()
                mus[start:stop, :] = mu.detach().cpu()
                logvars[start:stop, :] = logvar.detach().cpu()

        if save:
            if save_fn == 'model_name':
                save_fn = self.name
            save_path = os.path.join(save_dir, save_fn)
            np.save('{}_mems.npy'.format(save_path), mems.detach().numpy())
            np.save('{}_mus.npy'.format(save_path), mus.detach().numpy())
            np.save('{}_logvars.npy'.format(save_path), logvars.detach().numpy())
        else:
            return mems.detach().numpy(), mus.detach().numpy(), logvars.detach().numpy()
コード例 #3
0
def calc_attention(args):
    ### Load model
    ckpt_fn = args.model_ckpt
    if args.model == 'transvae':
        vae = TransVAE(load_fn=ckpt_fn)
    elif args.model == 'rnnattn':
        vae = RNNAttn(load_fn=ckpt_fn)
    elif args.model == 'rnn':
        vae = RNN(load_fn=ckpt_fn)

    if args.shuffle:
        data = pd.read_csv(args.smiles).sample(args.n_samples).to_numpy()
    else:
        data = pd.read_csv(args.smiles).to_numpy()
        data = data[:args.n_samples,:]

    ### Load data and prepare for iteration
    data = vae_data_gen(data, char_dict=vae.params['CHAR_DICT'])
    data_iter = torch.utils.data.DataLoader(data,
                                            batch_size=args.batch_size,
                                            shuffle=False, num_workers=0,
                                            pin_memory=False, drop_last=True)
    save_shape = len(data_iter)*args.batch_size
    chunk_size = args.batch_size // args.batch_chunks

    ### Prepare save path
    if args.save_path is None:
        os.makedirs('attn_wts', exist_ok=True)
        save_path = 'attn_wts/{}'.format(vae.name)
    else:
        save_path = args.save_path

    ### Calculate attention weights
    vae.model.eval()
    if args.model == 'transvae':
        self_attn = torch.empty((save_shape, 4, 4, 127, 127))
        src_attn = torch.empty((save_shape, 3, 4, 126, 127))
        for j, data in enumerate(data_iter):
            for i in range(args.batch_chunks):
                batch_data = data[i*chunk_size:(i+1)*chunk_size,:]
                if vae.use_gpu:
                    batch_data = batch_data.cuda()

                src = Variable(batch_data).long()
                src_mask = (src != vae.pad_idx).unsqueeze(-2)
                tgt = Variable(batch_data[:,:-1]).long()
                tgt_mask = make_std_mask(tgt, vae.pad_idx)

                # Run samples through model to calculate weights
                mem, mu, logvar, pred_len, self_attn_wts = vae.model.encoder.forward_w_attn(vae.model.src_embed(src), src_mask)
                probs, deconv_wts, src_attn_wts = vae.model.decoder.forward_w_attn(vae.model.tgt_embed(tgt), mem, src_mask, tgt_mask)

                # Save weights to torch tensors
                self_attn_wts += deconv_wts
                start = j*args.batch_size+i*chunk_size
                stop = j*args.batch_size+(i+1)*chunk_size
                for k in range(len(self_attn_wts)):
                    self_attn[start:stop,k,:,:,:] = self_attn_wts[k]
                for k in range(len(src_attn_wts)):
                    src_attn[start:stop,k,:,:,:] = src_attn_wts[k]

        np.save(save_path+'_self_attn.npy', self_attn.numpy())
        np.save(save_path+'_src_attn.npy', src_attn.numpy())

    elif args.model == 'rnnattn':
        attn = torch.empty((save_shape, 1, 1, 127, 127))
        for j, data in enumerate(data_iter):
            for i in range(args.batch_chunks):
                batch_data = data[i*chunk_size:(i+1)*chunk_size,:]
                if vae.use_gpu:
                    batch_data = batch_data.cuda()

                src = Variable(batch_data).long()

                # Run samples through model to calculate weights
                mem, mu, logvar, attn_wts = vae.model.encoder(vae.model.src_embed(src), return_attn=True)
                start = j*args.batch_size+i*chunk_size
                stop = j*args.batch_size+(i+1)*chunk_size
                attn[start:stop,0,0,:,:] = attn_wts

        np.save(save_path+'.npy', attn.numpy())