def test_superdiag_range(self): auc_superdiag_1 = contact_auc(self.pred, self.meas, superdiag=1, cutoff_range=[1, 2, 3]) auc_superdiag_2 = contact_auc(self.pred, self.meas, superdiag=2, cutoff_range=[1, 2, 3]) self.assertEqual(auc_superdiag_1, 8.0 / 9) self.assertEqual(auc_superdiag_2, 1.0)
def train(): # Initialize parser parser = ArgumentParser() parser.add_argument( "--model", default="gremlin", choices=models.MODELS.keys(), help="Which model to train.", ) parser.add_argument( "--train_unaligned", action="store_true", help="Whether to train unaligned instead.", ) model_name = parser.parse_known_args()[0].model train_unaligned = parser.parse_known_args()[0].train_unaligned parser.add_argument( "--save_model_s3", action="store_true", help="Whether to save the model state dict.", ) parser.add_argument( "--wandb_project", type=str, default="iclr2021-rebuttal", help="W&B project used for logging.", ) parser.add_argument( "--pdb", type=str, help="PDB id for training", ) if train_unaligned: parser = MSDataModule.add_args(parser) else: parser = MSADataModule.add_args(parser) parser = pl.Trainer.add_argparse_args(parser) parser.set_defaults( gpus=1, min_steps=50, max_steps=1000, log_every_n_steps=10, ) model_type = models.get(model_name) model_type.add_args(parser) args = parser.parse_args() # Modify name pdb = args.pdb args.data = "data/npz/" + pdb + ".npz" # Load ms(a) if train_unaligned: msa_dm = MSDataModule.from_args(args) else: msa_dm = MSADataModule.from_args(args) msa_dm.setup() # Load contacts true_contacts = torch.from_numpy(read_contacts(args.data)) # Initialize model num_seqs, msa_length, msa_counts = msa_dm.get_stats() model = model_type.from_args( args, num_seqs=num_seqs, msa_length=msa_length, msa_counts=msa_counts, vocab_size=len(FastaVocab), pad_idx=FastaVocab.pad_idx, true_contacts=true_contacts, ) kwargs = {} randstring = "".join(random.choice(string.ascii_lowercase) for i in range(6)) run_name = "_".join([args.model, pdb, randstring]) logger = WandbLoggerFrozenVal(project=args.wandb_project, name=run_name) logger.log_hyperparams(args) logger.log_hyperparams( { "pdb": pdb, "num_seqs": num_seqs, "msa_length": msa_length, } ) kwargs["logger"] = logger # Initialize Trainer trainer = pl.Trainer.from_argparse_args(args, checkpoint_callback=False, **kwargs) trainer.fit(model, msa_dm) # Log and print some metrics after training. contacts = model.get_contacts() apc_contacts = apc(contacts) auc = contact_auc(contacts, true_contacts).item() auc_apc = contact_auc(apc_contacts, true_contacts).item() print(f"AUC: {auc:0.3f}, AUC_APC: {auc_apc:0.3f}") filename = "top_L_contacts.png" plot_colored_preds_on_trues(contacts, true_contacts, point_size=5, cutoff=1) plt.title(f"Top L no APC {model.get_precision(do_apc=False)}") logger.log_metrics({filename: wandb.Image(plt)}) plt.close() filename = "top_L_contacts_apc.png" plot_colored_preds_on_trues(apc_contacts, true_contacts, point_size=5, cutoff=1) plt.title(f"Top L APC {model.get_precision(do_apc=True)}") logger.log_metrics({filename: wandb.Image(plt)}) plt.close() filename = "top_L_5_contacts.png" plot_colored_preds_on_trues(contacts, true_contacts, point_size=5, cutoff=5) plt.title(f"Top L/5 no APC {model.get_precision(do_apc=False, cutoff=5)}") logger.log_metrics({filename: wandb.Image(plt)}) plt.close() filename = "top_L_5_contacts_apc.png" plot_colored_preds_on_trues(apc_contacts, true_contacts, point_size=5, cutoff=5) plt.title(f"Top L/5 APC {model.get_precision(do_apc=True, cutoff=5)}") logger.log_metrics({filename: wandb.Image(plt)}) plt.close() filename = "precision_vs_L.png" plot_precision_vs_length(apc_contacts, true_contacts) logger.log_metrics({filename: wandb.Image(plt)}) plt.close() if args.save_model_s3: bytestream = io.BytesIO() torch.save(model.state_dict(), bytestream) bytestream.seek(0) key = os.path.join( "iclr-2021-factored-attention", wandb.run.path, "model_state_dict.h5" ) response = s3_client.put_object( Bucket=s3_bucket, Body=bytestream, Key=key, ACL="public-read" ) print(f"uploaded state dict to s3://{s3_bucket}/{key}")
def train(): # Initialize parser parser = ArgumentParser() parser.add_argument( "--model", default="gremlin", choices=models.MODELS.keys(), help="Which model to train.", ) model_name = parser.parse_known_args()[0].model parser.add_argument( "--structure_file", type=str, default=None, help=("Optional pdb or cf file containing protein structure. " "Used for evaluation."), ) parser.add_argument( "--output_file", type=str, default=None, help="Optional file to output gremlin weights.", ) parser.add_argument( "--contacts_file", type=str, default=None, help="Optional file to output gremlin contacts.", ) parser.add_argument( "--wandb_project", type=str, default=None, help="Optional wandb project to log to.", ) parser = MSADataModule.add_args(parser) parser = pl.Trainer.add_argparse_args(parser) parser.set_defaults( gpus=1, min_steps=50, max_steps=1000, ) model_type = models.get(model_name) model_type.add_args(parser) args = parser.parse_args() # Load msa msa_dm = MSADataModule.from_args(args) msa_dm.setup() # Load contacts true_contacts = (torch.from_numpy(read_contacts(args.structure_file)) if args.structure_file is not None else None) # Initialize model num_seqs, msa_length, msa_counts = msa_dm.get_stats() model = model_type.from_args( args, num_seqs=num_seqs, msa_length=msa_length, msa_counts=msa_counts, vocab_size=len(FastaVocab), pad_idx=FastaVocab.pad_idx, true_contacts=true_contacts, ) kwargs = {} if args.wandb_project: try: # Requires wandb to be installed logger = pl.loggers.WandbLogger(project=args.wandb_project) logger.log_hyperparams(args) logger.log_hyperparams({ "pdb": Path(args.data).stem, "num_seqs": num_seqs, "msa_length": msa_length, }) kwargs["logger"] = logger except ImportError: raise ImportError( "Cannot use W&B logger w/o W&b install. Run `pip install wandb` first." ) # Initialize Trainer trainer = pl.Trainer.from_argparse_args(args, checkpoint_callback=False, **kwargs) trainer.fit(model, msa_dm) if true_contacts is not None: contacts = model.get_contacts() auc = contact_auc(contacts, true_contacts).item() contacts = apc(contacts) auc_apc = contact_auc(contacts, true_contacts).item() print(f"AUC: {auc:0.3f}, AUC_APC: {auc_apc:0.3f}") if args.wandb_project: import matplotlib.pyplot as plt import wandb from mogwai.plotting import ( plot_colored_preds_on_trues, plot_precision_vs_length, ) filename = "top_L_contacts.png" plot_colored_preds_on_trues(contacts, true_contacts, point_size=5) logger.log_metrics({filename: wandb.Image(plt)}) plt.close() filename = "top_L_contacts_apc.png" plot_colored_preds_on_trues(apc(contacts), true_contacts, point_size=5) logger.log_metrics({filename: wandb.Image(plt)}) plt.close() filename = "precision_vs_L.png" plot_precision_vs_length(contacts, true_contacts) logger.log_metrics({filename: wandb.Image(plt)}) plt.close() if args.output_file is not None: torch.save(model.state_dict(), args.output_file) if args.contacts_file is not None: contacts = model.get_contacts() contacts = apc(contacts) x_ind, y_ind = np.triu_indices_from(contacts, 1) contacts = contacts[x_ind, y_ind] torch.save(contacts, args.contacts_file)
def test_range(self): auc = contact_auc(self.pred, self.meas, superdiag=0, cutoff_range=[1, 2, 3]) self.assertEqual(auc, 8.0 / 9)