Exemplo n.º 1
0
    def build_graph(self, image, label):
        pi, pl = image, label

        pi = tf_2tanh(pi)
        pl = tf_2tanh(pl)

        with tf.variable_scope('gen'):
            with tf.device('/device:GPU:0'):
                with tf.variable_scope('lbl'):
                    pil, _ = self.generator(pi, last_dim=1)

        losses = []

        pa = seg_to_aff_op(tf_2imag(pl) + 1.0, name='pa')  # 0 1
        pila = seg_to_aff_op(tf_2imag(pil) + 1.0, name='pila')  # 0 1

        with tf.name_scope('loss_aff'):
            aff_ila = tf.identity(tf.subtract(
                binary_cross_entropy(pa, pila),
                dice_coe(pa, pila, axis=[0, 1, 2, 3], loss_type='jaccard')),
                                  name='aff_ila')
            #losses.append(3e-3*aff_ila)
            add_moving_summary(aff_ila)

        with tf.name_scope('loss_smooth'):
            cond = tf.cast(pila == 0.0, tf.bool)
            pilc = tf.where(cond, tf.ones_like(pila), pila, name='pilc')
            smooth_ila = tf.reduce_mean((tf.ones_like(pila) - pilc),
                                        name='smooth_ila')
            losses.append(1e1 * smooth_ila)
            add_moving_summary(smooth_ila)

        with tf.name_scope('loss_mae'):
            mae_il = tf.reduce_mean(tf.abs(pl - pil), name='mae_il')
            losses.append(1e0 * mae_il)
            add_moving_summary(mae_il)

            mae_ila = tf.reduce_mean(tf.abs(pa - pila), name='mae_ila')
            losses.append(1e0 * mae_ila)
            add_moving_summary(mae_ila)

        self.cost = tf.reduce_sum(losses, name='self.cost')
        add_moving_summary(self.cost)
        # Visualization

        # Segmentation
        pz = tf.zeros_like(pi)
        # viz = tf.concat([image, label, pic], axis=2)
        viz = tf.concat([
            tf.concat([pi, pl, pil], axis=2),
            tf.concat([pa[..., 0:1], pa[..., 1:2], pa[..., 2:3]], axis=2),
            tf.concat([pila[..., 0:1], pila[..., 1:2], pila[..., 2:3]],
                      axis=2),
        ],
                        axis=1)
        viz = tf_2imag(viz)

        viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
        tf.summary.image('colorized', viz, max_outputs=50)
Exemplo n.º 2
0
 def membr_loss(y_true, y_pred, name='membr_loss'):
     return tf.subtract(binary_cross_entropy(
         cvt2imag(y_true, maxVal=1.0), cvt2imag(y_pred,
                                                maxVal=1.0)),
                        dice_coe(cvt2imag(y_true, maxVal=1.0),
                                 cvt2imag(y_pred, maxVal=1.0),
                                 axis=[1, 2, 3],
                                 loss_type='jaccard'),
                        name='membr_im')
Exemplo n.º 3
0
 def membr_loss(y_true, y_pred, name='membr_loss'):
     loss = tf.reduce_mean(
         tf.subtract(
             binary_cross_entropy(cvt2imag(y_true, maxVal=1.0),
                                  cvt2imag(y_pred, maxVal=1.0)),
             dice_coe(cvt2imag(y_true, maxVal=1.0),
                      cvt2imag(y_pred, maxVal=1.0),
                      axis=[1, 2, 3],
                      loss_type='jaccard')))
     # loss = tf.reshape(loss, [-1])
     return tf.identity(loss, name=name)
Exemplo n.º 4
0
    def _build_graph(self, inputs):
        G = tf.get_default_graph()  # For round
        tf.local_variables_initializer()
        tf.global_variables_initializer()
        pi, pm, pl, ui, um, ul = inputs
        pi = tf_2tanh(pi)
        pm = tf_2tanh(pm)
        pl = tf_2tanh(pl)
        ui = tf_2tanh(ui)
        um = tf_2tanh(um)
        ul = tf_2tanh(ul)

        pl = toMaxLabels(pl, factor=MAX_LABEL)  #0 MAX
        pl = toRangeTanh(pl, factor=MAX_LABEL)  # -1 1
        pa = seg_to_aff_op(toMaxLabels(pl, factor=MAX_LABEL),
                           name='pa')  # Calculate the affinity 	#0, 1




        with argscope([Conv2D, Deconv2D, FullyConnected],
             W_init=tf.truncated_normal_initializer(stddev=0.02),
             use_bias=False), \
          argscope(BatchNorm, gamma_init=tf.random_uniform_initializer()), \
          argscope([Conv2D, Deconv2D, BatchNorm], data_format='NHWC'), \
          argscope([Conv2D], dilation_rate=1):

            with tf.variable_scope('gen'):

                with tf.variable_scope('affnt'):
                    pia, feat_ia = self.generator(pi, last_dim=3)
                with tf.variable_scope('label'):
                    pil, feat_il = self.generator(pi, last_dim=1)

            #

            with tf.variable_scope('discrim'):
                dis_real = self.discriminator(tf.concat([pi, pl, pa], axis=-1))
                dis_fake = self.discriminator(
                    tf.concat([pi, pil, pia], axis=-1))

        with G.gradient_override_map({"Round": "Identity"}):
            with tf.variable_scope('fix'):
                # Round
                pil = toMaxLabels(pil, factor=MAX_LABEL)  #0 MAX
                pil = toRangeTanh(pil, factor=MAX_LABEL)  # -1 1

                pila = seg_to_aff_op(
                    toMaxLabels(pil, factor=MAX_LABEL),
                    name='pila')  # Calculate the affinity 	#0, 1

                pial = aff_to_seg_op(tf_2imag(pia, maxVal=1.0),
                                     name='pial')  # Calculate the segmentation
                pial = toRangeTanh(pial, factor=MAX_LABEL)  # -1, 1

                pil_ = (pial + pil) / 2.0  # Return the result
                pia_ = (pila + pia) / 2.0

        with tf.name_scope('GAN_loss'):
            G_loss, D_loss = self.build_losses(dis_real,
                                               dis_fake,
                                               name='gan_loss')

        with tf.name_scope('rand_loss'):
            rand_il = tf.reduce_mean(tf_rand_score(pl, pil_), name='rand_loss')
            # rand_il  = tf.reduce_mean(tf.cast(tf.py_func (tf_rand_score, [pl, pil_], tf.float64), tf.float32), name='rand_loss')
        with tf.name_scope('discrim_loss'):

            def regDLF(y_true,
                       y_pred,
                       alpha=1,
                       beta=1,
                       gamma=0.01,
                       delta_v=0.5,
                       delta_d=1.5,
                       name='loss_discrim'):
                def tf_norm(inputs, axis=1, epsilon=1e-7, name='safe_norm'):
                    squared_norm = tf.reduce_sum(tf.square(inputs),
                                                 axis=axis,
                                                 keep_dims=True)
                    safe_norm = tf.sqrt(squared_norm + epsilon)
                    return tf.identity(safe_norm, name=name)

                ###

                lins = tf.linspace(0.0, DIMZ * DIMY * DIMX, DIMZ * DIMY * DIMX)
                lins = tf.cast(lins, tf.int32)
                # lins = lins / tf.reduce_max(lins) * 255
                # lins = tf_2tanh(lins)
                # lins = tf.reshape(lins, tf.shape(y_true), name='lins_3d')
                # print lins
                lins_z = tf.div(lins, (DIMY * DIMX))
                lins_y = tf.div(tf.mod(lins, (DIMY * DIMX)), DIMY)
                lins_x = tf.mod(tf.mod(lins, (DIMY * DIMX)), DIMY)

                lins = tf.cast(lins, tf.float32)
                lins_z = tf.cast(lins_z, tf.float32)
                lins_y = tf.cast(lins_y, tf.float32)
                lins_x = tf.cast(lins_x, tf.float32)

                lins = lins / tf.reduce_max(lins) * 255
                lins_z = lins_z / tf.reduce_max(lins_z) * 255
                lins_y = lins_y / tf.reduce_max(lins_y) * 255
                lins_x = lins_x / tf.reduce_max(lins_x) * 255

                lins = tf_2tanh(lins)
                lins_z = tf_2tanh(lins_z)
                lins_y = tf_2tanh(lins_y)
                lins_x = tf_2tanh(lins_x)

                lins = tf.reshape(lins, tf.shape(y_true), name='lins')
                lins_z = tf.reshape(lins_z, tf.shape(y_true), name='lins_z')
                lins_y = tf.reshape(lins_y, tf.shape(y_true), name='lins_y')
                lins_x = tf.reshape(lins_x, tf.shape(y_true), name='lins_x')

                y_true = tf.reshape(y_true, [DIMZ * DIMY * DIMX])
                y_pred = tf.concat([y_pred, lins, lins_z, lins_y, lins_x],
                                   axis=-1)

                nDim = tf.shape(y_pred)[-1]
                X = tf.reshape(y_pred, [DIMZ * DIMY * DIMX, nDim])
                uniqueLabels, uniqueInd = tf.unique(y_true)

                numUnique = tf.size(
                    uniqueLabels)  # Get the number of connected component

                Sigma = tf.unsorted_segment_sum(X, uniqueInd, numUnique)
                # ones_Sigma = tf.ones((tf.shape(X)[0], 1))
                ones_Sigma = tf.ones_like(X)
                ones_Sigma = tf.unsorted_segment_sum(ones_Sigma, uniqueInd,
                                                     numUnique)
                mu = tf.divide(Sigma, ones_Sigma)

                Lreg = tf.reduce_mean(tf.norm(mu, axis=1, ord=1))

                T = tf.norm(tf.subtract(tf.gather(mu, uniqueInd), X),
                            axis=1,
                            ord=1)
                T = tf.divide(T, Lreg)
                T = tf.subtract(T, delta_v)
                T = tf.clip_by_value(T, 0, T)
                T = tf.square(T)

                ones_Sigma = tf.ones_like(uniqueInd, dtype=tf.float32)
                ones_Sigma = tf.unsorted_segment_sum(ones_Sigma, uniqueInd,
                                                     numUnique)
                clusterSigma = tf.unsorted_segment_sum(T, uniqueInd, numUnique)
                clusterSigma = tf.divide(clusterSigma, ones_Sigma)

                # Lvar = tf.reduce_mean(clusterSigma, axis=0)
                Lvar = tf.reduce_mean(clusterSigma)

                mu_interleaved_rep = tf.tile(mu, [numUnique, 1])
                mu_band_rep = tf.tile(mu, [1, numUnique])
                mu_band_rep = tf.reshape(mu_band_rep,
                                         (numUnique * numUnique, nDim))

                mu_diff = tf.subtract(mu_band_rep, mu_interleaved_rep)
                # Remove zero vector
                # intermediate_tensor = reduce_sum(tf.abs(x), 1)
                # zero_vector = tf.zeros(shape=(1,1), dtype=tf.float32)
                # bool_mask = tf.not_equal(intermediate_tensor, zero_vector)
                # omit_zeros = tf.boolean_mask(x, bool_mask)
                intermediate_tensor = tf.reduce_sum(tf.abs(mu_diff), 1)
                zero_vector = tf.zeros(shape=(1, 1), dtype=tf.float32)
                bool_mask = tf.not_equal(intermediate_tensor, zero_vector)
                omit_zeros = tf.boolean_mask(mu_diff, bool_mask)
                mu_diff = tf.expand_dims(omit_zeros, axis=1)
                print mu_diff
                mu_diff = tf.norm(mu_diff, ord=1)
                # squared_norm = tf.reduce_sum(tf.square(s), axis=axis,keep_dims=True)
                # safe_norm = tf.sqrt(squared_norm + epsilon)
                # squared_norm = tf.reduce_sum(tf.square(omit_zeros), axis=-1,keep_dims=True)
                # safe_norm = tf.sqrt(squared_norm + 1e-6)
                # mu_diff = safe_norm

                mu_diff = tf.divide(mu_diff, Lreg)

                mu_diff = tf.subtract(2 * delta_d, mu_diff)
                mu_diff = tf.clip_by_value(mu_diff, 0, mu_diff)
                mu_diff = tf.square(mu_diff)

                numUniqueF = tf.cast(numUnique, tf.float32)
                Ldist = tf.reduce_mean(mu_diff)

                # L = alpha * Lvar + beta * Ldist + gamma * Lreg
                # L = tf.reduce_mean(L, keep_dims=True)
                L = tf.reduce_sum([alpha * Lvar, beta * Ldist, gamma * Lreg],
                                  keep_dims=False)
                print L
                print Ldist
                print Lvar
                print Lreg
                return tf.identity(L, name=name)

            discrim_il = regDLF(toMaxLabels(pl, factor=MAX_LABEL),
                                tf.concat([feat_il, feat_ia], axis=-1),
                                name='discrim_il')
        with tf.name_scope('recon_loss'):
            recon_il = tf.reduce_mean(tf.abs(pl - pil_), name='recon_il')
        with tf.name_scope('affnt_loss'):
            # affnt_il = tf.reduce_mean(tf.abs(pa - pia_), name='affnt_il')
            affnt_il = tf.reduce_mean(
                tf.subtract(
                    binary_cross_entropy(pa, pia_),
                    dice_coe(pa, pia_, axis=[0, 1, 2, 3],
                             loss_type='jaccard')))
        with tf.name_scope('residual_loss'):
            # residual_a = tf.reduce_mean(tf.abs(pia - pila), name='residual_a')
            # residual_l = tf.reduce_mean(tf.abs(pil - pial), name='residual_l')
            residual_a = tf.reduce_mean(tf.cast(tf.not_equal(pia, pila),
                                                tf.float32),
                                        name='residual_a')
            residual_l = tf.reduce_mean(tf.cast(tf.not_equal(pil, pial),
                                                tf.float32),
                                        name='residual_l')
            residual_il = tf.reduce_mean([residual_a, residual_l],
                                         name='residual_il')

        def label_imag(y_pred_L, name='label_imag'):
            mag_grad_L = magnitute_central_difference(y_pred_L,
                                                      name='mag_grad_L')
            cond = tf.greater(mag_grad_L, tf.zeros_like(mag_grad_L))
            thresholded_mag_grad_L = tf.where(cond,
                                              tf.ones_like(mag_grad_L),
                                              tf.zeros_like(mag_grad_L),
                                              name='thresholded_mag_grad_L')

            thresholded_mag_grad_L = tf_2tanh(thresholded_mag_grad_L,
                                              maxVal=1.0)
            return thresholded_mag_grad_L

        g_il = label_imag(pil_, name='label_il')
        g_l = label_imag(pl, name='label_l')

        self.g_loss = tf.reduce_sum([
            1 * (G_loss),
            10 * (recon_il),
            10 * (residual_il),
            1 * (rand_il),
            50 * (discrim_il),
            0.005 * affnt_il,
        ],
                                    name='G_loss_total')
        self.d_loss = tf.reduce_sum([D_loss], name='D_loss_total')
        wd_g = regularize_cost('gen/.*/W',
                               l2_regularizer(1e-5),
                               name='G_regularize')
        wd_d = regularize_cost('discrim/.*/W',
                               l2_regularizer(1e-5),
                               name='D_regularize')

        self.g_loss = tf.add(self.g_loss, wd_g, name='g_loss')
        self.d_loss = tf.add(self.d_loss, wd_d, name='d_loss')

        self.collect_variables()

        add_moving_summary(self.d_loss, self.g_loss)
        with tf.name_scope('summaries'):
            add_tensor_summary(recon_il, types=['scalar'], name='recon_il')
            add_tensor_summary(rand_il, types=['scalar'], name='rand_il')
            add_tensor_summary(discrim_il, types=['scalar'], name='discrim_il')
            add_tensor_summary(affnt_il, types=['scalar'], name='affnt_il')
            add_tensor_summary(residual_il,
                               types=['scalar'],
                               name='residual_il')

        #Segmentation
        viz = tf.concat([
            tf.concat([pi, pl, g_l, g_il], 2),
            tf.concat([tf.zeros_like(pi), pil, pial, pil_], 2),
        ], 1)

        viz = tf_2imag(viz)
        viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
        tf.summary.image('colorized', viz, max_outputs=50)

        # Affinity
        vis = tf.concat([pa, pila, pia, pia_], 2)

        vis = tf_2imag(vis)
        vis = tf.cast(tf.clip_by_value(vis, 0, 255), tf.uint8, name='vis')
        tf.summary.image('affinities', vis, max_outputs=50)
Exemplo n.º 5
0
    def build_graph(self, image, label):
        pi, pl = image, label

        pi = tf_2tanh(pi)
        pl = tf_2tanh(pl)

        with tf.variable_scope('gen'):
            # with tf.device('/device:GPU:0'):
            with tf.variable_scope('feats'):
                pid, _ = self.generator(pi, last_dim=16)
        # with tf.device('/device:GPU:1'):
        # with tf.device('/cpu:0'):
            with tf.variable_scope('label'):
                with varreplace.freeze_variables():
                    pil = tf_cluster_dbscan(pid,
                                            feature_dim=16,
                                            label_shape=[1, DIMY, DIMX, 1])
                    pil = tf_2tanh(pil)

        losses = []

        pa = seg_to_aff_op(tf_2imag(pl) + 1.0, name='pa')  # 0 1
        pila = seg_to_aff_op(tf_2imag(pil) + 1.0, name='pila')  # 0 1

        with tf.name_scope('loss_spectral'):
            spectral_loss = supervised_clustering_loss(
                tf.concat([tf_2imag(pid) / 255.0, pil / 255.0, pila], axis=-1),
                tf_2imag(pl),
                20,
                (DIMY, DIMX),
            )

            losses.append(1e0 * spectral_loss)
            add_moving_summary(spectral_loss)

        with tf.name_scope('loss_discrim'):
            param_var = 1.0  #args.var
            param_dist = 1.0  #args.dist
            param_reg = 0.001  #args.reg
            delta_v = 0.5  #args.dvar
            delta_d = 1.5  #args.ddist

            #discrim_loss  =  ### Optimization operations
            discrim_loss, l_var, l_dist, l_reg = discriminative_loss(
                tf.concat([tf_2imag(pid) / 255.0, pil / 255.0, pila], axis=-1),
                tf_2imag(pl), 20, (DIMY, DIMX), delta_v, delta_d, param_var,
                param_dist, param_reg)

            losses.append(1e-3 * discrim_loss)
            add_moving_summary(discrim_loss)

        with tf.name_scope('loss_aff'):
            aff_ila = tf.identity(tf.subtract(
                binary_cross_entropy(pa, pila),
                dice_coe(pa, pila, axis=[0, 1, 2, 3], loss_type='jaccard')),
                                  name='aff_ila')
            #losses.append(3e-3*aff_ila)
            add_moving_summary(aff_ila)

        with tf.name_scope('loss_smooth'):
            smooth_ila = tf.reduce_mean((tf.ones_like(pila) - pila),
                                        name='smooth_ila')
            losses.append(1e1 * smooth_ila)
            add_moving_summary(smooth_ila)

        with tf.name_scope('loss_mae'):
            mae_il = tf.reduce_mean(tf.abs(pl - pil), name='mae_il')
            losses.append(1e0 * mae_il)
            add_moving_summary(mae_il)

            mae_ila = tf.reduce_mean(tf.abs(pa - pila), name='mae_ila')
            losses.append(1e0 * mae_ila)
            add_moving_summary(mae_ila)

        self.cost = tf.reduce_sum(losses, name='self.cost')
        add_moving_summary(self.cost)
        # Visualization

        # Segmentation
        pz = tf.zeros_like(pi)
        # viz = tf.concat([image, label, pic], axis=2)
        viz = tf.concat([
            tf.concat([pi, pl, pil], axis=2),
            tf.concat([pa[..., 0:1], pa[..., 1:2], pa[..., 2:3]], axis=2),
            tf.concat([pila[..., 0:1], pila[..., 1:2], pila[..., 2:3]],
                      axis=2),
        ],
                        axis=1)
        viz = tf_2imag(viz)

        viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
        tf.summary.image('colorized', viz, max_outputs=50)
    def _build_graph(self, inputs):
        volume, gt_seg, gt_affs = inputs
        volume = tf.expand_dims(volume, 3)
        volume = tf.expand_dims(volume, 0)
        #         image = image * 2 - 1

        with argscope(LeakyReLU, alpha=0.2),\
            argscope ([Conv3D, DeConv3D], use_bias=False,
                kernel_shape=3, stride=2, padding='SAME',
                W_init=tf.contrib.layers.variance_scaling_initializer(factor=0.333, uniform=True)):
            _in = Conv3D('in',
                         volume,
                         NB_FILTERS,
                         kernel_shape=(3, 5, 5),
                         stride=1,
                         padding='SAME',
                         nl=INELU,
                         use_bias=True)
            e0 = residual_enc('e0', volume, NB_FILTERS * 1)
            e1 = residual_enc('e1', e0, NB_FILTERS * 2)
            e2 = residual_enc('e2', e1, NB_FILTERS * 4, kernel_shape=(1, 3, 3))
            e3 = residual_enc('e3', e2, NB_FILTERS * 8, kernel_shape=(1, 3, 3))

            e3 = Dropout('dr', e3, rate=0.5)

            d3 = residual_dec('d3',
                              e3 + e3,
                              NB_FILTERS * 4,
                              kernel_shape=(1, 3, 3))
            d2 = residual_dec('d2',
                              d3 + e2,
                              NB_FILTERS * 2,
                              kernel_shape=(1, 3, 3))
            d1 = residual_dec('d1', d2 + e1, NB_FILTERS * 1)
            d0 = residual_dec('d0', d1 + e0, NB_FILTERS * 1)

            logits = funcs.Conv3D('x_out',
                                  d0,
                                  len(nhood),
                                  kernel_shape=(3, 5, 5),
                                  stride=1,
                                  padding='SAME',
                                  nl=tf.identity,
                                  use_bias=True)

        logits = tf.squeeze(logits)
        logits = tf.transpose(logits, perm=[3, 0, 1, 2], name='logits')

        affs = cvt2sigm(tf.tanh(logits))
        affs = tf.identity(affs, name='affs')

        ######################################################################################################################

        wbce_malis_loss = wbce_malis(logits,
                                     affs,
                                     gt_affs,
                                     gt_seg,
                                     nhood,
                                     affs_shape,
                                     name='wbce_malis',
                                     limit_z=False)
        wbce_malis_loss = tf.identity(wbce_malis_loss, name='wbce_malis_loss')

        dice_loss = tf.identity(
            (1. -
             dice_coe(affs, gt_affs, axis=[0, 1, 2, 3], loss_type='jaccard')) *
            0.1,
            name='dice_coe')
        tot_loss = tf.identity(wbce_malis_loss + dice_loss, name='tot_loss')
        ######################################################################################################################

        self.cost = tot_loss
        summary.add_tensor_summary(tot_loss, types=['scalar'])
        summary.add_tensor_summary(dice_loss, types=['scalar'])
        summary.add_tensor_summary(wbce_malis_loss, types=['scalar'])
Exemplo n.º 7
0
    def build_graph(self, image, membr, label):
        # G = tf.get_default_graph()
        # with G.gradient_override_map({"Round": "Identity", "ArgMax": "Identity"}):
        # pi, pl = image, label
        input_image, membrane, correct_label = image, membr, label

        # Calculate the loss
        feature_dim = 16
        param_var = 1.0
        param_dist = 1.0
        param_reg = 0.001
        delta_v = 0.5
        delta_d = 1.5
        image_shape = (DIMY, DIMX)
        # Run the model
        semantic, prediction = self.load_enet(input_image, feature_dim)
        semantic = tf.identity(semantic, 'semantic')

        prediction = tf.identity(prediction, 'prediction')

        # Summary the loss
        losses = []
        with tf.name_scope('loss_aff'):
            aff_im = tf.identity(1.0 - dice_coe(
                semantic, membrane, axis=[0, 1, 2, 3], loss_type='jaccard'),
                                 name='aff_im')
            losses.append(1e1 * aff_im)
            add_moving_summary(aff_im)
        with tf.name_scope('loss_discrim'):
            disc_loss, l_var, l_dist, l_reg = discriminative_loss_single(
                prediction, correct_label, feature_dim, image_shape, delta_v,
                delta_d, param_var, param_dist, param_reg)
            losses.append(1e0 * disc_loss)
            add_moving_summary(disc_loss)

        self.cost = tf.reduce_mean(losses)
        #add_moving_summary(disc_loss)
        #add_moving_summary(l_var)
        #add_moving_summary(l_dist)
        #add_moving_summary(l_reg)

        # Summary the image
        tf.summary.image('image_', input_image, max_outputs=50)
        tf.summary.image('label_', correct_label, max_outputs=50)
        tf.summary.image('membr_', 255 * membrane, max_outputs=50)

        tf.summary.image('preds0', prediction[..., 0:1], max_outputs=50)
        tf.summary.image('preds1', prediction[..., 1:2], max_outputs=50)
        tf.summary.image('preds2', prediction[..., 2:3], max_outputs=50)

        viz = tf.concat([
            tf.concat([
                input_image[..., 0:1], 255 * membrane[..., 0:1],
                correct_label[..., 0:1], 255 * semantic[..., 0:1]
            ],
                      axis=2),
        ],
                        axis=1)
        viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
        tf.summary.image('labelized', viz, max_outputs=50)

        return self.cost