def create_ntn(image, label, n_class, filter_size, num_of_feature, num_of_layers, keep_prob, name, reg_weight, debug, restore=False, weights=None, Unsupervised=False, ClassWeights=[1, 0.2, 1]): with tf.name_scope("NTN"): noise_y_out, clean_y_out, MapTransProb, variables, TransProbVar, dw_h_convs = NTN( image, n_class, filter_size, num_of_feature, num_of_layers, keep_prob, name, debug, restore, weights, Unsupervised) # summary if debug: for var in variables: utils.add_to_regularization_and_summary(var) for var in TransProbVar: utils.add_to_regularization_and_summary(var) with tf.name_scope("loss"): # loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = tf.reshape(label, [-1, n_class]), # logits = tf.multiply(tf.reshape(y_conv, [-1, n_class]), ClassWeights))) loss = utils.cross_entropy( tf.cast(tf.reshape(label, [-1, n_class]), tf.float32), tf.reshape(noise_y_out, [-1, n_class])) weight_decay = 0 if reg_weight != None: for var in TransProbVar: weight_decay = weight_decay + tf.nn.l2_loss(var) for var in variables: weight_decay = weight_decay + tf.nn.l2_loss(var) loss = tf.reduce_sum(loss + reg_weight * weight_decay, name='loss') if debug: utils.add_scalar_summary(loss) return loss, noise_y_out, clean_y_out, MapTransProb, variables, TransProbVar, dw_h_convs
def ANTN(image, label, n_class, filter_size, num_of_branch, num_of_feature, num_of_layers, clean_network_hidden, trans_network_hidden, debug, keep_prob=1.0, restore_clean=False, restore_tran=False, weights=None, all_tran_var=None): with tf.name_scope("clean-network"): clean_y_feat, clean_var, layer_id = unet(image, n_class, filter_size, num_of_feature, num_of_layers, keep_prob, 'main', debug, restore_clean, weights) clean_y_out = tf.clip_by_value( tf.reshape(tf.nn.softmax(tf.reshape(clean_y_feat, [-1, n_class])), tf.shape(clean_y_feat), 'clean_map'), 1e-6, 1.0) #softmax y output # summary if debug: utils.add_activation_summary(clean_y_out) utils.add_to_image_summary(clean_y_out) for var in clean_var: utils.add_to_regularization_and_summary(var) # branch process all_tran_y_out = [] all_tran_var = [] for branch in range(num_of_branch): with tf.name_scope("transition-network" + str(branch)): if restore_trans == True: tran_y_feat, tran_var, layer_id = unet( image, n_class * n_class, filter_size, num_of_feature, num_of_layers, keep_prob, 'tran' + str(branch), debug, restore_tran, all_tran_var[branch]) else: tran_y_feat, tran_var, layer_id = unet( image, n_class * n_class, filter_size, num_of_feature, num_of_layers, keep_prob, 'tran' + str(branch), debug) class_tran_y_out = [] for i in range(n_class): class_tran_y_out.append( tf.reshape(tf.nn.softmax( tf.reshape( tran_y_feat[:, :, :, i * n_class:(i * n_class + n_class)], [-1, n_class])), tf.shape(clean_y_feat), name='tran_map' + str(i))) tran_y_out = tf.clip_by_value( tf.concat(class_tran_y_out, 3, name='tran_map'), 1e-6, 1.0) all_tran_y_out.append(tran_y_out) all_tran_var.append(tran_var) # summary if debug: # for clss in class_tran_y_out: # utils.add_activation_summary(clss) # utils.add_to_image_summary(clss) # for var in tran_var: # utils.add_to_regularization_and_summary(var) for i in range(n_class): for j in range(n_class): z = tf.identity(tran_y_out[:, :, :, i], name=str(i) + 'to' + str(j)) utils.add_activation_summary(z) for var in tran_var: utils.add_to_regularization_and_summary(var) # branch process all_noise_y_out = [] for branch in range(num_of_branch): with tf.name_scope("integration" + str(branch)): noise_y_out = tf.reshape(tf.matmul( tf.reshape(clean_y_out, [-1, 1, n_class]), tf.reshape(all_tran_y_out[branch], [-1, n_class, n_class])), tf.shape(clean_y_feat), name='noise_output') all_noise_y_out.append(noise_y_out) # summary if debug: utils.add_activation_summary(noise_y_out) utils.add_to_image_summary(noise_y_out) with tf.name_scope("loss"): # clean_net_loss = utils.cross_entropy(tf.reshape(clean_network_hidden, [-1, n_class]), tf.reshape(clean_y_out, [-1, n_class]), 'clean_net_loss') # noise_loss = utils.cross_entropy(tf.cast(tf.reshape(label, [-1, n_class]), tf.float32), tf.reshape(noise_y_out, [-1, n_class]), 'noise_loss') # trans_net_loss = utils.cross_entropy(tf.reshape(trans_network_hidden, [-1, n_class * n_class]), tf.reshape(tran_y_out, [-1, n_class * n_class]), 'trans_net_loss') # clean_net_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = tf.reshape(clean_network_hidden, [-1, n_class]), # logits = tf.reshape(clean_y_feat, [-1, n_class])), # name = 'clean_net_loss') clean_net_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.reshape(clean_network_hidden, [-1]), logits=tf.reshape(clean_y_feat, [-1, n_class])), name='clean_net_loss') # branch process all_noise_loss = [] all_trans_net_loss = [] for branch in range(num_of_branch): noise_loss = utils.cross_entropy( tf.cast(tf.reshape(label[branch], [-1, n_class]), tf.float32), tf.reshape(all_noise_y_out[branch], [-1, n_class]), 'noise_loss') trans_net_loss = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits( labels=tf.reshape(trans_network_hidden[branch], [-1, n_class * n_class]), logits=tf.reshape(all_tran_y_out[branch], [-1, n_class * n_class])), name='trans_net_loss') all_noise_loss.append(noise_loss) all_trans_net_loss.append(trans_net_loss) if debug: utils.add_scalar_summary(noise_loss) return all_noise_loss, clean_net_loss, all_trans_net_loss, clean_y_out, all_tran_y_out, all_noise_y_out, clean_var, all_tran_var
def create_autoencoder_antn(image, decoded_image, label, posterior_prob, n_class, filter_size, num_of_feature, num_of_layers, keep_prob, debug, constrain_weight=0, reg_weight=0, restore=False, shared_weights=None, M_weights=None, AE_weights=None): channel = image.get_shape().as_list()[-1] with tf.name_scope("autoencoder-antn"): noise_y_out, clean_y_out, y_conv, tran_y_feat, tran_map, AE_conv, shared_variables, AE_variables, M_variables, inter_feat = autoencorder_antn( image, n_class, filter_size, num_of_feature, num_of_layers, keep_prob, debug, restore, shared_weights, M_weights, AE_weights) with tf.name_scope("loss"): # M_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels = tf.reshape(tf.argmax(label, 3), [-1]), # logits = tf.reshape(noise_y_out, [-1, n_class])), # name = 'M_loss') M_loss = tf.reduce_mean(utils.cross_entropy( tf.cast(tf.reshape(label, [-1, n_class]), tf.float32), tf.reshape(noise_y_out, [-1, n_class])), name='M_loss') M_hidden_loss = tf.reduce_mean( l2_loss( tf.cast(tf.reshape(posterior_prob, [-1, n_class]), tf.float32), tf.reshape(clean_y_out, [-1, n_class]))) # M_hidden_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels = tf.reshape(tf.argmax(posterior_prob, 3), [-1]), # logits = tf.reshape(y_conv, [-1, n_class])), # name = 'M_hidden_loss') M_hidden_loss_reg = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.reshape(tf.argmax(label, 3), [-1]), logits=tf.reshape(y_conv, [-1, n_class])), name='M_hidden_loss_reg') AE_loss = tf.reduce_mean( l2_loss(tf.reshape(decoded_image, [-1, channel]), tf.reshape(AE_conv, [-1, channel]))) weight_decay1 = 0 weight_decay2 = 0 weight_decay3 = 0 if reg_weight != None: for var in (shared_variables + M_variables + AE_variables): weight_decay1 = weight_decay1 + tf.nn.l2_loss(var) if reg_weight != None: for var in (shared_variables + AE_variables): weight_decay2 = weight_decay2 + tf.nn.l2_loss(var) if reg_weight != None: for var in (shared_variables + M_variables): weight_decay3 = weight_decay3 + tf.nn.l2_loss(var) loss = tf.reduce_sum((1 - constrain_weight) * M_hidden_loss + constrain_weight * M_hidden_loss_reg + AE_loss + reg_weight * weight_decay1, name='loss') M_hidden_loss = tf.reduce_sum( (1 - constrain_weight) * M_hidden_loss + constrain_weight * M_hidden_loss_reg + reg_weight * weight_decay3, name='M_hidden_loss') AE_loss = tf.reduce_sum(AE_loss + reg_weight * weight_decay2, name='AE_loss') # M_loss = tf.reduce_sum(M_loss + reg_weight*weight_decay, name = 'M_loss') # AE_loss = tf.reduce_sum(AE_loss + reg_weight*weight_decay, name = 'AE_loss') if debug: utils.add_scalar_summary(loss) utils.add_scalar_summary(M_loss) utils.add_scalar_summary(AE_loss) for var in (shared_variables + AE_variables + M_variables): utils.add_to_regularization_and_summary(var) return noise_y_out, clean_y_out, tran_y_feat, tran_map, AE_conv, loss, M_loss, M_hidden_loss, AE_loss, shared_variables, AE_variables, M_variables, inter_feat
def main(CHANNEL, NClass, FILTER_SIZE, NUM_OF_FEATURE, NUM_OF_LAYERS, NoisyInput, NoisyOutput, CleanInput, CleanOutput, MAX_EPOCH, BatchSize, KEEP_PROB, REG_WEIGHT, LearningRate, RESTORE, SizeLimitation, DirSave, DirLoad=None, Debug=True): ImageSize = [NoisyInput.shape[1], NoisyInput.shape[2]] OutputImageSize = [NoisyOutput.shape[1], NoisyOutput.shape[2]] NAME = 'unet' tf.reset_default_graph() with tf.name_scope("Input"): image = tf.placeholder( tf.float32, shape=[None, ImageSize[0], ImageSize[1], CHANNEL], name="input_image") NoisyLabel = tf.placeholder( tf.int64, shape=[None, OutputImageSize[0], OutputImageSize[1], NClass], name="NoisyLabel") CleanLabel = tf.placeholder( tf.int64, shape=[None, OutputImageSize[0], OutputImageSize[1], NClass], name="CleanLabel") keep_prob = tf.placeholder(tf.float32, shape=[], name="keep_prob") utils.add_to_image_summary(image) # 50 data are split to training set and validatation set # 70% training set and 30% validation set # if np.max(image_data) > 1: # # image_data = image_data # image_data = model.normalize(image_data) # size = np.shape(NoisyInput)[0] # tr_size = np.int(size * 0.7) # BATCHES = tr_size # val_size = size - tr_size # tr_image_data = ImageData[0 : tr_size] # tr_label_data = NoisyData[0 : tr_size] # TrCleanLabel = CleanData[0 : tr_size] # val_image_data = ImageData[tr_size : size] # val_label_data = NoisyData[tr_size : size] # ValCleanLabel = CleanData[tr_size : size] BATCHES = np.shape(NoisyInput)[0] tr_image_data = NoisyInput tr_label_data = NoisyOutput val_image_data = CleanInput val_label_data = CleanOutput with tf.name_scope("net"): if RESTORE == True: weights = np.load(DirLoad) NoisyLoss, CleanOut, variables, dw_h_convs = model.create_unet( image, NoisyLabel, NClass, FILTER_SIZE, NUM_OF_FEATURE, NUM_OF_LAYERS, keep_prob, NAME, REG_WEIGHT, Debug, restore=RESTORE, weights=weights) print("Model restored...") else: NoisyLoss, CleanOut, variables, dw_h_convs = model.create_unet( image, NoisyLabel, NClass, FILTER_SIZE, NUM_OF_FEATURE, NUM_OF_LAYERS, keep_prob, NAME, REG_WEIGHT, Debug) CleanLoss = utils.cross_entropy( tf.cast(tf.reshape(CleanLabel, [-1, NClass]), tf.float32), tf.reshape(CleanOut, [-1, NClass]), 'Cleanloss') # utils.add_scalar_summary(CleanLoss) # NoiseAcc = tf.reduce_mean(tf.cast(tf.reshape(tf.equal(NoisyLabel, tf.argmax(CleanOut, 3)), [-1]), tf.float32), name = 'NoiseAcc') NoiseAcc = tf.reduce_mean(tf.cast( tf.reshape( tf.equal(tf.argmax(NoisyLabel, 3), tf.argmax(CleanOut, 3)), [-1]), tf.float32), name='NoiseAcc') CleanAcc = tf.reduce_mean(tf.cast( tf.equal(tf.argmax(CleanLabel, 3), tf.argmax(CleanOut, 3)), tf.float32), name='CleanAcc') utils.add_scalar_summary(NoiseAcc) utils.add_scalar_summary(CleanAcc) utils.add_scalar_summary(CleanLoss) with tf.name_scope("Train"): trainable_var = tf.trainable_variables() train_op = train(NoisyLoss, trainable_var, LearningRate, Debug) print("Setting up summary op...") summary_op = tf.summary.merge_all() # uncomment BELOW TO RUNNING ON CPU # pdb.set_trace() # config = tf.ConfigProto(device_count = {'GPU': 0}) # sess = tf.Session(config=config) # uncomment to run on GPU sess = tf.Session() ############################### print("Setting up Saver...") saver = tf.train.Saver() summary_writer = tf.summary.FileWriter(DirSave, sess.graph) ################# # Insert code of data file checking here ################# sess.run(tf.global_variables_initializer()) tr_image_batch1 = tr_image_data[0:SizeLimitation] tr_label_batch1 = tr_label_data[0:SizeLimitation] val_image_batch = val_image_data[0:SizeLimitation] val_label_batch = val_label_data[0:SizeLimitation] total_iter = 0 for epoch in range(MAX_EPOCH): for batch in range(0, BATCHES / BatchSize): # for batch in [0]: # image: [batch, row, col, channel] # label: [batch, row, col, n_class] tr_image_batch = tr_image_data[batch * BatchSize:batch * BatchSize + BatchSize] tr_label_batch = tr_label_data[batch * BatchSize:batch * BatchSize + BatchSize] tr_feed_dict = { image: tr_image_batch, NoisyLabel: tr_label_batch, keep_prob: np.float32(KEEP_PROB) } tr_feed_dict1 = { image: tr_image_batch1, NoisyLabel: tr_label_batch1, CleanLabel: val_label_batch, keep_prob: np.float32(KEEP_PROB) } val_feed_dict = { image: val_image_batch, CleanLabel: val_label_batch, keep_prob: np.float32(KEEP_PROB) } # pdb.set_trace() # trainining set if (total_iter) % 10 == 0: # pre_seg, _NoisyLoss, _CleanLoss, _CleanAcc, _NoiseAcc, tr_variables, summary_str = sess.run([CleanOut, NoisyLoss, CleanLoss, CleanAcc, NoiseAcc, # variables, summary_op], feed_dict = tr_feed_dict1) _dw_h_convs, _NoisyLoss, pre_seg, _NoiseAcc, tr_variables, summary_str, = sess.run( [ dw_h_convs, NoisyLoss, CleanOut, NoiseAcc, variables, summary_op ], feed_dict=tr_feed_dict1) # print("Iter: %d, TrainNoisyLoss: %g, TrainNoiseAcc: %g" % (total_iter, _NoisyLoss, _NoiseAcc)) summary_writer.add_summary(summary_str, total_iter) saver.save(sess, DirSave + "model.ckpt", total_iter) np.save(DirSave + "weights", tr_variables) # validation set if (total_iter) % 50 == 0: # _NoisyLoss, _CleanLoss, _NoiseAcc, _CleanAcc = sess.run([NoisyLoss, CleanLoss, NoiseAcc, CleanAcc], feed_dict = val_feed_dict) # print("Iter: %d, ValNoisyLoss: %g, ValCleanLoss: %g, ValNoiseAcc: %g, ValCleanAcc: %g, curent_time: %s" % # (total_iter, _NoisyLoss, _CleanLoss, _NoiseAcc, _CleanAcc, str(datetime.datetime.now()))) _CleanLoss, _CleanAcc = sess.run([CleanLoss, CleanAcc], feed_dict=val_feed_dict) # print("Iter: %d, ValCleanLoss: %g, ValCleanAcc: %g, curent_time: %s" % # (total_iter, _CleanLoss, _CleanAcc, str(datetime.datetime.now()))) sess.run(train_op, feed_dict=tr_feed_dict) total_iter += 1 new_index = random.sample(range(BATCHES), BATCHES) tr_image_data = tr_image_data[new_index] tr_label_data = tr_label_data[new_index] sess.close() return _CleanAcc, _NoiseAcc