def Train(): print('********************load data********************') dataloader_train = get_train_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8) dataloader_val = get_validation_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8) #dataloader_train, dataloader_val = get_train_val_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8, split_ratio=0.1) print('********************load data succeed!********************') print('********************load model********************') # initialize and load the model if args.model == 'CXRNet': model = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True).cuda() #initialize model optimizer_model = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5) lr_scheduler_model = lr_scheduler.StepLR(optimizer_model, step_size=10, gamma=1) #for left_lung model_unet_left = UNet(n_channels=3, n_classes=1).cuda() #initialize model CKPT_PATH = config['CKPT_PATH'] + 'best_unet_left.pkl' if os.path.exists(CKPT_PATH): checkpoint = torch.load(CKPT_PATH) model_unet_left.load_state_dict(checkpoint) #strict=False print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH) model_unet_left.eval() #for right lung model_unet_right = UNet(n_channels=3, n_classes=1).cuda() #initialize model CKPT_PATH = config['CKPT_PATH'] + 'best_unet_right.pkl' if os.path.exists(CKPT_PATH): checkpoint = torch.load(CKPT_PATH) model_unet_right.load_state_dict(checkpoint) #strict=False print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH) model_unet_right.eval() #for heart model_unet_heart = UNet(n_channels=3, n_classes=1).cuda() #initialize model CKPT_PATH = config['CKPT_PATH'] + 'best_unet_heart.pkl' if os.path.exists(CKPT_PATH): checkpoint = torch.load(CKPT_PATH) model_unet_heart.load_state_dict(checkpoint) #strict=False print("=> loaded well-trained unet model checkpoint: " + CKPT_PATH) model_unet_heart.eval() else: print('No required model') return #over torch.backends.cudnn.benchmark = True # improve train speed slightly bce_criterion = nn.BCELoss() #define binary cross-entropy loss print('********************load model succeed!********************') print('********************begin training!********************') AUROC_best = 0.50 for epoch in range(config['MAX_EPOCHS']): since = time.time() print('Epoch {}/{}'.format(epoch + 1, config['MAX_EPOCHS'])) print('-' * 10) model.train() #set model to training mode train_loss = [] with torch.autograd.enable_grad(): for batch_idx, (image, label) in enumerate(dataloader_train): optimizer_model.zero_grad() #pathological regions generation var_image = torch.autograd.Variable(image).cuda() mask_left = model_unet_left(var_image) #for left lung mask_right = model_unet_right(var_image) #for right lung mask_heart = model_unet_heart(var_image) #for heart patchs, patch_labels, globals, global_labels = ROIGeneration( image, [mask_left, mask_right, mask_heart], label) #training loss_patch, loss_global = torch.FloatTensor( [0.0]).cuda(), torch.FloatTensor([0.0]).cuda() if len(patchs) > 0: var_patchs = torch.autograd.Variable(patchs).cuda() var_patch_labels = torch.autograd.Variable( patch_labels).cuda() out_patch = model(var_patchs, is_patch=True) #forward loss_patch = bce_criterion(out_patch, var_patch_labels) """ if len(globals)>0: var_globals = torch.autograd.Variable(globals).cuda() var_global_labels = torch.autograd.Variable(global_labels).cuda() out_global = model(var_globals, is_patch = False)#forward loss_global = bce_criterion(out_global, var_global_labels) """ loss_tensor = loss_patch + loss_global loss_tensor.backward() optimizer_model.step() train_loss.append(loss_tensor.item()) sys.stdout.write( '\r Epoch: {} / Step: {} : train loss = {}'.format( epoch + 1, batch_idx + 1, float('%0.6f' % loss_tensor.item()))) sys.stdout.flush() lr_scheduler_model.step() #about lr and gamma print("\r Eopch: %5d train loss = %.6f" % (epoch + 1, np.mean(train_loss))) model.eval() #turn to test mode val_loss = [] gt = torch.FloatTensor().cuda() pred = torch.FloatTensor().cuda() with torch.autograd.no_grad(): for batch_idx, (image, label) in enumerate(dataloader_val): #pathological regions generation var_image = torch.autograd.Variable(image).cuda() mask_left = model_unet_left(var_image) #for left lung mask_right = model_unet_right(var_image) #for right lung mask_heart = model_unet_heart(var_image) #for heart patchs, patch_labels, globals, global_labels = ROIGeneration( image, [mask_left, mask_right, mask_heart], label) #training loss_patch, loss_global = torch.FloatTensor( [0.0]).cuda(), torch.FloatTensor([0.0]).cuda() if len(patchs) > 0: var_patchs = torch.autograd.Variable(patchs).cuda() var_patch_labels = torch.autograd.Variable( patch_labels).cuda() out_patch = model(var_patchs, is_patch=True) #forward loss_patch = bce_criterion(out_patch, var_patch_labels) gt = torch.cat((gt, patch_labels.cuda()), 0) pred = torch.cat((pred, out_patch.data), 0) """ if len(globals)>0: var_globals = torch.autograd.Variable(globals).cuda() var_global_labels = torch.autograd.Variable(global_labels).cuda() out_global = model(var_globals, is_patch = False)#forward loss_global = bce_criterion(out_global, var_global_labels) gt = torch.cat((gt, global_labels.cuda()), 0) pred = torch.cat((pred, out_global.data), 0) """ loss_tensor = loss_patch + loss_global val_loss.append(loss_tensor.item()) sys.stdout.write( '\r Epoch: {} / Step: {} : validation loss = {}'.format( epoch + 1, batch_idx + 1, float('%0.6f' % loss_tensor.item()))) sys.stdout.flush() #evaluation AUROCs_avg = np.array(compute_AUCs(gt, pred)).mean() print("\r Eopch: %5d validation loss = %.6f, average AUROC=%.4f" % (epoch + 1, np.mean(val_loss), AUROCs_avg)) #save checkpoint if AUROC_best < AUROCs_avg: AUROC_best = AUROCs_avg torch.save(model.state_dict(), config['CKPT_PATH'] + 'best_model_CXRNet.pkl') print(' Epoch: {} model has been already save!'.format(epoch + 1)) time_elapsed = time.time() - since print('Training epoch: {} completed in {:.0f}m {:.0f}s'.format( epoch + 1, time_elapsed // 60, time_elapsed % 60))
def Train(): print('********************load data********************') dataloader_train = get_train_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8) dataloader_val = get_validation_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8) #dataloader_train, dataloader_val = get_train_val_dataloader(batch_size=config['BATCH_SIZE'], shuffle=True, num_workers=8, split_ratio=0.1) print('********************load data succeed!********************') print('********************load model********************') # initialize and load the model if args.model == 'CXRNet': model_img = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True, is_roi=False).cuda()#initialize model #model_img = nn.DataParallel(model_img).cuda() # make model available multi GPU cores training optimizer_img = optim.Adam(model_img.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5) lr_scheduler_img = lr_scheduler.StepLR(optimizer_img, step_size = 10, gamma = 1) roigen = ROIGenerator() model_roi = CXRClassifier(num_classes=N_CLASSES, is_pre_trained=True, is_roi=True).cuda() #model_roi = nn.DataParallel(model_roi).cuda() optimizer_roi = optim.Adam(model_roi.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5) lr_scheduler_roi = lr_scheduler.StepLR(optimizer_roi, step_size = 10, gamma = 1) model_fusion = FusionClassifier(input_size=2048, output_size=N_CLASSES).cuda() #model_fusion = nn.DataParallel(model_fusion).cuda() optimizer_fusion = optim.Adam(model_fusion.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-5) lr_scheduler_fusion = lr_scheduler.StepLR(optimizer_fusion, step_size = 10, gamma = 1) else: print('No required model') return #over torch.backends.cudnn.benchmark = True # improve train speed slightly bce_criterion = nn.BCELoss() #define binary cross-entropy loss print('********************load model succeed!********************') print('********************begin training!********************') AUROC_best = 0.50 for epoch in range(config['MAX_EPOCHS']): since = time.time() print('Epoch {}/{}'.format(epoch+1 , config['MAX_EPOCHS'])) print('-' * 10) model_img.train() #set model to training mode model_roi.train() model_fusion.train() train_loss = [] with torch.autograd.enable_grad(): for batch_idx, (image, label) in enumerate(dataloader_train): optimizer_img.zero_grad() optimizer_roi.zero_grad() optimizer_fusion.zero_grad() #image-level var_image = torch.autograd.Variable(image).cuda() var_label = torch.autograd.Variable(label).cuda() conv_fea_img, fc_fea_img, out_img = model_img(var_image)#forward loss_img = bce_criterion(out_img, var_label) #ROI-level cls_weights = list(model_img.parameters()) weight_softmax = np.squeeze(cls_weights[-5].data.cpu().numpy()) roi = roigen.ROIGeneration(image, conv_fea_img, weight_softmax, label.numpy()) var_roi = torch.autograd.Variable(roi).cuda() _, fc_fea_roi, out_roi = model_roi(var_roi) loss_roi = bce_criterion(out_roi, var_label) #Fusion fc_fea_fusion = torch.cat((fc_fea_img,fc_fea_roi), 1) var_fusion = torch.autograd.Variable(fc_fea_fusion).cuda() out_fusion = model_fusion(var_fusion) loss_fusion = bce_criterion(out_fusion, var_label) #backward and update parameters loss_tensor = 0.7*loss_img + 0.2*loss_roi + 0.1*loss_fusion loss_tensor.backward() optimizer_img.step() optimizer_roi.step() optimizer_fusion.step() train_loss.append(loss_tensor.item()) #print([x.grad for x in optimizer.param_groups[0]['params']]) sys.stdout.write('\r Epoch: {} / Step: {} : image loss ={}, roi loss ={}, fusion loss = {}, train loss = {}' .format(epoch+1, batch_idx+1, float('%0.6f'%loss_img.item()), float('%0.6f'%loss_roi.item()), float('%0.6f'%loss_fusion.item()), float('%0.6f'%loss_tensor.item()) )) sys.stdout.flush() lr_scheduler_img.step() #about lr and gamma lr_scheduler_roi.step() lr_scheduler_fusion.step() print("\r Eopch: %5d train loss = %.6f" % (epoch + 1, np.mean(train_loss))) model_img.eval() #turn to test mode model_roi.eval() model_fusion.eval() loss_img_all, loss_roi_all, loss_fusion_all = [], [], [] val_loss = [] gt = torch.FloatTensor().cuda() pred_img = torch.FloatTensor().cuda() pred_roi = torch.FloatTensor().cuda() pred_fusion = torch.FloatTensor().cuda() with torch.autograd.no_grad(): for batch_idx, (image, label) in enumerate(dataloader_val): gt = torch.cat((gt, label.cuda()), 0) #image-level var_image = torch.autograd.Variable(image).cuda() var_label = torch.autograd.Variable(label).cuda() conv_fea_img, fc_fea_img, out_img = model_img(var_image)#forward loss_img = bce_criterion(out_img, var_label) pred_img = torch.cat((pred_img, out_img.data), 0) #ROI-level cls_weights = list(model_img.parameters()) weight_softmax = np.squeeze(cls_weights[-5].data.cpu().numpy()) roi = roigen.ROIGeneration(image, conv_fea_img, weight_softmax, label.numpy()) var_roi = torch.autograd.Variable(roi).cuda() _, fc_fea_roi, out_roi = model_roi(var_roi) loss_roi = bce_criterion(out_roi, var_label) pred_roi = torch.cat((pred_roi, out_roi.data), 0) #Fusion fc_fea_fusion = torch.cat((fc_fea_img,fc_fea_roi), 1) var_fusion = torch.autograd.Variable(fc_fea_fusion).cuda() out_fusion = model_fusion(var_fusion) loss_fusion = bce_criterion(out_fusion, var_label) pred_fusion = torch.cat((pred_fusion, out_fusion.data), 0) #loss loss_tensor = 0.7*loss_img + 0.2*loss_roi + 0.1*loss_fusion val_loss.append(loss_tensor.item()) sys.stdout.write('\r Epoch: {} / Step: {} : image loss ={}, roi loss ={}, fusion loss = {}, train loss = {}' .format(epoch+1, batch_idx+1, float('%0.6f'%loss_img.item()), float('%0.6f'%loss_roi.item()), float('%0.6f'%loss_fusion.item()), float('%0.6f'%loss_tensor.item()) )) sys.stdout.flush() loss_img_all.append(loss_img.item()) loss_roi_all.append(loss_roi.item()) loss_fusion_all.append(loss_fusion.item()) #evaluation AUROCs_img = np.array(compute_AUCs(gt, pred_img)).mean() AUROCs_roi = np.array(compute_AUCs(gt, pred_roi)).mean() AUROCs_fusion = np.array(compute_AUCs(gt, pred_fusion)).mean() print("\r Eopch: %5d validation loss = %.6f, Validataion AUROC image=%.4f roi=%.4f fusion=%.4f" % (epoch + 1, np.mean(val_loss), AUROCs_img, AUROCs_roi, AUROCs_fusion)) logger.info("\r Eopch: %5d validation loss = %.4f, image loss = %.4f, roi loss =%.4f fusion loss =%.4f" % (epoch + 1, np.mean(val_loss), np.mean(loss_img_all), np.mean(loss_roi_all), np.mean(loss_fusion_all))) #save checkpoint if AUROC_best < AUROCs_fusion: AUROC_best = AUROCs_fusion #torch.save(model.module.state_dict(), CKPT_PATH) torch.save(model_img.state_dict(), config['CKPT_PATH'] + 'img_model.pkl') #Saving torch.nn.DataParallel Models torch.save(model_roi.state_dict(), config['CKPT_PATH'] + 'roi_model.pkl') torch.save(model_fusion.state_dict(), config['CKPT_PATH'] + 'fusion_model.pkl') print(' Epoch: {} model has been already save!'.format(epoch+1)) time_elapsed = time.time() - since print('Training epoch: {} completed in {:.0f}m {:.0f}s'.format(epoch+1, time_elapsed // 60 , time_elapsed % 60))