Пример #1
0
def main():

    test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=IMAGE_LIST_TEST)

    length = test_dataset.__len__()
    print("The length of test data is ", length)

    # (image_name, label, image) = test_dataset.__getitem__(0)
    # print ("The path of the first image is ", image_name, ", the lable of it is ", label)
    # (image, label) = test_dataset.__getitem__(0)
    # print ("The lable of the first image is ", label)

    dataDir = DATA_DIR
    imageListFileTrain = IMAGE_LIST_TRAIN
    imageListFileVal = IMAGE_LIST_VAL

    timestampTime = time.strftime("%H%M%S")
    timestampDate = time.strftime("%d%m%Y")
    timestampLaunch = timestampDate + '-' + timestampTime

    transResize = 256
    transCrop = 224

    isTrained = True
    classCount = 156

    batchSize = 16
    epochSize = 100

    ChexnetTrainer.train(dataDir, imageListFileTrain, imageListFileVal,
                         transResize, transCrop, isTrained, classCount,
                         batchSize, epochSize, timestampLaunch, None)
Пример #2
0
def main():

    cudnn.benchmark = True

    # initialize and load the model
    model = DenseNet121(N_CLASSES).cuda()
    model = torch.nn.DataParallel(model).cuda()

    if os.path.isfile(CKPT_PATH):
        print("=> loading checkpoint")
        checkpoint = torch.load(CKPT_PATH)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint")
    else:
        print("=> no checkpoint found")

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    test_dataset = ChestXrayDataSet(
        data_dir=DATA_DIR,
        image_list_file=TEST_IMAGE_LIST,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.TenCrop(224),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops])),
            transforms.Lambda(
                lambda crops: torch.stack([normalize(crop) for crop in crops]))
        ]))
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             num_workers=8,
                             pin_memory=True)

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor()
    gt = gt.cuda()
    pred = torch.FloatTensor()
    pred = pred.cuda()

    # switch to evaluate mode
    model.eval()

    for i, (inp, target) in enumerate(test_loader):
        target = target.cuda()
        gt = torch.cat((gt, target), 0)
        bs, n_crops, c, h, w = inp.size()
        input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda(),
                                            volatile=True)
        output = model(input_var)
        output_mean = output.view(bs, n_crops, -1).mean(1)
        pred = torch.cat((pred, output_mean.data), 0)

    AUROCs = compute_AUCs(gt, pred)
    AUROC_avg = np.array(AUROCs).mean()
    print('The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs[i]))
Пример #3
0
def show_dataset():

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    trainDataset = ChestXrayDataSet(
        data_dir=DATA_DIR,
        image_list_file=TRAIN_IMAGE_LIST,
        transform=transforms.Compose([
            transforms.Resize(224),
            # normalize,
            # transforms.RandomHorizontalFlip
            transforms.TenCrop(224),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops])),
            transforms.Lambda(lambda crops: torch.stack(
                [normalize(crop) for crop in crops])),
        ]))
    trainloader = DataLoader(dataset=trainDataset,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             num_workers=8,
                             pin_memory=True)

    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        pdb.set_trace()
Пример #4
0
 def __init_loader(self):
     train_loader = DataLoader(
         ChestXrayDataSet(data_dir=self.args.data_dir,
                          file_list=self.args.train_csv,
                          transforms=self.__init_transform()),
         batch_size=self.args.batch_size,
         shuffle=True
     )
     val_loader = DataLoader(
         ChestXrayDataSet(data_dir=self.args.data_dir,
                          file_list=self.args.val_csv,
                          transforms=self.__init_transform()),
         batch_size=self.args.val_batch_size,
         shuffle=True
     )
     return train_loader, val_loader
Пример #5
0
 def __init_loader(self):
     test_loader = DataLoader(
         ChestXrayDataSet(data_dir=self.args.data_dir,
                          file_list=self.args.test_csv,
                          transforms=self.__init_transform()),
         batch_size=self.args.batch_size,
         shuffle=False
     )
     return test_loader
Пример #6
0
def init():
    global isinit, dataloaders, dataset_sizes
    if isinit:
        return
    else:
        isinit = True

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
    data_transforms = {
        'train':
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(), normalize
        ]),
        'test':
        transforms.Compose([
            transforms.Resize(256),
            transforms.TenCrop(224),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops])),
            transforms.Lambda(
                lambda crops: torch.stack([normalize(crop) for crop in crops]))
        ]),
        'val':
        transforms.Compose([
            transforms.Resize(256),
            transforms.TenCrop(224),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops])),
            transforms.Lambda(
                lambda crops: torch.stack([normalize(crop) for crop in crops]))
        ])
    }

    image_datasets = {
        x: ChestXrayDataSet(data_dir=DATA_DIR,
                            image_list_file=IMAGE_LIST_FILES[x],
                            transform=data_transforms[x])
        for x in ['train', 'test', 'val']
    }

    dataloaders = {
        x: DataLoader(dataset=image_datasets[x],
                      batch_size=BATCH_SIZE,
                      shuffle=False if x == 'test' else True,
                      num_workers=8,
                      pin_memory=True)
        for x in ['train', 'test', 'val']
    }
    dataset_sizes = {
        x: len(image_datasets[x])
        for x in ['train', 'test', 'val']
    }
def main():
    print('********************load data********************')
    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=TEST_IMAGE_LIST,
                                    transform=transforms.Compose([
                                        transforms.Resize(256),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        normalize,
                                    ]))
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=128,
                             shuffle=False,
                             num_workers=4,
                             pin_memory=True)
    print('********************load data succeed!********************')

    print('********************load model********************')
    # initialize and load the model
    Global_Branch_model = Densenet121_AG(pretrained=False,
                                         num_classes=N_CLASSES).cuda()
    Local_Branch_model = Densenet121_AG(pretrained=False,
                                        num_classes=N_CLASSES).cuda()
    Fusion_Branch_model = Fusion_Branch(input_size=2048,
                                        output_size=N_CLASSES).cuda()

    if os.path.isfile(CKPT_PATH_G):
        checkpoint = torch.load(CKPT_PATH_G)
        Global_Branch_model.load_state_dict(checkpoint)
        print("=> loaded Global_Branch_model checkpoint")

    if os.path.isfile(CKPT_PATH_L):
        checkpoint = torch.load(CKPT_PATH_L)
        Local_Branch_model.load_state_dict(checkpoint)
        print("=> loaded Local_Branch_model checkpoint")

    if os.path.isfile(CKPT_PATH_F):
        checkpoint = torch.load(CKPT_PATH_F)
        Fusion_Branch_model.load_state_dict(checkpoint)
        print("=> loaded Fusion_Branch_model checkpoint")

    cudnn.benchmark = True
    print('******************** load model succeed!********************')

    print('******* begin testing!*********')
    test(Global_Branch_model, Local_Branch_model, Fusion_Branch_model,
         test_loader)
Пример #8
0
 def train(epoch, dev_auc):
     print("start training epoch %d" % (epoch))
     start_time = time()
     local_step = 0
     running_loss = 0
     running_loss_list = []
     model.train()
     if dev_auc is not None:
         train_dataset = ChestXrayDataSetWithAugmentationEachEpoch(
             data_dir=DATA_DIR,
             image_list_file=TRAIN_IMAGE_LIST,
             aucs=dev_auc,
             transform=train_transform)
     else:
         train_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                          image_list_file=TRAIN_IMAGE_LIST,
                                          transform=train_transform)
     train_loader = DataLoader(dataset=train_dataset,
                               batch_size=params["train_batch_size"],
                               shuffle=True,
                               num_workers=8,
                               pin_memory=True)
     for i, (inp, target) in enumerate(train_loader):
         inp = inp.cuda()
         target = target.cuda()
         optimizer.zero_grad()
         output = model(inp)
         local_loss = F.binary_cross_entropy(output, target)
         running_loss += local_loss.item()
         local_loss.backward()
         optimizer.step()
         if (i + 1) % PRINT_FREQ == 0:
             running_loss /= PRINT_FREQ
             print("epoch %d, batch %d/%d, loss: %.5f" %
                   (epoch, i + 1, len(train_loader), running_loss))
             running_loss_list.append(running_loss)
             running_loss = 0
     print("end training epoch %d, time elapsed: %.2fmin" %
           (epoch, (time() - start_time) / 60))
     return dict(running_loss_list=running_loss_list)
Пример #9
0
def main():

    cudnn.benchmark = True

    model = DenseNet121(N_CLASSES).cuda()

    ## getting the images list to be used for heatmap generation
    # image_names = []
    # with open(TEST_IMAGE_LIST, 'r') as f:
    #     for line in f:
    #         items = line.split()
    #         image_names.append(items[0])
    ##############################################################

    if os.path.isfile(CKPT_PATH):
        print("=> loading checkpoint")
        checkpoint = torch.load(CKPT_PATH)
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
        )
        state_dict = checkpoint['state_dict']
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        for key in list(state_dict.keys()):
            split_key = key.split('.')
            new_key = '.'.join(split_key[1:])
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
        model.load_state_dict(state_dict)
        print("=> loaded checkpoint")
    else:
        print("=> no checkpoint found")

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    test_dataset = ChestXrayDataSet(
        data_dir=DATA_DIR,
        image_list_file=TEST_IMAGE_LIST,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.TenCrop(224),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops])),
            transforms.Lambda(
                lambda crops: torch.stack([normalize(crop) for crop in crops]))
        ]))
    test_loader = sy.FederatedDataLoader(
        federated_dataset=test_dataset.federate(tuple(workers)),
        batch_size=BATCH_SIZE,
        shuffle=False)

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor()
    gt = gt.cuda()
    pred = torch.FloatTensor()
    pred = pred.cuda()

    # switch to evaluate mode
    model.eval()
    model_nosyft = model
    for i, (inp, target) in enumerate(test_loader):
        location = inp.location
        target = target.get().cuda()
        gt = torch.cat((gt, target), 0)
        inp = inp.get()
        bs, n_crops, c, h, w = inp.size()
        input_var = inp.view(-1, c, h, w).cuda().send(location)
        model.send(location)
        output = model(input_var)
        output_mean = output.view(bs, n_crops, -1).mean(1)
        output_mean_data = output_mean.get()
        model.get()
        pred = torch.cat((pred, output_mean_data.data), 0)
    pred_np = pred.cpu().numpy()
    pred_np[pred_np < THRESHOLD] = 0
    pred_np[pred_np >= THRESHOLD] = 1
    indexes = []

    AUROCs, specificity, sensitivity, F1score = compute_AUCs(gt, pred)
    AUROC_avg = np.array(AUROCs).mean()
    print('The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs[i]))
        print('The specificity of {} is {}'.format(CLASS_NAMES[i],
                                                   specificity[i]))
        print('The sensitivity of {} is {}'.format(CLASS_NAMES[i],
                                                   sensitivity[i]))
        print('The F1 score of {} is {}'.format(CLASS_NAMES[i], F1score[i]))

    for arr in pred_np:
        indexes.append([index for index, val in enumerate(arr) if val == 1])
    for index, disease_array in enumerate(indexes):
        if len(disease_array) > 0:
            print(
                f'XRAY :: {index + 1} :: {[CLASS_NAMES[i] for i in disease_array]}'
            )
        else:
            print(f'XRAY :: {index + 1} :: No Disease detected')
        # print("Confidence of the 14 classes are {}".format(pred[i] * 100))
        print("Confidence of the 14 classes are {}".format({
            CLASS_NAMES[j]: pred[index][j] * 100
            for j in range(len(pred[index]))
        }))
        get_heat_map(model_nosyft, image_names[index], DATA_DIR)
Пример #10
0
    print("=> no checkpoint found")

DATA_DIR = '/store/dataset/ChestXray-NIHCC/images_v1_small'
TEST_IMAGE_LIST = '../ChestX-ray14/labels/test_list.txt'
BATCH_SIZE = 10
#IMAGE_PATH = '/store/dataset/ChestXray-NIHCC/images/00017046_015.png'
#img = Image.open(IMAGE_PATH).convert('RGB')
normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(), normalize
])

test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                image_list_file=TEST_IMAGE_LIST,
                                transform=test_transform)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=BATCH_SIZE,
                                          shuffle=False,
                                          num_workers=8,
                                          pin_memory=True)

#img_torch = transform(img)
#img_torch.shape

pgd_params = {
    'ord': np.inf,
    'y': None,
    'y_target': None,
    'eps': 5.0 / 255,
Пример #11
0
def main():

    cudnn.benchmark = True

    # initialize and load the model
    model = DenseNet121(N_CLASSES).cuda()
    model = torch.nn.DataParallel(model).cuda()

    if os.path.isfile(CKPT_PATH):
        print("=> loading checkpoint")
        model = torch.load(CKPT_PATH)
        #model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint")
    else:
        print("=> no checkpoint found")

    #define loss and optimizer

    #The commented out losses are all the losses that I had attempted and all except crossentropy work just fine the cross entropy has a dimensional...
    #... expectation that may cause trouble to the structure so I have refrained from using the built in function and instead designed  a custom cross-entropy loss function I have posted in the pytorch discussion forum of the related issue but have not yet received any significant feedback.

    #criterion=torch.nn.MSELoss()

    #criterion=torch.nn.CrossEntropyLoss()

    ####       Notice that this is the orignal loss implemeted in the paper ####
    #criterion = torch.nn.BCELoss(size_average = True)

    #### This is the orignal optimizer used in the paper. possibly  quicker convergence with Adam than with cross  entropy?.

    optimiser = optim.Adam(model.parameters(),
                           lr=0.001,
                           betas=(0.9, 0.999),
                           eps=1e-08,
                           weight_decay=1e-5)

    #optimiser=optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    train_dataset = ChestXrayDataSet(
        data_dir=DATA_DIR,
        image_list_file=TRAIN_IMAGE_LIST,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.TenCrop(224),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops])),
            transforms.Lambda(
                lambda crops: torch.stack([normalize(crop) for crop in crops]))
        ]))
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              num_workers=LOADER_WORKERS,
                              pin_memory=True)

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor()
    gt = gt.cuda()
    pred = torch.FloatTensor()
    pred = pred.cuda()
    if os.path.isfile(TRAIN_VEC_FILE):
        train_vec = pickle.load(open(TRAIN_VEC_FILE, 'rb'))
        running_loss = train_vec[-1][3]
        LOADED_VEC_FLAG = True
        starting_epoch = train_vec[-1][1] + 1
    else:
        starting_epoch = 0
        running_loss = 0.0
        train_vec = []
    # switch to train mode
    model.train()
    print("starting loop")
    try:
        for epoch in range(starting_epoch + 1, N_EPOCHS):
            for i, (inp, label) in enumerate(train_loader):
                label = label.cuda()
                gt = torch.cat((gt, label), 0)
                bs, n_crops, c, h, w = inp.size()
                input_var = torch.autograd.Variable(
                    inp.view(-1, c, h, w).cuda())

                #fw + back + optimise
                output = model(
                    input_var)  #output dim should be: minibatch, classnum...
                output_mean = output.view(bs, n_crops, -1).mean(1)

                loss = custom_cross(output_mean,
                                    label.type(torch.FloatTensor).cuda())

                optimiser.zero_grad()

                loss.backward()
                optimiser.step()
                #output_mean = output.view(bs, n_crops, -1).mean(1)
                #pred = torch.cat((pred, output_mean.data), 0)
                loss = loss.item()
                running_loss += loss
                #print statistics
                if not i % 100:
                    print('[%d, %5d] loss=%.3f' % (epoch, i, running_loss))
                    torch.save(model, CKPT_PATH)
                    train_vec.append([epoch, i, loss, running_loss])
                    pickle.dump(train_vec, open(TRAIN_VEC_FILE, 'wb'))
    except:
        print('error in iteration: ' + str(i))
        raise ()
Пример #12
0
def main(modelfile):
    model_xml = os.path.join('model', modelfile)
    model_bin = model_xml.replace('xml', 'bin')

    log.info('Creating Inference Engine')
    ie = IECore()
    net = ie.read_network(model=model_xml, weights=model_bin)

    log.info('Preparing input blobs')
    input_blob = next(iter(net.input_info))
    out_blob = next(iter(net.outputs))
    net.batch_size = (args.batch_size * N_CROPS)

    n, c, h, w = net.input_info[input_blob].input_data.shape

    # for image load
    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    test_dataset = ChestXrayDataSet(
        data_dir=DATA_DIR,
        image_list_file=TEST_IMAGE_LIST,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.TenCrop(224),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops])),
            transforms.Lambda(
                lambda crops: torch.stack([normalize(crop) for crop in crops]))
        ]))

    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False,
                             pin_memory=False)

    gt = torch.FloatTensor()
    pred = torch.FloatTensor()

    # loading model to the plugin
    log.info('Loading model to the plugin')
    exec_net = ie.load_network(network=net, device_name='CPU')

    for index, (data, target) in enumerate(test_loader):
        start_time = timeit.default_timer()

        gt = torch.cat((gt, target), 0)
        bs, n_crops, c, h, w = data.size()

        images = data.view(-1, c, h, w).numpy()

        if bs != args.batch_size:
            images2 = np.zeros(shape=(args.batch_size * n_crops, c, h, w))
            images2[:bs * n_crops, :c, :h, :w] = images
            images = images2

        res = exec_net.infer(inputs={input_blob: images})
        res = res[out_blob]
        res = res.reshape(args.batch_size, n_crops, -1)
        res = np.mean(res, axis=1)

        if bs != args.batch_size:
            res = res[:bs, :res.shape[1]]

        pred = torch.cat((pred, torch.from_numpy(res)), 0)

        print('%03d/%03d, time: %6.3f sec' %
              (index, len(test_loader), (timeit.default_timer() - start_time)))

        if index == 10:
            break

    AUCs = [
        roc_auc_score(gt.cpu()[:, i],
                      pred.cpu()[:, i]) for i in range(N_CLASSES)
    ]
    AUC_avg = np.array(AUCs).mean()
    print('The average AUC is {AUC_avg:.3f}'.format(AUC_avg=AUC_avg))
    for i in range(N_CLASSES):
        print('The AUC of {} is {:.3f}'.format(CLASS_NAMES[i], AUCs[i]))
Пример #13
0
def main():
    args = get_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    SEED = 42
    utils.set_global_seed(SEED)
    utils.prepare_cudnn(deterministic=True)
    num_classes = 14

    #define datasets
    train_dataset = ChestXrayDataSet(
        data_dir=args.path_to_images,
        image_list_file=args.train_list,
        transform=transforms_train,
    )

    val_dataset = ChestXrayDataSet(
        data_dir=args.path_to_images,
        image_list_file=args.val_list,
        transform=transforms_val,
    )

    loaders = {
        'train':
        DataLoader(train_dataset,
                   batch_size=args.batch_size,
                   shuffle=True,
                   num_workers=args.num_workers),
        'valid':
        DataLoader(val_dataset,
                   batch_size=2,
                   shuffle=False,
                   num_workers=args.num_workers)
    }

    logdir = args.log_dir  #where model weights and logs are stored

    #define model
    model = DenseNet121(num_classes)
    if len(args.gpus) > 1:
        model = nn.DataParallel(model)
    device = utils.get_device()
    runner = SupervisedRunner(device=device)

    optimizer = RAdam(model.parameters(), lr=args.lr, weight_decay=0.0003)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     factor=0.25,
                                                     patience=2)

    weights = torch.Tensor(
        [10, 100, 30, 8, 40, 40, 330, 140, 35, 155, 110, 250, 155,
         200]).to(device)
    criterion = BCEWithLogitsLoss(pos_weight=weights)

    class_names = [
        'Atelectasis', 'Cardiomegaly', 'Effusion', 'Infiltration', 'Mass',
        'Nodule', 'Pneumonia', 'Pneumothorax', 'Consolidation', 'Edema',
        'Emphysema', 'Fibrosis', 'Pleural_Thickening', 'Hernia'
    ]

    runner.train(
        model=model,
        logdir=logdir,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        loaders=loaders,
        num_epochs=args.epochs,

        # We can specify the callbacks list for the experiment;
        # For this task, we will check AUC and accuracy
        callbacks=[
            AUCCallback(
                input_key="targets",
                output_key='logits',
                prefix='auc',
                class_names=class_names,
                num_classes=num_classes,
                activation='Sigmoid',
            ),
            AccuracyCallback(
                input_key="targets",
                output_key="logits",
                prefix="accuracy",
                accuracy_args=[1],
                num_classes=14,
                threshold=0.5,
                activation='Sigmoid',
            ),
        ],
        main_metric='auc/_mean',
        minimize_metric=False,
        verbose=True,
    )
Пример #14
0
def main():

    cudnn.benchmark = True

    # initialize and load the model
    model = DenseNet121(N_CLASSES).cuda()
    # model = torch.nn.DataParallel(model).cuda()

    ## getting the images list to be used for heatmap generation
    image_names = []
    with open(TEST_IMAGE_LIST, 'r') as f:
        for line in f:
            items = line.split()
            image_names.append(items[0])
    ##############################################################

    if os.path.isfile(CKPT_PATH):
        print("=> loading checkpoint")
        checkpoint = torch.load(CKPT_PATH)
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
        )
        state_dict = checkpoint['state_dict']
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        for key in list(state_dict.keys()):
            split_key = key.split('.')
            new_key = '.'.join(split_key[1:])
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
        model.load_state_dict(state_dict)
        print("=> loaded checkpoint")
    else:
        print("=> no checkpoint found")

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    test_dataset = ChestXrayDataSet(
        data_dir=DATA_DIR,
        image_list_file=TEST_IMAGE_LIST,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.TenCrop(224),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops])),
            transforms.Lambda(
                lambda crops: torch.stack([normalize(crop) for crop in crops]))
        ]))
    test_loader = sy.FederatedDataLoader(
        federated_dataset=test_dataset.federate((bob, alice)),
        batch_size=BATCH_SIZE,
        shuffle=False)

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor()
    gt = gt.cuda()
    pred = torch.FloatTensor()
    pred = pred.cuda()

    # switch to evaluate mode
    model.eval()
    model_nosyft = model
    for i, (inp, target) in enumerate(test_loader):
        location = inp.location
        inp = inp.get()
        bs, n_crops, c, h, w = inp.size()
        # print(inp.size())
        # input_var = torch.autograd.Variable(torch.FloatTensor(inp.view(-1, c, h, w)).cuda(), volatile=True).send(location)
        input_var = inp.view(-1, c, h, w).cuda().send(location)
        model.send(location)
        output = model(input_var)
        # print(f'output of the inference is {output.shape}')
        output_mean = output.view(bs, n_crops, -1).mean(1)
        output_mean_data = output_mean.get()
        model.get()
        pred = torch.cat((pred, output_mean_data.data), 0)
    pred_np = pred.cpu().numpy()
    pred_np[pred_np < THRESHOLD] = 0
    pred_np[pred_np >= THRESHOLD] = 1
    indexes = []
    for arr in pred_np:
        indexes.append([index for index, val in enumerate(arr) if val == 1])
    for index, disease_array in enumerate(indexes):
        if len(disease_array) > 0:
            print(
                f'XRAY :: {index + 1} :: {[CLASS_NAMES[i] for i in disease_array]}'
            )
        else:
            print(f'XRAY :: {index + 1} :: No Disease detected')
        get_heat_map(model_nosyft, image_names[index], DATA_DIR)
Пример #15
0
def main():
    cudnn.benchmark = True
    model = DenseNet121(N_CLASSES).cuda()
    model = torch.nn.DataParallel(model).cuda()

    # data preprocess
    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
    train_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    test_transform = transforms.Compose([
        transforms.Resize(256),
        # crop ten images from original
        transforms.TenCrop(224),
        transforms.Lambda(lambda crops: torch.stack(
            [transforms.ToTensor()(crop) for crop in crops])),
        transforms.Lambda(
            lambda crops: torch.stack([normalize(crop) for crop in crops]))
    ])

    # load data
    # train_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
    #                                  image_list_file=TRAIN_IMAGE_LIST,
    #                                  transform=train_transform)
    # train_loader = DataLoader(
    #     dataset=train_dataset, batch_size=params["train_batch_size"], shuffle=True, num_workers=8, pin_memory=True)
    train_evaluation_dataset = ChestXrayDataSet(
        data_dir=DATA_DIR,
        image_list_file=TRAIN_IMAGE_LIST,
        transform=test_transform)
    train_evaluation_loader = DataLoader(dataset=train_evaluation_dataset,
                                         batch_size=params["test_batch_size"],
                                         shuffle=False,
                                         num_workers=8,
                                         pin_memory=True)
    dev_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                   image_list_file=DEV_IMAGE_LIST,
                                   transform=test_transform)
    dev_loader = DataLoader(dataset=dev_dataset,
                            batch_size=params["test_batch_size"],
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True)
    test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=TEST_IMAGE_LIST,
                                    transform=test_transform)
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=params["test_batch_size"],
                             shuffle=False,
                             num_workers=8,
                             pin_memory=True)

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=params["lr"],
                                 weight_decay=params["beta"])

    def train(epoch, dev_auc):
        print("start training epoch %d" % (epoch))
        start_time = time()
        local_step = 0
        running_loss = 0
        running_loss_list = []
        model.train()
        if dev_auc is not None:
            train_dataset = ChestXrayDataSetWithAugmentationEachEpoch(
                data_dir=DATA_DIR,
                image_list_file=TRAIN_IMAGE_LIST,
                aucs=dev_auc,
                transform=train_transform)
        else:
            train_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                             image_list_file=TRAIN_IMAGE_LIST,
                                             transform=train_transform)
        train_loader = DataLoader(dataset=train_dataset,
                                  batch_size=params["train_batch_size"],
                                  shuffle=True,
                                  num_workers=8,
                                  pin_memory=True)
        for i, (inp, target) in enumerate(train_loader):
            inp = inp.cuda()
            target = target.cuda()
            optimizer.zero_grad()
            output = model(inp)
            local_loss = F.binary_cross_entropy(output, target)
            running_loss += local_loss.item()
            local_loss.backward()
            optimizer.step()
            if (i + 1) % PRINT_FREQ == 0:
                running_loss /= PRINT_FREQ
                print("epoch %d, batch %d/%d, loss: %.5f" %
                      (epoch, i + 1, len(train_loader), running_loss))
                running_loss_list.append(running_loss)
                running_loss = 0
        print("end training epoch %d, time elapsed: %.2fmin" %
              (epoch, (time() - start_time) / 60))
        return dict(running_loss_list=running_loss_list)

    def evaluate(epoch, dataset_loader, pytorch_dataset, dataset_name):
        print("start evaluating epoch %d on %s" % (epoch, dataset_name))
        gt = torch.tensor([], dtype=torch.float32, device="cuda")
        pred = torch.tensor([], dtype=torch.float32, device="cuda")
        loss = 0.
        model.eval()
        with torch.no_grad():
            for i, (inp, target) in enumerate(dataset_loader):
                target = target.cuda()
                gt = torch.cat((gt, target), 0)
                bs, n_crops, c, h, w = inp.size()
                inp_reshaped = inp.view(-1, c, h, w).cuda()
                output = model(inp_reshaped)
                output_mean = output.view(bs, n_crops, -1).mean(1)
                pred = torch.cat((pred, output_mean), 0)
                local_loss = F.binary_cross_entropy(output_mean, target)
                loss += local_loss * len(target) / len(pytorch_dataset)
        AUROCs = compute_AUCs(gt, pred, N_CLASSES)
        AUROC_avg = np.array(AUROCs).mean()
        print("epoch %d, %s, loss: %.5f, avg_AUC: %.5f" %
              (epoch, dataset_name, loss, AUROC_avg))
        print("epoch %d, %s, individual class AUC" % (epoch, dataset_name))
        for i in range(N_CLASSES):
            print('\tthe AUROC of %s is %.5f' % (CLASS_NAMES[i], AUROCs[i]))
        return dict(auroc=dict(zip(CLASS_NAMES, AUROCs)),
                    auroc_avg=AUROC_avg,
                    loss=loss.item())

    def init_history():
        return dict(epoch=0,
                    train_eval_vals_list=[],
                    dev_eval_vals_list=[],
                    best_dev_eval_vals=dict(auroc_avg=-np.inf, loss=np.inf),
                    best_dev_eval_vals_epoch=-1)

    def update_history(history, epoch, train_eval_vals, dev_eval_vals):
        history["epoch"] = epoch
        history["train_eval_vals_list"].append(train_eval_vals)
        history["dev_eval_vals_list"].append(dev_eval_vals)
        if dev_eval_vals["auroc_avg"] > history["best_dev_eval_vals"][
                "auroc_avg"]:
            history["best_dev_eval_vals"] = dev_eval_vals
            history["best_dev_eval_vals_epoch"] = epoch
            if epoch >= 1:
                print("saving model...")
                state_dict = model.state_dict()
                torch.save(state_dict, params["best_model_file_path"])
        if epoch >= 1:
            state_dict = model.state_dict()
            torch.save(state_dict, params["model_file_path"] % (epoch))
        with open(params["history_file_path"], 'wb') as f:
            pickle.dump(history, f)

    def train_initialization():
        if not os.path.exists(params["base_dir"]):
            os.mkdir(params["base_dir"])
        with open(params["params_file_path"], 'w') as f:
            yaml.dump(params, f, default_flow_style=False)
        if os.path.exists(params["history_file_path"]):
            with open(params["history_file_path"], 'rb') as f:
                old_history = pickle.load(f)
                last_epoch = old_history["epoch"]
                if last_epoch > params["epochs"]:
                    print("training completed")
                    exit(0)
                model_file = params["model_file_path"] % (last_epoch)
                if os.path.exists(model_file):
                    model.load_state_dict(torch.load(model_file))
                    print("training resumed from epoch %d" % last_epoch)
                    return old_history, last_epoch
        return init_history(), 0

    history, last_epoch = train_initialization()
    for epoch in range(last_epoch + 1, params["epochs"] + 1):
        if len(history["dev_eval_vals_list"]) == 0:
            dev_roc = None
        else:
            dev_roc = [
                history["dev_eval_vals_list"][-1]["auroc"][class_name]
                for class_name in CLASS_NAMES
            ]
        train_eval_vals = train(epoch, dev_roc)
        train_eval_vals2 = evaluate(epoch, train_evaluation_loader,
                                    train_evaluation_dataset, "train set")
        dev_eval_vals = evaluate(epoch, dev_loader, dev_dataset, "dev set")
        update_history(history, epoch, train_eval_vals, dev_eval_vals)
    print("training completed")
Пример #16
0
    else:
        print("=> no checkpoint found")

    print("=======>load dataset")

    BATCH_SIZE = 6
    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
    test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(), normalize
    ])
    print("=>load train dataset")
    train_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                     image_list_file=trainTXTFile,
                                     transform=test_transform)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=BATCH_SIZE,
                                               shuffle=False,
                                               num_workers=8,
                                               pin_memory=True)
    print("=>load test dataset")
    test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=testTXTFile,
                                    transform=test_transform)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=BATCH_SIZE,
                                              shuffle=False,
                                              num_workers=8,
                                              pin_memory=True)
    def train(self,
              TRAIN_IMAGE_LIST,
              VAL_IMAGE_LIST,
              NUM_EPOCHS=10,
              LR=0.001,
              BATCH_SIZE=64,
              start_epoch=0,
              logging=True,
              save_path=None,
              freeze_feature_layers=True):
        """
        Train the CovidAID
        """
        normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])

        train_dataset = ChestXrayDataSet(
            image_list_file=TRAIN_IMAGE_LIST,
            transform=transforms.Compose([
                transforms.Resize(256),
                transforms.TenCrop(224),
                transforms.Lambda(lambda crops: torch.stack(
                    [transforms.ToTensor()(crop) for crop in crops])),
                transforms.Lambda(lambda crops: torch.stack(
                    [normalize(crop) for crop in crops]))
            ]),
            combine_pneumonia=self.combine_pneumonia)
        if self.distributed:
            sampler = DistributedSampler(train_dataset)
            train_loader = DataLoader(dataset=train_dataset,
                                      batch_size=BATCH_SIZE,
                                      shuffle=False,
                                      num_workers=8,
                                      pin_memory=True,
                                      sampler=sampler)
        else:
            train_loader = DataLoader(dataset=train_dataset,
                                      batch_size=BATCH_SIZE,
                                      shuffle=True,
                                      num_workers=8,
                                      pin_memory=True)

        val_dataset = ChestXrayDataSet(
            image_list_file=VAL_IMAGE_LIST,
            transform=transforms.Compose([
                transforms.Resize(256),
                transforms.TenCrop(224),
                transforms.Lambda(lambda crops: torch.stack(
                    [transforms.ToTensor()(crop) for crop in crops])),
                transforms.Lambda(lambda crops: torch.stack(
                    [normalize(crop) for crop in crops]))
            ]),
            combine_pneumonia=self.combine_pneumonia)
        if self.distributed:
            sampler = DistributedSampler(val_dataset)
            val_loader = DataLoader(dataset=val_dataset,
                                    batch_size=BATCH_SIZE,
                                    shuffle=False,
                                    num_workers=8,
                                    pin_memory=True,
                                    sampler=sampler)
        else:
            val_loader = DataLoader(dataset=val_dataset,
                                    batch_size=BATCH_SIZE,
                                    shuffle=True,
                                    num_workers=8,
                                    pin_memory=True)

        # Freeze heads and create optimizer
        if freeze_feature_layers:
            print("Freezing feature layers")
            for param in self.net.densenet121.features.parameters():
                param.requires_grad = False

        # optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.net.parameters()),
        #                 lr=LR, momentum=0.9)
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      self.net.parameters()),
                               lr=LR)

        for epoch in range(start_epoch, NUM_EPOCHS):
            # switch to train mode
            self.net.train()
            tot_loss = 0.0
            for i, (inputs,
                    target) in tqdm(enumerate(train_loader),
                                    total=len(train_dataset) / BATCH_SIZE):
                # inputs = inputs.to(self.device)
                # target = target.to(self.device)
                inputs = inputs.cuda()
                target = target.cuda()

                # Shape of input == [BATCH_SIZE, NUM_CROPS=19, CHANNELS=3, HEIGHT=224, WIDTH=244]
                bs, n_crops, c, h, w = inputs.size()
                inputs = inputs.view(-1, c, h, w)
                inputs = torch.autograd.Variable(inputs.view(-1, c, h, w))
                target = torch.autograd.Variable(target)
                preds = self.net(inputs).view(bs, n_crops, -1).mean(dim=1)

                # loss = torch.sum(torch.abs(preds - target) ** 2)
                loss = train_dataset.loss(preds, target)
                # exit()
                tot_loss += float(loss.data)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            tot_loss /= len(train_dataset)

            # Clear cache
            torch.cuda.empty_cache()

            # Running on validation set
            self.net.eval()
            val_loss = 0.0
            for i, (inputs,
                    target) in tqdm(enumerate(val_loader),
                                    total=len(val_dataset) / BATCH_SIZE):
                # inputs = inputs.to(self.device)
                # target = target.to(self.device)
                inputs = inputs.cuda()
                target = target.cuda()

                # Shape of input == [BATCH_SIZE, NUM_CROPS=19, CHANNELS=3, HEIGHT=224, WIDTH=244]
                bs, n_crops, c, h, w = inputs.size()
                inputs = inputs.view(-1, c, h, w)
                inputs = torch.autograd.Variable(inputs.view(-1, c, h, w),
                                                 volatile=True)
                target = torch.autograd.Variable(target, volatile=True)

                preds = self.net(inputs).view(bs, n_crops, -1).mean(1)
                # loss = torch.sum(torch.abs(preds - target) ** 2)
                loss = val_dataset.loss(preds, target)

                val_loss += float(loss.data)

            val_loss /= len(val_dataset)

            # Clear cache
            torch.cuda.empty_cache()

            # logging statistics
            timestamp = str(datetime.datetime.now()).split('.')[0]
            log = json.dumps({
                'timestamp': timestamp,
                'epoch': epoch + 1,
                'train_loss': float('%.5f' % tot_loss),
                'val_loss': float('%.5f' % val_loss),
                'lr': float('%.6f' % LR)
            })
            if logging:
                print(log)

            log_file = os.path.join(save_path, 'train.log')
            if log_file is not None:
                with open(log_file, 'a') as f:
                    f.write("{}\n".format(log))

            model_path = os.path.join(save_path, 'epoch_%d.pth' % (epoch + 1))
            self.save_model(model_path)

        print('Finished Training')
Пример #18
0
def main(modelfile):
    model_xml = os.path.join('model', modelfile)
    model_bin = model_xml.replace('.xml', '.bin')

    log.info('Creating Inference Engine')
    ie = IECore()
    net = ie.read_network(model=model_xml, weights=model_bin)

    log.info('Preparing input blobs')
    input_blob = next(iter(net.input_info))
    out_blob = next(iter(net.outputs))
    net.batch_size = (args.batch_size * N_CROPS)

    n, c, h, w = net.input_info[input_blob].input_data.shape

    # for image load
    normalize = transforms.Normalize(
            [0.485, 0.456, 0.406],
            [0.229, 0.224, 0.225])

    test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
            image_list_file=TEST_IMAGE_LIST,
            transform=transforms.Compose([
                transforms.Resize(256),
                transforms.TenCrop(224),
                transforms.Lambda
                (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                transforms.Lambda
                (lambda crops: torch.stack([normalize(crop) for crop in crops]))
                ]))
    
    test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=False)

    gt = torch.FloatTensor()
    pred = torch.FloatTensor()
    
    # loading model to the plugin
    log.info('Loading model to the plugin')

    #config = {'CPU_THREADS_NUM': '48', 'CPU_THROUGHPUT_STREAMS': 'CPU_THROUGHPUT_AUTO'}
    config = {'CPU_THROUGHPUT_STREAMS': '%d' % args.cpu_throughput_streams}
    exec_net = ie.load_network(network=net, device_name='CPU', config=config, num_requests=args.num_requests)

    # Number of requests
    infer_requests = exec_net.requests
    print('reqeuest len', len(infer_requests))
    request_queue = InferRequestsQueue(infer_requests, out_blob)

    start_time = timeit.default_timer()

    for i, (inp, target) in enumerate(test_loader):
        bs, n_crops, c, h, w = inp.size()

        images = inp.view(-1, c, h, w).numpy()

        if bs !=  args.batch_size:
            images2 = np.zeros(shape=(args.batch_size * n_crops, c, h, w))
            images2[:bs*n_crops, :c, :h, :w] = images
            images = images2

        infer_request = request_queue.get_idle_request()

        infer_request.start_async({input_blob: images}, bs, target)

        if i == 20:
            break
        
    # wait the latest inference executions
    request_queue.wait_all()
    for i, queue in enumerate(request_queue.requests):
        # print(i, queue)
        gt = torch.cat((gt, queue.get_ground_truth()), 0)
        pred = torch.cat((pred, queue.get_prediction()), 0)
        
    print('Elapsed time: %0.2f sec.' % (timeit.default_timer() - start_time))

    AUCs = [roc_auc_score(gt.cpu()[:, i], pred.cpu()[:, i]) for i in range(N_CLASSES)]
    AUC_avg = np.array(AUCs).mean()
    print('The average AUC is {AUC_avg:.3f}'.format(AUC_avg=AUC_avg))
    for i in range(N_CLASSES):
        print('The AUC of {} is {:.3f}'.format(CLASS_NAMES[i], AUCs[i]))
Пример #19
0
def main():

    # initialize and load the model

    if USE_DENSENET:
        model = DenseNet121(N_CLASSES).cpu()
    else:
        model = ResNet18(N_CLASSES).cpu()

    print(model)

    model = torch.nn.DataParallel(model).cpu()

    if(USE_DENSENET):
        if os.path.isfile(CKPT_PATH):
            print("=> loading checkpoint")
            checkpoint = torch.load(CKPT_PATH)
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint")
        else:
            print("=> no checkpoint found")

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    #read training data and train
    train_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=TRAIN_IMAGE_LIST,
                                    transform=transforms.Compose([
                                        transforms.Resize((224, 224)),
                                        transforms.ToTensor(),
                                        normalize
                                    ])
    )
    train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, num_workers=8, pin_memory=True)


    val_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=VAL_IMAGE_LIST,
                                    transform=transforms.Compose([
                                        transforms.Resize((224, 224)),
                                        transforms.ToTensor(),
                                        normalize
                                    ])
    )
    val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, num_workers=8, pin_memory=True)


    criterion = nn.BCELoss().cpu()
    optimizer = optim.Adam(model.parameters())

    for epoch in range(0, RUNS):
        print("Epoch " + str(epoch + 1))
        train_run(model, train_loader, optimizer, criterion, epoch)
        val_run(model, val_loader, criterion, epoch)

    test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=TEST_IMAGE_LIST,
                                    transform=transforms.Compose([
                                        transforms.Resize(256),
                                        transforms.TenCrop(224),
                                        transforms.Lambda
                                        (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                                        transforms.Lambda
                                        (lambda crops: torch.stack([normalize(crop) for crop in crops]))
                                    ]))
    test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE,
                             shuffle=False, num_workers=8, pin_memory=True)

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor()
    gt = gt.cpu()
    pred = torch.FloatTensor()
    pred = pred.cpu()

    # switch to evaluate mode
    model.eval()

    for i, (inp, target) in enumerate(test_loader):
        target = target.cpu()
        gt = torch.cat((gt, target), 0)
        bs, n_crops, c, h, w = inp.size()
        input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cpu(), volatile=True)
        output = model(input_var)
        output_mean = output.view(bs, n_crops, -1).mean(1)
        pred = torch.cat((pred, output_mean.data), 0)

    AUROCs = compute_AUCs(gt, pred)
    AUROC_avg = np.array(AUROCs).mean()
    print('The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs[i]))
Пример #20
0
def main():

    cudnn.benchmark = True

    # initialize and load the model
    model = DenseNet121(N_CLASSES).cuda()
    model = torch.nn.DataParallel(model).cuda()

    if os.path.isfile(CKPT_PATH):
        print("=> loading checkpoint")
        checkpoint = torch.load(CKPT_PATH)
        if OLD_CHECKPOINT:
            update_checkpoint_dict(checkpoint)
            model.load_state_dict(checkpoint['state_dict'])
        else:
            model.load_state_dict(checkpoint)
        print("=> loaded checkpoint")
    else:
        print("=> no checkpoint found")

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    test_dataset = ChestXrayDataSet(
        data_dir=DATA_DIR,
        image_list_file=TEST_IMAGE_LIST,
        transform=transforms.Compose([
            transforms.Resize(256),
            # crop ten images from original
            transforms.TenCrop(224),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops])),
            transforms.Lambda(
                lambda crops: torch.stack([normalize(crop) for crop in crops]))
        ]))
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             num_workers=8,
                             pin_memory=True)

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor()
    gt = gt.cuda()
    pred = torch.FloatTensor()
    pred = pred.cuda()

    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        for i, (inp, target) in enumerate(test_loader):
            target = target.cuda()

            # batch_size, num_class
            gt = torch.cat((gt, target), 0)
            # batch_size, n_crops, channels, height, weights
            bs, n_crops, c, h, w = inp.size()
            input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda(),
                                                volatile=True)
            # output shape: batch_size*n_crops,num_class
            output = model(input_var)
            # average across crops
            # output_mean shape: batch_size, num_class
            output_mean = output.view(bs, n_crops, -1).mean(1)
            pred = torch.cat((pred, output_mean.data), 0)

    # np.save("./pred.pkl", pred.cpu().numpy())
    # np.save("./gt.pkl", gt.cpu().numpy())

    AUROCs = compute_AUCs(gt, pred, N_CLASSES)
    AUROCs = np.array(AUROCs)
    AUROC_avg = AUROCs.mean()
    print('The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs[i]))
    base_dir = os.path.dirname(CKPT_PATH)
    np.savetxt(os.path.join(base_dir, "test_auroc.txt"), AUROCs)
Пример #21
0
def main():

    cudnn.benchmark = True

    # initialize and load the model
    model = DenseNet121(N_CLASSES).cuda()
    model = torch.nn.DataParallel(model).cuda()

    if os.path.isfile(CKPT_PATH):
        print("=> loading checkpoint")
        checkpoint = torch.load(CKPT_PATH)
        #*************************debug*************************
        state_dict=checkpoint['state_dict']
        #print("before")
        #print(next(iter(state_dict)))
        #print("=> transforming parameter key names")
        pattern=re.compile(
                r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
                )
        for key in list(state_dict.keys()):
            res=pattern.match(key)
            if res:
                new_key=res.group(1)+res.group(2)
                state_dict[new_key]=state_dict[key]
                del state_dict[key]
        #*************************debug*************************
        #print("after")
        #print(next(iter(state_dict)))
        model.load_state_dict(state_dict)
        print("=> loaded checkpoint")
    else:
        print("=> no checkpoint found")

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=TEST_IMAGE_LIST,
                                    transform=transforms.Compose([
                                        transforms.Resize(256),
                                        transforms.TenCrop(224),
                                        transforms.Lambda
                                        (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                                        transforms.Lambda
                                        (lambda crops: torch.stack([normalize(crop) for crop in crops]))
                                    ]))
    test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE,
                             shuffle=False, num_workers=8, pin_memory=True)

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor()
    gt = gt.cuda()
    pred = torch.FloatTensor()
    pred = pred.cuda()

    # switch to evaluate mode
    model.eval()
    with torch.no_grad():
        for i, (inp, target) in enumerate(test_loader):
            target = target.cuda()
            gt = torch.cat((gt, target), 0)
            bs, n_crops, c, h, w = inp.size()
            input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda())
            #input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda(), volatile=True)
            output = model(input_var)
            output_mean = output.view(bs, n_crops, -1).mean(1)
            pred = torch.cat((pred, output_mean.data), 0)

    AUROCs = compute_AUCs(gt, pred)
    AUROC_avg = np.array(AUROCs).mean()
    print('The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs[i]))
Пример #22
0
def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using %s device.' % device)

    # initialize and load the model
    net = DenseNet121(N_CLASSES).to(device)

    if torch.cuda.device_count() > 1:
        net = torch.nn.DataParallel(net).to(device)

    net.load_state_dict(torch.load(args.model_path, map_location=device))
    print('model state has loaded')

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    test_dataset = ChestXrayDataSet(
        data_dir=DATA_DIR,
        image_list_file=TEST_IMAGE_LIST,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.TenCrop(224),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops])),
            transforms.Lambda(
                lambda crops: torch.stack([normalize(crop) for crop in crops]))
        ]))
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=args.batch_size,
                             shuffle=False)

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor().to(device)
    pred = torch.FloatTensor().to(device)

    # switch to evaluate mode
    net.eval()

    for index, (data, target) in enumerate(test_loader):
        start_time = timeit.default_timer()

        target = target.to(device)
        bs, n_crops, c, h, w = data.size()
        data = data.view(-1, c, h, w).to(device)

        with torch.no_grad():
            output = net(data)

        output_mean = output.view(bs, n_crops, -1).mean(1)

        gt = torch.cat((gt, target))
        pred = torch.cat((pred, output_mean))

        print('\rbatch %03d/%03d %6.3fsec' %
              (index, len(test_loader), (timeit.default_timer() - start_time)),
              end='')

    AUCs = []
    for i in range(N_CLASSES):
        AUCs.append(roc_auc_score(gt.cpu()[:, i], pred.cpu()[:, i]))
    print('The average AUC is %6.3f' % np.mean(AUCs))

    for i in range(N_CLASSES):
        print('The AUC of %s is %6.3f' % (CLASS_NAMES[i], AUCs[i]))
Пример #23
0
def main():

    cudnn.benchmark = True
    generate_image_list(DATA_DIR, TEST_IMAGE_LIST)

    # initialize and load the model
    model = DenseNet121(N_CLASSES).cuda()
    model = torch.nn.DataParallel(model).cuda()

    if os.path.isfile(CKPT_PATH):
        print("=> loading checkpoint")
        checkpoint = torch.load(CKPT_PATH)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint")
    else:
        print("=> no checkpoint found")

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
    x = 10

    test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=TEST_IMAGE_LIST,
                                    transform=transforms.Compose([
                                        transforms.Resize(224),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        normalize

                                        # transforms.Resize(256),
                                        # transforms.TenCrop(224),
                                        # transforms.Lambda
                                        # (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops[:x]])),
                                        # transforms.Lambda
                                        # (lambda crops: torch.stack([normalize(crop) for crop in crops[:x]]))
                                        # transforms.Lambda
                                        # (lambda crops: torch.stack([normalize(crops)]))
                                    ]))
    test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE,
                             shuffle=False, num_workers=8, pin_memory=True)

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor()
    gt = gt.cuda()
    pred = torch.FloatTensor()
    pred = pred.cuda()

    # switch to evaluate mode
    model.eval()

    # mkdir SAVE_DIR
    fc_dir = os.path.join(SAVE_DIR, 'densenet121_fc_feats')
    att_dir = os.path.join(SAVE_DIR, 'densenet121_att_feats')
    if not os.path.exists(fc_dir):
        os.mkdir(fc_dir)
    if not os.path.exists(att_dir):
        os.mkdir(att_dir)

    for i, (names, inp, target) in enumerate(test_loader):
        print i, len(test_loader)
        target = target.cuda()
        gt = torch.cat((gt, target), 0)
        bs, c, h, w = inp.size()
        input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda(), volatile=True)
        att_feats, fc_feats, output = model(input_var)
        output_mean = output
        pred = torch.cat((pred, output_mean.data), 0)

        # save features
        att_feats, fc_feats = att_feats.data.cpu().numpy(), fc_feats.data.cpu().numpy()
        for i_name, name in enumerate(names):
            name = name.split('/')[-1].split('.')[0]
            name = str(image_id_dict[name])
            np.save(os.path.join(fc_dir,name+'.npy'), fc_feats[i_name])
            np.savez(os.path.join(att_dir,name+'.npz'), feats=att_feats[i_name])

    return
    AUROCs = compute_AUCs(gt, pred)
    AUROC_avg = np.array(AUROCs).mean()
    print('The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs[i]))
Пример #24
0
def main():

    cudnn.benchmark = True

    # initialize and load the model
    model = DenseNet121(N_CLASSES).cpu()
    # model = torch.nn.DataParallel(model).cuda()

    if os.path.isfile(CKPT_PATH):
        print("=> loading checkpoint")
        checkpoint = torch.load(CKPT_PATH)

        # Code modified from torchvision densenet source for loading from pre .4 densenet weights.
        state_dict = checkpoint['state_dict']
        remove_data_parallel = True  # Change if you don't want to use nn.DataParallel(model)

        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
        )
        for key in list(state_dict.keys()):
            match = pattern.match(key)
            new_key = match.group(1) + match.group(2) if match else key
            new_key = new_key[7:] if remove_data_parallel else new_key
            state_dict[new_key] = state_dict[key]
            # Delete old key only if modified.
            if match or remove_data_parallel:
                del state_dict[key]
        print("=> loaded checkpoint")
    else:
        print("=> no checkpoint found")

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    test_dataset = ChestXrayDataSet(
        data_dir=DATA_DIR,
        image_list_file=TEST_IMAGE_LIST,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.TenCrop(224),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops])),
            transforms.Lambda(
                lambda crops: torch.stack([normalize(crop) for crop in crops]))
        ]))
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             num_workers=8,
                             pin_memory=True)

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor()
    gt = gt.cpu()
    pred = torch.FloatTensor()
    pred = pred.cpu()

    # switch to evaluate mode
    model.eval()

    for i, (inp, target) in enumerate(test_loader):
        target = target.cpu()
        gt = torch.cat((gt, target), 0)
        bs, n_crops, c, h, w = inp.size()
        input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cpu())
        output = model(input_var)
        print(output.tolist())
        output_mean = output.view(bs, n_crops, -1).mean(1)
        pred = torch.cat((pred, output_mean.data), 0)

    AUROCs = compute_AUCs(gt, pred)
    AUROC_avg = np.array(AUROCs).mean()
    print('The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs[i]))
Пример #25
0
    def train(dataDir, imageListFileTrain, imageListFileVal, transResize,
              transCrop, isTrained, classCount, batchSize, epochSize,
              launchTimestamp, checkpoint):
        """Train the network.

        Args:
            dataDir - path to the data dir
            imageListFileTrain - path to the iamge list file to train
            imageListFileVal - path to the iamge list file to train
            transResize - size of the image to scale down to
            transCrop - size of the cropped image
            isTrained - if True, uses pre-trained version of the network (pre-trained on imagenet)
            classCount - number of output classes
            batchSize - batch size
            epochSize - number of epochs
            launchTimestamp - date/time, used to assign unique name for the checkpoint file
            checkpoint - if not None loads the model and continues training

        """

        # SETTINGS
        # ^^^^^^^^

        # initialize and load the model
        # ---------------
        print("train begins!=======================")
        model = CheXNet(classCount, isTrained).cuda()
        model = torch.nn.DataParallel(model).cuda()

        # data transforms
        # ---------------
        normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                         [0.229, 0.224, 0.225])
        transformList = []
        transformList.append(transforms.Resize(transResize))
        transformList.append(transforms.RandomResizedCrop(transCrop))
        transformList.append(transforms.RandomHorizontalFlip())
        transformList.append(transforms.ToTensor())
        transformList.append(normalize)
        transform = transforms.Compose(transformList)

        # datasets
        # ---------------
        datasetTrain = ChestXrayDataSet(data_dir=dataDir,
                                        image_list_file=imageListFileTrain,
                                        transform=transform)
        datasetVal = ChestXrayDataSet(data_dir=dataDir,
                                      image_list_file=imageListFileVal,
                                      transform=transform)

        print(datasetTrain.__len__())
        dataLoaderTrain = DataLoader(dataset=datasetTrain,
                                     batch_size=batchSize,
                                     shuffle=False,
                                     num_workers=8,
                                     pin_memory=True)
        dataLoaderVal = DataLoader(dataset=datasetVal,
                                   batch_size=batchSize,
                                   shuffle=False,
                                   num_workers=8,
                                   pin_memory=True)

        # optimizer and scheduler
        # ---------------
        optimizer = optim.Adam(model.parameters(),
                               lr=0.0001,
                               betas=(0.9, 0.999),
                               eps=1e-08,
                               weight_decay=1e-5)
        scheduler = ReduceLROnPlateau(optimizer,
                                      factor=0.1,
                                      patience=5,
                                      mode='min')

        # loss
        # ---------------
        loss = torch.nn.BCELoss(size_average=True)

        # Load checkpoint
        # ---------------
        if checkpoint != None:
            modelCheckpoint = torch.load(checkpoint)
            model.load_state_dict(modelCheckpoint['state_dict'])
            optimizer.load_state_dict(modelCheckpoint['optimizer'])

        # Train
        # ^^^^^

        # TODO: train, epochTrain and epochVal

        lossMin = 100000

        for epochIdx in range(0, epochSize):
            print("EpochIdx: ################ ", epochIdx)
            timestampTime = time.strftime("%H%M%S")
            timestampDate = time.strftime("%d%m%Y")
            timestampSTART = timestampDate + '-' + timestampTime

            ChexnetTrainer.epochTrain(model, dataLoaderTrain, optimizer,
                                      scheduler, classCount, loss)
            lossVal = ChexnetTrainer.epochVal(model, dataLoaderVal, optimizer,
                                              scheduler, classCount, loss)

            timestampTime = time.strftime("%H%M%S")
            timestampDate = time.strftime("%d%m%Y")
            timestampEND = timestampDate + '-' + timestampTime

            scheduler.step(lossVal)

            if lossVal < lossMin:
                lossMin = lossVal
                torch.save(
                    {
                        'epoch': epochIdx + 1,
                        'state_dict': model.state_dict(),
                        'best_loss': lossMin,
                        'optimizer': optimizer.state_dict()
                    }, 'm-' + launchTimestamp + '.pth.tar')
                print('Epoch [' + str(epochIdx + 1) + '] [save] [' +
                      timestampEND + '] loss= ' + str(lossVal))
            else:
                print('Epoch [' + str(epochIdx + 1) + '] [----] [' +
                      timestampEND + '] loss= ' + str(lossVal))
Пример #26
0
def main():

    cudnn.benchmark = True

    # initialize and load the model
    model = DenseNet121(N_CLASSES).cuda()
    model = torch.nn.DataParallel(model).cuda()

    #if os.path.isfile(CKPT_PATH):
    if 0:
        print("=> loading checkpoint")
        checkpoint = torch.load(CKPT_PATH)
        model.load_state_dict(checkpoint['state_dict'])
        print("=> loaded checkpoint")
    else:
        print("=> no checkpoint found")

    #define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimiser = optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    train_dataset = ChestXrayDataSet(
        data_dir=DATA_DIR,
        image_list_file=TRAIN_IMAGE_LIST,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.TenCrop(224),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops])),
            transforms.Lambda(
                lambda crops: torch.stack([normalize(crop) for crop in crops]))
        ]))
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              num_workers=LOADER_WORKERS,
                              pin_memory=True)

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor()
    gt = gt.cuda()
    pred = torch.FloatTensor()
    pred = pred.cuda()

    running_loss = 0.0
    # switch to train mode
    model.train()
    print("starting loop")
    try:
        for epoch in range(2):
            for i, (inp, label) in enumerate(train_loader):
                label = label.cuda()
                gt = torch.cat((gt, label), 0)
                bs, n_crops, c, h, w = inp.size()
                input_var = torch.autograd.Variable(
                    inp.view(-1, c, h, w).cuda())
                #zero parameter gradients
                optimiser.zero_grad()
                #fw + back + optimise
                output = model(
                    input_var)  #output dim should be: minibatch, classnum,
                loss = criterion(output, label)
                loss.backward()
                optimiser.step()
                #output_mean = output.view(bs, n_crops, -1).mean(1)
                #pred = torch.cat((pred, output_mean.data), 0)
                running_loss += loss.item()
                #print statistics
                if not i % 100:
                    print('[%d, %5d] loss=%.3f' % (epoch, i, running_loss))
    except:
        print('error in iteration: ' + str(i))
        raise ()
Пример #27
0
def main():
	# ================= TRANSFORMS ================= #
	normalize = transforms.Normalize(
	   mean=[0.485, 0.456, 0.406],
	   std=[0.229, 0.224, 0.225]
	)

	transform_train = transforms.Compose([
	   transforms.Resize(tuple(exp_cfg['dataset']['resize'])),
	   transforms.RandomResizedCrop(tuple(exp_cfg['dataset']['crop'])),
	   transforms.RandomHorizontalFlip(),
	   transforms.ToTensor(),
	   normalize,
	])

	transform_test = transforms.Compose([
	   transforms.Resize(tuple(exp_cfg['dataset']['resize'])),
	   transforms.CenterCrop(tuple(exp_cfg['dataset']['crop'])),
	   transforms.ToTensor(),
	   normalize,
	])

	# ================= LOAD DATASET ================= #
	train_dataset = ChestXrayDataSet(data_dir = data_dir,split = 'train', transform = transform_train)
	train_loader = DataLoader(dataset = train_dataset, batch_size = max_batch_capacity, shuffle = True, num_workers = 4)

	val_dataset = ChestXrayDataSet(data_dir = data_dir, split = 'val', transform = transform_test)
	val_loader = DataLoader(dataset = val_dataset, batch_size = 32, shuffle = False, num_workers = 4)

	test_dataset = ChestXrayDataSet(data_dir = data_dir, split = 'test', transform = transform_test)
	test_loader = DataLoader(dataset = test_dataset, batch_size = 32, shuffle = False, num_workers = 4)

	# ================= MODELS ================= #
	Model = ResidualAttention().to(device)
	# LocalModel = Net(exp_cfg['backbone']).to(device)
	# FusionModel = FusionNet(exp_cfg['backbone']).to(device)

	# ================= OPTIMIZER ================= #
	optimizer = optim.SGD(Model.parameters(), **exp_cfg['optimizer']['SGD'])
	# optimizer_local = optim.SGD(LocalModel.parameters(), **exp_cfg['optimizer']['SGD'])
	# optimizer_fusion = optim.SGD(FusionModel.parameters(), **exp_cfg['optimizer']['SGD'])

	# ================= SCHEDULER ================= #
	lr_scheduler = optim.lr_scheduler.StepLR(optimizer , **exp_cfg['lr_scheduler'])
	# lr_scheduler_local = optim.lr_scheduler.StepLR(optimizer_local , **exp_cfg['lr_scheduler'])
	# lr_scheduler_fusion = optim.lr_scheduler.StepLR(optimizer_fusion , **exp_cfg['lr_scheduler'])

	# ================= LOSS FUNCTION ================= #
	criterion = nn.BCELoss()

	if args.resume:
		start_epoch = 0
		checkpoint_global = path.join(args.exp_dir, args.exp_dir.split('/')[-1] + '_global.pth')
		checkpoint_local = path.join(args.exp_dir, args.exp_dir.split('/')[-1] + '_local.pth')
		checkpoint_fusion = path.join(args.exp_dir, args.exp_dir.split('/')[-1] + '_fusion.pth')

		if path.isfile(checkpoint_global):
			save_dict = torch.load(checkpoint_global)
			start_epoch = max(save_dict['epoch'], start_epoch)
			GlobalModel.load_state_dict(save_dict['net'])
			optimizer_global.load_state_dict(save_dict['optim'])
			lr_scheduler_global.load_state_dict(save_dict['lr_scheduler'])
			print(" Loaded Global Branch Model checkpoint")

		if path.isfile(checkpoint_local):
			save_dict = torch.load(checkpoint_local)
			start_epoch = max(save_dict['epoch'], start_epoch)
			LocalModel.load_state_dict(save_dict['net'])
			optimizer_local.load_state_dict(save_dict['optim'])
			lr_scheduler_local.load_state_dict(save_dict['lr_scheduler'])
			print(" Loaded Local Branch Model checkpoint")

		if path.isfile(checkpoint_fusion):
			save_dict = torch.load(checkpoint_fusion)
			start_epoch = max(save_dict['epoch'], start_epoch)
			FusionModel.load_state_dict(save_dict['net'])
			optimizer_fusion.load_state_dict(save_dict['optim'])
			lr_scheduler_fusion.load_state_dict(save_dict['lr_scheduler'])
			print(" Loaded Fusion Branch Model checkpoint")

		start_epoch += 1

	else:
		start_epoch = 0

	for epoch in range(start_epoch, exp_cfg['NUM_EPOCH']):
		print(' Epoch [{}/{}]'.format(epoch , exp_cfg['NUM_EPOCH'] - 1))

		Model.train()
		# LocalModel.train()
		# FusionModel.train()
		
		running_loss = 0.
		mini_batch_loss = 0.

		# count = 0
		# batch_multiplier = 16

		progressbar = tqdm(range(len(train_loader)))

		for i, (image, target) in enumerate(train_loader):

			image_cuda = image.to(device)
			target_cuda = target.to(device)

			optimizer.zero_grad()

			# if count == 0:
			# 	optimizer_global.step()
			# 	optimizer_local.step()
			# 	optimizer_fusion.step()

			# 	optimizer_global.zero_grad()
			# 	optimizer_local.zero_grad()
			# 	optimizer_fusion.zero_grad()

			# 	count = batch_multiplier

			# compute output
			output = Model(image_cuda)
			
			# loss
			loss = criterion(output, target_cuda)
			# loss_local = criterion(output_local, target_cuda)
			# loss_fusion = criterion(output_fusion, target_cuda)

			# loss = (0.8 * loss_global + 0.1 *loss_local + 0.1 * loss_fusion) / batch_multiplier
			loss.backward()
			# count -= 1
			
			# progressbar.set_description(" bacth loss: {loss:.3f} "
			# 							"loss1: {loss1:.3f} "
			# 							"loss2: {loss2:.3f} "
			# 							"loss3: {loss3:.3f}".format(loss = loss * batch_multiplier,
			# 														loss1 = loss_global,
			# 														loss2 = loss_local,
			# 														loss3 = loss_fusion))
			progressbar.update(1)

			running_loss += loss.data.item()

		progressbar.close()

		lr_scheduler.step()
		# lr_scheduler_local.step()
		# lr_scheduler_fusion.step()

		# SAVE MODEL
		# save_model(args.exp_dir, epoch,
		# 			model = GlobalModel,
		# 			optimizer = optimizer_global,
		# 			lr_scheduler = lr_scheduler_global,
		# 			branch_name = 'global')
		# save_model(args.exp_dir, epoch,
		# 			model = LocalModel,
		# 			optimizer = optimizer_local,
		# 			lr_scheduler = lr_scheduler_local,
		# 			branch_name = 'local')
		# save_model(args.exp_dir, epoch,
		# 			model = FusionModel,
		# 			optimizer = optimizer_fusion,
		# 			lr_scheduler = lr_scheduler_fusion,
		# 			branch_name = 'fusion')

		epoch_train_loss = float(running_loss) / float(i)
		print(' Epoch over Loss: {:.5f}'.format(epoch_train_loss))
def main():
    print("[Info]: Loading Data ...")
    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    train_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                     image_list_file=TRAIN_IMAGE_LIST,
                                     transform=transforms.Compose([
                                         transforms.Resize(224),
                                         transforms.CenterCrop(224),
                                         transforms.ToTensor(),
                                         normalize,
                                     ]))
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=4,
                              pin_memory=True)

    test_dataset = ChestXrayDataSet(data_dir=DATA_DIR,
                                    image_list_file=VAL_IMAGE_LIST,
                                    transform=transforms.Compose([
                                        transforms.Resize(256),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        normalize,
                                    ]))
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=128,
                             shuffle=False,
                             num_workers=4,
                             pin_memory=True)
    print("[Info]: Data has been loaded ...")

    print("[Info]: Loading Model ...")
    # initialize and load the model
    Global_Branch_model = Densenet121_AG(pretrained=False,
                                         num_classes=N_CLASSES).cuda()
    Local_Branch_model = Densenet121_AG(pretrained=False,
                                        num_classes=N_CLASSES).cuda()
    Fusion_Branch_model = Fusion_Branch(input_size=2048,
                                        output_size=N_CLASSES).cuda()

    if os.path.isfile(CKPT_PATH):
        print("[Info]: Loading checkpoint")
        checkpoint = torch.load(CKPT_PATH)
        # to load state
        # Code modified from torchvision densenet source for loading from pre .4 densenet weights.
        state_dict = checkpoint['state_dict']
        remove_data_parallel = True  # Change if you don't want to use nn.DataParallel(model)

        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
        )
        for key in list(state_dict.keys()):
            ori_key = key
            key = key.replace('densenet121.', '')
            #print('key',key)
            match = pattern.match(key)
            new_key = match.group(1) + match.group(2) if match else key
            new_key = new_key[7:] if remove_data_parallel else new_key
            #print('new_key',new_key)
            if '.0.' in new_key:
                new_key = new_key.replace('0.', '')
            state_dict[new_key] = state_dict[ori_key]
            # Delete old key only if modified.
            if match or remove_data_parallel:
                del state_dict[ori_key]

        Global_Branch_model.load_state_dict(state_dict)
        Local_Branch_model.load_state_dict(state_dict)
        print("[Info]: Loaded baseline checkpoint")

    else:
        print("[Info]: No previous checkpoint found ...")

    if os.path.isfile(CKPT_PATH_G):
        checkpoint = torch.load(CKPT_PATH_G)
        Global_Branch_model.load_state_dict(checkpoint)
        print("[Info]: Loaded Global_Branch_model checkpoint")

    if os.path.isfile(CKPT_PATH_L):
        checkpoint = torch.load(CKPT_PATH_L)
        Local_Branch_model.load_state_dict(checkpoint)
        print("[Info]: Loaded Local_Branch_model checkpoint")

    if os.path.isfile(CKPT_PATH_F):
        checkpoint = torch.load(CKPT_PATH_F)
        Fusion_Branch_model.load_state_dict(checkpoint)
        print("[Info]: Loaded Fusion_Branch_model checkpoint")

    cudnn.benchmark = True
    criterion = nn.BCELoss()
    optimizer_global = optim.Adam(Global_Branch_model.parameters(),
                                  lr=LR_G,
                                  betas=(0.9, 0.999),
                                  eps=1e-08,
                                  weight_decay=1e-5)
    lr_scheduler_global = lr_scheduler.StepLR(optimizer_global,
                                              step_size=10,
                                              gamma=1)

    optimizer_local = optim.Adam(Local_Branch_model.parameters(),
                                 lr=LR_L,
                                 betas=(0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=1e-5)
    lr_scheduler_local = lr_scheduler.StepLR(optimizer_local,
                                             step_size=10,
                                             gamma=1)

    optimizer_fusion = optim.Adam(Fusion_Branch_model.parameters(),
                                  lr=LR_F,
                                  betas=(0.9, 0.999),
                                  eps=1e-08,
                                  weight_decay=1e-5)
    lr_scheduler_fusion = lr_scheduler.StepLR(optimizer_fusion,
                                              step_size=15,
                                              gamma=0.1)
    print("[Info]: Model has been loaded ...")

    print("[Info]: Starting training ...")
    for epoch in range(num_epochs):
        since = time.time()
        print('Epoch: {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        #set the mode of model
        lr_scheduler_global.step()  #about lr and gamma
        lr_scheduler_local.step()
        lr_scheduler_fusion.step()
        Global_Branch_model.train()  #set model to training mode
        Local_Branch_model.train()
        Fusion_Branch_model.train()

        running_loss = 0.0
        #Iterate over data
        for i, (input, target) in enumerate(train_loader):
            input_var = torch.autograd.Variable(input.cuda())
            target_var = torch.autograd.Variable(target.cuda())
            optimizer_global.zero_grad()
            optimizer_local.zero_grad()
            optimizer_fusion.zero_grad()

            # compute output
            output_global, fm_global, pool_global = Global_Branch_model(
                input_var)

            patchs_var = Attention_gen_patchs(input, fm_global)
            torch.cuda.empty_cache()
            output_local, _, pool_local = Local_Branch_model(patchs_var)
            #print(fusion_var.shape)
            output_fusion = Fusion_Branch_model(pool_global, pool_local)
            #
            torch.cuda.empty_cache()
            # loss
            loss1 = criterion(output_global, target_var)
            loss2 = criterion(output_local, target_var)
            loss3 = criterion(output_fusion, target_var)
            #

            loss = loss1 * 0.8 + loss2 * 0.1 + loss3 * 0.1

            if (i % 500) == 0:
                print(
                    'step: {} totalloss: {loss:.3f} loss1: {loss1:.3f} loss2: {loss2:.3f} loss3: {loss3:.3f}'
                    .format(i,
                            loss=loss,
                            loss1=loss1,
                            loss2=loss2,
                            loss3=loss3))

            loss.backward()
            optimizer_global.step()
            optimizer_local.step()
            optimizer_fusion.step()

            #print(loss.data.item())
            running_loss += loss.data.item()
            #break
            '''
            if i == 40:
                print('break')
                break
            '''

        epoch_loss = float(running_loss) / float(i)
        print(' Epoch over  Loss: {:.5f}'.format(epoch_loss))

        print("[Info]: Starting testing ...")
        test(Global_Branch_model, Local_Branch_model, Fusion_Branch_model,
             test_loader)
        #break

        #save
        if epoch % 1 == 0:
            save_path = save_model_path
            torch.save(
                Global_Branch_model.state_dict(), save_path + save_model_name +
                '_Global' + '_epoch_' + str(epoch) + '.pkl')
            print('Global_Branch_model already save!')
            torch.save(
                Local_Branch_model.state_dict(), save_path + save_model_name +
                '_Local' + '_epoch_' + str(epoch) + '.pkl')
            print('Local_Branch_model already save!')
            torch.save(
                Fusion_Branch_model.state_dict(), save_path + save_model_name +
                '_Fusion' + '_epoch_' + str(epoch) + '.pkl')
            print('Fusion_Branch_model already save!')

        time_elapsed = time.time() - since
        print('Training one epoch complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
Пример #29
0
def main():
    N_CLASSES = 14
    CLASS_NAMES = ['Atelectasis', 
                'Cardiomegaly', 
                'Effusion', 
                'Infiltration', 
                'Mass', 
                'Nodule', 
                'Pneumonia',
                'Pneumothorax', 
                'Consolidation', 
                'Edema', 
                'Emphysema', 
                'Fibrosis', 
                'Pleural_Thickening', 
                'Hernia']



    # initialize model
    device = utils.get_device()
    model = DenseNet121(N_CLASSES).to(device)
 
    
    
    checkpoint = torch.load(args.checkpoint)

    model.load_state_dict(checkpoint['model_state_dict'])


    # initialize test loader
    test_dataset = ChestXrayDataSet(data_dir=args.path_to_images,
                                    image_list_file=args.test_list,
                                    transform=transforms_test)
    test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size,
                            shuffle=False, num_workers=args.num_workers, pin_memory=True)

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor()
    gt = gt.cuda()
    pred = torch.FloatTensor()
    pred = pred.cuda()

    # switch to evaluate mode
    
    model.eval()
    with torch.no_grad():
        for i, (inp, target) in enumerate(test_loader):
            target = target.cuda()
            gt = torch.cat((gt, target), 0)
            bs, c, h, w = inp.size()
            input_var = torch.autograd.Variable(inp.view(-1, c, h, w).cuda())
            output = model(input_var)
            output_mean = output.view(bs, -1)
            pred = torch.cat((pred, output_mean.data), 0)

    gt_np = gt.cpu().numpy()
    pred_np = sigmoid(pred.cpu().numpy())

    Y_t = [] #labels for each anomaly
    for i in range(N_CLASSES):
        Y_t.append([])
        for x in gt_np:
            Y_t[i].append(x[i])

    Y_pred = [] #preds for each anomaly
    for j in range(N_CLASSES):
        Y_pred.append([])
        for y in pred_np:
            Y_pred[j].append(y[j])


    AUCs = [] # AUCs for each 
    for i in range(N_CLASSES):
        auc = roc_auc_score(Y_t[i], Y_pred[i])
        AUCs.append(auc)

    matrices=[] #for each
    for i in range(14):
        matrix = confusion_matrix(Y_t[i], np.asarray(Y_pred[i])>0.6)
        matrices.append(matrix)

    
    class_names = ['no disease', 'disease']
    fig = plt.figure(figsize = (20,20))
    for i in range(14):
        plt.subplot(4,4,i+1)
        
        df_cm = pd.DataFrame(
            matrices[i], index=class_names, columns=class_names)
        heatmap = sns.heatmap(df_cm, annot=True, fmt="d").set_title(CLASS_NAMES[i])
        plt.ylabel('True label')
        plt.xlabel('Predicted label')
        
        
        
    plt.show()
    fig.savefig(os.path.join(args.test_outdir,'confusion_matrix.pdf'))

    fig, axes2d = plt.subplots(nrows=2, ncols=7,
                            sharex=True, sharey=True,figsize = (12, 4))



    for i, row in enumerate(axes2d):
        for j, cell in enumerate(row):
            if i==0:
                x=i+j
            else:
                x=13-i*j
            
            fpr, tpr, threshold = roc_curve(Y_t[x], Y_pred[x])
            roc_auc = auc(fpr, tpr)
                      
            cell.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
            cell.legend(loc = 'lower right', handlelength=0,handletextpad=0,frameon=False, prop={'size': 8})

            cell.plot([0, 1], [0, 1],'r--')
            plt.xlim([0, 1])
            plt.ylim([0, 1])
            cell.set_title(CLASS_NAMES[x],fontsize=10)
            
            if i == len(axes2d) - 1:
                cell.set_xlabel('False positive rate')
            if j == 0:
                cell.set_ylabel('True negative rate')
    fig.tight_layout(pad=1.0)    
    plt.show()
    fig.savefig(os.path.join(args.test_outdir,'roc_auc.pdf'))