Exemplo n.º 1
0
 def test_superdiag_range(self):
     auc_superdiag_1 = contact_auc(self.pred,
                                   self.meas,
                                   superdiag=1,
                                   cutoff_range=[1, 2, 3])
     auc_superdiag_2 = contact_auc(self.pred,
                                   self.meas,
                                   superdiag=2,
                                   cutoff_range=[1, 2, 3])
     self.assertEqual(auc_superdiag_1, 8.0 / 9)
     self.assertEqual(auc_superdiag_2, 1.0)
def train():
    # Initialize parser
    parser = ArgumentParser()
    parser.add_argument(
        "--model",
        default="gremlin",
        choices=models.MODELS.keys(),
        help="Which model to train.",
    )
    parser.add_argument(
        "--train_unaligned",
        action="store_true",
        help="Whether to train unaligned instead.",
    )
    model_name = parser.parse_known_args()[0].model
    train_unaligned = parser.parse_known_args()[0].train_unaligned
    parser.add_argument(
        "--save_model_s3",
        action="store_true",
        help="Whether to save the model state dict.",
    )
    parser.add_argument(
        "--wandb_project",
        type=str,
        default="iclr2021-rebuttal",
        help="W&B project used for logging.",
    )
    parser.add_argument(
        "--pdb",
        type=str,
        help="PDB id for training",
    )

    if train_unaligned:
        parser = MSDataModule.add_args(parser)
    else:
        parser = MSADataModule.add_args(parser)

    parser = pl.Trainer.add_argparse_args(parser)
    parser.set_defaults(
        gpus=1,
        min_steps=50,
        max_steps=1000,
        log_every_n_steps=10,
    )
    model_type = models.get(model_name)
    model_type.add_args(parser)
    args = parser.parse_args()

    # Modify name
    pdb = args.pdb
    args.data = "data/npz/" + pdb + ".npz"

    # Load ms(a)
    if train_unaligned:
        msa_dm = MSDataModule.from_args(args)
    else:
        msa_dm = MSADataModule.from_args(args)
    msa_dm.setup()

    # Load contacts
    true_contacts = torch.from_numpy(read_contacts(args.data))

    # Initialize model
    num_seqs, msa_length, msa_counts = msa_dm.get_stats()
    model = model_type.from_args(
        args,
        num_seqs=num_seqs,
        msa_length=msa_length,
        msa_counts=msa_counts,
        vocab_size=len(FastaVocab),
        pad_idx=FastaVocab.pad_idx,
        true_contacts=true_contacts,
    )

    kwargs = {}
    randstring = "".join(random.choice(string.ascii_lowercase) for i in range(6))
    run_name = "_".join([args.model, pdb, randstring])
    logger = WandbLoggerFrozenVal(project=args.wandb_project, name=run_name)
    logger.log_hyperparams(args)
    logger.log_hyperparams(
        {
            "pdb": pdb,
            "num_seqs": num_seqs,
            "msa_length": msa_length,
        }
    )
    kwargs["logger"] = logger

    # Initialize Trainer
    trainer = pl.Trainer.from_argparse_args(args, checkpoint_callback=False, **kwargs)

    trainer.fit(model, msa_dm)

    # Log and print some metrics after training.
    contacts = model.get_contacts()
    apc_contacts = apc(contacts)

    auc = contact_auc(contacts, true_contacts).item()
    auc_apc = contact_auc(apc_contacts, true_contacts).item()
    print(f"AUC: {auc:0.3f}, AUC_APC: {auc_apc:0.3f}")

    filename = "top_L_contacts.png"
    plot_colored_preds_on_trues(contacts, true_contacts, point_size=5, cutoff=1)
    plt.title(f"Top L no APC {model.get_precision(do_apc=False)}")
    logger.log_metrics({filename: wandb.Image(plt)})
    plt.close()

    filename = "top_L_contacts_apc.png"
    plot_colored_preds_on_trues(apc_contacts, true_contacts, point_size=5, cutoff=1)
    plt.title(f"Top L APC {model.get_precision(do_apc=True)}")
    logger.log_metrics({filename: wandb.Image(plt)})
    plt.close()

    filename = "top_L_5_contacts.png"
    plot_colored_preds_on_trues(contacts, true_contacts, point_size=5, cutoff=5)
    plt.title(f"Top L/5 no APC {model.get_precision(do_apc=False, cutoff=5)}")
    logger.log_metrics({filename: wandb.Image(plt)})
    plt.close()

    filename = "top_L_5_contacts_apc.png"
    plot_colored_preds_on_trues(apc_contacts, true_contacts, point_size=5, cutoff=5)
    plt.title(f"Top L/5 APC {model.get_precision(do_apc=True, cutoff=5)}")
    logger.log_metrics({filename: wandb.Image(plt)})
    plt.close()

    filename = "precision_vs_L.png"
    plot_precision_vs_length(apc_contacts, true_contacts)
    logger.log_metrics({filename: wandb.Image(plt)})
    plt.close()

    if args.save_model_s3:
        bytestream = io.BytesIO()
        torch.save(model.state_dict(), bytestream)
        bytestream.seek(0)
        key = os.path.join(
            "iclr-2021-factored-attention", wandb.run.path, "model_state_dict.h5"
        )
        response = s3_client.put_object(
            Bucket=s3_bucket, Body=bytestream, Key=key, ACL="public-read"
        )
        print(f"uploaded state dict to s3://{s3_bucket}/{key}")
Exemplo n.º 3
0
def train():
    # Initialize parser
    parser = ArgumentParser()
    parser.add_argument(
        "--model",
        default="gremlin",
        choices=models.MODELS.keys(),
        help="Which model to train.",
    )
    model_name = parser.parse_known_args()[0].model
    parser.add_argument(
        "--structure_file",
        type=str,
        default=None,
        help=("Optional pdb or cf file containing protein structure. "
              "Used for evaluation."),
    )
    parser.add_argument(
        "--output_file",
        type=str,
        default=None,
        help="Optional file to output gremlin weights.",
    )
    parser.add_argument(
        "--contacts_file",
        type=str,
        default=None,
        help="Optional file to output gremlin contacts.",
    )
    parser.add_argument(
        "--wandb_project",
        type=str,
        default=None,
        help="Optional wandb project to log to.",
    )
    parser = MSADataModule.add_args(parser)
    parser = pl.Trainer.add_argparse_args(parser)
    parser.set_defaults(
        gpus=1,
        min_steps=50,
        max_steps=1000,
    )
    model_type = models.get(model_name)
    model_type.add_args(parser)
    args = parser.parse_args()

    # Load msa
    msa_dm = MSADataModule.from_args(args)
    msa_dm.setup()

    # Load contacts
    true_contacts = (torch.from_numpy(read_contacts(args.structure_file))
                     if args.structure_file is not None else None)

    # Initialize model
    num_seqs, msa_length, msa_counts = msa_dm.get_stats()
    model = model_type.from_args(
        args,
        num_seqs=num_seqs,
        msa_length=msa_length,
        msa_counts=msa_counts,
        vocab_size=len(FastaVocab),
        pad_idx=FastaVocab.pad_idx,
        true_contacts=true_contacts,
    )

    kwargs = {}
    if args.wandb_project:
        try:
            # Requires wandb to be installed
            logger = pl.loggers.WandbLogger(project=args.wandb_project)
            logger.log_hyperparams(args)
            logger.log_hyperparams({
                "pdb": Path(args.data).stem,
                "num_seqs": num_seqs,
                "msa_length": msa_length,
            })
            kwargs["logger"] = logger
        except ImportError:
            raise ImportError(
                "Cannot use W&B logger w/o W&b install. Run `pip install wandb` first."
            )

    # Initialize Trainer
    trainer = pl.Trainer.from_argparse_args(args,
                                            checkpoint_callback=False,
                                            **kwargs)

    trainer.fit(model, msa_dm)

    if true_contacts is not None:
        contacts = model.get_contacts()
        auc = contact_auc(contacts, true_contacts).item()
        contacts = apc(contacts)
        auc_apc = contact_auc(contacts, true_contacts).item()
        print(f"AUC: {auc:0.3f}, AUC_APC: {auc_apc:0.3f}")

        if args.wandb_project:
            import matplotlib.pyplot as plt
            import wandb

            from mogwai.plotting import (
                plot_colored_preds_on_trues,
                plot_precision_vs_length,
            )

            filename = "top_L_contacts.png"
            plot_colored_preds_on_trues(contacts, true_contacts, point_size=5)
            logger.log_metrics({filename: wandb.Image(plt)})
            plt.close()

            filename = "top_L_contacts_apc.png"
            plot_colored_preds_on_trues(apc(contacts),
                                        true_contacts,
                                        point_size=5)
            logger.log_metrics({filename: wandb.Image(plt)})
            plt.close()

            filename = "precision_vs_L.png"
            plot_precision_vs_length(contacts, true_contacts)
            logger.log_metrics({filename: wandb.Image(plt)})
            plt.close()

    if args.output_file is not None:
        torch.save(model.state_dict(), args.output_file)

    if args.contacts_file is not None:
        contacts = model.get_contacts()
        contacts = apc(contacts)
        x_ind, y_ind = np.triu_indices_from(contacts, 1)
        contacts = contacts[x_ind, y_ind]
        torch.save(contacts, args.contacts_file)
Exemplo n.º 4
0
 def test_range(self):
     auc = contact_auc(self.pred,
                       self.meas,
                       superdiag=0,
                       cutoff_range=[1, 2, 3])
     self.assertEqual(auc, 8.0 / 9)