def BoxTest(): print('********************load data********************') dataloader_bbox = get_bbox_dataloader(batch_size=1, shuffle=False, num_workers=0) print('********************load data succeed!********************') print('********************load model********************') # initialize and load the model if args.model == 'CXRNet': model_img = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True, is_roi=False).cuda() CKPT_PATH = config['CKPT_PATH'] + 'img_model.pkl' checkpoint = torch.load(CKPT_PATH) model_img.load_state_dict(checkpoint) #strict=False print("=> loaded Image model checkpoint: "+CKPT_PATH) else: print('No required model') return #over torch.backends.cudnn.benchmark = True # improve train speed slightly print('******************** load model succeed!********************') print('******* begin bounding box testing!*********') #np.set_printoptions(suppress=True) #to float #for name, layer in model.named_modules(): # print(name, layer) cls_weights = list(model_img.parameters()) weight_softmax = np.squeeze(cls_weights[-5].data.cpu().numpy()) cam = CAM() IoUs = [] IoU_dict = {0:[],1:[],2:[],3:[],4:[],5:[],6:[],7:[]} with torch.autograd.no_grad(): for batch_idx, (_, gtbox, image, label) in enumerate(dataloader_bbox): #if batch_idx != 963: continue var_image = torch.autograd.Variable(image).cuda() conv_fea_img, fc_fea_img, out_img = model_img(var_image) #get feature maps """ logit = out_img.cpu().data.numpy().squeeze() #predict idxs = [] for i in range(N_CLASSES): if logit[i] > thresholds[i]: #diffrent diseases vary in threshold idxs.append(i) """ idx = torch.where(label[0]==1)[0] #true label cam_img = cam.returnCAM(conv_fea_img.cpu().data.numpy(), weight_softmax, idx) pdbox = cam.returnBox(cam_img, gtbox[0].numpy()) iou = compute_IoUs(pdbox, gtbox[0].numpy()) IoU_dict[idx.item()].append(iou) IoUs.append(iou) #compute IoU if iou>0.99: cam.visHeatmap(batch_idx, CLASS_NAMES[idx], image, cam_img, pdbox, gtbox[0].numpy(), iou) #visulization sys.stdout.write('\r box process: = {}'.format(batch_idx+1)) sys.stdout.flush() print('The average IoU is {:.4f}'.format(np.array(IoUs).mean())) for i in range(len(IoU_dict)): print('The average IoU of {} is {:.4f}'.format(CLASS_NAMES[i], np.array(IoU_dict[i]).mean()))
def Train(): print('********************load data********************') dataloader_train = get_train_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8) dataloader_val = get_validation_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8) #dataloader_train, dataloader_val = get_train_val_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8, split_ratio=0.1) print('********************load data succeed!********************') print('********************load model********************') # initialize and load the model if args.model == 'CXRNet': model = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True).cuda() #initialize model optimizer_model = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5) lr_scheduler_model = lr_scheduler.StepLR(optimizer_model, step_size=10, gamma=1) #for left_lung model_unet_left = UNet(n_channels=3, n_classes=1).cuda() #initialize model CKPT_PATH = config['CKPT_PATH'] + 'best_unet_left.pkl' if os.path.exists(CKPT_PATH): checkpoint = torch.load(CKPT_PATH) model_unet_left.load_state_dict(checkpoint) #strict=False print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH) model_unet_left.eval() #for right lung model_unet_right = UNet(n_channels=3, n_classes=1).cuda() #initialize model CKPT_PATH = config['CKPT_PATH'] + 'best_unet_right.pkl' if os.path.exists(CKPT_PATH): checkpoint = torch.load(CKPT_PATH) model_unet_right.load_state_dict(checkpoint) #strict=False print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH) model_unet_right.eval() #for heart model_unet_heart = UNet(n_channels=3, n_classes=1).cuda() #initialize model CKPT_PATH = config['CKPT_PATH'] + 'best_unet_heart.pkl' if os.path.exists(CKPT_PATH): checkpoint = torch.load(CKPT_PATH) model_unet_heart.load_state_dict(checkpoint) #strict=False print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH) model_unet_heart.eval() else: print('No required model') return #over torch.backends.cudnn.benchmark = True # improve train speed slightly bce_criterion = nn.BCELoss() #define binary cross-entropy loss print('********************load model succeed!********************') print('********************begin training!********************') AUROC_best = 0.50 for epoch in range(config['MAX_EPOCHS']): since = time.time() print('Epoch {}/{}'.format(epoch + 1, config['MAX_EPOCHS'])) print('-' * 10) model.train() #set model to training mode train_loss = [] with torch.autograd.enable_grad(): for batch_idx, (image, label) in enumerate(dataloader_train): optimizer_model.zero_grad() #pathological regions generation var_image = torch.autograd.Variable(image).cuda() mask_left = model_unet_left(var_image) #for left lung mask_right = model_unet_right(var_image) #for right lung mask_heart = model_unet_heart(var_image) #for heart patchs, patch_labels, globals, global_labels = ROIGeneration( image, [mask_left, mask_right, mask_heart], label) #training loss_patch, loss_global = torch.FloatTensor( [0.0]).cuda(), torch.FloatTensor([0.0]).cuda() if len(patchs) > 0: var_patchs = torch.autograd.Variable(patchs).cuda() var_patch_labels = torch.autograd.Variable( patch_labels).cuda() out_patch = model(var_patchs, is_patch=True) #forward loss_patch = bce_criterion(out_patch, var_patch_labels) """ if len(globals)>0: var_globals = torch.autograd.Variable(globals).cuda() var_global_labels = torch.autograd.Variable(global_labels).cuda() out_global = model(var_globals, is_patch = False)#forward loss_global = bce_criterion(out_global, var_global_labels) """ loss_tensor = loss_patch + loss_global loss_tensor.backward() optimizer_model.step() train_loss.append(loss_tensor.item()) sys.stdout.write( '\r Epoch: {} / Step: {} : train loss = {}'.format( epoch + 1, batch_idx + 1, float('%0.6f' % loss_tensor.item()))) sys.stdout.flush() lr_scheduler_model.step() #about lr and gamma print("\r Eopch: %5d train loss = %.6f" % (epoch + 1, np.mean(train_loss))) model.eval() #turn to test mode val_loss = [] gt = torch.FloatTensor().cuda() pred = torch.FloatTensor().cuda() with torch.autograd.no_grad(): for batch_idx, (image, label) in enumerate(dataloader_val): #pathological regions generation var_image = torch.autograd.Variable(image).cuda() mask_left = model_unet_left(var_image) #for left lung mask_right = model_unet_right(var_image) #for right lung mask_heart = model_unet_heart(var_image) #for heart patchs, patch_labels, globals, global_labels = ROIGeneration( image, [mask_left, mask_right, mask_heart], label) #training loss_patch, loss_global = torch.FloatTensor( [0.0]).cuda(), torch.FloatTensor([0.0]).cuda() if len(patchs) > 0: var_patchs = torch.autograd.Variable(patchs).cuda() var_patch_labels = torch.autograd.Variable( patch_labels).cuda() out_patch = model(var_patchs, is_patch=True) #forward loss_patch = bce_criterion(out_patch, var_patch_labels) gt = torch.cat((gt, patch_labels.cuda()), 0) pred = torch.cat((pred, out_patch.data), 0) """ if len(globals)>0: var_globals = torch.autograd.Variable(globals).cuda() var_global_labels = torch.autograd.Variable(global_labels).cuda() out_global = model(var_globals, is_patch = False)#forward loss_global = bce_criterion(out_global, var_global_labels) gt = torch.cat((gt, global_labels.cuda()), 0) pred = torch.cat((pred, out_global.data), 0) """ loss_tensor = loss_patch + loss_global val_loss.append(loss_tensor.item()) sys.stdout.write( '\r Epoch: {} / Step: {} : validation loss = {}'.format( epoch + 1, batch_idx + 1, float('%0.6f' % loss_tensor.item()))) sys.stdout.flush() #evaluation AUROCs_avg = np.array(compute_AUCs(gt, pred)).mean() print("\r Eopch: %5d validation loss = %.6f, average AUROC=%.4f" % (epoch + 1, np.mean(val_loss), AUROCs_avg)) #save checkpoint if AUROC_best < AUROCs_avg: AUROC_best = AUROCs_avg torch.save(model.state_dict(), config['CKPT_PATH'] + 'best_model_CXRNet.pkl') print(' Epoch: {} model has been already save!'.format(epoch + 1)) time_elapsed = time.time() - since print('Training epoch: {} completed in {:.0f}m {:.0f}s'.format( epoch + 1, time_elapsed // 60, time_elapsed % 60))
def Train(): print('********************load data********************') dataloader_train = get_train_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8) dataloader_val = get_validation_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8) #dataloader_train, dataloader_val = get_train_val_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8, split_ratio=0.1) print('********************load data succeed!********************') print('********************load model********************') # initialize and load the model if args.model == 'CXRNet': model_img = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True, is_roi=False).cuda()#initialize model #model_img = nn.DataParallel(model_img).cuda() # make model available multi GPU cores training optimizer_img = optim.Adam(model_img.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5) lr_scheduler_img = lr_scheduler.StepLR(optimizer_img, step_size = 10, gamma = 1) roigen = ROIGenerator() model_roi = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True, is_roi=True).cuda() #model_roi = nn.DataParallel(model_roi).cuda() optimizer_roi = optim.Adam(model_roi.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5) lr_scheduler_roi = lr_scheduler.StepLR(optimizer_roi, step_size = 10, gamma = 1) model_fusion = FusionClassifier(input_size=2048, output_size=N_CLASSES).cuda() #model_fusion = nn.DataParallel(model_fusion).cuda() optimizer_fusion = optim.Adam(model_fusion.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5) lr_scheduler_fusion = lr_scheduler.StepLR(optimizer_fusion, step_size = 10, gamma = 1) else: print('No required model') return #over torch.backends.cudnn.benchmark = True # improve train speed slightly bce_criterion = nn.BCELoss() #define binary cross-entropy loss print('********************load model succeed!********************') print('********************begin training!********************') AUROC_best = 0.50 for epoch in range(config['MAX_EPOCHS']): since = time.time() print('Epoch {}/{}'.format(epoch+1 , config['MAX_EPOCHS'])) print('-' * 10) model_img.train() #set model to training mode model_roi.train() model_fusion.train() train_loss = [] with torch.autograd.enable_grad(): for batch_idx, (image, label) in enumerate(dataloader_train): optimizer_img.zero_grad() optimizer_roi.zero_grad() optimizer_fusion.zero_grad() #image-level var_image = torch.autograd.Variable(image).cuda() var_label = torch.autograd.Variable(label).cuda() conv_fea_img, fc_fea_img, out_img = model_img(var_image)#forward loss_img = bce_criterion(out_img, var_label) #ROI-level cls_weights = list(model_img.parameters()) weight_softmax = np.squeeze(cls_weights[-5].data.cpu().numpy()) roi = roigen.ROIGeneration(image, conv_fea_img, weight_softmax, label.numpy()) var_roi = torch.autograd.Variable(roi).cuda() _, fc_fea_roi, out_roi = model_roi(var_roi) loss_roi = bce_criterion(out_roi, var_label) #Fusion fc_fea_fusion = torch.cat((fc_fea_img,fc_fea_roi), 1) var_fusion = torch.autograd.Variable(fc_fea_fusion).cuda() out_fusion = model_fusion(var_fusion) loss_fusion = bce_criterion(out_fusion, var_label) #backward and update parameters loss_tensor = 0.7*loss_img + 0.2*loss_roi + 0.1*loss_fusion loss_tensor.backward() optimizer_img.step() optimizer_roi.step() optimizer_fusion.step() train_loss.append(loss_tensor.item()) #print([x.grad for x in optimizer.param_groups[0]['params']]) sys.stdout.write('\r Epoch: {} / Step: {} : image loss ={}, roi loss ={}, fusion loss = {}, train loss = {}' .format(epoch+1, batch_idx+1, float('%0.6f'%loss_img.item()), float('%0.6f'%loss_roi.item()), float('%0.6f'%loss_fusion.item()), float('%0.6f'%loss_tensor.item()) )) sys.stdout.flush() lr_scheduler_img.step() #about lr and gamma lr_scheduler_roi.step() lr_scheduler_fusion.step() print("\r Eopch: %5d train loss = %.6f" % (epoch + 1, np.mean(train_loss))) model_img.eval() #turn to test mode model_roi.eval() model_fusion.eval() loss_img_all, loss_roi_all, loss_fusion_all = [], [], [] val_loss = [] gt = torch.FloatTensor().cuda() pred_img = torch.FloatTensor().cuda() pred_roi = torch.FloatTensor().cuda() pred_fusion = torch.FloatTensor().cuda() with torch.autograd.no_grad(): for batch_idx, (image, label) in enumerate(dataloader_val): gt = torch.cat((gt, label.cuda()), 0) #image-level var_image = torch.autograd.Variable(image).cuda() var_label = torch.autograd.Variable(label).cuda() conv_fea_img, fc_fea_img, out_img = model_img(var_image)#forward loss_img = bce_criterion(out_img, var_label) pred_img = torch.cat((pred_img, out_img.data), 0) #ROI-level cls_weights = list(model_img.parameters()) weight_softmax = np.squeeze(cls_weights[-5].data.cpu().numpy()) roi = roigen.ROIGeneration(image, conv_fea_img, weight_softmax, label.numpy()) var_roi = torch.autograd.Variable(roi).cuda() _, fc_fea_roi, out_roi = model_roi(var_roi) loss_roi = bce_criterion(out_roi, var_label) pred_roi = torch.cat((pred_roi, out_roi.data), 0) #Fusion fc_fea_fusion = torch.cat((fc_fea_img,fc_fea_roi), 1) var_fusion = torch.autograd.Variable(fc_fea_fusion).cuda() out_fusion = model_fusion(var_fusion) loss_fusion = bce_criterion(out_fusion, var_label) pred_fusion = torch.cat((pred_fusion, out_fusion.data), 0) #loss loss_tensor = 0.7*loss_img + 0.2*loss_roi + 0.1*loss_fusion val_loss.append(loss_tensor.item()) sys.stdout.write('\r Epoch: {} / Step: {} : image loss ={}, roi loss ={}, fusion loss = {}, train loss = {}' .format(epoch+1, batch_idx+1, float('%0.6f'%loss_img.item()), float('%0.6f'%loss_roi.item()), float('%0.6f'%loss_fusion.item()), float('%0.6f'%loss_tensor.item()) )) sys.stdout.flush() loss_img_all.append(loss_img.item()) loss_roi_all.append(loss_roi.item()) loss_fusion_all.append(loss_fusion.item()) #evaluation AUROCs_img = np.array(compute_AUCs(gt, pred_img)).mean() AUROCs_roi = np.array(compute_AUCs(gt, pred_roi)).mean() AUROCs_fusion = np.array(compute_AUCs(gt, pred_fusion)).mean() print("\r Eopch: %5d validation loss = %.6f, Validataion AUROC image=%.4f roi=%.4f fusion=%.4f" % (epoch + 1, np.mean(val_loss), AUROCs_img, AUROCs_roi, AUROCs_fusion)) logger.info("\r Eopch: %5d validation loss = %.4f, image loss = %.4f, roi loss =%.4f fusion loss =%.4f" % (epoch + 1, np.mean(val_loss), np.mean(loss_img_all), np.mean(loss_roi_all), np.mean(loss_fusion_all))) #save checkpoint if AUROC_best < AUROCs_fusion: AUROC_best = AUROCs_fusion #torch.save(model.module.state_dict(), CKPT_PATH) torch.save(model_img.state_dict(), config['CKPT_PATH'] + 'img_model.pkl') #Saving torch.nn.DataParallel Models torch.save(model_roi.state_dict(), config['CKPT_PATH'] + 'roi_model.pkl') torch.save(model_fusion.state_dict(), config['CKPT_PATH'] + 'fusion_model.pkl') print(' Epoch: {} model has been already save!'.format(epoch+1)) time_elapsed = time.time() - since print('Training epoch: {} completed in {:.0f}m {:.0f}s'.format(epoch+1, time_elapsed // 60 , time_elapsed % 60))
def Test(): print('********************load data********************') dataloader_test = get_test_dataloader(batch_size=config['BATCH_SIZE'], shuffle=False, num_workers=8) print('********************load data succeed!********************') print('********************load model********************') # initialize and load the model if args.model == 'CXRNet': model_img = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True, is_roi=False).cuda() CKPT_PATH = config['CKPT_PATH'] +'img_model.pkl' checkpoint = torch.load(CKPT_PATH) model_img.load_state_dict(checkpoint) #strict=False print("=> loaded Image model checkpoint: "+CKPT_PATH) model_roi = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True, is_roi=True).cuda() CKPT_PATH = config['CKPT_PATH'] + 'roi_model.pkl' checkpoint = torch.load(CKPT_PATH) model_roi.load_state_dict(checkpoint) #strict=False print("=> loaded ROI model checkpoint: "+CKPT_PATH) model_fusion = FusionClassifier(input_size=2048, output_size=N_CLASSES).cuda() CKPT_PATH = config['CKPT_PATH'] + 'fusion_model.pkl' checkpoint = torch.load(CKPT_PATH) model_fusion.load_state_dict(checkpoint) #strict=False print("=> loaded Fusion model checkpoint: "+CKPT_PATH) roigen = ROIGenerator() #region generator else: print('No required model') return #over torch.backends.cudnn.benchmark = True # improve train speed slightly print('******************** load model succeed!********************') print('******* begin testing!*********') gt = torch.FloatTensor().cuda() pred_img = torch.FloatTensor().cuda() pred_roi = torch.FloatTensor().cuda() pred_fusion = torch.FloatTensor().cuda() # switch to evaluate mode model_img.eval() #turn to test mode model_roi.eval() model_fusion.eval() cudnn.benchmark = True with torch.autograd.no_grad(): for batch_idx, (image, label) in enumerate(dataloader_test): gt = torch.cat((gt, label.cuda()), 0) #image-level var_image = torch.autograd.Variable(image).cuda() #var_label = torch.autograd.Variable(label).cuda() conv_fea_img, fc_fea_img, out_img = model_img(var_image)#forward pred_img = torch.cat((pred_img, out_img.data), 0) #ROI-level #-----predicted label--------------- shape_l, shape_c = out_img.size()[0], out_img.size()[1] pdlabel = torch.FloatTensor(shape_l, shape_c).zero_() for i in range(shape_l): for j in range(shape_c): if pdlabel[i,j]>classes_threshold_common[j]: pdlabel[i,j]> = 1.0 #-----predicted label--------------- cls_weights = list(model_img.parameters()) weight_softmax = np.squeeze(cls_weights[-5].data.cpu().numpy()) roi = roigen.ROIGeneration(image, conv_fea_img, weight_softmax, pdlabel.numpy()) var_roi = torch.autograd.Variable(roi).cuda() _, fc_fea_roi, out_roi = model_roi(var_roi) pred_roi = torch.cat((pred_roi, out_roi.data), 0) #Fusion fc_fea_fusion = torch.cat((fc_fea_img,fc_fea_roi), 1) var_fusion = torch.autograd.Variable(fc_fea_fusion).cuda() out_fusion = model_fusion(var_fusion) pred_fusion = torch.cat((pred_fusion, out_fusion.data), 0) sys.stdout.write('\r testing process: = {}'.format(batch_idx+1)) sys.stdout.flush() #for evaluation AUROC_img = compute_AUCs(gt, pred_img) AUROC_avg = np.array(AUROC_img).mean() for i in range(N_CLASSES): print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_img[i])) print('The average AUROC is {:.4f}'.format(AUROC_avg)) AUROC_roi = compute_AUCs(gt, pred_roi) AUROC_avg = np.array(AUROC_roi).mean() for i in range(N_CLASSES): print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_roi[i])) print('The average AUROC is {:.4f}'.format(AUROC_avg)) AUROC_fusion = compute_AUCs(gt, pred_fusion) AUROC_avg = np.array(AUROC_fusion).mean() for i in range(N_CLASSES): print('The AUROC of {} is {:.4f}'.format(CLASS_NAMES[i], AUROC_fusion[i])) print('The average AUROC is {:.4f}'.format(AUROC_avg)) #Evaluating the threshold of prediction thresholds = compute_ROCCurve(gt, pred_fusion) print(thresholds)
def Train(): print('********************load data********************') dataloader_train = get_train_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8) dataloader_val = get_validation_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8) #dataloader_train, dataloader_val = get_train_val_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8, split_ratio=0.1) print('********************load data succeed!********************') print('********************load model********************') # initialize and load the model if args.model == 'CXRNet': #for left_lung model_unet_left = UNet(n_channels=3, n_classes=1).cuda()#initialize model CKPT_PATH = config['CKPT_PATH'] + 'best_unet_left.pkl' if os.path.exists(CKPT_PATH): checkpoint = torch.load(CKPT_PATH) model_unet_left.load_state_dict(checkpoint) #strict=False print("=> loaded well-trained unet model checkpoint: "+CKPT_PATH) model_unet_left.eval() model_left = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True).cuda()#initialize model optimizer_left = optim.Adam(model_left.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5) lr_scheduler_left = lr_scheduler.StepLR(optimizer_left, step_size = 10, gamma = 1) #for right lung model_unet_right = UNet(n_channels=3, n_classes=1).cuda()#initialize model CKPT_PATH = config['CKPT_PATH'] + 'best_unet_right.pkl' if os.path.exists(CKPT_PATH): checkpoint = torch.load(CKPT_PATH) model_unet_right.load_state_dict(checkpoint) #strict=False print("=> loaded well-trained unet model checkpoint: "+CKPT_PATH) model_unet_right.eval() model_right = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True).cuda()#initialize model optimizer_right = optim.Adam(model_right.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5) lr_scheduler_right = lr_scheduler.StepLR(optimizer_right, step_size = 10, gamma = 1) #for heart model_unet_heart = UNet(n_channels=3, n_classes=1).cuda()#initialize model CKPT_PATH = config['CKPT_PATH'] + 'best_unet_heart.pkl' if os.path.exists(CKPT_PATH): checkpoint = torch.load(CKPT_PATH) model_unet_heart.load_state_dict(checkpoint) #strict=False print("=> loaded well-trained unet model checkpoint: "+CKPT_PATH) model_unet_heart.eval() model_heart = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True).cuda()#initialize model optimizer_heart = optim.Adam(model_heart.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5) lr_scheduler_heart = lr_scheduler.StepLR(optimizer_heart, step_size = 10, gamma = 1) else: print('No required model') return #over torch.backends.cudnn.benchmark = True # improve train speed slightly bce_criterion = nn.BCELoss() #define binary cross-entropy loss print('********************load model succeed!********************') print('********************begin training!********************') AUROC_best = 0.50 for epoch in range(config['MAX_EPOCHS']): since = time.time() print('Epoch {}/{}'.format(epoch+1 , config['MAX_EPOCHS'])) print('-' * 10) model_left.train() #set model to training mode model_right.train() model_heart.train() train_loss = [] with torch.autograd.enable_grad(): for batch_idx, (image, label) in enumerate(dataloader_train): optimizer_left.zero_grad() optimizer_right.zero_grad() optimizer_heart.zero_grad() var_image = torch.autograd.Variable(image).cuda() var_label = torch.autograd.Variable(label).cuda() #for left lung mask = model_unet_left(var_image) roi = ROIGeneration(image, mask) var_roi = torch.autograd.Variable(roi).cuda() out_left = model_left(var_roi)#forward loss_left = bce_criterion(out_left, var_label) loss_left.backward() optimizer_left.step() #for right lung mask = model_unet_right(var_image) roi = ROIGeneration(image, mask) var_roi = torch.autograd.Variable(roi).cuda() out_right = model_right(var_roi)#forward loss_right = bce_criterion(out_right, var_label) loss_right.backward() optimizer_right.step() #for heart mask = model_unet_heart(var_image) roi = ROIGeneration(image, mask) var_roi = torch.autograd.Variable(roi).cuda() out_heart = model_heart(var_roi)#forward loss_heart = bce_criterion(out_heart, var_label) loss_heart.backward() optimizer_heart.step() #loss sum loss_tensor = loss_left + loss_right + loss_heart train_loss.append(loss_tensor.item()) #print([x.grad for x in optimizer.param_groups[0]['params']]) sys.stdout.write('\r Epoch: {} / Step: {} : train loss = {}'.format(epoch+1, batch_idx+1, float('%0.6f'%loss_tensor.item()) )) sys.stdout.flush() lr_scheduler_left.step() #about lr and gamma lr_scheduler_right.step() lr_scheduler_heart.step() print("\r Eopch: %5d train loss = %.6f" % (epoch + 1, np.mean(train_loss))) model_left.eval() #turn to test mode model_right.eval() model_heart.eval() val_loss = [] gt = torch.FloatTensor().cuda() preds = torch.FloatTensor().cuda() with torch.autograd.no_grad(): for batch_idx, (image, label) in enumerate(dataloader_val): pred = torch.FloatTensor().cuda() gt = torch.cat((gt, label.cuda()), 0) var_image = torch.autograd.Variable(image).cuda() var_label = torch.autograd.Variable(label).cuda() #for left lung mask = model_unet_left(var_image) roi = ROIGeneration(image, mask) var_roi = torch.autograd.Variable(roi).cuda() out_left = model_left(var_roi)#forward pred = torch.cat((pred, out_left.data.unsqueeze(0)), 0) #for right lung mask = model_unet_right(var_image) roi = ROIGeneration(image, mask) var_roi = torch.autograd.Variable(roi).cuda() out_right = model_right(var_roi)#forward pred = torch.cat((pred, out_right.data.unsqueeze(0)), 0) #for heart mask = model_unet_heart(var_image) roi = ROIGeneration(image, mask) var_roi = torch.autograd.Variable(roi).cuda() out_heart = model_heart(var_roi)#forward pred = torch.cat((pred, out_heart.data.unsqueeze(0)), 0) #prediction pred = torch.max(pred, 0)[0] #torch.mean preds = torch.cat((preds, pred.data), 0) loss_tensor = bce_criterion(pred, var_label) val_loss.append(loss_tensor.item()) sys.stdout.write('\r Epoch: {} / Step: {} : validation loss = {}'.format(epoch+1, batch_idx+1, float('%0.6f'%loss_tensor.item()) )) sys.stdout.flush() #evaluation AUROCs_avg = np.array(compute_AUCs(gt, preds)).mean() print("\r Eopch: %5d validation loss = %.6f, average AUROC=%.4f"% (epoch + 1, np.mean(val_loss), AUROCs_avg)) #save checkpoint if AUROC_best < AUROCs_avg: AUROC_best = AUROCs_avg torch.save(model_img.state_dict(), config['CKPT_PATH'] + 'left_model.pkl') #Saving torch.nn.DataParallel Models torch.save(model_roi.state_dict(), config['CKPT_PATH'] + 'right_model.pkl') torch.save(model_fusion.state_dict(), config['CKPT_PATH'] + 'heart_model.pkl') print(' Epoch: {} model has been already save!'.format(epoch+1)) time_elapsed = time.time() - since print('Training epoch: {} completed in {:.0f}m {:.0f}s'.format(epoch+1, time_elapsed // 60 , time_elapsed % 60))