Esempio n. 1
0
def Test():
    print('********************load data********************')
    dataloader_test = get_test_dataloader(batch_size=config['BATCH_SIZE'], shuffle=False, num_workers=8)
    print('********************load data succeed!********************')

    print('********************load model********************')
    # initialize and load the model
    if args.model == 'CXRNet':
        #for left
        model_left = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True).cuda()
        CKPT_PATH = config['CKPT_PATH']  +'left_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_left.load_state_dict(checkpoint) #strict=False
        print("=> loaded left model checkpoint: "+CKPT_PATH)
        model_left.eval()

        model_unet_left = UNet(n_channels=3, n_classes=1).cuda()#initialize model 
        CKPT_PATH = config['CKPT_PATH'] +  'best_unet_left.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_unet_left.load_state_dict(checkpoint) #strict=False
        print("=> loaded well-trained unet model checkpoint: "+CKPT_PATH)
        model_unet_left.eval()

        #for right
        model_right = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True).cuda()
        CKPT_PATH = config['CKPT_PATH']  +'right_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_right.load_state_dict(checkpoint) #strict=False
        print("=> loaded right model checkpoint: "+CKPT_PATH)
        model_right.eval()

        model_unet_right = UNet(n_channels=3, n_classes=1).cuda()#initialize model 
        CKPT_PATH = config['CKPT_PATH'] +  'best_unet_right.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_unet_right.load_state_dict(checkpoint) #strict=False
        print("=> loaded well-trained unet model checkpoint: "+CKPT_PATH)
        model_unet_right.eval()

        #for heart
        model_heart = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True).cuda()
        CKPT_PATH = config['CKPT_PATH']  +'heart_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_heart.load_state_dict(checkpoint) #strict=False
        print("=> loaded heart model checkpoint: "+CKPT_PATH)
        model_heart.eval()

        model_unet_heart = UNet(n_channels=3, n_classes=1).cuda()#initialize model 
        CKPT_PATH = config['CKPT_PATH'] +  'best_unet_right.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_unet_heart.load_state_dict(checkpoint) #strict=False
        print("=> loaded well-trained unet model checkpoint: "+CKPT_PATH)
        model_unet_heart.eval()
        
    else: 
        print('No required model')
        return #over
    torch.backends.cudnn.benchmark = True  # improve train speed slightly
        
    print('******* begin testing!*********')
    gt = torch.FloatTensor().cuda()
    preds = torch.FloatTensor().cuda()
    # switch to evaluate mode
    model_img.eval() #turn to test mode
    with torch.autograd.no_grad():
        for batch_idx, (image, label) in enumerate(dataloader_test):
            gt = torch.cat((gt, label.cuda()), 0)
            pred = torch.FloatTensor().cuda()
            var_image = torch.autograd.Variable(image).cuda()
            var_label = torch.autograd.Variable(label).cuda()
            #for left lung
            mask = model_unet_left(var_image)
            roi = ROIGeneration(image, mask)
            var_roi = torch.autograd.Variable(roi).cuda()
            out_left = model_left(var_roi)#forward
            pred = torch.cat((pred, out_left.data.unsqueeze(0)), 0)
            #for right lung
            mask = model_unet_right(var_image)
            roi = ROIGeneration(image, mask)
            var_roi = torch.autograd.Variable(roi).cuda()
            out_right = model_right(var_roi)#forward
            pred = torch.cat((pred, out_right.data.unsqueeze(0)), 0)
            #for heart
            mask = model_unet_heart(var_image)
            roi = ROIGeneration(image, mask)
            var_roi = torch.autograd.Variable(roi).cuda()
            out_heart = model_heart(var_roi)#forward
            pred = torch.cat((pred, out_heart.data.unsqueeze(0)), 0)
            #prediction
            pred = torch.max(pred, 0)[0] #torch.mean
            preds = torch.cat((preds, pred.data), 0)
            sys.stdout.write('\r testing process: = {}'.format(batch_idx+1))
            sys.stdout.flush()

    #for evaluation
    AUROC_img = compute_AUCs(gt, preds)
    AUROC_avg = np.array(AUROC_img).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_img[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))
Esempio n. 2
0
def Test():
    print('********************load data********************')
    dataloader_test = get_test_dataloader(batch_size=config['BATCH_SIZE'], shuffle=False, num_workers=8)
    print('********************load data succeed!********************')

    print('********************load model********************')
    # initialize and load the model
    if args.model == 'CXRNet':
        model_img = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True, is_roi=False).cuda()
        CKPT_PATH = config['CKPT_PATH']  +'img_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_img.load_state_dict(checkpoint) #strict=False
        print("=> loaded Image model checkpoint: "+CKPT_PATH)

        model_roi = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True, is_roi=True).cuda()
        CKPT_PATH = config['CKPT_PATH'] + 'roi_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_roi.load_state_dict(checkpoint) #strict=False
        print("=> loaded ROI model checkpoint: "+CKPT_PATH)

        model_fusion = FusionClassifier(input_size=2048, output_size=N_CLASSES).cuda()
        CKPT_PATH = config['CKPT_PATH'] + 'fusion_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_fusion.load_state_dict(checkpoint) #strict=False
        print("=> loaded Fusion model checkpoint: "+CKPT_PATH)

        roigen = ROIGenerator() #region generator

    else: 
        print('No required model')
        return #over
    torch.backends.cudnn.benchmark = True  # improve train speed slightly
    print('******************** load model succeed!********************')

    print('******* begin testing!*********')
    gt = torch.FloatTensor().cuda()
    pred_img = torch.FloatTensor().cuda()
    pred_roi = torch.FloatTensor().cuda()
    pred_fusion = torch.FloatTensor().cuda()
    # switch to evaluate mode
    model_img.eval() #turn to test mode
    model_roi.eval()
    model_fusion.eval()
    cudnn.benchmark = True
    with torch.autograd.no_grad():
        for batch_idx, (image, label) in enumerate(dataloader_test):
            gt = torch.cat((gt, label.cuda()), 0)
            #image-level
            var_image = torch.autograd.Variable(image).cuda()
            #var_label = torch.autograd.Variable(label).cuda()
            conv_fea_img, fc_fea_img, out_img = model_img(var_image)#forward
            pred_img = torch.cat((pred_img, out_img.data), 0)
            #ROI-level
            #-----predicted label---------------
            shape_l, shape_c = out_img.size()[0], out_img.size()[1]
            pdlabel = torch.FloatTensor(shape_l, shape_c).zero_()
            for i in range(shape_l):
                for j in range(shape_c): 
                    if pdlabel[i,j]>classes_threshold_common[j]:
                        pdlabel[i,j]> = 1.0
            #-----predicted label---------------
            cls_weights = list(model_img.parameters())
            weight_softmax = np.squeeze(cls_weights[-5].data.cpu().numpy())
            roi = roigen.ROIGeneration(image, conv_fea_img, weight_softmax, pdlabel.numpy())
            var_roi = torch.autograd.Variable(roi).cuda()
            _, fc_fea_roi, out_roi = model_roi(var_roi)
            pred_roi = torch.cat((pred_roi, out_roi.data), 0)
            #Fusion
            fc_fea_fusion = torch.cat((fc_fea_img,fc_fea_roi), 1)
            var_fusion = torch.autograd.Variable(fc_fea_fusion).cuda()
            out_fusion = model_fusion(var_fusion)
            pred_fusion = torch.cat((pred_fusion, out_fusion.data), 0)
            sys.stdout.write('\r testing process: = {}'.format(batch_idx+1))
            sys.stdout.flush()

    #for evaluation
    AUROC_img = compute_AUCs(gt, pred_img)
    AUROC_avg = np.array(AUROC_img).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_img[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))

    AUROC_roi = compute_AUCs(gt, pred_roi)
    AUROC_avg = np.array(AUROC_roi).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_roi[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))

    AUROC_fusion = compute_AUCs(gt, pred_fusion)
    AUROC_avg = np.array(AUROC_fusion).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_fusion[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))

    #Evaluating the threshold of prediction
    thresholds = compute_ROCCurve(gt, pred_fusion)
    print(thresholds)
Esempio n. 3
0
def Test():
    print('********************load data********************')
    #dataloader_train = get_train_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8)
    dataloader_val = get_validation_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8)
    dataloader_test = get_test_dataloader(batch_size=config['BATCH_SIZE'], shuffle=False, num_workers=8)
    print('********************load data succeed!********************')

    print('********************load model********************')
    # initialize and load the model
    if args.model == 'CXRNet':
        model_img = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True, is_roi=False).cuda()
        CKPT_PATH = config['CKPT_PATH']  +'img_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_img.load_state_dict(checkpoint) #strict=False
        print("=> loaded Image model checkpoint: "+CKPT_PATH)

        model_roi = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True, is_roi=True).cuda()
        CKPT_PATH = config['CKPT_PATH'] + 'roi_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_roi.load_state_dict(checkpoint) #strict=False
        print("=> loaded ROI model checkpoint: "+CKPT_PATH)

        model_fusion = FusionClassifier(input_size=2048, output_size=N_CLASSES).cuda()
        CKPT_PATH = config['CKPT_PATH'] + 'fusion_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_fusion.load_state_dict(checkpoint) #strict=False
        print("=> loaded Fusion model checkpoint: "+CKPT_PATH)

        roigen = ROIGenerator() #region generator

    else: 
        print('No required model')
        return #over
    torch.backends.cudnn.benchmark = True  # improve train speed slightly
    print('******************** load model succeed!********************')

    print('******* begin testing!*********')
    gt = torch.FloatTensor().cuda()
    pred_img = torch.FloatTensor().cuda()
    pred_roi = torch.FloatTensor().cuda()
    pred_fusion = torch.FloatTensor().cuda()
    # switch to evaluate mode
    model_img.eval() #turn to test mode
    model_roi.eval()
    model_fusion.eval()
    cudnn.benchmark = True
    with torch.autograd.no_grad():
        for batch_idx, (image, label) in enumerate(dataloader_test):
            gt = torch.cat((gt, label.cuda()), 0)
            #image-level
            var_image = torch.autograd.Variable(image).cuda()
            #var_label = torch.autograd.Variable(label).cuda()
            conv_fea_img, fc_fea_img, out_img = model_img(var_image)#forward
            pred_img = torch.cat((pred_img, out_img.data), 0)
            #ROI-level
            roi = roigen.ROIGeneration(conv_fea_img)
            var_roi = torch.autograd.Variable(roi).cuda()
            out_roi = model_roi(var_roi)
            pred_roi = torch.cat((pred_roi, out_roi.data), 0)
            #Fusion
            fc_fea_fusion = torch.cat((fc_fea_img,var_roi), 1)
            var_fusion = torch.autograd.Variable(fc_fea_fusion).cuda()
            out_fusion = model_fusion(var_fusion)
            pred_fusion = torch.cat((pred_fusion, out_fusion.data), 0)
            sys.stdout.write('\r testing process: = {}'.format(batch_idx+1))
            sys.stdout.flush()

    #for evaluation
    AUROC_img = compute_AUCs(gt, pred_img)
    AUROC_avg = np.array(AUROC_img).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_img[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))

    AUROC_roi = compute_AUCs(gt, pred_roi)
    AUROC_avg = np.array(AUROC_roi).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_roi[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))

    AUROC_fusion = compute_AUCs(gt, pred_fusion)
    AUROC_avg = np.array(AUROC_fusion).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_fusion[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))