def load_model(model, use_cuda=True, nic=len(common.atoms.atoms)): classifier = models.seqPred(nic=nic) if use_cuda: classifier.cuda() if use_cuda: state = torch.load(model) else: state = torch.load(model, map_location="cpu") for k in state.keys(): if "module" in k: print("MODULE") classifier = nn.DataParallel(classifier) break if use_cuda: classifier.load_state_dict(torch.load(model)) else: classifier.load_state_dict(torch.load(model, map_location="cpu")) return classifier
def main(): manager = common.run_manager.RunManager() manager.parse_args() args = manager.args log = manager.log use_cuda = torch.cuda.is_available() and args.cuda # set up model model = models.seqPred(nic=len(common.atoms.atoms) + 1 + 21, nf=args.nf, momentum=0.01) model.apply(models.init_ortho_weights) if use_cuda: model.cuda() else: print("Training model on CPU") if args.model != "": # load pretrained model model.load_state_dict(torch.load(args.model)) print("loaded pretrained model") # parallelize over available GPUs if torch.cuda.device_count() > 1 and args.cuda: print("using", torch.cuda.device_count(), "GPUs") model = nn.DataParallel(model) optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta1, 0.999), weight_decay=args.reg) if args.optimizer != "": # load pretrained optimizer optimizer.load_state_dict(torch.load(args.optimizer)) print("loaded pretrained optimizer") # load pretrained model weights / optimizer state chi_1_criterion = nn.CrossEntropyLoss(ignore_index=-1) chi_2_criterion = nn.CrossEntropyLoss(ignore_index=-1) chi_3_criterion = nn.CrossEntropyLoss(ignore_index=-1) chi_4_criterion = nn.CrossEntropyLoss(ignore_index=-1) criterion = nn.CrossEntropyLoss() if use_cuda: criterion.cuda() chi_1_criterion.cuda() chi_2_criterion.cuda() chi_3_criterion.cuda() chi_4_criterion.cuda() train_dataset = datasets.PDB_data_spitter(data_dir=args.data_dir + "/train_s95_chi") train_dataset.len = 8145448 # NOTE -- need to update this if underlying data changes test_dataset = datasets.PDB_data_spitter(data_dir=args.data_dir + "/test_s95_chi") test_dataset.len = 574267 # NOTE -- need to update this if underlying data changes train_dataloader = data.DataLoader(train_dataset, batch_size=args.batchSize, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=datasets.collate_wrapper) test_dataloader = data.DataLoader(test_dataset, batch_size=args.batchSize, shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=datasets.collate_wrapper) # training params validation_frequency = args.validation_frequency save_frequency = args.save_frequency """ TRAIN """ model.train() gen = iter(train_dataloader) test_gen = iter(test_dataloader) bs = args.batchSize output_atom = torch.zeros((bs, c + 1, n + 2, n + 2, n + 2)) output_bb = torch.zeros((bs, 2, n + 2, n + 2, n + 2)) output_res = torch.zeros((bs, 22, n + 2, n + 2, n + 2)) y_onehot = torch.FloatTensor(bs, 20) chi_1_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS)) chi_2_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS)) chi_3_onehot = torch.FloatTensor(bs, len(datasets.CHI_BINS)) if use_cuda: output_atom, output_bb, output_res, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot = map( lambda x: x.cuda(), [ output_atom, output_bb, output_res, y_onehot, chi_1_onehot, chi_2_onehot, chi_3_onehot ]) for epoch in range(args.epochs): for it in tqdm(range(len(train_dataloader)), desc="training epoch %0.2d" % epoch): gen, out = step_iter(gen, train_dataloader) bs_idx, x_atom, x_bb, x_b, y_b, z_b, x_res_type, y, chi_angles_real, chi_angles = out bs_i = len(bs_idx) output_atom.zero_() output_atom[bs_idx, x_atom, x_b, y_b, z_b] = 1 # atom type output_bb.zero_() output_bb[bs_idx, x_bb, x_b, y_b, z_b] = 1 # BB indicator output_res.zero_() output_res[bs_idx, x_res_type, x_b, y_b, z_b] = 1 # res type output = torch.cat( [output_atom[:, :c], output_bb[:, :1], output_res[:, :21]], 1) X = output[:, :, 1:-1, 1:-1, 1:-1] X, y = X.float(), y.long() chi_angles = chi_angles.long() chi_1 = chi_angles[:, 0] chi_2 = chi_angles[:, 1] chi_3 = chi_angles[:, 2] chi_4 = chi_angles[:, 3] if use_cuda: y, y_onehot, chi_1, chi_2, chi_3, chi_4 = map( lambda x: x.cuda(), [y, y_onehot, chi_1, chi_2, chi_3, chi_4]) if bs_i < bs: y = F.pad(y, (0, bs - bs_i)) chi_1 = F.pad(chi_1, (0, bs - bs_i)) chi_2 = F.pad(chi_2, (0, bs - bs_i)) chi_3 = F.pad(chi_3, (0, bs - bs_i)) y_onehot.zero_() y_onehot.scatter_(1, y[:, None], 1) chi_1_onehot.zero_() chi_1_onehot.scatter_(1, chi_1[:, None], 1) chi_2_onehot.zero_() chi_2_onehot.scatter_(1, chi_2[:, None], 1) chi_3_onehot.zero_() chi_3_onehot.scatter_(1, chi_3[:, None], 1) # 0 index for chi indicates that chi is masked out, chi_1_pred, chi_2_pred, chi_3_pred, chi_4_pred = model( X[:bs_i], y_onehot[:bs_i], chi_1_onehot[:bs_i, 1:], chi_2_onehot[:bs_i, 1:], chi_3_onehot[:bs_i, 1:]) res_loss = criterion(out, y[:bs_i]) chi_1_loss = chi_1_criterion(chi_1_pred, chi_1[:bs_i] - 1) # , 1:]) chi_2_loss = chi_2_criterion(chi_2_pred, chi_2[:bs_i] - 1) # , 1:]) chi_3_loss = chi_3_criterion(chi_3_pred, chi_3[:bs_i] - 1) # , 1:]) chi_4_loss = chi_4_criterion(chi_4_pred, chi_4[:bs_i] - 1) # , 1:]) train_loss = res_loss + chi_1_loss + chi_2_loss + chi_3_loss + chi_4_loss train_loss.backward() optimizer.step() # acc train_acc, _ = acc_util.get_acc(out, y[:bs_i], cm=None) train_top_k_acc = acc_util.get_top_k_acc(out, y[:bs_i], k=3) train_coarse_acc, _ = acc_util.get_acc( out, y[:bs_i], label_dict=acc_util.label_coarse) train_polar_acc, _ = acc_util.get_acc( out, y[:bs_i], label_dict=acc_util.label_polar) chi_1_acc, _ = acc_util.get_acc(chi_1_pred, chi_1[:bs_i] - 1, ignore_idx=-1) chi_2_acc, _ = acc_util.get_acc(chi_2_pred, chi_2[:bs_i] - 1, ignore_idx=-1) chi_3_acc, _ = acc_util.get_acc(chi_3_pred, chi_3[:bs_i] - 1, ignore_idx=-1) chi_4_acc, _ = acc_util.get_acc(chi_4_pred, chi_4[:bs_i] - 1, ignore_idx=-1) # tensorboard logging map( lambda x: log.log_scalar("seq_chi_pred/%s" % x[0], x[1]), zip( [ "res_loss", "chi_1_loss", "chi_2_loss", "chi_3_loss", "chi_4_loss", "train_acc", "chi_1_acc", "chi_2_acc", "chi_3_acc", "chi_4_acc", "train_top3_acc", "train_coarse_acc", "train_polar_acc" ], [ res_loss.item(), chi_1_loss.item(), chi_2_loss.item(), chi_3_loss.item(), chi_4_loss.item(), train_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc, train_top_k_acc, train_coarse_acc, train_polar_acc ], ), ) if it % validation_frequency == 0 or it == len( train_dataloader) - 1: if it > 0: if torch.cuda.device_count() > 1 and args.cuda: torch.save( model.module.state_dict(), log.log_path + "/seq_chi_pred_curr_weights.pt") else: torch.save( model.state_dict(), log.log_path + "/seq_chi_pred_curr_weights.pt") torch.save( optimizer.state_dict(), log.log_path + "/seq_chi_pred_curr_optimizer.pt") # NOTE -- saving models for each validation step if it > 0 and (it % save_frequency == 0 or it == len(train_dataloader) - 1): if torch.cuda.device_count() > 1 and args.cuda: torch.save( model.module.state_dict(), log.log_path + "/seq_chi_pred_epoch_%0.3d_%s_weights.pt" % (epoch, it)) else: torch.save( model.state_dict(), log.log_path + "/seq_chi_pred_epoch_%0.3d_%s_weights.pt" % (epoch, it)) torch.save( optimizer.state_dict(), log.log_path + "/seq_chi_pred_epoch_%0.3d_%s_optimizer.pt" % (epoch, it)) ##NOTE -- turning back on model.eval() model.eval() # eval on the test set test_gen, curr_test_loss, test_chi_1_loss, test_chi_2_loss, test_chi_3_loss, test_chi_4_loss, curr_test_acc, curr_test_top_k_acc, coarse_acc, polar_acc, chi_1_acc, chi_2_acc, chi_3_acc, chi_4_acc = test( model, test_gen, test_dataloader, criterion, chi_1_criterion, chi_2_criterion, chi_3_criterion, chi_4_criterion, max_it=len(test_dataloader), n_iters=min(10, len(test_dataloader)), desc="test", batch_size=args.batchSize, use_cuda=use_cuda, ) map( lambda x: log.log_scalar("seq_chi_pred/%s" % x[0], x[1]), zip( [ "test_loss", "test_chi_1_loss", "test_chi_2_loss", "test_chi_3_loss", "test_chi_4_loss", "test_acc", "test_chi_1_acc", "test_chi_2_acc", "test_chi_3_acc", "test_chi_4_acc", "test_acc_top3", "test_coarse_acc", "test_polar_acc", ], [ curr_test_loss.item(), chi_1_loss.item(), chi_2_loss.item(), chi_3_loss.item(), chi_4_loss.item(), curr_test_acc.item(), chi_1_acc.item(), chi_2_acc.item(), chi_3_acc.item(), chi_4_acc.item(), curr_test_top_k_acc.item(), coarse_acc.item(), polar_acc.item(), ], ), ) model.train() log.advance_iteration()