예제 #1
0
    steps = 0
    for batch_idx, (images, masks) in enumerate(train_loader):
        images = images.to(device)
        masks = masks.to(device)
        
        model.train()
        opt.zero_grad()
        preds = model(images)

        loss = losses.dice_loss(preds, masks)
        
        loss.backward()
        #opt.step()
        
        train_losses.append(loss.item())
        train_dsc.append(losses.dice_score(preds, masks).item())


    else:        
        val_loss = 0
        val_acc = 0
        model.eval()
        with torch.no_grad():
            for inputs, masks in val_loader:
                inputs, masks = inputs.to(device), masks.to(device)
                preds = model.forward(inputs)
                loss = losses.dice_loss(preds, masks)

                val_losses.append(loss.item())
                val_dsc.append(losses.dice_score(preds,masks).item())
                scheduler.step(loss)
예제 #2
0
        test_dice_score = AverageMeter()

        with torch.no_grad():
            for i, data in enumerate(test_loader, 0):
                inputs, labels = data
                if torch.cuda.is_available():
                    inputs = inputs.cuda(non_blocking=True)
                    labels = labels.cuda(non_blocking=True)
                outputs = net(inputs)
                loss_test_tmp = loss_fct(outputs, labels)
                test_loss.append(loss_test_tmp.item())

                if torch.cuda.is_available():
                    res = np.round(
                        outputs[0, 1, :, :, :].cpu().numpy()).astype(int)
                    test_dice_score.append(
                        losses.dice_score(res,
                                          labels[0, 0, :, :, :].cpu().numpy()))
                else:
                    res = np.round(outputs[0, 1, :, :, :].numpy()).astype(int)
                    test_dice_score.append(
                        losses.dice_score(res, labels[0, 0, :, :, :].numpy()))

                if epoch == params.N_EPOCHS - 1:
                    np.save("./last_epoch_results/test_" + str(i) + ".npy",
                            res)

        print("epoch " + str(epoch + 1) + ": %.3f, %.3f, %.3f" %
              (train_loss.avrg, test_loss.avrg, test_dice_score.avrg))

print('Finished Training')
예제 #3
0
def main():
    if args.restart_training == 'true':
        if use_multiinput_architecture is False:
            if modeltype == 'unet':
                model = UNet(n_classes=n_classes,
                             padding=True,
                             depth=model_depth,
                             wf=wf,
                             up_mode='upconv',
                             batch_norm=True,
                             residual=False).double().to(device)
            elif modeltype == 'resunet':
                model = UNet(n_classes=n_classes,
                             padding=True,
                             depth=model_depth,
                             wf=wf,
                             up_mode='upconv',
                             batch_norm=True,
                             residual=True).double().to(device)

        elif use_multiinput_architecture is True:
            if modeltype == 'unet':
                model = Attention_UNet(
                    n_classes=n_classes,
                    padding=True,
                    up_mode='upconv',
                    batch_norm=True,
                    residual=False,
                    wf=wf,
                    use_attention=use_attention).double().to(device)
            elif modeltype == 'resunet':
                model = Attention_UNet(
                    n_classes=n_classes,
                    padding=True,
                    up_mode='upconv',
                    batch_norm=True,
                    residual=True,
                    wf=wf,
                    use_attention=use_attention).double().to(device)

    else:
        if use_multiinput_architecture is False:
            if modeltype == 'unet':
                model = UNet(n_classes=n_classes,
                             padding=True,
                             depth=model_depth,
                             wf=wf,
                             up_mode='upconv',
                             batch_norm=True,
                             residual=False).double().to(device)
            elif modeltype == 'resunet':
                model = UNet(n_classes=n_classes,
                             padding=True,
                             depth=model_depth,
                             wf=wf,
                             up_mode='upconv',
                             batch_norm=True,
                             residual=True).double().to(device)

            # checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)
            # pretrained_dict = checkpoint['model_state_dict']

            # model_dict = model.state_dict()
            # # 1. filter out unnecessary keys
            # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k not in ['last.weight', 'last.bias']}
            # # 2. overwrite entries in the existing state dict
            # model_dict.update(pretrained_dict)
            # # 3. load the new state dict
            # model.load_state_dict(model_dict)

        elif use_multiinput_architecture is True:
            if modeltype == 'unet':
                model = Attention_UNet(
                    n_classes=n_classes,
                    padding=True,
                    up_mode='upconv',
                    batch_norm=True,
                    residual=False,
                    wf=wf,
                    use_attention=use_attention).double().to(device)
            elif modeltype == 'resunet':
                model = Attention_UNet(
                    n_classes=n_classes,
                    padding=True,
                    up_mode='upconv',
                    batch_norm=True,
                    residual=True,
                    wf=wf,
                    use_attention=use_attention).double().to(device)

            # checkpoint = torch.load(args.model_path, map_location=lambda storage, loc: storage)
            # model.load_state_dict(checkpoint['model_state_dict'])

        checkpoint = torch.load(args.model_path,
                                map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint['model_state_dict'])

    train_loader = dataloader_cxr.DataLoader(data_path,
                                             dataloader_type='train',
                                             batchsize=batch_size,
                                             device=device,
                                             image_resolution=image_resolution)
    print('trainloader loaded')
    valid_loader = dataloader_cxr.DataLoader(data_path,
                                             dataloader_type='valid',
                                             batchsize=batch_size,
                                             device=device,
                                             image_resolution=image_resolution)
    print('validloader loaded')

    loss_list_train_epoch = [None]
    dice_score_list_train_epoch = [None]
    epoch_data_list = [None]

    loss_list_validation = [None]
    loss_list_validation_index = [None]
    dice_score_list_validation = [None]
    dice_score_list_validation_0 = [None]
    dice_score_list_validation_1 = [None]

    epoch_old = 0
    if load_old_lists == True:
        if args.restart_training == 'false':
            epoch_old = checkpoint['epochs']

            if checkpoint['train_loss_list_epoch'][-1] == None:
                dice_score_list_train_epoch = [None]
                loss_list_train_epoch = [None]
                epoch_data_list = [None]

            else:
                dice_score_list_train_epoch = checkpoint[
                    'train_dice_score_list_epoch']
                loss_list_train_epoch = checkpoint['train_loss_list_epoch']
                epoch_data_list = checkpoint['train_loss_index_epoch']

            if checkpoint['valid_loss_list'][-1] == None:
                loss_list_validation = [None]
                loss_list_validation_index = [None]

                dice_score_list_validation = [None]
                dice_score_list_validation_0 = [None]
                dice_score_list_validation_1 = [None]

            else:
                loss_list_validation = checkpoint['valid_loss_list']
                loss_list_validation_index = checkpoint['valid_loss_index']
                dice_score_list_validation = checkpoint[
                    'valid_dice_score_list']
                dice_score_list_validation_0 = checkpoint[
                    'valid_dice_score_list_0']
                dice_score_list_validation_1 = checkpoint[
                    'valid_dice_score_list_1']
                best_model_accuracy = np.max(dice_score_list_validation[1:])

    if len(train_loader.data_list) % batch_size == 0:
        total_idx_train = len(train_loader.data_list) // batch_size
    else:
        total_idx_train = len(train_loader.data_list) // batch_size + 1

    if len(valid_loader.data_list) % batch_size == 0:
        total_idx_valid = len(valid_loader.data_list) // batch_size
    else:
        total_idx_valid = len(valid_loader.data_list) // batch_size + 1

    if epoch_old != 0:
        power_factor = epoch_old // scheduler_step_size
        LR_ = LR * (scheduler_gamma**power_factor)
    else:
        LR_ = LR

    LR_ = LR
    optimizer = optim.Adam(model.parameters(), lr=LR_)
    # optimizer = optim.SGD(model.parameters(), lr=LR_, momentum=0.9)
    scheduler = StepLR(optimizer,
                       step_size=scheduler_step_size,
                       gamma=scheduler_gamma)

    for epoch in range(epoch_old, train_epoch):

        if (epoch + 1) % 10 == 0:
            scheduler.step()

        epoch_loss = 0.0
        epoch_dice_score = 0.0
        train_count = 0

        model.train()
        for idx in range(total_idx_train):

            optimizer.zero_grad()

            batch_images_input, batch_label_input = train_loader[idx]
            output = model(batch_images_input)

            if use_multiinput_architecture is False:
                loss = losses.dice_loss(
                    output,
                    batch_label_input,
                    weights=torch.Tensor([gamma0, gamma1]).double().to(device))

            elif use_multiinput_architecture is True:
                loss = losses.dice_loss_deep_supervised(
                    output,
                    batch_label_input,
                    weights=torch.Tensor([gamma0, gamma1]).double().to(device))

            loss.backward()
            optimizer.step()

            if use_multiinput_architecture is False:
                score = losses.dice_score(output, batch_label_input)
            else:
                score = losses.dice_score(output[-1], batch_label_input)

            epoch_dice_score += (score.sum().item() /
                                 score.size(0)) * batch_images_input.shape[0]

            epoch_loss += loss.item() * batch_images_input.shape[0]
            train_count += batch_images_input.shape[0]

        loss_list_train_epoch.append(epoch_loss / train_count)
        epoch_data_list.append(epoch + 1)
        dice_score_list_train_epoch.append(epoch_dice_score / train_count)

        print(
            'Epoch %d Training Loss: %.3f Dice Score: %.3f' %
            (epoch + 1, loss_list_train_epoch[-1],
             dice_score_list_train_epoch[-1]), ' Time:',
            datetime.datetime.now())

        plt.plot(epoch_data_list[1:],
                 loss_list_train_epoch[1:],
                 label="Training",
                 color='red',
                 marker='o',
                 markerfacecolor='yellow',
                 markersize=5)
        plt.xlabel('Epoch')
        plt.ylabel('Training Loss')
        plt.savefig(plots_dir + '/train_loss_plot.png')
        plt.clf()

        plt.plot(epoch_data_list[1:],
                 dice_score_list_train_epoch[1:],
                 label="Training",
                 color='red',
                 marker='o',
                 markerfacecolor='yellow',
                 markersize=5)
        plt.xlabel('Epoch')
        plt.ylabel('Training Dice Score')
        plt.savefig(plots_dir + '/train_dice_score_plot.png')
        plt.clf()

        training_pickle = open(plots_pickle_dir + "/loss_list_train.npy", 'wb')
        pickle.dump(loss_list_train_epoch, training_pickle)
        training_pickle.close()

        training_pickle = open(plots_pickle_dir + "/epoch_list_train.npy",
                               'wb')
        pickle.dump(epoch_data_list, training_pickle)
        training_pickle.close()

        training_pickle = open(
            plots_pickle_dir + "/dice_score_list_train_epoch.npy", 'wb')
        pickle.dump(dice_score_list_train_epoch, training_pickle)
        training_pickle.close()

        if (epoch + 1) % save_every == 0:
            print('Saving model at %d epoch' % (epoch + 1), ' Time:',
                  datetime.datetime.now()
                  )  # save every save_every mini_batch of data
            torch.save(
                {
                    'epochs': epoch + 1,
                    'batchsize': batch_size,
                    'train_loss_list_epoch': loss_list_train_epoch,
                    'train_dice_score_list_epoch': dice_score_list_train_epoch,
                    'train_loss_index_epoch': epoch_data_list,
                    'valid_loss_list': loss_list_validation,
                    'valid_dice_score_list': dice_score_list_validation,
                    'valid_dice_score_list_0': dice_score_list_validation_0,
                    'valid_dice_score_list_1': dice_score_list_validation_1,
                    'valid_loss_index': loss_list_validation_index,
                    'model_state_dict': model.state_dict(),
                }, model_checkpoint_dir + '/model_%d.pth' % (epoch + 1))

        if (epoch + 1) % valid_every == 0:
            model.eval()
            optimizer.zero_grad()

            valid_count = 0
            total_loss_valid = 0.0
            valid_dice_score = 0.0
            valid_dice_score_0 = 0.0
            valid_dice_score_1 = 0.0

            for idx in range(total_idx_valid):
                with torch.no_grad():

                    batch_images_input, batch_label_input = valid_loader[idx]

                    output = model(batch_images_input)

                    if use_multiinput_architecture is False:
                        loss = losses.dice_loss(output, batch_label_input)
                    else:
                        loss = losses.dice_loss(output[-1], batch_label_input)

                    total_loss_valid += loss.item(
                    ) * batch_images_input.shape[0]
                    valid_count += batch_images_input.shape[0]

                    if use_multiinput_architecture is False:
                        score = losses.dice_score(output, batch_label_input)
                    else:
                        score = losses.dice_score(output[-1],
                                                  batch_label_input)

                    valid_dice_score += (score.sum().item() / score.size(0)
                                         ) * batch_images_input.shape[0]

                    valid_dice_score_0 += score[0].item(
                    ) * batch_images_input.shape[0]
                    valid_dice_score_1 += score[1].item(
                    ) * batch_images_input.shape[0]

            loss_list_validation.append(total_loss_valid / valid_count)
            dice_score_list_validation.append(valid_dice_score / valid_count)

            dice_score_list_validation_0.append(valid_dice_score_0 /
                                                valid_count)
            dice_score_list_validation_1.append(valid_dice_score_1 /
                                                valid_count)

            loss_list_validation_index.append(epoch + 1)

            print(
                'Epoch %d Valid Loss: %.3f' %
                (epoch + 1, loss_list_validation[-1]), ' Time:',
                datetime.datetime.now())

            print('Valid Dice Score: ', dice_score_list_validation[-1],
                  ' Valid Dice Score 0: ', dice_score_list_validation_0[-1],
                  ' Valid Dice Score 1: ', dice_score_list_validation_1[-1])

            plt.plot(loss_list_validation_index[1:],
                     loss_list_validation[1:],
                     label="Validation",
                     color='red',
                     marker='o',
                     markerfacecolor='yellow',
                     markersize=5)
            plt.xlabel('Epoch')
            plt.ylabel('Validation Loss')
            plt.savefig(plots_dir + '/valid_loss_plot.png')
            plt.clf()

            plt.plot(loss_list_validation_index[1:],
                     dice_score_list_validation[1:],
                     label="Validation",
                     color='red',
                     marker='o',
                     markerfacecolor='yellow',
                     markersize=5)
            plt.xlabel('Epoch')
            plt.ylabel('Validation Dice Score')
            plt.savefig(plots_dir + '/valid_dice_score_plot.png')
            plt.clf()

            plt.plot(loss_list_validation_index[1:],
                     dice_score_list_validation_0[1:],
                     label="Validation",
                     color='red',
                     marker='o',
                     markerfacecolor='yellow',
                     markersize=5)
            plt.xlabel('Epoch')
            plt.ylabel('Validation Dice Score')
            plt.savefig(plots_dir + '/valid_dice_score_0_plot.png')
            plt.clf()

            plt.plot(loss_list_validation_index[1:],
                     dice_score_list_validation_1[1:],
                     label="Validation",
                     color='red',
                     marker='o',
                     markerfacecolor='yellow',
                     markersize=5)
            plt.xlabel('Epoch')
            plt.ylabel('Validation Dice Score')
            plt.savefig(plots_dir + '/valid_dice_score_1_plot.png')
            plt.clf()

            validation_pickle = open(
                plots_pickle_dir + "/loss_list_validation.npy", 'wb')
            pickle.dump(loss_list_validation, validation_pickle)
            validation_pickle.close()

            validation_pickle = open(
                plots_pickle_dir + "/index_list_validation.npy", 'wb')
            pickle.dump(loss_list_validation_index, validation_pickle)
            validation_pickle.close()

            validation_pickle = open(
                plots_pickle_dir + "/dice_score_list_validation.npy", 'wb')
            pickle.dump(dice_score_list_validation, validation_pickle)
            validation_pickle.close()

            if len(loss_list_validation) >= 3:
                if dice_score_list_validation[-1] > best_model_accuracy:
                    best_model_accuracy = dice_score_list_validation[-1]
                    torch.save(
                        {
                            'epochs': epoch + 1,
                            'batchsize': batch_size,
                            'train_loss_list_epoch': loss_list_train_epoch,
                            'train_dice_score_list_epoch':
                            dice_score_list_train_epoch,
                            'train_loss_index_epoch': epoch_data_list,
                            'valid_loss_list': loss_list_validation,
                            'valid_dice_score_list':
                            dice_score_list_validation,
                            'valid_dice_score_list_0':
                            dice_score_list_validation_0,
                            'valid_dice_score_list_1':
                            dice_score_list_validation_1,
                            'valid_loss_index': loss_list_validation_index,
                            'model_state_dict': model.state_dict(),
                        }, model_checkpoint_dir + '/model_best.pth')

            else:
                best_model_accuracy = dice_score_list_validation[-1]
                torch.save(
                    {
                        'epochs': epoch + 1,
                        'batchsize': batch_size,
                        'train_loss_list_epoch': loss_list_train_epoch,
                        'train_dice_score_list_epoch':
                        dice_score_list_train_epoch,
                        'train_loss_index_epoch': epoch_data_list,
                        'valid_loss_list': loss_list_validation,
                        'valid_dice_score_list': dice_score_list_validation,
                        'valid_dice_score_list_0':
                        dice_score_list_validation_0,
                        'valid_dice_score_list_1':
                        dice_score_list_validation_1,
                        'valid_loss_index': loss_list_validation_index,
                        'model_state_dict': model.state_dict(),
                    }, model_checkpoint_dir + '/model_best.pth')
예제 #4
0
test_dataset = dataset_generator.getTestDataset(0, 1)
test_size = test_dataset.__len__()

test_results = np.zeros(tuple([test_size] + basic_image_size), dtype=int)
for i in range(test_size):
    test_results[i, :, :, :] = np.load(
        os.path.join(res_dir, "test_" + str(i) + ".npy"))

images = np.zeros(tuple([test_size] + basic_image_size), dtype=int)
GT_labels = np.zeros(tuple([test_size] + basic_image_size), dtype=int)
for i in range(test_size):
    images[i, :, :, :] = test_dataset.__getitem__(i)[0].numpy()[0]
    GT_labels[i, :, :, :] = test_dataset.__getitem__(i)[1].numpy()[0]

dice_scores = np.array([
    losses.dice_score(test_results[i], GT_labels[i]) for i in range(test_size)
])

img_id = 9
slice_id_list = [7, 10, 13]

print("AVERAGE DICE:", np.mean(dice_scores))
print()
print("Original image index: ", test_dataset.indices[img_id])
print("dice = ", dice_scores[img_id])
comb = []
for slice_id in slice_id_list:
    margin_size = 5

    color_scale_GT = [0.3, 0.0, 0.0]
    image_GT = np.zeros((basic_image_size[1], basic_image_size[2], 3))
예제 #5
0
def main():

    if use_multiinput_architecture is False:
        if modeltype == 'unet':
            model = UNet(n_classes=n_classes,
                         padding=True,
                         depth=model_depth,
                         up_mode='upconv',
                         batch_norm=True,
                         residual=False).double().to(device)
        elif modeltype == 'resunet':
            model = UNet(n_classes=n_classes,
                         padding=True,
                         depth=model_depth,
                         up_mode='upconv',
                         batch_norm=True,
                         residual=True).double().to(device)

    elif use_multiinput_architecture is True:
        if modeltype == 'unet':
            model = Attention_UNet(
                n_classes=n_classes,
                padding=True,
                up_mode='upconv',
                batch_norm=True,
                residual=False,
                wf=wf,
                use_attention=use_attention).double().to(device)
        elif modeltype == 'resunet':
            model = Attention_UNet(
                n_classes=n_classes,
                padding=True,
                up_mode='upconv',
                batch_norm=True,
                residual=True,
                wf=wf,
                use_attention=use_attention).double().to(device)

    checkpoint = torch.load(args.model_path,
                            map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['model_state_dict'])

    valid_loader = dataloader_cxr.DataLoader(data_path,
                                             dataloader_type=dataloader_type,
                                             batchsize=batch_size,
                                             device=device,
                                             image_resolution=image_resolution)

    if len(valid_loader.data_list) % batch_size == 0:
        total_idx_valid = len(valid_loader.data_list) // batch_size
    else:
        total_idx_valid = len(valid_loader.data_list) // batch_size + 1

    model.eval()
    prediction_array = np.zeros((len(valid_loader.data_list),
                                 image_resolution[0], image_resolution[1]))
    if valid_loader.dataloader_type != "test":
        input_mask_array = np.zeros((len(valid_loader.data_list),
                                     image_resolution[0], image_resolution[1]))

    valid_count = 0
    valid_dice_score = 0.0

    valid_dice_score_0 = 0.0
    valid_dice_score_1 = 0.0

    for idx in range(total_idx_valid):
        with torch.no_grad():

            if valid_loader.dataloader_type != "test":
                batch_images_input, batch_label_input = valid_loader[idx]
            else:
                batch_images_input = valid_loader[idx]

            output = model(batch_images_input)

            if use_multiinput_architecture is False:
                if len(valid_loader.data_list) % batch_size == 0:
                    temp_image = torch.max(
                        output, 1)[1].detach().cpu().numpy().astype(np.bool)
                    prediction_array[idx * batch_size:(idx + 1) *
                                     batch_size] = remove_small_regions(
                                         temp_image,
                                         0.02 * np.prod(image_resolution))

                    if valid_loader.dataloader_type != "test":
                        input_mask_array[idx * batch_size:(
                            idx + 1) * batch_size] = batch_label_input.detach(
                            ).cpu().numpy().astype(np.uint8)
                else:
                    if idx == len(valid_loader.data_list) // batch_size:
                        temp_image = torch.max(
                            output,
                            1)[1].detach().cpu().numpy().astype(np.bool)
                        prediction_array[idx *
                                         batch_size:] = remove_small_regions(
                                             temp_image,
                                             0.02 * np.prod(image_resolution))

                        if valid_loader.dataloader_type != "test":
                            input_mask_array[
                                idx * batch_size:] = batch_label_input.detach(
                                ).cpu().numpy().astype(np.uint8)
                    else:
                        temp_image = torch.max(
                            output,
                            1)[1].detach().cpu().numpy().astype(np.bool)
                        prediction_array[idx * batch_size:(idx + 1) *
                                         batch_size] = remove_small_regions(
                                             temp_image,
                                             0.02 * np.prod(image_resolution))

                        if valid_loader.dataloader_type != "test":
                            input_mask_array[idx * batch_size:(
                                idx +
                                1) * batch_size] = batch_label_input.detach(
                                ).cpu().numpy().astype(np.uint8)
            else:
                if len(valid_loader.data_list) % batch_size == 0:
                    temp_image = torch.max(output[-1],
                                           1)[1].detach().cpu().numpy().astype(
                                               np.bool)
                    prediction_array[idx * batch_size:(idx + 1) *
                                     batch_size] = remove_small_regions(
                                         temp_image,
                                         0.02 * np.prod(image_resolution))

                    if valid_loader.dataloader_type != "test":
                        input_mask_array[idx * batch_size:(
                            idx + 1) * batch_size] = batch_label_input.detach(
                            ).cpu().numpy().astype(np.uint8)
                else:
                    if idx == len(valid_loader.data_list) // batch_size:
                        temp_image = torch.max(
                            output[-1],
                            1)[1].detach().cpu().numpy().astype(np.bool)
                        prediction_array[idx *
                                         batch_size:] = remove_small_regions(
                                             temp_image,
                                             0.02 * np.prod(image_resolution))

                        if valid_loader.dataloader_type != "test":
                            input_mask_array[
                                idx * batch_size:] = batch_label_input.detach(
                                ).cpu().numpy().astype(np.uint8)
                    else:
                        temp_image = torch.max(
                            output[-1],
                            1)[1].detach().cpu().numpy().astype(np.bool)
                        prediction_array[idx * batch_size:(idx + 1) *
                                         batch_size] = remove_small_regions(
                                             temp_image,
                                             0.02 * np.prod(image_resolution))

                        if valid_loader.dataloader_type != "test":
                            input_mask_array[idx * batch_size:(
                                idx +
                                1) * batch_size] = batch_label_input.detach(
                                ).cpu().numpy().astype(np.uint8)

            if valid_loader.dataloader_type != "test":

                if use_multiinput_architecture is False:
                    loss = losses.dice_loss(output, batch_label_input)
                else:
                    loss = losses.dice_loss(output[-1], batch_label_input)

                valid_count += batch_images_input.shape[0]

                if use_multiinput_architecture is False:
                    score = losses.dice_score(output, batch_label_input)
                else:
                    score = losses.dice_score(output[-1], batch_label_input)

                valid_dice_score += (score.sum().item() / score.size(0)
                                     ) * batch_images_input.shape[0]

                valid_dice_score_0 += score[0].item(
                ) * batch_images_input.shape[0]
                valid_dice_score_1 += score[1].item(
                ) * batch_images_input.shape[0]

    if valid_loader.dataloader_type != "test":
        valid_dice_score = valid_dice_score / valid_count
        valid_dice_score_0 = valid_dice_score_0 / valid_count
        valid_dice_score_1 = valid_dice_score_1 / valid_count

    if generate_mask is True:

        for i, files in enumerate(valid_loader.data_list):
            temp_mask = prediction_array[i].astype(int)
            temp_mask = ndimage.zoom(
                temp_mask,
                np.asarray(valid_loader.original_size_array[files]) /
                np.asarray(temp_mask.shape),
                order=0)
            io.imsave(save_path + '/Pred_mask_' + files.split('.')[0] + '.png',
                      temp_mask)

    if valid_loader.dataloader_type != "test":
        print('Valid Dice Score: ', valid_dice_score, ' Valid Dice Score 0: ',
              valid_dice_score_0, ' Valid Dice Score 1: ', valid_dice_score_1)
예제 #6
0
def main():

    if use_multiinput_architecture is False:
        if modeltype == 'unet':
            model = UNet(n_classes=n_classes,
                         padding=True,
                         depth=model_depth,
                         up_mode='upsample',
                         batch_norm=True,
                         residual=False).double().to(device)
        elif modeltype == 'resunet':
            model = UNet(n_classes=n_classes,
                         padding=True,
                         depth=model_depth,
                         up_mode='upsample',
                         batch_norm=True,
                         residual=True).double().to(device)

    elif use_multiinput_architecture is True:
        if modeltype == 'unet':
            model = Attention_UNet(
                n_classes=n_classes,
                padding=True,
                up_mode='upconv',
                batch_norm=True,
                residual=False,
                wf=wf,
                use_attention=use_attention).double().to(device)
        elif modeltype == 'resunet':
            model = Attention_UNet(
                n_classes=n_classes,
                padding=True,
                up_mode='upconv',
                batch_norm=True,
                residual=True,
                wf=wf,
                use_attention=use_attention).double().to(device)

    checkpoint = torch.load(args.model_path,
                            map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['model_state_dict'])

    valid_loader = dataloader_cxr.DataLoader(data_path,
                                             dataloader_type=dataloader_type,
                                             batchsize=batch_size,
                                             device=device,
                                             image_resolution=image_resolution,
                                             invert=invert,
                                             remove_wires=remove_wires)

    if len(valid_loader.data_list) % batch_size == 0:
        total_idx_valid = len(valid_loader.data_list) // batch_size
    else:
        total_idx_valid = len(valid_loader.data_list) // batch_size + 1

    model.eval()
    prediction_array = np.zeros((len(valid_loader.data_list),
                                 image_resolution[0], image_resolution[1]))
    if valid_loader.dataloader_type != "test":
        input_mask_array = np.zeros((len(valid_loader.data_list),
                                     image_resolution[0], image_resolution[1]))

    valid_count = 0
    valid_dice_score = 0.0

    if 0 in classes:
        valid_dice_score_0 = 0.0
    if 1 in classes:
        valid_dice_score_1 = 0.0

    for idx in range(total_idx_valid):
        with torch.no_grad():

            if valid_loader.dataloader_type != "test":
                batch_images_input, batch_label_input = valid_loader[idx]
            else:
                batch_images_input = valid_loader[idx]

            output = model(batch_images_input)

            if use_multiinput_architecture is False:
                if len(valid_loader.data_list) % batch_size == 0:
                    temp_image = torch.max(
                        output, 1)[1].detach().cpu().numpy().astype(np.bool)
                    prediction_array[idx * batch_size:(idx + 1) *
                                     batch_size] = remove_small_regions(
                                         temp_image,
                                         0.02 * np.prod(image_resolution))

                    if valid_loader.dataloader_type != "test":
                        input_mask_array[idx * batch_size:(
                            idx + 1) * batch_size] = batch_label_input.detach(
                            ).cpu().numpy().astype(np.uint8)
                else:
                    if idx == len(valid_loader.data_list) // batch_size:
                        temp_image = torch.max(
                            output,
                            1)[1].detach().cpu().numpy().astype(np.bool)
                        prediction_array[idx *
                                         batch_size:] = remove_small_regions(
                                             temp_image,
                                             0.02 * np.prod(image_resolution))

                        if valid_loader.dataloader_type != "test":
                            input_mask_array[
                                idx * batch_size:] = batch_label_input.detach(
                                ).cpu().numpy().astype(np.uint8)
                    else:
                        temp_image = torch.max(
                            output,
                            1)[1].detach().cpu().numpy().astype(np.bool)
                        prediction_array[idx * batch_size:(idx + 1) *
                                         batch_size] = remove_small_regions(
                                             temp_image,
                                             0.02 * np.prod(image_resolution))

                        if valid_loader.dataloader_type != "test":
                            input_mask_array[idx * batch_size:(
                                idx +
                                1) * batch_size] = batch_label_input.detach(
                                ).cpu().numpy().astype(np.uint8)
            else:
                if len(valid_loader.data_list) % batch_size == 0:
                    temp_image = torch.max(output[-1],
                                           1)[1].detach().cpu().numpy().astype(
                                               np.bool)
                    prediction_array[idx * batch_size:(idx + 1) *
                                     batch_size] = remove_small_regions(
                                         temp_image,
                                         0.02 * np.prod(image_resolution))

                    if valid_loader.dataloader_type != "test":
                        input_mask_array[idx * batch_size:(
                            idx + 1) * batch_size] = batch_label_input.detach(
                            ).cpu().numpy().astype(np.uint8)
                else:
                    if idx == len(valid_loader.data_list) // batch_size:
                        temp_image = torch.max(
                            output[-1],
                            1)[1].detach().cpu().numpy().astype(np.bool)
                        prediction_array[idx *
                                         batch_size:] = remove_small_regions(
                                             temp_image,
                                             0.02 * np.prod(image_resolution))

                        if valid_loader.dataloader_type != "test":
                            input_mask_array[
                                idx * batch_size:] = batch_label_input.detach(
                                ).cpu().numpy().astype(np.uint8)
                    else:
                        temp_image = torch.max(
                            output[-1],
                            1)[1].detach().cpu().numpy().astype(np.bool)
                        prediction_array[idx * batch_size:(idx + 1) *
                                         batch_size] = remove_small_regions(
                                             temp_image,
                                             0.02 * np.prod(image_resolution))

                        if valid_loader.dataloader_type != "test":
                            input_mask_array[idx * batch_size:(
                                idx +
                                1) * batch_size] = batch_label_input.detach(
                                ).cpu().numpy().astype(np.uint8)

            if valid_loader.dataloader_type != "test":

                if use_multiinput_architecture is False:
                    loss = losses.dice_loss(output, batch_label_input,
                                            exclude_0)
                else:
                    loss = losses.dice_loss(output[-1], batch_label_input,
                                            exclude_0)

                valid_count += batch_images_input.shape[0]

                if use_multiinput_architecture is False:
                    score = losses.dice_score(output, batch_label_input,
                                              exclude_0)
                else:
                    score = losses.dice_score(output[-1], batch_label_input,
                                              exclude_0)

                valid_dice_score += (score.sum().item() / score.size(0)
                                     ) * batch_images_input.shape[0]

                if 0 in classes and 1 in classes and len(classes) == 2 and len(
                        clubbed) == 0:
                    valid_dice_score_0 += score[0].item(
                    ) * batch_images_input.shape[0]
                    valid_dice_score_1 += score[1].item(
                    ) * batch_images_input.shape[0]

    if valid_loader.dataloader_type != "test":
        valid_dice_score = valid_dice_score / valid_count

        if 0 in classes:
            valid_dice_score_0 = valid_dice_score_0 / valid_count
        if 1 in classes:
            valid_dice_score_1 = valid_dice_score_1 / valid_count

    if generate_mask is True:

        for i, files in enumerate(valid_loader.data_list):
            temp_mask = prediction_array[i].astype(int)
            temp_mask = ndimage.zoom(
                temp_mask,
                np.asarray(valid_loader.original_size_array[files]) /
                np.asarray(temp_mask.shape),
                order=0)
            temp_mask = sitk.GetImageFromArray(temp_mask)
            temp_mask = sitk.Cast(temp_mask, sitk.sitkUInt8)

            resampler = sitk.ResampleImageFilter()
            resampler.SetReferenceImage(temp_mask)
            resampler.SetOutputSpacing(valid_loader.spacing[files])
            resampler.SetSize(valid_loader.size_[files])
            resampler.SetInterpolator(sitk.sitkNearestNeighbor)
            temp_mask = resampler.Execute(temp_mask)
            temp_mask.SetOrigin(valid_loader.origin[files])

            #temp_name = ''
            #for j in range(len(files.split('.'))-1):
            #if files.split('.')[j] != 'nii':
            #temp_name = temp_name + files.split('.')[j] + '.'
            temp_name = files[:-4]
            sitk.WriteImage(temp_mask,
                            save_path + '/Pred_mask_' + temp_name + '.nii.gz')

            # sitk.WriteImage(temp_mask, save_path + '/Pred_mask_' + files.split('.')[0] + '.nii.gz')

            # io.imsave(save_path + '/Pred_mask_' + files.split('.')[0] + '.png', temp_mask)

            # if valid_loader.dataloader_type != "test":
            #   temp_img_plus_mask = prediction_array[i].astype(int) + input_mask_array[i]*2
            #   temp_img_plus_mask = ndimage.zoom(temp_img_plus_mask, np.asarray(valid_loader.original_size_array[files]) / np.asarray(temp_img_plus_mask.shape), order=0)

            #   temp_img_plus_mask = sitk.GetImageFromArray(temp_img_plus_mask)

            #   resampler = sitk.ResampleImageFilter()
            #   resampler.SetReferenceImage(temp_img_plus_mask)
            #   resampler.SetOutputSpacing(valid_loader.spacing[files])
            #   resampler.SetSize(valid_loader.size_[files])
            #   resampler.SetInterpolator(sitk.sitkNearestNeighbor)
            #   temp_img_plus_mask = resampler.Execute(temp_img_plus_mask)
            #   temp_img_plus_mask.SetOrigin(valid_loader.origin[files])

            #   sitk.WriteImage(temp_img_plus_mask, save_path + '/Pred_merged_mask_' + files.split('.')[0] + '.nii.gz')

    if valid_loader.dataloader_type != "test":

        if 0 in classes and 1 in classes and len(clubbed) == 0:
            print('Valid Dice Score: ', valid_dice_score,
                  ' Valid Dice Score 0: ', valid_dice_score_0,
                  ' Valid Dice Score 1: ', valid_dice_score_1)