Exemple #1
0
def main():
    args = get_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    SEED = 42
    utils.set_global_seed(SEED)
    utils.prepare_cudnn(deterministic=True)
    num_classes = 14

    #define datasets
    train_dataset = ChestXrayDataSet(
        data_dir=args.path_to_images,
        image_list_file=args.train_list,
        transform=transforms_train,
    )

    val_dataset = ChestXrayDataSet(
        data_dir=args.path_to_images,
        image_list_file=args.val_list,
        transform=transforms_val,
    )

    loaders = {
        'train':
        DataLoader(train_dataset,
                   batch_size=args.batch_size,
                   shuffle=True,
                   num_workers=args.num_workers),
        'valid':
        DataLoader(val_dataset,
                   batch_size=2,
                   shuffle=False,
                   num_workers=args.num_workers)
    }

    logdir = args.log_dir  #where model weights and logs are stored

    #define model
    model = DenseNet121(num_classes)
    if len(args.gpus) > 1:
        model = nn.DataParallel(model)
    device = utils.get_device()
    runner = SupervisedRunner(device=device)

    optimizer = RAdam(model.parameters(), lr=args.lr, weight_decay=0.0003)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     factor=0.25,
                                                     patience=2)

    weights = torch.Tensor(
        [10, 100, 30, 8, 40, 40, 330, 140, 35, 155, 110, 250, 155,
         200]).to(device)
    criterion = BCEWithLogitsLoss(pos_weight=weights)

    class_names = [
        'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
        'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
        'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'
    ]

    runner.train(
        model=model,
        logdir=logdir,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        loaders=loaders,
        num_epochs=args.epochs,

        # We can specify the callbacks list for the experiment;
        # For this task, we will check AUC and accuracy
        callbacks=[
            AUCCallback(
                input_key="targets",
                output_key='logits',
                prefix='auc',
                class_names=class_names,
                num_classes=num_classes,
                activation='Sigmoid',
            ),
            AccuracyCallback(
                input_key="targets",
                output_key="logits",
                prefix="accuracy",
                accuracy_args=[1],
                num_classes=14,
                threshold=0.5,
                activation='Sigmoid',
            ),
        ],
        main_metric='auc/_mean',
        minimize_metric=False,
        verbose=True,
    )
Exemple #2
0
def main():
    cudnn.benchmark = True
    model = DenseNet121(N_CLASSES).cuda()
    model = torch.nn.DataParallel(model).cuda()

    # data preprocess
    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    test_transform = transforms.Compose([
        transforms.Resize(256),
        # crop ten images from original
        transforms.TenCrop(224),
        transforms.Lambda(lambda crops: torch.stack(
            [transforms.ToTensor()(crop) for crop in crops])),
        transforms.Lambda(
            lambda crops: torch.stack([normalize(crop) for crop in crops]))
    ])

    # load data
    if params["aug_dataset"]:
        train_dataset = ChestXrayDataSetWithAugmentation(
            data_dir=DATA_DIR,
            image_list_file=TRAIN_IMAGE_LIST,
            transform=train_transform)
    else:
        train_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                         image_list_file=TRAIN_IMAGE_LIST,
                                         transform=train_transform)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=params["train_batch_size"],
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    train_evaluation_dataset = ChestXrayDataSet(
        data_dir=DATA_DIR,
        image_list_file=TRAIN_IMAGE_LIST,
        transform=test_transform)

    train_evaluation_loader = DataLoader(dataset=train_evaluation_dataset,
                                         batch_size=params["test_batch_size"],
                                         shuffle=False,
                                         num_workers=8,
                                         pin_memory=True)
    dev_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                   image_list_file=DEV_IMAGE_LIST,
                                   transform=test_transform)
    dev_loader = DataLoader(dataset=dev_dataset,
                            batch_size=params["test_batch_size"],
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True)
    test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=TEST_IMAGE_LIST,
                                    transform=test_transform)
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=params["test_batch_size"],
                             shuffle=False,
                             num_workers=8,
                             pin_memory=True)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=params["lr"],
                                 weight_decay=params["beta"])

    def train(epoch):
        print("start training epoch %d" % (epoch))
        start_time = time()
        local_step = 0
        running_loss = 0
        running_loss_list = []
        model.train()
        if not params["aug_dataset"]:
            pos_weight = torch.tensor(train_dataset.pos_weight,
                                      dtype=torch.float32).cuda()
            neg_weight = torch.tensor(train_dataset.neg_weight,
                                      dtype=torch.float32).cuda()
            class_weight = torch.tensor(train_dataset.class_weight,
                                        dtype=torch.float32).cuda()
        for i, (inp, target) in enumerate(train_loader):
            inp = inp.cuda()
            target = target.cuda()
            optimizer.zero_grad()
            output = model(inp)

            if params["loss_function"] == "unweighted":
                local_loss = F.binary_cross_entropy(output, target)
            elif params["loss_function"] == "weighted1":
                local_loss = weighted_binary_cross_entropy(
                    output, target, pos_weight, neg_weight)
            elif params["loss_function"] == "weighted2":
                local_loss = weighted_binary_cross_entropy2(
                    output, target, class_weight)
            else:
                assert False
            running_loss += local_loss.item()
            local_loss.backward()
            optimizer.step()
            if (i + 1) % PRINT_FREQ == 0:
                running_loss /= PRINT_FREQ
                print("epoch %d, batch %d/%d, loss: %.5f" %
                      (epoch, i + 1, len(train_loader), running_loss))
                running_loss_list.append(running_loss)
                running_loss = 0
        print("end training epoch %d, time elapsed: %.2fmin" %
              (epoch, (time() - start_time) / 60))
        return dict(running_loss_list=running_loss_list)

    def evaluate(epoch, dataset_loader, pytorch_dataset, dataset_name):
        print("start evaluating epoch %d on %s" % (epoch, dataset_name))
        gt = torch.tensor([], dtype=torch.float32, device="cuda")
        pred = torch.tensor([], dtype=torch.float32, device="cuda")
        loss = 0.
        model.eval()
        if not params["aug_dataset"]:
            pos_weight = torch.tensor(pytorch_dataset.pos_weight,
                                      dtype=torch.float32).cuda()
            neg_weight = torch.tensor(pytorch_dataset.neg_weight,
                                      dtype=torch.float32).cuda()
            class_weight = torch.tensor(train_dataset.class_weight,
                                        dtype=torch.float32).cuda()
        with torch.no_grad():
            for i, (inp, target) in enumerate(dataset_loader):
                target = target.cuda()
                gt = torch.cat((gt, target), 0)
                bs, n_crops, c, h, w = inp.size()
                inp_reshaped = inp.view(-1, c, h, w).cuda()
                output = model(inp_reshaped)
                output_mean = output.view(bs, n_crops, -1).mean(1)
                pred = torch.cat((pred, output_mean), 0)
                if params["loss_function"] == "unweighted":
                    local_loss = F.binary_cross_entropy(output_mean, target)
                elif params["loss_function"] == "weighted1":
                    local_loss = weighted_binary_cross_entropy(
                        output_mean, target, pos_weight, neg_weight)
                elif params["loss_function"] == "weighted2":
                    local_loss = weighted_binary_cross_entropy2(
                        output_mean, target, class_weight)
                else:
                    assert False
                loss += local_loss * len(target) / len(pytorch_dataset)

        AUROCs = compute_AUCs(gt, pred, N_CLASSES)
        AUROC_avg = np.array(AUROCs).mean()
        print("epoch %d, %s, loss: %.5f, avg_AUC: %.5f" %
              (epoch, dataset_name, loss, AUROC_avg))
        print("epoch %d, %s, individual class AUC" % (epoch, dataset_name))
        for i in range(N_CLASSES):
            print('\tthe AUROC of %s is %.5f' % (CLASS_NAMES[i], AUROCs[i]))
        return dict(auroc=dict(zip(CLASS_NAMES, AUROCs)),
                    auroc_avg=AUROC_avg,
                    loss=loss.item())

    def init_history():
        return dict(epoch=0,
                    train_eval_vals_list=[],
                    dev_eval_vals_list=[],
                    best_dev_eval_vals=dict(auroc_avg=-np.inf, loss=np.inf),
                    best_dev_eval_vals_epoch=-1)

    def update_history(history, epoch, train_eval_vals, dev_eval_vals):
        history["epoch"] = epoch
        history["train_eval_vals_list"].append(train_eval_vals)
        history["dev_eval_vals_list"].append(dev_eval_vals)
        if dev_eval_vals["auroc_avg"] > history["best_dev_eval_vals"][
                "auroc_avg"]:
            history["best_dev_eval_vals"] = dev_eval_vals
            history["best_dev_eval_vals_epoch"] = epoch
            if epoch >= 1:
                print("saving model...")
                state_dict = model.state_dict()
                torch.save(state_dict, params["best_model_file_path"])
        if epoch >= 1:
            state_dict = model.state_dict()
            torch.save(state_dict, params["model_file_path"] % (epoch))
        with open(params["history_file_path"], 'wb') as f:
            pickle.dump(history, f)

    def train_initialization():
        if not os.path.exists(params["base_dir"]):
            os.mkdir(params["base_dir"])
        with open(params["params_file_path"], 'w') as f:
            yaml.dump(params, f, default_flow_style=False)
        if os.path.exists(params["history_file_path"]):
            with open(params["history_file_path"], 'rb') as f:
                old_history = pickle.load(f)
                last_epoch = old_history["epoch"]
                if last_epoch > params["epochs"]:
                    print("training completed")
                    exit(0)
                model_file = params["model_file_path"] % (last_epoch)
                if os.path.exists(model_file):
                    model.load_state_dict(torch.load(model_file))
                    print("training resumed from epoch %d" % last_epoch)
                    return old_history, last_epoch
        return init_history(), 0

    history, last_epoch = train_initialization()
    for epoch in range(last_epoch + 1, params["epochs"] + 1):
        train_eval_vals = train(epoch)
        train_eval_vals2 = evaluate(epoch, train_evaluation_loader,
                                    train_evaluation_dataset, "train set")
        dev_eval_vals = evaluate(epoch, dev_loader, dev_dataset, "dev set")
        update_history(history, epoch, train_eval_vals, dev_eval_vals)
    print("training completed")
Exemple #3
0
def main():
    N_CLASSES = 14
    CLASS_NAMES = ['Atelectasis', 
                'Cardiomegaly', 
                'Effusion', 
                'Infiltration', 
                'Mass', 
                'Nodule', 
                'Pneumonia',
                'Pneumothorax', 
                'Consolidation', 
                'Edema', 
                'Emphysema', 
                'Fibrosis', 
                'Pleural_Thickening', 
                'Hernia']



    # initialize model
    device = utils.get_device()
    model = DenseNet121(N_CLASSES).to(device)
 
    
    
    checkpoint = torch.load(args.checkpoint)

    model.load_state_dict(checkpoint['model_state_dict'])


    # initialize test loader
    test_dataset = ChestXrayDataSet(data_dir=args.path_to_images,
                                    image_list_file=args.test_list,
                                    transform=transforms_test)
    test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size,
                            shuffle=False, num_workers=args.num_workers, pin_memory=True)

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor()
    gt = gt.cuda()
    pred = torch.FloatTensor()
    pred = pred.cuda()

    # switch to evaluate mode
    
    model.eval()
    with torch.no_grad():
        for i, (inp, target) in enumerate(test_loader):
            target = target.cuda()
            gt = torch.cat((gt, target), 0)
            bs, c, h, w = inp.size()
            input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda())
            output = model(input_var)
            output_mean = output.view(bs, -1)
            pred = torch.cat((pred, output_mean.data), 0)

    gt_np = gt.cpu().numpy()
    pred_np = sigmoid(pred.cpu().numpy())

    Y_t = [] #labels for each anomaly
    for i in range(N_CLASSES):
        Y_t.append([])
        for x in gt_np:
            Y_t[i].append(x[i])

    Y_pred = [] #preds for each anomaly
    for j in range(N_CLASSES):
        Y_pred.append([])
        for y in pred_np:
            Y_pred[j].append(y[j])


    AUCs = [] # AUCs for each 
    for i in range(N_CLASSES):
        auc = roc_auc_score(Y_t[i], Y_pred[i])
        AUCs.append(auc)

    matrices=[] #for each
    for i in range(14):
        matrix = confusion_matrix(Y_t[i], np.asarray(Y_pred[i])>0.6)
        matrices.append(matrix)

    
    class_names = ['no disease', 'disease']
    fig = plt.figure(figsize = (20,20))
    for i in range(14):
        plt.subplot(4,4,i+1)
        
        df_cm = pd.DataFrame(
            matrices[i], index=class_names, columns=class_names)
        heatmap = sns.heatmap(df_cm, annot=True, fmt="d").set_title(CLASS_NAMES[i])
        plt.ylabel('True label')
        plt.xlabel('Predicted label')
        
        
        
    plt.show()
    fig.savefig(os.path.join(args.test_outdir,'confusion_matrix.pdf'))

    fig, axes2d = plt.subplots(nrows=2, ncols=7,
                            sharex=True, sharey=True,figsize = (12, 4))



    for i, row in enumerate(axes2d):
        for j, cell in enumerate(row):
            if i==0:
                x=i+j
            else:
                x=13-i*j
            
            fpr, tpr, threshold = roc_curve(Y_t[x], Y_pred[x])
            roc_auc = auc(fpr, tpr)
                      
            cell.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
            cell.legend(loc = 'lower right', handlelength=0,handletextpad=0,frameon=False, prop={'size': 8})

            cell.plot([0, 1], [0, 1],'r--')
            plt.xlim([0, 1])
            plt.ylim([0, 1])
            cell.set_title(CLASS_NAMES[x],fontsize=10)
            
            if i == len(axes2d) - 1:
                cell.set_xlabel('False positive rate')
            if j == 0:
                cell.set_ylabel('True negative rate')
    fig.tight_layout(pad=1.0)    
    plt.show()
    fig.savefig(os.path.join(args.test_outdir,'roc_auc.pdf'))
Exemple #4
0
DATA_DIR = '/home/lrh/dataset/ChestXray-NIHCC/images_v1_small'
testTXTFile = '/home/lrh/git/CheXNet/ChestX-ray14/labels/test.txt'
trainTXTFile = '/home/lrh/git/CheXNet/ChestX-ray14/labels/train.txt'
normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(), normalize
])
test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(), normalize
])
train_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                 image_list_file=trainTXTFile,
                                 transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=16,
                                           shuffle=True,
                                           num_workers=8,
                                           pin_memory=True)
test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                image_list_file=testTXTFile,
                                transform=test_transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=16,
                                          shuffle=False,
                                          num_workers=8,
                                          pin_memory=True)
print('\ndone')
Exemple #5
0
    def train(self,
              TRAIN_IMAGE_LIST,
              VAL_IMAGE_LIST,
              NUM_EPOCHS=10,
              LR=0.001,
              BATCH_SIZE=64,
              start_epoch=0,
              logging=True,
              save_path=None,
              freeze_feature_layers=True):
        """
        Train the CovidAID
        """
        normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])

        train_dataset = ChestXrayDataSet(
            image_list_file=TRAIN_IMAGE_LIST,
            transform=transforms.Compose([
                transforms.Resize(256),
                transforms.TenCrop(224),
                transforms.Lambda(lambda crops: torch.stack(
                    [transforms.ToTensor()(crop) for crop in crops])),
                transforms.Lambda(lambda crops: torch.stack(
                    [normalize(crop) for crop in crops]))
            ]),
            combine_pneumonia=self.combine_pneumonia)
        if self.distributed:
            sampler = DistributedSampler(train_dataset)
            train_loader = DataLoader(dataset=train_dataset,
                                      batch_size=BATCH_SIZE,
                                      shuffle=False,
                                      num_workers=8,
                                      pin_memory=True,
                                      sampler=sampler)
        else:
            train_loader = DataLoader(dataset=train_dataset,
                                      batch_size=BATCH_SIZE,
                                      shuffle=True,
                                      num_workers=8,
                                      pin_memory=True)

        val_dataset = ChestXrayDataSet(
            image_list_file=VAL_IMAGE_LIST,
            transform=transforms.Compose([
                transforms.Resize(256),
                transforms.TenCrop(224),
                transforms.Lambda(lambda crops: torch.stack(
                    [transforms.ToTensor()(crop) for crop in crops])),
                transforms.Lambda(lambda crops: torch.stack(
                    [normalize(crop) for crop in crops]))
            ]),
            combine_pneumonia=self.combine_pneumonia)
        if self.distributed:
            sampler = DistributedSampler(val_dataset)
            val_loader = DataLoader(dataset=val_dataset,
                                    batch_size=BATCH_SIZE,
                                    shuffle=False,
                                    num_workers=8,
                                    pin_memory=True,
                                    sampler=sampler)
        else:
            val_loader = DataLoader(dataset=val_dataset,
                                    batch_size=BATCH_SIZE,
                                    shuffle=True,
                                    num_workers=8,
                                    pin_memory=True)

        # Freeze heads and create optimizer
        if freeze_feature_layers:
            print("Freezing feature layers")
            for param in self.net.densenet121.features.parameters():
                param.requires_grad = False

        # optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.net.parameters()),
        #                 lr=LR, momentum=0.9)
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      self.net.parameters()),
                               lr=LR)

        for epoch in range(start_epoch, NUM_EPOCHS):
            # switch to train mode
            self.net.train()
            tot_loss = 0.0
            for i, (inputs,
                    target) in tqdm(enumerate(train_loader),
                                    total=len(train_dataset) / BATCH_SIZE):
                # inputs = inputs.to(self.device)
                # target = target.to(self.device)
                inputs = inputs.cuda()
                target = target.cuda()

                # Shape of input == [BATCH_SIZE, NUM_CROPS=19, CHANNELS=3, HEIGHT=224, WIDTH=244]
                bs, n_crops, c, h, w = inputs.size()
                inputs = inputs.view(-1, c, h, w)
                inputs = torch.autograd.Variable(inputs.view(-1, c, h, w))
                target = torch.autograd.Variable(target)
                preds = self.net(inputs).view(bs, n_crops, -1).mean(dim=1)

                # loss = torch.sum(torch.abs(preds - target) ** 2)
                loss = train_dataset.loss(preds, target)
                # exit()
                tot_loss += float(loss.data)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            tot_loss /= len(train_dataset)

            # Clear cache
            torch.cuda.empty_cache()

            # Running on validation set
            self.net.eval()
            val_loss = 0.0
            for i, (inputs,
                    target) in tqdm(enumerate(val_loader),
                                    total=len(val_dataset) / BATCH_SIZE):
                # inputs = inputs.to(self.device)
                # target = target.to(self.device)
                inputs = inputs.cuda()
                target = target.cuda()

                # Shape of input == [BATCH_SIZE, NUM_CROPS=19, CHANNELS=3, HEIGHT=224, WIDTH=244]
                bs, n_crops, c, h, w = inputs.size()
                inputs = inputs.view(-1, c, h, w)
                inputs = torch.autograd.Variable(inputs.view(-1, c, h, w),
                                                 volatile=True)
                target = torch.autograd.Variable(target, volatile=True)

                preds = self.net(inputs).view(bs, n_crops, -1).mean(1)
                # loss = torch.sum(torch.abs(preds - target) ** 2)
                loss = val_dataset.loss(preds, target)

                val_loss += float(loss.data)

            val_loss /= len(val_dataset)

            # Clear cache
            torch.cuda.empty_cache()

            # logging statistics
            timestamp = str(datetime.datetime.now()).split('.')[0]
            log = json.dumps({
                'timestamp': timestamp,
                'epoch': epoch + 1,
                'train_loss': float('%.5f' % tot_loss),
                'val_loss': float('%.5f' % val_loss),
                'lr': float('%.6f' % LR)
            })
            if logging:
                print(log)

            log_file = os.path.join(save_path, 'train.log')
            if log_file is not None:
                with open(log_file, 'a') as f:
                    f.write("{}\n".format(log))

            model_path = os.path.join(save_path, 'epoch_%d.pth' % (epoch + 1))
            self.save_model(model_path)

        print('Finished Training')