示例#1
0
文件: main.py 项目: kuixu/PrismNet
def save_infers(out_dir, filename, predictions):
    evals_dir = datautils.make_directory(out_dir, "out/infer")
    probs_path = os.path.join(evals_dir, filename + '.probs')
    with open(probs_path, "w") as f:
        for i in range(len(predictions)):
            print("{:f}".format(predictions[i, 0]), file=f)
    print("Prediction file:", probs_path)
示例#2
0
def compute_saliency(args, model, device, test_loader, identity):
    from prismnet.model import GuidedBackpropSmoothGrad

    model.eval()

    saliency_dir = datautils.make_directory(args.out_dir, "out/saliency")
    saliency_path = os.path.join(saliency_dir, identity + '.sal')

    # sgrad = SmoothGrad(model, device=device)
    sgrad = GuidedBackpropSmoothGrad(model, device=device)
    sal = ""
    for batch_idx, (x0, y0) in enumerate(test_loader):
        X, Y = x0.float().to(device), y0.to(device).float()
        output = model(X)
        prob = torch.sigmoid(output)
        p_np = prob.to(device='cpu').detach().numpy().squeeze()
        guided_saliency = sgrad.get_batch_gradients(X, Y)
        # import pdb; pdb.set_trace()
        N, NS, _, _ = guided_saliency.shape  # (N, 101, 1, 5)

        for i in range(N):
            inr = batch_idx * args.batch_size + i
            str_sal = datautils.mat2str(np.squeeze(guided_saliency[i]))
            sal += "{}\t{:.6f}\t{}\n".format(inr, p_np[i], str_sal)

    f = open(saliency_path, "w")
    f.write(sal)
    f.close()
    print(saliency_path)
示例#3
0
def compute_high_attention_region(args, model, device, test_loader, identity):
    from prismnet.model import GuidedBackpropSmoothGrad

    har_dir = datautils.make_directory(args.out_dir, "out/har")
    har_path = os.path.join(har_dir, identity + '.har')

    L = 20
    har = ""
    # sgrad = SmoothGrad(model, device=device)
    sgrad = GuidedBackpropSmoothGrad(model, device=device)
    for batch_idx, (x0, y0) in enumerate(test_loader):
        X, Y = x0.float().to(device), y0.to(device).float()
        output = model(X)
        prob = torch.sigmoid(output)
        p_np = prob.to(device='cpu').detach().numpy().squeeze()
        guided_saliency = sgrad.get_batch_gradients(X, Y)

        attention_region = guided_saliency.sum(dim=3)[:, 0, :].to(
            device='cpu').numpy()  # (N, 101, 1)
        N, NS = attention_region.shape  # (N, 101)
        for i in range(N):
            inr = batch_idx * args.batch_size + i
            iar = attention_region[i]
            ar_score = np.array(
                [iar[j:j + L].sum() for j in range(NS - L + 1)])
            # import pdb; pdb.set_trace()
            highest_ind = np.argmax(iar)
            har += "{}\t{:.6f}\t{}\t{}\n".format(inr, p_np[i], highest_ind,
                                                 highest_ind + L)

    f = open(har_path, "w")
    f.write(har)
    f.close()
    print(har_path)
示例#4
0
文件: main.py 项目: kuixu/PrismNet
def save_evals(out_dir, filename, dataname, predictions, label, met):
    evals_dir = datautils.make_directory(out_dir, "out/evals")
    metrics_path = os.path.join(evals_dir, filename + '.metrics')
    probs_path = os.path.join(evals_dir, filename + '.probs')
    with open(metrics_path, "w") as f:
        if "_reg" in filename:
            print(
                "{:s}\t{:.3f}\t{:.3f}\t{:.3f}\t{:d}\t{:d}\t{:d}\t{:d}\t{:.3f}\t{:.3f}\t{:.3f}"
                .format(
                    dataname,
                    met.acc,
                    met.auc,
                    met.prc,
                    met.tp,
                    met.tn,
                    met.fp,
                    met.fn,
                    met.avg[7],
                    met.avg[8],
                    met.avg[9],
                ),
                file=f)
        else:
            print(
                "{:s}\t{:.3f}\t{:.3f}\t{:.3f}\t{:d}\t{:d}\t{:d}\t{:d}".format(
                    dataname,
                    met.acc,
                    met.auc,
                    met.prc,
                    met.tp,
                    met.tn,
                    met.fp,
                    met.fn,
                ),
                file=f)
    with open(probs_path, "w") as f:
        for i in range(len(predictions)):
            print("{:.3f}\t{}".format(predictions[i, 0], label[i, 0]), file=f)
    print("Evaluation file:", metrics_path)
    print("Prediction file:", probs_path)
示例#5
0
文件: main.py 项目: kuixu/PrismNet
def main():
    global writer, best_epoch
    # Training settings
    parser = argparse.ArgumentParser(
        description='Official version of PrismNet')
    # Data options
    parser.add_argument('--data_dir',
                        type=str,
                        default="data",
                        help='data path')
    parser.add_argument('--exp_name',
                        type=str,
                        default="cnn",
                        metavar='N',
                        help='experiment name')
    parser.add_argument('--p_name',
                        type=str,
                        default="TIA1_Hela",
                        metavar='N',
                        help='protein name')
    parser.add_argument('--out_dir',
                        type=str,
                        default=".",
                        help='output directory')
    parser.add_argument('--mode', type=str, default="pu", help='data mode')
    parser.add_argument("--infer_file",
                        type=str,
                        help="infer file",
                        default="")
    # Training Hyper-parameter
    parser.add_argument('--arch',
                        default="PrismNet",
                        help='network architecture')
    parser.add_argument('--lr_scheduler',
                        default="warmup",
                        help=' lr scheduler: warmup/cosine')
    parser.add_argument('--lr',
                        type=float,
                        default=0.0001,
                        help='learning rate')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='input batch size')
    parser.add_argument('--nepochs',
                        type=int,
                        default=200,
                        help='number of epochs to train')
    parser.add_argument('--pos_weight',
                        type=int,
                        default=2,
                        help='positive class weight')
    parser.add_argument('--weight_decay',
                        type=float,
                        default=1e-6,
                        help='weight decay, default=1e-6')
    parser.add_argument('--early_stopping',
                        type=int,
                        default=20,
                        help='early stopping')
    # Training
    parser.add_argument('--load_best',
                        action='store_true',
                        help='load best model')
    parser.add_argument('--eval', action='store_true', help='eval mode')
    parser.add_argument('--train', action='store_true', help='train mode')
    parser.add_argument('--infer', action='store_true', help='infer mode')
    parser.add_argument('--infer_test',
                        action='store_true',
                        help='infer test from h5')
    parser.add_argument('--eval_test',
                        action='store_true',
                        help='eval test from h5')
    parser.add_argument('--saliency',
                        action='store_true',
                        help='compute saliency mode')
    parser.add_argument('--saliency_img',
                        action='store_true',
                        help='compute saliency and plot image mode')
    parser.add_argument('--har',
                        action='store_true',
                        help='compute highest attention region')
    # misc
    parser.add_argument('--tfboard', action='store_true', help='tf board')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--workers',
                        type=int,
                        help='number of data loading workers',
                        default=2)
    parser.add_argument('--log_interval',
                        type=int,
                        default=100,
                        help='log print interval')
    parser.add_argument('--seed', type=int, default=1024, help='manual seed')
    args = parser.parse_args()
    print(args)
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    if args.mode == 'pu':
        args.nstr = 1
    else:
        args.nstr = 0

    # out dir
    data_path = args.data_dir + "/" + args.p_name + ".h5"
    identity = args.p_name + '_' + args.arch + "_" + args.mode
    datautils.make_directory(args.out_dir, "out/")
    model_dir = datautils.make_directory(args.out_dir, "out/models")
    model_path = os.path.join(model_dir, identity + "_{}.pth")

    if args.tfboard:
        tfb_dir = datautils.make_directory(args.out_dir, "out/tfb")
        writer = SummaryWriter(tfb_dir)
    else:
        writer = None
    # fix random seed
    fix_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {
        'num_workers': args.workers,
        'pin_memory': True
    } if use_cuda else {}

    train_loader = torch.utils.data.DataLoader(SeqicSHAPE(data_path), \
        batch_size=args.batch_size, shuffle=True,  **kwargs)

    test_loader  = torch.utils.data.DataLoader(SeqicSHAPE(data_path, is_test=True), \
        batch_size=args.batch_size*8, shuffle=False, **kwargs)
    print("Train set:", len(train_loader.dataset))
    print("Test  set:", len(test_loader.dataset))

    print("Network Arch:", args.arch)
    model = getattr(arch, args.arch)(mode=args.mode)
    arch.param_num(model)
    # print(model)

    if args.load_best:
        filename = model_path.format("best")
        print("Loading model: {}".format(filename))
        model.load_state_dict(torch.load(filename, map_location='cpu'))

    model = model.to(device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(args.pos_weight))

    if args.train:
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=args.lr,
                                     betas=(0.9, 0.999),
                                     weight_decay=args.weight_decay)
        scheduler = GradualWarmupScheduler(optimizer,
                                           multiplier=8,
                                           total_epoch=float(args.nepochs),
                                           after_scheduler=None)

        best_auc = 0
        best_acc = 0
        best_epoch = 0
        for epoch in range(1, args.nepochs + 1):
            t_met = train(args, model, device, train_loader, criterion,
                          optimizer)
            v_met, _, _ = validate(args, model, device, test_loader, criterion)
            scheduler.step(epoch)
            lr = scheduler.get_lr()[0]
            color_best = 'green'
            if best_auc < v_met.auc:
                best_auc = v_met.auc
                best_acc = v_met.acc
                best_epoch = epoch
                color_best = 'red'
                filename = model_path.format("best")
                torch.save(model.state_dict(), filename)
            if epoch - best_epoch > args.early_stopping:
                print("Early stop at %d, %s " % (epoch, args.exp_name))
                break

            if args.tfboard and writer is not None:
                writer.add_scalar('loss/train', t_met.other[0], epoch)
                writer.add_scalar('acc/train', t_met.acc, epoch)
                writer.add_scalar('AUC/train', t_met.auc, epoch)
                writer.add_scalar('lr', lr, epoch)
                writer.add_scalar('loss/test', v_met.other[0], epoch)
                writer.add_scalar('acc/test', v_met.acc, epoch)
                writer.add_scalar('AUC/test', v_met.auc, epoch)
            line='{} \t Train Epoch: {}     avg.loss: {:.4f} Acc: {:.2f}%, AUC: {:.4f} lr: {:.6f}'.format(\
                args.p_name, epoch, t_met.other[0], t_met.acc, t_met.auc, lr)
            log_print(line, color='green', attrs=['bold'])

            line='{} \t Test  Epoch: {}     avg.loss: {:.4f} Acc: {:.2f}%, AUC: {:.4f} ({:.4f})'.format(\
                args.p_name, epoch, v_met.other[0], v_met.acc, v_met.auc, best_auc)
            log_print(line, color=color_best, attrs=['bold'])

        print("{} auc: {:.4f} acc: {:.4f}".format(args.p_name, best_auc,
                                                  best_acc))

        filename = model_path.format("best")
        print("Loading model: {}".format(filename))
        model.load_state_dict(torch.load(filename))

    if args.eval:
        met, y_all, p_all = validate(args, model, device, test_loader,
                                     criterion)
        print("> eval {} auc: {:.4f} acc: {:.4f}".format(
            args.p_name, met.auc, met.acc))
        save_evals(args.out_dir, identity, args.p_name, p_all, y_all, met)

    if args.infer and os.path.exists(args.infer_file):
        test_loader  = torch.utils.data.DataLoader(SeqicSHAPE(args.infer_file, is_infer=True), \
            batch_size=args.batch_size, shuffle=False, **kwargs)

        p_all = inference(args, model, device, test_loader)
        identity = identity + "_" + os.path.basename(args.infer_file).replace(
            ".txt", "")
        save_infers(args.out_dir, identity, p_all)

    if args.saliency:
        compute_saliency(args, model, device, test_loader, identity)

    if args.saliency_img:
        compute_saliency_img(args, model, device, test_loader, identity)

    if args.har:
        compute_high_attention_region(args, model, device, test_loader,
                                      identity)
示例#6
0
def compute_saliency_img(args, model, device, test_loader, identity):
    from prismnet.model import GuidedBackpropSmoothGrad
    from prismnet.utils import visualize

    def saliency_img(X, mul_saliency, outdir="results"):
        """generate saliency image

        Args:
            X ([np.ndarray]): raw input(L x 5/4)
            mul_saliency ([np.ndarray]): [description]
            outdir (str, optional): [description]. Defaults to "results".
        """
        if X.shape[-1] == 5:
            x_str = X[:, 4:]
            str_null = np.zeros_like(x_str)
            ind = np.where(x_str == -1)[0]
            str_null[ind, 0] = 1

            ss = mul_saliency[:, :]
            s_str = mul_saliency[:, 4:]
            s_str = (s_str - s_str.min()) / (s_str.max() - s_str.min())
            ss[:, 4:] = s_str * (1 - str_null)

            str_null = np.squeeze(str_null).T
        else:
            str_null = None
            ss = mul_saliency[:, :]

        visualize.plot_saliency(X.T,
                                ss.T,
                                nt_width=100,
                                norm_factor=3,
                                str_null=str_null,
                                outdir=outdir)

    prefix_n = len(str(len(test_loader.dataset)))
    datautils.make_directory(args.out_dir, "out/imgs/")
    imgs_dir = datautils.make_directory(args.out_dir, "out/imgs/" + identity)
    imgs_path = imgs_dir + '/{:0' + str(prefix_n) + 'd}_{:.3f}.pdf'
    saliency_path = os.path.join(imgs_dir, 'all.sal')

    # sgrad = SmoothGrad(model, device=device)
    sgrad = GuidedBackpropSmoothGrad(model, device=device, magnitude=1)
    for batch_idx, (x0, y0) in enumerate(test_loader):
        X, Y = x0.float().to(device), y0.to(device).float()
        output = model(X)
        prob = torch.sigmoid(output)
        p_np = prob.to(device='cpu').detach().numpy().squeeze()
        guided_saliency = sgrad.get_batch_gradients(X, Y)
        mul_saliency = copy.deepcopy(guided_saliency)
        mul_saliency[:, :, :, :
                     4] = guided_saliency[:, :, :, :4] * X[:, :, :, :4]
        N, NS, _, _ = guided_saliency.shape  # (N, 101, 1, 5)
        sal = ""
        for i in tqdm(range(N)):
            inr = batch_idx * args.batch_size + i
            str_sal = datautils.mat2str(np.squeeze(guided_saliency[i]))
            sal += "{}\t{:.6f}\t{}\n".format(inr, p_np[i], str_sal)
            img_path = imgs_path.format(inr, p_np[i])
            # import pdb; pdb.set_trace()
            saliency_img(X[i, 0].to(device='cpu').detach().numpy(),
                         mul_saliency[i, 0].to(device='cpu').numpy(),
                         outdir=img_path)
    if not os.path.exists(saliency_path):
        f = open(saliency_path, "w")
        f.write(sal)
        f.close()
        print(saliency_path)