예제 #1
0
def Test():
    print('********************load data********************')
    dataloader_test = get_test_dataloader(batch_size=config['BATCH_SIZE'],
                                          shuffle=False,
                                          num_workers=8)
    print('********************load data succeed!********************')

    print('********************load model********************')
    # initialize and load the model
    if args.model == 'CXRNet':
        model = CXRNet(num_classes=N_CLASSES, is_pre_trained=True).cuda()
        CKPT_PATH = config['CKPT_PATH'] + 'best_model_CXRNet.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model.load_state_dict(checkpoint)  #strict=False
        print("=> loaded Image model checkpoint: " + CKPT_PATH)
        torch.backends.cudnn.benchmark = True  # improve train speed slightly

        model_unet = UNet(n_channels=3, n_classes=1).cuda()  #initialize model
        CKPT_PATH = config['CKPT_PATH'] + 'best_unet.pkl'
        if os.path.exists(CKPT_PATH):
            checkpoint = torch.load(CKPT_PATH)
            model_unet.load_state_dict(checkpoint)  #strict=False
            print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH)
        model_unet.eval()
    else:
        print('No required model')
        return  #over
    print('******************** load model succeed!********************')

    print('******* begin testing!*********')
    gt = torch.FloatTensor().cuda()
    pred = torch.FloatTensor().cuda()
    with torch.autograd.no_grad():
        for batch_idx, (image, label) in enumerate(dataloader_test):
            gt = torch.cat((gt, label.cuda()), 0)
            var_image = torch.autograd.Variable(image).cuda()
            var_label = torch.autograd.Variable(label).cuda()
            var_mask = model_unet(var_image)
            var_output = model(var_image, var_mask)  #forward
            pred = torch.cat((pred, var_output.data), 0)
            sys.stdout.write('\r testing process: = {}'.format(batch_idx + 1))
            sys.stdout.flush()

    #for evaluation
    AUROC_all = compute_AUCs(gt, pred)
    AUROC_avg = np.array(AUROC_all).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_all[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))
예제 #2
0
def Train():
    print('********************load data********************')
    dataloader_train = get_train_dataloader(batch_size=config['BATCH_SIZE'],
                                            shuffle=True,
                                            num_workers=8)
    dataloader_val = get_validation_dataloader(batch_size=config['BATCH_SIZE'],
                                               shuffle=True,
                                               num_workers=8)
    #dataloader_train, dataloader_val = get_train_val_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8, split_ratio=0.1)
    print('********************load data succeed!********************')

    print('********************load model********************')
    # initialize and load the model
    if args.model == 'CXRNet':

        model = CXRClassifier(num_classes=N_CLASSES,
                              is_pre_trained=True).cuda()  #initialize model
        optimizer_model = optim.Adam(model.parameters(),
                                     lr=1e-3,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=1e-5)
        lr_scheduler_model = lr_scheduler.StepLR(optimizer_model,
                                                 step_size=10,
                                                 gamma=1)

        #for left_lung
        model_unet_left = UNet(n_channels=3,
                               n_classes=1).cuda()  #initialize model
        CKPT_PATH = config['CKPT_PATH'] + 'best_unet_left.pkl'
        if os.path.exists(CKPT_PATH):
            checkpoint = torch.load(CKPT_PATH)
            model_unet_left.load_state_dict(checkpoint)  #strict=False
            print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH)
        model_unet_left.eval()
        #for right lung
        model_unet_right = UNet(n_channels=3,
                                n_classes=1).cuda()  #initialize model
        CKPT_PATH = config['CKPT_PATH'] + 'best_unet_right.pkl'
        if os.path.exists(CKPT_PATH):
            checkpoint = torch.load(CKPT_PATH)
            model_unet_right.load_state_dict(checkpoint)  #strict=False
            print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH)
        model_unet_right.eval()
        #for heart
        model_unet_heart = UNet(n_channels=3,
                                n_classes=1).cuda()  #initialize model
        CKPT_PATH = config['CKPT_PATH'] + 'best_unet_heart.pkl'
        if os.path.exists(CKPT_PATH):
            checkpoint = torch.load(CKPT_PATH)
            model_unet_heart.load_state_dict(checkpoint)  #strict=False
            print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH)
        model_unet_heart.eval()
    else:
        print('No required model')
        return  #over

    torch.backends.cudnn.benchmark = True  # improve train speed slightly
    bce_criterion = nn.BCELoss()  #define binary cross-entropy loss
    print('********************load model succeed!********************')

    print('********************begin training!********************')
    AUROC_best = 0.50
    for epoch in range(config['MAX_EPOCHS']):
        since = time.time()
        print('Epoch {}/{}'.format(epoch + 1, config['MAX_EPOCHS']))
        print('-' * 10)
        model.train()  #set model to training mode
        train_loss = []
        with torch.autograd.enable_grad():
            for batch_idx, (image, label) in enumerate(dataloader_train):
                optimizer_model.zero_grad()
                #pathological regions generation
                var_image = torch.autograd.Variable(image).cuda()
                mask_left = model_unet_left(var_image)  #for left lung
                mask_right = model_unet_right(var_image)  #for right lung
                mask_heart = model_unet_heart(var_image)  #for heart
                patchs, patch_labels, globals, global_labels = ROIGeneration(
                    image, [mask_left, mask_right, mask_heart], label)
                #training
                loss_patch, loss_global = torch.FloatTensor(
                    [0.0]).cuda(), torch.FloatTensor([0.0]).cuda()
                if len(patchs) > 0:
                    var_patchs = torch.autograd.Variable(patchs).cuda()
                    var_patch_labels = torch.autograd.Variable(
                        patch_labels).cuda()
                    out_patch = model(var_patchs, is_patch=True)  #forward
                    loss_patch = bce_criterion(out_patch, var_patch_labels)
                """
                if len(globals)>0:
                    var_globals = torch.autograd.Variable(globals).cuda()
                    var_global_labels = torch.autograd.Variable(global_labels).cuda()
                    out_global = model(var_globals, is_patch = False)#forward
                    loss_global = bce_criterion(out_global, var_global_labels)
                """
                loss_tensor = loss_patch + loss_global
                loss_tensor.backward()
                optimizer_model.step()
                train_loss.append(loss_tensor.item())
                sys.stdout.write(
                    '\r Epoch: {} / Step: {} : train loss = {}'.format(
                        epoch + 1, batch_idx + 1,
                        float('%0.6f' % loss_tensor.item())))
                sys.stdout.flush()
        lr_scheduler_model.step()  #about lr and gamma
        print("\r Eopch: %5d train loss = %.6f" %
              (epoch + 1, np.mean(train_loss)))

        model.eval()  #turn to test mode
        val_loss = []
        gt = torch.FloatTensor().cuda()
        pred = torch.FloatTensor().cuda()
        with torch.autograd.no_grad():
            for batch_idx, (image, label) in enumerate(dataloader_val):
                #pathological regions generation
                var_image = torch.autograd.Variable(image).cuda()
                mask_left = model_unet_left(var_image)  #for left lung
                mask_right = model_unet_right(var_image)  #for right lung
                mask_heart = model_unet_heart(var_image)  #for heart
                patchs, patch_labels, globals, global_labels = ROIGeneration(
                    image, [mask_left, mask_right, mask_heart], label)
                #training
                loss_patch, loss_global = torch.FloatTensor(
                    [0.0]).cuda(), torch.FloatTensor([0.0]).cuda()
                if len(patchs) > 0:
                    var_patchs = torch.autograd.Variable(patchs).cuda()
                    var_patch_labels = torch.autograd.Variable(
                        patch_labels).cuda()
                    out_patch = model(var_patchs, is_patch=True)  #forward
                    loss_patch = bce_criterion(out_patch, var_patch_labels)
                    gt = torch.cat((gt, patch_labels.cuda()), 0)
                    pred = torch.cat((pred, out_patch.data), 0)
                """
                if len(globals)>0:
                    var_globals = torch.autograd.Variable(globals).cuda()
                    var_global_labels = torch.autograd.Variable(global_labels).cuda()
                    out_global = model(var_globals, is_patch = False)#forward
                    loss_global = bce_criterion(out_global, var_global_labels)
                    gt = torch.cat((gt, global_labels.cuda()), 0)
                    pred = torch.cat((pred, out_global.data), 0)
                """
                loss_tensor = loss_patch + loss_global
                val_loss.append(loss_tensor.item())
                sys.stdout.write(
                    '\r Epoch: {} / Step: {} : validation loss = {}'.format(
                        epoch + 1, batch_idx + 1,
                        float('%0.6f' % loss_tensor.item())))
                sys.stdout.flush()
        #evaluation
        AUROCs_avg = np.array(compute_AUCs(gt, pred)).mean()
        print("\r Eopch: %5d validation loss = %.6f, average AUROC=%.4f" %
              (epoch + 1, np.mean(val_loss), AUROCs_avg))

        #save checkpoint
        if AUROC_best < AUROCs_avg:
            AUROC_best = AUROCs_avg
            torch.save(model.state_dict(),
                       config['CKPT_PATH'] + 'best_model_CXRNet.pkl')
            print(' Epoch: {} model has been already save!'.format(epoch + 1))

        time_elapsed = time.time() - since
        print('Training epoch: {} completed in {:.0f}m {:.0f}s'.format(
            epoch + 1, time_elapsed // 60, time_elapsed % 60))
예제 #3
0
def Test():
    print('********************load data********************')
    dataloader_test = get_test_dataloader(batch_size=config['BATCH_SIZE'],
                                          shuffle=False,
                                          num_workers=8)
    print('********************load data succeed!********************')

    print('********************load model********************')
    # initialize and load the model
    if args.model == 'CXRNet':
        model = CXRClassifier(num_classes=N_CLASSES,
                              is_pre_trained=True).cuda()
        CKPT_PATH = config['CKPT_PATH'] + 'best_model_CXRNet.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model.load_state_dict(checkpoint)  #strict=False
        print("=> loaded left model checkpoint: " + CKPT_PATH)
        model.eval()
        #for left lung
        model_unet_left = UNet(n_channels=3,
                               n_classes=1).cuda()  #initialize model
        CKPT_PATH = config['CKPT_PATH'] + 'best_unet_left.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_unet_left.load_state_dict(checkpoint)  #strict=False
        print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH)
        model_unet_left.eval()
        #for right lung
        model_unet_right = UNet(n_channels=3,
                                n_classes=1).cuda()  #initialize model
        CKPT_PATH = config['CKPT_PATH'] + 'best_unet_right.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_unet_right.load_state_dict(checkpoint)  #strict=False
        print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH)
        model_unet_right.eval()
        #for heart
        model_unet_heart = UNet(n_channels=3,
                                n_classes=1).cuda()  #initialize model
        CKPT_PATH = config['CKPT_PATH'] + 'best_unet_right.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_unet_heart.load_state_dict(checkpoint)  #strict=False
        print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH)
        model_unet_heart.eval()

    else:
        print('No required model')
        return  #over
    torch.backends.cudnn.benchmark = True  # improve train speed slightly

    print('******* begin testing!*********')
    gt = torch.FloatTensor().cuda()
    pred = torch.FloatTensor().cuda()
    with torch.autograd.no_grad():
        for batch_idx, (image, label) in enumerate(dataloader_test):
            #pathological regions generation
            var_image = torch.autograd.Variable(image).cuda()
            mask_left = model_unet_left(var_image)  #for left lung
            mask_right = model_unet_right(var_image)  #for right lung
            mask_heart = model_unet_heart(var_image)  #for heart
            patchs, patch_labels, globals, global_labels = ROIGeneration(
                image, [mask_left, mask_right, mask_heart], label)
            #training
            if len(patchs) > 0:
                var_patchs = torch.autograd.Variable(patchs).cuda()
                var_patch_labels = torch.autograd.Variable(patch_labels).cuda()
                out_patch = model(var_patchs, is_patch=True)  #forward
                gt = torch.cat((gt, patch_labels.cuda()), 0)
                pred = torch.cat((pred, out_patch.data), 0)
            if len(globals) > 0:
                var_globals = torch.autograd.Variable(globals).cuda()
                var_global_labels = torch.autograd.Variable(
                    global_labels).cuda()
                out_global = model(var_globals, is_patch=False)  #forward
                gt = torch.cat((gt, global_labels.cuda()), 0)
                pred = torch.cat((pred, out_global.data), 0)
            sys.stdout.write('\r testing process: = {}'.format(batch_idx + 1))
            sys.stdout.flush()

    #for evaluation
    AUROC_img = compute_AUCs(gt, pred)
    AUROC_avg = np.array(AUROC_img).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_img[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))
예제 #4
0
def Train():
    print('********************load data********************')
    dataloader_train = get_train_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8)
    dataloader_val = get_validation_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8)
    #dataloader_train, dataloader_val = get_train_val_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8, split_ratio=0.1)
    print('********************load data succeed!********************')

    print('********************load model********************')
    # initialize and load the model
    if args.model == 'CXRNet':
        model_img = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True, is_roi=False).cuda()#initialize model 
        #model_img = nn.DataParallel(model_img).cuda()  # make model available multi GPU cores training
        optimizer_img = optim.Adam(model_img.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
        lr_scheduler_img = lr_scheduler.StepLR(optimizer_img, step_size = 10, gamma = 1)

        roigen = ROIGenerator()

        model_roi = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True, is_roi=True).cuda()
        #model_roi = nn.DataParallel(model_roi).cuda()
        optimizer_roi = optim.Adam(model_roi.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
        lr_scheduler_roi = lr_scheduler.StepLR(optimizer_roi, step_size = 10, gamma = 1)

        model_fusion = FusionClassifier(input_size=2048, output_size=N_CLASSES).cuda()
        #model_fusion = nn.DataParallel(model_fusion).cuda()
        optimizer_fusion = optim.Adam(model_fusion.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
        lr_scheduler_fusion = lr_scheduler.StepLR(optimizer_fusion, step_size = 10, gamma = 1)
    else: 
        print('No required model')
        return #over

    torch.backends.cudnn.benchmark = True  # improve train speed slightly
    bce_criterion = nn.BCELoss() #define binary cross-entropy loss
    print('********************load model succeed!********************')

    print('********************begin training!********************')
    AUROC_best = 0.50
    for epoch in range(config['MAX_EPOCHS']):
        since = time.time()
        print('Epoch {}/{}'.format(epoch+1 , config['MAX_EPOCHS']))
        print('-' * 10)
        model_img.train()  #set model to training mode
        model_roi.train()
        model_fusion.train()
        train_loss = []
        with torch.autograd.enable_grad():
            for batch_idx, (image, label) in enumerate(dataloader_train):
                optimizer_img.zero_grad()
                optimizer_roi.zero_grad() 
                optimizer_fusion.zero_grad() 
                #image-level
                var_image = torch.autograd.Variable(image).cuda()
                var_label = torch.autograd.Variable(label).cuda()
                conv_fea_img, fc_fea_img, out_img = model_img(var_image)#forward
                loss_img = bce_criterion(out_img, var_label)
                #ROI-level
                cls_weights = list(model_img.parameters())
                weight_softmax = np.squeeze(cls_weights[-5].data.cpu().numpy())
                roi = roigen.ROIGeneration(image, conv_fea_img, weight_softmax, label.numpy())
                var_roi = torch.autograd.Variable(roi).cuda()
                _, fc_fea_roi, out_roi = model_roi(var_roi)
                loss_roi = bce_criterion(out_roi, var_label) 
                #Fusion
                fc_fea_fusion = torch.cat((fc_fea_img,fc_fea_roi), 1)
                var_fusion = torch.autograd.Variable(fc_fea_fusion).cuda()
                out_fusion = model_fusion(var_fusion)
                loss_fusion = bce_criterion(out_fusion, var_label) 
                #backward and update parameters 
                loss_tensor = 0.7*loss_img + 0.2*loss_roi + 0.1*loss_fusion
                loss_tensor.backward() 
                optimizer_img.step() 
                optimizer_roi.step()
                optimizer_fusion.step() 
                train_loss.append(loss_tensor.item())
                #print([x.grad for x in optimizer.param_groups[0]['params']])
                sys.stdout.write('\r Epoch: {} / Step: {} : image loss ={}, roi loss ={}, fusion loss = {}, train loss = {}'
                                .format(epoch+1, batch_idx+1, float('%0.6f'%loss_img.item()), float('%0.6f'%loss_roi.item()),
                                float('%0.6f'%loss_fusion.item()), float('%0.6f'%loss_tensor.item()) ))
                sys.stdout.flush()        
        lr_scheduler_img.step()  #about lr and gamma
        lr_scheduler_roi.step()
        lr_scheduler_fusion.step()
        print("\r Eopch: %5d train loss = %.6f" % (epoch + 1, np.mean(train_loss))) 

        model_img.eval() #turn to test mode
        model_roi.eval()
        model_fusion.eval()
        loss_img_all, loss_roi_all, loss_fusion_all = [], [], []
        val_loss = []
        gt = torch.FloatTensor().cuda()
        pred_img = torch.FloatTensor().cuda()
        pred_roi = torch.FloatTensor().cuda()
        pred_fusion = torch.FloatTensor().cuda()
        with torch.autograd.no_grad():
            for batch_idx, (image, label) in enumerate(dataloader_val):
                gt = torch.cat((gt, label.cuda()), 0)
                #image-level
                var_image = torch.autograd.Variable(image).cuda()
                var_label = torch.autograd.Variable(label).cuda()
                conv_fea_img, fc_fea_img, out_img = model_img(var_image)#forward
                loss_img = bce_criterion(out_img, var_label) 
                pred_img = torch.cat((pred_img, out_img.data), 0)
                #ROI-level
                cls_weights = list(model_img.parameters())
                weight_softmax = np.squeeze(cls_weights[-5].data.cpu().numpy())
                roi = roigen.ROIGeneration(image, conv_fea_img, weight_softmax, label.numpy())
                var_roi = torch.autograd.Variable(roi).cuda()
                _, fc_fea_roi, out_roi = model_roi(var_roi)
                loss_roi = bce_criterion(out_roi, var_label) 
                pred_roi = torch.cat((pred_roi, out_roi.data), 0)
                #Fusion
                fc_fea_fusion = torch.cat((fc_fea_img,fc_fea_roi), 1)
                var_fusion = torch.autograd.Variable(fc_fea_fusion).cuda()
                out_fusion = model_fusion(var_fusion)
                loss_fusion = bce_criterion(out_fusion, var_label) 
                pred_fusion = torch.cat((pred_fusion, out_fusion.data), 0)
                #loss
                loss_tensor = 0.7*loss_img + 0.2*loss_roi + 0.1*loss_fusion
                val_loss.append(loss_tensor.item())
                sys.stdout.write('\r Epoch: {} / Step: {} : image loss ={}, roi loss ={}, fusion loss = {}, train loss = {}'
                                .format(epoch+1, batch_idx+1, float('%0.6f'%loss_img.item()), float('%0.6f'%loss_roi.item()),
                                float('%0.6f'%loss_fusion.item()), float('%0.6f'%loss_tensor.item()) ))
                sys.stdout.flush()
                
                loss_img_all.append(loss_img.item())
                loss_roi_all.append(loss_roi.item())
                loss_fusion_all.append(loss_fusion.item())
        #evaluation       
        AUROCs_img = np.array(compute_AUCs(gt, pred_img)).mean()
        AUROCs_roi = np.array(compute_AUCs(gt, pred_roi)).mean()
        AUROCs_fusion = np.array(compute_AUCs(gt, pred_fusion)).mean()
        print("\r Eopch: %5d validation loss = %.6f, Validataion AUROC image=%.4f roi=%.4f fusion=%.4f" 
              % (epoch + 1, np.mean(val_loss), AUROCs_img, AUROCs_roi, AUROCs_fusion)) 

        logger.info("\r Eopch: %5d validation loss = %.4f, image loss = %.4f,  roi loss =%.4f fusion loss =%.4f" 
                     % (epoch + 1, np.mean(val_loss), np.mean(loss_img_all), np.mean(loss_roi_all), np.mean(loss_fusion_all))) 
        #save checkpoint
        if AUROC_best < AUROCs_fusion:
            AUROC_best = AUROCs_fusion
            #torch.save(model.module.state_dict(), CKPT_PATH)
            torch.save(model_img.state_dict(), config['CKPT_PATH'] +  'img_model.pkl') #Saving torch.nn.DataParallel Models
            torch.save(model_roi.state_dict(), config['CKPT_PATH'] + 'roi_model.pkl')
            torch.save(model_fusion.state_dict(), config['CKPT_PATH'] + 'fusion_model.pkl')
            print(' Epoch: {} model has been already save!'.format(epoch+1))
    
        time_elapsed = time.time() - since
        print('Training epoch: {} completed in {:.0f}m {:.0f}s'.format(epoch+1, time_elapsed // 60 , time_elapsed % 60))
예제 #5
0
def Train():
    print('********************load data********************')
    dataloader_train, dataloader_val = get_train_val_dataloader(
        batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8)
    print('********************load data succeed!********************')

    print('********************load model********************')
    # initialize and load the model
    if args.model == 'CXRNet':
        model = CXRNet(num_classes=N_CLASSES,
                       is_pre_trained=True).cuda()  #initialize model
        optimizer_model = optim.Adam(model.parameters(),
                                     lr=1e-3,
                                     betas=(0.9, 0.999),
                                     eps=1e-08,
                                     weight_decay=1e-5)
        lr_scheduler_model = lr_scheduler.StepLR(optimizer_model,
                                                 step_size=10,
                                                 gamma=1)
        torch.backends.cudnn.benchmark = True  # improve train speed slightly
        bce_criterion = nn.BCELoss()  #define binary cross-entropy loss
        #mse_criterion = nn.MSELoss() #define regression loss

        model_unet = UNet(n_channels=3, n_classes=1).cuda()  #initialize model
        CKPT_PATH = config['CKPT_PATH'] + 'best_unet.pkl'
        if os.path.exists(CKPT_PATH):
            checkpoint = torch.load(CKPT_PATH)
            model_unet.load_state_dict(checkpoint)  #strict=False
            print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH)
        model_unet.eval()
    else:
        print('No required model')
        return  #over
    print('********************load model succeed!********************')

    print('********************begin training!********************')
    AUROC_best = 0.50
    for epoch in range(config['MAX_EPOCHS']):
        since = time.time()
        print('Epoch {}/{}'.format(epoch + 1, config['MAX_EPOCHS']))
        print('-' * 10)
        train_loss = []
        model.train()  #set model to training mode
        with torch.autograd.enable_grad():
            for batch_idx, (image, label) in enumerate(dataloader_train):
                optimizer_model.zero_grad()
                var_image = torch.autograd.Variable(image).cuda()
                var_label = torch.autograd.Variable(label).cuda()
                var_mask = model_unet(var_image)
                var_output = model(var_image, var_mask)  #forward
                loss_tensor = bce_criterion(var_output, var_label)
                loss_tensor.backward()
                optimizer_model.step()
                train_loss.append(loss_tensor.item())
                sys.stdout.write(
                    '\r Epoch: {} / Step: {} : train BCE loss = {}'.format(
                        epoch + 1, batch_idx + 1,
                        float('%0.6f' % loss_tensor.item())))
                sys.stdout.flush()
        lr_scheduler_model.step()  #about lr and gamma
        print("\r Eopch: %5d train loss = %.6f" %
              (epoch + 1, np.mean(train_loss)))

        model.eval()  #turn to test mode
        val_loss = []
        gt = torch.FloatTensor().cuda()
        pred = torch.FloatTensor().cuda()
        with torch.autograd.no_grad():
            for batch_idx, (image, label) in enumerate(dataloader_val):
                gt = torch.cat((gt, label.cuda()), 0)
                var_image = torch.autograd.Variable(image).cuda()
                var_label = torch.autograd.Variable(label).cuda()
                var_mask = model_unet(var_image)
                var_output = model(var_image, var_mask)  #forward
                loss_tensor = bce_criterion(var_output, var_label)
                pred = torch.cat((pred, var_output.data), 0)
                val_loss.append(loss_tensor.item())
                sys.stdout.write(
                    '\r Epoch: {} / Step: {} : validation loss = {}'.format(
                        epoch + 1, batch_idx + 1,
                        float('%0.6f' % loss_tensor.item())))
                sys.stdout.flush()
        #evaluation
        AUROCs_avg = np.array(compute_AUCs(gt, pred)).mean()
        logger.info(
            "\r Eopch: %5d validation loss = %.6f, Validataion AUROC image=%.4f"
            % (epoch + 1, np.mean(val_loss), AUROCs_avg))

        #save checkpoint
        if AUROC_best < AUROCs_avg:
            AUROC_best = AUROCs_avg
            torch.save(
                model.state_dict(), config['CKPT_PATH'] +
                'best_model_CXRNet.pkl')  #Saving torch.nn.DataParallel Models
            print(' Epoch: {} model has been already save!'.format(epoch + 1))

        time_elapsed = time.time() - since
        print('Training epoch: {} completed in {:.0f}m {:.0f}s'.format(
            epoch + 1, time_elapsed // 60, time_elapsed % 60))
예제 #6
0
def Test():
    print('********************load data********************')
    dataloader_test = get_test_dataloader(batch_size=config['BATCH_SIZE'], shuffle=False, num_workers=8)
    print('********************load data succeed!********************')

    print('********************load model********************')
    # initialize and load the model
    if args.model == 'CXRNet':
        model_img = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True, is_roi=False).cuda()
        CKPT_PATH = config['CKPT_PATH']  +'img_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_img.load_state_dict(checkpoint) #strict=False
        print("=> loaded Image model checkpoint: "+CKPT_PATH)

        model_roi = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True, is_roi=True).cuda()
        CKPT_PATH = config['CKPT_PATH'] + 'roi_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_roi.load_state_dict(checkpoint) #strict=False
        print("=> loaded ROI model checkpoint: "+CKPT_PATH)

        model_fusion = FusionClassifier(input_size=2048, output_size=N_CLASSES).cuda()
        CKPT_PATH = config['CKPT_PATH'] + 'fusion_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_fusion.load_state_dict(checkpoint) #strict=False
        print("=> loaded Fusion model checkpoint: "+CKPT_PATH)

        roigen = ROIGenerator() #region generator

    else: 
        print('No required model')
        return #over
    torch.backends.cudnn.benchmark = True  # improve train speed slightly
    print('******************** load model succeed!********************')

    print('******* begin testing!*********')
    gt = torch.FloatTensor().cuda()
    pred_img = torch.FloatTensor().cuda()
    pred_roi = torch.FloatTensor().cuda()
    pred_fusion = torch.FloatTensor().cuda()
    # switch to evaluate mode
    model_img.eval() #turn to test mode
    model_roi.eval()
    model_fusion.eval()
    cudnn.benchmark = True
    with torch.autograd.no_grad():
        for batch_idx, (image, label) in enumerate(dataloader_test):
            gt = torch.cat((gt, label.cuda()), 0)
            #image-level
            var_image = torch.autograd.Variable(image).cuda()
            #var_label = torch.autograd.Variable(label).cuda()
            conv_fea_img, fc_fea_img, out_img = model_img(var_image)#forward
            pred_img = torch.cat((pred_img, out_img.data), 0)
            #ROI-level
            #-----predicted label---------------
            shape_l, shape_c = out_img.size()[0], out_img.size()[1]
            pdlabel = torch.FloatTensor(shape_l, shape_c).zero_()
            for i in range(shape_l):
                for j in range(shape_c): 
                    if pdlabel[i,j]>classes_threshold_common[j]:
                        pdlabel[i,j]> = 1.0
            #-----predicted label---------------
            cls_weights = list(model_img.parameters())
            weight_softmax = np.squeeze(cls_weights[-5].data.cpu().numpy())
            roi = roigen.ROIGeneration(image, conv_fea_img, weight_softmax, pdlabel.numpy())
            var_roi = torch.autograd.Variable(roi).cuda()
            _, fc_fea_roi, out_roi = model_roi(var_roi)
            pred_roi = torch.cat((pred_roi, out_roi.data), 0)
            #Fusion
            fc_fea_fusion = torch.cat((fc_fea_img,fc_fea_roi), 1)
            var_fusion = torch.autograd.Variable(fc_fea_fusion).cuda()
            out_fusion = model_fusion(var_fusion)
            pred_fusion = torch.cat((pred_fusion, out_fusion.data), 0)
            sys.stdout.write('\r testing process: = {}'.format(batch_idx+1))
            sys.stdout.flush()

    #for evaluation
    AUROC_img = compute_AUCs(gt, pred_img)
    AUROC_avg = np.array(AUROC_img).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_img[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))

    AUROC_roi = compute_AUCs(gt, pred_roi)
    AUROC_avg = np.array(AUROC_roi).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_roi[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))

    AUROC_fusion = compute_AUCs(gt, pred_fusion)
    AUROC_avg = np.array(AUROC_fusion).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_fusion[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))

    #Evaluating the threshold of prediction
    thresholds = compute_ROCCurve(gt, pred_fusion)
    print(thresholds)
예제 #7
0
def Test():
    print('********************load data********************')
    dataloader_test = get_test_dataloader(batch_size=config['BATCH_SIZE'],
                                          shuffle=False,
                                          num_workers=8)
    print('********************load data succeed!********************')

    print('********************load model********************')
    # initialize and load the model
    if args.model == 'CXRNet':
        model_img = ImageClassifier(num_classes=N_CLASSES,
                                    is_pre_trained=True).cuda()
        CKPT_PATH = config['CKPT_PATH'] + 'img_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_img.load_state_dict(checkpoint)  #strict=False
        print("=> loaded Image model checkpoint: " + CKPT_PATH)

        model_roi = RegionClassifier(num_classes=N_CLASSES,
                                     is_pre_trained=True).cuda()
        CKPT_PATH = config['CKPT_PATH'] + 'roi_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_roi.load_state_dict(checkpoint)  #strict=False
        print("=> loaded ROI model checkpoint: " + CKPT_PATH)

        model_fusion = FusionClassifier(input_size=2048,
                                        output_size=N_CLASSES).cuda()
        CKPT_PATH = config['CKPT_PATH'] + 'fusion_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_fusion.load_state_dict(checkpoint)  #strict=False
        print("=> loaded Fusion model checkpoint: " + CKPT_PATH)

    else:
        print('No required model')
        return  #over
    torch.backends.cudnn.benchmark = True  # improve train speed slightly
    print('******************** load model succeed!********************')

    print('******* begin testing!*********')
    gt = torch.FloatTensor().cuda()
    pred_img = torch.FloatTensor().cuda()
    pred_roi = torch.FloatTensor().cuda()
    pred_fusion = torch.FloatTensor().cuda()
    # switch to evaluate mode
    model_img.eval()  #turn to test mode
    model_roi.eval()
    model_fusion.eval()
    cudnn.benchmark = True
    with torch.autograd.no_grad():
        for batch_idx, (image, label) in enumerate(dataloader_test):
            gt = torch.cat((gt, label.cuda()), 0)
            var_image = torch.autograd.Variable(image).cuda()
            #image-level
            fc_fea_img, out_img = model_img(var_image)  #forward
            pred_img = torch.cat((pred_img, out_img.data), 0)
            #ROI-level
            fc_fea_roi, out_roi = model_roi(var_image)
            pred_roi = torch.cat((pred_roi, out_roi.data), 0)
            #Fusion
            fc_fea_fusion = torch.cat((fc_fea_img, fc_fea_roi), 1)
            var_fusion = torch.autograd.Variable(fc_fea_fusion).cuda()
            out_fusion = model_fusion(var_fusion)
            pred_fusion = torch.cat((pred_fusion, out_fusion.data), 0)
            sys.stdout.write('\r testing process: = {}'.format(batch_idx + 1))
            sys.stdout.flush()

    #for evaluation
    AUROC_img = compute_AUCs(gt, pred_img)
    AUROC_avg = np.array(AUROC_img).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_img[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))

    AUROC_roi = compute_AUCs(gt, pred_roi)
    AUROC_avg = np.array(AUROC_roi).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_roi[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))

    AUROC_fusion = compute_AUCs(gt, pred_fusion)
    AUROC_avg = np.array(AUROC_fusion).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i],
                                                 AUROC_fusion[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))
예제 #8
0
def Train():
    print('********************load data********************')
    dataloader_train = get_train_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8)
    dataloader_val = get_validation_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8)
    #dataloader_train, dataloader_val = get_train_val_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8, split_ratio=0.1)
    print('********************load data succeed!********************')

    print('********************load model********************')
    # initialize and load the model
    if args.model == 'CXRNet':
        #for left_lung
        model_unet_left = UNet(n_channels=3, n_classes=1).cuda()#initialize model 
        CKPT_PATH = config['CKPT_PATH'] +  'best_unet_left.pkl'
        if os.path.exists(CKPT_PATH):
            checkpoint = torch.load(CKPT_PATH)
            model_unet_left.load_state_dict(checkpoint) #strict=False
            print("=> loaded well-trained unet model checkpoint: "+CKPT_PATH)
        model_unet_left.eval()

        model_left = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True).cuda()#initialize model 
        optimizer_left = optim.Adam(model_left.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
        lr_scheduler_left = lr_scheduler.StepLR(optimizer_left, step_size = 10, gamma = 1)
        #for right lung
        model_unet_right = UNet(n_channels=3, n_classes=1).cuda()#initialize model 
        CKPT_PATH = config['CKPT_PATH'] +  'best_unet_right.pkl'
        if os.path.exists(CKPT_PATH):
            checkpoint = torch.load(CKPT_PATH)
            model_unet_right.load_state_dict(checkpoint) #strict=False
            print("=> loaded well-trained unet model checkpoint: "+CKPT_PATH)
        model_unet_right.eval()

        model_right = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True).cuda()#initialize model 
        optimizer_right = optim.Adam(model_right.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
        lr_scheduler_right = lr_scheduler.StepLR(optimizer_right, step_size = 10, gamma = 1)
        #for heart
        model_unet_heart = UNet(n_channels=3, n_classes=1).cuda()#initialize model 
        CKPT_PATH = config['CKPT_PATH'] +  'best_unet_heart.pkl'
        if os.path.exists(CKPT_PATH):
            checkpoint = torch.load(CKPT_PATH)
            model_unet_heart.load_state_dict(checkpoint) #strict=False
            print("=> loaded well-trained unet model checkpoint: "+CKPT_PATH)
        model_unet_heart.eval()

        model_heart = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True).cuda()#initialize model 
        optimizer_heart = optim.Adam(model_heart.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5)
        lr_scheduler_heart = lr_scheduler.StepLR(optimizer_heart, step_size = 10, gamma = 1)
    else: 
        print('No required model')
        return #over

    torch.backends.cudnn.benchmark = True  # improve train speed slightly
    bce_criterion = nn.BCELoss() #define binary cross-entropy loss
    print('********************load model succeed!********************')

    print('********************begin training!********************')
    AUROC_best = 0.50
    for epoch in range(config['MAX_EPOCHS']):
        since = time.time()
        print('Epoch {}/{}'.format(epoch+1 , config['MAX_EPOCHS']))
        print('-' * 10)
        model_left.train()  #set model to training mode
        model_right.train()
        model_heart.train()
        train_loss = []
        with torch.autograd.enable_grad():
            for batch_idx, (image, label) in enumerate(dataloader_train):
                optimizer_left.zero_grad()
                optimizer_right.zero_grad() 
                optimizer_heart.zero_grad() 
                var_image = torch.autograd.Variable(image).cuda()
                var_label = torch.autograd.Variable(label).cuda()
                #for left lung
                mask = model_unet_left(var_image)
                roi = ROIGeneration(image, mask)
                var_roi = torch.autograd.Variable(roi).cuda()
                out_left = model_left(var_roi)#forward
                loss_left = bce_criterion(out_left, var_label)
                loss_left.backward()
                optimizer_left.step()
                #for right lung
                mask = model_unet_right(var_image)
                roi = ROIGeneration(image, mask)
                var_roi = torch.autograd.Variable(roi).cuda()
                out_right = model_right(var_roi)#forward
                loss_right = bce_criterion(out_right, var_label)
                loss_right.backward()
                optimizer_right.step()
                #for heart
                mask = model_unet_heart(var_image)
                roi = ROIGeneration(image, mask)
                var_roi = torch.autograd.Variable(roi).cuda()
                out_heart = model_heart(var_roi)#forward
                loss_heart = bce_criterion(out_heart, var_label)
                loss_heart.backward()
                optimizer_heart.step()
                #loss sum 
                loss_tensor = loss_left + loss_right + loss_heart
                train_loss.append(loss_tensor.item())
                #print([x.grad for x in optimizer.param_groups[0]['params']])
                sys.stdout.write('\r Epoch: {} / Step: {} : train loss = {}'.format(epoch+1, batch_idx+1, float('%0.6f'%loss_tensor.item()) ))
                sys.stdout.flush()        
        lr_scheduler_left.step()  #about lr and gamma
        lr_scheduler_right.step()
        lr_scheduler_heart.step()
        print("\r Eopch: %5d train loss = %.6f" % (epoch + 1, np.mean(train_loss))) 

        model_left.eval() #turn to test mode
        model_right.eval()
        model_heart.eval()
        val_loss = []
        gt = torch.FloatTensor().cuda()
        preds = torch.FloatTensor().cuda()
        with torch.autograd.no_grad():
            for batch_idx, (image, label) in enumerate(dataloader_val):
                pred = torch.FloatTensor().cuda()
                gt = torch.cat((gt, label.cuda()), 0)
                var_image = torch.autograd.Variable(image).cuda()
                var_label = torch.autograd.Variable(label).cuda()
                #for left lung
                mask = model_unet_left(var_image)
                roi = ROIGeneration(image, mask)
                var_roi = torch.autograd.Variable(roi).cuda()
                out_left = model_left(var_roi)#forward
                pred = torch.cat((pred, out_left.data.unsqueeze(0)), 0)
                #for right lung
                mask = model_unet_right(var_image)
                roi = ROIGeneration(image, mask)
                var_roi = torch.autograd.Variable(roi).cuda()
                out_right = model_right(var_roi)#forward
                pred = torch.cat((pred, out_right.data.unsqueeze(0)), 0)
                #for heart
                mask = model_unet_heart(var_image)
                roi = ROIGeneration(image, mask)
                var_roi = torch.autograd.Variable(roi).cuda()
                out_heart = model_heart(var_roi)#forward
                pred = torch.cat((pred, out_heart.data.unsqueeze(0)), 0)
                #prediction
                pred = torch.max(pred, 0)[0] #torch.mean
                preds = torch.cat((preds, pred.data), 0)
                loss_tensor = bce_criterion(pred, var_label)
                val_loss.append(loss_tensor.item())
                sys.stdout.write('\r Epoch: {} / Step: {} : validation loss = {}'.format(epoch+1, batch_idx+1, float('%0.6f'%loss_tensor.item()) ))
                sys.stdout.flush()
        #evaluation       
        AUROCs_avg = np.array(compute_AUCs(gt, preds)).mean()
        print("\r Eopch: %5d validation loss = %.6f, average AUROC=%.4f"% (epoch + 1, np.mean(val_loss), AUROCs_avg)) 

        #save checkpoint
        if AUROC_best < AUROCs_avg:
            AUROC_best = AUROCs_avg
            torch.save(model_img.state_dict(), config['CKPT_PATH'] +  'left_model.pkl') #Saving torch.nn.DataParallel Models
            torch.save(model_roi.state_dict(), config['CKPT_PATH'] + 'right_model.pkl')
            torch.save(model_fusion.state_dict(), config['CKPT_PATH'] + 'heart_model.pkl')
            print(' Epoch: {} model has been already save!'.format(epoch+1))
    
        time_elapsed = time.time() - since
        print('Training epoch: {} completed in {:.0f}m {:.0f}s'.format(epoch+1, time_elapsed // 60 , time_elapsed % 60))
예제 #9
0
def Test():
    print('********************load data********************')
    dataloader_test = get_test_dataloader(batch_size=config['BATCH_SIZE'], shuffle=False, num_workers=8)
    print('********************load data succeed!********************')

    print('********************load model********************')
    # initialize and load the model
    if args.model == 'CXRNet':
        #for left
        model_left = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True).cuda()
        CKPT_PATH = config['CKPT_PATH']  +'left_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_left.load_state_dict(checkpoint) #strict=False
        print("=> loaded left model checkpoint: "+CKPT_PATH)
        model_left.eval()

        model_unet_left = UNet(n_channels=3, n_classes=1).cuda()#initialize model 
        CKPT_PATH = config['CKPT_PATH'] +  'best_unet_left.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_unet_left.load_state_dict(checkpoint) #strict=False
        print("=> loaded well-trained unet model checkpoint: "+CKPT_PATH)
        model_unet_left.eval()

        #for right
        model_right = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True).cuda()
        CKPT_PATH = config['CKPT_PATH']  +'right_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_right.load_state_dict(checkpoint) #strict=False
        print("=> loaded right model checkpoint: "+CKPT_PATH)
        model_right.eval()

        model_unet_right = UNet(n_channels=3, n_classes=1).cuda()#initialize model 
        CKPT_PATH = config['CKPT_PATH'] +  'best_unet_right.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_unet_right.load_state_dict(checkpoint) #strict=False
        print("=> loaded well-trained unet model checkpoint: "+CKPT_PATH)
        model_unet_right.eval()

        #for heart
        model_heart = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True).cuda()
        CKPT_PATH = config['CKPT_PATH']  +'heart_model.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_heart.load_state_dict(checkpoint) #strict=False
        print("=> loaded heart model checkpoint: "+CKPT_PATH)
        model_heart.eval()

        model_unet_heart = UNet(n_channels=3, n_classes=1).cuda()#initialize model 
        CKPT_PATH = config['CKPT_PATH'] +  'best_unet_right.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model_unet_heart.load_state_dict(checkpoint) #strict=False
        print("=> loaded well-trained unet model checkpoint: "+CKPT_PATH)
        model_unet_heart.eval()
        
    else: 
        print('No required model')
        return #over
    torch.backends.cudnn.benchmark = True  # improve train speed slightly
        
    print('******* begin testing!*********')
    gt = torch.FloatTensor().cuda()
    preds = torch.FloatTensor().cuda()
    # switch to evaluate mode
    model_img.eval() #turn to test mode
    with torch.autograd.no_grad():
        for batch_idx, (image, label) in enumerate(dataloader_test):
            gt = torch.cat((gt, label.cuda()), 0)
            pred = torch.FloatTensor().cuda()
            var_image = torch.autograd.Variable(image).cuda()
            var_label = torch.autograd.Variable(label).cuda()
            #for left lung
            mask = model_unet_left(var_image)
            roi = ROIGeneration(image, mask)
            var_roi = torch.autograd.Variable(roi).cuda()
            out_left = model_left(var_roi)#forward
            pred = torch.cat((pred, out_left.data.unsqueeze(0)), 0)
            #for right lung
            mask = model_unet_right(var_image)
            roi = ROIGeneration(image, mask)
            var_roi = torch.autograd.Variable(roi).cuda()
            out_right = model_right(var_roi)#forward
            pred = torch.cat((pred, out_right.data.unsqueeze(0)), 0)
            #for heart
            mask = model_unet_heart(var_image)
            roi = ROIGeneration(image, mask)
            var_roi = torch.autograd.Variable(roi).cuda()
            out_heart = model_heart(var_roi)#forward
            pred = torch.cat((pred, out_heart.data.unsqueeze(0)), 0)
            #prediction
            pred = torch.max(pred, 0)[0] #torch.mean
            preds = torch.cat((preds, pred.data), 0)
            sys.stdout.write('\r testing process: = {}'.format(batch_idx+1))
            sys.stdout.flush()

    #for evaluation
    AUROC_img = compute_AUCs(gt, preds)
    AUROC_avg = np.array(AUROC_img).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_img[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))
예제 #10
0
def Test():
    print('********************load data********************')
    dataloader_test = get_test_dataloader(batch_size=config['BATCH_SIZE'],
                                          shuffle=False,
                                          num_workers=8)
    print('********************load data succeed!********************')

    print('********************load model********************')
    # initialize and load the model
    if args.model == 'CXRNet':
        model = CXRNet(num_classes=N_CLASSES, is_pre_trained=True).cuda()
        CKPT_PATH = config['CKPT_PATH'] + 'best_model_CXRNet.pkl'
        checkpoint = torch.load(CKPT_PATH)
        model.load_state_dict(checkpoint)  #strict=False
        print("=> loaded Image model checkpoint: " + CKPT_PATH)
        torch.backends.cudnn.benchmark = True  # improve train speed slightly

        model_unet = UNet(n_channels=3, n_classes=1).cuda()  #initialize model
        CKPT_PATH = config['CKPT_PATH'] + 'best_unet.pkl'
        if os.path.exists(CKPT_PATH):
            checkpoint = torch.load(CKPT_PATH)
            model_unet.load_state_dict(checkpoint)  #strict=False
            print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH)
        model_unet.eval()
    else:
        print('No required model')
        return  #over
    print('******************** load model succeed!********************')

    print('******* begin testing!*********')
    gt = torch.FloatTensor().cuda()
    pred = torch.FloatTensor().cuda()
    with torch.autograd.no_grad():
        for batch_idx, (image, label) in enumerate(dataloader_test):
            gt = torch.cat((gt, label.cuda()), 0)
            var_image = torch.autograd.Variable(image).cuda()
            var_label = torch.autograd.Variable(label).cuda()

            var_mask = model_unet(var_image)
            var_mask = var_mask.ge(0.5).float()  #0,1 binarization
            mask_np = var_mask.squeeze().cpu().numpy()  #bz*224*224
            patchs = torch.FloatTensor()
            for i in range(0, mask_np.shape[0]):
                mask = mask_np[i]
                ind = np.argwhere(mask != 0)
                if len(ind) > 0:
                    minh = min(ind[:, 0])
                    minw = min(ind[:, 1])
                    maxh = max(ind[:, 0])
                    maxw = max(ind[:, 1])

                    image_crop = image[i].permute(
                        1, 2, 0).squeeze().numpy()  #224*224*3
                    image_crop = image_crop[minh:maxh, minw:maxw, :]
                    image_crop = cv2.resize(
                        image_crop, (config['TRAN_CROP'], config['TRAN_CROP']))
                    image_crop = torch.FloatTensor(image_crop).permute(
                        2, 1, 0).unsqueeze(0)  #1*3*224*224
                    patchs = torch.cat((patchs, image_crop), 0)
                else:
                    image_crop = image[i].unsqueeze(0)
                    patchs = torch.cat((patchs, image_crop), 0)

            var_patchs = torch.autograd.Variable(patchs).cuda()
            var_output = model(var_patchs)  #forward
            pred = torch.cat((pred, var_output.data), 0)
            sys.stdout.write('\r testing process: = {}'.format(batch_idx + 1))
            sys.stdout.flush()

    #for evaluation
    AUROC_all = compute_AUCs(gt, pred)
    AUROC_avg = np.array(AUROC_all).mean()
    for i in range(N_CLASSES):
        print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_all[i]))
    print('The average AUROC is {:.4f}'.format(AUROC_avg))