Пример #1
0
def main(model_path, exp_config):

    # Make and restore vagan model
    segvae_model = segvae(exp_config=exp_config)
    segvae_model.load_weights(model_path, type='latest')

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    ncc_list = []

    for ii, batch in enumerate(data.test.iterate_batches(1)):

        if ii % 10 == 0:
            print('Progress: %d' % ii)

        x_b, s_b = batch

        s_m, s_v, s_e = segvae_model.predict_mean_variance_and_error_maps(
            s_b, x_b, num_samples=100)

        ncc_list.append(utils.ncc(s_v, s_e))

    ncc_arr = np.asarray(ncc_list)

    ncc_mean = np.mean(ncc_arr, axis=0)
    ncc_std = np.std(ncc_arr, axis=0)

    print('NCC mean: %.4f', ncc_mean)
    print('NCC std: %.4f', ncc_std)
Пример #2
0
def main(model_path, exp_config, do_plots=False):

    n_samples = 50
    model_selection = 'best_ged'

    # Get Data
    segvae_model = segvae(exp_config=exp_config)
    segvae_model.load_weights(model_path, type=model_selection)

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    N = data.test.images.shape[0]

    ged_list = []
    ncc_list = []

    for ii in range(N):

        if ii % 10 == 0:
            logging.info("Progress: %d" % ii)

        x_b = data.test.images[ii, ...].reshape([1] + list(exp_config.image_size))
        s_b = data.test.labels[ii, ...]

        x_b_stacked = np.tile(x_b, [n_samples, 1, 1, 1])

        feed_dict = {}
        feed_dict[segvae_model.training_pl] = False
        feed_dict[segvae_model.x_inp] = x_b_stacked


        s_arr_sm = segvae_model.sess.run(segvae_model.s_out_eval_sm, feed_dict=feed_dict)
        s_arr = np.argmax(s_arr_sm, axis=-1)

        # s_arr = np.squeeze(np.asarray(s_list)) # num samples x X x Y
        s_b_r = s_b.transpose((2,0,1)) # num gts x X x Y
        s_b_r_sm = utils.convert_batch_to_onehot(s_b_r, exp_config.nlabels)  # num gts x X x Y x nlabels

        ged = utils.generalised_energy_distance(s_arr, s_b_r, nlabels=exp_config.nlabels-1, label_range=range(1,exp_config.nlabels))
        ged_list.append(ged)

        ncc = utils.variance_ncc_dist(s_arr_sm, s_b_r_sm)
        ncc_list.append(ncc)



    ged_arr = np.asarray(ged_list)
    ncc_arr = np.asarray(ncc_list)

    logging.info('-- GED: --')
    logging.info(np.mean(ged_arr))
    logging.info(np.std(ged_arr))

    logging.info('-- NCC: --')
    logging.info(np.mean(ncc_arr))
    logging.info(np.std(ncc_arr))

    np.savez(os.path.join(model_path, 'ged%s_%s.npz' % (str(n_samples), model_selection)), ged_arr)
    np.savez(os.path.join(model_path, 'ncc%s_%s.npz' % (str(n_samples), model_selection)), ncc_arr)
def main(model_path, exp_config):

    # Get Data
    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    data = data_loader(exp_config)

    # Make and restore vagan model
    classifier_model = classifier(exp_config=exp_config, data=data, fixed_batch_size=1)

    # classifier_model.initialise_saliency(mode='additive_pertubation')
    # classifier_model.initialise_saliency(mode='backprop')
    # classifier_model.initialise_saliency(mode='integrated_gradients')
    classifier_model.initialise_saliency(mode='guided_backprop')
    # classifier_model.initialise_saliency(mode='CAM')  # Requires CAM net (obvs)

    classifier_model.load_weights(model_path, type='best_xent')

    for batch in data.testAD.iterate_batches(1):

        x, y = batch

        sal = classifier_model.compute_saliency(x, label=1)

        plt.figure()
        plt.imshow(np.squeeze(x))

        plt.figure()
        plt.imshow(np.squeeze(sal))
        plt.show()
Пример #4
0
def calculate_expert_diversity(exp_config):
    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)
    diversity = []
    for ii in tqdm(range(data.test.images.shape[0])):
        targets = data.test.labels[ii, ...].transpose((2, 0, 1))
        ged_, diversity_ = utils.generalised_energy_distance(targets, targets, exp_config.nlabels - 1,
                                                             range(1, exp_config.nlabels))
        diversity.append(diversity_)
    diversity = np.array(diversity)
    print(f'{np.mean(diversity):.6f} +- {nanstderr(diversity):.6f}')
Пример #5
0
def main(exp_config):

    logging.info(
        '**************************************************************')
    logging.info(' *** Running Experiment: %s', exp_config.experiment_name)
    logging.info(
        '**************************************************************')
    # Get Data
    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    # Create Model
    phiseg = phiseg_model.phiseg(exp_config)

    # Fit model to data
    phiseg.train(data)
def main():

    # Select experiment below

    # from classifier.experiments import synthetic_CAM as exp_config
    # from classifier.experiments import synthetic_vgg16 as exp_config
    from classifier.experiments import synthetic_resnet34 as exp_config

    # Get Data
    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    # Build VAGAN model
    classifier_model = classifier(exp_config=exp_config, data=data, fixed_batch_size=exp_config.batch_size)

    # Train VAGAN model
    classifier_model.train()
Пример #7
0
def main():

    # Select experiment below
    # from detseg.experiments import acdc_unet as exp_config
    from detseg.experiments import acdc_probunetarch as exp_config
    # from detseg.experiments import lidc_probunetarch as exp_config
    # from detseg.experiments import uzh_prostate_probunetarch as exp_config
    # from detseg.experiments import twolbl_uzh_prostate_probunetarch as exp_config

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    # Build VAGAN model
    segmenter_model = segmenter(exp_config=exp_config,
                                data=data,
                                fixed_batch_size=exp_config.batch_size)

    # Train VAGAN model
    segmenter_model.train()
def main():

    # Select experiment below
    # from segmenter.experiments import synthetic_unet as exp_config
    # from segmenter.experiments import acdc_resunet as exp_config
    # from segmenter.experiments import acdc_unet3D as exp_config
    from segmenter.experiments import acdc_unet as exp_config
    # from segmenter.experiments import nci_prostate_unet as exp_config
    # from segmenter.experiments import nci_prostate_unet3D as exp_config

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    # Build VAGAN model
    segmenter_model = segmenter(exp_config=exp_config,
                                data=data,
                                fixed_batch_size=exp_config.batch_size)

    # Train VAGAN model
    segmenter_model.train()
def main(model_path, exp_config):

    # Make and restore vagan model
    phiseg_model = phiseg(exp_config=exp_config)
    phiseg_model.load_weights(model_path, type=model_selection)

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    N = data.test.images.shape[0]

    n_images = 16
    n_samples = 16

    # indices = np.arange(N)
    # sample_inds = np.random.choice(indices, n_images)
    sample_inds = [165, 280, 213]  # <-- prostate
    # sample_inds = [1551] #[907, 1296, 1551]  # <-- LIDC

    for ii in sample_inds:

        print('------- Processing image %d -------' % ii)

        outfolder = os.path.join(model_path, 'samples_%s' % model_selection,
                                 str(ii))
        utils.makefolder(outfolder)

        x_b = data.test.images[ii,
                               ...].reshape([1] + list(exp_config.image_size))
        s_b = data.test.labels[ii, ...]

        if np.sum(s_b) < 10:
            print('WARNING: skipping cases with no structures')
            continue

        s_b_r = utils.convert_batch_to_onehot(s_b.transpose((2, 0, 1)),
                                              exp_config.nlabels)

        print('Plotting input image')
        plt.figure()
        x_b_d = preproc_image(x_b)
        plt.imshow(x_b_d, cmap='gray')
        plt.axis('off')
        plt.savefig(os.path.join(outfolder, 'input_img_%d.png' % ii),
                    bbox_inches='tight')

        print('Generating 100 samples')
        s_p_list = []
        for kk in range(100):
            s_p_list.append(
                phiseg_model.predict_segmentation_sample(x_b,
                                                         return_softmax=True))
        s_p_arr = np.squeeze(np.asarray(s_p_list))

        print('Plotting %d of those samples' % n_samples)
        for jj in range(n_samples):

            s_p_sm = s_p_arr[jj, ...]
            s_p_am = np.argmax(s_p_sm, axis=-1)

            plt.figure()
            s_p_d = preproc_image(s_p_am, nlabels=exp_config.nlabels)
            plt.imshow(s_p_d, cmap='gray')
            plt.axis('off')
            plt.savefig(os.path.join(outfolder,
                                     'sample_img_%d_samp_%d.png' % (ii, jj)),
                        bbox_inches='tight')

        print('Plotting ground-truths masks')
        for jj in range(s_b_r.shape[0]):

            s_b_sm = s_b_r[jj, ...]
            s_b_am = np.argmax(s_b_sm, axis=-1)

            plt.figure()
            s_p_d = preproc_image(s_b_am, nlabels=exp_config.nlabels)
            plt.imshow(s_p_d, cmap='gray')
            plt.axis('off')
            plt.savefig(os.path.join(outfolder,
                                     'gt_img_%d_samp_%d.png' % (ii, jj)),
                        bbox_inches='tight')

        print('Generating error masks')
        E_ss, E_sy_avg, E_yy_avg = generate_error_maps(s_p_arr, s_b_r)

        print('Plotting them')
        plt.figure()
        plt.imshow(preproc_image(E_ss))
        plt.axis('off')
        plt.savefig(os.path.join(outfolder, 'E_ss_%d.png' % ii),
                    bbox_inches='tight')

        print('Plotting them')
        plt.figure()
        plt.imshow(preproc_image(np.log(E_ss)))
        plt.axis('off')
        plt.savefig(os.path.join(outfolder, 'log_E_ss_%d.png' % ii),
                    bbox_inches='tight')

        plt.figure()
        plt.imshow(preproc_image(E_sy_avg))
        plt.axis('off')
        plt.savefig(os.path.join(outfolder, 'E_sy_avg_%d_.png' % ii),
                    bbox_inches='tight')

        plt.figure()
        plt.imshow(preproc_image(E_yy_avg))
        plt.axis('off')
        plt.savefig(os.path.join(outfolder, 'E_yy_avg_%d_.png' % ii),
                    bbox_inches='tight')

        plt.close('all')
def main(model_path, exp_config):

    # Make and restore vagan model
    phiseg_model = phiseg(exp_config=exp_config)
    phiseg_model.load_weights(model_path, type='best_ged')

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    lat_lvls = exp_config.latent_levels

    # RANDOM IMAGE
    # x_b, s_b = data.test.next_batch(1)

    # FIXED IMAGE
    # Cardiac: 100 normal image
    # LIDC: 200 large lesion, 203, 1757 complicated lesion
    # Prostate: 165 nice slice, 170 is a challenging and interesting slice
    index = 165  # #

    if SAVE_GIF:
        outfolder_gif = os.path.join(model_path,
                                     'model_samples_id%d_gif' % index)
        utils.makefolder(outfolder_gif)

    x_b = data.test.images[index,
                           ...].reshape([1] + list(exp_config.image_size))

    x_b_d = utils.convert_to_uint8(np.squeeze(x_b))
    x_b_d = utils.resize_image(x_b_d, video_target_size)

    if exp_config.data_identifier == 'uzh_prostate':
        # rotate
        rows, cols = x_b_d.shape
        M = cv2.getRotationMatrix2D((cols / 2, rows / 2), 270, 1)
        x_b_d = cv2.warpAffine(x_b_d, M, (cols, rows))

    if SAVE_VIDEO:
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        outfile = os.path.join(model_path, 'model_samples_id%d.avi' % index)
        out = cv2.VideoWriter(outfile, fourcc, 5.0,
                              (2 * video_target_size[1], video_target_size[0]))

    samps = 20
    for ii in range(samps):

        # fix all below current level (the correct implementation)
        feed_dict = {}
        feed_dict[phiseg_model.training_pl] = False
        feed_dict[phiseg_model.x_inp] = x_b

        s_p, s_p_list = phiseg_model.sess.run(
            [phiseg_model.s_out_eval, phiseg_model.s_out_eval_list],
            feed_dict=feed_dict)
        s_p = np.argmax(s_p, axis=-1)

        # s_p_d = utils.convert_to_uint8(np.squeeze(s_p))
        s_p_d = np.squeeze(np.uint8((s_p / exp_config.nlabels) * 255))
        s_p_d = utils.resize_image(s_p_d,
                                   video_target_size,
                                   interp=cv2.INTER_NEAREST)

        if exp_config.data_identifier == 'uzh_prostate':
            #rotate
            s_p_d = cv2.warpAffine(s_p_d, M, (cols, rows))

        img = np.concatenate([x_b_d, s_p_d], axis=1)
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

        img = histogram_equalization(img)

        if exp_config.data_identifier == 'acdc':
            # labels (0 85 170 255)
            rv = cv2.inRange(s_p_d, 84, 86)
            my = cv2.inRange(s_p_d, 169, 171)
            rv_cnt, hierarchy = cv2.findContours(rv, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            my_cnt, hierarchy = cv2.findContours(my, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)

            cv2.drawContours(img, rv_cnt, -1, (0, 255, 0), 1)
            cv2.drawContours(img, my_cnt, -1, (0, 0, 255), 1)
        if exp_config.data_identifier == 'uzh_prostate':

            print(np.unique(s_p_d))
            s1 = cv2.inRange(s_p_d, 84, 86)
            s2 = cv2.inRange(s_p_d, 169, 171)
            # s3 = cv2.inRange(s_p_d, 190, 192)
            s1_cnt, hierarchy = cv2.findContours(s1, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            s2_cnt, hierarchy = cv2.findContours(s2, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            # s3_cnt, hierarchy = cv2.findContours(s3, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

            cv2.drawContours(img, s1_cnt, -1, (0, 255, 0), 1)
            cv2.drawContours(img, s2_cnt, -1, (0, 0, 255), 1)
            # cv2.drawContours(img, s3_cnt, -1, (255, 0, 255), 1)
        elif exp_config.data_identifier == 'lidc':
            thresh = cv2.inRange(s_p_d, 127, 255)
            lesion, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(img, lesion, -1, (0, 255, 0), 1)

        if SAVE_VIDEO:
            out.write(img)

        if SAVE_GIF:
            outfile_gif = os.path.join(outfolder_gif,
                                       'frame_%s.png' % str(ii).zfill(3))
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            # scipy.misc.imsave(outfile_gif, img_rgb)
            im = Image.fromarray(img_rgb)
            im = im.resize((im.size[0] * 2, im.size[1] * 2), Image.ANTIALIAS)

            im.save(outfile_gif)

        if DISPLAY_VIDEO:
            cv2.imshow('frame', img)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    if SAVE_VIDEO:
        out.release()
    cv2.destroyAllWindows()
Пример #11
0
def main(model_path, exp_config, model_selection='latest'):

    # Get Data
    phiseg_model = phiseg(exp_config=exp_config)
    phiseg_model.load_weights(model_path, type=model_selection)

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    # Run predictions in an endless loop
    dice_list = []

    num_samples = 1 if exp_config.likelihood is likelihoods.det_unet2D else 100

    for ii, batch in enumerate(data.test.iterate_batches(1)):

        if ii % 10 == 0:
            logging.info("Progress: %d" % ii)

        # print(ii)

        x, y = batch

        y_ = np.squeeze(phiseg_model.predict(x, num_samples=num_samples))

        per_lbl_dice = []
        per_pixel_preds = []
        per_pixel_gts = []

        for lbl in range(exp_config.nlabels):

            binary_pred = (y_ == lbl) * 1
            binary_gt = (y == lbl) * 1

            if np.sum(binary_gt) == 0 and np.sum(binary_pred) == 0:
                per_lbl_dice.append(1)
            elif np.sum(binary_pred) > 0 and np.sum(binary_gt) == 0 or np.sum(
                    binary_pred) == 0 and np.sum(binary_gt) > 0:
                logging.warning(
                    'Structure missing in either GT (x)or prediction. ASSD and HD will not be accurate.'
                )
                per_lbl_dice.append(0)
            else:
                per_lbl_dice.append(dc(binary_pred, binary_gt))

        dice_list.append(per_lbl_dice)

        per_pixel_preds.append(y_.flatten())
        per_pixel_gts.append(y.flatten())

    dice_arr = np.asarray(dice_list)

    mean_per_lbl_dice = dice_arr.mean(axis=0)

    logging.info('Dice')
    logging.info(mean_per_lbl_dice)
    logging.info(np.mean(mean_per_lbl_dice))
    logging.info('foreground mean: %f' % (np.mean(mean_per_lbl_dice[1:])))

    np.savez(os.path.join(model_path, 'dice_%s.npz' % model_selection),
             dice_arr)
Пример #12
0
def main(model_path, exp_config, do_plots=False):

    # Get Data
    segvae_model = segvae(exp_config=exp_config)
    segvae_model.load_weights(model_path, type=model_selection)

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    # Run predictions in an endless loop
    dice_list = []

    num_samples = 1 if exp_config.likelihood is likelihoods.det_unet2D else 100

    for ii, batch in enumerate(data.test.iterate_batches(1)):

        if ii % 10 == 0:
            logging.info("Progress: %d" % ii)

        # print(ii)

        x, y = batch

        # Adding motion corrpution
        # x = utils.add_motion_artefacts(np.squeeze(x), 15)
        # x = x.reshape([1] + list(exp_config.image_size))

        # Add box corruption
        # x[:, 192 // 2 - 20:192 // 2 + 20, 192 // 2 - 5:192 // 2 + 5, :] = 0

        y_ = np.squeeze(segvae_model.predict(x, num_samples=num_samples))

        per_lbl_dice = []
        per_pixel_preds = []
        per_pixel_gts = []

        if do_plots and not sys_config.running_on_gpu_host:
            fig = plt.figure()
            fig.add_subplot(131)
            plt.imshow(np.squeeze(x), cmap='gray')
            fig.add_subplot(132)
            plt.imshow(np.squeeze(y_))
            fig.add_subplot(133)
            plt.imshow(np.squeeze(y))
            plt.show()

        for lbl in range(exp_config.nlabels):

            binary_pred = (y_ == lbl) * 1
            binary_gt = (y == lbl) * 1

            if np.sum(binary_gt) == 0 and np.sum(binary_pred) == 0:
                per_lbl_dice.append(1)
            elif np.sum(binary_pred) > 0 and np.sum(binary_gt) == 0 or np.sum(
                    binary_pred) == 0 and np.sum(binary_gt) > 0:
                logging.warning(
                    'Structure missing in either GT (x)or prediction. ASSD and HD will not be accurate.'
                )
                per_lbl_dice.append(0)
            else:
                per_lbl_dice.append(dc(binary_pred, binary_gt))

        dice_list.append(per_lbl_dice)

        per_pixel_preds.append(y_.flatten())
        per_pixel_gts.append(y.flatten())

    dice_arr = np.asarray(dice_list)

    mean_per_lbl_dice = dice_arr.mean(axis=0)

    logging.info('Dice')
    logging.info(mean_per_lbl_dice)
    logging.info(np.mean(mean_per_lbl_dice))
    logging.info('foreground mean: %f' % (np.mean(mean_per_lbl_dice[1:])))

    np.savez(os.path.join(model_path, 'dice_%s.npz' % model_selection),
             dice_arr)
Пример #13
0
def test(model_path, exp_config, model_selection='latest', num_samples=100, overwrite=False, mode=False):
    output_path = get_output_path(model_path, num_samples, model_selection, mode) + '.pickle'
    if os.path.exists(output_path) and not overwrite:
        return
    image_saver = ImageSaver(os.path.join(model_path, 'samples'))
    tf.reset_default_graph()
    phiseg_model = phiseg(exp_config=exp_config)
    phiseg_model.load_weights(model_path, type=model_selection)

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    metrics = {key: [] for key in
               ['dsc', 'presence', 'ged', 'ncc', 'entropy', 'diversity', 'sample_dsc', 'ece', 'unweighted_ece',
                'loglikelihood']}

    num_samples = 1 if exp_config.likelihood is likelihoods.det_unet2D else num_samples

    for ii in tqdm(range(data.test.images.shape[0])):
        image = data.test.images[ii, ...].reshape([1] + list(exp_config.image_size))
        targets = data.test.labels[ii, ...].transpose((2, 0, 1))

        feed_dict = {phiseg_model.training_pl: False,
                     phiseg_model.x_inp: np.tile(image, [num_samples, 1, 1, 1])}

        prob_maps = phiseg_model.sess.run(phiseg_model.s_out_eval_sm, feed_dict=feed_dict)
        samples = np.argmax(prob_maps, axis=-1)
        probability = np.mean(prob_maps, axis=0) + 1e-10
        metrics['entropy'].append(float(np.sum(-probability * np.log(probability))))
        if mode:
            prediction = np.round(np.mean(np.argmax(prob_maps, axis=-1), axis=0)).astype(np.int64)
        else:
            if 'proposed' not in exp_config.experiment_name:
                prediction = np.argmax(np.sum(prob_maps, axis=0), axis=-1)
            else:
                mean = phiseg_model.sess.run(phiseg_model.dist_eval.loc, feed_dict=feed_dict)[0]
                mean = np.reshape(mean, image.shape[:-1] + (2,))
                prediction = np.argmax(mean, axis=-1)

        metrics['loglikelihood'].append(calculate_log_likelihood(targets, prob_maps))
        # calculate DSC per expert
        metrics['dsc'].append(
            [[calc_dsc(target == i, prediction == i) for i in range(exp_config.nlabels)] for target in targets])
        metrics['presence'].append([[np.any(target == i) for i in range(exp_config.nlabels)] for target in targets])

        metrics['sample_dsc'].append([[[calc_dsc(target == i, sample == i) for i in range(exp_config.nlabels)]
                                       for target in targets] for sample in samples])

        # ged and diversity
        ged_, diversity_ = utils.generalised_energy_distance(samples, targets, exp_config.nlabels - 1,
                                                             range(1, exp_config.nlabels))
        metrics['ged'].append(ged_)
        metrics['diversity'].append(diversity_)
        # NCC
        targets_one_hot = utils.to_one_hot(targets, exp_config.nlabels)
        metrics['ncc'].append(utils.variance_ncc_dist(prob_maps, targets_one_hot)[0])
        prob_map = np.mean(prob_maps, axis=0)
        ece, unweighted_ece = calc_class_wise_expected_calibration_error(targets, prob_map, 2, 10)
        metrics['ece'].append(ece)
        metrics['unweighted_ece'].append(unweighted_ece)
        image_saver(str(ii) + '/', image[0, ..., 0], targets, prediction, samples)

    metrics = {key: np.array(metric) for key, metric in metrics.items()}
    with open(output_path, 'wb') as f:
        pickle.dump(metrics, f)
    image_saver.close()
Пример #14
0
def main(model_path, exp_config):

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    # RANDOM IMAGE
    # x_b, s_b = data.test.next_batch(1)

    # FIXED IMAGE
    # Cardiac: 100 normal image
    # LIDC: 200 large lesion, 203, 1757 complicated lesion
    # Prostate: 165 nice slice
    index = 165  #

    x_b = data.test.images[index,
                           ...].reshape([1] + list(exp_config.image_size))
    s_b = data.test.labels[index, ...]

    annot_index_gen = cycle(exp_config.annotator_range)

    x_b_d = utils.convert_to_uint8(np.squeeze(x_b))
    x_b_d = utils.resize_image(x_b_d, video_target_size)

    if SAVE_VIDEO:
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        outfile = os.path.join(model_path, 'gt_samples_id%d.avi' % index)
        out = cv2.VideoWriter(outfile, fourcc, 5.0,
                              (2 * video_target_size[1], video_target_size[0]))

    for _ in range(50):

        annot_index = next(annot_index_gen)
        s_b_d = s_b[..., annot_index]

        s_b_d = np.squeeze(np.uint8((s_b_d / exp_config.nlabels) * 255))
        s_b_d = utils.resize_image(s_b_d,
                                   video_target_size,
                                   interp=cv2.INTER_NEAREST)

        img = np.concatenate([x_b_d, s_b_d], axis=1)
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

        img = histogram_equalization(img)

        if exp_config.data_identifier == 'acdc':
            # labels (0 85 170 255)
            rv = cv2.inRange(s_b_d, 84, 86)
            my = cv2.inRange(s_b_d, 169, 171)
            rv_cnt, hierarchy = cv2.findContours(rv, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            my_cnt, hierarchy = cv2.findContours(my, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)

            cv2.drawContours(img, rv_cnt, -1, (0, 255, 0), 1)
            cv2.drawContours(img, my_cnt, -1, (0, 0, 255), 1)
        if exp_config.data_identifier == 'uzh_prostate':
            # labels (0 85 170 255)
            print(np.unique(s_b_d))
            s1 = cv2.inRange(s_b_d, 84, 86)
            s2 = cv2.inRange(s_b_d, 169, 171)
            # s3 = cv2.inRange(s_p_d, 190, 192)
            s1_cnt, hierarchy = cv2.findContours(s1, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            s2_cnt, hierarchy = cv2.findContours(s2, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            # s3_cnt, hierarchy = cv2.findContours(s3, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

            cv2.drawContours(img, s1_cnt, -1, (0, 255, 0), 1)
            cv2.drawContours(img, s2_cnt, -1, (0, 0, 255), 1)
            # cv2.drawContours(img, s3_cnt, -1, (255, 0, 255), 1)
        elif exp_config.data_identifier == 'lidc':
            thresh = cv2.inRange(s_b_d, 127, 255)
            lesion, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(img, lesion, -1, (0, 255, 0), 1)

        font = cv2.FONT_HERSHEY_SIMPLEX
        cv2.putText(
            img, 'Expert # %d/%d' %
            (annot_index + 1, len(exp_config.annotator_range)), (30, 256 - 30),
            font, 1, (255, 255, 255), 1, cv2.LINE_AA)

        if SAVE_VIDEO:
            out.write(img)

        cv2.imshow('frame', img)
        if cv2.waitKey(200) & 0xFF == ord('q'):
            break

    if SAVE_VIDEO:
        out.release()
    cv2.destroyAllWindows()
def main(input_folder,
         output_folder,
         model_path,
         exp_config,
         do_postprocessing=False,
         gt_exists=True):

    # Get Data
    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    # Make and restore vagan model
    segmenter_model = segmenter(
        exp_config=exp_config, data=data,
        fixed_batch_size=1)  # CRF model requires fixed batch size
    segmenter_model.load_weights(model_path, type='best_dice')

    total_time = 0
    total_volumes = 0

    dice_list = []
    assd_list = []
    hd_list = []

    for folder in os.listdir(input_folder):

        folder_path = os.path.join(input_folder, folder)

        if os.path.isdir(folder_path):

            infos = {}
            for line in open(os.path.join(folder_path, 'Info.cfg')):
                label, value = line.split(':')
                infos[label] = value.rstrip('\n').lstrip(' ')

            patient_id = folder.lstrip('patient')

            if not int(patient_id) % 5 == 0:
                continue

            ED_frame = int(infos['ED'])
            ES_frame = int(infos['ES'])

            for file in glob.glob(
                    os.path.join(folder_path, 'patient???_frame??.nii.gz')):

                logging.info(' ----- Doing image: -------------------------')
                logging.info('Doing: %s' % file)
                logging.info(' --------------------------------------------')

                file_base = file.split('.nii.gz')[0]

                frame = int(file_base.split('frame')[-1])
                img, img_affine, img_header = utils.load_nii(file)
                img = utils.normalise_image(img)
                zooms = img_header.get_zooms()

                if gt_exists:
                    file_mask = file_base + '_gt.nii.gz'
                    mask, mask_affine, mask_header = utils.load_nii(file_mask)

                start_time = time.time()

                if exp_config.dimensionality_mode == '2D':

                    pixel_size = (img_header.structarr['pixdim'][1],
                                  img_header.structarr['pixdim'][2])
                    scale_vector = (pixel_size[0] /
                                    exp_config.target_resolution[0],
                                    pixel_size[1] /
                                    exp_config.target_resolution[1])

                    predictions = []

                    nx, ny = exp_config.image_size

                    for zz in range(img.shape[2]):

                        slice_img = np.squeeze(img[:, :, zz])
                        slice_rescaled = transform.rescale(slice_img,
                                                           scale_vector,
                                                           order=1,
                                                           preserve_range=True,
                                                           multichannel=False,
                                                           mode='constant')

                        x, y = slice_rescaled.shape

                        x_s = (x - nx) // 2
                        y_s = (y - ny) // 2
                        x_c = (nx - x) // 2
                        y_c = (ny - y) // 2

                        # Crop section of image for prediction
                        if x > nx and y > ny:
                            slice_cropped = slice_rescaled[x_s:x_s + nx,
                                                           y_s:y_s + ny]
                        else:
                            slice_cropped = np.zeros((nx, ny))
                            if x <= nx and y > ny:
                                slice_cropped[x_c:x_c +
                                              x, :] = slice_rescaled[:,
                                                                     y_s:y_s +
                                                                     ny]
                            elif x > nx and y <= ny:
                                slice_cropped[:, y_c:y_c +
                                              y] = slice_rescaled[x_s:x_s +
                                                                  nx, :]
                            else:
                                slice_cropped[x_c:x_c + x, y_c:y_c +
                                              y] = slice_rescaled[:, :]

                        # GET PREDICTION
                        network_input = np.float32(
                            np.tile(np.reshape(slice_cropped, (nx, ny, 1)),
                                    (1, 1, 1, 1)))
                        mask_out, softmax = segmenter_model.predict(
                            network_input)

                        prediction_cropped = np.squeeze(softmax[0, ...])

                        # ASSEMBLE BACK THE SLICES
                        slice_predictions = np.zeros(
                            (x, y, exp_config.nlabels))
                        # insert cropped region into original image again
                        if x > nx and y > ny:
                            slice_predictions[x_s:x_s + nx, y_s:y_s +
                                              ny, :] = prediction_cropped
                        else:
                            if x <= nx and y > ny:
                                slice_predictions[:, y_s:y_s +
                                                  ny, :] = prediction_cropped[
                                                      x_c:x_c + x, :, :]
                            elif x > nx and y <= ny:
                                slice_predictions[
                                    x_s:x_s +
                                    nx, :, :] = prediction_cropped[:, y_c:y_c +
                                                                   y, :]
                            else:
                                slice_predictions[:, :, :] = prediction_cropped[
                                    x_c:x_c + x, y_c:y_c + y, :]

                        # RESCALING ON THE LOGITS
                        if gt_exists:
                            prediction = transform.resize(
                                slice_predictions,
                                (mask.shape[0], mask.shape[1],
                                 exp_config.nlabels),
                                order=1,
                                preserve_range=True,
                                mode='constant')
                        else:  # This can occasionally lead to wrong volume size, therefore if gt_exists
                            # we use the gt mask size for resizing.
                            prediction = transform.rescale(
                                slice_predictions, (1.0 / scale_vector[0],
                                                    1.0 / scale_vector[1], 1),
                                order=1,
                                preserve_range=True,
                                multichannel=False,
                                mode='constant')

                        prediction = np.uint8(np.argmax(prediction, axis=-1))
                        # import matplotlib.pyplot as plt
                        # fig = plt.Figure()
                        # for ii in range(3):
                        #     plt.subplot(1, 3, ii + 1)
                        #     plt.imshow(np.squeeze(prediction))
                        # plt.show()

                        predictions.append(prediction)

                    prediction_arr = np.transpose(
                        np.asarray(predictions, dtype=np.uint8), (1, 2, 0))

                elif exp_config.dimensionality_mode == '3D':

                    nx, ny, nz = exp_config.image_size

                    pixel_size = (img_header.structarr['pixdim'][1],
                                  img_header.structarr['pixdim'][2],
                                  img_header.structarr['pixdim'][3])

                    scale_vector = (pixel_size[0] /
                                    exp_config.target_resolution[0],
                                    pixel_size[1] /
                                    exp_config.target_resolution[1],
                                    pixel_size[2] /
                                    exp_config.target_resolution[2])

                    vol_scaled = transform.rescale(img,
                                                   scale_vector,
                                                   order=1,
                                                   preserve_range=True,
                                                   multichannel=False,
                                                   mode='constant')

                    nz_max = exp_config.image_size[2]
                    slice_vol = np.zeros((nx, ny, nz_max), dtype=np.float32)

                    nz_curr = vol_scaled.shape[2]
                    stack_from = (nz_max - nz_curr) // 2
                    stack_counter = stack_from

                    x, y, z = vol_scaled.shape

                    x_s = (x - nx) // 2
                    y_s = (y - ny) // 2
                    x_c = (nx - x) // 2
                    y_c = (ny - y) // 2

                    for zz in range(nz_curr):

                        slice_rescaled = vol_scaled[:, :, zz]

                        if x > nx and y > ny:
                            slice_cropped = slice_rescaled[x_s:x_s + nx,
                                                           y_s:y_s + ny]
                        else:
                            slice_cropped = np.zeros((nx, ny))
                            if x <= nx and y > ny:
                                slice_cropped[x_c:x_c +
                                              x, :] = slice_rescaled[:,
                                                                     y_s:y_s +
                                                                     ny]
                            elif x > nx and y <= ny:
                                slice_cropped[:, y_c:y_c +
                                              y] = slice_rescaled[x_s:x_s +
                                                                  nx, :]

                            else:
                                slice_cropped[x_c:x_c + x, y_c:y_c +
                                              y] = slice_rescaled[:, :]

                        slice_vol[:, :, stack_counter] = slice_cropped
                        stack_counter += 1

                    stack_to = stack_counter

                    network_input = np.float32(
                        np.reshape(slice_vol, (1, nx, ny, nz_max, 1)))
                    start_time = time.time()
                    mask_out, softmax = segmenter_model.predict(network_input)
                    logging.info('Classified 3D: %f secs' %
                                 (time.time() - start_time))

                    prediction_nzs = mask_out[0, :, :, stack_from:
                                              stack_to]  # non-zero-slices

                    if not prediction_nzs.shape[2] == nz_curr:
                        raise ValueError('sizes mismatch')

                    # ASSEMBLE BACK THE SLICES
                    prediction_scaled = np.zeros(
                        vol_scaled.shape)  # last dim is for logits classes

                    # insert cropped region into original image again
                    if x > nx and y > ny:
                        prediction_scaled[x_s:x_s + nx,
                                          y_s:y_s + ny, :] = prediction_nzs
                    else:
                        if x <= nx and y > ny:
                            prediction_scaled[:, y_s:y_s +
                                              ny, :] = prediction_nzs[x_c:x_c +
                                                                      x, :, :]
                        elif x > nx and y <= ny:
                            prediction_scaled[
                                x_s:x_s +
                                nx, :, :] = prediction_nzs[:, y_c:y_c + y, :]
                        else:
                            prediction_scaled[:, :, :] = prediction_nzs[
                                x_c:x_c + x, y_c:y_c + y, :]

                    logging.info('Prediction_scaled mean %f' %
                                 (np.mean(prediction_scaled)))

                    prediction = transform.resize(
                        prediction_scaled,
                        (mask.shape[0], mask.shape[1], mask.shape[2], 1),
                        order=1,
                        preserve_range=True,
                        mode='constant')
                    prediction = np.argmax(prediction, axis=-1)
                    prediction_arr = np.asarray(prediction, dtype=np.uint8)

                # This is the same for 2D and 3D again
                if do_postprocessing:
                    prediction_arr = utils.keep_largest_connected_components(
                        prediction_arr)

                elapsed_time = time.time() - start_time
                total_time += elapsed_time
                total_volumes += 1

                logging.info('Evaluation of volume took %f secs.' %
                             elapsed_time)

                if frame == ED_frame:
                    frame_suffix = '_ED'
                elif frame == ES_frame:
                    frame_suffix = '_ES'
                else:
                    raise ValueError(
                        'Frame doesnt correspond to ED or ES. frame = %d, ED = %d, ES = %d'
                        % (frame, ED_frame, ES_frame))

                # Save prediced mask
                out_file_name = os.path.join(
                    output_folder, 'prediction',
                    'patient' + patient_id + frame_suffix + '.nii.gz')
                if gt_exists:
                    out_affine = mask_affine
                    out_header = mask_header
                else:
                    out_affine = img_affine
                    out_header = img_header

                logging.info('saving to: %s' % out_file_name)
                utils.save_nii(out_file_name, prediction_arr, out_affine,
                               out_header)

                # Save image data to the same folder for convenience
                image_file_name = os.path.join(
                    output_folder, 'image',
                    'patient' + patient_id + frame_suffix + '.nii.gz')
                logging.info('saving to: %s' % image_file_name)
                utils.save_nii(image_file_name, img, out_affine, out_header)

                if gt_exists:

                    # Save GT image
                    gt_file_name = os.path.join(
                        output_folder, 'ground_truth',
                        'patient' + patient_id + frame_suffix + '.nii.gz')
                    logging.info('saving to: %s' % gt_file_name)
                    utils.save_nii(gt_file_name, mask, out_affine, out_header)

                    # Save difference mask between predictions and ground truth
                    difference_mask = np.where(
                        np.abs(prediction_arr - mask) > 0, [1], [0])
                    difference_mask = np.asarray(difference_mask,
                                                 dtype=np.uint8)
                    diff_file_name = os.path.join(
                        output_folder, 'difference',
                        'patient' + patient_id + frame_suffix + '.nii.gz')
                    logging.info('saving to: %s' % diff_file_name)
                    utils.save_nii(diff_file_name, difference_mask, out_affine,
                                   out_header)

                # calculate metrics
                y_ = prediction_arr
                y = mask

                per_lbl_dice = []
                per_lbl_assd = []
                per_lbl_hd = []

                for lbl in [3, 1, 2]:  #range(exp_config.nlabels):

                    binary_pred = (y_ == lbl) * 1
                    binary_gt = (y == lbl) * 1

                    if np.sum(binary_gt) == 0 and np.sum(binary_pred) == 0:
                        per_lbl_dice.append(1)
                        per_lbl_assd.append(0)
                        per_lbl_hd.append(0)
                    elif np.sum(binary_pred) > 0 and np.sum(
                            binary_gt) == 0 or np.sum(
                                binary_pred) == 0 and np.sum(binary_gt) > 0:
                        logging.warning(
                            'Structure missing in either GT (x)or prediction. ASSD and HD will not be accurate.'
                        )
                        per_lbl_dice.append(0)
                        per_lbl_assd.append(1)
                        per_lbl_hd.append(1)
                    else:
                        per_lbl_dice.append(dc(binary_pred, binary_gt))
                        per_lbl_assd.append(
                            assd(binary_pred, binary_gt, voxelspacing=zooms))
                        per_lbl_hd.append(
                            hd(binary_pred, binary_gt, voxelspacing=zooms))

                dice_list.append(per_lbl_dice)
                assd_list.append(per_lbl_assd)
                hd_list.append(per_lbl_hd)

    logging.info('Average time per volume: %f' % (total_time / total_volumes))

    dice_arr = np.asarray(dice_list)
    assd_arr = np.asarray(assd_list)
    hd_arr = np.asarray(hd_list)

    mean_per_lbl_dice = dice_arr.mean(axis=0)
    mean_per_lbl_assd = assd_arr.mean(axis=0)
    mean_per_lbl_hd = hd_arr.mean(axis=0)

    logging.info('Dice')
    logging.info(mean_per_lbl_dice)
    logging.info(np.mean(mean_per_lbl_dice))
    logging.info('ASSD')
    logging.info(mean_per_lbl_assd)
    logging.info(np.mean(mean_per_lbl_assd))
    logging.info('HD')
    logging.info(mean_per_lbl_hd)
    logging.info(np.mean(mean_per_lbl_hd))
def main(model_path, exp_config, do_plots=False):

    # Get Data
    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    # Make and restore vagan model
    segmenter_model = segmenter(
        exp_config=exp_config, data=data,
        fixed_batch_size=1)  # CRF model requires fixed batch size
    segmenter_model.load_weights(model_path, type='best_dice')

    # Run predictions in an endless loop
    dice_list = []
    assd_list = []
    hd_list = []

    logging.info('WARNING: Adding motion corruption!')

    for ii, batch in enumerate(data.test.iterate_batches(1)):

        if ii % 100 == 0:
            logging.info("Progress: %d" % ii)

        x, y = batch

        # Adding motion corrpution
        # x = utils.add_motion_artefacts(np.squeeze(x), 15)
        # x = x.reshape([1] + list(exp_config.image_size) + [1])
        # Add box corruption
        # x[:, 192 // 2 - 20:192 // 2 + 20, 192 // 2 - 5:192 // 2 + 5, :] = 0

        y_ = segmenter_model.predict(x)[0]

        per_lbl_dice = []
        per_lbl_assd = []
        per_lbl_hd = []
        per_pixel_preds = []
        per_pixel_gts = []

        if do_plots and not sys_config.running_on_gpu_host:
            fig = plt.figure()
            fig.add_subplot(131)
            plt.imshow(np.squeeze(x), cmap='gray')
            fig.add_subplot(132)
            plt.imshow(np.squeeze(y_))
            fig.add_subplot(133)
            plt.imshow(np.squeeze(y))
            plt.show()

        for lbl in range(exp_config.nlabels):

            binary_pred = (y_ == lbl) * 1
            binary_gt = (y == lbl) * 1

            if np.sum(binary_gt) == 0 and np.sum(binary_pred) == 0:
                per_lbl_dice.append(1)
                # per_lbl_assd.append(0)
                # per_lbl_hd.append(0)
            elif np.sum(binary_pred) > 0 and np.sum(binary_gt) == 0 or np.sum(
                    binary_pred) == 0 and np.sum(binary_gt) > 0:
                logging.warning(
                    'Structure missing in either GT (x)or prediction. ASSD and HD will not be accurate.'
                )
                per_lbl_dice.append(0)
                # per_lbl_assd.append(1)
                # per_lbl_hd.append(1)
            else:
                per_lbl_dice.append(dc(binary_pred, binary_gt))
                # per_lbl_assd.append(assd(binary_pred, binary_gt))
                # per_lbl_hd.append(hd(binary_pred, binary_gt))

        dice_list.append(per_lbl_dice)
        assd_list.append(per_lbl_assd)
        hd_list.append(per_lbl_hd)

        per_pixel_preds.append(y_.flatten())
        per_pixel_gts.append(y.flatten())

    dice_arr = np.asarray(dice_list)
    assd_arr = np.asarray(assd_list)
    hd_arr = np.asarray(hd_list)

    mean_per_lbl_dice = dice_arr.mean(axis=0)
    mean_per_lbl_assd = assd_arr.mean(axis=0)
    mean_per_lbl_hd = hd_arr.mean(axis=0)

    logging.info('Dice')
    logging.info(mean_per_lbl_dice)
    logging.info(np.mean(mean_per_lbl_dice))
    logging.info('foreground mean: %f' % (np.mean(mean_per_lbl_dice[1:])))
    np.savez(os.path.join(model_path, 'dice.npz'), dice_arr)
Пример #17
0
def main(model_path, exp_config):

    # Make and restore vagan model
    segvae_model = segvae(exp_config=exp_config)
    segvae_model.load_weights(model_path, type='best_ged')

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    lat_lvls = exp_config.latent_levels

    # RANDOM IMAGE
    # x_b, s_b = data.test.next_batch(1)

    # FIXED IMAGE
    # Cardiac: 100 normal image
    # LIDC: 200 large lesion, 203, 1757 complicated lesion
    # Prostate: 165 nice slice
    index = 165  #

    x_b = data.test.images[index,
                           ...].reshape([1] + list(exp_config.image_size))

    x_b_d = utils.convert_to_uint8(np.squeeze(x_b))
    x_b_d = utils.resize_image(x_b_d, video_target_size)

    if SAVE_VIDEO:
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        outfile = os.path.join(model_path, 'model_samples_id%d.avi' % index)
        out = cv2.VideoWriter(outfile, fourcc, 5.0,
                              (2 * video_target_size[1], video_target_size[0]))

        samps = 50
        for _ in range(samps):

            # fix all below current level (the correct implementation)
            feed_dict = {}
            feed_dict[segvae_model.training_pl] = False
            feed_dict[segvae_model.x_inp] = x_b

            s_p, s_p_list = segvae_model.sess.run(
                [segvae_model.s_out_eval, segvae_model.s_out_eval_list],
                feed_dict=feed_dict)
            s_p = np.argmax(s_p, axis=-1)

            # s_p_d = utils.convert_to_uint8(np.squeeze(s_p))
            s_p_d = np.squeeze(np.uint8((s_p / exp_config.nlabels) * 255))
            s_p_d = utils.resize_image(s_p_d,
                                       video_target_size,
                                       interp=cv2.INTER_NEAREST)

            img = np.concatenate([x_b_d, s_p_d], axis=1)
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

            img = histogram_equalization(img)

            if exp_config.data_identifier == 'acdc':
                # labels (0 85 170 255)
                rv = cv2.inRange(s_p_d, 84, 86)
                my = cv2.inRange(s_p_d, 169, 171)
                rv_cnt, hierarchy = cv2.findContours(rv, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                my_cnt, hierarchy = cv2.findContours(my, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)

                cv2.drawContours(img, rv_cnt, -1, (0, 255, 0), 1)
                cv2.drawContours(img, my_cnt, -1, (0, 0, 255), 1)
            if exp_config.data_identifier == 'uzh_prostate':
                # labels (0 85 170 255)
                print(np.unique(s_p_d))
                s1 = cv2.inRange(s_p_d, 84, 86)
                s2 = cv2.inRange(s_p_d, 169, 171)
                # s3 = cv2.inRange(s_p_d, 190, 192)
                s1_cnt, hierarchy = cv2.findContours(s1, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                s2_cnt, hierarchy = cv2.findContours(s2, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                # s3_cnt, hierarchy = cv2.findContours(s3, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

                cv2.drawContours(img, s1_cnt, -1, (0, 255, 0), 1)
                cv2.drawContours(img, s2_cnt, -1, (0, 0, 255), 1)
                # cv2.drawContours(img, s3_cnt, -1, (255, 0, 255), 1)
            elif exp_config.data_identifier == 'lidc':
                thresh = cv2.inRange(s_p_d, 127, 255)
                lesion, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                cv2.drawContours(img, lesion, -1, (0, 255, 0), 1)

            if SAVE_VIDEO:
                out.write(img)

            cv2.imshow('frame', img)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    if SAVE_VIDEO:
        out.release()
    cv2.destroyAllWindows()
def main(model_path, exp_config):

    # Make and restore vagan model
    segvae_model = segvae(exp_config=exp_config)
    segvae_model.load_weights(model_path, type='best_dice')

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    outfolder = '/home/baumgach/Reports/ETH/MICCAI2019_segvae/raw_figures'

    ims = exp_config.image_size

    # x_b, s_b = data.test.next_batch(1)

    # heart 100
    # prostate 165
    index = 165  # 100 is a normal image, 15 is a very good slice
    x_b = data.test.images[index,
                           ...].reshape([1] + list(exp_config.image_size))
    if exp_config.data_identifier == 'lidc':
        s_b = data.test.labels[index, ...]
        if np.sum(s_b[..., 0]) > 0:
            s_b = s_b[..., 0]
        elif np.sum(s_b[..., 1]) > 0:
            s_b = s_b[..., 1]
        elif np.sum(s_b[..., 2]) > 0:
            s_b = s_b[..., 2]
        else:
            s_b = s_b[..., 3]

        s_b = s_b.reshape([1] + list(exp_config.image_size[0:2]))
    elif exp_config.data_identifier == 'uzh_prostate':
        s_b = data.test.labels[index, ...]
        s_b = s_b[..., 0]
        s_b = s_b.reshape([1] + list(exp_config.image_size[0:2]))
    else:
        s_b = data.test.labels[index,
                               ...].reshape([1] +
                                            list(exp_config.image_size[0:2]))

    x_b_for_cnt = utils.convert_to_uint8(np.squeeze(x_b.copy()))
    x_b_for_cnt = cv2.cvtColor(x_b_for_cnt, cv2.COLOR_GRAY2BGR)

    x_b_for_cnt = utils.resize_image(x_b_for_cnt, (2 * ims[0], 2 * ims[1]),
                                     interp=cv2.INTER_NEAREST)
    x_b_for_cnt = utils.histogram_equalization(x_b_for_cnt)

    for ss in range(3):

        print(ss)

        s_p_list = segvae_model.predict_segmentation_sample_levels(
            x_b, return_softmax=False)

        accum_list = [None] * exp_config.latent_levels
        accum_list[exp_config.latent_levels - 1] = s_p_list[-1]
        for lvl in reversed(range(exp_config.latent_levels - 1)):
            accum_list[lvl] = accum_list[lvl + 1] + s_p_list[lvl]

        print('Plotting accum_list')
        for ii, img in enumerate(accum_list):

            plt.figure()
            img = utils.resize_image(np.squeeze(np.argmax(img, axis=-1)),
                                     (2 * ims[0], 2 * ims[1]),
                                     interp=cv2.INTER_NEAREST)
            plt.imshow(img[2 * 30:2 * 192 - 2 * 30, 2 * 30:2 * 192 - 2 * 30],
                       cmap='gray')
            plt.axis('off')
            plt.savefig(os.path.join(outfolder,
                                     'segm_lvl_%d_samp_%d.png' % (ii, ss)),
                        bbox_inches='tight')

        print('Plotting s_p_list')
        for ii, img in enumerate(s_p_list):

            img = utils.softmax(img)

            plt.figure()
            img = utils.resize_image(np.squeeze(img[..., 1]),
                                     (2 * ims[0], 2 * ims[1]),
                                     interp=cv2.INTER_NEAREST)
            plt.imshow(img[2 * 30:2 * 192 - 2 * 30, 2 * 30:2 * 192 - 2 * 30],
                       cmap='gray')
            plt.axis('off')
            plt.savefig(os.path.join(outfolder,
                                     'residual_lvl_%d_samp_%d.png' % (ii, ss)),
                        bbox_inches='tight')

        s_p_d = np.uint8((np.squeeze(np.argmax(accum_list[0], axis=-1)) /
                          (exp_config.nlabels - 1)) * 255)
        s_p_d = utils.resize_image(s_p_d, (2 * ims[0], 2 * ims[1]),
                                   interp=cv2.INTER_NEAREST)

        print('Calculating contours')
        print(np.unique(s_p_d))
        rv = cv2.inRange(s_p_d, 84, 86)
        my = cv2.inRange(s_p_d, 169, 171)
        rv_cnt, hierarchy = cv2.findContours(rv, cv2.RETR_TREE,
                                             cv2.CHAIN_APPROX_SIMPLE)
        my_cnt, hierarchy = cv2.findContours(my, cv2.RETR_TREE,
                                             cv2.CHAIN_APPROX_SIMPLE)

        x_b_for_cnt = cv2.drawContours(x_b_for_cnt, rv_cnt, -1, (0, 255, 0), 1)
        x_b_for_cnt = cv2.drawContours(x_b_for_cnt, my_cnt, -1, (0, 0, 255), 1)

    x_b_for_cnt = cv2.cvtColor(x_b_for_cnt, cv2.COLOR_BGR2RGB)

    print('Plotting final images...')
    plt.figure()
    plt.imshow(x_b_for_cnt[2 * 30:2 * 192 - 2 * 30,
                           2 * 30:2 * 192 - 2 * 30, :],
               cmap='gray')
    plt.axis('off')
    plt.savefig(os.path.join(outfolder, 'input_img_cnts.png'),
                bbox_inches='tight')

    plt.figure()
    x_b = utils.convert_to_uint8(x_b)
    x_b = cv2.cvtColor(np.squeeze(x_b), cv2.COLOR_GRAY2BGR)
    x_b = utils.histogram_equalization(x_b)
    x_b = utils.resize_image(x_b, (2 * ims[0], 2 * ims[1]),
                             interp=cv2.INTER_NEAREST)
    plt.imshow(x_b[2 * 30:2 * 192 - 2 * 30, 2 * 30:2 * 192 - 2 * 30],
               cmap='gray')
    plt.axis('off')
    plt.savefig(os.path.join(outfolder, 'input_img.png'), bbox_inches='tight')

    plt.figure()
    s_b = utils.resize_image(np.squeeze(s_b), (2 * ims[0], 2 * ims[1]),
                             interp=cv2.INTER_NEAREST)
    plt.imshow(s_b[2 * 30:2 * 192 - 2 * 30, 2 * 30:2 * 192 - 2 * 30],
               cmap='gray')
    plt.axis('off')
    plt.savefig(os.path.join(outfolder, 'gt_seg.png'), bbox_inches='tight')
def main(model_path, exp_config):

    # Make and restore vagan model
    segvae_model = segvae(exp_config=exp_config)
    segvae_model.load_weights(model_path, type='best_ged')

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    lat_lvls = exp_config.latent_levels

    # RANDOM IMAGE
    # x_b, s_b = data.test.next_batch(1)

    # FIXED IMAGE
    # Cardiac: 100 normal image
    # LIDC: 200 large lesion, 203, 1757 complicated lesion
    # Prostate: 165 nice slice
    index = 165  #

    x_b = data.test.images[index,
                           ...].reshape([1] + list(exp_config.image_size))
    if exp_config.data_identifier == 'lidc':
        s_b = data.test.labels[index, ...]
        if np.sum(s_b[..., 0]) > 0:
            s_b = s_b[..., 0]
        elif np.sum(s_b[..., 1]) > 0:
            s_b = s_b[..., 1]
        elif np.sum(s_b[..., 2]) > 0:
            s_b = s_b[..., 2]
        else:
            s_b = s_b[..., 3]

        s_b = s_b.reshape([1] + list(exp_config.image_size[0:2]))
    elif exp_config.data_identifier == 'uzh_prostate':
        s_b = data.test.labels[index, ...]
        s_b = s_b[..., 0]
        s_b = s_b.reshape([1] + list(exp_config.image_size[0:2]))
    else:
        s_b = data.test.labels[index,
                               ...].reshape([1] +
                                            list(exp_config.image_size[0:2]))
    #
    # print(x_b.shape)
    # print(s_b.shape)

    # x_b[:,30:64+10,64:64+10,:] = np.mean(x_b)
    #
    # x_b = utils.add_motion_artefacts(np.squeeze(x_b), 15)
    # x_b = x_b.reshape([1]+list(exp_config.image_size))

    x_b_d = utils.convert_to_uint8(np.squeeze(x_b))
    x_b_d = utils.resize_image(x_b_d, video_target_size)

    s_b_d = np.squeeze(np.uint8((s_b / exp_config.nlabels) * 255))
    s_b_d = utils.resize_image(s_b_d,
                               video_target_size,
                               interp=cv2.INTER_NEAREST)

    _, mu_list_init, _ = segvae_model.generate_prior_samples(
        x_b, return_params=True)

    if SAVE_VIDEO:
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        outfile = os.path.join(model_path, 'samplevid_id%d.avi' % index)
        out = cv2.VideoWriter(outfile, fourcc, 10.0,
                              (3 * video_target_size[1], video_target_size[0]))

    for lvl in reversed(range(lat_lvls)):

        samps = 50 if lat_lvls > 1 else 200
        for _ in range(samps):

            # z_list, mu_list, sigma_list = segvae_model.generate_prior_samples(x_b, return_params=True)

            print('doing level %d/%d' % (lvl, lat_lvls))

            # fix all below current level
            # for jj in range(lvl,lat_lvls-1):
            #     z_list[jj+1] = mu_list_init[jj+1]  # fix jj's level to mu

            # sample only current level
            # z_list_new = z_list.copy()
            # for jj in range(lat_lvls):
            #     z_list_new[jj] = mu_list_init[jj]
            # z_list_new[lvl] = z_list[lvl]
            # z_list = z_list_new
            #
            # print('z means')
            # for jj, z in enumerate(z_list):
            #     print('lvl %d: %.3f' % (jj, np.mean(z)))
            #
            #
            # feed_dict = {i: d for i, d in zip(segvae_model.prior_z_list_gen, z_list)}
            # feed_dict[segvae_model.training_pl] = False
            #

            # fix all below current level (the correct implementation)
            feed_dict = {}
            for jj in range(lvl, lat_lvls - 1):
                feed_dict[segvae_model.prior_z_list_gen[jj +
                                                        1]] = mu_list_init[jj +
                                                                           1]
            feed_dict[segvae_model.training_pl] = False
            feed_dict[segvae_model.x_inp] = x_b

            s_p, s_p_list = segvae_model.sess.run(
                [segvae_model.s_out_eval, segvae_model.s_out_eval_list],
                feed_dict=feed_dict)
            s_p = np.argmax(s_p, axis=-1)

            print(np.unique(s_p))

            # print('mean logits for myo cardium per level')
            # fig = plt.figure()
            #
            # cumsum = np.zeros((128,128))
            # cumsum_all = np.zeros((128,128,4))
            # for i, s in enumerate(reversed(s_p_list)):
            #
            #     cumsum += s[0,:,:,2]
            #     cumsum_all += s[0,:,:,:]
            #
            #     fig.add_subplot(4,4,i+1)
            #     plt.imshow(s[0,:,:,2])
            #
            #     fig.add_subplot(4,4,i+1+4)
            #     plt.imshow(cumsum)
            #
            #     fig.add_subplot(4,4,i+1+8)
            #     plt.imshow(1./(1+np.exp(-cumsum)))
            #
            #     fig.add_subplot(4,4,i+1+12)
            #     plt.imshow(np.argmax(cumsum_all, axis=-1))
            #
            #
            # plt.show()

            # DEUBG
            # cum_img = np.squeeze(s_p_list[lat_lvls-1])
            # cum_img_disp = softmax(cum_img)
            #
            # indiv_img = np.squeeze(s_p_list[lat_lvls-1])
            # indiv_img_disp = softmax(indiv_img)
            #
            # for ii in reversed(range(lat_lvls-1)):
            #     cum_img += np.squeeze(s_p_list[ii])
            #     indiv_img = np.squeeze(s_p_list[ii])
            #
            #     cum_img_disp = np.concatenate([cum_img_disp, softmax(cum_img)], axis=1)
            #     indiv_img_disp = np.concatenate([indiv_img_disp, softmax(indiv_img)], axis=1)
            #
            #
            # cum_img_disp = utils.convert_to_uint8(np.argmax(cum_img_disp, axis=-1))
            # indiv_img_disp = utils.convert_to_uint8(indiv_img_disp[:,:,2])
            #
            # cum_img_disp = np.concatenate([cum_img_disp, indiv_img_disp], axis=0)
            #
            #
            # print('cum img shape')
            # print(cum_img_disp.shape)
            # cv2.imshow('debug', cum_img_disp)
            # END DEBUG

            # s_p_d = utils.convert_to_uint8(np.squeeze(s_p))
            s_p_d = np.squeeze(np.uint8((s_p / exp_config.nlabels) * 255))
            s_p_d = utils.resize_image(s_p_d,
                                       video_target_size,
                                       interp=cv2.INTER_NEAREST)

            img = np.concatenate([x_b_d, s_b_d, s_p_d], axis=1)
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

            img = histogram_equalization(img)

            if exp_config.data_identifier == 'acdc':
                # labels (0 85 170 255)
                rv = cv2.inRange(s_p_d, 84, 86)
                my = cv2.inRange(s_p_d, 169, 171)
                rv_cnt, hierarchy = cv2.findContours(rv, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                my_cnt, hierarchy = cv2.findContours(my, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)

                cv2.drawContours(img, rv_cnt, -1, (0, 255, 0), 1)
                cv2.drawContours(img, my_cnt, -1, (0, 0, 255), 1)
            if exp_config.data_identifier == 'uzh_prostate':
                # labels (0 85 170 255)
                print(np.unique(s_p_d))
                s1 = cv2.inRange(s_p_d, 84, 86)
                s2 = cv2.inRange(s_p_d, 169, 171)
                # s3 = cv2.inRange(s_p_d, 190, 192)
                s1_cnt, hierarchy = cv2.findContours(s1, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                s2_cnt, hierarchy = cv2.findContours(s2, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                # s3_cnt, hierarchy = cv2.findContours(s3, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

                cv2.drawContours(img, s1_cnt, -1, (0, 255, 0), 1)
                cv2.drawContours(img, s2_cnt, -1, (0, 0, 255), 1)
                # cv2.drawContours(img, s3_cnt, -1, (255, 0, 255), 1)
            elif exp_config.data_identifier == 'lidc':
                thresh = cv2.inRange(s_p_d, 127, 255)
                lesion, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                cv2.drawContours(img, lesion, -1, (0, 255, 0), 1)

            font = cv2.FONT_HERSHEY_SIMPLEX
            cv2.putText(img, 'Sampling level %d/%d' % (lvl + 1, lat_lvls),
                        (30, 256 - 30), font, 1, (255, 255, 255), 1,
                        cv2.LINE_AA)

            print('actual size')
            print(img.shape)

            if SAVE_VIDEO:
                out.write(img)

            cv2.imshow('frame', img)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    if SAVE_VIDEO:
        out.release()
    cv2.destroyAllWindows()