def model_init(args, params={}): ### Model Name if args.save_name is None: if args.model == 'transvae': save_name = 'trans{}x-{}_{}'.format( args.d_feedforward // args.d_model, args.d_model, args.data_source) else: save_name = '{}-{}_{}'.format(args.model, args.d_model, args.data_source) else: save_name = args.save_name ### Load Model if args.model == 'transvae': vae = TransVAE(params=params, name=save_name, d_model=args.d_model, d_ff=args.d_feedforward, d_latent=args.d_latent) elif args.model == 'rnnattn': vae = RNNAttn(params=params, name=save_name, d_model=args.d_model, d_latent=args.d_latent) elif args.model == 'rnn': vae = RNN(params=params, name=save_name, d_model=args.d_model, d_latent=args.d_latent) return vae
def sample(args): ### Load model ckpt_fn = args.model_ckpt if args.model == 'transvae': vae = TransVAE(load_fn=ckpt_fn) elif args.model == 'rnnattn': vae = RNNAttn(load_fn=ckpt_fn) elif args.model == 'rnn': vae = RNN(load_fn=ckpt_fn) ### Parse conditional string if args.condition == '': condition = [] else: condition = args.condition.split(',') ### Calculate entropy depending on sampling mode if args.sample_mode == 'rand': sample_mode = 'rand' sample_dims = None else: entropy_data = pd.read_csv(args.mols).to_numpy() _, mus, _ = vae.calc_mems(entropy_data, log=False, save=False) vae_entropy = calc_entropy(mus) entropy_idxs = np.where(np.array(vae_entropy) > args.entropy_cutoff)[0] sample_dims = entropy_idxs if args.sample_mode == 'high_entropy': sample_mode = 'top_dims' elif args.sample_mode == 'k_high_entropy': sample_mode = 'k_dims' ### Generate samples samples = [] n_gen = args.n_samples while n_gen > 0: current_samples = vae.sample(args.n_samples_per_batch, sample_mode=sample_mode, sample_dims=sample_dims, k=args.k, condition=condition) samples.extend(current_samples) n_gen -= len(current_samples) samples = pd.DataFrame(samples, columns=['mol']) if args.save_path is None: os.makedirs('generated', exist_ok=True) save_path = 'generated/{}_{}.csv'.format(vae.name, args.sample_mode) else: save_path = args.save_path samples.to_csv(save_path, index=False)
def calc_attention(args): ### Load model ckpt_fn = args.model_ckpt if args.model == 'transvae': vae = TransVAE(load_fn=ckpt_fn) elif args.model == 'rnnattn': vae = RNNAttn(load_fn=ckpt_fn) elif args.model == 'rnn': vae = RNN(load_fn=ckpt_fn) if args.shuffle: data = pd.read_csv(args.smiles).sample(args.n_samples).to_numpy() else: data = pd.read_csv(args.smiles).to_numpy() data = data[:args.n_samples,:] ### Load data and prepare for iteration data = vae_data_gen(data, char_dict=vae.params['CHAR_DICT']) data_iter = torch.utils.data.DataLoader(data, batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=False, drop_last=True) save_shape = len(data_iter)*args.batch_size chunk_size = args.batch_size // args.batch_chunks ### Prepare save path if args.save_path is None: os.makedirs('attn_wts', exist_ok=True) save_path = 'attn_wts/{}'.format(vae.name) else: save_path = args.save_path ### Calculate attention weights vae.model.eval() if args.model == 'transvae': self_attn = torch.empty((save_shape, 4, 4, 127, 127)) src_attn = torch.empty((save_shape, 3, 4, 126, 127)) for j, data in enumerate(data_iter): for i in range(args.batch_chunks): batch_data = data[i*chunk_size:(i+1)*chunk_size,:] if vae.use_gpu: batch_data = batch_data.cuda() src = Variable(batch_data).long() src_mask = (src != vae.pad_idx).unsqueeze(-2) tgt = Variable(batch_data[:,:-1]).long() tgt_mask = make_std_mask(tgt, vae.pad_idx) # Run samples through model to calculate weights mem, mu, logvar, pred_len, self_attn_wts = vae.model.encoder.forward_w_attn(vae.model.src_embed(src), src_mask) probs, deconv_wts, src_attn_wts = vae.model.decoder.forward_w_attn(vae.model.tgt_embed(tgt), mem, src_mask, tgt_mask) # Save weights to torch tensors self_attn_wts += deconv_wts start = j*args.batch_size+i*chunk_size stop = j*args.batch_size+(i+1)*chunk_size for k in range(len(self_attn_wts)): self_attn[start:stop,k,:,:,:] = self_attn_wts[k] for k in range(len(src_attn_wts)): src_attn[start:stop,k,:,:,:] = src_attn_wts[k] np.save(save_path+'_self_attn.npy', self_attn.numpy()) np.save(save_path+'_src_attn.npy', src_attn.numpy()) elif args.model == 'rnnattn': attn = torch.empty((save_shape, 1, 1, 127, 127)) for j, data in enumerate(data_iter): for i in range(args.batch_chunks): batch_data = data[i*chunk_size:(i+1)*chunk_size,:] if vae.use_gpu: batch_data = batch_data.cuda() src = Variable(batch_data).long() # Run samples through model to calculate weights mem, mu, logvar, attn_wts = vae.model.encoder(vae.model.src_embed(src), return_attn=True) start = j*args.batch_size+i*chunk_size stop = j*args.batch_size+(i+1)*chunk_size attn[start:stop,0,0,:,:] = attn_wts np.save(save_path+'.npy', attn.numpy())