def reconstruct(self, data, method='greedy', log=True, return_mems=True, return_str=True): """ Method for encoding input smiles into memory and decoding back into smiles Arguments: data (np.array, required): Input array consisting of smiles and property method (str): Method for decoding. Greedy decoding is currently the only method implemented. May implement beam search, top_p or top_k in future versions. log (bool): If true, tracks reconstruction progress in separate log file return_mems (bool): If true, returns memory vectors in addition to decoded SMILES return_str (bool): If true, translates decoded vectors into SMILES strings. If false returns tensor of token ids Returns: decoded_smiles (list): Decoded smiles data - either decoded SMILES strings or tensor of token ids mems (np.array): Array of model memory vectors """ data = vae_data_gen(data, props=None, char_dict=self.params['CHAR_DICT']) data_iter = torch.utils.data.DataLoader(data, batch_size=self.params['BATCH_SIZE'], shuffle=False, num_workers=0, pin_memory=False, drop_last=True) self.batch_size = self.params['BATCH_SIZE'] self.chunk_size = self.batch_size // self.params['BATCH_CHUNKS'] self.model.eval() decoded_smiles = [] mems = torch.empty((data.shape[0], self.params['d_latent'])).cpu() for j, data in enumerate(data_iter): if log: log_file = open('calcs/{}_progress.txt'.format(self.name), 'a') log_file.write('{}\n'.format(j)) log_file.close() for i in range(self.params['BATCH_CHUNKS']): batch_data = data[i*self.chunk_size:(i+1)*self.chunk_size,:] mols_data = batch_data[:,:-1] props_data = batch_data[:,-1] if self.use_gpu: mols_data = mols_data.cuda() props_data = props_data.cuda() src = Variable(mols_data).long() src_mask = (src != self.pad_idx).unsqueeze(-2) ### Run through encoder to get memory if self.model_type == 'transformer': _, mem, _, _ = self.model.encode(src, src_mask) else: _, mem, _ = self.model.encode(src) start = j*self.batch_size+i*self.chunk_size stop = j*self.batch_size+(i+1)*self.chunk_size mems[start:stop, :] = mem.detach().cpu() ### Decode logic if method == 'greedy': decoded = self.greedy_decode(mem, src_mask=src_mask) else: decoded = None if return_str: decoded = decode_mols(decoded, self.params['ORG_DICT']) decoded_smiles += decoded else: decoded_smiles.append(decoded) if return_mems: return decoded_smiles, mems.detach().numpy() else: return decoded_smiles
def calc_mems(self, data, log=True, save_dir='memory', save_fn='model_name', save=True): """ Method for calculating and saving the memory of each neural net Arguments: data (np.array, req): Input array containing SMILES strings log (bool): If true, tracks calculation progress in separate log file save_dir (str): Directory to store output memory array save_fn (str): File name to store output memory array save (bool): If true, saves memory to disk. If false, returns memory Returns: mems(np.array): Reparameterized memory array mus(np.array): Mean memory array (prior to reparameterization) logvars(np.array): Log variance array (prior to reparameterization) """ data = vae_data_gen(data, props=None, char_dict=self.params['CHAR_DICT']) data_iter = torch.utils.data.DataLoader(data, batch_size=self.params['BATCH_SIZE'], shuffle=False, num_workers=0, pin_memory=False, drop_last=True) save_shape = len(data_iter)*self.params['BATCH_SIZE'] self.batch_size = self.params['BATCH_SIZE'] self.chunk_size = self.batch_size // self.params['BATCH_CHUNKS'] mems = torch.empty((save_shape, self.params['d_latent'])).cpu() mus = torch.empty((save_shape, self.params['d_latent'])).cpu() logvars = torch.empty((save_shape, self.params['d_latent'])).cpu() self.model.eval() for j, data in enumerate(data_iter): if log: log_file = open('memory/{}_progress.txt'.format(self.name), 'a') log_file.write('{}\n'.format(j)) log_file.close() for i in range(self.params['BATCH_CHUNKS']): batch_data = data[i*self.chunk_size:(i+1)*self.chunk_size,:] mols_data = batch_data[:,:-1] props_data = batch_data[:,-1] if self.use_gpu: mols_data = mols_data.cuda() props_data = props_data.cuda() src = Variable(mols_data).long() src_mask = (src != self.pad_idx).unsqueeze(-2) ### Run through encoder to get memory if self.model_type == 'transformer': mem, mu, logvar, _ = self.model.encode(src, src_mask) else: mem, mu, logvar = self.model.encode(src) start = j*self.batch_size+i*self.chunk_size stop = j*self.batch_size+(i+1)*self.chunk_size mems[start:stop, :] = mem.detach().cpu() mus[start:stop, :] = mu.detach().cpu() logvars[start:stop, :] = logvar.detach().cpu() if save: if save_fn == 'model_name': save_fn = self.name save_path = os.path.join(save_dir, save_fn) np.save('{}_mems.npy'.format(save_path), mems.detach().numpy()) np.save('{}_mus.npy'.format(save_path), mus.detach().numpy()) np.save('{}_logvars.npy'.format(save_path), logvars.detach().numpy()) else: return mems.detach().numpy(), mus.detach().numpy(), logvars.detach().numpy()
def calc_attention(args): ### Load model ckpt_fn = args.model_ckpt if args.model == 'transvae': vae = TransVAE(load_fn=ckpt_fn) elif args.model == 'rnnattn': vae = RNNAttn(load_fn=ckpt_fn) elif args.model == 'rnn': vae = RNN(load_fn=ckpt_fn) if args.shuffle: data = pd.read_csv(args.smiles).sample(args.n_samples).to_numpy() else: data = pd.read_csv(args.smiles).to_numpy() data = data[:args.n_samples,:] ### Load data and prepare for iteration data = vae_data_gen(data, char_dict=vae.params['CHAR_DICT']) data_iter = torch.utils.data.DataLoader(data, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=False, drop_last=True) save_shape = len(data_iter)*args.batch_size chunk_size = args.batch_size // args.batch_chunks ### Prepare save path if args.save_path is None: os.makedirs('attn_wts', exist_ok=True) save_path = 'attn_wts/{}'.format(vae.name) else: save_path = args.save_path ### Calculate attention weights vae.model.eval() if args.model == 'transvae': self_attn = torch.empty((save_shape, 4, 4, 127, 127)) src_attn = torch.empty((save_shape, 3, 4, 126, 127)) for j, data in enumerate(data_iter): for i in range(args.batch_chunks): batch_data = data[i*chunk_size:(i+1)*chunk_size,:] if vae.use_gpu: batch_data = batch_data.cuda() src = Variable(batch_data).long() src_mask = (src != vae.pad_idx).unsqueeze(-2) tgt = Variable(batch_data[:,:-1]).long() tgt_mask = make_std_mask(tgt, vae.pad_idx) # Run samples through model to calculate weights mem, mu, logvar, pred_len, self_attn_wts = vae.model.encoder.forward_w_attn(vae.model.src_embed(src), src_mask) probs, deconv_wts, src_attn_wts = vae.model.decoder.forward_w_attn(vae.model.tgt_embed(tgt), mem, src_mask, tgt_mask) # Save weights to torch tensors self_attn_wts += deconv_wts start = j*args.batch_size+i*chunk_size stop = j*args.batch_size+(i+1)*chunk_size for k in range(len(self_attn_wts)): self_attn[start:stop,k,:,:,:] = self_attn_wts[k] for k in range(len(src_attn_wts)): src_attn[start:stop,k,:,:,:] = src_attn_wts[k] np.save(save_path+'_self_attn.npy', self_attn.numpy()) np.save(save_path+'_src_attn.npy', src_attn.numpy()) elif args.model == 'rnnattn': attn = torch.empty((save_shape, 1, 1, 127, 127)) for j, data in enumerate(data_iter): for i in range(args.batch_chunks): batch_data = data[i*chunk_size:(i+1)*chunk_size,:] if vae.use_gpu: batch_data = batch_data.cuda() src = Variable(batch_data).long() # Run samples through model to calculate weights mem, mu, logvar, attn_wts = vae.model.encoder(vae.model.src_embed(src), return_attn=True) start = j*args.batch_size+i*chunk_size stop = j*args.batch_size+(i+1)*chunk_size attn[start:stop,0,0,:,:] = attn_wts np.save(save_path+'.npy', attn.numpy())