def main(): args = get_args() os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus SEED = 42 utils.set_global_seed(SEED) utils.prepare_cudnn(deterministic=True) num_classes = 14 #define datasets train_dataset = ChestXrayDataSet( data_dir=args.path_to_images, image_list_file=args.train_list, transform=transforms_train, ) val_dataset = ChestXrayDataSet( data_dir=args.path_to_images, image_list_file=args.val_list, transform=transforms_val, ) loaders = { 'train': DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers), 'valid': DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=args.num_workers) } logdir = args.log_dir #where model weights and logs are stored #define model model = DenseNet121(num_classes) if len(args.gpus) > 1: model = nn.DataParallel(model) device = utils.get_device() runner = SupervisedRunner(device=device) optimizer = RAdam(model.parameters(), lr=args.lr, weight_decay=0.0003) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.25, patience=2) weights = torch.Tensor( [10, 100, 30, 8, 40, 40, 330, 140, 35, 155, 110, 250, 155, 200]).to(device) criterion = BCEWithLogitsLoss(pos_weight=weights) class_names = [ 'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia' ] runner.train( model=model, logdir=logdir, criterion=criterion, optimizer=optimizer, scheduler=scheduler, loaders=loaders, num_epochs=args.epochs, # We can specify the callbacks list for the experiment; # For this task, we will check AUC and accuracy callbacks=[ AUCCallback( input_key="targets", output_key='logits', prefix='auc', class_names=class_names, num_classes=num_classes, activation='Sigmoid', ), AccuracyCallback( input_key="targets", output_key="logits", prefix="accuracy", accuracy_args=[1], num_classes=14, threshold=0.5, activation='Sigmoid', ), ], main_metric='auc/_mean', minimize_metric=False, verbose=True, )
def main(): cudnn.benchmark = True model = DenseNet121(N_CLASSES).cuda() model = torch.nn.DataParallel(model).cuda() # data preprocess normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_transform = transforms.Compose([ transforms.Resize(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) test_transform = transforms.Compose([ transforms.Resize(256), # crop ten images from original transforms.TenCrop(224), transforms.Lambda(lambda crops: torch.stack( [transforms.ToTensor()(crop) for crop in crops])), transforms.Lambda( lambda crops: torch.stack([normalize(crop) for crop in crops])) ]) # load data if params["aug_dataset"]: train_dataset = ChestXrayDataSetWithAugmentation( data_dir=DATA_DIR, image_list_file=TRAIN_IMAGE_LIST, transform=train_transform) else: train_dataset = ChestXrayDataSet(data_dir=DATA_DIR, image_list_file=TRAIN_IMAGE_LIST, transform=train_transform) train_loader = DataLoader(dataset=train_dataset, batch_size=params["train_batch_size"], shuffle=True, num_workers=8, pin_memory=True) train_evaluation_dataset = ChestXrayDataSet( data_dir=DATA_DIR, image_list_file=TRAIN_IMAGE_LIST, transform=test_transform) train_evaluation_loader = DataLoader(dataset=train_evaluation_dataset, batch_size=params["test_batch_size"], shuffle=False, num_workers=8, pin_memory=True) dev_dataset = ChestXrayDataSet(data_dir=DATA_DIR, image_list_file=DEV_IMAGE_LIST, transform=test_transform) dev_loader = DataLoader(dataset=dev_dataset, batch_size=params["test_batch_size"], shuffle=False, num_workers=8, pin_memory=True) test_dataset = ChestXrayDataSet(data_dir=DATA_DIR, image_list_file=TEST_IMAGE_LIST, transform=test_transform) test_loader = DataLoader(dataset=test_dataset, batch_size=params["test_batch_size"], shuffle=False, num_workers=8, pin_memory=True) optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"], weight_decay=params["beta"]) def train(epoch): print("start training epoch %d" % (epoch)) start_time = time() local_step = 0 running_loss = 0 running_loss_list = [] model.train() if not params["aug_dataset"]: pos_weight = torch.tensor(train_dataset.pos_weight, dtype=torch.float32).cuda() neg_weight = torch.tensor(train_dataset.neg_weight, dtype=torch.float32).cuda() class_weight = torch.tensor(train_dataset.class_weight, dtype=torch.float32).cuda() for i, (inp, target) in enumerate(train_loader): inp = inp.cuda() target = target.cuda() optimizer.zero_grad() output = model(inp) if params["loss_function"] == "unweighted": local_loss = F.binary_cross_entropy(output, target) elif params["loss_function"] == "weighted1": local_loss = weighted_binary_cross_entropy( output, target, pos_weight, neg_weight) elif params["loss_function"] == "weighted2": local_loss = weighted_binary_cross_entropy2( output, target, class_weight) else: assert False running_loss += local_loss.item() local_loss.backward() optimizer.step() if (i + 1) % PRINT_FREQ == 0: running_loss /= PRINT_FREQ print("epoch %d, batch %d/%d, loss: %.5f" % (epoch, i + 1, len(train_loader), running_loss)) running_loss_list.append(running_loss) running_loss = 0 print("end training epoch %d, time elapsed: %.2fmin" % (epoch, (time() - start_time) / 60)) return dict(running_loss_list=running_loss_list) def evaluate(epoch, dataset_loader, pytorch_dataset, dataset_name): print("start evaluating epoch %d on %s" % (epoch, dataset_name)) gt = torch.tensor([], dtype=torch.float32, device="cuda") pred = torch.tensor([], dtype=torch.float32, device="cuda") loss = 0. model.eval() if not params["aug_dataset"]: pos_weight = torch.tensor(pytorch_dataset.pos_weight, dtype=torch.float32).cuda() neg_weight = torch.tensor(pytorch_dataset.neg_weight, dtype=torch.float32).cuda() class_weight = torch.tensor(train_dataset.class_weight, dtype=torch.float32).cuda() with torch.no_grad(): for i, (inp, target) in enumerate(dataset_loader): target = target.cuda() gt = torch.cat((gt, target), 0) bs, n_crops, c, h, w = inp.size() inp_reshaped = inp.view(-1, c, h, w).cuda() output = model(inp_reshaped) output_mean = output.view(bs, n_crops, -1).mean(1) pred = torch.cat((pred, output_mean), 0) if params["loss_function"] == "unweighted": local_loss = F.binary_cross_entropy(output_mean, target) elif params["loss_function"] == "weighted1": local_loss = weighted_binary_cross_entropy( output_mean, target, pos_weight, neg_weight) elif params["loss_function"] == "weighted2": local_loss = weighted_binary_cross_entropy2( output_mean, target, class_weight) else: assert False loss += local_loss * len(target) / len(pytorch_dataset) AUROCs = compute_AUCs(gt, pred, N_CLASSES) AUROC_avg = np.array(AUROCs).mean() print("epoch %d, %s, loss: %.5f, avg_AUC: %.5f" % (epoch, dataset_name, loss, AUROC_avg)) print("epoch %d, %s, individual class AUC" % (epoch, dataset_name)) for i in range(N_CLASSES): print('\tthe AUROC of %s is %.5f' % (CLASS_NAMES[i], AUROCs[i])) return dict(auroc=dict(zip(CLASS_NAMES, AUROCs)), auroc_avg=AUROC_avg, loss=loss.item()) def init_history(): return dict(epoch=0, train_eval_vals_list=[], dev_eval_vals_list=[], best_dev_eval_vals=dict(auroc_avg=-np.inf, loss=np.inf), best_dev_eval_vals_epoch=-1) def update_history(history, epoch, train_eval_vals, dev_eval_vals): history["epoch"] = epoch history["train_eval_vals_list"].append(train_eval_vals) history["dev_eval_vals_list"].append(dev_eval_vals) if dev_eval_vals["auroc_avg"] > history["best_dev_eval_vals"][ "auroc_avg"]: history["best_dev_eval_vals"] = dev_eval_vals history["best_dev_eval_vals_epoch"] = epoch if epoch >= 1: print("saving model...") state_dict = model.state_dict() torch.save(state_dict, params["best_model_file_path"]) if epoch >= 1: state_dict = model.state_dict() torch.save(state_dict, params["model_file_path"] % (epoch)) with open(params["history_file_path"], 'wb') as f: pickle.dump(history, f) def train_initialization(): if not os.path.exists(params["base_dir"]): os.mkdir(params["base_dir"]) with open(params["params_file_path"], 'w') as f: yaml.dump(params, f, default_flow_style=False) if os.path.exists(params["history_file_path"]): with open(params["history_file_path"], 'rb') as f: old_history = pickle.load(f) last_epoch = old_history["epoch"] if last_epoch > params["epochs"]: print("training completed") exit(0) model_file = params["model_file_path"] % (last_epoch) if os.path.exists(model_file): model.load_state_dict(torch.load(model_file)) print("training resumed from epoch %d" % last_epoch) return old_history, last_epoch return init_history(), 0 history, last_epoch = train_initialization() for epoch in range(last_epoch + 1, params["epochs"] + 1): train_eval_vals = train(epoch) train_eval_vals2 = evaluate(epoch, train_evaluation_loader, train_evaluation_dataset, "train set") dev_eval_vals = evaluate(epoch, dev_loader, dev_dataset, "dev set") update_history(history, epoch, train_eval_vals, dev_eval_vals) print("training completed")
def main(): N_CLASSES = 14 CLASS_NAMES = ['Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass', 'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'] # initialize model device = utils.get_device() model = DenseNet121(N_CLASSES).to(device) checkpoint = torch.load(args.checkpoint) model.load_state_dict(checkpoint['model_state_dict']) # initialize test loader test_dataset = ChestXrayDataSet(data_dir=args.path_to_images, image_list_file=args.test_list, transform=transforms_test) test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=True) # initialize the ground truth and output tensor gt = torch.FloatTensor() gt = gt.cuda() pred = torch.FloatTensor() pred = pred.cuda() # switch to evaluate mode model.eval() with torch.no_grad(): for i, (inp, target) in enumerate(test_loader): target = target.cuda() gt = torch.cat((gt, target), 0) bs, c, h, w = inp.size() input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda()) output = model(input_var) output_mean = output.view(bs, -1) pred = torch.cat((pred, output_mean.data), 0) gt_np = gt.cpu().numpy() pred_np = sigmoid(pred.cpu().numpy()) Y_t = [] #labels for each anomaly for i in range(N_CLASSES): Y_t.append([]) for x in gt_np: Y_t[i].append(x[i]) Y_pred = [] #preds for each anomaly for j in range(N_CLASSES): Y_pred.append([]) for y in pred_np: Y_pred[j].append(y[j]) AUCs = [] # AUCs for each for i in range(N_CLASSES): auc = roc_auc_score(Y_t[i], Y_pred[i]) AUCs.append(auc) matrices=[] #for each for i in range(14): matrix = confusion_matrix(Y_t[i], np.asarray(Y_pred[i])>0.6) matrices.append(matrix) class_names = ['no disease', 'disease'] fig = plt.figure(figsize = (20,20)) for i in range(14): plt.subplot(4,4,i+1) df_cm = pd.DataFrame( matrices[i], index=class_names, columns=class_names) heatmap = sns.heatmap(df_cm, annot=True, fmt="d").set_title(CLASS_NAMES[i]) plt.ylabel('True label') plt.xlabel('Predicted label') plt.show() fig.savefig(os.path.join(args.test_outdir,'confusion_matrix.pdf')) fig, axes2d = plt.subplots(nrows=2, ncols=7, sharex=True, sharey=True,figsize = (12, 4)) for i, row in enumerate(axes2d): for j, cell in enumerate(row): if i==0: x=i+j else: x=13-i*j fpr, tpr, threshold = roc_curve(Y_t[x], Y_pred[x]) roc_auc = auc(fpr, tpr) cell.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc) cell.legend(loc = 'lower right', handlelength=0,handletextpad=0,frameon=False, prop={'size': 8}) cell.plot([0, 1], [0, 1],'r--') plt.xlim([0, 1]) plt.ylim([0, 1]) cell.set_title(CLASS_NAMES[x],fontsize=10) if i == len(axes2d) - 1: cell.set_xlabel('False positive rate') if j == 0: cell.set_ylabel('True negative rate') fig.tight_layout(pad=1.0) plt.show() fig.savefig(os.path.join(args.test_outdir,'roc_auc.pdf'))
DATA_DIR = '/home/lrh/dataset/ChestXray-NIHCC/images_v1_small' testTXTFile = '/home/lrh/git/CheXNet/ChestX-ray14/labels/test.txt' trainTXTFile = '/home/lrh/git/CheXNet/ChestX-ray14/labels/train.txt' normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ]) test_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ]) train_dataset = ChestXrayDataSet(data_dir=DATA_DIR, image_list_file=trainTXTFile, transform=train_transform) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=16, shuffle=True, num_workers=8, pin_memory=True) test_dataset = ChestXrayDataSet(data_dir=DATA_DIR, image_list_file=testTXTFile, transform=test_transform) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=16, shuffle=False, num_workers=8, pin_memory=True) print('\ndone')
def train(self, TRAIN_IMAGE_LIST, VAL_IMAGE_LIST, NUM_EPOCHS=10, LR=0.001, BATCH_SIZE=64, start_epoch=0, logging=True, save_path=None, freeze_feature_layers=True): """ Train the CovidAID """ normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_dataset = ChestXrayDataSet( image_list_file=TRAIN_IMAGE_LIST, transform=transforms.Compose([ transforms.Resize(256), transforms.TenCrop(224), transforms.Lambda(lambda crops: torch.stack( [transforms.ToTensor()(crop) for crop in crops])), transforms.Lambda(lambda crops: torch.stack( [normalize(crop) for crop in crops])) ]), combine_pneumonia=self.combine_pneumonia) if self.distributed: sampler = DistributedSampler(train_dataset) train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True, sampler=sampler) else: train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True) val_dataset = ChestXrayDataSet( image_list_file=VAL_IMAGE_LIST, transform=transforms.Compose([ transforms.Resize(256), transforms.TenCrop(224), transforms.Lambda(lambda crops: torch.stack( [transforms.ToTensor()(crop) for crop in crops])), transforms.Lambda(lambda crops: torch.stack( [normalize(crop) for crop in crops])) ]), combine_pneumonia=self.combine_pneumonia) if self.distributed: sampler = DistributedSampler(val_dataset) val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=True, sampler=sampler) else: val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=8, pin_memory=True) # Freeze heads and create optimizer if freeze_feature_layers: print("Freezing feature layers") for param in self.net.densenet121.features.parameters(): param.requires_grad = False # optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.net.parameters()), # lr=LR, momentum=0.9) optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=LR) for epoch in range(start_epoch, NUM_EPOCHS): # switch to train mode self.net.train() tot_loss = 0.0 for i, (inputs, target) in tqdm(enumerate(train_loader), total=len(train_dataset) / BATCH_SIZE): # inputs = inputs.to(self.device) # target = target.to(self.device) inputs = inputs.cuda() target = target.cuda() # Shape of input == [BATCH_SIZE, NUM_CROPS=19, CHANNELS=3, HEIGHT=224, WIDTH=244] bs, n_crops, c, h, w = inputs.size() inputs = inputs.view(-1, c, h, w) inputs = torch.autograd.Variable(inputs.view(-1, c, h, w)) target = torch.autograd.Variable(target) preds = self.net(inputs).view(bs, n_crops, -1).mean(dim=1) # loss = torch.sum(torch.abs(preds - target) ** 2) loss = train_dataset.loss(preds, target) # exit() tot_loss += float(loss.data) optimizer.zero_grad() loss.backward() optimizer.step() tot_loss /= len(train_dataset) # Clear cache torch.cuda.empty_cache() # Running on validation set self.net.eval() val_loss = 0.0 for i, (inputs, target) in tqdm(enumerate(val_loader), total=len(val_dataset) / BATCH_SIZE): # inputs = inputs.to(self.device) # target = target.to(self.device) inputs = inputs.cuda() target = target.cuda() # Shape of input == [BATCH_SIZE, NUM_CROPS=19, CHANNELS=3, HEIGHT=224, WIDTH=244] bs, n_crops, c, h, w = inputs.size() inputs = inputs.view(-1, c, h, w) inputs = torch.autograd.Variable(inputs.view(-1, c, h, w), volatile=True) target = torch.autograd.Variable(target, volatile=True) preds = self.net(inputs).view(bs, n_crops, -1).mean(1) # loss = torch.sum(torch.abs(preds - target) ** 2) loss = val_dataset.loss(preds, target) val_loss += float(loss.data) val_loss /= len(val_dataset) # Clear cache torch.cuda.empty_cache() # logging statistics timestamp = str(datetime.datetime.now()).split('.')[0] log = json.dumps({ 'timestamp': timestamp, 'epoch': epoch + 1, 'train_loss': float('%.5f' % tot_loss), 'val_loss': float('%.5f' % val_loss), 'lr': float('%.6f' % LR) }) if logging: print(log) log_file = os.path.join(save_path, 'train.log') if log_file is not None: with open(log_file, 'a') as f: f.write("{}\n".format(log)) model_path = os.path.join(save_path, 'epoch_%d.pth' % (epoch + 1)) self.save_model(model_path) print('Finished Training')