def save_infers(out_dir, filename, predictions): evals_dir = datautils.make_directory(out_dir, "out/infer") probs_path = os.path.join(evals_dir, filename + '.probs') with open(probs_path, "w") as f: for i in range(len(predictions)): print("{:f}".format(predictions[i, 0]), file=f) print("Prediction file:", probs_path)
def compute_saliency(args, model, device, test_loader, identity): from prismnet.model import GuidedBackpropSmoothGrad model.eval() saliency_dir = datautils.make_directory(args.out_dir, "out/saliency") saliency_path = os.path.join(saliency_dir, identity + '.sal') # sgrad = SmoothGrad(model, device=device) sgrad = GuidedBackpropSmoothGrad(model, device=device) sal = "" for batch_idx, (x0, y0) in enumerate(test_loader): X, Y = x0.float().to(device), y0.to(device).float() output = model(X) prob = torch.sigmoid(output) p_np = prob.to(device='cpu').detach().numpy().squeeze() guided_saliency = sgrad.get_batch_gradients(X, Y) # import pdb; pdb.set_trace() N, NS, _, _ = guided_saliency.shape # (N, 101, 1, 5) for i in range(N): inr = batch_idx * args.batch_size + i str_sal = datautils.mat2str(np.squeeze(guided_saliency[i])) sal += "{}\t{:.6f}\t{}\n".format(inr, p_np[i], str_sal) f = open(saliency_path, "w") f.write(sal) f.close() print(saliency_path)
def compute_high_attention_region(args, model, device, test_loader, identity): from prismnet.model import GuidedBackpropSmoothGrad har_dir = datautils.make_directory(args.out_dir, "out/har") har_path = os.path.join(har_dir, identity + '.har') L = 20 har = "" # sgrad = SmoothGrad(model, device=device) sgrad = GuidedBackpropSmoothGrad(model, device=device) for batch_idx, (x0, y0) in enumerate(test_loader): X, Y = x0.float().to(device), y0.to(device).float() output = model(X) prob = torch.sigmoid(output) p_np = prob.to(device='cpu').detach().numpy().squeeze() guided_saliency = sgrad.get_batch_gradients(X, Y) attention_region = guided_saliency.sum(dim=3)[:, 0, :].to( device='cpu').numpy() # (N, 101, 1) N, NS = attention_region.shape # (N, 101) for i in range(N): inr = batch_idx * args.batch_size + i iar = attention_region[i] ar_score = np.array( [iar[j:j + L].sum() for j in range(NS - L + 1)]) # import pdb; pdb.set_trace() highest_ind = np.argmax(iar) har += "{}\t{:.6f}\t{}\t{}\n".format(inr, p_np[i], highest_ind, highest_ind + L) f = open(har_path, "w") f.write(har) f.close() print(har_path)
def save_evals(out_dir, filename, dataname, predictions, label, met): evals_dir = datautils.make_directory(out_dir, "out/evals") metrics_path = os.path.join(evals_dir, filename + '.metrics') probs_path = os.path.join(evals_dir, filename + '.probs') with open(metrics_path, "w") as f: if "_reg" in filename: print( "{:s}\t{:.3f}\t{:.3f}\t{:.3f}\t{:d}\t{:d}\t{:d}\t{:d}\t{:.3f}\t{:.3f}\t{:.3f}" .format( dataname, met.acc, met.auc, met.prc, met.tp, met.tn, met.fp, met.fn, met.avg[7], met.avg[8], met.avg[9], ), file=f) else: print( "{:s}\t{:.3f}\t{:.3f}\t{:.3f}\t{:d}\t{:d}\t{:d}\t{:d}".format( dataname, met.acc, met.auc, met.prc, met.tp, met.tn, met.fp, met.fn, ), file=f) with open(probs_path, "w") as f: for i in range(len(predictions)): print("{:.3f}\t{}".format(predictions[i, 0], label[i, 0]), file=f) print("Evaluation file:", metrics_path) print("Prediction file:", probs_path)
def main(): global writer, best_epoch # Training settings parser = argparse.ArgumentParser( description='Official version of PrismNet') # Data options parser.add_argument('--data_dir', type=str, default="data", help='data path') parser.add_argument('--exp_name', type=str, default="cnn", metavar='N', help='experiment name') parser.add_argument('--p_name', type=str, default="TIA1_Hela", metavar='N', help='protein name') parser.add_argument('--out_dir', type=str, default=".", help='output directory') parser.add_argument('--mode', type=str, default="pu", help='data mode') parser.add_argument("--infer_file", type=str, help="infer file", default="") # Training Hyper-parameter parser.add_argument('--arch', default="PrismNet", help='network architecture') parser.add_argument('--lr_scheduler', default="warmup", help=' lr scheduler: warmup/cosine') parser.add_argument('--lr', type=float, default=0.0001, help='learning rate') parser.add_argument('--batch_size', type=int, default=64, help='input batch size') parser.add_argument('--nepochs', type=int, default=200, help='number of epochs to train') parser.add_argument('--pos_weight', type=int, default=2, help='positive class weight') parser.add_argument('--weight_decay', type=float, default=1e-6, help='weight decay, default=1e-6') parser.add_argument('--early_stopping', type=int, default=20, help='early stopping') # Training parser.add_argument('--load_best', action='store_true', help='load best model') parser.add_argument('--eval', action='store_true', help='eval mode') parser.add_argument('--train', action='store_true', help='train mode') parser.add_argument('--infer', action='store_true', help='infer mode') parser.add_argument('--infer_test', action='store_true', help='infer test from h5') parser.add_argument('--eval_test', action='store_true', help='eval test from h5') parser.add_argument('--saliency', action='store_true', help='compute saliency mode') parser.add_argument('--saliency_img', action='store_true', help='compute saliency and plot image mode') parser.add_argument('--har', action='store_true', help='compute highest attention region') # misc parser.add_argument('--tfboard', action='store_true', help='tf board') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) parser.add_argument('--log_interval', type=int, default=100, help='log print interval') parser.add_argument('--seed', type=int, default=1024, help='manual seed') args = parser.parse_args() print(args) use_cuda = not args.no_cuda and torch.cuda.is_available() if args.mode == 'pu': args.nstr = 1 else: args.nstr = 0 # out dir data_path = args.data_dir + "/" + args.p_name + ".h5" identity = args.p_name + '_' + args.arch + "_" + args.mode datautils.make_directory(args.out_dir, "out/") model_dir = datautils.make_directory(args.out_dir, "out/models") model_path = os.path.join(model_dir, identity + "_{}.pth") if args.tfboard: tfb_dir = datautils.make_directory(args.out_dir, "out/tfb") writer = SummaryWriter(tfb_dir) else: writer = None # fix random seed fix_seed(args.seed) device = torch.device("cuda" if use_cuda else "cpu") kwargs = { 'num_workers': args.workers, 'pin_memory': True } if use_cuda else {} train_loader = torch.utils.data.DataLoader(SeqicSHAPE(data_path), \ batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader(SeqicSHAPE(data_path, is_test=True), \ batch_size=args.batch_size*8, shuffle=False, **kwargs) print("Train set:", len(train_loader.dataset)) print("Test set:", len(test_loader.dataset)) print("Network Arch:", args.arch) model = getattr(arch, args.arch)(mode=args.mode) arch.param_num(model) # print(model) if args.load_best: filename = model_path.format("best") print("Loading model: {}".format(filename)) model.load_state_dict(torch.load(filename, map_location='cpu')) model = model.to(device) criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(args.pos_weight)) if args.train: optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.weight_decay) scheduler = GradualWarmupScheduler(optimizer, multiplier=8, total_epoch=float(args.nepochs), after_scheduler=None) best_auc = 0 best_acc = 0 best_epoch = 0 for epoch in range(1, args.nepochs + 1): t_met = train(args, model, device, train_loader, criterion, optimizer) v_met, _, _ = validate(args, model, device, test_loader, criterion) scheduler.step(epoch) lr = scheduler.get_lr()[0] color_best = 'green' if best_auc < v_met.auc: best_auc = v_met.auc best_acc = v_met.acc best_epoch = epoch color_best = 'red' filename = model_path.format("best") torch.save(model.state_dict(), filename) if epoch - best_epoch > args.early_stopping: print("Early stop at %d, %s " % (epoch, args.exp_name)) break if args.tfboard and writer is not None: writer.add_scalar('loss/train', t_met.other[0], epoch) writer.add_scalar('acc/train', t_met.acc, epoch) writer.add_scalar('AUC/train', t_met.auc, epoch) writer.add_scalar('lr', lr, epoch) writer.add_scalar('loss/test', v_met.other[0], epoch) writer.add_scalar('acc/test', v_met.acc, epoch) writer.add_scalar('AUC/test', v_met.auc, epoch) line='{} \t Train Epoch: {} avg.loss: {:.4f} Acc: {:.2f}%, AUC: {:.4f} lr: {:.6f}'.format(\ args.p_name, epoch, t_met.other[0], t_met.acc, t_met.auc, lr) log_print(line, color='green', attrs=['bold']) line='{} \t Test Epoch: {} avg.loss: {:.4f} Acc: {:.2f}%, AUC: {:.4f} ({:.4f})'.format(\ args.p_name, epoch, v_met.other[0], v_met.acc, v_met.auc, best_auc) log_print(line, color=color_best, attrs=['bold']) print("{} auc: {:.4f} acc: {:.4f}".format(args.p_name, best_auc, best_acc)) filename = model_path.format("best") print("Loading model: {}".format(filename)) model.load_state_dict(torch.load(filename)) if args.eval: met, y_all, p_all = validate(args, model, device, test_loader, criterion) print("> eval {} auc: {:.4f} acc: {:.4f}".format( args.p_name, met.auc, met.acc)) save_evals(args.out_dir, identity, args.p_name, p_all, y_all, met) if args.infer and os.path.exists(args.infer_file): test_loader = torch.utils.data.DataLoader(SeqicSHAPE(args.infer_file, is_infer=True), \ batch_size=args.batch_size, shuffle=False, **kwargs) p_all = inference(args, model, device, test_loader) identity = identity + "_" + os.path.basename(args.infer_file).replace( ".txt", "") save_infers(args.out_dir, identity, p_all) if args.saliency: compute_saliency(args, model, device, test_loader, identity) if args.saliency_img: compute_saliency_img(args, model, device, test_loader, identity) if args.har: compute_high_attention_region(args, model, device, test_loader, identity)
def compute_saliency_img(args, model, device, test_loader, identity): from prismnet.model import GuidedBackpropSmoothGrad from prismnet.utils import visualize def saliency_img(X, mul_saliency, outdir="results"): """generate saliency image Args: X ([np.ndarray]): raw input(L x 5/4) mul_saliency ([np.ndarray]): [description] outdir (str, optional): [description]. Defaults to "results". """ if X.shape[-1] == 5: x_str = X[:, 4:] str_null = np.zeros_like(x_str) ind = np.where(x_str == -1)[0] str_null[ind, 0] = 1 ss = mul_saliency[:, :] s_str = mul_saliency[:, 4:] s_str = (s_str - s_str.min()) / (s_str.max() - s_str.min()) ss[:, 4:] = s_str * (1 - str_null) str_null = np.squeeze(str_null).T else: str_null = None ss = mul_saliency[:, :] visualize.plot_saliency(X.T, ss.T, nt_width=100, norm_factor=3, str_null=str_null, outdir=outdir) prefix_n = len(str(len(test_loader.dataset))) datautils.make_directory(args.out_dir, "out/imgs/") imgs_dir = datautils.make_directory(args.out_dir, "out/imgs/" + identity) imgs_path = imgs_dir + '/{:0' + str(prefix_n) + 'd}_{:.3f}.pdf' saliency_path = os.path.join(imgs_dir, 'all.sal') # sgrad = SmoothGrad(model, device=device) sgrad = GuidedBackpropSmoothGrad(model, device=device, magnitude=1) for batch_idx, (x0, y0) in enumerate(test_loader): X, Y = x0.float().to(device), y0.to(device).float() output = model(X) prob = torch.sigmoid(output) p_np = prob.to(device='cpu').detach().numpy().squeeze() guided_saliency = sgrad.get_batch_gradients(X, Y) mul_saliency = copy.deepcopy(guided_saliency) mul_saliency[:, :, :, : 4] = guided_saliency[:, :, :, :4] * X[:, :, :, :4] N, NS, _, _ = guided_saliency.shape # (N, 101, 1, 5) sal = "" for i in tqdm(range(N)): inr = batch_idx * args.batch_size + i str_sal = datautils.mat2str(np.squeeze(guided_saliency[i])) sal += "{}\t{:.6f}\t{}\n".format(inr, p_np[i], str_sal) img_path = imgs_path.format(inr, p_np[i]) # import pdb; pdb.set_trace() saliency_img(X[i, 0].to(device='cpu').detach().numpy(), mul_saliency[i, 0].to(device='cpu').numpy(), outdir=img_path) if not os.path.exists(saliency_path): f = open(saliency_path, "w") f.write(sal) f.close() print(saliency_path)