def preproc_image(x, nlabels=None):

    x_b = np.squeeze(x)

    ims = x_b.shape[:2]

    if nlabels:
        x_b = np.uint8((x_b / (nlabels)) *
                       255)  # not nlabels - 1 because I prefer gray over white
    else:
        x_b = utils.convert_to_uint8(x_b)

    # x_b = cv2.cvtColor(np.squeeze(x_b), cv2.COLOR_GRAY2BGR)
    # x_b = utils.histogram_equalization(x_b)
    x_b = utils.resize_image(x_b, (2 * ims[0], 2 * ims[1]),
                             interp=cv2.INTER_NEAREST)

    # ims_n = x_b.shape[:2]
    # x_b = x_b[ims_n[0]//4:3*ims_n[0]//4, ims_n[1]//4: 3*ims_n[1]//4,...]
    return x_b
def main(model_path, exp_config):

    # Make and restore vagan model
    segvae_model = segvae(exp_config=exp_config)
    segvae_model.load_weights(model_path, type='best_dice')

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    outfolder = '/home/baumgach/Reports/ETH/MICCAI2019_segvae/raw_figures'

    ims = exp_config.image_size

    # x_b, s_b = data.test.next_batch(1)

    # heart 100
    # prostate 165
    index = 165  # 100 is a normal image, 15 is a very good slice
    x_b = data.test.images[index,
                           ...].reshape([1] + list(exp_config.image_size))
    if exp_config.data_identifier == 'lidc':
        s_b = data.test.labels[index, ...]
        if np.sum(s_b[..., 0]) > 0:
            s_b = s_b[..., 0]
        elif np.sum(s_b[..., 1]) > 0:
            s_b = s_b[..., 1]
        elif np.sum(s_b[..., 2]) > 0:
            s_b = s_b[..., 2]
        else:
            s_b = s_b[..., 3]

        s_b = s_b.reshape([1] + list(exp_config.image_size[0:2]))
    elif exp_config.data_identifier == 'uzh_prostate':
        s_b = data.test.labels[index, ...]
        s_b = s_b[..., 0]
        s_b = s_b.reshape([1] + list(exp_config.image_size[0:2]))
    else:
        s_b = data.test.labels[index,
                               ...].reshape([1] +
                                            list(exp_config.image_size[0:2]))

    x_b_for_cnt = utils.convert_to_uint8(np.squeeze(x_b.copy()))
    x_b_for_cnt = cv2.cvtColor(x_b_for_cnt, cv2.COLOR_GRAY2BGR)

    x_b_for_cnt = utils.resize_image(x_b_for_cnt, (2 * ims[0], 2 * ims[1]),
                                     interp=cv2.INTER_NEAREST)
    x_b_for_cnt = utils.histogram_equalization(x_b_for_cnt)

    for ss in range(3):

        print(ss)

        s_p_list = segvae_model.predict_segmentation_sample_levels(
            x_b, return_softmax=False)

        accum_list = [None] * exp_config.latent_levels
        accum_list[exp_config.latent_levels - 1] = s_p_list[-1]
        for lvl in reversed(range(exp_config.latent_levels - 1)):
            accum_list[lvl] = accum_list[lvl + 1] + s_p_list[lvl]

        print('Plotting accum_list')
        for ii, img in enumerate(accum_list):

            plt.figure()
            img = utils.resize_image(np.squeeze(np.argmax(img, axis=-1)),
                                     (2 * ims[0], 2 * ims[1]),
                                     interp=cv2.INTER_NEAREST)
            plt.imshow(img[2 * 30:2 * 192 - 2 * 30, 2 * 30:2 * 192 - 2 * 30],
                       cmap='gray')
            plt.axis('off')
            plt.savefig(os.path.join(outfolder,
                                     'segm_lvl_%d_samp_%d.png' % (ii, ss)),
                        bbox_inches='tight')

        print('Plotting s_p_list')
        for ii, img in enumerate(s_p_list):

            img = utils.softmax(img)

            plt.figure()
            img = utils.resize_image(np.squeeze(img[..., 1]),
                                     (2 * ims[0], 2 * ims[1]),
                                     interp=cv2.INTER_NEAREST)
            plt.imshow(img[2 * 30:2 * 192 - 2 * 30, 2 * 30:2 * 192 - 2 * 30],
                       cmap='gray')
            plt.axis('off')
            plt.savefig(os.path.join(outfolder,
                                     'residual_lvl_%d_samp_%d.png' % (ii, ss)),
                        bbox_inches='tight')

        s_p_d = np.uint8((np.squeeze(np.argmax(accum_list[0], axis=-1)) /
                          (exp_config.nlabels - 1)) * 255)
        s_p_d = utils.resize_image(s_p_d, (2 * ims[0], 2 * ims[1]),
                                   interp=cv2.INTER_NEAREST)

        print('Calculating contours')
        print(np.unique(s_p_d))
        rv = cv2.inRange(s_p_d, 84, 86)
        my = cv2.inRange(s_p_d, 169, 171)
        rv_cnt, hierarchy = cv2.findContours(rv, cv2.RETR_TREE,
                                             cv2.CHAIN_APPROX_SIMPLE)
        my_cnt, hierarchy = cv2.findContours(my, cv2.RETR_TREE,
                                             cv2.CHAIN_APPROX_SIMPLE)

        x_b_for_cnt = cv2.drawContours(x_b_for_cnt, rv_cnt, -1, (0, 255, 0), 1)
        x_b_for_cnt = cv2.drawContours(x_b_for_cnt, my_cnt, -1, (0, 0, 255), 1)

    x_b_for_cnt = cv2.cvtColor(x_b_for_cnt, cv2.COLOR_BGR2RGB)

    print('Plotting final images...')
    plt.figure()
    plt.imshow(x_b_for_cnt[2 * 30:2 * 192 - 2 * 30,
                           2 * 30:2 * 192 - 2 * 30, :],
               cmap='gray')
    plt.axis('off')
    plt.savefig(os.path.join(outfolder, 'input_img_cnts.png'),
                bbox_inches='tight')

    plt.figure()
    x_b = utils.convert_to_uint8(x_b)
    x_b = cv2.cvtColor(np.squeeze(x_b), cv2.COLOR_GRAY2BGR)
    x_b = utils.histogram_equalization(x_b)
    x_b = utils.resize_image(x_b, (2 * ims[0], 2 * ims[1]),
                             interp=cv2.INTER_NEAREST)
    plt.imshow(x_b[2 * 30:2 * 192 - 2 * 30, 2 * 30:2 * 192 - 2 * 30],
               cmap='gray')
    plt.axis('off')
    plt.savefig(os.path.join(outfolder, 'input_img.png'), bbox_inches='tight')

    plt.figure()
    s_b = utils.resize_image(np.squeeze(s_b), (2 * ims[0], 2 * ims[1]),
                             interp=cv2.INTER_NEAREST)
    plt.imshow(s_b[2 * 30:2 * 192 - 2 * 30, 2 * 30:2 * 192 - 2 * 30],
               cmap='gray')
    plt.axis('off')
    plt.savefig(os.path.join(outfolder, 'gt_seg.png'), bbox_inches='tight')
Example #3
0
def main(model_path, exp_config):

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    # RANDOM IMAGE
    # x_b, s_b = data.test.next_batch(1)

    # FIXED IMAGE
    # Cardiac: 100 normal image
    # LIDC: 200 large lesion, 203, 1757 complicated lesion
    # Prostate: 165 nice slice
    index = 165  #

    x_b = data.test.images[index,
                           ...].reshape([1] + list(exp_config.image_size))
    s_b = data.test.labels[index, ...]

    annot_index_gen = cycle(exp_config.annotator_range)

    x_b_d = utils.convert_to_uint8(np.squeeze(x_b))
    x_b_d = utils.resize_image(x_b_d, video_target_size)

    if SAVE_VIDEO:
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        outfile = os.path.join(model_path, 'gt_samples_id%d.avi' % index)
        out = cv2.VideoWriter(outfile, fourcc, 5.0,
                              (2 * video_target_size[1], video_target_size[0]))

    for _ in range(50):

        annot_index = next(annot_index_gen)
        s_b_d = s_b[..., annot_index]

        s_b_d = np.squeeze(np.uint8((s_b_d / exp_config.nlabels) * 255))
        s_b_d = utils.resize_image(s_b_d,
                                   video_target_size,
                                   interp=cv2.INTER_NEAREST)

        img = np.concatenate([x_b_d, s_b_d], axis=1)
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

        img = histogram_equalization(img)

        if exp_config.data_identifier == 'acdc':
            # labels (0 85 170 255)
            rv = cv2.inRange(s_b_d, 84, 86)
            my = cv2.inRange(s_b_d, 169, 171)
            rv_cnt, hierarchy = cv2.findContours(rv, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            my_cnt, hierarchy = cv2.findContours(my, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)

            cv2.drawContours(img, rv_cnt, -1, (0, 255, 0), 1)
            cv2.drawContours(img, my_cnt, -1, (0, 0, 255), 1)
        if exp_config.data_identifier == 'uzh_prostate':
            # labels (0 85 170 255)
            print(np.unique(s_b_d))
            s1 = cv2.inRange(s_b_d, 84, 86)
            s2 = cv2.inRange(s_b_d, 169, 171)
            # s3 = cv2.inRange(s_p_d, 190, 192)
            s1_cnt, hierarchy = cv2.findContours(s1, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            s2_cnt, hierarchy = cv2.findContours(s2, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            # s3_cnt, hierarchy = cv2.findContours(s3, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

            cv2.drawContours(img, s1_cnt, -1, (0, 255, 0), 1)
            cv2.drawContours(img, s2_cnt, -1, (0, 0, 255), 1)
            # cv2.drawContours(img, s3_cnt, -1, (255, 0, 255), 1)
        elif exp_config.data_identifier == 'lidc':
            thresh = cv2.inRange(s_b_d, 127, 255)
            lesion, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(img, lesion, -1, (0, 255, 0), 1)

        font = cv2.FONT_HERSHEY_SIMPLEX
        cv2.putText(
            img, 'Expert # %d/%d' %
            (annot_index + 1, len(exp_config.annotator_range)), (30, 256 - 30),
            font, 1, (255, 255, 255), 1, cv2.LINE_AA)

        if SAVE_VIDEO:
            out.write(img)

        cv2.imshow('frame', img)
        if cv2.waitKey(200) & 0xFF == ord('q'):
            break

    if SAVE_VIDEO:
        out.release()
    cv2.destroyAllWindows()
def main(model_path, exp_config):

    # Make and restore vagan model
    phiseg_model = phiseg(exp_config=exp_config)
    phiseg_model.load_weights(model_path, type='best_ged')

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    lat_lvls = exp_config.latent_levels

    # RANDOM IMAGE
    # x_b, s_b = data.test.next_batch(1)

    # FIXED IMAGE
    # Cardiac: 100 normal image
    # LIDC: 200 large lesion, 203, 1757 complicated lesion
    # Prostate: 165 nice slice, 170 is a challenging and interesting slice
    index = 165  # #

    if SAVE_GIF:
        outfolder_gif = os.path.join(model_path,
                                     'model_samples_id%d_gif' % index)
        utils.makefolder(outfolder_gif)

    x_b = data.test.images[index,
                           ...].reshape([1] + list(exp_config.image_size))

    x_b_d = utils.convert_to_uint8(np.squeeze(x_b))
    x_b_d = utils.resize_image(x_b_d, video_target_size)

    if exp_config.data_identifier == 'uzh_prostate':
        # rotate
        rows, cols = x_b_d.shape
        M = cv2.getRotationMatrix2D((cols / 2, rows / 2), 270, 1)
        x_b_d = cv2.warpAffine(x_b_d, M, (cols, rows))

    if SAVE_VIDEO:
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        outfile = os.path.join(model_path, 'model_samples_id%d.avi' % index)
        out = cv2.VideoWriter(outfile, fourcc, 5.0,
                              (2 * video_target_size[1], video_target_size[0]))

    samps = 20
    for ii in range(samps):

        # fix all below current level (the correct implementation)
        feed_dict = {}
        feed_dict[phiseg_model.training_pl] = False
        feed_dict[phiseg_model.x_inp] = x_b

        s_p, s_p_list = phiseg_model.sess.run(
            [phiseg_model.s_out_eval, phiseg_model.s_out_eval_list],
            feed_dict=feed_dict)
        s_p = np.argmax(s_p, axis=-1)

        # s_p_d = utils.convert_to_uint8(np.squeeze(s_p))
        s_p_d = np.squeeze(np.uint8((s_p / exp_config.nlabels) * 255))
        s_p_d = utils.resize_image(s_p_d,
                                   video_target_size,
                                   interp=cv2.INTER_NEAREST)

        if exp_config.data_identifier == 'uzh_prostate':
            #rotate
            s_p_d = cv2.warpAffine(s_p_d, M, (cols, rows))

        img = np.concatenate([x_b_d, s_p_d], axis=1)
        img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

        img = histogram_equalization(img)

        if exp_config.data_identifier == 'acdc':
            # labels (0 85 170 255)
            rv = cv2.inRange(s_p_d, 84, 86)
            my = cv2.inRange(s_p_d, 169, 171)
            rv_cnt, hierarchy = cv2.findContours(rv, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            my_cnt, hierarchy = cv2.findContours(my, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)

            cv2.drawContours(img, rv_cnt, -1, (0, 255, 0), 1)
            cv2.drawContours(img, my_cnt, -1, (0, 0, 255), 1)
        if exp_config.data_identifier == 'uzh_prostate':

            print(np.unique(s_p_d))
            s1 = cv2.inRange(s_p_d, 84, 86)
            s2 = cv2.inRange(s_p_d, 169, 171)
            # s3 = cv2.inRange(s_p_d, 190, 192)
            s1_cnt, hierarchy = cv2.findContours(s1, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            s2_cnt, hierarchy = cv2.findContours(s2, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            # s3_cnt, hierarchy = cv2.findContours(s3, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

            cv2.drawContours(img, s1_cnt, -1, (0, 255, 0), 1)
            cv2.drawContours(img, s2_cnt, -1, (0, 0, 255), 1)
            # cv2.drawContours(img, s3_cnt, -1, (255, 0, 255), 1)
        elif exp_config.data_identifier == 'lidc':
            thresh = cv2.inRange(s_p_d, 127, 255)
            lesion, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE,
                                                 cv2.CHAIN_APPROX_SIMPLE)
            cv2.drawContours(img, lesion, -1, (0, 255, 0), 1)

        if SAVE_VIDEO:
            out.write(img)

        if SAVE_GIF:
            outfile_gif = os.path.join(outfolder_gif,
                                       'frame_%s.png' % str(ii).zfill(3))
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            # scipy.misc.imsave(outfile_gif, img_rgb)
            im = Image.fromarray(img_rgb)
            im = im.resize((im.size[0] * 2, im.size[1] * 2), Image.ANTIALIAS)

            im.save(outfile_gif)

        if DISPLAY_VIDEO:
            cv2.imshow('frame', img)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    if SAVE_VIDEO:
        out.release()
    cv2.destroyAllWindows()
Example #5
0
def main(model_path, exp_config):

    # Make and restore vagan model
    segvae_model = segvae(exp_config=exp_config)
    segvae_model.load_weights(model_path, type='best_ged')

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    lat_lvls = exp_config.latent_levels

    # RANDOM IMAGE
    # x_b, s_b = data.test.next_batch(1)

    # FIXED IMAGE
    # Cardiac: 100 normal image
    # LIDC: 200 large lesion, 203, 1757 complicated lesion
    # Prostate: 165 nice slice
    index = 165  #

    x_b = data.test.images[index,
                           ...].reshape([1] + list(exp_config.image_size))

    x_b_d = utils.convert_to_uint8(np.squeeze(x_b))
    x_b_d = utils.resize_image(x_b_d, video_target_size)

    if SAVE_VIDEO:
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        outfile = os.path.join(model_path, 'model_samples_id%d.avi' % index)
        out = cv2.VideoWriter(outfile, fourcc, 5.0,
                              (2 * video_target_size[1], video_target_size[0]))

        samps = 50
        for _ in range(samps):

            # fix all below current level (the correct implementation)
            feed_dict = {}
            feed_dict[segvae_model.training_pl] = False
            feed_dict[segvae_model.x_inp] = x_b

            s_p, s_p_list = segvae_model.sess.run(
                [segvae_model.s_out_eval, segvae_model.s_out_eval_list],
                feed_dict=feed_dict)
            s_p = np.argmax(s_p, axis=-1)

            # s_p_d = utils.convert_to_uint8(np.squeeze(s_p))
            s_p_d = np.squeeze(np.uint8((s_p / exp_config.nlabels) * 255))
            s_p_d = utils.resize_image(s_p_d,
                                       video_target_size,
                                       interp=cv2.INTER_NEAREST)

            img = np.concatenate([x_b_d, s_p_d], axis=1)
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

            img = histogram_equalization(img)

            if exp_config.data_identifier == 'acdc':
                # labels (0 85 170 255)
                rv = cv2.inRange(s_p_d, 84, 86)
                my = cv2.inRange(s_p_d, 169, 171)
                rv_cnt, hierarchy = cv2.findContours(rv, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                my_cnt, hierarchy = cv2.findContours(my, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)

                cv2.drawContours(img, rv_cnt, -1, (0, 255, 0), 1)
                cv2.drawContours(img, my_cnt, -1, (0, 0, 255), 1)
            if exp_config.data_identifier == 'uzh_prostate':
                # labels (0 85 170 255)
                print(np.unique(s_p_d))
                s1 = cv2.inRange(s_p_d, 84, 86)
                s2 = cv2.inRange(s_p_d, 169, 171)
                # s3 = cv2.inRange(s_p_d, 190, 192)
                s1_cnt, hierarchy = cv2.findContours(s1, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                s2_cnt, hierarchy = cv2.findContours(s2, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                # s3_cnt, hierarchy = cv2.findContours(s3, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

                cv2.drawContours(img, s1_cnt, -1, (0, 255, 0), 1)
                cv2.drawContours(img, s2_cnt, -1, (0, 0, 255), 1)
                # cv2.drawContours(img, s3_cnt, -1, (255, 0, 255), 1)
            elif exp_config.data_identifier == 'lidc':
                thresh = cv2.inRange(s_p_d, 127, 255)
                lesion, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                cv2.drawContours(img, lesion, -1, (0, 255, 0), 1)

            if SAVE_VIDEO:
                out.write(img)

            cv2.imshow('frame', img)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    if SAVE_VIDEO:
        out.release()
    cv2.destroyAllWindows()
def main(model_path, exp_config):

    # Make and restore vagan model
    segvae_model = segvae(exp_config=exp_config)
    segvae_model.load_weights(model_path, type='best_ged')

    data_loader = data_switch(exp_config.data_identifier)
    data = data_loader(exp_config)

    lat_lvls = exp_config.latent_levels

    # RANDOM IMAGE
    # x_b, s_b = data.test.next_batch(1)

    # FIXED IMAGE
    # Cardiac: 100 normal image
    # LIDC: 200 large lesion, 203, 1757 complicated lesion
    # Prostate: 165 nice slice
    index = 165  #

    x_b = data.test.images[index,
                           ...].reshape([1] + list(exp_config.image_size))
    if exp_config.data_identifier == 'lidc':
        s_b = data.test.labels[index, ...]
        if np.sum(s_b[..., 0]) > 0:
            s_b = s_b[..., 0]
        elif np.sum(s_b[..., 1]) > 0:
            s_b = s_b[..., 1]
        elif np.sum(s_b[..., 2]) > 0:
            s_b = s_b[..., 2]
        else:
            s_b = s_b[..., 3]

        s_b = s_b.reshape([1] + list(exp_config.image_size[0:2]))
    elif exp_config.data_identifier == 'uzh_prostate':
        s_b = data.test.labels[index, ...]
        s_b = s_b[..., 0]
        s_b = s_b.reshape([1] + list(exp_config.image_size[0:2]))
    else:
        s_b = data.test.labels[index,
                               ...].reshape([1] +
                                            list(exp_config.image_size[0:2]))
    #
    # print(x_b.shape)
    # print(s_b.shape)

    # x_b[:,30:64+10,64:64+10,:] = np.mean(x_b)
    #
    # x_b = utils.add_motion_artefacts(np.squeeze(x_b), 15)
    # x_b = x_b.reshape([1]+list(exp_config.image_size))

    x_b_d = utils.convert_to_uint8(np.squeeze(x_b))
    x_b_d = utils.resize_image(x_b_d, video_target_size)

    s_b_d = np.squeeze(np.uint8((s_b / exp_config.nlabels) * 255))
    s_b_d = utils.resize_image(s_b_d,
                               video_target_size,
                               interp=cv2.INTER_NEAREST)

    _, mu_list_init, _ = segvae_model.generate_prior_samples(
        x_b, return_params=True)

    if SAVE_VIDEO:
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        outfile = os.path.join(model_path, 'samplevid_id%d.avi' % index)
        out = cv2.VideoWriter(outfile, fourcc, 10.0,
                              (3 * video_target_size[1], video_target_size[0]))

    for lvl in reversed(range(lat_lvls)):

        samps = 50 if lat_lvls > 1 else 200
        for _ in range(samps):

            # z_list, mu_list, sigma_list = segvae_model.generate_prior_samples(x_b, return_params=True)

            print('doing level %d/%d' % (lvl, lat_lvls))

            # fix all below current level
            # for jj in range(lvl,lat_lvls-1):
            #     z_list[jj+1] = mu_list_init[jj+1]  # fix jj's level to mu

            # sample only current level
            # z_list_new = z_list.copy()
            # for jj in range(lat_lvls):
            #     z_list_new[jj] = mu_list_init[jj]
            # z_list_new[lvl] = z_list[lvl]
            # z_list = z_list_new
            #
            # print('z means')
            # for jj, z in enumerate(z_list):
            #     print('lvl %d: %.3f' % (jj, np.mean(z)))
            #
            #
            # feed_dict = {i: d for i, d in zip(segvae_model.prior_z_list_gen, z_list)}
            # feed_dict[segvae_model.training_pl] = False
            #

            # fix all below current level (the correct implementation)
            feed_dict = {}
            for jj in range(lvl, lat_lvls - 1):
                feed_dict[segvae_model.prior_z_list_gen[jj +
                                                        1]] = mu_list_init[jj +
                                                                           1]
            feed_dict[segvae_model.training_pl] = False
            feed_dict[segvae_model.x_inp] = x_b

            s_p, s_p_list = segvae_model.sess.run(
                [segvae_model.s_out_eval, segvae_model.s_out_eval_list],
                feed_dict=feed_dict)
            s_p = np.argmax(s_p, axis=-1)

            print(np.unique(s_p))

            # print('mean logits for myo cardium per level')
            # fig = plt.figure()
            #
            # cumsum = np.zeros((128,128))
            # cumsum_all = np.zeros((128,128,4))
            # for i, s in enumerate(reversed(s_p_list)):
            #
            #     cumsum += s[0,:,:,2]
            #     cumsum_all += s[0,:,:,:]
            #
            #     fig.add_subplot(4,4,i+1)
            #     plt.imshow(s[0,:,:,2])
            #
            #     fig.add_subplot(4,4,i+1+4)
            #     plt.imshow(cumsum)
            #
            #     fig.add_subplot(4,4,i+1+8)
            #     plt.imshow(1./(1+np.exp(-cumsum)))
            #
            #     fig.add_subplot(4,4,i+1+12)
            #     plt.imshow(np.argmax(cumsum_all, axis=-1))
            #
            #
            # plt.show()

            # DEUBG
            # cum_img = np.squeeze(s_p_list[lat_lvls-1])
            # cum_img_disp = softmax(cum_img)
            #
            # indiv_img = np.squeeze(s_p_list[lat_lvls-1])
            # indiv_img_disp = softmax(indiv_img)
            #
            # for ii in reversed(range(lat_lvls-1)):
            #     cum_img += np.squeeze(s_p_list[ii])
            #     indiv_img = np.squeeze(s_p_list[ii])
            #
            #     cum_img_disp = np.concatenate([cum_img_disp, softmax(cum_img)], axis=1)
            #     indiv_img_disp = np.concatenate([indiv_img_disp, softmax(indiv_img)], axis=1)
            #
            #
            # cum_img_disp = utils.convert_to_uint8(np.argmax(cum_img_disp, axis=-1))
            # indiv_img_disp = utils.convert_to_uint8(indiv_img_disp[:,:,2])
            #
            # cum_img_disp = np.concatenate([cum_img_disp, indiv_img_disp], axis=0)
            #
            #
            # print('cum img shape')
            # print(cum_img_disp.shape)
            # cv2.imshow('debug', cum_img_disp)
            # END DEBUG

            # s_p_d = utils.convert_to_uint8(np.squeeze(s_p))
            s_p_d = np.squeeze(np.uint8((s_p / exp_config.nlabels) * 255))
            s_p_d = utils.resize_image(s_p_d,
                                       video_target_size,
                                       interp=cv2.INTER_NEAREST)

            img = np.concatenate([x_b_d, s_b_d, s_p_d], axis=1)
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

            img = histogram_equalization(img)

            if exp_config.data_identifier == 'acdc':
                # labels (0 85 170 255)
                rv = cv2.inRange(s_p_d, 84, 86)
                my = cv2.inRange(s_p_d, 169, 171)
                rv_cnt, hierarchy = cv2.findContours(rv, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                my_cnt, hierarchy = cv2.findContours(my, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)

                cv2.drawContours(img, rv_cnt, -1, (0, 255, 0), 1)
                cv2.drawContours(img, my_cnt, -1, (0, 0, 255), 1)
            if exp_config.data_identifier == 'uzh_prostate':
                # labels (0 85 170 255)
                print(np.unique(s_p_d))
                s1 = cv2.inRange(s_p_d, 84, 86)
                s2 = cv2.inRange(s_p_d, 169, 171)
                # s3 = cv2.inRange(s_p_d, 190, 192)
                s1_cnt, hierarchy = cv2.findContours(s1, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                s2_cnt, hierarchy = cv2.findContours(s2, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                # s3_cnt, hierarchy = cv2.findContours(s3, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)

                cv2.drawContours(img, s1_cnt, -1, (0, 255, 0), 1)
                cv2.drawContours(img, s2_cnt, -1, (0, 0, 255), 1)
                # cv2.drawContours(img, s3_cnt, -1, (255, 0, 255), 1)
            elif exp_config.data_identifier == 'lidc':
                thresh = cv2.inRange(s_p_d, 127, 255)
                lesion, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE,
                                                     cv2.CHAIN_APPROX_SIMPLE)
                cv2.drawContours(img, lesion, -1, (0, 255, 0), 1)

            font = cv2.FONT_HERSHEY_SIMPLEX
            cv2.putText(img, 'Sampling level %d/%d' % (lvl + 1, lat_lvls),
                        (30, 256 - 30), font, 1, (255, 255, 255), 1,
                        cv2.LINE_AA)

            print('actual size')
            print(img.shape)

            if SAVE_VIDEO:
                out.write(img)

            cv2.imshow('frame', img)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    if SAVE_VIDEO:
        out.release()
    cv2.destroyAllWindows()