Example #1
0
def main(args):
    # -------------------------------------------------------------------
    # Load Model
    # -------------------------------------------------------------------
    model_name = args.model  # ModelNames.T6_43M_UR50S # ModelNames.DEFAULT
    logger.info(f'Attempting to load ESM Model {model_name}...')
    model, alphabet = load_model_and_alphabet(model_name.value)
    logger.debug(model)
    logger.info(f'Loaded ESM Model {model_name}')

    # -------------------------------------------------------------------
    # Parse arguments
    # -------------------------------------------------------------------
    vars(model.args).update(vars(args))
    vars(args).update(vars(model.args))
    logger.info(f"Training Arguments: {args}")
    device = "cpu" if args.nogpu or not torch.cuda.is_available() else "cuda"

    # ------------
    # system setup
    # ------------
    output_path = pathlib.Path(
        os.path.join(args.output_dir, model_name.value,
                     datetime.today().strftime('%Y-%m-%d-%H-%M-%S')))
    output_path.mkdir(parents=True, exist_ok=True)
    logger.info(f'Output logging to: {output_path}')

    # -------------------------------------------------------------------
    # Load Data
    # -------------------------------------------------------------------
    tokenizer = CharacterTokenizer()
    batch_converter = MaskingBatchConverter(alphabet, tokenizer)
    datacontainer = load_train_splits_from_fasta(args.datapath,
                                                 batch_converter, args)

    # -------------------------------------------------------------------
    # Train
    # -------------------------------------------------------------------
    trainer = ESMTrainer(model_name, model, datacontainer, args, output_path,
                         alphabet, device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=args.lr[0],
        betas=args.adam_betas,
        eps=args.adam_eps)  #weight_decay=args.weight_decay)
    criterion = torch.nn.CrossEntropyLoss(
        ignore_index=-1
    )  # Ignore index important because we set unmasked token labels to -1, so ignored in loss

    res = trainer.train(optimizer=optimizer, criterion=criterion)

    logger.info(f'Finished {args.epochs}/{args.epochs} epochs.')
Example #2
0
    def __init__(self, name, repr_layer=[-1]):
        self.name_ = name
        self.repr_layer_ = repr_layer

        model, alphabet = pretrained.load_model_and_alphabet(name)
        model.eval()
        if torch.cuda.is_available():
            model = model.cuda()
        self.model_ = model
        self.alphabet_ = alphabet

        assert (all(-(model.num_layers + 1) <= i <= model.num_layers
                    for i in [-1]))
        self.repr_layers_ = [
            (i + model.num_layers + 1) % (model.num_layers + 1) for i in [-1]
        ]
Example #3
0
def main(args):
    model, alphabet = pretrained.load_model_and_alphabet(args.model_location)
    model.eval()
    if isinstance(model, MSATransformer):
        raise ValueError(
            "This script currently does not handle models with MSA input (MSA Transformer)."
        )
    if torch.cuda.is_available() and not args.nogpu:
        model = model.cuda()
        print("Transferred model to GPU")

    dataset = FastaBatchedDataset.from_file(args.fasta_file)
    batches = dataset.get_batch_indices(args.toks_per_batch,
                                        extra_toks_per_seq=1)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        collate_fn=alphabet.get_batch_converter(),
        batch_sampler=batches)
    print(f"Read {args.fasta_file} with {len(dataset)} sequences")

    args.output_dir.mkdir(parents=True, exist_ok=True)
    return_contacts = "contacts" in args.include

    assert all(-(model.num_layers + 1) <= i <= model.num_layers
               for i in args.repr_layers)
    repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1)
                   for i in args.repr_layers]

    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            print(
                f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
            )
            if torch.cuda.is_available() and not args.nogpu:
                toks = toks.to(device="cuda", non_blocking=True)

            # The model is trained on truncated sequences and passing longer ones in at
            # infernce will cause an error. See https://github.com/facebookresearch/esm/issues/21
            if args.truncate:
                toks = toks[:, :1022]

            out = model(toks,
                        repr_layers=repr_layers,
                        return_contacts=return_contacts)

            logits = out["logits"].to(device="cpu")
            representations = {
                layer: t.to(device="cpu")
                for layer, t in out["representations"].items()
            }
            if return_contacts:
                contacts = out["contacts"].to(device="cpu")

            for i, label in enumerate(labels):
                args.output_file = args.output_dir / f"{label}.pt"
                args.output_file.parent.mkdir(parents=True, exist_ok=True)
                result = {"label": label}
                # Call clone on tensors to ensure tensors are not views into a larger representation
                # See https://github.com/pytorch/pytorch/issues/1995
                if "per_tok" in args.include:
                    result["representations"] = {
                        layer: t[i, 1:len(strs[i]) + 1].clone()
                        for layer, t in representations.items()
                    }
                if "mean" in args.include:
                    result["mean_representations"] = {
                        layer: t[i, 1:len(strs[i]) + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }
                if "bos" in args.include:
                    result["bos_representations"] = {
                        layer: t[i, 0].clone()
                        for layer, t in representations.items()
                    }
                if return_contacts:
                    result["contacts"] = contacts[
                        i, :len(strs[i]), :len(strs[i])].clone()

                torch.save(
                    result,
                    args.output_file,
                )
Example #4
0
def main(args):
    model, alphabet = pretrained.load_model_and_alphabet(args.model_location)
    model.eval()
    if torch.cuda.is_available() and not args.nogpu:
        model = model.cuda()
        print("Transferred model to GPU")

    dataset = FastaBatchedDataset.from_file(args.fasta_file)
    batches = dataset.get_batch_indices(args.toks_per_batch,
                                        extra_toks_per_seq=1)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        collate_fn=alphabet.get_batch_converter(),
        batch_sampler=batches)
    print(f"Read {args.fasta_file} with {len(dataset)} sequences")

    args.output_dir.mkdir(parents=True, exist_ok=True)

    assert all(-(model.num_layers + 1) <= i <= model.num_layers
               for i in args.repr_layers)
    repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1)
                   for i in args.repr_layers]

    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            print(
                f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
            )
            if torch.cuda.is_available() and not args.nogpu:
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers)
            logits = out["logits"].to(device="cpu")
            representations = {
                layer: t.to(device="cpu")
                for layer, t in out["representations"].items()
            }

            for i, label in enumerate(labels):
                args.output_file = (args.output_dir / f"{label}.pt")
                args.output_file.parent.mkdir(parents=True, exist_ok=True)
                result = {"label": label}
                if args.include_per_tok:
                    result["representations"] = {
                        layer: t[i, 1:len(strs[i]) + 1]
                        for layer, t in representations.items()
                    }
                if args.include_mean:
                    result["mean_representations"] = {
                        layer: t[i, 1:len(strs[i]) + 1].mean(0)
                        for layer, t in representations.items()
                    }
                if args.include_bos:
                    result["bos_representations"] = {
                        layer: t[i, 0]
                        for layer, t in representations.items()
                    }
                torch.save(
                    result,
                    args.output_file,
                )
            changes.append(change)

    if plot_acquisition:
        from cached_semantics import cached_escape
        cached_escape(cache_fname,
                      beta,
                      plot=plot_acquisition,
                      namespace=plot_namespace)

    return seqs, np.array(probs), np.array(changes)


if __name__ == '__main__':
    name = 'esm1_t34_670M_UR50S'

    model, alphabet = pretrained.load_model_and_alphabet(name)
    model.eval()
    if torch.cuda.is_available():
        model = model.cuda()

    assert (all(-(model.num_layers + 1) <= i <= model.num_layers
                for i in [-1]))
    repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1)
                   for i in [-1]]

    from escape import *

    tprint('Lee et al. 2018...')
    seq_to_mutate, escape_seqs = load_doud2018()
    fb_semantics(model,
                 repr_layers,
Example #6
0
 def __init__(self, checkpoint_path):
     self.model, self.alphabet = load_model_and_alphabet(checkpoint_path)
     self.batch_converter = self.alphabet.get_batch_converter()