コード例 #1
0
ファイル: run.py プロジェクト: nanand2/protein_seq_des
def load_model(model, use_cuda=True, nic=len(common.atoms.atoms)):
    classifier = models.seqPred(nic=nic)
    if use_cuda:
        classifier.cuda()
    if use_cuda:
        state = torch.load(model)
    else:
        state = torch.load(model, map_location="cpu")
    for k in state.keys():
        if "module" in k:
            print("MODULE")
            classifier = nn.DataParallel(classifier)
        break
    if use_cuda:
        classifier.load_state_dict(torch.load(model))
    else:
        classifier.load_state_dict(torch.load(model, map_location="cpu"))
    return classifier
コード例 #2
0
def main():

    manager = common.run_manager.RunManager()

    manager.parse_args()
    args = manager.args
    log = manager.log

    use_cuda = torch.cuda.is_available() and args.cuda

    # set up model
    model = models.seqPred(nic=len(common.atoms.atoms) + 1 + 21,
                           nf=args.nf,
                           momentum=0.01)
    model.apply(models.init_ortho_weights)

    if use_cuda:
        model.cuda()
    else:
        print("Training model on CPU")

    if args.model != "":
        # load pretrained model
        model.load_state_dict(torch.load(args.model))
        print("loaded pretrained model")

    # parallelize over available GPUs
    if torch.cuda.device_count() > 1 and args.cuda:
        print("using", torch.cuda.device_count(), "GPUs")
        model = nn.DataParallel(model)

    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, 0.999),
                           weight_decay=args.reg)

    if args.optimizer != "":
        # load pretrained optimizer
        optimizer.load_state_dict(torch.load(args.optimizer))
        print("loaded pretrained optimizer")

    # load pretrained model weights / optimizer state

    chi_1_criterion = nn.CrossEntropyLoss(ignore_index=-1)
    chi_2_criterion = nn.CrossEntropyLoss(ignore_index=-1)
    chi_3_criterion = nn.CrossEntropyLoss(ignore_index=-1)
    chi_4_criterion = nn.CrossEntropyLoss(ignore_index=-1)
    criterion = nn.CrossEntropyLoss()
    if use_cuda:
        criterion.cuda()
        chi_1_criterion.cuda()
        chi_2_criterion.cuda()
        chi_3_criterion.cuda()
        chi_4_criterion.cuda()

    train_dataset = datasets.PDB_data_spitter(data_dir=args.data_dir +
                                              "/train_s95_chi")
    train_dataset.len = 8145448  # NOTE -- need to update this if underlying data changes

    test_dataset = datasets.PDB_data_spitter(data_dir=args.data_dir +
                                             "/test_s95_chi")
    test_dataset.len = 574267  # NOTE -- need to update this if underlying data changes

    train_dataloader = data.DataLoader(train_dataset,
                                       batch_size=args.batchSize,
                                       shuffle=False,
                                       num_workers=args.workers,
                                       pin_memory=True,
                                       collate_fn=datasets.collate_wrapper)
    test_dataloader = data.DataLoader(test_dataset,
                                      batch_size=args.batchSize,
                                      shuffle=False,
                                      num_workers=args.workers,
                                      pin_memory=True,
                                      collate_fn=datasets.collate_wrapper)

    # training params
    validation_frequency = args.validation_frequency
    save_frequency = args.save_frequency
    """ TRAIN """

    model.train()
    gen = iter(train_dataloader)
    test_gen = iter(test_dataloader)
    bs = args.batchSize
    output_atom = torch.zeros((bs, c + 1, n + 2, n + 2, n + 2))
    output_bb = torch.zeros((bs, 2, n + 2, n + 2, n + 2))
    output_res = torch.zeros((bs, 22, n + 2, n + 2, n + 2))
    y_onehot = torch.FloatTensor(bs, 20)
    chi_1_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS))
    chi_2_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS))
    chi_3_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS))

    if use_cuda:
        output_atom, output_bb, output_res, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot = map(
            lambda x: x.cuda(), [
                output_atom, output_bb, output_res, y_onehot, chi_1_onehot,
                chi_2_onehot, chi_3_onehot
            ])
    for epoch in range(args.epochs):
        for it in tqdm(range(len(train_dataloader)),
                       desc="training epoch %0.2d" % epoch):

            gen, out = step_iter(gen, train_dataloader)
            bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles_real, chi_angles = out
            bs_i = len(bs_idx)
            output_atom.zero_()
            output_atom[bs_idx, x_atom, x_b, y_b, z_b] = 1  # atom type
            output_bb.zero_()
            output_bb[bs_idx, x_bb, x_b, y_b, z_b] = 1  # BB indicator
            output_res.zero_()
            output_res[bs_idx, x_res_type, x_b, y_b, z_b] = 1  # res type
            output = torch.cat(
                [output_atom[:, :c], output_bb[:, :1], output_res[:, :21]], 1)

            X = output[:, :, 1:-1, 1:-1, 1:-1]

            X, y = X.float(), y.long()
            chi_angles = chi_angles.long()

            chi_1 = chi_angles[:, 0]
            chi_2 = chi_angles[:, 1]
            chi_3 = chi_angles[:, 2]
            chi_4 = chi_angles[:, 3]

            if use_cuda:
                y, y_onehot, chi_1, chi_2, chi_3, chi_4 = map(
                    lambda x: x.cuda(),
                    [y, y_onehot, chi_1, chi_2, chi_3, chi_4])

            if bs_i < bs:
                y = F.pad(y, (0, bs - bs_i))
                chi_1 = F.pad(chi_1, (0, bs - bs_i))
                chi_2 = F.pad(chi_2, (0, bs - bs_i))
                chi_3 = F.pad(chi_3, (0, bs - bs_i))

            y_onehot.zero_()
            y_onehot.scatter_(1, y[:, None], 1)

            chi_1_onehot.zero_()
            chi_1_onehot.scatter_(1, chi_1[:, None], 1)

            chi_2_onehot.zero_()
            chi_2_onehot.scatter_(1, chi_2[:, None], 1)

            chi_3_onehot.zero_()
            chi_3_onehot.scatter_(1, chi_3[:, None], 1)

            # 0  index for chi indicates that chi is masked
            out, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred = model(
                X[:bs_i], y_onehot[:bs_i], chi_1_onehot[:bs_i, 1:],
                chi_2_onehot[:bs_i, 1:], chi_3_onehot[:bs_i, 1:])
            res_loss = criterion(out, y[:bs_i])
            chi_1_loss = chi_1_criterion(chi_1_pred,
                                         chi_1[:bs_i] - 1)  # , 1:])
            chi_2_loss = chi_2_criterion(chi_2_pred,
                                         chi_2[:bs_i] - 1)  # , 1:])
            chi_3_loss = chi_3_criterion(chi_3_pred,
                                         chi_3[:bs_i] - 1)  # , 1:])
            chi_4_loss = chi_4_criterion(chi_4_pred,
                                         chi_4[:bs_i] - 1)  # , 1:])

            train_loss = res_loss + chi_1_loss + chi_2_loss + chi_3_loss + chi_4_loss
            train_loss.backward()
            optimizer.step()

            # acc
            train_acc, _ = acc_util.get_acc(out, y[:bs_i], cm=None)
            train_top_k_acc = acc_util.get_top_k_acc(out, y[:bs_i], k=3)
            train_coarse_acc, _ = acc_util.get_acc(
                out, y[:bs_i], label_dict=acc_util.label_coarse)
            train_polar_acc, _ = acc_util.get_acc(
                out, y[:bs_i], label_dict=acc_util.label_polar)

            chi_1_acc, _ = acc_util.get_acc(chi_1_pred,
                                            chi_1[:bs_i] - 1,
                                            ignore_idx=-1)
            chi_2_acc, _ = acc_util.get_acc(chi_2_pred,
                                            chi_2[:bs_i] - 1,
                                            ignore_idx=-1)
            chi_3_acc, _ = acc_util.get_acc(chi_3_pred,
                                            chi_3[:bs_i] - 1,
                                            ignore_idx=-1)
            chi_4_acc, _ = acc_util.get_acc(chi_4_pred,
                                            chi_4[:bs_i] - 1,
                                            ignore_idx=-1)

            # tensorboard logging
            map(
                lambda x: log.log_scalar("seq_chi_pred/%s" % x[0], x[1]),
                zip(
                    [
                        "res_loss", "chi_1_loss", "chi_2_loss", "chi_3_loss",
                        "chi_4_loss", "train_acc", "chi_1_acc", "chi_2_acc",
                        "chi_3_acc", "chi_4_acc", "train_top3_acc",
                        "train_coarse_acc", "train_polar_acc"
                    ],
                    [
                        res_loss.item(),
                        chi_1_loss.item(),
                        chi_2_loss.item(),
                        chi_3_loss.item(),
                        chi_4_loss.item(), train_acc, chi_1_acc, chi_2_acc,
                        chi_3_acc, chi_4_acc, train_top_k_acc,
                        train_coarse_acc, train_polar_acc
                    ],
                ),
            )

            if it % validation_frequency == 0 or it == len(
                    train_dataloader) - 1:

                if it > 0:
                    if torch.cuda.device_count() > 1 and args.cuda:
                        torch.save(
                            model.module.state_dict(),
                            log.log_path + "/seq_chi_pred_curr_weights.pt")
                    else:
                        torch.save(
                            model.state_dict(),
                            log.log_path + "/seq_chi_pred_curr_weights.pt")
                    torch.save(
                        optimizer.state_dict(),
                        log.log_path + "/seq_chi_pred_curr_optimizer.pt")

                # NOTE -- saving models for each validation step
                if it > 0 and (it % save_frequency == 0
                               or it == len(train_dataloader) - 1):
                    if torch.cuda.device_count() > 1 and args.cuda:
                        torch.save(
                            model.module.state_dict(), log.log_path +
                            "/seq_chi_pred_epoch_%0.3d_%s_weights.pt" %
                            (epoch, it))
                    else:
                        torch.save(
                            model.state_dict(), log.log_path +
                            "/seq_chi_pred_epoch_%0.3d_%s_weights.pt" %
                            (epoch, it))

                    torch.save(
                        optimizer.state_dict(), log.log_path +
                        "/seq_chi_pred_epoch_%0.3d_%s_optimizer.pt" %
                        (epoch, it))

                ##NOTE -- turning back on model.eval()
                model.eval()
                # eval on the test set
                test_gen, curr_test_loss, test_chi_1_loss, test_chi_2_loss, test_chi_3_loss, test_chi_4_loss, curr_test_acc, curr_test_top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc = test(
                    model,
                    test_gen,
                    test_dataloader,
                    criterion,
                    chi_1_criterion,
                    chi_2_criterion,
                    chi_3_criterion,
                    chi_4_criterion,
                    max_it=len(test_dataloader),
                    n_iters=min(10, len(test_dataloader)),
                    desc="test",
                    batch_size=args.batchSize,
                    use_cuda=use_cuda,
                )

                map(
                    lambda x: log.log_scalar("seq_chi_pred/%s" % x[0], x[1]),
                    zip(
                        [
                            "test_loss",
                            "test_chi_1_loss",
                            "test_chi_2_loss",
                            "test_chi_3_loss",
                            "test_chi_4_loss",
                            "test_acc",
                            "test_chi_1_acc",
                            "test_chi_2_acc",
                            "test_chi_3_acc",
                            "test_chi_4_acc",
                            "test_acc_top3",
                            "test_coarse_acc",
                            "test_polar_acc",
                        ],
                        [
                            curr_test_loss.item(),
                            chi_1_loss.item(),
                            chi_2_loss.item(),
                            chi_3_loss.item(),
                            chi_4_loss.item(),
                            curr_test_acc.item(),
                            chi_1_acc.item(),
                            chi_2_acc.item(),
                            chi_3_acc.item(),
                            chi_4_acc.item(),
                            curr_test_top_k_acc.item(),
                            coarse_acc.item(),
                            polar_acc.item(),
                        ],
                    ),
                )

                model.train()

            log.advance_iteration()