def test_ProfileHMM_smoke(jit): # Setup dataset. seqs = ["BABBA", "BAAB", "BABBB"] alph = "AB" dataset = BiosequenceDataset(seqs, "list", alph) # Infer. scheduler = MultiStepLR({ "optimizer": Adam, "optim_args": { "lr": 0.1 }, "milestones": [20, 100, 1000, 2000], "gamma": 0.5, }) model = ProfileHMM(int(dataset.max_length * 1.1), dataset.alphabet_length) n_epochs = 5 batch_size = 2 losses = model.fit_svi(dataset, n_epochs, batch_size, scheduler, jit) assert not np.isnan(losses[-1]) # Evaluate. train_lp, test_lp, train_perplex, test_perplex = model.evaluate( dataset, dataset, jit) assert train_lp < 0.0 assert test_lp < 0.0 assert train_perplex > 0.0 assert test_perplex > 0.0
def test_write(): # Define dataset. seqs = ["AATC*C", "CA*", "T**"] dataset = BiosequenceDataset(seqs, "list", "ACGT*", include_stop=False) # With truncation at stop symbol. # Write. with open("test_seqs.fasta", "w") as fw: fw.write("") write( dataset.seq_data, dataset.alphabet, "test_seqs.fasta", truncate_stop=True, append=True, ) # Reload. dataset2 = BiosequenceDataset("test_seqs.fasta", "fasta", "dna", include_stop=True) to_stop_lens = [4, 2, 1] for j, to_stop_len in enumerate(to_stop_lens): assert torch.allclose(dataset.seq_data[j, :to_stop_len], dataset2.seq_data[j, :to_stop_len]) assert torch.allclose(dataset2.seq_data[j, (to_stop_len + 1):], torch.tensor(0.0)) # Without truncation at stop symbol. # Write. write( dataset.seq_data, dataset.alphabet, "test_seqs.fasta", truncate_stop=False, append=False, ) # Reload. dataset2 = BiosequenceDataset("test_seqs.fasta", "fasta", "ACGT*", include_stop=False) for j, to_stop_len in enumerate(to_stop_lens): assert torch.allclose(dataset.seq_data, dataset2.seq_data)
def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution, ARD_prior, substitution_matrix, jit): # Setup dataset. seqs = ["BABBA", "BAAB", "BABBB"] alph = "AB" dataset = BiosequenceDataset(seqs, "list", alph) # Infer. z_dim = 2 scheduler = MultiStepLR({ "optimizer": Adam, "optim_args": { "lr": 0.1 }, "milestones": [20, 100, 1000, 2000], "gamma": 0.5, }) model = FactorMuE( dataset.max_length, dataset.alphabet_length, z_dim, indel_factor_dependence=indel_factor_dependence, z_prior_distribution=z_prior_distribution, ARD_prior=ARD_prior, substitution_matrix=substitution_matrix, ) n_epochs = 5 anneal_length = 2 batch_size = 2 losses = model.fit_svi(dataset, n_epochs, anneal_length, batch_size, scheduler, jit) # Reconstruct. recon = model._reconstruct_regressor_seq(dataset, 1, pyro.param) assert not np.isnan(losses[-1]) assert recon.shape == (1, max([len(seq) for seq in seqs]), len(alph)) assert torch.allclose(model._beta_anneal(3, 2, 6, 2), torch.tensor(0.5)) assert torch.allclose(model._beta_anneal(100, 2, 6, 2), torch.tensor(1.0)) # Evaluate. train_lp, test_lp, train_perplex, test_perplex = model.evaluate( dataset, dataset, jit) assert train_lp < 0.0 assert test_lp < 0.0 assert train_perplex > 0.0 assert test_perplex > 0.0 # Embedding. z_locs, z_scales = model.embed(dataset) assert z_locs.shape == (len(dataset), z_dim) assert z_scales.shape == (len(dataset), z_dim) assert torch.all(z_scales > 0.0)
def generate_data(small_test, include_stop, device): """Generate mini example dataset.""" if small_test: mult_dat = 1 else: mult_dat = 10 seqs = ["BABBA"] * mult_dat + ["BAAB"] * mult_dat + ["BABBB"] * mult_dat dataset = BiosequenceDataset(seqs, "list", "AB", include_stop=include_stop, device=device) return dataset
def test_biosequencedataset(source_type, alphabet, include_stop): # Define dataset. seqs = ["AATC", "CA", "T"] # Encode dataset, alternate approach. if alphabet in alphabets: alphabet_list = list(alphabets[alphabet]) + include_stop * ["*"] else: alphabet_list = list(alphabet) + include_stop * ["*"] L_data_check = [len(seq) + include_stop for seq in seqs] max_length_check = max(L_data_check) data_size_check = len(seqs) seq_data_check = torch.zeros( [len(seqs), max_length_check, len(alphabet_list)]) for i in range(len(seqs)): for j, s in enumerate(seqs[i] + include_stop * "*"): seq_data_check[i, j, list(alphabet_list).index(s)] = 1 # Setup data source. if source_type == "fasta": # Save as external file. source = "test_seqs.fasta" with open(source, "w") as fw: text = """>one AAT C >two CA >three T """ fw.write(text) elif source_type == "list": source = seqs # Load dataset. dataset = BiosequenceDataset(source, source_type, alphabet, include_stop=include_stop) # Check. assert torch.allclose(dataset.L_data, torch.tensor(L_data_check, dtype=torch.float64)) assert dataset.max_length == max_length_check assert len(dataset) == data_size_check assert dataset.data_size == data_size_check assert dataset.alphabet_length == len(alphabet_list) assert torch.allclose(dataset.seq_data, seq_data_check) ind = torch.tensor([0, 2]) assert torch.allclose( dataset[ind][0], torch.cat( [seq_data_check[0, None, :, :], seq_data_check[2, None, :, :]]), ) assert torch.allclose( dataset[ind][1], torch.tensor([4.0 + include_stop, 1.0 + include_stop])) dataload = torch.utils.data.DataLoader(dataset, batch_size=2) for seq_data, L_data in dataload: assert seq_data.shape[0] == L_data.shape[0]
def main(args): # Load dataset. if args.cpu_data or not args.cuda: device = torch.device("cpu") else: device = torch.device("cuda") if args.test: dataset = generate_data(args.small, args.include_stop, device) else: dataset = BiosequenceDataset( args.file, "fasta", args.alphabet, include_stop=args.include_stop, device=device, ) args.batch_size = min([dataset.data_size, args.batch_size]) if args.split > 0.0: # Train test split. heldout_num = int(np.ceil(args.split * len(dataset))) data_lengths = [len(dataset) - heldout_num, heldout_num] # Specific data split seed, for comparability across models and # parameter initializations. pyro.set_rng_seed(args.rng_data_seed) indices = torch.randperm(sum(data_lengths), device=device).tolist() dataset_train, dataset_test = [ torch.utils.data.Subset(dataset, indices[(offset - length):offset]) for offset, length in zip(torch._utils._accumulate(data_lengths), data_lengths) ] else: dataset_train = dataset dataset_test = None # Training seed. pyro.set_rng_seed(args.rng_seed) # Construct model. model = FactorMuE( dataset.max_length, dataset.alphabet_length, args.z_dim, batch_size=args.batch_size, latent_seq_length=args.latent_seq_length, indel_factor_dependence=args.indel_factor, indel_prior_scale=args.indel_prior_scale, indel_prior_bias=args.indel_prior_bias, inverse_temp_prior=args.inverse_temp_prior, weights_prior_scale=args.weights_prior_scale, offset_prior_scale=args.offset_prior_scale, z_prior_distribution=args.z_prior, ARD_prior=args.ARD_prior, substitution_matrix=(not args.no_substitution_matrix), substitution_prior_scale=args.substitution_prior_scale, latent_alphabet_length=args.latent_alphabet, cuda=args.cuda, pin_memory=args.pin_mem, ) # Infer with SVI. scheduler = MultiStepLR({ "optimizer": Adam, "optim_args": { "lr": args.learning_rate }, "milestones": json.loads(args.milestones), "gamma": args.learning_gamma, }) n_epochs = args.n_epochs losses = model.fit_svi( dataset_train, n_epochs, args.anneal, args.batch_size, scheduler, args.jit, ) # Evaluate. train_lp, test_lp, train_perplex, test_perplex = model.evaluate( dataset_train, dataset_test, args.jit) print("train logp: {} perplex: {}".format(train_lp, train_perplex)) print("test logp: {} perplex: {}".format(test_lp, test_perplex)) # Get latent space embedding. z_locs, z_scales = model.embed(dataset) # Plot and save. time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") if not args.no_plots: plt.figure(figsize=(6, 6)) plt.plot(losses) plt.xlabel("step") plt.ylabel("loss") if not args.no_save: plt.savefig( os.path.join(args.out_folder, "FactorMuE_plot.loss_{}.pdf".format(time_stamp))) plt.figure(figsize=(6, 6)) plt.scatter(z_locs[:, 0], z_locs[:, 1]) plt.xlabel(r"$z_1$") plt.ylabel(r"$z_2$") if not args.no_save: plt.savefig( os.path.join( args.out_folder, "FactorMuE_plot.latent_{}.pdf".format(time_stamp))) if not args.indel_factor: # Plot indel parameters. See statearrangers.py for details on the # r and u parameters. plt.figure(figsize=(6, 6)) insert = pyro.param("insert_q_mn").detach() insert_expect = torch.exp(insert - insert.logsumexp(-1, True)) plt.plot(insert_expect[:, :, 1].cpu().numpy()) plt.xlabel("position") plt.ylabel("probability of insert") plt.legend([r"$r_0$", r"$r_1$", r"$r_2$"]) if not args.no_save: plt.savefig( os.path.join( args.out_folder, "FactorMuE_plot.insert_prob_{}.pdf".format(time_stamp), )) plt.figure(figsize=(6, 6)) delete = pyro.param("delete_q_mn").detach() delete_expect = torch.exp(delete - delete.logsumexp(-1, True)) plt.plot(delete_expect[:, :, 1].cpu().numpy()) plt.xlabel("position") plt.ylabel("probability of delete") plt.legend([r"$u_0$", r"$u_1$", r"$u_2$"]) if not args.no_save: plt.savefig( os.path.join( args.out_folder, "FactorMuE_plot.delete_prob_{}.pdf".format(time_stamp), )) if not args.no_save: pyro.get_param_store().save( os.path.join(args.out_folder, "FactorMuE_results.params_{}.out".format(time_stamp))) with open( os.path.join( args.out_folder, "FactorMuE_results.evaluation_{}.txt".format(time_stamp), ), "w", ) as ow: ow.write("train_lp,test_lp,train_perplex,test_perplex\n") ow.write("{},{},{},{}\n".format(train_lp, test_lp, train_perplex, test_perplex)) np.savetxt( os.path.join( args.out_folder, "FactorMuE_results.embed_loc_{}.txt".format(time_stamp)), z_locs.cpu().numpy(), ) np.savetxt( os.path.join( args.out_folder, "FactorMuE_results.embed_scale_{}.txt".format(time_stamp), ), z_scales.cpu().numpy(), ) with open( os.path.join( args.out_folder, "FactorMuE_results.input_{}.txt".format(time_stamp), ), "w", ) as ow: ow.write("[args]\n") args.latent_seq_length = model.latent_seq_length args.latent_alphabet = model.latent_alphabet_length for elem in list(args.__dict__.keys()): ow.write("{} = {}\n".format(elem, args.__getattribute__(elem))) ow.write("alphabet_str = {}\n".format("".join(dataset.alphabet))) ow.write("max_length = {}\n".format(dataset.max_length))