示例#1
0
    def build_graph(self):
        assert tf.test.is_gpu_available()

        image_raw = tf.constant(self.image, tf.float32)
        noise = tf.get_variable(
            'noise',
            shape=[BATCH_SIZE, 28, 28],
            initializer=tf.random_normal_initializer(stddev=0.01))
        image = tf.add(image_raw, noise, name='image')
        image = tf.clip_by_value(image, -1e-5, 1 + 1e-5)
        noise_show = tf.concat([
            tf.concat([noise[i * N + j] for j in range(N)], axis=1)
            for i in range(N)
        ],
                               axis=0,
                               name='noise_show')
        image_show = tf.concat([
            tf.concat([image[i * N + j] for j in range(N)], axis=1)
            for i in range(N)
        ],
                               axis=0,
                               name='image_show')
        l = tf.expand_dims(image, 3)
        l = l * 2 - 1  # center the pixels values at zero

        l = Conv2D('conv1', l, 32, 3, activation=tf.nn.relu)
        l = Conv2D('conv2', l, 64, 3, activation=tf.nn.relu)
        l = MaxPooling('pool2', l, 2)
        # l = Dropout(l, keep_prob=0.75)
        l = FullyConnected('fc1', l, 128)
        l = tf.nn.relu(l)
        # l = Dropout(l, keep_prob=0.5)
        logits = FullyConnected('fc2', l, 10)
        tf.argmax(logits, 1, name='pred_label')

        cost = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=logits, labels=self.label)
        cost = tf.reduce_mean(cost, name='cross_entropy_loss')
        add_tensor_summary(cost, ['scalar'])
        # l1_cost = regularize_cost('noise', tf.contrib.layers.l1_regularizer(1e-2))
        l2_cost = regularize_cost('noise',
                                  tf.contrib.layers.l2_regularizer(1e-2))
        add_tensor_summary(l2_cost, ['scalar'])
        add_tensor_summary(
            tf.div(tf.reduce_sum(tf.abs(noise)), BATCH_SIZE, name='L1Norm'),
            ['scalar'])
        add_tensor_summary(
            tf.div(tf.sqrt(tf.reduce_sum(tf.square(noise))),
                   BATCH_SIZE,
                   name='L2Norm'), ['scalar'])

        wrong = tf.to_float(tf.logical_not(
            tf.nn.in_top_k(logits, self.label, 1)),
                            name='wrong_vector')
        add_tensor_summary(tf.reduce_mean(wrong, name='train_error'),
                           ['scalar'])

        return tf.add_n([cost, l2_cost], name='cost')
示例#2
0
 def optimizer(self):
     lr_var = tf.get_variable("learning_rate",
                              initializer=self._lr,
                              trainable=False)
     add_tensor_summary(lr_var, ['scalar'], main_tower_only=True)
     return tf.train.AdamOptimizer(lr_var,
                                   beta1=self._beta1,
                                   beta2=self._beta2,
                                   epsilon=1e-3)
示例#3
0
def wbce_malis (logits, affs, gt_affs, gt_seg, neighborhood, affs_shape, name='wbce_malis', limit_z=False):
    with tf.name_scope (name):

        pos_cnt = tf.cast (tf.count_nonzero (tf.cast (gt_affs, tf.int32)), tf.float32)
        neg_cnt = tf.cast (tf.constant (np.prod (affs_shape)), tf.float32) - pos_cnt
        pos_weight = neg_cnt / pos_cnt
        summary.add_tensor_summary (pos_weight, types=['scalar'])
        weighted_bce_losses = tf.nn.weighted_cross_entropy_with_logits (targets=gt_affs, logits=logits, pos_weight=pos_weight)

        malis_weights, pos_weights, neg_weights = malis_weights_op(affs, gt_affs, gt_seg, neighborhood, name='malis_weights', limit_z=limit_z)
        pos_weights = tf.identity (pos_weights, name='pos_weight')
        neg_weights = tf.identity (neg_weights, name='neg_weight')
        
        malis_weighted_bce_loss = tf.reduce_mean (tf.multiply (malis_weights, weighted_bce_losses), name='malis_weighted_bce_loss')

        return malis_weighted_bce_loss
示例#4
0
    def _build_graph(self, inputs):
        G = tf.get_default_graph()  # For round
        tf.local_variables_initializer()
        tf.global_variables_initializer()
        pi, pm, pl, ui, um, ul, sm = inputs
        pi = cvt2tanh(pi)
        pm = cvt2tanh(pm)
        pl = cvt2tanh(pl)
        ui = cvt2tanh(ui)
        um = cvt2tanh(um)
        ul = cvt2tanh(ul)
        sm = cvt2tanh(sm)

        def tf_rand_score(x1, x2):
            return np.mean(1.0 -
                           adjusted_rand_score(x1.flatten(), x2.flatten()))

        def rounded(label, factor=MAX_LABEL, name='quantized'):
            with G.gradient_override_map({"Round": "Identity"}):
                with freeze_variables():
                    with tf.name_scope(name=name):
                        label = cvt2imag(label, maxVal=factor)
                        label = tf.round(label)
                        label = cvt2tanh(label, maxVal=factor)
                    return tf.identity(label, name=name)


        with argscope([Conv2D, Deconv2D, FullyConnected],
             W_init=tf.truncated_normal_initializer(stddev=0.02),
             use_bias=False), \
          argscope(BatchNorm, gamma_init=tf.random_uniform_initializer()), \
          argscope([Conv2D, Deconv2D, BatchNorm], data_format='NHWC'), \
          argscope(LeakyReLU, alpha=0.2):

            with tf.variable_scope('gen'):
                # Real pair image 4 gen
                with tf.variable_scope('I2M'):
                    pim, feat_im = self.generator(pi)
                with tf.variable_scope('M2L'):
                    piml, feat_iml = self.generator(
                        tf.concat([pim, sm], axis=-1))
                    pml, feat_ml = self.generator(tf.concat([pm, sm], axis=-1))

            with tf.variable_scope('discrim'):
                # with tf.variable_scope('I'):
                # 	i_dis_real 			  = self.discriminator(ui)
                # 	i_dis_fake_from_label = self.discriminator(plmi)
                with tf.variable_scope('M'):
                    m_dis_real = self.discriminator(um)
                    m_dis_fake_from_image = self.discriminator(pim)
                    # m_dis_fake_from_label = self.discriminator(plm)
                with tf.variable_scope('L'):
                    l_dis_real = self.discriminator(ul)
                    l_dis_fake_from_image = self.discriminator(piml)

        # piml  = rounded(piml) #
        # pml   = rounded(pml)

        with tf.name_scope('Recon_L_loss'):
            recon_iml = tf.reduce_mean(tf.abs((pl) - (piml)), name='recon_iml')
            recon_ml = tf.reduce_mean(tf.abs((pl) - (pml)), name='recon_ml')

        with tf.name_scope('Recon_M_loss'):
            recon_im = tf.reduce_mean(tf.abs((pm) - (pim)), name='recon_im')

        with tf.name_scope('GAN_loss'):
            # G_loss_IL, D_loss_IL = self.build_losses(i_dis_real, i_dis_fake_from_label, name='IL')
            G_loss_IL, D_loss_IL = self.build_losses(l_dis_real,
                                                     l_dis_fake_from_image,
                                                     name='IL')
            G_loss_MI, D_loss_MI = self.build_losses(m_dis_real,
                                                     m_dis_fake_from_image,
                                                     name='MI')
            # G_loss_ML, D_loss_ML = self.build_losses(m_dis_real, m_dis_fake_from_label, name='ML')

        # custom loss for membr
        with tf.name_scope('membr_loss'):

            def membr_loss(y_true, y_pred, name='membr_loss'):
                loss = tf.reduce_mean(
                    tf.subtract(
                        binary_cross_entropy(cvt2imag(y_true, maxVal=1.0),
                                             cvt2imag(y_pred, maxVal=1.0)),
                        dice_coe(cvt2imag(y_true, maxVal=1.0),
                                 cvt2imag(y_pred, maxVal=1.0),
                                 axis=[1, 2, 3],
                                 loss_type='jaccard')))
                # loss = tf.reshape(loss, [-1])
                return tf.identity(loss, name=name)

            membr_im = membr_loss(pm, pim, name='membr_im')
            membr_iml = membr_loss(pl, piml, name='membr_iml')
            membr_ml = membr_loss(pl, pml, name='membr_ml')

        # custom loss for label
        with tf.name_scope('label_loss'):

            def label_loss(y_true_L, y_pred_L, y_grad_M, name='label_loss'):
                g_mag_grad_M = cvt2imag(y_grad_M, maxVal=1.0)
                mag_grad_L = magnitute_central_difference(y_pred_L,
                                                          name='mag_grad_L')
                cond = tf.greater(mag_grad_L, tf.zeros_like(mag_grad_L))
                thresholded_mag_grad_L = tf.where(
                    cond,
                    tf.ones_like(mag_grad_L),
                    tf.zeros_like(mag_grad_L),
                    name='thresholded_mag_grad_L')

                gtv_guess = tf.multiply(g_mag_grad_M,
                                        thresholded_mag_grad_L,
                                        name='gtv_guess')
                loss_gtv_guess = tf.reduce_mean(gtv_guess,
                                                name='loss_gtv_guess')
                # loss_gtv_guess = tf.reshape(loss_gtv_guess, [-1])
                thresholded_mag_grad_L = cvt2tanh(thresholded_mag_grad_L,
                                                  maxVal=1.0)
                gtv_guess = cvt2tanh(gtv_guess, maxVal=1.0)
                return tf.identity(loss_gtv_guess,
                                   name=name), thresholded_mag_grad_L

            label_iml, g_iml = label_loss(None, piml, pim, name='label_iml')
            label_ml, g_ml = label_loss(None, pml, pm, name='label_loss_ml')

        # custom loss for tf_rand_score
        with tf.name_scope('rand_loss'):
            rand_iml = tf.reduce_mean(
                tf.cast(tf.py_func(tf_rand_score, [piml, pl], tf.float64),
                        tf.float32))
            rand_ml = tf.reduce_mean(
                tf.cast(tf.py_func(tf_rand_score, [pml, pl], tf.float64),
                        tf.float32))

        with tf.name_scope('discrim_loss'):

            def regDLF(y_true,
                       y_pred,
                       alpha=1,
                       beta=1,
                       gamma=0.01,
                       delta_v=0.5,
                       delta_d=1.5,
                       name='loss_discrim'):
                def tf_norm(inputs, axis=1, epsilon=1e-7, name='safe_norm'):
                    squared_norm = tf.reduce_sum(tf.square(inputs),
                                                 axis=axis,
                                                 keep_dims=True)
                    safe_norm = tf.sqrt(squared_norm + epsilon)
                    return tf.identity(safe_norm, name=name)

                ###
                y_true = tf.reshape(y_true, [DIMZ * DIMY * DIMX])

                nDim = tf.shape(y_pred)[-1]
                X = tf.reshape(y_pred, [DIMZ * DIMY * DIMX, nDim])
                uniqueLabels, uniqueInd = tf.unique(y_true)

                numUnique = tf.size(
                    uniqueLabels)  # Get the number of connected component

                Sigma = tf.unsorted_segment_sum(X, uniqueInd, numUnique)
                # ones_Sigma = tf.ones((tf.shape(X)[0], 1))
                ones_Sigma = tf.ones_like(X)
                ones_Sigma = tf.unsorted_segment_sum(ones_Sigma, uniqueInd,
                                                     numUnique)
                mu = tf.divide(Sigma, ones_Sigma)

                Lreg = tf.reduce_mean(tf.norm(mu, axis=1, ord=1))

                T = tf.norm(tf.subtract(tf.gather(mu, uniqueInd), X),
                            axis=1,
                            ord=1)
                T = tf.divide(T, Lreg)
                T = tf.subtract(T, delta_v)
                T = tf.clip_by_value(T, 0, T)
                T = tf.square(T)

                ones_Sigma = tf.ones_like(uniqueInd, dtype=tf.float32)
                ones_Sigma = tf.unsorted_segment_sum(ones_Sigma, uniqueInd,
                                                     numUnique)
                clusterSigma = tf.unsorted_segment_sum(T, uniqueInd, numUnique)
                clusterSigma = tf.divide(clusterSigma, ones_Sigma)

                # Lvar = tf.reduce_mean(clusterSigma, axis=0)
                Lvar = tf.reduce_mean(clusterSigma)

                mu_interleaved_rep = tf.tile(mu, [numUnique, 1])
                mu_band_rep = tf.tile(mu, [1, numUnique])
                mu_band_rep = tf.reshape(mu_band_rep,
                                         (numUnique * numUnique, nDim))

                mu_diff = tf.subtract(mu_band_rep, mu_interleaved_rep)
                # Remove zero vector
                # intermediate_tensor = reduce_sum(tf.abs(x), 1)
                # zero_vector = tf.zeros(shape=(1,1), dtype=tf.float32)
                # bool_mask = tf.not_equal(intermediate_tensor, zero_vector)
                # omit_zeros = tf.boolean_mask(x, bool_mask)
                intermediate_tensor = tf.reduce_sum(tf.abs(mu_diff), 1)
                zero_vector = tf.zeros(shape=(1, 1), dtype=tf.float32)
                bool_mask = tf.not_equal(intermediate_tensor, zero_vector)
                omit_zeros = tf.boolean_mask(mu_diff, bool_mask)
                mu_diff = tf.expand_dims(omit_zeros, axis=1)
                print mu_diff
                mu_diff = tf.norm(mu_diff, ord=1)
                # squared_norm = tf.reduce_sum(tf.square(s), axis=axis,keep_dims=True)
                # safe_norm = tf.sqrt(squared_norm + epsilon)
                # squared_norm = tf.reduce_sum(tf.square(omit_zeros), axis=-1,keep_dims=True)
                # safe_norm = tf.sqrt(squared_norm + 1e-6)
                # mu_diff = safe_norm

                mu_diff = tf.divide(mu_diff, Lreg)

                mu_diff = tf.subtract(2 * delta_d, mu_diff)
                mu_diff = tf.clip_by_value(mu_diff, 0, mu_diff)
                mu_diff = tf.square(mu_diff)

                numUniqueF = tf.cast(numUnique, tf.float32)
                Ldist = tf.reduce_mean(mu_diff)

                # L = alpha * Lvar + beta * Ldist + gamma * Lreg
                # L = tf.reduce_mean(L, keep_dims=True)
                L = tf.reduce_sum([alpha * Lvar, beta * Ldist, gamma * Lreg],
                                  keep_dims=False)
                print L
                print Ldist
                print Lvar
                print Lreg
                return tf.identity(L, name=name)

            discrim_im = regDLF(cvt2imag(pm, maxVal=1.0),
                                feat_im,
                                name='discrim_im')
            discrim_iml = regDLF(cvt2imag(pl, maxVal=1.0),
                                 feat_iml,
                                 name='discrim_iml')
            discrim_ml = regDLF(cvt2imag(pl, maxVal=1.0),
                                feat_ml,
                                name='discrim_ml')
            print discrim_im
            print discrim_iml
            print discrim_ml

            print rand_iml
            print rand_ml
        self.g_loss = tf.reduce_sum(
            [
                #(recon_imi), # + recon_lmi + recon_imlmi), #
                # (recon_iml), # + recon_lml + recon_lmiml), #
                # (recon_im), #  + recon_lm + recon_mim + recon_mlm),
                # (recon_ml), #  + recon_lm + recon_mim + recon_mlm),
                # (rand_iml), # + rand_lml + rand_lmiml), #
                # (rand_ml), #  + rand_lm + rand_mim + rand_mlm),
                # (G_loss_IL + G_loss_LI + G_loss_MI + G_loss_ML),
                (G_loss_IL + G_loss_MI),
                # 0.1*(discrim_im + discrim_iml + discrim_ml),
                0.001 * (
                    membr_im
                ),  # + membr_lm + membr_imlm + membr_lmim + membr_mlm + membr_mim),
                0.001 * (
                    membr_iml
                ),  # + membr_lm + membr_imlm + membr_lmim + membr_mlm + membr_mim),
                0.001 * (
                    membr_ml
                ),  # + membr_lm + membr_imlm + membr_lmim + membr_mlm + membr_mim),
                # (label_iml + label_lml + label_lmiml + label_ml)
                # (label_iml + label_ml)
            ],
            name='G_loss_total')
        self.d_loss = tf.reduce_sum(
            [
                # (D_loss_IL + D_loss_LI + D_loss_MI + D_loss_ML),
                (D_loss_IL + D_loss_MI),
            ],
            name='D_loss_total')

        wd_g = regularize_cost('gen/.*/W',
                               l2_regularizer(1e-5),
                               name='G_regularize')
        wd_d = regularize_cost('discrim/.*/W',
                               l2_regularizer(1e-5),
                               name='D_regularize')

        self.g_loss = tf.add(self.g_loss, wd_g, name='g_loss')
        self.d_loss = tf.add(self.d_loss, wd_d, name='d_loss')

        self.collect_variables()

        add_moving_summary(self.d_loss, self.g_loss)
        # add_moving_summary(
        with tf.name_scope('summaries'):
            add_tensor_summary(recon_iml, types=['scalar'], name='recon_iml')
            add_tensor_summary(recon_im, types=['scalar'], name='recon_im')
            add_tensor_summary(recon_ml, types=['scalar'], name='recon_ml')
            add_tensor_summary(label_iml, types=['scalar'], name='label_iml')
            add_tensor_summary(label_ml, types=['scalar'], name='label_ml')
            add_tensor_summary(rand_iml, types=['scalar'], name='rand_iml')
            add_tensor_summary(rand_ml, types=['scalar'], name='rand_ml')
            add_tensor_summary(membr_im, types=['scalar'], name='membr_im')
            add_tensor_summary(membr_iml, types=['scalar'], name='membr_iml')
            add_tensor_summary(membr_ml, types=['scalar'], name='membr_ml')
            add_tensor_summary(discrim_im, types=['scalar'], name='discrim_im')
            add_tensor_summary(discrim_iml,
                               types=['scalar'],
                               name='discrim_iml')
            add_tensor_summary(discrim_ml, types=['scalar'], name='discrim_ml')
            # recon_imi, recon_lmi, recon_imlmi,
            # recon_lml, recon_iml, recon_lmiml,
            # recon_mim, recon_mlm, recon_im , recon_lm,
            # )

        viz = tf.concat(
            [
                tf.concat([ui, pi, sm, pim, piml, g_iml], 2),
                # tf.concat([ul, pl, plm, plmi, plmim, plmiml], 2),
                tf.concat([um, pl, ul, pm, pml, g_ml], 2),
                # tf.concat([pl, pl, g_iml, g_lml, g_lmiml,   g_ml], 2),
            ],
            1)
        # add_moving_summary(
        # 	recon_imi, recon_lmi,# recon_imlmi,
        # 	recon_lml, recon_iml,# recon_lmiml,
        # 	recon_mim, recon_mlm, recon_im , recon_lm,
        # 	)
        # viz = tf.concat([tf.concat([ui, pi, pim, piml], 2),
        # 				 tf.concat([ul, pl, plm, plmi], 2),
        # 				 tf.concat([um, pm, pmi, pmim], 2),
        # 				 tf.concat([um, pm, pml, pmlm], 2),
        # 				 ], 1)
        viz = cvt2imag(viz)
        viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
        tf.summary.image('colorized', viz, max_outputs=50)
示例#5
0
    def _build_graph(self, inputs):
        G = tf.get_default_graph()  # For round
        tf.local_variables_initializer()
        tf.global_variables_initializer()
        pi, pm, pl, ui, um, ul = inputs
        pi = tf_2tanh(pi)
        pm = tf_2tanh(pm)
        pl = tf_2tanh(pl)
        ui = tf_2tanh(ui)
        um = tf_2tanh(um)
        ul = tf_2tanh(ul)

        pl = toMaxLabels(pl, factor=MAX_LABEL)  #0 MAX
        pl = toRangeTanh(pl, factor=MAX_LABEL)  # -1 1
        pa = seg_to_aff_op(toMaxLabels(pl, factor=MAX_LABEL),
                           name='pa')  # Calculate the affinity 	#0, 1




        with argscope([Conv2D, Deconv2D, FullyConnected],
             W_init=tf.truncated_normal_initializer(stddev=0.02),
             use_bias=False), \
          argscope(BatchNorm, gamma_init=tf.random_uniform_initializer()), \
          argscope([Conv2D, Deconv2D, BatchNorm], data_format='NHWC'), \
          argscope([Conv2D], dilation_rate=1):

            with tf.variable_scope('gen'):

                with tf.variable_scope('affnt'):
                    pia, feat_ia = self.generator(pi, last_dim=3)
                with tf.variable_scope('label'):
                    pil, feat_il = self.generator(pi, last_dim=1)

            #

            with tf.variable_scope('discrim'):
                dis_real = self.discriminator(tf.concat([pi, pl, pa], axis=-1))
                dis_fake = self.discriminator(
                    tf.concat([pi, pil, pia], axis=-1))

        with G.gradient_override_map({"Round": "Identity"}):
            with tf.variable_scope('fix'):
                # Round
                pil = toMaxLabels(pil, factor=MAX_LABEL)  #0 MAX
                pil = toRangeTanh(pil, factor=MAX_LABEL)  # -1 1

                pila = seg_to_aff_op(
                    toMaxLabels(pil, factor=MAX_LABEL),
                    name='pila')  # Calculate the affinity 	#0, 1

                pial = aff_to_seg_op(tf_2imag(pia, maxVal=1.0),
                                     name='pial')  # Calculate the segmentation
                pial = toRangeTanh(pial, factor=MAX_LABEL)  # -1, 1

                pil_ = (pial + pil) / 2.0  # Return the result
                pia_ = (pila + pia) / 2.0

        with tf.name_scope('GAN_loss'):
            G_loss, D_loss = self.build_losses(dis_real,
                                               dis_fake,
                                               name='gan_loss')

        with tf.name_scope('rand_loss'):
            rand_il = tf.reduce_mean(tf_rand_score(pl, pil_), name='rand_loss')
            # rand_il  = tf.reduce_mean(tf.cast(tf.py_func (tf_rand_score, [pl, pil_], tf.float64), tf.float32), name='rand_loss')
        with tf.name_scope('discrim_loss'):

            def regDLF(y_true,
                       y_pred,
                       alpha=1,
                       beta=1,
                       gamma=0.01,
                       delta_v=0.5,
                       delta_d=1.5,
                       name='loss_discrim'):
                def tf_norm(inputs, axis=1, epsilon=1e-7, name='safe_norm'):
                    squared_norm = tf.reduce_sum(tf.square(inputs),
                                                 axis=axis,
                                                 keep_dims=True)
                    safe_norm = tf.sqrt(squared_norm + epsilon)
                    return tf.identity(safe_norm, name=name)

                ###

                lins = tf.linspace(0.0, DIMZ * DIMY * DIMX, DIMZ * DIMY * DIMX)
                lins = tf.cast(lins, tf.int32)
                # lins = lins / tf.reduce_max(lins) * 255
                # lins = tf_2tanh(lins)
                # lins = tf.reshape(lins, tf.shape(y_true), name='lins_3d')
                # print lins
                lins_z = tf.div(lins, (DIMY * DIMX))
                lins_y = tf.div(tf.mod(lins, (DIMY * DIMX)), DIMY)
                lins_x = tf.mod(tf.mod(lins, (DIMY * DIMX)), DIMY)

                lins = tf.cast(lins, tf.float32)
                lins_z = tf.cast(lins_z, tf.float32)
                lins_y = tf.cast(lins_y, tf.float32)
                lins_x = tf.cast(lins_x, tf.float32)

                lins = lins / tf.reduce_max(lins) * 255
                lins_z = lins_z / tf.reduce_max(lins_z) * 255
                lins_y = lins_y / tf.reduce_max(lins_y) * 255
                lins_x = lins_x / tf.reduce_max(lins_x) * 255

                lins = tf_2tanh(lins)
                lins_z = tf_2tanh(lins_z)
                lins_y = tf_2tanh(lins_y)
                lins_x = tf_2tanh(lins_x)

                lins = tf.reshape(lins, tf.shape(y_true), name='lins')
                lins_z = tf.reshape(lins_z, tf.shape(y_true), name='lins_z')
                lins_y = tf.reshape(lins_y, tf.shape(y_true), name='lins_y')
                lins_x = tf.reshape(lins_x, tf.shape(y_true), name='lins_x')

                y_true = tf.reshape(y_true, [DIMZ * DIMY * DIMX])
                y_pred = tf.concat([y_pred, lins, lins_z, lins_y, lins_x],
                                   axis=-1)

                nDim = tf.shape(y_pred)[-1]
                X = tf.reshape(y_pred, [DIMZ * DIMY * DIMX, nDim])
                uniqueLabels, uniqueInd = tf.unique(y_true)

                numUnique = tf.size(
                    uniqueLabels)  # Get the number of connected component

                Sigma = tf.unsorted_segment_sum(X, uniqueInd, numUnique)
                # ones_Sigma = tf.ones((tf.shape(X)[0], 1))
                ones_Sigma = tf.ones_like(X)
                ones_Sigma = tf.unsorted_segment_sum(ones_Sigma, uniqueInd,
                                                     numUnique)
                mu = tf.divide(Sigma, ones_Sigma)

                Lreg = tf.reduce_mean(tf.norm(mu, axis=1, ord=1))

                T = tf.norm(tf.subtract(tf.gather(mu, uniqueInd), X),
                            axis=1,
                            ord=1)
                T = tf.divide(T, Lreg)
                T = tf.subtract(T, delta_v)
                T = tf.clip_by_value(T, 0, T)
                T = tf.square(T)

                ones_Sigma = tf.ones_like(uniqueInd, dtype=tf.float32)
                ones_Sigma = tf.unsorted_segment_sum(ones_Sigma, uniqueInd,
                                                     numUnique)
                clusterSigma = tf.unsorted_segment_sum(T, uniqueInd, numUnique)
                clusterSigma = tf.divide(clusterSigma, ones_Sigma)

                # Lvar = tf.reduce_mean(clusterSigma, axis=0)
                Lvar = tf.reduce_mean(clusterSigma)

                mu_interleaved_rep = tf.tile(mu, [numUnique, 1])
                mu_band_rep = tf.tile(mu, [1, numUnique])
                mu_band_rep = tf.reshape(mu_band_rep,
                                         (numUnique * numUnique, nDim))

                mu_diff = tf.subtract(mu_band_rep, mu_interleaved_rep)
                # Remove zero vector
                # intermediate_tensor = reduce_sum(tf.abs(x), 1)
                # zero_vector = tf.zeros(shape=(1,1), dtype=tf.float32)
                # bool_mask = tf.not_equal(intermediate_tensor, zero_vector)
                # omit_zeros = tf.boolean_mask(x, bool_mask)
                intermediate_tensor = tf.reduce_sum(tf.abs(mu_diff), 1)
                zero_vector = tf.zeros(shape=(1, 1), dtype=tf.float32)
                bool_mask = tf.not_equal(intermediate_tensor, zero_vector)
                omit_zeros = tf.boolean_mask(mu_diff, bool_mask)
                mu_diff = tf.expand_dims(omit_zeros, axis=1)
                print mu_diff
                mu_diff = tf.norm(mu_diff, ord=1)
                # squared_norm = tf.reduce_sum(tf.square(s), axis=axis,keep_dims=True)
                # safe_norm = tf.sqrt(squared_norm + epsilon)
                # squared_norm = tf.reduce_sum(tf.square(omit_zeros), axis=-1,keep_dims=True)
                # safe_norm = tf.sqrt(squared_norm + 1e-6)
                # mu_diff = safe_norm

                mu_diff = tf.divide(mu_diff, Lreg)

                mu_diff = tf.subtract(2 * delta_d, mu_diff)
                mu_diff = tf.clip_by_value(mu_diff, 0, mu_diff)
                mu_diff = tf.square(mu_diff)

                numUniqueF = tf.cast(numUnique, tf.float32)
                Ldist = tf.reduce_mean(mu_diff)

                # L = alpha * Lvar + beta * Ldist + gamma * Lreg
                # L = tf.reduce_mean(L, keep_dims=True)
                L = tf.reduce_sum([alpha * Lvar, beta * Ldist, gamma * Lreg],
                                  keep_dims=False)
                print L
                print Ldist
                print Lvar
                print Lreg
                return tf.identity(L, name=name)

            discrim_il = regDLF(toMaxLabels(pl, factor=MAX_LABEL),
                                tf.concat([feat_il, feat_ia], axis=-1),
                                name='discrim_il')
        with tf.name_scope('recon_loss'):
            recon_il = tf.reduce_mean(tf.abs(pl - pil_), name='recon_il')
        with tf.name_scope('affnt_loss'):
            # affnt_il = tf.reduce_mean(tf.abs(pa - pia_), name='affnt_il')
            affnt_il = tf.reduce_mean(
                tf.subtract(
                    binary_cross_entropy(pa, pia_),
                    dice_coe(pa, pia_, axis=[0, 1, 2, 3],
                             loss_type='jaccard')))
        with tf.name_scope('residual_loss'):
            # residual_a = tf.reduce_mean(tf.abs(pia - pila), name='residual_a')
            # residual_l = tf.reduce_mean(tf.abs(pil - pial), name='residual_l')
            residual_a = tf.reduce_mean(tf.cast(tf.not_equal(pia, pila),
                                                tf.float32),
                                        name='residual_a')
            residual_l = tf.reduce_mean(tf.cast(tf.not_equal(pil, pial),
                                                tf.float32),
                                        name='residual_l')
            residual_il = tf.reduce_mean([residual_a, residual_l],
                                         name='residual_il')

        def label_imag(y_pred_L, name='label_imag'):
            mag_grad_L = magnitute_central_difference(y_pred_L,
                                                      name='mag_grad_L')
            cond = tf.greater(mag_grad_L, tf.zeros_like(mag_grad_L))
            thresholded_mag_grad_L = tf.where(cond,
                                              tf.ones_like(mag_grad_L),
                                              tf.zeros_like(mag_grad_L),
                                              name='thresholded_mag_grad_L')

            thresholded_mag_grad_L = tf_2tanh(thresholded_mag_grad_L,
                                              maxVal=1.0)
            return thresholded_mag_grad_L

        g_il = label_imag(pil_, name='label_il')
        g_l = label_imag(pl, name='label_l')

        self.g_loss = tf.reduce_sum([
            1 * (G_loss),
            10 * (recon_il),
            10 * (residual_il),
            1 * (rand_il),
            50 * (discrim_il),
            0.005 * affnt_il,
        ],
                                    name='G_loss_total')
        self.d_loss = tf.reduce_sum([D_loss], name='D_loss_total')
        wd_g = regularize_cost('gen/.*/W',
                               l2_regularizer(1e-5),
                               name='G_regularize')
        wd_d = regularize_cost('discrim/.*/W',
                               l2_regularizer(1e-5),
                               name='D_regularize')

        self.g_loss = tf.add(self.g_loss, wd_g, name='g_loss')
        self.d_loss = tf.add(self.d_loss, wd_d, name='d_loss')

        self.collect_variables()

        add_moving_summary(self.d_loss, self.g_loss)
        with tf.name_scope('summaries'):
            add_tensor_summary(recon_il, types=['scalar'], name='recon_il')
            add_tensor_summary(rand_il, types=['scalar'], name='rand_il')
            add_tensor_summary(discrim_il, types=['scalar'], name='discrim_il')
            add_tensor_summary(affnt_il, types=['scalar'], name='affnt_il')
            add_tensor_summary(residual_il,
                               types=['scalar'],
                               name='residual_il')

        #Segmentation
        viz = tf.concat([
            tf.concat([pi, pl, g_l, g_il], 2),
            tf.concat([tf.zeros_like(pi), pil, pial, pil_], 2),
        ], 1)

        viz = tf_2imag(viz)
        viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
        tf.summary.image('colorized', viz, max_outputs=50)

        # Affinity
        vis = tf.concat([pa, pila, pia, pia_], 2)

        vis = tf_2imag(vis)
        vis = tf.cast(tf.clip_by_value(vis, 0, 255), tf.uint8, name='vis')
        tf.summary.image('affinities', vis, max_outputs=50)
示例#6
0
    def _build_graph(self, inputs):
        # sImg2d # sImg the projection 2D, reshape from
        VGG19_MEAN = np.array([123.68, 116.779, 103.939])  # RGB
        VGG19_MEAN_TENSOR = tf.constant(VGG19_MEAN, dtype=tf.float32)

        image, style = inputs  # Split the input

        @auto_reuse_variable_scope
        def vgg19_encoder(source, name='VGG19_Encoder'):
            with tf.variable_scope(name):
                with varreplace.freeze_variables():
                    with argscope([Conv2D], kernel_shape=3, nl=tf.nn.relu):
                        source = source - VGG19_MEAN_TENSOR
                        conv1_1 = Conv2D('conv1_1', source, 64)
                        conv1_2 = Conv2D('conv1_2', conv1_1, 64)
                        pool1 = MaxPooling('pool1', conv1_2, 2)  # 64
                        conv2_1 = Conv2D('conv2_1', pool1, 128)
                        conv2_2 = Conv2D('conv2_2', conv2_1, 128)
                        pool2 = MaxPooling('pool2', conv2_2, 2)  # 32
                        conv3_1 = Conv2D('conv3_1', pool2, 256)
                        conv3_2 = Conv2D('conv3_2', conv3_1, 256)
                        conv3_3 = Conv2D('conv3_3', conv3_2, 256)
                        conv3_4 = Conv2D('conv3_4', conv3_3, 256)
                        pool3 = MaxPooling('pool3', conv3_4, 2)  # 16
                        conv4_1 = Conv2D('conv4_1', pool3, 512)
                        conv4_2 = Conv2D('conv4_2', conv4_1, 512)
                        conv4_3 = Conv2D('conv4_3', conv4_2, 512)
                        conv4_4 = Conv2D('conv4_4', conv4_3, 512)
                        pool4 = MaxPooling('pool4', conv4_4, 2)  # 8
                        conv5_1 = Conv2D('conv5_1', pool4, 512)
                        conv5_2 = Conv2D('conv5_2', conv5_1, 512)
                        conv5_3 = Conv2D('conv5_3', conv5_2, 512)
                        conv5_4 = Conv2D('conv5_4', conv5_3, 512)
                        pool5 = MaxPooling('pool5', conv5_4, 2)  # 4
                        # return normalize(conv4_1), [normalize(conv1_1), normalize(conv2_1), normalize(conv3_1), normalize(conv4_1)] # List of returned feature maps
                        return conv4_1, [conv1_1, conv2_1, conv3_1, conv4_1]

        @auto_reuse_variable_scope
        def vgg19_decoder(source, name='VGG19_Decoder'):
            with tf.variable_scope(name):
                # with varreplace.freeze_variables():
                with argscope([Conv2D], kernel_shape=3, nl=tf.nn.elu):
                    with argscope([Deconv2D],
                                  kernel_shape=3,
                                  strides=(2, 2),
                                  nl=tf.nn.elu):
                        # conv5_4 = Conv2D('conv5_4', input,   512)
                        # conv5_3 = Conv2D('conv5_3', conv5_4, 512)
                        # conv5_2 = Conv2D('conv5_2', conv5_3, 512)
                        # conv5_1 = Conv2D('conv5_1', conv5_2, 512)
                        # pool4 = Deconv2D('pool4',   input,   512)  # 8
                        # conv4_4 = Conv2D('conv4_4', pool4,   512)
                        # conv4_3 = Conv2D('conv4_3', conv4_4, 512)
                        # conv4_2 = Conv2D('conv4_2', conv4_3, 512)
                        # conv4_1 = Conv2D('conv4_1', conv4_2, 512)
                        pool3 = Subpix2D('pool3', source, 256)  # 16
                        conv3_4 = Conv2D('conv3_4', pool3, 256)
                        conv3_3 = Conv2D('conv3_3', conv3_4, 256)
                        conv3_2 = Conv2D('conv3_2', conv3_3, 256)
                        conv3_1 = Conv2D('conv3_1', conv3_2, 256)
                        pool2 = Subpix2D('pool2', conv3_1, 128)  # 32
                        conv2_2 = Conv2D('conv2_2', pool2, 128)
                        conv2_1 = Conv2D('conv2_1', conv2_2, 128)
                        pool1 = Subpix2D('pool1', conv2_1, 64)  # 64
                        conv1_2 = Conv2D('conv1_2', pool1, 64)
                        conv1_1 = Conv2D('conv1_1', conv1_2, 64)
                        conv1_0 = Conv2D('conv1_0', conv1_1, 3)
                        conv1_0 = conv1_0 + VGG19_MEAN_TENSOR
                        # conv1_0 = 255.0*tf.nn.tanh(conv1_0)
                        # conv1_0 = tf_2tanh(conv1_0, maxVal=255.0)
                        # conv1_0 = tf_2imag(conv1_0, maxVal=255.0)
                        # conv1_0 = conv1_0 - VGG19_MEAN_TENSOR
                        return conv1_0  # List of feature maps

        @auto_reuse_variable_scope
        def vgg19_feature(source, name='VGG19_Feature'):
            with tf.variable_scope(name):
                with varreplace.freeze_variables():
                    with argscope([Conv2D], kernel_shape=3, nl=tf.nn.relu):
                        source = source - VGG19_MEAN_TENSOR
                        conv1_1 = Conv2D('conv1_1', source, 64)
                        conv1_2 = Conv2D('conv1_2', conv1_1, 64)
                        pool1 = MaxPooling('pool1', conv1_2, 2)  # 64
                        conv2_1 = Conv2D('conv2_1', pool1, 128)
                        conv2_2 = Conv2D('conv2_2', conv2_1, 128)
                        pool2 = MaxPooling('pool2', conv2_2, 2)  # 32
                        conv3_1 = Conv2D('conv3_1', pool2, 256)
                        conv3_2 = Conv2D('conv3_2', conv3_1, 256)
                        conv3_3 = Conv2D('conv3_3', conv3_2, 256)
                        conv3_4 = Conv2D('conv3_4', conv3_3, 256)
                        pool3 = MaxPooling('pool3', conv3_4, 2)  # 16
                        conv4_1 = Conv2D('conv4_1', pool3, 512)
                        conv4_2 = Conv2D('conv4_2', conv4_1, 512)
                        conv4_3 = Conv2D('conv4_3', conv4_2, 512)
                        conv4_4 = Conv2D('conv4_4', conv4_3, 512)
                        pool4 = MaxPooling('pool4', conv4_4, 2)  # 8
                        conv5_1 = Conv2D('conv5_1', pool4, 512)
                        conv5_2 = Conv2D('conv5_2', conv5_1, 512)
                        conv5_3 = Conv2D('conv5_3', conv5_2, 512)
                        conv5_4 = Conv2D('conv5_4', conv5_3, 512)
                        pool5 = MaxPooling('pool5', conv5_4, 2)  # 4
                        return conv4_1, [conv1_1, conv2_1, conv3_1, conv4_1]
                        # return normalize(conv4_1), [normalize(conv1_1), normalize(conv2_1), normalize(conv3_1), normalize(conv4_1)] # List of returned feature maps

        # Step 1: Run thru the encoder
        image_encoded, image_feature = vgg19_encoder(image)
        style_encoded, style_feature = vgg19_encoder(style)

        # Step 2: Run thru the adain block to get t=AdIN(f(c), f(s))
        merge_encoded = self._build_adain_layers(image_encoded, style_encoded)

        # Step 3: Run thru the decoder to get the paint image
        paint = vgg19_decoder(merge_encoded)

        # Actually, vgg19_feature and vgg19_encoder are identical
        # Splitting them to improve the programmability
        paint_encoded, paint_feature = vgg19_feature(paint)
        style_encoded, style_feature = vgg19_feature(style)

        # print(merge_encoded.get_shape())
        # print(paint_encoded.get_shape())
        #
        # Build losses here
        #
        with tf.name_scope('losses'):
            losses = []
            # Content loss between t and f(g(t))
            content_loss = self._build_content_loss(merge_encoded,
                                                    paint_encoded,
                                                    weight=args.weight_c)
            # add_moving_summary(content_loss)
            add_tensor_summary(content_loss,
                               types=['scalar'],
                               name='content_loss')
            losses.append(content_loss)

            # Style losses between paint and style
            style_losses = self._build_style_losses(paint_feature,
                                                    style_feature,
                                                    weight=args.weight_s)
            for idx, style_loss in enumerate(style_losses):
                add_tensor_summary(style_loss,
                                   types=['scalar'],
                                   name='style_loss')
                losses.append(style_loss)

            # Total variation loss
            smoothness = tf.reduce_sum(tf.image.total_variation(paint))
            add_tensor_summary(smoothness, types=['scalar'], name='smoothness')
            losses.append(smoothness * args.weight_tv)

            # Total loss
            self.cost = tf.reduce_sum(
                losses, name='self.cost')  # this one goes to the optimizer
            add_tensor_summary(self.cost, types=['scalar'], name='self.cost')

        # Reconstruct img
        # image = image + VGG19_MEAN_TENSOR
        # style = style + VGG19_MEAN_TENSOR
        paint = tf.identity(paint, name='paint')
        # Build loss in here

        # Visualization
        viz = tf.concat([image, style, paint], axis=2)
        viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
        tf.summary.image('colorized', viz, max_outputs=50)
示例#7
0
文件: train.py 项目: hakillha/maria03
    def build_graph(self, *inputs):
        is_training = get_current_tower_context().is_training
        image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_ids, orig_shape = inputs
        image = self.preprocess(image)  # 1CHW

        featuremap = resnet_c4_backbone(image,
                                        cfg.BACKBONE.RESNET_NUM_BLOCK[:3])
        rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap,
                                                    cfg.RPN.HEAD_DIM,
                                                    cfg.RPN.NUM_ANCHOR)

        anchors = RPNAnchors(get_all_anchors(), anchor_labels, anchor_boxes)
        anchors = anchors.narrow_to(featuremap)

        image_shape2d = tf.shape(image)[2:]  # h,w
        # decode into actual image coordinates
        pred_boxes_decoded = anchors.decode_logits(
            rpn_box_logits)  # fHxfWxNAx4, floatbox
        proposal_boxes, proposal_scores = generate_rpn_proposals(
            tf.reshape(pred_boxes_decoded, [-1, 4]),
            tf.reshape(rpn_label_logits,
                       [-1]), image_shape2d, cfg.RPN.TRAIN_PRE_NMS_TOPK
            if is_training else cfg.RPN.TEST_PRE_NMS_TOPK,
            cfg.RPN.TRAIN_POST_NMS_TOPK
            if is_training else cfg.RPN.TEST_POST_NMS_TOPK)

        if is_training:
            # sample proposal boxes in training
            rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets(
                proposal_boxes, gt_boxes, gt_labels)
        else:
            # The boxes to be used to crop RoIs.
            # Use all proposal boxes in inference
            rcnn_boxes = proposal_boxes

        boxes_on_featuremap = rcnn_boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE)
        # size? #proposals*h*w*c?
        roi_resized = roi_align(featuremap, boxes_on_featuremap, 14)

        feature_fastrcnn = resnet_conv5(
            roi_resized, cfg.BACKBONE.RESNET_NUM_BLOCK[-1])  # nxcx7x7
        # Keep C5 feature to be shared with mask branch
        feature_gap = GlobalAvgPooling('gap',
                                       feature_fastrcnn,
                                       data_format='channels_first')
        fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs(
            'fastrcnn', feature_gap, cfg.DATA.NUM_CLASS)

        if is_training:
            # rpn loss
            rpn_label_loss, rpn_box_loss = rpn_losses(
                anchors.gt_labels, anchors.encoded_gt_boxes(),
                rpn_label_logits, rpn_box_logits)

            # fastrcnn loss
            matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt)

            fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0),
                                            [-1])  # fg inds w.r.t all samples
            # outputs from fg proposals
            fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample)
            fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits,
                                               fg_inds_wrt_sample)

            # rcnn_labels: the labels of the proposals
            # fg_sampled_boxes: fg proposals
            # matched_gt_boxes: just like RPN, the gt boxes
            #                   that match the corresponding fg proposals
            fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training(
                image, rcnn_labels, fg_sampled_boxes, matched_gt_boxes,
                fastrcnn_label_logits, fg_fastrcnn_box_logits)

            # acquire pred for re-id training
            # turning NMS off gives re-id branch more training samples
            if cfg.RE_ID.NMS:
                boxes, final_labels, final_probs = self.fastrcnn_inference(
                    image_shape2d, rcnn_boxes, fastrcnn_label_logits,
                    fastrcnn_box_logits)
            else:
                boxes, final_labels, final_probs = self.fastrcnn_inference_id(
                    image_shape2d, rcnn_boxes, fastrcnn_label_logits,
                    fastrcnn_box_logits)
            # scale = tf.sqrt(tf.cast(image_shape2d[0], tf.float32) / tf.cast(orig_shape[0], tf.float32) *
            #                 tf.cast(image_shape2d[1], tf.float32) / tf.cast(orig_shape[1], tf.float32))
            # final_boxes = boxes / scale
            # # boxes are already clipped inside the graph, but after the floating point scaling, this may not be true any more.
            # final_boxes = tf_clip_boxes(final_boxes, orig_shape)

            # IOU, discard bad dets, assign re-id labels
            # the results are already NMS so no need to NMS again
            # crop from conv4 with dets (maybe plus gts)
            # feedforward re-id branch
            # resizing during ROIalign?
            iou = pairwise_iou(boxes, gt_boxes)  # are the gt boxes resized?
            tp_mask = tf.reduce_max(iou, axis=1) >= cfg.RE_ID.IOU_THRESH
            iou = tf.boolean_mask(iou, tp_mask)

            # return iou to debug

            def re_id_loss(pred_boxes, pred_matching_gt_ids, featuremap):
                with tf.variable_scope('id_head'):
                    num_of_samples_used = tf.get_variable(
                        'num_of_samples_used', initializer=0, trainable=False)
                    num_of_samples_used = num_of_samples_used.assign_add(
                        tf.shape(pred_boxes)[0])

                    boxes_on_featuremap = pred_boxes * (1.0 /
                                                        cfg.RPN.ANCHOR_STRIDE)
                    # name scope?
                    # stop gradient
                    roi_resized = roi_align(featuremap, boxes_on_featuremap,
                                            14)
                    feature_idhead = resnet_conv5(
                        roi_resized,
                        cfg.BACKBONE.RESNET_NUM_BLOCK[-1])  # nxcx7x7
                    feature_gap = GlobalAvgPooling(
                        'gap', feature_idhead, data_format='channels_first')

                    init = tf.variance_scaling_initializer()
                    hidden = FullyConnected('fc6',
                                            feature_gap,
                                            1024,
                                            kernel_initializer=init,
                                            activation=tf.nn.relu)
                    hidden = FullyConnected('fc7',
                                            hidden,
                                            1024,
                                            kernel_initializer=init,
                                            activation=tf.nn.relu)
                    hidden = FullyConnected('fc8',
                                            hidden,
                                            256,
                                            kernel_initializer=init,
                                            activation=tf.nn.relu)
                    id_logits = FullyConnected(
                        'class',
                        hidden,
                        cfg.DATA.NUM_ID,
                        kernel_initializer=tf.random_normal_initializer(
                            stddev=0.01))

                label_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=pred_matching_gt_ids, logits=id_logits)
                label_loss = tf.reduce_mean(label_loss, name='label_loss')

                return label_loss, num_of_samples_used

            def check_unid_pedes(iou, gt_ids, boxes, tp_mask, featuremap):
                pred_gt_ind = tf.argmax(iou, axis=1)
                # output following tensors
                # pick out the -2 class here
                pred_matching_gt_ids = tf.gather(gt_ids, pred_gt_ind)
                pred_boxes = tf.boolean_mask(boxes, tp_mask)
                # label 1 corresponds to unid pedes
                unid_ind = tf.not_equal(pred_matching_gt_ids, 1)
                pred_matching_gt_ids = tf.boolean_mask(pred_matching_gt_ids,
                                                       unid_ind)
                pred_boxes = tf.boolean_mask(pred_boxes, unid_ind)

                ret = tf.cond(
                    tf.equal(tf.size(pred_boxes), 0), lambda:
                    (tf.constant(cfg.RE_ID.STABLE_LOSS), tf.constant(0)),
                    lambda: re_id_loss(pred_boxes, pred_matching_gt_ids,
                                       featuremap))
                return ret

            with tf.name_scope('id_head'):
                # no detection has IOU > 0.7, re-id returns 0 loss
                re_id_loss, num_of_samples_used = tf.cond(
                    tf.equal(tf.size(iou), 0), lambda:
                    (tf.constant(cfg.RE_ID.STABLE_LOSS), tf.constant(0)),
                    lambda: check_unid_pedes(iou, gt_ids, boxes, tp_mask,
                                             featuremap))
                add_tensor_summary(num_of_samples_used, ['scalar'],
                                   name='num_of_samples_used')
            # for debug, use tensor name to take out the handle
            # return re_id_loss

            # pred_gt_ind = tf.argmax(iou, axis=1)
            # # output following tensors
            # # pick out the -2 class here
            # pred_gt_ids = tf.gather(gt_ids, pred_gt_ind)
            # pred_boxes = tf.boolean_mask(boxes, tp_mask)
            # unid_ind = pred_gt_ids != 1

            # return unid_ind

            # return tf.shape(boxes)[0]

            unnormed_id_loss = tf.identity(re_id_loss, name='unnormed_id_loss')
            re_id_loss = tf.divide(re_id_loss, cfg.RE_ID.LOSS_NORMALIZATION,
                                   're_id_loss')
            add_moving_summary(unnormed_id_loss)
            add_moving_summary(re_id_loss)

            wd_cost = regularize_cost('.*/W',
                                      l2_regularizer(cfg.TRAIN.WEIGHT_DECAY),
                                      name='wd_cost')

            # weights on the losses?
            total_cost = tf.add_n([
                rpn_label_loss, rpn_box_loss, fastrcnn_label_loss,
                fastrcnn_box_loss, re_id_loss, wd_cost
            ], 'total_cost')

            add_moving_summary(total_cost, wd_cost)
            return total_cost
        else:
            if cfg.RE_ID.QUERY_EVAL:
                # resize the gt_boxes in dataflow
                final_boxes = gt_boxes
            else:
                final_boxes, final_labels, _ = self.fastrcnn_inference(
                    image_shape2d, rcnn_boxes, fastrcnn_label_logits,
                    fastrcnn_box_logits)

            with tf.variable_scope('id_head'):
                preds_on_featuremap = final_boxes * (1.0 /
                                                     cfg.RPN.ANCHOR_STRIDE)
                # name scope?
                # stop gradient
                roi_resized = roi_align(featuremap, preds_on_featuremap, 14)
                feature_idhead = resnet_conv5(
                    roi_resized, cfg.BACKBONE.RESNET_NUM_BLOCK[-1])  # nxcx7x7
                feature_gap = GlobalAvgPooling('gap',
                                               feature_idhead,
                                               data_format='channels_first')

                hidden = FullyConnected('fc6',
                                        feature_gap,
                                        1024,
                                        activation=tf.nn.relu)
                hidden = FullyConnected('fc7',
                                        hidden,
                                        1024,
                                        activation=tf.nn.relu)
                fv = FullyConnected('fc8', hidden, 256, activation=tf.nn.relu)
                id_logits = FullyConnected(
                    'class',
                    fv,
                    cfg.DATA.NUM_ID,
                    kernel_initializer=tf.random_normal_initializer(
                        stddev=0.01))

            scale = tf.sqrt(
                tf.cast(image_shape2d[0], tf.float32) /
                tf.cast(orig_shape[0], tf.float32) *
                tf.cast(image_shape2d[1], tf.float32) /
                tf.cast(orig_shape[1], tf.float32))
            rescaled_final_boxes = final_boxes / scale
            # boxes are already clipped inside the graph, but after the floating point scaling, this may not be true any more.
            # rescaled_final_boxes_pre_clip = tf.identity(rescaled_final_boxes, name='re_boxes_pre_clip')
            rescaled_final_boxes = tf_clip_boxes(rescaled_final_boxes,
                                                 orig_shape)
            rescaled_final_boxes = tf.identity(rescaled_final_boxes,
                                               'rescaled_final_boxes')

            fv = tf.identity(fv, name='feature_vector')
            prob = tf.nn.softmax(id_logits, name='re_id_probs')
    def _build_graph(self, inputs):
        volume, gt_seg, gt_affs = inputs
        volume = tf.expand_dims(volume, 3)
        volume = tf.expand_dims(volume, 0)
        #         image = image * 2 - 1

        with argscope(LeakyReLU, alpha=0.2),\
            argscope ([Conv3D, DeConv3D], use_bias=False,
                kernel_shape=3, stride=2, padding='SAME',
                W_init=tf.contrib.layers.variance_scaling_initializer(factor=0.333, uniform=True)):
            _in = Conv3D('in',
                         volume,
                         NB_FILTERS,
                         kernel_shape=(3, 5, 5),
                         stride=1,
                         padding='SAME',
                         nl=INELU,
                         use_bias=True)
            e0 = residual_enc('e0', volume, NB_FILTERS * 1)
            e1 = residual_enc('e1', e0, NB_FILTERS * 2)
            e2 = residual_enc('e2', e1, NB_FILTERS * 4, kernel_shape=(1, 3, 3))
            e3 = residual_enc('e3', e2, NB_FILTERS * 8, kernel_shape=(1, 3, 3))

            e3 = Dropout('dr', e3, rate=0.5)

            d3 = residual_dec('d3',
                              e3 + e3,
                              NB_FILTERS * 4,
                              kernel_shape=(1, 3, 3))
            d2 = residual_dec('d2',
                              d3 + e2,
                              NB_FILTERS * 2,
                              kernel_shape=(1, 3, 3))
            d1 = residual_dec('d1', d2 + e1, NB_FILTERS * 1)
            d0 = residual_dec('d0', d1 + e0, NB_FILTERS * 1)

            logits = funcs.Conv3D('x_out',
                                  d0,
                                  len(nhood),
                                  kernel_shape=(3, 5, 5),
                                  stride=1,
                                  padding='SAME',
                                  nl=tf.identity,
                                  use_bias=True)

        logits = tf.squeeze(logits)
        logits = tf.transpose(logits, perm=[3, 0, 1, 2], name='logits')

        affs = cvt2sigm(tf.tanh(logits))
        affs = tf.identity(affs, name='affs')

        ######################################################################################################################

        wbce_malis_loss = wbce_malis(logits,
                                     affs,
                                     gt_affs,
                                     gt_seg,
                                     nhood,
                                     affs_shape,
                                     name='wbce_malis',
                                     limit_z=False)
        wbce_malis_loss = tf.identity(wbce_malis_loss, name='wbce_malis_loss')

        dice_loss = tf.identity(
            (1. -
             dice_coe(affs, gt_affs, axis=[0, 1, 2, 3], loss_type='jaccard')) *
            0.1,
            name='dice_coe')
        tot_loss = tf.identity(wbce_malis_loss + dice_loss, name='tot_loss')
        ######################################################################################################################

        self.cost = tot_loss
        summary.add_tensor_summary(tot_loss, types=['scalar'])
        summary.add_tensor_summary(dice_loss, types=['scalar'])
        summary.add_tensor_summary(wbce_malis_loss, types=['scalar'])
    def _build_graph(self, inputs):
        G = tf.get_default_graph()  # For round
        tf.local_variables_initializer()
        tf.global_variables_initializer()
        pi, pm, pl, ui, um, ul = inputs
        pi = cvt2tanh(pi)
        pm = cvt2tanh(pm)
        pl = cvt2tanh(pl)
        ui = cvt2tanh(ui)
        um = cvt2tanh(um)
        ul = cvt2tanh(ul)

        # def tf_membr(label):
        # 	with freeze_variables():
        # 		label = np_2imag(label, maxVal=MAX_LABEL)
        # 		label = np.squeeze(label) # Unimplemented: exceptions.NotImplementedError: Only for images of dimension 1-3 are supported, got a 4D one
        # 		# label, nb_labels = skimage.measure.label(color, return_num=True)
        # 		# label = np.expand_dims(label, axis=-1).astype(np.float32) # Modify here for batch
        # 		# for z in range(membr.shape[0]):
        # 		# 	membr[z,...] = 1-skimage.segmentation.find_boundaries(np.squeeze(label[z,...]), mode='thick') #, mode='inner'
        # 		membr = 1-skimage.segmentation.find_boundaries(np.squeeze(label), mode='thick') #, mode='inner'
        # 		membr = np.expand_dims(membr, axis=-1).astype(np.float32)
        # 		membr = np.expand_dims(membr, axis=0).astype(np.float32)
        # 		membr = np_2tanh(membr, maxVal=1.0)
        # 		membr = np.reshape(membr, label.shape)
        # 		return membr

        # def tf_label(color):
        # 	with freeze_variables():
        # 		color = np_2imag(color, maxVal=MAX_LABEL)
        # 		color = np.squeeze(color) # Unimplemented: exceptions.NotImplementedError: Only for images of dimension 1-3 are supported, got a 4D one
        # 		label, nb_labels = skimage.measure.label(color, return_num=True)
        # 		label = np.expand_dims(label, axis=-1).astype(np.float32)
        # 		label = np.expand_dims(label, axis=0).astype(np.float32)
        # 		label = np_2tanh(label, maxVal=MAX_LABEL)
        # 		label = np.reshape(label, color.shape)
        # 		return label

        def tf_rand_score(x1, x2):
            return np.mean(1.0 -
                           adjusted_rand_score(x1.flatten(), x2.flatten()))

        def rounded(label, factor=MAX_LABEL, name='quantized'):
            with G.gradient_override_map({"Round": "Identity"}):
                with freeze_variables():
                    with tf.name_scope(name=name):
                        # label = cvt2imag(label, maxVal=factor)
                        # label = tf.round(label)
                        # label = cvt2tanh(label, maxVal=factor)
                        # cvt from -1 ~ 1 to 0 255
                        # label = cvt2imag(label, maxVal=255.0)
                        cond0 = tf.equal(label, -1.0 * tf.ones_like(label))
                        label = tf.where(
                            cond0,
                            tf.zeros_like(label),
                            label,
                            name='removedBackground')  # From -1 to 0
                        label = label * factor  # From 0~1 to 0~MAXLABEL
                        label = tf.round(label)
                        label = label / factor  # From 0~MAXLABEL to 0~1
                        cond1 = tf.equal(label, 0.0 * tf.zeros_like(label))
                        label = tf.where(
                            cond1,
                            -1.0 * tf.ones_like(label),
                            label,
                            name='addedBackground')  # From -1 to 0
                    return tf.identity(label, name=name)


        with argscope([Conv2D, Deconv2D, FullyConnected],
             W_init=tf.truncated_normal_initializer(stddev=0.02),
             use_bias=False), \
          argscope(BatchNorm, gamma_init=tf.random_uniform_initializer()), \
          argscope([Conv2D, Deconv2D, BatchNorm], data_format='NHWC'), \
          argscope(LeakyReLU, alpha=0.2):

            with tf.variable_scope('gen'):
                # Real pair image 4 gen
                with tf.variable_scope('I2M'):
                    pim, feat_im = self.generator(pi)
                with tf.variable_scope('M2L'):
                    piml, feat_iml = self.generator(pim)
                    pml, feat_ml = self.generator(pm)
                    # piml  = tf.py_func(tf_label, [(pim)], tf.float32)
                    # pml   = tf.py_func(tf_label, [(pm)], tf.float32)
                    # print pim
                    # print piml
                # with tf.variable_scope('L2M'):
                # # with freeze_variables():
                # 	pimlm = self.generator(piml) #
                # 	plm   = self.generator(pl)
                # 	pmlm  = self.generator(pml)
                # 	# pimlm = tf.py_func(tf_membr, [(piml)], tf.float32) #
                # 	# plm   = tf.py_func(tf_membr, [(pl)	], tf.float32)
                # 	# pmlm  = tf.py_func(tf_membr, [(pml)	], tf.float32)
                # 	# print piml
                # 	# print pimlm
                # with tf.variable_scope('M2I'):
                # 	pimlmi = self.generator(pimlm) #
                # 	pimi   = self.generator(pim)

                # # Real pair label 4 gen
                # with tf.variable_scope('L2M'):
                # # with freeze_variables():
                # 	plm = self.generator(pl)
                # 	# plm  = tf.py_func(tf_membr, [(pl)	, tf.float32])
                # with tf.variable_scope('M2I'):
                # 	plmi = self.generator(plm)
                # 	pmi  = self.generator(pi)
                # with tf.variable_scope('I2M'):
                # 	plmim = self.generator(plmi) #
                # 	pim   = self.generator(pi)
                # 	pmim  = self.generator(pmi)

                # with tf.variable_scope('M2L'):
                # 	plmiml = self.generator(plmim) #
                # 	plml   = self.generator(plm)
                # 	# plmiml = tf.py_func(tf_label, [(plmim)], tf.float32)
                # 	# plml   = tf.py_func(tf_label, [(plm)], tf.float32)

            with tf.variable_scope('discrim'):
                # with tf.variable_scope('I'):
                # 	i_dis_real 			  = self.discriminator(ui)
                # 	i_dis_fake_from_label = self.discriminator(plmi)
                with tf.variable_scope('M'):
                    m_dis_real = self.discriminator(um)
                    m_dis_fake_from_image = self.discriminator(pim)
                    # m_dis_fake_from_label = self.discriminator(plm)
                with tf.variable_scope('L'):
                    l_dis_real = self.discriminator(ul)
                    l_dis_fake_from_image = self.discriminator(piml)

        piml = rounded(piml)  #
        pml = rounded(pml)
        # plmiml = rounded(plmiml) #
        # plml   = rounded(plml)

        # with tf.name_scope('Recon_I_loss'):
        # 	recon_imi 		= tf.reduce_mean(tf.abs((pi) - (pimi)), name='recon_imi')
        # 	recon_lmi 		= tf.reduce_mean(tf.abs((pi) - (plmi)), name='recon_lmi')
        # 	recon_imlmi 	= tf.reduce_mean(tf.abs((pi) - (pimlmi)), name='recon_imlmi') #

        with tf.name_scope('Recon_L_loss'):
            # recon_lml 		= tf.reduce_mean(tf.abs((pl) - (plml)), name='recon_lml')
            # recon_iml 		= tf.reduce_mean(tf.abs((pl) - (piml)), name='recon_iml')
            # recon_ml 		= tf.reduce_mean(tf.abs((pl) - (pml)), name='recon_ml')
            # recon_lmiml 	= tf.reduce_mean(tf.abs((pl) - (plmiml)), name='recon_lmiml') #
            # recon_iml 		= tf.reduce_mean(tf.cast(tf.not_equal(pl, piml), tf.float32), name='recon_iml')
            # recon_ml 		= tf.reduce_mean(tf.cast(tf.not_equal(pl, pml), tf.float32), name='recon_ml')
            recon_ml = tf.reduce_mean(tf.abs(
                cvt2imag(pl, maxVal=MAX_LABEL) -
                cvt2imag(pml, maxVal=MAX_LABEL)),
                                      name='recon_ml')
            recon_iml = tf.reduce_mean(tf.abs(
                cvt2imag(pl, maxVal=MAX_LABEL) -
                cvt2imag(piml, maxVal=MAX_LABEL)),
                                       name='recon_iml')  #

        with tf.name_scope('Recon_M_loss'):
            # recon_mim 		= tf.reduce_mean(tf.abs((pm) - (pmim)), name='recon_mim')
            # recon_mlm 		= tf.reduce_mean(tf.abs((pm) - (pmlm)), name='recon_mlm')

            # recon_im 		= tf.reduce_mean(tf.abs((pm) - (pim)), name='recon_im')
            # recon_lm 		= tf.reduce_mean(tf.abs((pm) - (plm)), name='recon_lm')
            # recon_im 		= tf.reduce_mean(tf.cast(tf.not_equal(pm, pim), tf.float32), name='recon_im')
            recon_im = tf.reduce_mean(
                tf.abs(cvt2imag(pm, maxVal=1.0) - cvt2imag(pim, maxVal=1.0)),
                name='recon_im')

        with tf.name_scope('GAN_loss'):
            # G_loss_IL, D_loss_IL = self.build_losses(i_dis_real, i_dis_fake_from_label, name='IL')
            G_loss_LI, D_loss_LI = self.build_losses(l_dis_real,
                                                     l_dis_fake_from_image,
                                                     name='LL')
            G_loss_MI, D_loss_MI = self.build_losses(m_dis_real,
                                                     m_dis_fake_from_image,
                                                     name='MI')
            # G_loss_ML, D_loss_ML = self.build_losses(m_dis_real, m_dis_fake_from_label, name='ML')

        # custom loss for membr
        with tf.name_scope('membr_loss'):

            def membr_loss(y_true, y_pred, name='membr_loss'):
                loss = tf.reduce_mean(
                    tf.subtract(
                        binary_cross_entropy(cvt2imag(y_true, maxVal=1.0),
                                             cvt2imag(y_pred, maxVal=1.0)),
                        dice_coe(cvt2imag(y_true, maxVal=1.0),
                                 cvt2imag(y_pred, maxVal=1.0),
                                 axis=[1, 2, 3],
                                 loss_type='jaccard')))
                # loss = tf.reshape(loss, [-1])
                return tf.identity(loss, name=name)

            membr_im = membr_loss(pm, pim, name='membr_im')
            # print membr_im
            # membr_lm = membr_loss(pm, plm, name='membr_lm')
            # membr_imlm = membr_loss(pm, pimlm, name='membr_imlm')
            # membr_lmim = membr_loss(pm, plmim, name='membr_lmim')
            # membr_mlm = membr_loss(pm, pmlm, name='membr_mlm')
            # membr_mim = membr_loss(pm, pmim, name='membr_mim')
        # custom loss for label
        with tf.name_scope('label_loss'):

            def label_loss(y_true_L, y_pred_L, y_grad_M, name='label_loss'):
                g_mag_grad_M = cvt2imag(y_grad_M, maxVal=1.0)
                mag_grad_L = magnitute_central_difference(y_pred_L,
                                                          name='mag_grad_L')
                cond = tf.greater(mag_grad_L, tf.zeros_like(mag_grad_L))
                thresholded_mag_grad_L = tf.where(
                    cond,
                    tf.ones_like(mag_grad_L),
                    tf.zeros_like(mag_grad_L),
                    name='thresholded_mag_grad_L')

                gtv_guess = tf.multiply(g_mag_grad_M,
                                        thresholded_mag_grad_L,
                                        name='gtv_guess')
                loss_gtv_guess = tf.reduce_mean(gtv_guess,
                                                name='loss_gtv_guess')
                # loss_gtv_guess = tf.reshape(loss_gtv_guess, [-1])
                thresholded_mag_grad_L = cvt2tanh(thresholded_mag_grad_L,
                                                  maxVal=1.0)
                gtv_guess = cvt2tanh(gtv_guess, maxVal=1.0)
                return tf.identity(loss_gtv_guess,
                                   name=name), thresholded_mag_grad_L

            label_iml, g_iml = label_loss(None, piml, pim, name='label_iml')
            # label_lml, g_lml = label_loss(None, plml, plm, name='label_lml')
            # label_lmiml, g_lmiml = label_loss(None, plmiml, plmim, name='label_lmiml')
            label_ml, g_ml = label_loss(None, pml, pm, name='label_loss_ml')

        # custom loss for tf_rand_score
        with tf.name_scope('rand_loss'):
            rand_iml = tf.reduce_mean(
                tf.cast(tf.py_func(tf_rand_score, [piml, pl], tf.float64),
                        tf.float32))
            rand_ml = tf.reduce_mean(
                tf.cast(tf.py_func(tf_rand_score, [pml, pl], tf.float64),
                        tf.float32))

        with tf.name_scope('discrim_loss'):

            def regDLF(y_true,
                       y_pred,
                       alpha=1,
                       beta=1,
                       gamma=0.01,
                       delta_v=0.5,
                       delta_d=1.5,
                       name='loss_discrim'):
                def tf_norm(inputs, axis=1, epsilon=1e-7, name='safe_norm'):
                    squared_norm = tf.reduce_sum(tf.square(inputs),
                                                 axis=axis,
                                                 keep_dims=True)
                    safe_norm = tf.sqrt(squared_norm + epsilon)
                    return tf.identity(safe_norm, name=name)

                ###

                lins = tf.linspace(0.0, DIMZ * DIMY * DIMX, DIMZ * DIMY * DIMX)
                lins = tf.cast(lins, tf.int32)
                # lins = lins / tf.reduce_max(lins) * 255
                # lins = cvt2tanh(lins)
                # lins = tf.reshape(lins, tf.shape(y_true), name='lins_3d')
                # print lins
                lins_z = tf.div(lins, (DIMY * DIMX))
                lins_y = tf.div(tf.mod(lins, (DIMY * DIMX)), DIMY)
                lins_x = tf.mod(tf.mod(lins, (DIMY * DIMX)), DIMY)

                lins = tf.cast(lins, tf.float32)
                lins_z = tf.cast(lins_z, tf.float32)
                lins_y = tf.cast(lins_y, tf.float32)
                lins_x = tf.cast(lins_x, tf.float32)

                lins = lins / tf.reduce_max(lins) * 255
                lins_z = lins_z / tf.reduce_max(lins_z) * 255
                lins_y = lins_y / tf.reduce_max(lins_y) * 255
                lins_x = lins_x / tf.reduce_max(lins_x) * 255

                lins = cvt2tanh(lins)
                lins_z = cvt2tanh(lins_z)
                lins_y = cvt2tanh(lins_y)
                lins_x = cvt2tanh(lins_x)

                lins = tf.reshape(lins, tf.shape(y_true), name='lins')
                lins_z = tf.reshape(lins_z, tf.shape(y_true), name='lins_z')
                lins_y = tf.reshape(lins_y, tf.shape(y_true), name='lins_y')
                lins_x = tf.reshape(lins_x, tf.shape(y_true), name='lins_x')

                y_true = tf.reshape(y_true, [DIMZ * DIMY * DIMX])
                y_pred = tf.concat([y_pred, lins, lins_z, lins_y, lins_x],
                                   axis=-1)

                nDim = tf.shape(y_pred)[-1]
                X = tf.reshape(y_pred, [DIMZ * DIMY * DIMX, nDim])
                uniqueLabels, uniqueInd = tf.unique(y_true)

                numUnique = tf.size(
                    uniqueLabels)  # Get the number of connected component

                Sigma = tf.unsorted_segment_sum(X, uniqueInd, numUnique)
                # ones_Sigma = tf.ones((tf.shape(X)[0], 1))
                ones_Sigma = tf.ones_like(X)
                ones_Sigma = tf.unsorted_segment_sum(ones_Sigma, uniqueInd,
                                                     numUnique)
                mu = tf.divide(Sigma, ones_Sigma)

                Lreg = tf.reduce_mean(tf.norm(mu, axis=1, ord=1))

                T = tf.norm(tf.subtract(tf.gather(mu, uniqueInd), X),
                            axis=1,
                            ord=1)
                T = tf.divide(T, Lreg)
                T = tf.subtract(T, delta_v)
                T = tf.clip_by_value(T, 0, T)
                T = tf.square(T)

                ones_Sigma = tf.ones_like(uniqueInd, dtype=tf.float32)
                ones_Sigma = tf.unsorted_segment_sum(ones_Sigma, uniqueInd,
                                                     numUnique)
                clusterSigma = tf.unsorted_segment_sum(T, uniqueInd, numUnique)
                clusterSigma = tf.divide(clusterSigma, ones_Sigma)

                # Lvar = tf.reduce_mean(clusterSigma, axis=0)
                Lvar = tf.reduce_mean(clusterSigma)

                mu_interleaved_rep = tf.tile(mu, [numUnique, 1])
                mu_band_rep = tf.tile(mu, [1, numUnique])
                mu_band_rep = tf.reshape(mu_band_rep,
                                         (numUnique * numUnique, nDim))

                mu_diff = tf.subtract(mu_band_rep, mu_interleaved_rep)
                # Remove zero vector
                # intermediate_tensor = reduce_sum(tf.abs(x), 1)
                # zero_vector = tf.zeros(shape=(1,1), dtype=tf.float32)
                # bool_mask = tf.not_equal(intermediate_tensor, zero_vector)
                # omit_zeros = tf.boolean_mask(x, bool_mask)
                intermediate_tensor = tf.reduce_sum(tf.abs(mu_diff), 1)
                zero_vector = tf.zeros(shape=(1, 1), dtype=tf.float32)
                bool_mask = tf.not_equal(intermediate_tensor, zero_vector)
                omit_zeros = tf.boolean_mask(mu_diff, bool_mask)
                mu_diff = tf.expand_dims(omit_zeros, axis=1)
                print mu_diff
                mu_diff = tf.norm(mu_diff, ord=1)
                # squared_norm = tf.reduce_sum(tf.square(s), axis=axis,keep_dims=True)
                # safe_norm = tf.sqrt(squared_norm + epsilon)
                # squared_norm = tf.reduce_sum(tf.square(omit_zeros), axis=-1,keep_dims=True)
                # safe_norm = tf.sqrt(squared_norm + 1e-6)
                # mu_diff = safe_norm

                mu_diff = tf.divide(mu_diff, Lreg)

                mu_diff = tf.subtract(2 * delta_d, mu_diff)
                mu_diff = tf.clip_by_value(mu_diff, 0, mu_diff)
                mu_diff = tf.square(mu_diff)

                numUniqueF = tf.cast(numUnique, tf.float32)
                Ldist = tf.reduce_mean(mu_diff)

                # L = alpha * Lvar + beta * Ldist + gamma * Lreg
                # L = tf.reduce_mean(L, keep_dims=True)
                L = tf.reduce_sum([alpha * Lvar, beta * Ldist, gamma * Lreg],
                                  keep_dims=False)
                print L
                print Ldist
                print Lvar
                print Lreg
                return tf.identity(L, name=name)

            discrim_im = regDLF(cvt2imag(pm, maxVal=1.0),
                                feat_im,
                                name='discrim_im')
            discrim_iml = regDLF(cvt2imag(pl, maxVal=MAX_LABEL),
                                 feat_iml,
                                 name='discrim_iml')
            discrim_ml = regDLF(cvt2imag(pl, maxVal=MAX_LABEL),
                                feat_ml,
                                name='discrim_ml')
            print discrim_im
            print discrim_iml
            print discrim_ml

            print rand_iml
            print rand_ml
        self.g_loss = tf.reduce_sum(
            [
                #(recon_imi), # + recon_lmi + recon_imlmi), #
                10 * (recon_iml),  # + recon_lml + recon_lmiml), #
                10 * (recon_im),  #  + recon_lm + recon_mim + recon_mlm),
                10 * (recon_ml),  #  + recon_lm + recon_mim + recon_mlm),
                (rand_iml),  # + rand_lml + rand_lmiml), #
                (rand_ml),  #  + rand_lm + rand_mim + rand_mlm),
                # (G_loss_IL + G_loss_LI + G_loss_MI + G_loss_ML),
                (G_loss_LI + G_loss_MI),
                (0.1 * discrim_im + 10 * discrim_iml + 10 * discrim_ml),
                (
                    0.001 * membr_im
                ),  # + membr_lm + membr_imlm + membr_lmim + membr_mlm + membr_mim),
                # (label_iml + label_lml + label_lmiml + label_ml)
                (label_iml + label_ml)
            ],
            name='G_loss_total')
        self.d_loss = tf.reduce_sum(
            [
                # (D_loss_IL + D_loss_LI + D_loss_MI + D_loss_ML),
                (D_loss_LI + D_loss_MI),
            ],
            name='D_loss_total')

        wd_g = regularize_cost('gen/.*/W',
                               l2_regularizer(1e-5),
                               name='G_regularize')
        wd_d = regularize_cost('discrim/.*/W',
                               l2_regularizer(1e-5),
                               name='D_regularize')

        self.g_loss = tf.add(self.g_loss, wd_g, name='g_loss')
        self.d_loss = tf.add(self.d_loss, wd_d, name='d_loss')

        self.collect_variables()

        add_moving_summary(self.d_loss, self.g_loss)
        # add_moving_summary(
        with tf.name_scope('summaries'):
            add_tensor_summary(recon_iml, types=['scalar'], name='recon_iml')
            add_tensor_summary(recon_im, types=['scalar'], name='recon_im')
            add_tensor_summary(recon_ml, types=['scalar'], name='recon_ml')
            add_tensor_summary(label_iml, types=['scalar'], name='label_iml')
            add_tensor_summary(label_ml, types=['scalar'], name='label_ml')
            add_tensor_summary(rand_iml, types=['scalar'], name='rand_iml')
            add_tensor_summary(rand_ml, types=['scalar'], name='rand_ml')
            add_tensor_summary(membr_im, types=['scalar'], name='membr_im')
            add_tensor_summary(discrim_im, types=['scalar'], name='discrim_im')
            add_tensor_summary(discrim_iml,
                               types=['scalar'],
                               name='discrim_iml')
            add_tensor_summary(discrim_ml, types=['scalar'], name='discrim_ml')
            # recon_imi, recon_lmi, recon_imlmi,
            # recon_lml, recon_iml, recon_lmiml,
            # recon_mim, recon_mlm, recon_im , recon_lm,
            # )

        viz = tf.concat(
            [
                tf.concat([ui, pi, pim, piml, g_iml], 2),
                # tf.concat([ul, pl, plm, plmi, plmim, plmiml], 2),
                tf.concat([um, pl, pm, pml, g_ml], 2),
                # tf.concat([pl, pl, g_iml, g_lml, g_lmiml,   g_ml], 2),
            ],
            1)
        # add_moving_summary(
        # 	recon_imi, recon_lmi,# recon_imlmi,
        # 	recon_lml, recon_iml,# recon_lmiml,
        # 	recon_mim, recon_mlm, recon_im , recon_lm,
        # 	)
        # viz = tf.concat([tf.concat([ui, pi, pim, piml], 2),
        # 				 tf.concat([ul, pl, plm, plmi], 2),
        # 				 tf.concat([um, pm, pmi, pmim], 2),
        # 				 tf.concat([um, pm, pml, pmlm], 2),
        # 				 ], 1)
        viz = cvt2imag(viz)
        viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
        tf.summary.image('colorized', viz, max_outputs=50)