Example #1
0
def test_ProfileHMM_smoke(jit):
    # Setup dataset.
    seqs = ["BABBA", "BAAB", "BABBB"]
    alph = "AB"
    dataset = BiosequenceDataset(seqs, "list", alph)

    # Infer.
    scheduler = MultiStepLR({
        "optimizer": Adam,
        "optim_args": {
            "lr": 0.1
        },
        "milestones": [20, 100, 1000, 2000],
        "gamma": 0.5,
    })
    model = ProfileHMM(int(dataset.max_length * 1.1), dataset.alphabet_length)
    n_epochs = 5
    batch_size = 2
    losses = model.fit_svi(dataset, n_epochs, batch_size, scheduler, jit)

    assert not np.isnan(losses[-1])

    # Evaluate.
    train_lp, test_lp, train_perplex, test_perplex = model.evaluate(
        dataset, dataset, jit)
    assert train_lp < 0.0
    assert test_lp < 0.0
    assert train_perplex > 0.0
    assert test_perplex > 0.0
Example #2
0
def test_write():

    # Define dataset.
    seqs = ["AATC*C", "CA*", "T**"]
    dataset = BiosequenceDataset(seqs, "list", "ACGT*", include_stop=False)
    # With truncation at stop symbol.
    # Write.
    with open("test_seqs.fasta", "w") as fw:
        fw.write("")
    write(
        dataset.seq_data,
        dataset.alphabet,
        "test_seqs.fasta",
        truncate_stop=True,
        append=True,
    )

    # Reload.
    dataset2 = BiosequenceDataset("test_seqs.fasta",
                                  "fasta",
                                  "dna",
                                  include_stop=True)
    to_stop_lens = [4, 2, 1]
    for j, to_stop_len in enumerate(to_stop_lens):
        assert torch.allclose(dataset.seq_data[j, :to_stop_len],
                              dataset2.seq_data[j, :to_stop_len])
        assert torch.allclose(dataset2.seq_data[j, (to_stop_len + 1):],
                              torch.tensor(0.0))

    # Without truncation at stop symbol.
    # Write.
    write(
        dataset.seq_data,
        dataset.alphabet,
        "test_seqs.fasta",
        truncate_stop=False,
        append=False,
    )

    # Reload.
    dataset2 = BiosequenceDataset("test_seqs.fasta",
                                  "fasta",
                                  "ACGT*",
                                  include_stop=False)
    for j, to_stop_len in enumerate(to_stop_lens):
        assert torch.allclose(dataset.seq_data, dataset2.seq_data)
Example #3
0
def test_FactorMuE_smoke(indel_factor_dependence, z_prior_distribution,
                         ARD_prior, substitution_matrix, jit):
    # Setup dataset.
    seqs = ["BABBA", "BAAB", "BABBB"]
    alph = "AB"
    dataset = BiosequenceDataset(seqs, "list", alph)

    # Infer.
    z_dim = 2
    scheduler = MultiStepLR({
        "optimizer": Adam,
        "optim_args": {
            "lr": 0.1
        },
        "milestones": [20, 100, 1000, 2000],
        "gamma": 0.5,
    })
    model = FactorMuE(
        dataset.max_length,
        dataset.alphabet_length,
        z_dim,
        indel_factor_dependence=indel_factor_dependence,
        z_prior_distribution=z_prior_distribution,
        ARD_prior=ARD_prior,
        substitution_matrix=substitution_matrix,
    )
    n_epochs = 5
    anneal_length = 2
    batch_size = 2
    losses = model.fit_svi(dataset, n_epochs, anneal_length, batch_size,
                           scheduler, jit)

    # Reconstruct.
    recon = model._reconstruct_regressor_seq(dataset, 1, pyro.param)

    assert not np.isnan(losses[-1])
    assert recon.shape == (1, max([len(seq) for seq in seqs]), len(alph))

    assert torch.allclose(model._beta_anneal(3, 2, 6, 2), torch.tensor(0.5))
    assert torch.allclose(model._beta_anneal(100, 2, 6, 2), torch.tensor(1.0))

    # Evaluate.
    train_lp, test_lp, train_perplex, test_perplex = model.evaluate(
        dataset, dataset, jit)
    assert train_lp < 0.0
    assert test_lp < 0.0
    assert train_perplex > 0.0
    assert test_perplex > 0.0

    # Embedding.
    z_locs, z_scales = model.embed(dataset)
    assert z_locs.shape == (len(dataset), z_dim)
    assert z_scales.shape == (len(dataset), z_dim)
    assert torch.all(z_scales > 0.0)
Example #4
0
def generate_data(small_test, include_stop, device):
    """Generate mini example dataset."""
    if small_test:
        mult_dat = 1
    else:
        mult_dat = 10

    seqs = ["BABBA"] * mult_dat + ["BAAB"] * mult_dat + ["BABBB"] * mult_dat
    dataset = BiosequenceDataset(seqs,
                                 "list",
                                 "AB",
                                 include_stop=include_stop,
                                 device=device)

    return dataset
Example #5
0
def test_biosequencedataset(source_type, alphabet, include_stop):

    # Define dataset.
    seqs = ["AATC", "CA", "T"]

    # Encode dataset, alternate approach.
    if alphabet in alphabets:
        alphabet_list = list(alphabets[alphabet]) + include_stop * ["*"]
    else:
        alphabet_list = list(alphabet) + include_stop * ["*"]
    L_data_check = [len(seq) + include_stop for seq in seqs]
    max_length_check = max(L_data_check)
    data_size_check = len(seqs)
    seq_data_check = torch.zeros(
        [len(seqs), max_length_check,
         len(alphabet_list)])
    for i in range(len(seqs)):
        for j, s in enumerate(seqs[i] + include_stop * "*"):
            seq_data_check[i, j, list(alphabet_list).index(s)] = 1

    # Setup data source.
    if source_type == "fasta":
        # Save as external file.
        source = "test_seqs.fasta"
        with open(source, "w") as fw:
            text = """>one
AAT
C
>two
CA
>three
T
"""
            fw.write(text)
    elif source_type == "list":
        source = seqs

    # Load dataset.
    dataset = BiosequenceDataset(source,
                                 source_type,
                                 alphabet,
                                 include_stop=include_stop)

    # Check.
    assert torch.allclose(dataset.L_data,
                          torch.tensor(L_data_check, dtype=torch.float64))
    assert dataset.max_length == max_length_check
    assert len(dataset) == data_size_check
    assert dataset.data_size == data_size_check
    assert dataset.alphabet_length == len(alphabet_list)
    assert torch.allclose(dataset.seq_data, seq_data_check)
    ind = torch.tensor([0, 2])
    assert torch.allclose(
        dataset[ind][0],
        torch.cat(
            [seq_data_check[0, None, :, :], seq_data_check[2, None, :, :]]),
    )
    assert torch.allclose(
        dataset[ind][1], torch.tensor([4.0 + include_stop,
                                       1.0 + include_stop]))
    dataload = torch.utils.data.DataLoader(dataset, batch_size=2)
    for seq_data, L_data in dataload:
        assert seq_data.shape[0] == L_data.shape[0]
Example #6
0
def main(args):

    # Load dataset.
    if args.cpu_data or not args.cuda:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda")
    if args.test:
        dataset = generate_data(args.small, args.include_stop, device)
    else:
        dataset = BiosequenceDataset(
            args.file,
            "fasta",
            args.alphabet,
            include_stop=args.include_stop,
            device=device,
        )
    args.batch_size = min([dataset.data_size, args.batch_size])
    if args.split > 0.0:
        # Train test split.
        heldout_num = int(np.ceil(args.split * len(dataset)))
        data_lengths = [len(dataset) - heldout_num, heldout_num]
        # Specific data split seed, for comparability across models and
        # parameter initializations.
        pyro.set_rng_seed(args.rng_data_seed)
        indices = torch.randperm(sum(data_lengths), device=device).tolist()
        dataset_train, dataset_test = [
            torch.utils.data.Subset(dataset, indices[(offset - length):offset])
            for offset, length in zip(torch._utils._accumulate(data_lengths),
                                      data_lengths)
        ]
    else:
        dataset_train = dataset
        dataset_test = None

    # Training seed.
    pyro.set_rng_seed(args.rng_seed)

    # Construct model.
    model = FactorMuE(
        dataset.max_length,
        dataset.alphabet_length,
        args.z_dim,
        batch_size=args.batch_size,
        latent_seq_length=args.latent_seq_length,
        indel_factor_dependence=args.indel_factor,
        indel_prior_scale=args.indel_prior_scale,
        indel_prior_bias=args.indel_prior_bias,
        inverse_temp_prior=args.inverse_temp_prior,
        weights_prior_scale=args.weights_prior_scale,
        offset_prior_scale=args.offset_prior_scale,
        z_prior_distribution=args.z_prior,
        ARD_prior=args.ARD_prior,
        substitution_matrix=(not args.no_substitution_matrix),
        substitution_prior_scale=args.substitution_prior_scale,
        latent_alphabet_length=args.latent_alphabet,
        cuda=args.cuda,
        pin_memory=args.pin_mem,
    )

    # Infer with SVI.
    scheduler = MultiStepLR({
        "optimizer": Adam,
        "optim_args": {
            "lr": args.learning_rate
        },
        "milestones": json.loads(args.milestones),
        "gamma": args.learning_gamma,
    })
    n_epochs = args.n_epochs
    losses = model.fit_svi(
        dataset_train,
        n_epochs,
        args.anneal,
        args.batch_size,
        scheduler,
        args.jit,
    )

    # Evaluate.
    train_lp, test_lp, train_perplex, test_perplex = model.evaluate(
        dataset_train, dataset_test, args.jit)
    print("train logp: {} perplex: {}".format(train_lp, train_perplex))
    print("test logp: {} perplex: {}".format(test_lp, test_perplex))

    # Get latent space embedding.
    z_locs, z_scales = model.embed(dataset)

    # Plot and save.
    time_stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    if not args.no_plots:
        plt.figure(figsize=(6, 6))
        plt.plot(losses)
        plt.xlabel("step")
        plt.ylabel("loss")
        if not args.no_save:
            plt.savefig(
                os.path.join(args.out_folder,
                             "FactorMuE_plot.loss_{}.pdf".format(time_stamp)))

        plt.figure(figsize=(6, 6))
        plt.scatter(z_locs[:, 0], z_locs[:, 1])
        plt.xlabel(r"$z_1$")
        plt.ylabel(r"$z_2$")
        if not args.no_save:
            plt.savefig(
                os.path.join(
                    args.out_folder,
                    "FactorMuE_plot.latent_{}.pdf".format(time_stamp)))

        if not args.indel_factor:
            # Plot indel parameters. See statearrangers.py for details on the
            # r and u parameters.
            plt.figure(figsize=(6, 6))
            insert = pyro.param("insert_q_mn").detach()
            insert_expect = torch.exp(insert - insert.logsumexp(-1, True))
            plt.plot(insert_expect[:, :, 1].cpu().numpy())
            plt.xlabel("position")
            plt.ylabel("probability of insert")
            plt.legend([r"$r_0$", r"$r_1$", r"$r_2$"])
            if not args.no_save:
                plt.savefig(
                    os.path.join(
                        args.out_folder,
                        "FactorMuE_plot.insert_prob_{}.pdf".format(time_stamp),
                    ))
            plt.figure(figsize=(6, 6))
            delete = pyro.param("delete_q_mn").detach()
            delete_expect = torch.exp(delete - delete.logsumexp(-1, True))
            plt.plot(delete_expect[:, :, 1].cpu().numpy())
            plt.xlabel("position")
            plt.ylabel("probability of delete")
            plt.legend([r"$u_0$", r"$u_1$", r"$u_2$"])
            if not args.no_save:
                plt.savefig(
                    os.path.join(
                        args.out_folder,
                        "FactorMuE_plot.delete_prob_{}.pdf".format(time_stamp),
                    ))

    if not args.no_save:
        pyro.get_param_store().save(
            os.path.join(args.out_folder,
                         "FactorMuE_results.params_{}.out".format(time_stamp)))
        with open(
                os.path.join(
                    args.out_folder,
                    "FactorMuE_results.evaluation_{}.txt".format(time_stamp),
                ),
                "w",
        ) as ow:
            ow.write("train_lp,test_lp,train_perplex,test_perplex\n")
            ow.write("{},{},{},{}\n".format(train_lp, test_lp, train_perplex,
                                            test_perplex))
        np.savetxt(
            os.path.join(
                args.out_folder,
                "FactorMuE_results.embed_loc_{}.txt".format(time_stamp)),
            z_locs.cpu().numpy(),
        )
        np.savetxt(
            os.path.join(
                args.out_folder,
                "FactorMuE_results.embed_scale_{}.txt".format(time_stamp),
            ),
            z_scales.cpu().numpy(),
        )
        with open(
                os.path.join(
                    args.out_folder,
                    "FactorMuE_results.input_{}.txt".format(time_stamp),
                ),
                "w",
        ) as ow:
            ow.write("[args]\n")
            args.latent_seq_length = model.latent_seq_length
            args.latent_alphabet = model.latent_alphabet_length
            for elem in list(args.__dict__.keys()):
                ow.write("{} = {}\n".format(elem, args.__getattribute__(elem)))
            ow.write("alphabet_str = {}\n".format("".join(dataset.alphabet)))
            ow.write("max_length = {}\n".format(dataset.max_length))