Ejemplo n.º 1
0
def embed_seqs_fb(model,
                  seqs,
                  repr_layers,
                  alphabet,
                  batch_size=4096,
                  use_cache=False,
                  verbose=True):
    labels = ['seq' + str(i) for i in range(len(seqs))]

    dataset = FastaBatchedDataset(labels, seqs)
    batches = dataset.get_batch_indices(batch_size, extra_toks_per_seq=1)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        collate_fn=alphabet.get_batch_converter(),
        batch_sampler=batches)

    embedded_seqs = {}
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)
            out = model(toks, repr_layers=repr_layers, return_contacts=False)
            representations = {
                layer: t.to(device="cpu")
                for layer, t in out["representations"].items()
            }

            for i, label in enumerate(labels):
                seq_idx = int(label[3:])
                seq = seqs[seq_idx]
                assert (len(representations.items()) == 1)
                for _, t in representations.items():
                    representation = t[i, 1:len(strs[i]) + 1]
                if seq not in embedded_seqs:
                    embedded_seqs[seq] = []
                embedded_seqs[seq].append(
                    {'embedding': representation.numpy()})

    return embedded_seqs
Ejemplo n.º 2
0
def predict_sequence_prob_fb(seq,
                             alphabet,
                             model,
                             repr_layers,
                             batch_size=4096,
                             verbose=False):
    seqs = [seq]
    labels = ['seq0']

    dataset = FastaBatchedDataset(labels, seqs)
    batches = dataset.get_batch_indices(batch_size, extra_toks_per_seq=1)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        collate_fn=alphabet.get_batch_converter(),
        batch_sampler=batches)

    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)
            out = model(toks, repr_layers=repr_layers, return_contacts=False)
            logits = out["logits"].to(device="cpu")

    return logits.numpy()[0]