예제 #1
0
def model_init(args, params={}):
    ### Model Name
    if args.save_name is None:
        if args.model == 'transvae':
            save_name = 'trans{}x-{}_{}'.format(
                args.d_feedforward // args.d_model, args.d_model,
                args.data_source)
        else:
            save_name = '{}-{}_{}'.format(args.model, args.d_model,
                                          args.data_source)
    else:
        save_name = args.save_name

    ### Load Model
    if args.model == 'transvae':
        vae = TransVAE(params=params,
                       name=save_name,
                       d_model=args.d_model,
                       d_ff=args.d_feedforward,
                       d_latent=args.d_latent)
    elif args.model == 'rnnattn':
        vae = RNNAttn(params=params,
                      name=save_name,
                      d_model=args.d_model,
                      d_latent=args.d_latent)
    elif args.model == 'rnn':
        vae = RNN(params=params,
                  name=save_name,
                  d_model=args.d_model,
                  d_latent=args.d_latent)

    return vae
예제 #2
0
def sample(args):
    ### Load model
    ckpt_fn = args.model_ckpt
    if args.model == 'transvae':
        vae = TransVAE(load_fn=ckpt_fn)
    elif args.model == 'rnnattn':
        vae = RNNAttn(load_fn=ckpt_fn)
    elif args.model == 'rnn':
        vae = RNN(load_fn=ckpt_fn)

    ### Parse conditional string
    if args.condition == '':
        condition = []
    else:
        condition = args.condition.split(',')

    ### Calculate entropy depending on sampling mode
    if args.sample_mode == 'rand':
        sample_mode = 'rand'
        sample_dims = None
    else:
        entropy_data = pd.read_csv(args.mols).to_numpy()
        _, mus, _ = vae.calc_mems(entropy_data, log=False, save=False)
        vae_entropy = calc_entropy(mus)
        entropy_idxs = np.where(np.array(vae_entropy) > args.entropy_cutoff)[0]
        sample_dims = entropy_idxs
        if args.sample_mode == 'high_entropy':
            sample_mode = 'top_dims'
        elif args.sample_mode == 'k_high_entropy':
            sample_mode = 'k_dims'

    ### Generate samples
    samples = []
    n_gen = args.n_samples
    while n_gen > 0:
        current_samples = vae.sample(args.n_samples_per_batch,
                                     sample_mode=sample_mode,
                                     sample_dims=sample_dims,
                                     k=args.k,
                                     condition=condition)
        samples.extend(current_samples)
        n_gen -= len(current_samples)

    samples = pd.DataFrame(samples, columns=['mol'])
    if args.save_path is None:
        os.makedirs('generated', exist_ok=True)
        save_path = 'generated/{}_{}.csv'.format(vae.name, args.sample_mode)
    else:
        save_path = args.save_path
    samples.to_csv(save_path, index=False)
예제 #3
0
def calc_attention(args):
    ### Load model
    ckpt_fn = args.model_ckpt
    if args.model == 'transvae':
        vae = TransVAE(load_fn=ckpt_fn)
    elif args.model == 'rnnattn':
        vae = RNNAttn(load_fn=ckpt_fn)
    elif args.model == 'rnn':
        vae = RNN(load_fn=ckpt_fn)

    if args.shuffle:
        data = pd.read_csv(args.smiles).sample(args.n_samples).to_numpy()
    else:
        data = pd.read_csv(args.smiles).to_numpy()
        data = data[:args.n_samples,:]

    ### Load data and prepare for iteration
    data = vae_data_gen(data, char_dict=vae.params['CHAR_DICT'])
    data_iter = torch.utils.data.DataLoader(data,
                                            batch_size=args.batch_size,
                                            shuffle=False, num_workers=0,
                                            pin_memory=False, drop_last=True)
    save_shape = len(data_iter)*args.batch_size
    chunk_size = args.batch_size // args.batch_chunks

    ### Prepare save path
    if args.save_path is None:
        os.makedirs('attn_wts', exist_ok=True)
        save_path = 'attn_wts/{}'.format(vae.name)
    else:
        save_path = args.save_path

    ### Calculate attention weights
    vae.model.eval()
    if args.model == 'transvae':
        self_attn = torch.empty((save_shape, 4, 4, 127, 127))
        src_attn = torch.empty((save_shape, 3, 4, 126, 127))
        for j, data in enumerate(data_iter):
            for i in range(args.batch_chunks):
                batch_data = data[i*chunk_size:(i+1)*chunk_size,:]
                if vae.use_gpu:
                    batch_data = batch_data.cuda()

                src = Variable(batch_data).long()
                src_mask = (src != vae.pad_idx).unsqueeze(-2)
                tgt = Variable(batch_data[:,:-1]).long()
                tgt_mask = make_std_mask(tgt, vae.pad_idx)

                # Run samples through model to calculate weights
                mem, mu, logvar, pred_len, self_attn_wts = vae.model.encoder.forward_w_attn(vae.model.src_embed(src), src_mask)
                probs, deconv_wts, src_attn_wts = vae.model.decoder.forward_w_attn(vae.model.tgt_embed(tgt), mem, src_mask, tgt_mask)

                # Save weights to torch tensors
                self_attn_wts += deconv_wts
                start = j*args.batch_size+i*chunk_size
                stop = j*args.batch_size+(i+1)*chunk_size
                for k in range(len(self_attn_wts)):
                    self_attn[start:stop,k,:,:,:] = self_attn_wts[k]
                for k in range(len(src_attn_wts)):
                    src_attn[start:stop,k,:,:,:] = src_attn_wts[k]

        np.save(save_path+'_self_attn.npy', self_attn.numpy())
        np.save(save_path+'_src_attn.npy', src_attn.numpy())

    elif args.model == 'rnnattn':
        attn = torch.empty((save_shape, 1, 1, 127, 127))
        for j, data in enumerate(data_iter):
            for i in range(args.batch_chunks):
                batch_data = data[i*chunk_size:(i+1)*chunk_size,:]
                if vae.use_gpu:
                    batch_data = batch_data.cuda()

                src = Variable(batch_data).long()

                # Run samples through model to calculate weights
                mem, mu, logvar, attn_wts = vae.model.encoder(vae.model.src_embed(src), return_attn=True)
                start = j*args.batch_size+i*chunk_size
                stop = j*args.batch_size+(i+1)*chunk_size
                attn[start:stop,0,0,:,:] = attn_wts

        np.save(save_path+'.npy', attn.numpy())