Beispiel #1
0
def main():
    opt = Options(isTrain=False)
    opt.parse()
    opt.save_options()
    opt.print_options()

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in opt.test['gpu'])

    img_dir = opt.test['img_dir']
    label_dir = opt.test['label_dir']
    save_dir = opt.test['save_dir']
    model_path = opt.test['model_path']
    save_flag = opt.test['save_flag']
    tta = opt.test['tta']

    # check if it is needed to compute accuracies
    eval_flag = True if label_dir else False

    # data transforms
    test_transform = get_transforms(opt.transform['test'])

    # load model
    model = FullNet(opt.model['in_c'],
                    opt.model['out_c'],
                    n_layers=opt.model['n_layers'],
                    growth_rate=opt.model['growth_rate'],
                    drop_rate=opt.model['drop_rate'],
                    dilations=opt.model['dilations'],
                    is_hybrid=opt.model['is_hybrid'],
                    compress_ratio=opt.model['compress_ratio'],
                    layer_type=opt.model['layer_type'])
    model = torch.nn.DataParallel(model).cuda()
    cudnn.benchmark = True

    # ----- load trained model ----- #
    print("=> loading trained model")
    best_checkpoint = torch.load(model_path)
    model.load_state_dict(best_checkpoint['state_dict'])
    print("=> loaded model at epoch {}".format(best_checkpoint['epoch']))
    model = model.module

    # switch to evaluate mode
    model.eval()
    counter = 0
    print("=> Test begins:")

    img_names = os.listdir(img_dir)

    # pixel_accu, recall, precision, F1, dice, iou, haus, (AJI)
    num_metrics = 8 if opt.dataset == 'MultiOrgan' else 7
    avg_results = utils.AverageMeter(num_metrics)
    all_results = dict()

    if save_flag:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        strs = img_dir.split('/')
        prob_maps_folder = '{:s}/{:s}_prob_maps'.format(save_dir, strs[-1])
        seg_folder = '{:s}/{:s}_segmentation'.format(save_dir, strs[-1])
        if not os.path.exists(prob_maps_folder):
            os.mkdir(prob_maps_folder)
        if not os.path.exists(seg_folder):
            os.mkdir(seg_folder)

    for img_name in img_names:
        # load test image
        print('=> Processing image {:s}'.format(img_name))
        img_path = '{:s}/{:s}'.format(img_dir, img_name)
        img = Image.open(img_path)
        ori_h = img.size[1]
        ori_w = img.size[0]
        name = os.path.splitext(img_name)[0]
        if eval_flag:
            if opt.dataset == 'MultiOrgan':
                label_path = '{:s}/{:s}.png'.format(label_dir, name)
            else:
                label_path = '{:s}/{:s}_anno.bmp'.format(label_dir, name)
            label_img = misc.imread(label_path)

        input = test_transform((img, ))[0].unsqueeze(0)

        print('\tComputing output probability maps...')
        prob_maps = get_probmaps(input, model, opt)
        if tta:
            img_hf = img.transpose(Image.FLIP_LEFT_RIGHT)  # horizontal flip
            img_vf = img.transpose(Image.FLIP_TOP_BOTTOM)  # vertical flip
            img_hvf = img_hf.transpose(
                Image.FLIP_TOP_BOTTOM)  # horizontal and vertical flips

            input_hf = test_transform(
                (img_hf, ))[0].unsqueeze(0)  # horizontal flip input
            input_vf = test_transform(
                (img_vf, ))[0].unsqueeze(0)  # vertical flip input
            input_hvf = test_transform((img_hvf, ))[0].unsqueeze(
                0)  # horizontal and vertical flip input

            prob_maps_hf = get_probmaps(input_hf, model, opt)
            prob_maps_vf = get_probmaps(input_vf, model, opt)
            prob_maps_hvf = get_probmaps(input_hvf, model, opt)

            # re flip
            prob_maps_hf = np.flip(prob_maps_hf, 2)
            prob_maps_vf = np.flip(prob_maps_vf, 1)
            prob_maps_hvf = np.flip(np.flip(prob_maps_hvf, 1), 2)

            # rotation 90 and flips
            img_r90 = img.rotate(90, expand=True)
            img_r90_hf = img_r90.transpose(
                Image.FLIP_LEFT_RIGHT)  # horizontal flip
            img_r90_vf = img_r90.transpose(
                Image.FLIP_TOP_BOTTOM)  # vertical flip
            img_r90_hvf = img_r90_hf.transpose(
                Image.FLIP_TOP_BOTTOM)  # horizontal and vertical flips

            input_r90 = test_transform((img_r90, ))[0].unsqueeze(0)
            input_r90_hf = test_transform(
                (img_r90_hf, ))[0].unsqueeze(0)  # horizontal flip input
            input_r90_vf = test_transform(
                (img_r90_vf, ))[0].unsqueeze(0)  # vertical flip input
            input_r90_hvf = test_transform((img_r90_hvf, ))[0].unsqueeze(
                0)  # horizontal and vertical flip input

            prob_maps_r90 = get_probmaps(input_r90, model, opt)
            prob_maps_r90_hf = get_probmaps(input_r90_hf, model, opt)
            prob_maps_r90_vf = get_probmaps(input_r90_vf, model, opt)
            prob_maps_r90_hvf = get_probmaps(input_r90_hvf, model, opt)

            # re flip
            prob_maps_r90 = np.rot90(prob_maps_r90, k=3, axes=(1, 2))
            prob_maps_r90_hf = np.rot90(np.flip(prob_maps_r90_hf, 2),
                                        k=3,
                                        axes=(1, 2))
            prob_maps_r90_vf = np.rot90(np.flip(prob_maps_r90_vf, 1),
                                        k=3,
                                        axes=(1, 2))
            prob_maps_r90_hvf = np.rot90(np.flip(np.flip(prob_maps_r90_hvf, 1),
                                                 2),
                                         k=3,
                                         axes=(1, 2))

            prob_maps = (prob_maps + prob_maps_hf + prob_maps_vf +
                         prob_maps_hvf + prob_maps_r90 + prob_maps_r90_hf +
                         prob_maps_r90_vf + prob_maps_r90_hvf) / 8

        pred = np.argmax(prob_maps, axis=0)  # prediction
        pred_inside = pred == 1
        pred2 = morph.remove_small_objects(
            pred_inside, opt.post['min_area'])  # remove small object

        if 'scale' in opt.transform['test']:
            pred2 = misc.imresize(pred2.astype(np.uint8) * 255, (ori_h, ori_w),
                                  interp='bilinear')
            pred2 = (pred2 > 127.5)

        pred_labeled = measure.label(pred2)  # connected component labeling
        pred_labeled = morph.dilation(pred_labeled,
                                      selem=morph.selem.disk(
                                          opt.post['radius']))

        if eval_flag:
            print('\tComputing metrics...')
            result = utils.accuracy_pixel_level(
                np.expand_dims(pred_labeled > 0, 0),
                np.expand_dims(label_img > 0, 0))
            pixel_accu = result[0]

            if opt.dataset == 'MultiOrgan':
                result_object = utils.nuclei_accuracy_object_level(
                    pred_labeled, label_img)
            else:
                result_object = utils.gland_accuracy_object_level(
                    pred_labeled, label_img)

            all_results[name] = tuple([pixel_accu, *result_object])

            # update values
            avg_results.update([pixel_accu, *result_object])

        # save image
        if save_flag:
            print('\tSaving image results...')
            misc.imsave(
                '{:s}/{:s}_prob_inside.png'.format(prob_maps_folder, name),
                prob_maps[1, :, :])
            misc.imsave(
                '{:s}/{:s}_prob_contour.png'.format(prob_maps_folder, name),
                prob_maps[2, :, :])
            final_pred = Image.fromarray(pred_labeled.astype(np.uint16))
            final_pred.save('{:s}/{:s}_seg.tiff'.format(seg_folder, name))

            # save colored objects
            pred_colored = np.zeros((ori_h, ori_w, 3))
            for k in range(1, pred_labeled.max() + 1):
                pred_colored[pred_labeled == k, :] = np.array(
                    utils.get_random_color())
            filename = '{:s}/{:s}_seg_colored.png'.format(seg_folder, name)
            misc.imsave(filename, pred_colored)

        counter += 1
        if counter % 10 == 0:
            print('\tProcessed {:d} images'.format(counter))

    print('=> Processed all {:d} images'.format(counter))
    if eval_flag:
        print('Average of all images:\n'
              'pixel_accu: {r[0]:.4f}\n'
              'recall: {r[1]:.4f}\n'
              'precision: {r[2]:.4f}\n'
              'F1: {r[3]:.4f}\n'
              'dice: {r[4]:.4f}\n'
              'iou: {r[5]:.4f}\n'
              'haus: {r[6]:.4f}'.format(r=avg_results.avg))
        if opt.dataset == 'MultiOrgan':
            print('AJI: {r[7]:.4f}'.format(r=avg_results.avg))

        strs = img_dir.split('/')
        header = [
            'pixel_acc', 'recall', 'precision', 'F1', 'Dice', 'IoU',
            'Hausdorff'
        ]
        if opt.dataset == 'MultiOrgan':
            header.append('AJI')
        save_results(header, avg_results.avg, all_results,
                     '{:s}/{:s}_test_result.txt'.format(save_dir, strs[-1]))
Beispiel #2
0
def main():
    opt = Options(isTrain=False)
    opt.parse()
    opt.save_options()

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in opt.test['gpus'])

    img_dir = opt.test['img_dir']
    label_dir = opt.test['label_dir']
    save_dir = opt.test['save_dir']
    model_path = opt.test['model_path']
    save_flag = opt.test['save_flag']

    # data transforms
    test_transform = get_transforms(opt.transform['test'])

    model = ResUNet34(pretrained=opt.model['pretrained'])
    model = torch.nn.DataParallel(model)
    model = model.cuda()
    cudnn.benchmark = True

    # ----- load trained model ----- #
    print("=> loading trained model")
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])
    print("=> loaded model at epoch {}".format(checkpoint['epoch']))
    model = model.module

    # switch to evaluate mode
    model.eval()
    counter = 0
    print("=> Test begins:")

    img_names = os.listdir(img_dir)

    if save_flag:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        strs = img_dir.split('/')
        prob_maps_folder = '{:s}/{:s}_prob_maps'.format(save_dir, strs[-1])
        seg_folder = '{:s}/{:s}_segmentation'.format(save_dir, strs[-1])
        if not os.path.exists(prob_maps_folder):
            os.mkdir(prob_maps_folder)
        if not os.path.exists(seg_folder):
            os.mkdir(seg_folder)

    metric_names = ['acc', 'p_F1', 'p_recall', 'p_precision', 'dice', 'aji']
    test_results = dict()
    all_result = utils.AverageMeter(len(metric_names))

    for img_name in img_names:
        # load test image
        print('=> Processing image {:s}'.format(img_name))
        img_path = '{:s}/{:s}'.format(img_dir, img_name)
        img = Image.open(img_path)
        ori_h = img.size[1]
        ori_w = img.size[0]
        name = os.path.splitext(img_name)[0]
        label_path = '{:s}/{:s}_label.png'.format(label_dir, name)
        gt = misc.imread(label_path)

        input = test_transform((img, ))[0].unsqueeze(0)

        print('\tComputing output probability maps...')
        prob_maps = get_probmaps(input, model, opt)
        pred = np.argmax(prob_maps, axis=0)  # prediction

        pred_labeled = measure.label(pred)
        pred_labeled = morph.remove_small_objects(pred_labeled,
                                                  opt.post['min_area'])
        pred_labeled = ndi_morph.binary_fill_holes(pred_labeled > 0)
        pred_labeled = measure.label(pred_labeled)

        print('\tComputing metrics...')
        metrics = compute_metrics(pred_labeled, gt, metric_names)

        # save result for each image
        test_results[name] = [
            metrics['acc'], metrics['p_F1'], metrics['p_recall'],
            metrics['p_precision'], metrics['dice'], metrics['aji']
        ]

        # update the average result
        all_result.update([
            metrics['acc'], metrics['p_F1'], metrics['p_recall'],
            metrics['p_precision'], metrics['dice'], metrics['aji']
        ])

        # save image
        if save_flag:
            print('\tSaving image results...')
            misc.imsave('{:s}/{:s}_pred.png'.format(prob_maps_folder, name),
                        pred.astype(np.uint8) * 255)
            misc.imsave('{:s}/{:s}_prob.png'.format(prob_maps_folder, name),
                        prob_maps[1, :, :])
            final_pred = Image.fromarray(pred_labeled.astype(np.uint16))
            final_pred.save('{:s}/{:s}_seg.tiff'.format(seg_folder, name))

            # save colored objects
            pred_colored_instance = np.zeros((ori_h, ori_w, 3))
            for k in range(1, pred_labeled.max() + 1):
                pred_colored_instance[pred_labeled == k, :] = np.array(
                    utils.get_random_color())
            filename = '{:s}/{:s}_seg_colored.png'.format(seg_folder, name)
            misc.imsave(filename, pred_colored_instance)

        counter += 1
        if counter % 10 == 0:
            print('\tProcessed {:d} images'.format(counter))

    print('=> Processed all {:d} images'.format(counter))
    print('Average Acc: {r[0]:.4f}\nF1: {r[1]:.4f}\nRecall: {r[2]:.4f}\n'
          'Precision: {r[3]:.4f}\nDice: {r[4]:.4f}\nAJI: {r[5]:.4f}\n'.format(
              r=all_result.avg))

    header = metric_names
    utils.save_results(header, all_result.avg, test_results,
                       '{:s}/test_results.txt'.format(save_dir))
def main(opt):
    global best_score, logger, logger_results
    best_score = 0
    opt.save_options()

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in opt.train['gpus'])

    # set up logger
    logger, logger_results = setup_logging(opt)
    opt.print_options(logger)

    if opt.train['random_seed'] >= 0:
        # logger.info("=> Using random seed {:d}".format(opt.train['random_seed']))
        torch.manual_seed(opt.train['random_seed'])
        torch.cuda.manual_seed(opt.train['random_seed'])
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(opt.train['random_seed'])
        random.seed(opt.train['random_seed'])
    else:
        torch.backends.cudnn.benchmark = True

    # ----- create model ----- #
    model = ResUNet34(pretrained=opt.model['pretrained'],
                      with_uncertainty=opt.with_uncertainty)
    # model = nn.DataParallel(model)
    model = model.cuda()

    # ----- define optimizer ----- #
    optimizer = torch.optim.Adam(model.parameters(),
                                 opt.train['lr'],
                                 betas=(0.9, 0.99),
                                 weight_decay=opt.train['weight_decay'])

    # ----- define criterion ----- #
    criterion = torch.nn.NLLLoss(ignore_index=2).cuda()

    # ----- load data ----- #
    data_transforms = {
        'train': get_transforms(opt.transform['train']),
        'val': get_transforms(opt.transform['val'])
    }

    img_dir = '{:s}/train'.format(opt.train['img_dir'])
    target_vor_dir = '{:s}/train'.format(opt.train['label_vor_dir'])
    target_cluster_dir = '{:s}/train'.format(opt.train['label_cluster_dir'])
    dir_list = [img_dir, target_vor_dir, target_cluster_dir]
    post_fix = ['label_vor.png', 'label_cluster.png']
    num_channels = [3, 3, 3]
    train_set = DataFolder(dir_list, post_fix, num_channels,
                           data_transforms['train'])
    train_loader = DataLoader(train_set,
                              batch_size=opt.train['batch_size'],
                              shuffle=True,
                              num_workers=opt.train['workers'])

    # ----- optionally load from a checkpoint for validation or resuming training ----- #
    if opt.train['checkpoint']:
        if os.path.isfile(opt.train['checkpoint']):
            logger.info("=> loading checkpoint '{}'".format(
                opt.train['checkpoint']))
            checkpoint = torch.load(opt.train['checkpoint'])
            opt.train['start_epoch'] = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                opt.train['checkpoint'], checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(
                opt.train['checkpoint']))

    # ----- training and validation ----- #
    num_epochs = opt.train['num_epochs']

    for epoch in range(opt.train['start_epoch'], num_epochs):
        # train for one epoch or len(train_loader) iterations
        logger.info('Epoch: [{:d}/{:d}]'.format(epoch + 1, num_epochs))
        train_loss, train_loss_vor, train_loss_cluster = train(
            opt, train_loader, model, optimizer, criterion)

        # evaluate on val set
        with torch.no_grad():
            val_acc, val_aji = validate(opt, model, data_transforms['val'])

        # check if it is the best accuracy
        is_best = val_aji > best_score
        best_score = max(val_aji, best_score)

        cp_flag = (epoch + 1) % opt.train['checkpoint_freq'] == 0
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, epoch, opt.train['save_dir'], is_best, cp_flag)

        # save the training results to txt files
        logger_results.info(
            '{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(
                epoch + 1, train_loss, train_loss_vor, train_loss_cluster,
                val_acc, val_aji))

    for i in list(logger.handlers):
        logger.removeHandler(i)
        i.flush()
        i.close()
    for i in list(logger_results.handlers):
        logger_results.removeHandler(i)
        i.flush()
        i.close()
Beispiel #4
0
def main():
    global opt, best_iou, num_iter, tb_writer, logger, logger_results
    best_iou = 0
    opt = Options(isTrain=True)
    opt.parse()
    opt.save_options()

    tb_writer = SummaryWriter('{:s}/tb_logs'.format(opt.train['save_dir']))
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(x) for x in opt.train['gpu'])

    # set up logger 日志
    logger, logger_results = setup_logging(opt)
    opt.print_options(logger)

    # ----- create model ----- # 创建模型
    model = FullNet(opt.model['in_c'], opt.model['out_c'], n_layers=opt.model['n_layers'],
                    growth_rate=opt.model['growth_rate'], drop_rate=opt.model['drop_rate'],
                    dilations=opt.model['dilations'], is_hybrid=opt.model['is_hybrid'],
                    compress_ratio=opt.model['compress_ratio'], layer_type=opt.model['layer_type'])
    model = nn.DataParallel(model)
    model = model.cuda()
    torch.backends.cudnn.benchmark = True

    # ----- define optimizer ----- # 定义优化器
    optimizer = torch.optim.Adam(model.parameters(), opt.train['lr'], betas=(0.9, 0.99),
                                 weight_decay=opt.train['weight_decay'])

    # ----- define criterion ----- # 定义评估函数
    criterion = torch.nn.NLLLoss(reduction='none').cuda()

    if opt.train['alpha'] > 0:
        logger.info('=> Using variance term in loss...')
        global criterion_var
        criterion_var = LossVariance()

    data_transforms = {'train': get_transforms(opt.transform['train']),
                       'val': get_transforms(opt.transform['val'])}

    # ----- load data ----- # 加载数据
    dsets = {}
    for x in ['train', 'val']:
        img_dir = '{:s}/{:s}'.format(opt.train['img_dir'], x)
        target_dir = '{:s}/{:s}'.format(opt.train['label_dir'], x)
        weight_map_dir = '{:s}/{:s}'.format(opt.train['weight_map_dir'], x)
        dir_list = [img_dir, weight_map_dir, target_dir]
        if opt.dataset == 'MultiOrgan':
            post_fix = ['weight.png', 'label.png']
        else:
            post_fix = ['anno_weight.png', 'anno.bmp']
        num_channels = [3, 1, 3]
        dsets[x] = DataFolder(dir_list, post_fix, num_channels, data_transforms[x])
    train_loader = DataLoader(dsets['train'], batch_size=opt.train['batch_size'], shuffle=True,
                              num_workers=opt.train['workers'])
    val_loader = DataLoader(dsets['val'], batch_size=1, shuffle=False, 
                            num_workers=opt.train['workers']) #这里的验证batch也是1

    # ----- optionally load from a checkpoint for validation or resuming training 可选择从checkpoint中加载用于验证或者恢复训练----- #
    if opt.train['checkpoint']:
        if os.path.isfile(opt.train['checkpoint']):
            logger.info("=> loading checkpoint '{}'".format(opt.train['checkpoint']))
            checkpoint = torch.load(opt.train['checkpoint'])
            opt.train['start_epoch'] = checkpoint['epoch']
            best_iou = checkpoint['best_iou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint '{}' (epoch {})"
                        .format(opt.train['checkpoint'], checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(opt.train['checkpoint']))

    # ----- training and validation 训练与验证----- #
    for epoch in range(opt.train['start_epoch'], opt.train['num_epochs']):
        # train for one epoch or len(train_loader) iterations 训练每个epoch
        logger.info('Epoch: [{:d}/{:d}]'.format(epoch+1, opt.train['num_epochs']))
        train_results = train(train_loader, model, optimizer, criterion, epoch)
        train_loss, train_loss_ce, train_loss_var, train_pixel_acc, train_iou = train_results

        # evaluate on validation set 再验证集上进行验证
        with torch.no_grad():
            val_loss, val_pixel_acc, val_iou = validate(val_loader, model, criterion)

        # check if it is the best accuracy 检查是否是最好的准确率
        is_best = val_iou > best_iou
        best_iou = max(val_iou, best_iou)

        cp_flag = (epoch+1) % opt.train['checkpoint_freq'] == 0

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_iou': best_iou,
            'optimizer' : optimizer.state_dict(),
        }, epoch, is_best, opt.train['save_dir'], cp_flag)

        # save the training results to txt files
        logger_results.info('{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'
                            .format(epoch+1, train_loss, train_loss_ce, train_loss_var, train_pixel_acc,
                                    train_iou, val_loss, val_pixel_acc, val_iou))
        # tensorboard logs
        tb_writer.add_scalars('epoch_losses',
                              {'train_loss': train_loss, 'train_loss_ce': train_loss_ce,
                               'train_loss_var': train_loss_var, 'val_loss': val_loss}, epoch)
        tb_writer.add_scalars('epoch_accuracies',
                              {'train_pixel_acc': train_pixel_acc, 'train_iou': train_iou,
                               'val_pixel_acc': val_pixel_acc, 'val_iou': val_iou}, epoch)
    tb_writer.close()
Beispiel #5
0
def main(opt):
    # opt = Options(isTrain=False)
    opt.isTrain = False
    # opt.parse()
    opt.define_transforms()

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in opt.test['gpus'])

    img_dir = opt.test['img_dir']
    label_dir = opt.test['label_dir']
    model_path = opt.test['model_path']
    save_flag = opt.test['save_flag']
    save_dir = opt.test['save_dir']
    if save_flag and not os.path.exists(save_dir):
        os.mkdir(save_dir)
    opt.save_options()

    # check if it is needed to compute accuracies
    eval_flag = True if label_dir else False
    if eval_flag:
        test_results = dict()
        total_TP = 0.0
        total_FP = 0.0
        total_FN = 0.0
        total_d_list = []

    # data transforms
    test_transform = get_transforms(opt.transform['test'])

    model_name = opt.model['name']
    model = create_model(model_name, opt.model['out_c'],
                         opt.model['pretrained'])
    model = torch.nn.DataParallel(model)
    model = model.cuda()

    # ----- load trained model ----- #
    # print("=> loading trained model")
    best_checkpoint = torch.load(model_path)
    model.load_state_dict(best_checkpoint['state_dict'])
    print("=> loaded model at epoch {}".format(best_checkpoint['epoch']))
    model = model.module

    # switch to evaluate mode
    model.eval()
    counter = 0
    # print("=> Test begins:")

    img_names = os.listdir(img_dir)

    if save_flag:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        strs = img_dir.split('/')
        prob_maps_folder = '{:s}/{:s}_prob_maps'.format(save_dir, strs[-1])
        if not os.path.exists(prob_maps_folder):
            os.mkdir(prob_maps_folder)

    # img_names = ['BP-5.png']
    # total_time = 0.0
    for k in range(len(img_names)):
        img_name = img_names[k]
        # load test image
        print('=> Processing image {:s}'.format(img_name))
        img_path = '{:s}/{:s}'.format(img_dir, img_name)
        img = Image.open(img_path)
        ori_h = img.size[1]
        ori_w = img.size[0]
        name = os.path.splitext(img_name)[0]
        if eval_flag:
            # label_path = '{:s}/{:s}.png'.format(label_dir, name)
            label_path = '{:s}/{:s}_label_point.png'.format(label_dir, name)
            gt = io.imread(label_path)
            # gt_dilated = ski_morph.dilation(gt, ski_morph.disk(5))
            # utils.show_figures((gt, gt_dilated))
            # continue

        input = test_transform((img, ))[0].unsqueeze(0)

        # print('\tComputing output probability maps...')
        prob_maps = get_probmaps(input, model, opt)

        pred = prob_maps > opt.test['threshold']
        pred_labeled, N = measure.label(pred, return_num=True)
        if N > 1:
            bg_area = ski_morph.remove_small_objects(pred_labeled,
                                                     opt.post['max_area']) > 0
            large_area = ski_morph.remove_small_objects(
                pred_labeled, opt.post['min_area']) > 0
            pred = pred * (bg_area == 0) * (large_area > 0)

        if eval_flag:
            # print('\tComputing metrics...')
            TP, FP, FN, d_list = utils.compute_accuracy(pred,
                                                        gt,
                                                        radius=opt.r1,
                                                        return_distance=True)
            total_TP += TP
            total_FP += FP
            total_FN += FN
            total_d_list += d_list

            # save result for each image
            test_results[name] = [
                float(TP) / (TP + FN + 1e-8),
                float(TP) / (TP + FP + 1e-8),
                float(2 * TP) / (2 * TP + FP + FN + 1e-8)
            ]

        # save image
        if save_flag:
            # print('\tSaving image results...')
            io.imsave('{:s}/{:s}_pred.png'.format(prob_maps_folder, name),
                      pred.astype(np.uint8) * 255)
            io.imsave('{:s}/{:s}_prob.png'.format(prob_maps_folder, name),
                      prob_maps)

        counter += 1
        # if counter % 10 == 0:
        # print('\tProcessed {:d} images'.format(counter))

    # print('Time: {:4f}'.format(total_time/counter))

    # print('=> Processed all {:d} images'.format(counter))
    if eval_flag:
        recall = float(total_TP) / (total_TP + total_FN + 1e-8)
        precision = float(total_TP) / (total_TP + total_FP + 1e-8)
        F1 = 2 * precision * recall / (precision + recall + 1e-8)
        if len(total_d_list) > 0:
            mu = np.mean(np.array(total_d_list))
            sigma = np.sqrt(np.var(np.array(total_d_list)))
        else:
            mu = -1
            sigma = -1

        print('Average: precision\trecall\tF1\tmean\tstd:'
              '\t\t{:.4f}\t{:.4f}\t{:.4f}\t{:3f}\t{:.3f}'.format(
                  precision, recall, F1, mu, sigma))

        header = ['precision', 'recall', 'F1', 'mean', 'std']
        strs = img_dir.split('/')
        save_results(
            header, [precision, recall, F1, mu, sigma], test_results,
            '{:s}/{:s}_test_result_{:.2f}.txt'.format(save_dir, strs[-1],
                                                      opt.test['threshold']))
def main(opt):
    global best_score, num_iter, tb_writer, logger, logger_results
    best_score = 0
    opt.isTrain = True

    if not os.path.exists(opt.train['save_dir']):
        os.makedirs(opt.train['save_dir'])
    tb_writer = SummaryWriter('{:s}/tb_logs'.format(opt.train['save_dir']))

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in opt.train['gpus'])

    opt.define_transforms()
    opt.save_options()

    # set up logger
    logger, logger_results = setup_logging(opt)

    # ----- create model ----- #
    model_name = opt.model['name']
    model = create_model(model_name, opt.model['out_c'],
                         opt.model['pretrained'])
    # if not opt.train['checkpoint']:
    #     logger.info(model)
    model = nn.DataParallel(model)
    model = model.cuda()

    # ----- define optimizer ----- #
    optimizer = torch.optim.Adam(model.parameters(),
                                 opt.train['lr'],
                                 betas=(0.9, 0.99),
                                 weight_decay=opt.train['weight_decay'])

    # ----- define criterion ----- #
    criterion = torch.nn.MSELoss(reduction='none').cuda()

    # ----- load data ----- #
    img_dir = '{:s}/train'.format(opt.train['img_dir'])
    target_dir = '{:s}/train'.format(opt.train['label_dir'])
    if opt.round == 0:
        dir_list = [img_dir, target_dir]
        post_fix = ['label_detect.png']
        num_channels = [3, 1]
        train_transform = get_transforms(opt.transform['train_stage1'])
    else:
        bg_dir = '{:s}/train'.format(opt.train['bg_dir'])
        dir_list = [img_dir, target_dir, bg_dir]
        post_fix = ['label_detect.png', 'label_bg.png']
        num_channels = [3, 1, 1]
        train_transform = get_transforms(opt.transform['train_stage2'])
    train_set = DataFolder(dir_list, post_fix, num_channels, train_transform)
    train_loader = DataLoader(train_set,
                              batch_size=opt.train['batch_size'],
                              shuffle=True,
                              num_workers=opt.train['workers'])
    val_transform = get_transforms(opt.transform['val'])

    # ----- training and validation ----- #
    num_epoch = opt.train['train_epochs']
    num_iter = num_epoch * len(train_loader)
    # print training parameters
    logger.info("=> Initial learning rate: {:g}".format(opt.train['lr']))
    logger.info("=> Batch size: {:d}".format(opt.train['batch_size']))
    logger.info("=> Number of training iterations: {:d}".format(num_iter))
    logger.info("=> Training epochs: {:d}".format(opt.train['train_epochs']))

    for epoch in range(num_epoch):
        # train for one epoch or len(train_loader) iterations
        logger.info('Epoch: [{:d}/{:d}]'.format(epoch + 1, num_epoch))
        train_loss = train(opt, train_loader, model, optimizer, criterion)

        # evaluate on val set
        with torch.no_grad():
            val_recall, val_prec, val_F1 = validate(opt, model, val_transform)

        # check if it is the best accuracy
        is_best = val_F1 > best_score
        best_score = max(val_F1, best_score)

        cp_flag = True if (epoch +
                           1) % opt.train['checkpoint_freq'] == 0 else False
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, epoch, is_best, opt.train['save_dir'], cp_flag)

        # save the training results to txt files
        logger_results.info('{:d}\t{:.4f} || {:.4f}\t{:.4f}\t{:.4f}'.format(
            epoch + 1, train_loss, val_recall, val_prec, val_F1))
        # tensorboard logs
        tb_writer.add_scalars('epoch_loss', {'train_loss': train_loss}, epoch)
        tb_writer.add_scalars('epoch_acc', {
            'val_recall': val_recall,
            'val_prec': val_prec,
            'val_F1': val_F1
        }, epoch)

    tb_writer.close()
    for i in list(logger.handlers):
        logger.removeHandler(i)
        i.flush()
        i.close()
    for i in list(logger_results.handlers):
        logger_results.removeHandler(i)
        i.flush()
        i.close()
Beispiel #7
0
def main():
    global opt, num_iter, tb_writer, logger, logger_results
    opt = Options(isTrain=True)
    opt.parse()
    opt.save_options()

    tb_writer = SummaryWriter('{:s}/tb_logs'.format(opt.train['save_dir']))

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in opt.train['gpus'])

    # set up logger
    logger, logger_results = setup_logging(opt)

    # ----- create model ----- #
    model = ResUNet34(pretrained=opt.model['pretrained'])
    # if not opt.train['checkpoint']:
    #     logger.info(model)
    model = nn.DataParallel(model)
    model = model.cuda()
    cudnn.benchmark = True

    # ----- define optimizer ----- #
    optimizer = torch.optim.Adam(model.parameters(),
                                 opt.train['lr'],
                                 betas=(0.9, 0.99),
                                 weight_decay=opt.train['weight_decay'])

    # ----- define criterion ----- #
    criterion = torch.nn.NLLLoss(ignore_index=2).cuda()
    if opt.train['crf_weight'] > 0:
        logger.info('=> Using CRF loss...')
        global criterion_crf
        criterion_crf = CRFLoss(opt.train['sigmas'][0], opt.train['sigmas'][1])

    # ----- load data ----- #
    data_transforms = {
        'train': get_transforms(opt.transform['train']),
        'test': get_transforms(opt.transform['test'])
    }

    img_dir = '{:s}/train'.format(opt.train['img_dir'])
    target_vor_dir = '{:s}/train'.format(opt.train['label_vor_dir'])
    target_cluster_dir = '{:s}/train'.format(opt.train['label_cluster_dir'])
    dir_list = [img_dir, target_vor_dir, target_cluster_dir]
    post_fix = ['label_vor.png', 'label_cluster.png']
    num_channels = [3, 3, 3]
    train_set = DataFolder(dir_list, post_fix, num_channels,
                           data_transforms['train'])
    train_loader = DataLoader(train_set,
                              batch_size=opt.train['batch_size'],
                              shuffle=True,
                              num_workers=opt.train['workers'])

    # ----- optionally load from a checkpoint for validation or resuming training ----- #
    if opt.train['checkpoint']:
        if os.path.isfile(opt.train['checkpoint']):
            logger.info("=> loading checkpoint '{}'".format(
                opt.train['checkpoint']))
            checkpoint = torch.load(opt.train['checkpoint'])
            opt.train['start_epoch'] = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                opt.train['checkpoint'], checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(
                opt.train['checkpoint']))

    # ----- training and validation ----- #
    num_epoch = opt.train['train_epochs'] + opt.train['finetune_epochs']
    num_iter = num_epoch * len(train_loader)
    # print training parameters
    logger.info("=> Initial learning rate: {:g}".format(opt.train['lr']))
    logger.info("=> Batch size: {:d}".format(opt.train['batch_size']))
    logger.info("=> Number of training iterations: {:d}".format(num_iter))
    logger.info("=> Training epochs: {:d}".format(opt.train['train_epochs']))
    logger.info("=> Fine-tune epochs using dense CRF loss: {:d}".format(
        opt.train['finetune_epochs']))
    logger.info("=> CRF loss weight: {:.2g}".format(opt.train['crf_weight']))

    for epoch in range(opt.train['start_epoch'], num_epoch):
        # train for one epoch or len(train_loader) iterations
        logger.info('Epoch: [{:d}/{:d}]'.format(epoch + 1, num_epoch))
        finetune_flag = False if epoch < opt.train['train_epochs'] else True
        if epoch == opt.train['train_epochs']:
            logger.info("Fine-tune begins, lr = {:.2g}".format(
                opt.train['lr'] * 0.1))
            for param_group in optimizer.param_groups:
                param_group['lr'] = opt.train['lr'] * 0.1

        train_results = train(train_loader, model, optimizer, criterion,
                              finetune_flag)
        train_loss, train_loss_vor, train_loss_cluster, train_loss_crf = train_results

        cp_flag = (epoch + 1) % opt.train['checkpoint_freq'] == 0
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, epoch, opt.train['save_dir'], cp_flag)

        # save the training results to txt files
        logger_results.info('{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(
            epoch + 1, train_loss, train_loss_vor, train_loss_cluster,
            train_loss_crf))
        # tensorboard logs
        tb_writer.add_scalars(
            'epoch_losses', {
                'train_loss': train_loss,
                'train_loss_vor': train_loss_vor,
                'train_loss_cluster': train_loss_cluster,
                'train_loss_crf': train_loss_crf
            }, epoch)
    tb_writer.close()
    for i in list(logger.handlers):
        logger.removeHandler(i)
        i.flush()
        i.close()
    for i in list(logger_results.handlers):
        logger_results.removeHandler(i)
        i.flush()
        i.close()
def main(opt, save_dir):
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in opt.test['gpus'])

    # img_dir = opt.test['img_dir']
    ratio = opt.ratio
    img_dir = './data/{:s}/images'.format(opt.dataset)
    label_dir = './data/{:s}/labels_point'.format(opt.dataset)
    label_instance_dir = './data/{:s}/labels_instance'.format(opt.dataset)
    # save_dir = './data/{:s}/selected_masks'.format(opt.dataset)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    model_path = opt.test['model_path']

    # data transforms
    test_transform = get_transforms(opt.transform['test'])

    model = ResUNet34(pretrained=opt.model['pretrained'],
                      with_uncertainty=opt.with_uncertainty)
    model = model.cuda()
    cudnn.benchmark = True

    # ----- load trained model ----- #
    # print("=> loading trained model")
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['state_dict'])
    # print("=> loaded model at epoch {}".format(checkpoint['epoch']))

    # switch to evaluate mode
    model.eval()
    apply_dropout(model)

    with open('./data/{:s}/train_val_test.json'.format(opt.dataset),
              'r') as file:
        data_list = json.load(file)
        train_list = data_list['train']

    for img_name in tqdm(train_list):
        # load test image
        # print('=> Processing image {:s}'.format(img_name))
        img_path = '{:s}/{:s}'.format(img_dir, img_name)
        img = Image.open(img_path)
        ori_h = img.size[1]
        ori_w = img.size[0]
        name = os.path.splitext(img_name)[0]
        label_point = misc.imread('{:s}/{:s}_label_point.png'.format(
            label_dir, name))

        input = test_transform((img, ))[0].unsqueeze(0)
        # print('\tComputing unertainty maps...')
        mean_sigma = np.zeros((2, ori_h, ori_w))
        mean_sigma_normalized = np.zeros((2, ori_h, ori_w))
        mean_prob = np.zeros((2, ori_h, ori_w))
        for _ in range(opt.T):
            output, log_var = get_probmaps(input, model, opt)
            output = output.astype(np.float64)
            log_var = log_var.astype(np.float64)
            sigma_map = np.exp(log_var / 2)
            sigma_map_normalized = sigma_map / (np.exp(output) + 1e-8)

            mean_prob += np.exp(output) / np.sum(np.exp(output), axis=0)
            mean_sigma += sigma_map
            mean_sigma_normalized += sigma_map_normalized

        mean_prob /= opt.T
        mean_sigma /= opt.T
        mean_sigma_normalized /= opt.T

        un_data_normalized = mean_sigma_normalized**2

        pred = np.argmax(mean_prob, axis=0)
        un_data_normalized = np.sum(un_data_normalized *
                                    utils.onehot_encoding(pred),
                                    axis=0)

        # find the area of largest uncertainty for visualization
        threshed = un_data_normalized > 1.0
        large_unc_area = morph.opening(threshed, selem=morph.disk(1))
        large_unc_area = morph.remove_small_objects(large_unc_area,
                                                    min_size=64)
        un_data_smoothed = gaussian_filter(un_data_normalized * large_unc_area,
                                           sigma=5)

        # cmap = plt.cm.jet
        # plt.imsave('{:s}/{:s}_uncertainty.png'.format(save_dir, name), cmap(un_data_normalized))

        points = measure.label(label_point)
        uncertainty_list = []
        radius = 10
        for k in range(1, np.max(points) + 1):
            x, y = np.argwhere(points == k)[0]
            r1 = x - radius if x - radius > 0 else 0
            r2 = x + radius if x + radius < ori_h else ori_h
            c1 = y - radius if y - radius > 0 else 0
            c2 = y + radius if y + radius < ori_w else ori_w
            uncertainty = np.mean(un_data_smoothed[r1:r2, c1:c2])
            uncertainty_list.append([k, uncertainty])

        uncertainty_list = np.array(uncertainty_list)
        sorted_list = uncertainty_list[uncertainty_list[:, 1].argsort()[::-1]]
        indices = sorted_list[:int(ratio * np.max(points)), 0]

        # annotation
        label_instance = misc.imread('{:s}/{:s}_label.png'.format(
            label_instance_dir, name))
        new_anno = np.zeros_like(label_instance)
        counter = 1
        for idx in indices:
            nuclei_idx = np.unique(label_instance[points == idx])[0]
            if nuclei_idx == 0:
                continue
            new_anno += (label_instance == nuclei_idx) * counter
            counter += 1
            # utils.show_figures((new_anno,))

        misc.imsave('{:s}/{:s}_label_partial_mask.png'.format(save_dir, name),
                    new_anno.astype(np.uint8))
        misc.imsave(
            '{:s}/{:s}_label_partial_mask_binary.png'.format(save_dir, name),
            (new_anno > 0).astype(np.uint8) * 255)

    print('=> Processed all images')
Beispiel #9
0
def val(img_dir, label_dir, model, transform, opt, tb_writer, epoch):
    model.eval()
    img_names = os.listdir(img_dir)
    metric_names = ['acc', 'p_F1', 'p_recall', 'p_precision', 'dice', 'aji']
    val_results = dict()
    all_results = utils.AverageMeter(len(metric_names))

    plot_num = 10  #len(img_names)
    for img_name in img_names:
        img_path = '{:s}/{:s}'.format(img_dir, img_name)
        img = Image.open(img_path)
        ori_h = img.size[1]
        ori_w = img.size[0]
        name = os.path.splitext(img_name)[0]
        label_path = '{:s}/{:s}_label.png'.format(label_dir, name)
        gt = misc.imread(label_path)

        input = transform((img, ))[0].unsqueeze(0)

        prob_maps = get_probmaps(input, model, opt)
        pred = np.argmax(prob_maps, axis=0)

        pred_labeled = measure.label(pred)
        pred_labeled = morph.remove_small_objects(pred_labeled,
                                                  opt.post['min_area'])
        pred_labeled = ndi_morph.binary_fill_holes(pred_labeled > 0)
        pred_labeled = measure.label(pred_labeled)

        metrics = compute_metrics(pred_labeled, gt, metric_names)

        if plot_num > 0:

            unNorm = get_transforms({
                'unnormalize':
                np.load('{:s}/mean_std.npy'.format(opt.train['data_dir']))
            })
            img_tensor = unNorm(input.squeeze(0))
            img_np = img_tensor.permute(1, 2, 0).numpy()
            font = ImageFont.truetype(
                '/usr/share/fonts/truetype/dejavu/DejaVuSansMono.ttf', 42)
            metrics_text = Image.new("RGB", (512, 512), (255, 255, 255))
            draw = ImageDraw.Draw(metrics_text)
            draw.text(
                (32, 128),
                'Acc: {:.4f}\nF1: {:.4f}\nRecall: {:.4f}\nPrecision: {:.4f}\nDice: {:.4f}\nAJI: {:.4f}'
                .format(metrics['acc'], metrics['p_F1'], metrics['p_recall'],
                        metrics['p_precision'], metrics['dice'],
                        metrics['aji']),
                fill='rgb(0,0,0)',
                font=font)
            #tb_writer.add_scalars('{:s}'.format(name), 'Acc: {:.4f}\nF1: {:.4f}\nRecall: {:.4f}\nPrecision: {:.4f}\nDice: {:.4f}\nAJI: {:.4f}'.format(metrics['acc'], metrics['p_F1'], metrics['p_recall'], metrics['p_precision'], metrics['dice'], metrics['aji']), epoch)
            metrics_text = metrics_text.resize((ori_w, ori_h), Image.ANTIALIAS)
            trans_to_tensor = transforms.Compose([
                transforms.ToTensor(),
            ])
            text_tensor = trans_to_tensor(metrics_text).float()
            colored_gt = np.zeros((ori_h, ori_w, 3))
            colored_pred = np.zeros((ori_h, ori_w, 3))
            img_w_colored_gt = img_np.copy()
            img_w_colored_pred = img_np.copy()
            alpha = 0.5
            for k in range(1, gt.max() + 1):
                colored_gt[gt == k, :] = np.array(
                    utils.get_random_color(seed=k))
                img_w_colored_gt[gt == k, :] = img_w_colored_gt[gt == k, :] * (
                    1 - alpha) + colored_gt[gt == k, :] * alpha
            for k in range(1, pred_labeled.max() + 1):
                colored_pred[pred_labeled == k, :] = np.array(
                    utils.get_random_color(seed=k))
                img_w_colored_pred[
                    pred_labeled ==
                    k, :] = img_w_colored_pred[pred_labeled == k, :] * (
                        1 - alpha) + colored_pred[pred_labeled == k, :] * alpha

            gt_tensor = torch.from_numpy(colored_gt).permute(2, 0, 1).float()
            pred_tensor = torch.from_numpy(colored_pred).permute(2, 0,
                                                                 1).float()
            img_w_gt_tensor = torch.from_numpy(img_w_colored_gt).permute(
                2, 0, 1).float()
            img_w_pred_tensor = torch.from_numpy(img_w_colored_pred).permute(
                2, 0, 1).float()
            tb_writer.add_image(
                '{:s}'.format(name),
                make_grid([
                    img_tensor, img_w_gt_tensor, img_w_pred_tensor,
                    text_tensor, gt_tensor, pred_tensor
                ],
                          nrow=3,
                          padding=10,
                          pad_value=1), epoch)
            plot_num -= 1

        # update the average result
        all_results.update([
            metrics['acc'], metrics['p_F1'], metrics['p_recall'],
            metrics['p_precision'], metrics['dice'], metrics['aji']
        ])
    logger.info('\t=> Val Avg: Acc {r[0]:.4f}'
                '\tF1 {r[1]:.4f}'
                '\tRecall {r[2]:.4f}'
                '\tPrecision {r[3]:.4f}'
                '\tDice {r[4]:.4f}'
                '\tAJI {r[5]:.4f}'.format(r=all_results.avg))

    return all_results.avg