def build_graph(self): assert tf.test.is_gpu_available() image_raw = tf.constant(self.image, tf.float32) noise = tf.get_variable( 'noise', shape=[BATCH_SIZE, 28, 28], initializer=tf.random_normal_initializer(stddev=0.01)) image = tf.add(image_raw, noise, name='image') image = tf.clip_by_value(image, -1e-5, 1 + 1e-5) noise_show = tf.concat([ tf.concat([noise[i * N + j] for j in range(N)], axis=1) for i in range(N) ], axis=0, name='noise_show') image_show = tf.concat([ tf.concat([image[i * N + j] for j in range(N)], axis=1) for i in range(N) ], axis=0, name='image_show') l = tf.expand_dims(image, 3) l = l * 2 - 1 # center the pixels values at zero l = Conv2D('conv1', l, 32, 3, activation=tf.nn.relu) l = Conv2D('conv2', l, 64, 3, activation=tf.nn.relu) l = MaxPooling('pool2', l, 2) # l = Dropout(l, keep_prob=0.75) l = FullyConnected('fc1', l, 128) l = tf.nn.relu(l) # l = Dropout(l, keep_prob=0.5) logits = FullyConnected('fc2', l, 10) tf.argmax(logits, 1, name='pred_label') cost = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=self.label) cost = tf.reduce_mean(cost, name='cross_entropy_loss') add_tensor_summary(cost, ['scalar']) # l1_cost = regularize_cost('noise', tf.contrib.layers.l1_regularizer(1e-2)) l2_cost = regularize_cost('noise', tf.contrib.layers.l2_regularizer(1e-2)) add_tensor_summary(l2_cost, ['scalar']) add_tensor_summary( tf.div(tf.reduce_sum(tf.abs(noise)), BATCH_SIZE, name='L1Norm'), ['scalar']) add_tensor_summary( tf.div(tf.sqrt(tf.reduce_sum(tf.square(noise))), BATCH_SIZE, name='L2Norm'), ['scalar']) wrong = tf.to_float(tf.logical_not( tf.nn.in_top_k(logits, self.label, 1)), name='wrong_vector') add_tensor_summary(tf.reduce_mean(wrong, name='train_error'), ['scalar']) return tf.add_n([cost, l2_cost], name='cost')
def optimizer(self): lr_var = tf.get_variable("learning_rate", initializer=self._lr, trainable=False) add_tensor_summary(lr_var, ['scalar'], main_tower_only=True) return tf.train.AdamOptimizer(lr_var, beta1=self._beta1, beta2=self._beta2, epsilon=1e-3)
def wbce_malis (logits, affs, gt_affs, gt_seg, neighborhood, affs_shape, name='wbce_malis', limit_z=False): with tf.name_scope (name): pos_cnt = tf.cast (tf.count_nonzero (tf.cast (gt_affs, tf.int32)), tf.float32) neg_cnt = tf.cast (tf.constant (np.prod (affs_shape)), tf.float32) - pos_cnt pos_weight = neg_cnt / pos_cnt summary.add_tensor_summary (pos_weight, types=['scalar']) weighted_bce_losses = tf.nn.weighted_cross_entropy_with_logits (targets=gt_affs, logits=logits, pos_weight=pos_weight) malis_weights, pos_weights, neg_weights = malis_weights_op(affs, gt_affs, gt_seg, neighborhood, name='malis_weights', limit_z=limit_z) pos_weights = tf.identity (pos_weights, name='pos_weight') neg_weights = tf.identity (neg_weights, name='neg_weight') malis_weighted_bce_loss = tf.reduce_mean (tf.multiply (malis_weights, weighted_bce_losses), name='malis_weighted_bce_loss') return malis_weighted_bce_loss
def _build_graph(self, inputs): G = tf.get_default_graph() # For round tf.local_variables_initializer() tf.global_variables_initializer() pi, pm, pl, ui, um, ul, sm = inputs pi = cvt2tanh(pi) pm = cvt2tanh(pm) pl = cvt2tanh(pl) ui = cvt2tanh(ui) um = cvt2tanh(um) ul = cvt2tanh(ul) sm = cvt2tanh(sm) def tf_rand_score(x1, x2): return np.mean(1.0 - adjusted_rand_score(x1.flatten(), x2.flatten())) def rounded(label, factor=MAX_LABEL, name='quantized'): with G.gradient_override_map({"Round": "Identity"}): with freeze_variables(): with tf.name_scope(name=name): label = cvt2imag(label, maxVal=factor) label = tf.round(label) label = cvt2tanh(label, maxVal=factor) return tf.identity(label, name=name) with argscope([Conv2D, Deconv2D, FullyConnected], W_init=tf.truncated_normal_initializer(stddev=0.02), use_bias=False), \ argscope(BatchNorm, gamma_init=tf.random_uniform_initializer()), \ argscope([Conv2D, Deconv2D, BatchNorm], data_format='NHWC'), \ argscope(LeakyReLU, alpha=0.2): with tf.variable_scope('gen'): # Real pair image 4 gen with tf.variable_scope('I2M'): pim, feat_im = self.generator(pi) with tf.variable_scope('M2L'): piml, feat_iml = self.generator( tf.concat([pim, sm], axis=-1)) pml, feat_ml = self.generator(tf.concat([pm, sm], axis=-1)) with tf.variable_scope('discrim'): # with tf.variable_scope('I'): # i_dis_real = self.discriminator(ui) # i_dis_fake_from_label = self.discriminator(plmi) with tf.variable_scope('M'): m_dis_real = self.discriminator(um) m_dis_fake_from_image = self.discriminator(pim) # m_dis_fake_from_label = self.discriminator(plm) with tf.variable_scope('L'): l_dis_real = self.discriminator(ul) l_dis_fake_from_image = self.discriminator(piml) # piml = rounded(piml) # # pml = rounded(pml) with tf.name_scope('Recon_L_loss'): recon_iml = tf.reduce_mean(tf.abs((pl) - (piml)), name='recon_iml') recon_ml = tf.reduce_mean(tf.abs((pl) - (pml)), name='recon_ml') with tf.name_scope('Recon_M_loss'): recon_im = tf.reduce_mean(tf.abs((pm) - (pim)), name='recon_im') with tf.name_scope('GAN_loss'): # G_loss_IL, D_loss_IL = self.build_losses(i_dis_real, i_dis_fake_from_label, name='IL') G_loss_IL, D_loss_IL = self.build_losses(l_dis_real, l_dis_fake_from_image, name='IL') G_loss_MI, D_loss_MI = self.build_losses(m_dis_real, m_dis_fake_from_image, name='MI') # G_loss_ML, D_loss_ML = self.build_losses(m_dis_real, m_dis_fake_from_label, name='ML') # custom loss for membr with tf.name_scope('membr_loss'): def membr_loss(y_true, y_pred, name='membr_loss'): loss = tf.reduce_mean( tf.subtract( binary_cross_entropy(cvt2imag(y_true, maxVal=1.0), cvt2imag(y_pred, maxVal=1.0)), dice_coe(cvt2imag(y_true, maxVal=1.0), cvt2imag(y_pred, maxVal=1.0), axis=[1, 2, 3], loss_type='jaccard'))) # loss = tf.reshape(loss, [-1]) return tf.identity(loss, name=name) membr_im = membr_loss(pm, pim, name='membr_im') membr_iml = membr_loss(pl, piml, name='membr_iml') membr_ml = membr_loss(pl, pml, name='membr_ml') # custom loss for label with tf.name_scope('label_loss'): def label_loss(y_true_L, y_pred_L, y_grad_M, name='label_loss'): g_mag_grad_M = cvt2imag(y_grad_M, maxVal=1.0) mag_grad_L = magnitute_central_difference(y_pred_L, name='mag_grad_L') cond = tf.greater(mag_grad_L, tf.zeros_like(mag_grad_L)) thresholded_mag_grad_L = tf.where( cond, tf.ones_like(mag_grad_L), tf.zeros_like(mag_grad_L), name='thresholded_mag_grad_L') gtv_guess = tf.multiply(g_mag_grad_M, thresholded_mag_grad_L, name='gtv_guess') loss_gtv_guess = tf.reduce_mean(gtv_guess, name='loss_gtv_guess') # loss_gtv_guess = tf.reshape(loss_gtv_guess, [-1]) thresholded_mag_grad_L = cvt2tanh(thresholded_mag_grad_L, maxVal=1.0) gtv_guess = cvt2tanh(gtv_guess, maxVal=1.0) return tf.identity(loss_gtv_guess, name=name), thresholded_mag_grad_L label_iml, g_iml = label_loss(None, piml, pim, name='label_iml') label_ml, g_ml = label_loss(None, pml, pm, name='label_loss_ml') # custom loss for tf_rand_score with tf.name_scope('rand_loss'): rand_iml = tf.reduce_mean( tf.cast(tf.py_func(tf_rand_score, [piml, pl], tf.float64), tf.float32)) rand_ml = tf.reduce_mean( tf.cast(tf.py_func(tf_rand_score, [pml, pl], tf.float64), tf.float32)) with tf.name_scope('discrim_loss'): def regDLF(y_true, y_pred, alpha=1, beta=1, gamma=0.01, delta_v=0.5, delta_d=1.5, name='loss_discrim'): def tf_norm(inputs, axis=1, epsilon=1e-7, name='safe_norm'): squared_norm = tf.reduce_sum(tf.square(inputs), axis=axis, keep_dims=True) safe_norm = tf.sqrt(squared_norm + epsilon) return tf.identity(safe_norm, name=name) ### y_true = tf.reshape(y_true, [DIMZ * DIMY * DIMX]) nDim = tf.shape(y_pred)[-1] X = tf.reshape(y_pred, [DIMZ * DIMY * DIMX, nDim]) uniqueLabels, uniqueInd = tf.unique(y_true) numUnique = tf.size( uniqueLabels) # Get the number of connected component Sigma = tf.unsorted_segment_sum(X, uniqueInd, numUnique) # ones_Sigma = tf.ones((tf.shape(X)[0], 1)) ones_Sigma = tf.ones_like(X) ones_Sigma = tf.unsorted_segment_sum(ones_Sigma, uniqueInd, numUnique) mu = tf.divide(Sigma, ones_Sigma) Lreg = tf.reduce_mean(tf.norm(mu, axis=1, ord=1)) T = tf.norm(tf.subtract(tf.gather(mu, uniqueInd), X), axis=1, ord=1) T = tf.divide(T, Lreg) T = tf.subtract(T, delta_v) T = tf.clip_by_value(T, 0, T) T = tf.square(T) ones_Sigma = tf.ones_like(uniqueInd, dtype=tf.float32) ones_Sigma = tf.unsorted_segment_sum(ones_Sigma, uniqueInd, numUnique) clusterSigma = tf.unsorted_segment_sum(T, uniqueInd, numUnique) clusterSigma = tf.divide(clusterSigma, ones_Sigma) # Lvar = tf.reduce_mean(clusterSigma, axis=0) Lvar = tf.reduce_mean(clusterSigma) mu_interleaved_rep = tf.tile(mu, [numUnique, 1]) mu_band_rep = tf.tile(mu, [1, numUnique]) mu_band_rep = tf.reshape(mu_band_rep, (numUnique * numUnique, nDim)) mu_diff = tf.subtract(mu_band_rep, mu_interleaved_rep) # Remove zero vector # intermediate_tensor = reduce_sum(tf.abs(x), 1) # zero_vector = tf.zeros(shape=(1,1), dtype=tf.float32) # bool_mask = tf.not_equal(intermediate_tensor, zero_vector) # omit_zeros = tf.boolean_mask(x, bool_mask) intermediate_tensor = tf.reduce_sum(tf.abs(mu_diff), 1) zero_vector = tf.zeros(shape=(1, 1), dtype=tf.float32) bool_mask = tf.not_equal(intermediate_tensor, zero_vector) omit_zeros = tf.boolean_mask(mu_diff, bool_mask) mu_diff = tf.expand_dims(omit_zeros, axis=1) print mu_diff mu_diff = tf.norm(mu_diff, ord=1) # squared_norm = tf.reduce_sum(tf.square(s), axis=axis,keep_dims=True) # safe_norm = tf.sqrt(squared_norm + epsilon) # squared_norm = tf.reduce_sum(tf.square(omit_zeros), axis=-1,keep_dims=True) # safe_norm = tf.sqrt(squared_norm + 1e-6) # mu_diff = safe_norm mu_diff = tf.divide(mu_diff, Lreg) mu_diff = tf.subtract(2 * delta_d, mu_diff) mu_diff = tf.clip_by_value(mu_diff, 0, mu_diff) mu_diff = tf.square(mu_diff) numUniqueF = tf.cast(numUnique, tf.float32) Ldist = tf.reduce_mean(mu_diff) # L = alpha * Lvar + beta * Ldist + gamma * Lreg # L = tf.reduce_mean(L, keep_dims=True) L = tf.reduce_sum([alpha * Lvar, beta * Ldist, gamma * Lreg], keep_dims=False) print L print Ldist print Lvar print Lreg return tf.identity(L, name=name) discrim_im = regDLF(cvt2imag(pm, maxVal=1.0), feat_im, name='discrim_im') discrim_iml = regDLF(cvt2imag(pl, maxVal=1.0), feat_iml, name='discrim_iml') discrim_ml = regDLF(cvt2imag(pl, maxVal=1.0), feat_ml, name='discrim_ml') print discrim_im print discrim_iml print discrim_ml print rand_iml print rand_ml self.g_loss = tf.reduce_sum( [ #(recon_imi), # + recon_lmi + recon_imlmi), # # (recon_iml), # + recon_lml + recon_lmiml), # # (recon_im), # + recon_lm + recon_mim + recon_mlm), # (recon_ml), # + recon_lm + recon_mim + recon_mlm), # (rand_iml), # + rand_lml + rand_lmiml), # # (rand_ml), # + rand_lm + rand_mim + rand_mlm), # (G_loss_IL + G_loss_LI + G_loss_MI + G_loss_ML), (G_loss_IL + G_loss_MI), # 0.1*(discrim_im + discrim_iml + discrim_ml), 0.001 * ( membr_im ), # + membr_lm + membr_imlm + membr_lmim + membr_mlm + membr_mim), 0.001 * ( membr_iml ), # + membr_lm + membr_imlm + membr_lmim + membr_mlm + membr_mim), 0.001 * ( membr_ml ), # + membr_lm + membr_imlm + membr_lmim + membr_mlm + membr_mim), # (label_iml + label_lml + label_lmiml + label_ml) # (label_iml + label_ml) ], name='G_loss_total') self.d_loss = tf.reduce_sum( [ # (D_loss_IL + D_loss_LI + D_loss_MI + D_loss_ML), (D_loss_IL + D_loss_MI), ], name='D_loss_total') wd_g = regularize_cost('gen/.*/W', l2_regularizer(1e-5), name='G_regularize') wd_d = regularize_cost('discrim/.*/W', l2_regularizer(1e-5), name='D_regularize') self.g_loss = tf.add(self.g_loss, wd_g, name='g_loss') self.d_loss = tf.add(self.d_loss, wd_d, name='d_loss') self.collect_variables() add_moving_summary(self.d_loss, self.g_loss) # add_moving_summary( with tf.name_scope('summaries'): add_tensor_summary(recon_iml, types=['scalar'], name='recon_iml') add_tensor_summary(recon_im, types=['scalar'], name='recon_im') add_tensor_summary(recon_ml, types=['scalar'], name='recon_ml') add_tensor_summary(label_iml, types=['scalar'], name='label_iml') add_tensor_summary(label_ml, types=['scalar'], name='label_ml') add_tensor_summary(rand_iml, types=['scalar'], name='rand_iml') add_tensor_summary(rand_ml, types=['scalar'], name='rand_ml') add_tensor_summary(membr_im, types=['scalar'], name='membr_im') add_tensor_summary(membr_iml, types=['scalar'], name='membr_iml') add_tensor_summary(membr_ml, types=['scalar'], name='membr_ml') add_tensor_summary(discrim_im, types=['scalar'], name='discrim_im') add_tensor_summary(discrim_iml, types=['scalar'], name='discrim_iml') add_tensor_summary(discrim_ml, types=['scalar'], name='discrim_ml') # recon_imi, recon_lmi, recon_imlmi, # recon_lml, recon_iml, recon_lmiml, # recon_mim, recon_mlm, recon_im , recon_lm, # ) viz = tf.concat( [ tf.concat([ui, pi, sm, pim, piml, g_iml], 2), # tf.concat([ul, pl, plm, plmi, plmim, plmiml], 2), tf.concat([um, pl, ul, pm, pml, g_ml], 2), # tf.concat([pl, pl, g_iml, g_lml, g_lmiml, g_ml], 2), ], 1) # add_moving_summary( # recon_imi, recon_lmi,# recon_imlmi, # recon_lml, recon_iml,# recon_lmiml, # recon_mim, recon_mlm, recon_im , recon_lm, # ) # viz = tf.concat([tf.concat([ui, pi, pim, piml], 2), # tf.concat([ul, pl, plm, plmi], 2), # tf.concat([um, pm, pmi, pmim], 2), # tf.concat([um, pm, pml, pmlm], 2), # ], 1) viz = cvt2imag(viz) viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz') tf.summary.image('colorized', viz, max_outputs=50)
def _build_graph(self, inputs): G = tf.get_default_graph() # For round tf.local_variables_initializer() tf.global_variables_initializer() pi, pm, pl, ui, um, ul = inputs pi = tf_2tanh(pi) pm = tf_2tanh(pm) pl = tf_2tanh(pl) ui = tf_2tanh(ui) um = tf_2tanh(um) ul = tf_2tanh(ul) pl = toMaxLabels(pl, factor=MAX_LABEL) #0 MAX pl = toRangeTanh(pl, factor=MAX_LABEL) # -1 1 pa = seg_to_aff_op(toMaxLabels(pl, factor=MAX_LABEL), name='pa') # Calculate the affinity #0, 1 with argscope([Conv2D, Deconv2D, FullyConnected], W_init=tf.truncated_normal_initializer(stddev=0.02), use_bias=False), \ argscope(BatchNorm, gamma_init=tf.random_uniform_initializer()), \ argscope([Conv2D, Deconv2D, BatchNorm], data_format='NHWC'), \ argscope([Conv2D], dilation_rate=1): with tf.variable_scope('gen'): with tf.variable_scope('affnt'): pia, feat_ia = self.generator(pi, last_dim=3) with tf.variable_scope('label'): pil, feat_il = self.generator(pi, last_dim=1) # with tf.variable_scope('discrim'): dis_real = self.discriminator(tf.concat([pi, pl, pa], axis=-1)) dis_fake = self.discriminator( tf.concat([pi, pil, pia], axis=-1)) with G.gradient_override_map({"Round": "Identity"}): with tf.variable_scope('fix'): # Round pil = toMaxLabels(pil, factor=MAX_LABEL) #0 MAX pil = toRangeTanh(pil, factor=MAX_LABEL) # -1 1 pila = seg_to_aff_op( toMaxLabels(pil, factor=MAX_LABEL), name='pila') # Calculate the affinity #0, 1 pial = aff_to_seg_op(tf_2imag(pia, maxVal=1.0), name='pial') # Calculate the segmentation pial = toRangeTanh(pial, factor=MAX_LABEL) # -1, 1 pil_ = (pial + pil) / 2.0 # Return the result pia_ = (pila + pia) / 2.0 with tf.name_scope('GAN_loss'): G_loss, D_loss = self.build_losses(dis_real, dis_fake, name='gan_loss') with tf.name_scope('rand_loss'): rand_il = tf.reduce_mean(tf_rand_score(pl, pil_), name='rand_loss') # rand_il = tf.reduce_mean(tf.cast(tf.py_func (tf_rand_score, [pl, pil_], tf.float64), tf.float32), name='rand_loss') with tf.name_scope('discrim_loss'): def regDLF(y_true, y_pred, alpha=1, beta=1, gamma=0.01, delta_v=0.5, delta_d=1.5, name='loss_discrim'): def tf_norm(inputs, axis=1, epsilon=1e-7, name='safe_norm'): squared_norm = tf.reduce_sum(tf.square(inputs), axis=axis, keep_dims=True) safe_norm = tf.sqrt(squared_norm + epsilon) return tf.identity(safe_norm, name=name) ### lins = tf.linspace(0.0, DIMZ * DIMY * DIMX, DIMZ * DIMY * DIMX) lins = tf.cast(lins, tf.int32) # lins = lins / tf.reduce_max(lins) * 255 # lins = tf_2tanh(lins) # lins = tf.reshape(lins, tf.shape(y_true), name='lins_3d') # print lins lins_z = tf.div(lins, (DIMY * DIMX)) lins_y = tf.div(tf.mod(lins, (DIMY * DIMX)), DIMY) lins_x = tf.mod(tf.mod(lins, (DIMY * DIMX)), DIMY) lins = tf.cast(lins, tf.float32) lins_z = tf.cast(lins_z, tf.float32) lins_y = tf.cast(lins_y, tf.float32) lins_x = tf.cast(lins_x, tf.float32) lins = lins / tf.reduce_max(lins) * 255 lins_z = lins_z / tf.reduce_max(lins_z) * 255 lins_y = lins_y / tf.reduce_max(lins_y) * 255 lins_x = lins_x / tf.reduce_max(lins_x) * 255 lins = tf_2tanh(lins) lins_z = tf_2tanh(lins_z) lins_y = tf_2tanh(lins_y) lins_x = tf_2tanh(lins_x) lins = tf.reshape(lins, tf.shape(y_true), name='lins') lins_z = tf.reshape(lins_z, tf.shape(y_true), name='lins_z') lins_y = tf.reshape(lins_y, tf.shape(y_true), name='lins_y') lins_x = tf.reshape(lins_x, tf.shape(y_true), name='lins_x') y_true = tf.reshape(y_true, [DIMZ * DIMY * DIMX]) y_pred = tf.concat([y_pred, lins, lins_z, lins_y, lins_x], axis=-1) nDim = tf.shape(y_pred)[-1] X = tf.reshape(y_pred, [DIMZ * DIMY * DIMX, nDim]) uniqueLabels, uniqueInd = tf.unique(y_true) numUnique = tf.size( uniqueLabels) # Get the number of connected component Sigma = tf.unsorted_segment_sum(X, uniqueInd, numUnique) # ones_Sigma = tf.ones((tf.shape(X)[0], 1)) ones_Sigma = tf.ones_like(X) ones_Sigma = tf.unsorted_segment_sum(ones_Sigma, uniqueInd, numUnique) mu = tf.divide(Sigma, ones_Sigma) Lreg = tf.reduce_mean(tf.norm(mu, axis=1, ord=1)) T = tf.norm(tf.subtract(tf.gather(mu, uniqueInd), X), axis=1, ord=1) T = tf.divide(T, Lreg) T = tf.subtract(T, delta_v) T = tf.clip_by_value(T, 0, T) T = tf.square(T) ones_Sigma = tf.ones_like(uniqueInd, dtype=tf.float32) ones_Sigma = tf.unsorted_segment_sum(ones_Sigma, uniqueInd, numUnique) clusterSigma = tf.unsorted_segment_sum(T, uniqueInd, numUnique) clusterSigma = tf.divide(clusterSigma, ones_Sigma) # Lvar = tf.reduce_mean(clusterSigma, axis=0) Lvar = tf.reduce_mean(clusterSigma) mu_interleaved_rep = tf.tile(mu, [numUnique, 1]) mu_band_rep = tf.tile(mu, [1, numUnique]) mu_band_rep = tf.reshape(mu_band_rep, (numUnique * numUnique, nDim)) mu_diff = tf.subtract(mu_band_rep, mu_interleaved_rep) # Remove zero vector # intermediate_tensor = reduce_sum(tf.abs(x), 1) # zero_vector = tf.zeros(shape=(1,1), dtype=tf.float32) # bool_mask = tf.not_equal(intermediate_tensor, zero_vector) # omit_zeros = tf.boolean_mask(x, bool_mask) intermediate_tensor = tf.reduce_sum(tf.abs(mu_diff), 1) zero_vector = tf.zeros(shape=(1, 1), dtype=tf.float32) bool_mask = tf.not_equal(intermediate_tensor, zero_vector) omit_zeros = tf.boolean_mask(mu_diff, bool_mask) mu_diff = tf.expand_dims(omit_zeros, axis=1) print mu_diff mu_diff = tf.norm(mu_diff, ord=1) # squared_norm = tf.reduce_sum(tf.square(s), axis=axis,keep_dims=True) # safe_norm = tf.sqrt(squared_norm + epsilon) # squared_norm = tf.reduce_sum(tf.square(omit_zeros), axis=-1,keep_dims=True) # safe_norm = tf.sqrt(squared_norm + 1e-6) # mu_diff = safe_norm mu_diff = tf.divide(mu_diff, Lreg) mu_diff = tf.subtract(2 * delta_d, mu_diff) mu_diff = tf.clip_by_value(mu_diff, 0, mu_diff) mu_diff = tf.square(mu_diff) numUniqueF = tf.cast(numUnique, tf.float32) Ldist = tf.reduce_mean(mu_diff) # L = alpha * Lvar + beta * Ldist + gamma * Lreg # L = tf.reduce_mean(L, keep_dims=True) L = tf.reduce_sum([alpha * Lvar, beta * Ldist, gamma * Lreg], keep_dims=False) print L print Ldist print Lvar print Lreg return tf.identity(L, name=name) discrim_il = regDLF(toMaxLabels(pl, factor=MAX_LABEL), tf.concat([feat_il, feat_ia], axis=-1), name='discrim_il') with tf.name_scope('recon_loss'): recon_il = tf.reduce_mean(tf.abs(pl - pil_), name='recon_il') with tf.name_scope('affnt_loss'): # affnt_il = tf.reduce_mean(tf.abs(pa - pia_), name='affnt_il') affnt_il = tf.reduce_mean( tf.subtract( binary_cross_entropy(pa, pia_), dice_coe(pa, pia_, axis=[0, 1, 2, 3], loss_type='jaccard'))) with tf.name_scope('residual_loss'): # residual_a = tf.reduce_mean(tf.abs(pia - pila), name='residual_a') # residual_l = tf.reduce_mean(tf.abs(pil - pial), name='residual_l') residual_a = tf.reduce_mean(tf.cast(tf.not_equal(pia, pila), tf.float32), name='residual_a') residual_l = tf.reduce_mean(tf.cast(tf.not_equal(pil, pial), tf.float32), name='residual_l') residual_il = tf.reduce_mean([residual_a, residual_l], name='residual_il') def label_imag(y_pred_L, name='label_imag'): mag_grad_L = magnitute_central_difference(y_pred_L, name='mag_grad_L') cond = tf.greater(mag_grad_L, tf.zeros_like(mag_grad_L)) thresholded_mag_grad_L = tf.where(cond, tf.ones_like(mag_grad_L), tf.zeros_like(mag_grad_L), name='thresholded_mag_grad_L') thresholded_mag_grad_L = tf_2tanh(thresholded_mag_grad_L, maxVal=1.0) return thresholded_mag_grad_L g_il = label_imag(pil_, name='label_il') g_l = label_imag(pl, name='label_l') self.g_loss = tf.reduce_sum([ 1 * (G_loss), 10 * (recon_il), 10 * (residual_il), 1 * (rand_il), 50 * (discrim_il), 0.005 * affnt_il, ], name='G_loss_total') self.d_loss = tf.reduce_sum([D_loss], name='D_loss_total') wd_g = regularize_cost('gen/.*/W', l2_regularizer(1e-5), name='G_regularize') wd_d = regularize_cost('discrim/.*/W', l2_regularizer(1e-5), name='D_regularize') self.g_loss = tf.add(self.g_loss, wd_g, name='g_loss') self.d_loss = tf.add(self.d_loss, wd_d, name='d_loss') self.collect_variables() add_moving_summary(self.d_loss, self.g_loss) with tf.name_scope('summaries'): add_tensor_summary(recon_il, types=['scalar'], name='recon_il') add_tensor_summary(rand_il, types=['scalar'], name='rand_il') add_tensor_summary(discrim_il, types=['scalar'], name='discrim_il') add_tensor_summary(affnt_il, types=['scalar'], name='affnt_il') add_tensor_summary(residual_il, types=['scalar'], name='residual_il') #Segmentation viz = tf.concat([ tf.concat([pi, pl, g_l, g_il], 2), tf.concat([tf.zeros_like(pi), pil, pial, pil_], 2), ], 1) viz = tf_2imag(viz) viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz') tf.summary.image('colorized', viz, max_outputs=50) # Affinity vis = tf.concat([pa, pila, pia, pia_], 2) vis = tf_2imag(vis) vis = tf.cast(tf.clip_by_value(vis, 0, 255), tf.uint8, name='vis') tf.summary.image('affinities', vis, max_outputs=50)
def _build_graph(self, inputs): # sImg2d # sImg the projection 2D, reshape from VGG19_MEAN = np.array([123.68, 116.779, 103.939]) # RGB VGG19_MEAN_TENSOR = tf.constant(VGG19_MEAN, dtype=tf.float32) image, style = inputs # Split the input @auto_reuse_variable_scope def vgg19_encoder(source, name='VGG19_Encoder'): with tf.variable_scope(name): with varreplace.freeze_variables(): with argscope([Conv2D], kernel_shape=3, nl=tf.nn.relu): source = source - VGG19_MEAN_TENSOR conv1_1 = Conv2D('conv1_1', source, 64) conv1_2 = Conv2D('conv1_2', conv1_1, 64) pool1 = MaxPooling('pool1', conv1_2, 2) # 64 conv2_1 = Conv2D('conv2_1', pool1, 128) conv2_2 = Conv2D('conv2_2', conv2_1, 128) pool2 = MaxPooling('pool2', conv2_2, 2) # 32 conv3_1 = Conv2D('conv3_1', pool2, 256) conv3_2 = Conv2D('conv3_2', conv3_1, 256) conv3_3 = Conv2D('conv3_3', conv3_2, 256) conv3_4 = Conv2D('conv3_4', conv3_3, 256) pool3 = MaxPooling('pool3', conv3_4, 2) # 16 conv4_1 = Conv2D('conv4_1', pool3, 512) conv4_2 = Conv2D('conv4_2', conv4_1, 512) conv4_3 = Conv2D('conv4_3', conv4_2, 512) conv4_4 = Conv2D('conv4_4', conv4_3, 512) pool4 = MaxPooling('pool4', conv4_4, 2) # 8 conv5_1 = Conv2D('conv5_1', pool4, 512) conv5_2 = Conv2D('conv5_2', conv5_1, 512) conv5_3 = Conv2D('conv5_3', conv5_2, 512) conv5_4 = Conv2D('conv5_4', conv5_3, 512) pool5 = MaxPooling('pool5', conv5_4, 2) # 4 # return normalize(conv4_1), [normalize(conv1_1), normalize(conv2_1), normalize(conv3_1), normalize(conv4_1)] # List of returned feature maps return conv4_1, [conv1_1, conv2_1, conv3_1, conv4_1] @auto_reuse_variable_scope def vgg19_decoder(source, name='VGG19_Decoder'): with tf.variable_scope(name): # with varreplace.freeze_variables(): with argscope([Conv2D], kernel_shape=3, nl=tf.nn.elu): with argscope([Deconv2D], kernel_shape=3, strides=(2, 2), nl=tf.nn.elu): # conv5_4 = Conv2D('conv5_4', input, 512) # conv5_3 = Conv2D('conv5_3', conv5_4, 512) # conv5_2 = Conv2D('conv5_2', conv5_3, 512) # conv5_1 = Conv2D('conv5_1', conv5_2, 512) # pool4 = Deconv2D('pool4', input, 512) # 8 # conv4_4 = Conv2D('conv4_4', pool4, 512) # conv4_3 = Conv2D('conv4_3', conv4_4, 512) # conv4_2 = Conv2D('conv4_2', conv4_3, 512) # conv4_1 = Conv2D('conv4_1', conv4_2, 512) pool3 = Subpix2D('pool3', source, 256) # 16 conv3_4 = Conv2D('conv3_4', pool3, 256) conv3_3 = Conv2D('conv3_3', conv3_4, 256) conv3_2 = Conv2D('conv3_2', conv3_3, 256) conv3_1 = Conv2D('conv3_1', conv3_2, 256) pool2 = Subpix2D('pool2', conv3_1, 128) # 32 conv2_2 = Conv2D('conv2_2', pool2, 128) conv2_1 = Conv2D('conv2_1', conv2_2, 128) pool1 = Subpix2D('pool1', conv2_1, 64) # 64 conv1_2 = Conv2D('conv1_2', pool1, 64) conv1_1 = Conv2D('conv1_1', conv1_2, 64) conv1_0 = Conv2D('conv1_0', conv1_1, 3) conv1_0 = conv1_0 + VGG19_MEAN_TENSOR # conv1_0 = 255.0*tf.nn.tanh(conv1_0) # conv1_0 = tf_2tanh(conv1_0, maxVal=255.0) # conv1_0 = tf_2imag(conv1_0, maxVal=255.0) # conv1_0 = conv1_0 - VGG19_MEAN_TENSOR return conv1_0 # List of feature maps @auto_reuse_variable_scope def vgg19_feature(source, name='VGG19_Feature'): with tf.variable_scope(name): with varreplace.freeze_variables(): with argscope([Conv2D], kernel_shape=3, nl=tf.nn.relu): source = source - VGG19_MEAN_TENSOR conv1_1 = Conv2D('conv1_1', source, 64) conv1_2 = Conv2D('conv1_2', conv1_1, 64) pool1 = MaxPooling('pool1', conv1_2, 2) # 64 conv2_1 = Conv2D('conv2_1', pool1, 128) conv2_2 = Conv2D('conv2_2', conv2_1, 128) pool2 = MaxPooling('pool2', conv2_2, 2) # 32 conv3_1 = Conv2D('conv3_1', pool2, 256) conv3_2 = Conv2D('conv3_2', conv3_1, 256) conv3_3 = Conv2D('conv3_3', conv3_2, 256) conv3_4 = Conv2D('conv3_4', conv3_3, 256) pool3 = MaxPooling('pool3', conv3_4, 2) # 16 conv4_1 = Conv2D('conv4_1', pool3, 512) conv4_2 = Conv2D('conv4_2', conv4_1, 512) conv4_3 = Conv2D('conv4_3', conv4_2, 512) conv4_4 = Conv2D('conv4_4', conv4_3, 512) pool4 = MaxPooling('pool4', conv4_4, 2) # 8 conv5_1 = Conv2D('conv5_1', pool4, 512) conv5_2 = Conv2D('conv5_2', conv5_1, 512) conv5_3 = Conv2D('conv5_3', conv5_2, 512) conv5_4 = Conv2D('conv5_4', conv5_3, 512) pool5 = MaxPooling('pool5', conv5_4, 2) # 4 return conv4_1, [conv1_1, conv2_1, conv3_1, conv4_1] # return normalize(conv4_1), [normalize(conv1_1), normalize(conv2_1), normalize(conv3_1), normalize(conv4_1)] # List of returned feature maps # Step 1: Run thru the encoder image_encoded, image_feature = vgg19_encoder(image) style_encoded, style_feature = vgg19_encoder(style) # Step 2: Run thru the adain block to get t=AdIN(f(c), f(s)) merge_encoded = self._build_adain_layers(image_encoded, style_encoded) # Step 3: Run thru the decoder to get the paint image paint = vgg19_decoder(merge_encoded) # Actually, vgg19_feature and vgg19_encoder are identical # Splitting them to improve the programmability paint_encoded, paint_feature = vgg19_feature(paint) style_encoded, style_feature = vgg19_feature(style) # print(merge_encoded.get_shape()) # print(paint_encoded.get_shape()) # # Build losses here # with tf.name_scope('losses'): losses = [] # Content loss between t and f(g(t)) content_loss = self._build_content_loss(merge_encoded, paint_encoded, weight=args.weight_c) # add_moving_summary(content_loss) add_tensor_summary(content_loss, types=['scalar'], name='content_loss') losses.append(content_loss) # Style losses between paint and style style_losses = self._build_style_losses(paint_feature, style_feature, weight=args.weight_s) for idx, style_loss in enumerate(style_losses): add_tensor_summary(style_loss, types=['scalar'], name='style_loss') losses.append(style_loss) # Total variation loss smoothness = tf.reduce_sum(tf.image.total_variation(paint)) add_tensor_summary(smoothness, types=['scalar'], name='smoothness') losses.append(smoothness * args.weight_tv) # Total loss self.cost = tf.reduce_sum( losses, name='self.cost') # this one goes to the optimizer add_tensor_summary(self.cost, types=['scalar'], name='self.cost') # Reconstruct img # image = image + VGG19_MEAN_TENSOR # style = style + VGG19_MEAN_TENSOR paint = tf.identity(paint, name='paint') # Build loss in here # Visualization viz = tf.concat([image, style, paint], axis=2) viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz') tf.summary.image('colorized', viz, max_outputs=50)
def build_graph(self, *inputs): is_training = get_current_tower_context().is_training image, anchor_labels, anchor_boxes, gt_boxes, gt_labels, gt_ids, orig_shape = inputs image = self.preprocess(image) # 1CHW featuremap = resnet_c4_backbone(image, cfg.BACKBONE.RESNET_NUM_BLOCK[:3]) rpn_label_logits, rpn_box_logits = rpn_head('rpn', featuremap, cfg.RPN.HEAD_DIM, cfg.RPN.NUM_ANCHOR) anchors = RPNAnchors(get_all_anchors(), anchor_labels, anchor_boxes) anchors = anchors.narrow_to(featuremap) image_shape2d = tf.shape(image)[2:] # h,w # decode into actual image coordinates pred_boxes_decoded = anchors.decode_logits( rpn_box_logits) # fHxfWxNAx4, floatbox proposal_boxes, proposal_scores = generate_rpn_proposals( tf.reshape(pred_boxes_decoded, [-1, 4]), tf.reshape(rpn_label_logits, [-1]), image_shape2d, cfg.RPN.TRAIN_PRE_NMS_TOPK if is_training else cfg.RPN.TEST_PRE_NMS_TOPK, cfg.RPN.TRAIN_POST_NMS_TOPK if is_training else cfg.RPN.TEST_POST_NMS_TOPK) if is_training: # sample proposal boxes in training rcnn_boxes, rcnn_labels, fg_inds_wrt_gt = sample_fast_rcnn_targets( proposal_boxes, gt_boxes, gt_labels) else: # The boxes to be used to crop RoIs. # Use all proposal boxes in inference rcnn_boxes = proposal_boxes boxes_on_featuremap = rcnn_boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE) # size? #proposals*h*w*c? roi_resized = roi_align(featuremap, boxes_on_featuremap, 14) feature_fastrcnn = resnet_conv5( roi_resized, cfg.BACKBONE.RESNET_NUM_BLOCK[-1]) # nxcx7x7 # Keep C5 feature to be shared with mask branch feature_gap = GlobalAvgPooling('gap', feature_fastrcnn, data_format='channels_first') fastrcnn_label_logits, fastrcnn_box_logits = fastrcnn_outputs( 'fastrcnn', feature_gap, cfg.DATA.NUM_CLASS) if is_training: # rpn loss rpn_label_loss, rpn_box_loss = rpn_losses( anchors.gt_labels, anchors.encoded_gt_boxes(), rpn_label_logits, rpn_box_logits) # fastrcnn loss matched_gt_boxes = tf.gather(gt_boxes, fg_inds_wrt_gt) fg_inds_wrt_sample = tf.reshape(tf.where(rcnn_labels > 0), [-1]) # fg inds w.r.t all samples # outputs from fg proposals fg_sampled_boxes = tf.gather(rcnn_boxes, fg_inds_wrt_sample) fg_fastrcnn_box_logits = tf.gather(fastrcnn_box_logits, fg_inds_wrt_sample) # rcnn_labels: the labels of the proposals # fg_sampled_boxes: fg proposals # matched_gt_boxes: just like RPN, the gt boxes # that match the corresponding fg proposals fastrcnn_label_loss, fastrcnn_box_loss = self.fastrcnn_training( image, rcnn_labels, fg_sampled_boxes, matched_gt_boxes, fastrcnn_label_logits, fg_fastrcnn_box_logits) # acquire pred for re-id training # turning NMS off gives re-id branch more training samples if cfg.RE_ID.NMS: boxes, final_labels, final_probs = self.fastrcnn_inference( image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits) else: boxes, final_labels, final_probs = self.fastrcnn_inference_id( image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits) # scale = tf.sqrt(tf.cast(image_shape2d[0], tf.float32) / tf.cast(orig_shape[0], tf.float32) * # tf.cast(image_shape2d[1], tf.float32) / tf.cast(orig_shape[1], tf.float32)) # final_boxes = boxes / scale # # boxes are already clipped inside the graph, but after the floating point scaling, this may not be true any more. # final_boxes = tf_clip_boxes(final_boxes, orig_shape) # IOU, discard bad dets, assign re-id labels # the results are already NMS so no need to NMS again # crop from conv4 with dets (maybe plus gts) # feedforward re-id branch # resizing during ROIalign? iou = pairwise_iou(boxes, gt_boxes) # are the gt boxes resized? tp_mask = tf.reduce_max(iou, axis=1) >= cfg.RE_ID.IOU_THRESH iou = tf.boolean_mask(iou, tp_mask) # return iou to debug def re_id_loss(pred_boxes, pred_matching_gt_ids, featuremap): with tf.variable_scope('id_head'): num_of_samples_used = tf.get_variable( 'num_of_samples_used', initializer=0, trainable=False) num_of_samples_used = num_of_samples_used.assign_add( tf.shape(pred_boxes)[0]) boxes_on_featuremap = pred_boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE) # name scope? # stop gradient roi_resized = roi_align(featuremap, boxes_on_featuremap, 14) feature_idhead = resnet_conv5( roi_resized, cfg.BACKBONE.RESNET_NUM_BLOCK[-1]) # nxcx7x7 feature_gap = GlobalAvgPooling( 'gap', feature_idhead, data_format='channels_first') init = tf.variance_scaling_initializer() hidden = FullyConnected('fc6', feature_gap, 1024, kernel_initializer=init, activation=tf.nn.relu) hidden = FullyConnected('fc7', hidden, 1024, kernel_initializer=init, activation=tf.nn.relu) hidden = FullyConnected('fc8', hidden, 256, kernel_initializer=init, activation=tf.nn.relu) id_logits = FullyConnected( 'class', hidden, cfg.DATA.NUM_ID, kernel_initializer=tf.random_normal_initializer( stddev=0.01)) label_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=pred_matching_gt_ids, logits=id_logits) label_loss = tf.reduce_mean(label_loss, name='label_loss') return label_loss, num_of_samples_used def check_unid_pedes(iou, gt_ids, boxes, tp_mask, featuremap): pred_gt_ind = tf.argmax(iou, axis=1) # output following tensors # pick out the -2 class here pred_matching_gt_ids = tf.gather(gt_ids, pred_gt_ind) pred_boxes = tf.boolean_mask(boxes, tp_mask) # label 1 corresponds to unid pedes unid_ind = tf.not_equal(pred_matching_gt_ids, 1) pred_matching_gt_ids = tf.boolean_mask(pred_matching_gt_ids, unid_ind) pred_boxes = tf.boolean_mask(pred_boxes, unid_ind) ret = tf.cond( tf.equal(tf.size(pred_boxes), 0), lambda: (tf.constant(cfg.RE_ID.STABLE_LOSS), tf.constant(0)), lambda: re_id_loss(pred_boxes, pred_matching_gt_ids, featuremap)) return ret with tf.name_scope('id_head'): # no detection has IOU > 0.7, re-id returns 0 loss re_id_loss, num_of_samples_used = tf.cond( tf.equal(tf.size(iou), 0), lambda: (tf.constant(cfg.RE_ID.STABLE_LOSS), tf.constant(0)), lambda: check_unid_pedes(iou, gt_ids, boxes, tp_mask, featuremap)) add_tensor_summary(num_of_samples_used, ['scalar'], name='num_of_samples_used') # for debug, use tensor name to take out the handle # return re_id_loss # pred_gt_ind = tf.argmax(iou, axis=1) # # output following tensors # # pick out the -2 class here # pred_gt_ids = tf.gather(gt_ids, pred_gt_ind) # pred_boxes = tf.boolean_mask(boxes, tp_mask) # unid_ind = pred_gt_ids != 1 # return unid_ind # return tf.shape(boxes)[0] unnormed_id_loss = tf.identity(re_id_loss, name='unnormed_id_loss') re_id_loss = tf.divide(re_id_loss, cfg.RE_ID.LOSS_NORMALIZATION, 're_id_loss') add_moving_summary(unnormed_id_loss) add_moving_summary(re_id_loss) wd_cost = regularize_cost('.*/W', l2_regularizer(cfg.TRAIN.WEIGHT_DECAY), name='wd_cost') # weights on the losses? total_cost = tf.add_n([ rpn_label_loss, rpn_box_loss, fastrcnn_label_loss, fastrcnn_box_loss, re_id_loss, wd_cost ], 'total_cost') add_moving_summary(total_cost, wd_cost) return total_cost else: if cfg.RE_ID.QUERY_EVAL: # resize the gt_boxes in dataflow final_boxes = gt_boxes else: final_boxes, final_labels, _ = self.fastrcnn_inference( image_shape2d, rcnn_boxes, fastrcnn_label_logits, fastrcnn_box_logits) with tf.variable_scope('id_head'): preds_on_featuremap = final_boxes * (1.0 / cfg.RPN.ANCHOR_STRIDE) # name scope? # stop gradient roi_resized = roi_align(featuremap, preds_on_featuremap, 14) feature_idhead = resnet_conv5( roi_resized, cfg.BACKBONE.RESNET_NUM_BLOCK[-1]) # nxcx7x7 feature_gap = GlobalAvgPooling('gap', feature_idhead, data_format='channels_first') hidden = FullyConnected('fc6', feature_gap, 1024, activation=tf.nn.relu) hidden = FullyConnected('fc7', hidden, 1024, activation=tf.nn.relu) fv = FullyConnected('fc8', hidden, 256, activation=tf.nn.relu) id_logits = FullyConnected( 'class', fv, cfg.DATA.NUM_ID, kernel_initializer=tf.random_normal_initializer( stddev=0.01)) scale = tf.sqrt( tf.cast(image_shape2d[0], tf.float32) / tf.cast(orig_shape[0], tf.float32) * tf.cast(image_shape2d[1], tf.float32) / tf.cast(orig_shape[1], tf.float32)) rescaled_final_boxes = final_boxes / scale # boxes are already clipped inside the graph, but after the floating point scaling, this may not be true any more. # rescaled_final_boxes_pre_clip = tf.identity(rescaled_final_boxes, name='re_boxes_pre_clip') rescaled_final_boxes = tf_clip_boxes(rescaled_final_boxes, orig_shape) rescaled_final_boxes = tf.identity(rescaled_final_boxes, 'rescaled_final_boxes') fv = tf.identity(fv, name='feature_vector') prob = tf.nn.softmax(id_logits, name='re_id_probs')
def _build_graph(self, inputs): volume, gt_seg, gt_affs = inputs volume = tf.expand_dims(volume, 3) volume = tf.expand_dims(volume, 0) # image = image * 2 - 1 with argscope(LeakyReLU, alpha=0.2),\ argscope ([Conv3D, DeConv3D], use_bias=False, kernel_shape=3, stride=2, padding='SAME', W_init=tf.contrib.layers.variance_scaling_initializer(factor=0.333, uniform=True)): _in = Conv3D('in', volume, NB_FILTERS, kernel_shape=(3, 5, 5), stride=1, padding='SAME', nl=INELU, use_bias=True) e0 = residual_enc('e0', volume, NB_FILTERS * 1) e1 = residual_enc('e1', e0, NB_FILTERS * 2) e2 = residual_enc('e2', e1, NB_FILTERS * 4, kernel_shape=(1, 3, 3)) e3 = residual_enc('e3', e2, NB_FILTERS * 8, kernel_shape=(1, 3, 3)) e3 = Dropout('dr', e3, rate=0.5) d3 = residual_dec('d3', e3 + e3, NB_FILTERS * 4, kernel_shape=(1, 3, 3)) d2 = residual_dec('d2', d3 + e2, NB_FILTERS * 2, kernel_shape=(1, 3, 3)) d1 = residual_dec('d1', d2 + e1, NB_FILTERS * 1) d0 = residual_dec('d0', d1 + e0, NB_FILTERS * 1) logits = funcs.Conv3D('x_out', d0, len(nhood), kernel_shape=(3, 5, 5), stride=1, padding='SAME', nl=tf.identity, use_bias=True) logits = tf.squeeze(logits) logits = tf.transpose(logits, perm=[3, 0, 1, 2], name='logits') affs = cvt2sigm(tf.tanh(logits)) affs = tf.identity(affs, name='affs') ###################################################################################################################### wbce_malis_loss = wbce_malis(logits, affs, gt_affs, gt_seg, nhood, affs_shape, name='wbce_malis', limit_z=False) wbce_malis_loss = tf.identity(wbce_malis_loss, name='wbce_malis_loss') dice_loss = tf.identity( (1. - dice_coe(affs, gt_affs, axis=[0, 1, 2, 3], loss_type='jaccard')) * 0.1, name='dice_coe') tot_loss = tf.identity(wbce_malis_loss + dice_loss, name='tot_loss') ###################################################################################################################### self.cost = tot_loss summary.add_tensor_summary(tot_loss, types=['scalar']) summary.add_tensor_summary(dice_loss, types=['scalar']) summary.add_tensor_summary(wbce_malis_loss, types=['scalar'])
def _build_graph(self, inputs): G = tf.get_default_graph() # For round tf.local_variables_initializer() tf.global_variables_initializer() pi, pm, pl, ui, um, ul = inputs pi = cvt2tanh(pi) pm = cvt2tanh(pm) pl = cvt2tanh(pl) ui = cvt2tanh(ui) um = cvt2tanh(um) ul = cvt2tanh(ul) # def tf_membr(label): # with freeze_variables(): # label = np_2imag(label, maxVal=MAX_LABEL) # label = np.squeeze(label) # Unimplemented: exceptions.NotImplementedError: Only for images of dimension 1-3 are supported, got a 4D one # # label, nb_labels = skimage.measure.label(color, return_num=True) # # label = np.expand_dims(label, axis=-1).astype(np.float32) # Modify here for batch # # for z in range(membr.shape[0]): # # membr[z,...] = 1-skimage.segmentation.find_boundaries(np.squeeze(label[z,...]), mode='thick') #, mode='inner' # membr = 1-skimage.segmentation.find_boundaries(np.squeeze(label), mode='thick') #, mode='inner' # membr = np.expand_dims(membr, axis=-1).astype(np.float32) # membr = np.expand_dims(membr, axis=0).astype(np.float32) # membr = np_2tanh(membr, maxVal=1.0) # membr = np.reshape(membr, label.shape) # return membr # def tf_label(color): # with freeze_variables(): # color = np_2imag(color, maxVal=MAX_LABEL) # color = np.squeeze(color) # Unimplemented: exceptions.NotImplementedError: Only for images of dimension 1-3 are supported, got a 4D one # label, nb_labels = skimage.measure.label(color, return_num=True) # label = np.expand_dims(label, axis=-1).astype(np.float32) # label = np.expand_dims(label, axis=0).astype(np.float32) # label = np_2tanh(label, maxVal=MAX_LABEL) # label = np.reshape(label, color.shape) # return label def tf_rand_score(x1, x2): return np.mean(1.0 - adjusted_rand_score(x1.flatten(), x2.flatten())) def rounded(label, factor=MAX_LABEL, name='quantized'): with G.gradient_override_map({"Round": "Identity"}): with freeze_variables(): with tf.name_scope(name=name): # label = cvt2imag(label, maxVal=factor) # label = tf.round(label) # label = cvt2tanh(label, maxVal=factor) # cvt from -1 ~ 1 to 0 255 # label = cvt2imag(label, maxVal=255.0) cond0 = tf.equal(label, -1.0 * tf.ones_like(label)) label = tf.where( cond0, tf.zeros_like(label), label, name='removedBackground') # From -1 to 0 label = label * factor # From 0~1 to 0~MAXLABEL label = tf.round(label) label = label / factor # From 0~MAXLABEL to 0~1 cond1 = tf.equal(label, 0.0 * tf.zeros_like(label)) label = tf.where( cond1, -1.0 * tf.ones_like(label), label, name='addedBackground') # From -1 to 0 return tf.identity(label, name=name) with argscope([Conv2D, Deconv2D, FullyConnected], W_init=tf.truncated_normal_initializer(stddev=0.02), use_bias=False), \ argscope(BatchNorm, gamma_init=tf.random_uniform_initializer()), \ argscope([Conv2D, Deconv2D, BatchNorm], data_format='NHWC'), \ argscope(LeakyReLU, alpha=0.2): with tf.variable_scope('gen'): # Real pair image 4 gen with tf.variable_scope('I2M'): pim, feat_im = self.generator(pi) with tf.variable_scope('M2L'): piml, feat_iml = self.generator(pim) pml, feat_ml = self.generator(pm) # piml = tf.py_func(tf_label, [(pim)], tf.float32) # pml = tf.py_func(tf_label, [(pm)], tf.float32) # print pim # print piml # with tf.variable_scope('L2M'): # # with freeze_variables(): # pimlm = self.generator(piml) # # plm = self.generator(pl) # pmlm = self.generator(pml) # # pimlm = tf.py_func(tf_membr, [(piml)], tf.float32) # # # plm = tf.py_func(tf_membr, [(pl) ], tf.float32) # # pmlm = tf.py_func(tf_membr, [(pml) ], tf.float32) # # print piml # # print pimlm # with tf.variable_scope('M2I'): # pimlmi = self.generator(pimlm) # # pimi = self.generator(pim) # # Real pair label 4 gen # with tf.variable_scope('L2M'): # # with freeze_variables(): # plm = self.generator(pl) # # plm = tf.py_func(tf_membr, [(pl) , tf.float32]) # with tf.variable_scope('M2I'): # plmi = self.generator(plm) # pmi = self.generator(pi) # with tf.variable_scope('I2M'): # plmim = self.generator(plmi) # # pim = self.generator(pi) # pmim = self.generator(pmi) # with tf.variable_scope('M2L'): # plmiml = self.generator(plmim) # # plml = self.generator(plm) # # plmiml = tf.py_func(tf_label, [(plmim)], tf.float32) # # plml = tf.py_func(tf_label, [(plm)], tf.float32) with tf.variable_scope('discrim'): # with tf.variable_scope('I'): # i_dis_real = self.discriminator(ui) # i_dis_fake_from_label = self.discriminator(plmi) with tf.variable_scope('M'): m_dis_real = self.discriminator(um) m_dis_fake_from_image = self.discriminator(pim) # m_dis_fake_from_label = self.discriminator(plm) with tf.variable_scope('L'): l_dis_real = self.discriminator(ul) l_dis_fake_from_image = self.discriminator(piml) piml = rounded(piml) # pml = rounded(pml) # plmiml = rounded(plmiml) # # plml = rounded(plml) # with tf.name_scope('Recon_I_loss'): # recon_imi = tf.reduce_mean(tf.abs((pi) - (pimi)), name='recon_imi') # recon_lmi = tf.reduce_mean(tf.abs((pi) - (plmi)), name='recon_lmi') # recon_imlmi = tf.reduce_mean(tf.abs((pi) - (pimlmi)), name='recon_imlmi') # with tf.name_scope('Recon_L_loss'): # recon_lml = tf.reduce_mean(tf.abs((pl) - (plml)), name='recon_lml') # recon_iml = tf.reduce_mean(tf.abs((pl) - (piml)), name='recon_iml') # recon_ml = tf.reduce_mean(tf.abs((pl) - (pml)), name='recon_ml') # recon_lmiml = tf.reduce_mean(tf.abs((pl) - (plmiml)), name='recon_lmiml') # # recon_iml = tf.reduce_mean(tf.cast(tf.not_equal(pl, piml), tf.float32), name='recon_iml') # recon_ml = tf.reduce_mean(tf.cast(tf.not_equal(pl, pml), tf.float32), name='recon_ml') recon_ml = tf.reduce_mean(tf.abs( cvt2imag(pl, maxVal=MAX_LABEL) - cvt2imag(pml, maxVal=MAX_LABEL)), name='recon_ml') recon_iml = tf.reduce_mean(tf.abs( cvt2imag(pl, maxVal=MAX_LABEL) - cvt2imag(piml, maxVal=MAX_LABEL)), name='recon_iml') # with tf.name_scope('Recon_M_loss'): # recon_mim = tf.reduce_mean(tf.abs((pm) - (pmim)), name='recon_mim') # recon_mlm = tf.reduce_mean(tf.abs((pm) - (pmlm)), name='recon_mlm') # recon_im = tf.reduce_mean(tf.abs((pm) - (pim)), name='recon_im') # recon_lm = tf.reduce_mean(tf.abs((pm) - (plm)), name='recon_lm') # recon_im = tf.reduce_mean(tf.cast(tf.not_equal(pm, pim), tf.float32), name='recon_im') recon_im = tf.reduce_mean( tf.abs(cvt2imag(pm, maxVal=1.0) - cvt2imag(pim, maxVal=1.0)), name='recon_im') with tf.name_scope('GAN_loss'): # G_loss_IL, D_loss_IL = self.build_losses(i_dis_real, i_dis_fake_from_label, name='IL') G_loss_LI, D_loss_LI = self.build_losses(l_dis_real, l_dis_fake_from_image, name='LL') G_loss_MI, D_loss_MI = self.build_losses(m_dis_real, m_dis_fake_from_image, name='MI') # G_loss_ML, D_loss_ML = self.build_losses(m_dis_real, m_dis_fake_from_label, name='ML') # custom loss for membr with tf.name_scope('membr_loss'): def membr_loss(y_true, y_pred, name='membr_loss'): loss = tf.reduce_mean( tf.subtract( binary_cross_entropy(cvt2imag(y_true, maxVal=1.0), cvt2imag(y_pred, maxVal=1.0)), dice_coe(cvt2imag(y_true, maxVal=1.0), cvt2imag(y_pred, maxVal=1.0), axis=[1, 2, 3], loss_type='jaccard'))) # loss = tf.reshape(loss, [-1]) return tf.identity(loss, name=name) membr_im = membr_loss(pm, pim, name='membr_im') # print membr_im # membr_lm = membr_loss(pm, plm, name='membr_lm') # membr_imlm = membr_loss(pm, pimlm, name='membr_imlm') # membr_lmim = membr_loss(pm, plmim, name='membr_lmim') # membr_mlm = membr_loss(pm, pmlm, name='membr_mlm') # membr_mim = membr_loss(pm, pmim, name='membr_mim') # custom loss for label with tf.name_scope('label_loss'): def label_loss(y_true_L, y_pred_L, y_grad_M, name='label_loss'): g_mag_grad_M = cvt2imag(y_grad_M, maxVal=1.0) mag_grad_L = magnitute_central_difference(y_pred_L, name='mag_grad_L') cond = tf.greater(mag_grad_L, tf.zeros_like(mag_grad_L)) thresholded_mag_grad_L = tf.where( cond, tf.ones_like(mag_grad_L), tf.zeros_like(mag_grad_L), name='thresholded_mag_grad_L') gtv_guess = tf.multiply(g_mag_grad_M, thresholded_mag_grad_L, name='gtv_guess') loss_gtv_guess = tf.reduce_mean(gtv_guess, name='loss_gtv_guess') # loss_gtv_guess = tf.reshape(loss_gtv_guess, [-1]) thresholded_mag_grad_L = cvt2tanh(thresholded_mag_grad_L, maxVal=1.0) gtv_guess = cvt2tanh(gtv_guess, maxVal=1.0) return tf.identity(loss_gtv_guess, name=name), thresholded_mag_grad_L label_iml, g_iml = label_loss(None, piml, pim, name='label_iml') # label_lml, g_lml = label_loss(None, plml, plm, name='label_lml') # label_lmiml, g_lmiml = label_loss(None, plmiml, plmim, name='label_lmiml') label_ml, g_ml = label_loss(None, pml, pm, name='label_loss_ml') # custom loss for tf_rand_score with tf.name_scope('rand_loss'): rand_iml = tf.reduce_mean( tf.cast(tf.py_func(tf_rand_score, [piml, pl], tf.float64), tf.float32)) rand_ml = tf.reduce_mean( tf.cast(tf.py_func(tf_rand_score, [pml, pl], tf.float64), tf.float32)) with tf.name_scope('discrim_loss'): def regDLF(y_true, y_pred, alpha=1, beta=1, gamma=0.01, delta_v=0.5, delta_d=1.5, name='loss_discrim'): def tf_norm(inputs, axis=1, epsilon=1e-7, name='safe_norm'): squared_norm = tf.reduce_sum(tf.square(inputs), axis=axis, keep_dims=True) safe_norm = tf.sqrt(squared_norm + epsilon) return tf.identity(safe_norm, name=name) ### lins = tf.linspace(0.0, DIMZ * DIMY * DIMX, DIMZ * DIMY * DIMX) lins = tf.cast(lins, tf.int32) # lins = lins / tf.reduce_max(lins) * 255 # lins = cvt2tanh(lins) # lins = tf.reshape(lins, tf.shape(y_true), name='lins_3d') # print lins lins_z = tf.div(lins, (DIMY * DIMX)) lins_y = tf.div(tf.mod(lins, (DIMY * DIMX)), DIMY) lins_x = tf.mod(tf.mod(lins, (DIMY * DIMX)), DIMY) lins = tf.cast(lins, tf.float32) lins_z = tf.cast(lins_z, tf.float32) lins_y = tf.cast(lins_y, tf.float32) lins_x = tf.cast(lins_x, tf.float32) lins = lins / tf.reduce_max(lins) * 255 lins_z = lins_z / tf.reduce_max(lins_z) * 255 lins_y = lins_y / tf.reduce_max(lins_y) * 255 lins_x = lins_x / tf.reduce_max(lins_x) * 255 lins = cvt2tanh(lins) lins_z = cvt2tanh(lins_z) lins_y = cvt2tanh(lins_y) lins_x = cvt2tanh(lins_x) lins = tf.reshape(lins, tf.shape(y_true), name='lins') lins_z = tf.reshape(lins_z, tf.shape(y_true), name='lins_z') lins_y = tf.reshape(lins_y, tf.shape(y_true), name='lins_y') lins_x = tf.reshape(lins_x, tf.shape(y_true), name='lins_x') y_true = tf.reshape(y_true, [DIMZ * DIMY * DIMX]) y_pred = tf.concat([y_pred, lins, lins_z, lins_y, lins_x], axis=-1) nDim = tf.shape(y_pred)[-1] X = tf.reshape(y_pred, [DIMZ * DIMY * DIMX, nDim]) uniqueLabels, uniqueInd = tf.unique(y_true) numUnique = tf.size( uniqueLabels) # Get the number of connected component Sigma = tf.unsorted_segment_sum(X, uniqueInd, numUnique) # ones_Sigma = tf.ones((tf.shape(X)[0], 1)) ones_Sigma = tf.ones_like(X) ones_Sigma = tf.unsorted_segment_sum(ones_Sigma, uniqueInd, numUnique) mu = tf.divide(Sigma, ones_Sigma) Lreg = tf.reduce_mean(tf.norm(mu, axis=1, ord=1)) T = tf.norm(tf.subtract(tf.gather(mu, uniqueInd), X), axis=1, ord=1) T = tf.divide(T, Lreg) T = tf.subtract(T, delta_v) T = tf.clip_by_value(T, 0, T) T = tf.square(T) ones_Sigma = tf.ones_like(uniqueInd, dtype=tf.float32) ones_Sigma = tf.unsorted_segment_sum(ones_Sigma, uniqueInd, numUnique) clusterSigma = tf.unsorted_segment_sum(T, uniqueInd, numUnique) clusterSigma = tf.divide(clusterSigma, ones_Sigma) # Lvar = tf.reduce_mean(clusterSigma, axis=0) Lvar = tf.reduce_mean(clusterSigma) mu_interleaved_rep = tf.tile(mu, [numUnique, 1]) mu_band_rep = tf.tile(mu, [1, numUnique]) mu_band_rep = tf.reshape(mu_band_rep, (numUnique * numUnique, nDim)) mu_diff = tf.subtract(mu_band_rep, mu_interleaved_rep) # Remove zero vector # intermediate_tensor = reduce_sum(tf.abs(x), 1) # zero_vector = tf.zeros(shape=(1,1), dtype=tf.float32) # bool_mask = tf.not_equal(intermediate_tensor, zero_vector) # omit_zeros = tf.boolean_mask(x, bool_mask) intermediate_tensor = tf.reduce_sum(tf.abs(mu_diff), 1) zero_vector = tf.zeros(shape=(1, 1), dtype=tf.float32) bool_mask = tf.not_equal(intermediate_tensor, zero_vector) omit_zeros = tf.boolean_mask(mu_diff, bool_mask) mu_diff = tf.expand_dims(omit_zeros, axis=1) print mu_diff mu_diff = tf.norm(mu_diff, ord=1) # squared_norm = tf.reduce_sum(tf.square(s), axis=axis,keep_dims=True) # safe_norm = tf.sqrt(squared_norm + epsilon) # squared_norm = tf.reduce_sum(tf.square(omit_zeros), axis=-1,keep_dims=True) # safe_norm = tf.sqrt(squared_norm + 1e-6) # mu_diff = safe_norm mu_diff = tf.divide(mu_diff, Lreg) mu_diff = tf.subtract(2 * delta_d, mu_diff) mu_diff = tf.clip_by_value(mu_diff, 0, mu_diff) mu_diff = tf.square(mu_diff) numUniqueF = tf.cast(numUnique, tf.float32) Ldist = tf.reduce_mean(mu_diff) # L = alpha * Lvar + beta * Ldist + gamma * Lreg # L = tf.reduce_mean(L, keep_dims=True) L = tf.reduce_sum([alpha * Lvar, beta * Ldist, gamma * Lreg], keep_dims=False) print L print Ldist print Lvar print Lreg return tf.identity(L, name=name) discrim_im = regDLF(cvt2imag(pm, maxVal=1.0), feat_im, name='discrim_im') discrim_iml = regDLF(cvt2imag(pl, maxVal=MAX_LABEL), feat_iml, name='discrim_iml') discrim_ml = regDLF(cvt2imag(pl, maxVal=MAX_LABEL), feat_ml, name='discrim_ml') print discrim_im print discrim_iml print discrim_ml print rand_iml print rand_ml self.g_loss = tf.reduce_sum( [ #(recon_imi), # + recon_lmi + recon_imlmi), # 10 * (recon_iml), # + recon_lml + recon_lmiml), # 10 * (recon_im), # + recon_lm + recon_mim + recon_mlm), 10 * (recon_ml), # + recon_lm + recon_mim + recon_mlm), (rand_iml), # + rand_lml + rand_lmiml), # (rand_ml), # + rand_lm + rand_mim + rand_mlm), # (G_loss_IL + G_loss_LI + G_loss_MI + G_loss_ML), (G_loss_LI + G_loss_MI), (0.1 * discrim_im + 10 * discrim_iml + 10 * discrim_ml), ( 0.001 * membr_im ), # + membr_lm + membr_imlm + membr_lmim + membr_mlm + membr_mim), # (label_iml + label_lml + label_lmiml + label_ml) (label_iml + label_ml) ], name='G_loss_total') self.d_loss = tf.reduce_sum( [ # (D_loss_IL + D_loss_LI + D_loss_MI + D_loss_ML), (D_loss_LI + D_loss_MI), ], name='D_loss_total') wd_g = regularize_cost('gen/.*/W', l2_regularizer(1e-5), name='G_regularize') wd_d = regularize_cost('discrim/.*/W', l2_regularizer(1e-5), name='D_regularize') self.g_loss = tf.add(self.g_loss, wd_g, name='g_loss') self.d_loss = tf.add(self.d_loss, wd_d, name='d_loss') self.collect_variables() add_moving_summary(self.d_loss, self.g_loss) # add_moving_summary( with tf.name_scope('summaries'): add_tensor_summary(recon_iml, types=['scalar'], name='recon_iml') add_tensor_summary(recon_im, types=['scalar'], name='recon_im') add_tensor_summary(recon_ml, types=['scalar'], name='recon_ml') add_tensor_summary(label_iml, types=['scalar'], name='label_iml') add_tensor_summary(label_ml, types=['scalar'], name='label_ml') add_tensor_summary(rand_iml, types=['scalar'], name='rand_iml') add_tensor_summary(rand_ml, types=['scalar'], name='rand_ml') add_tensor_summary(membr_im, types=['scalar'], name='membr_im') add_tensor_summary(discrim_im, types=['scalar'], name='discrim_im') add_tensor_summary(discrim_iml, types=['scalar'], name='discrim_iml') add_tensor_summary(discrim_ml, types=['scalar'], name='discrim_ml') # recon_imi, recon_lmi, recon_imlmi, # recon_lml, recon_iml, recon_lmiml, # recon_mim, recon_mlm, recon_im , recon_lm, # ) viz = tf.concat( [ tf.concat([ui, pi, pim, piml, g_iml], 2), # tf.concat([ul, pl, plm, plmi, plmim, plmiml], 2), tf.concat([um, pl, pm, pml, g_ml], 2), # tf.concat([pl, pl, g_iml, g_lml, g_lmiml, g_ml], 2), ], 1) # add_moving_summary( # recon_imi, recon_lmi,# recon_imlmi, # recon_lml, recon_iml,# recon_lmiml, # recon_mim, recon_mlm, recon_im , recon_lm, # ) # viz = tf.concat([tf.concat([ui, pi, pim, piml], 2), # tf.concat([ul, pl, plm, plmi], 2), # tf.concat([um, pm, pmi, pmim], 2), # tf.concat([um, pm, pml, pmlm], 2), # ], 1) viz = cvt2imag(viz) viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz') tf.summary.image('colorized', viz, max_outputs=50)