def main(_): # placeholders # -> raw input data x_raw = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3]) label_raw = tf.placeholder(tf.int32, shape=[None, 1]) # -> parameters involved with tf.name_scope('isTraining'): is_Training = tf.placeholder(tf.bool) with tf.name_scope('isPhase'): is_Phase = tf.placeholder(tf.bool) with tf.name_scope('learningRate'): lr = tf.placeholder(tf.float32) with tf.name_scope('lambdas'): lambda1 = tf.placeholder(tf.float32) lambda2 = tf.placeholder(tf.float32) lambda3 = tf.placeholder(tf.float32) lambda4 = tf.placeholder(tf.float32) # -> feature extractor layers with tf.variable_scope('FeatureExtractor'): # modified googLeNet layer pre-trained on ILSVRC2012 google_net_model = GoogleNet_Model.GoogleNet_Model() embedding_gn = google_net_model.forward(x_raw) # batch normalization of average pooling layer of googLeNet embedding = nn_Ops.bn_block(embedding_gn, normal=True, is_Training=is_Training, name='BN') # 3 fully connected layers of Size EMBEDDING_SIZE # mean of cluster embedding_mu = nn_Ops.fc_block(embedding, in_d=1024, out_d=EMBEDDING_SIZE, name='fc1', is_bn=False, is_relu=False, is_Training=is_Training) # log(sigma^2) of cluster embedding_sigma = nn_Ops.fc_block(embedding, in_d=1024, out_d=EMBEDDING_SIZE, name='fc2', is_bn=False, is_relu=False, is_Training=is_Training) # invariant feature of cluster embedding_zi = nn_Ops.fc_block(embedding, in_d=1024, out_d=EMBEDDING_SIZE, name='fc3', is_bn=False, is_relu=False, is_Training=is_Training) with tf.name_scope('Loss'): def exclude_batch_norm(name): return 'batch_normalization' not in name and 'Generator' not in name and 'Loss' not in name wdLoss = 5e-3 * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if exclude_batch_norm(v.name) ]) label = tf.reduce_mean(label_raw, axis=1) J_m = Losses.triplet_semihard_loss(label, embedding_zi) + wdLoss # Generator with tf.variable_scope('Generator'): embedding_re = samplingGaussian(embedding_mu, embedding_sigma) embedding_zv = tf.reshape(embedding_re, (-1, EMBEDDING_SIZE)) # Z = Zi + Zv embedding_z = tf.add(embedding_zi, embedding_zv, name='Synthesized_features') # Decoder with tf.variable_scope('Decoder'): embedding_y1 = nn_Ops.fc_block(embedding_z, in_d=EMBEDDING_SIZE, out_d=512, name='decoder1', is_bn=True, is_relu=True, is_Training=is_Phase) embedding_y2 = nn_Ops.fc_block(embedding_y1, in_d=512, out_d=1024, name='decoder2', is_bn=False, is_relu=False, is_Training=is_Phase) print("embedding_sigma", embedding_sigma) print("embedding_mu", embedding_mu) # Defining the 4 Losses # L1 loss # Definition: L1 = (1/2) x sum( 1 + log(sigma^2) - mu^2 - sigma^2) with tf.name_scope('L1_KLDivergence'): kl_loss = 1 + embedding_sigma - K.square(embedding_mu) - K.exp( embedding_sigma) kl_loss = K.sum(kl_loss, axis=-1) kl_loss *= -(0.5 / BATCH_SIZE) L1 = lambda1 * K.mean(kl_loss) # L2 Loss # Definition: L2 = sum( L2Norm( target - outputOf(GoogleNet) )) with tf.name_scope('L2_Reconstruction'): L2 = lambda2 * (1 / (20 * BATCH_SIZE)) * tf.reduce_sum( tf.square(embedding_y2 - embedding_gn)) # L3 Loss # Definition: L3 = Lm( Z ) with tf.name_scope('L3_Synthesized'): L3 = lambda3 * Losses.triplet_semihard_loss(labels=label, embeddings=embedding_z) # L4 Loss # Definition: L4 = Lm( Zi ) with tf.name_scope('L4_Metric'): L4 = lambda4 * J_m # Classifier Loss with tf.name_scope('Softmax_Loss'): cross_entropy, W_fc, b_fc = Losses.cross_entropy( embedding=embedding_gn, label=label) c_train_step = nn_Ops.training(loss=L4 + L3 + L1, lr=lr, var_scope='FeatureExtractor') g_train_step = nn_Ops.training(loss=L2, lr=LR_gen, var_scope='Decoder') s_train_step = nn_Ops.training(loss=cross_entropy, lr=LR_s, var_scope='Softmax_classifier') def model_summary(): model_vars = tf.trainable_variables() slim.model_analyzer.analyze_vars(model_vars, print_info=True) model_summary() with tf.Session(config=config) as sess: summary_writer = tf.summary.FileWriter(LOGDIR, sess.graph) sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() _lr = LR_init # To Record the losses to TfBoard Jm_loss = nn_Ops.data_collector(tag='Jm', init=1e+6) L1_loss = nn_Ops.data_collector(tag='KLDivergence', init=1e+6) L2_loss = nn_Ops.data_collector(tag='Reconstruction', init=1e+6) L3_loss = nn_Ops.data_collector(tag='Synthesized', init=1e+6) L4_loss = nn_Ops.data_collector(tag='Metric', init=1e+6) cross_entropy_loss = nn_Ops.data_collector(tag='cross_entropy', init=1e+6) wd_Loss = nn_Ops.data_collector(tag='weight_decay', init=1e+6) max_nmi = 0 print("Phase 1") step = 0 epoch_iterator = stream_train.get_epoch_iterator() for epoch in tqdm(range(NUM_EPOCHS_PHASE1)): print("Epoch: ", epoch) for batch in tqdm(copy.copy(epoch_iterator), total=MAX_ITER): step += 1 # get images and labels from batch x_batch_data, Label_raw = nn_Ops.batch_data(batch) # training step c_train, g_train, s_train, wd_Loss_var, L1_var,L4_var, J_m_var, \ L3_var, L2_var, cross_en_var = sess.run( [c_train_step, g_train_step, s_train_step, wdLoss, L1, L4, J_m, L3, L2, cross_entropy], feed_dict={x_raw: x_batch_data, label_raw: Label_raw, is_Training: True,is_Phase:False, lambda1:1, lambda2:1, lambda3:0.1, lambda4:1, lr: _lr}) Jm_loss.update(var=J_m_var) L1_loss.update(var=L1_var) L2_loss.update(var=L2_var) L3_loss.update(var=L3_var) L4_loss.update(var=L4_var) cross_entropy_loss.update(cross_en_var) wd_Loss.update(var=wd_Loss_var) # evaluation if step % RECORD_STEP == 0: nmi_te, f1_te, recalls_te = evaluation.Evaluation( stream_test, image_mean, sess, x_raw, label_raw, is_Training, is_Phase, embedding_zi, 98, neighbours) # Summary eval_summary = tf.Summary() eval_summary.value.add(tag='test nmi', simple_value=nmi_te) eval_summary.value.add(tag='test f1', simple_value=f1_te) for i in range(0, np.shape(neighbours)[0]): eval_summary.value.add(tag='Recall@%d test' % neighbours[i], simple_value=recalls_te[i]) Jm_loss.write_to_tfboard(eval_summary) eval_summary.value.add(tag='learning_rate', simple_value=_lr) L1_loss.write_to_tfboard(eval_summary) L2_loss.write_to_tfboard(eval_summary) L3_loss.write_to_tfboard(eval_summary) L4_loss.write_to_tfboard(eval_summary) wd_Loss.write_to_tfboard(eval_summary) cross_entropy_loss.write_to_tfboard(eval_summary) summary_writer.add_summary(eval_summary, step) print('Summary Recorder') if nmi_te > max_nmi: max_nmi = nmi_te saver.save(sess, os.path.join(LOGDIR, "model.ckpt")) print("Model Saved") summary_writer.flush() print("Phase 2") epoch_iterator = stream_train.get_epoch_iterator() for epoch in tqdm(range(NUM_EPOCHS_PHASE2)): print("Epoch: ", epoch) for batch in tqdm(copy.copy(epoch_iterator), total=MAX_ITER): step += 1 # get images and labels from batch x_batch_data, Label_raw = nn_Ops.batch_data(batch) # training step c_train, g_train, s_train, wd_Loss_var, L1_var,L4_var, J_m_var, \ L3_var, L2_var, cross_en_var = sess.run( [c_train_step, g_train_step, s_train_step, wdLoss, L1, L4, J_m, L3, L2, cross_entropy], feed_dict={x_raw: x_batch_data, label_raw: Label_raw, is_Training: True,is_Phase:True, lambda1:0.8, lambda2:1, lambda3:0.2, lambda4:0.8, lr: _lr}) Jm_loss.update(var=J_m_var) L1_loss.update(var=L1_var) L2_loss.update(var=L2_var) L3_loss.update(var=L3_var) L4_loss.update(var=L4_var) wd_Loss.update(var=wd_Loss_var) cross_entropy_loss.update(cross_en_var) # evaluation if step % RECORD_STEP == 0: nmi_te, f1_te, recalls_te = evaluation.Evaluation( stream_test, image_mean, sess, x_raw, label_raw, is_Training, is_Phase, embedding_zi, 98, neighbours) # Summary eval_summary = tf.Summary() eval_summary.value.add(tag='test nmi', simple_value=nmi_te) eval_summary.value.add(tag='test f1', simple_value=f1_te) for i in range(0, np.shape(neighbours)[0]): eval_summary.value.add(tag='Recall@%d test' % neighbours[i], simple_value=recalls_te[i]) Jm_loss.write_to_tfboard(eval_summary) eval_summary.value.add(tag='learning_rate', simple_value=_lr) L1_loss.write_to_tfboard(eval_summary) L2_loss.write_to_tfboard(eval_summary) L3_loss.write_to_tfboard(eval_summary) L4_loss.write_to_tfboard(eval_summary) wd_Loss.write_to_tfboard(eval_summary) cross_entropy_loss.write_to_tfboard(eval_summary) summary_writer.add_summary(eval_summary, step) print('Summary Recorder') if nmi_te > max_nmi: max_nmi = nmi_te saver.save(sess, os.path.join(LOGDIR, "model.ckpt")) print("Model Saved") summary_writer.flush()
def main(_): if not FLAGS.LossType == 'NpairLoss': print("LossType n-pair-loss is required") return 0 # placeholders x_raw = tf.placeholder( tf.float32, shape=[None, FLAGS.default_image_size, FLAGS.default_image_size, 3]) label_raw = tf.placeholder(tf.int32, shape=[None, 1]) with tf.name_scope('istraining'): is_Training = tf.placeholder(tf.bool) with tf.name_scope('learning_rate'): lr = tf.placeholder(tf.float32) with tf.variable_scope('Classifier'): google_net_model = GoogleNet_Model.GoogleNet_Model() embedding = google_net_model.forward(x_raw) if FLAGS.Apply_HDML: embedding_y_origin = embedding # Batch Normalization layer 1 embedding = nn_Ops.bn_block(embedding, normal=FLAGS.normalize, is_Training=is_Training, name='BN1') # FC layer 1 embedding_z = nn_Ops.fc_block(embedding, in_d=1024, out_d=FLAGS.embedding_size, name='fc1', is_bn=False, is_relu=False, is_Training=is_Training) # Embedding Visualization assignment, embedding_var = Embedding_Visualization.embedding_assign( batch_size=256, embedding=embedding_z, embedding_size=FLAGS.embedding_size, name='Embedding_of_fc1') # conventional Loss function with tf.name_scope('Loss'): # wdLoss = layers.apply_regularization(regularizer, weights_list=None) def exclude_batch_norm(name): return 'batch_normalization' not in name and 'Generator' not in name and 'Loss' not in name wdLoss = FLAGS.Regular_factor * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if exclude_batch_norm(v.name) ]) # Get the Label label = tf.reduce_mean(label_raw, axis=1, keep_dims=False) # For some kinds of Losses, the embedding should be l2 normed #embedding_l = embedding_z J_m = Loss_ops.Loss(embedding_z, label, FLAGS.LossType) + wdLoss # if HNG is applied if FLAGS.Apply_HDML: with tf.name_scope('Javg'): Javg = tf.placeholder(tf.float32) with tf.name_scope('Jgen'): Jgen = tf.placeholder(tf.float32) embedding_z_quta = HDML.Pulling(FLAGS.LossType, embedding_z, Javg) embedding_z_concate = tf.concat([embedding_z, embedding_z_quta], axis=0) # Generator with tf.variable_scope('Generator'): # generator fc3 embedding_y_concate = nn_Ops.fc_block(embedding_z_concate, in_d=FLAGS.embedding_size, out_d=512, name='generator1', is_bn=True, is_relu=True, is_Training=is_Training) # generator fc4 embedding_y_concate = nn_Ops.fc_block(embedding_y_concate, in_d=512, out_d=1024, name='generator2', is_bn=False, is_relu=False, is_Training=is_Training) embedding_yp, embedding_yq = HDML.npairSplit(embedding_y_concate) with tf.variable_scope('Classifier'): embedding_z_quta = nn_Ops.bn_block(embedding_yq, normal=FLAGS.normalize, is_Training=is_Training, name='BN1', reuse=True) embedding_z_quta = nn_Ops.fc_block(embedding_z_quta, in_d=1024, out_d=FLAGS.embedding_size, name='fc1', is_bn=False, is_relu=False, reuse=True, is_Training=is_Training) embedding_zq_anc = tf.slice( input_=embedding_z_quta, begin=[0, 0], size=[int(FLAGS.batch_size / 2), int(FLAGS.embedding_size)]) embedding_zq_negtile = tf.slice( input_=embedding_z_quta, begin=[int(FLAGS.batch_size / 2), 0], size=[ int(np.square(FLAGS.batch_size / 2)), int(FLAGS.embedding_size) ]) with tf.name_scope('Loss'): J_syn = (1. - tf.exp(-FLAGS.beta / Jgen)) * Loss_ops.new_npair_loss( labels=label, embedding_anchor=embedding_zq_anc, embedding_positive=embedding_zq_negtile, equal_shape=False, reg_lambda=FLAGS.loss_l2_reg) J_m = (tf.exp(-FLAGS.beta / Jgen)) * J_m J_metric = J_m + J_syn cross_entropy, W_fc, b_fc = HDML.cross_entropy( embedding=embedding_y_origin, label=label) embedding_yq_anc = tf.slice(input_=embedding_yq, begin=[0, 0], size=[int(FLAGS.batch_size / 2), 1024]) embedding_yq_negtile = tf.slice( input_=embedding_yq, begin=[int(FLAGS.batch_size / 2), 0], size=[int(np.square(FLAGS.batch_size / 2)), 1024]) J_recon = (1 - FLAGS._lambda) * tf.reduce_sum( tf.square(embedding_yp - embedding_y_origin)) J_soft = HDML.genSoftmax(embedding_anc=embedding_yq_anc, embedding_neg=embedding_yq_negtile, W_fc=W_fc, b_fc=b_fc, label=label) J_gen = J_recon + J_soft if FLAGS.Apply_HDML: c_train_step = nn_Ops.training(loss=J_metric, lr=lr, var_scope='Classifier') g_train_step = nn_Ops.training(loss=J_gen, lr=FLAGS.lr_gen, var_scope='Generator') s_train_step = nn_Ops.training(loss=cross_entropy, lr=FLAGS.s_lr, var_scope='Softmax_classifier') else: train_step = nn_Ops.training(loss=J_m, lr=lr) # initialise the session with tf.Session(config=config) as sess: # Initial all the variables with the sess sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() # learning rate _lr = FLAGS.init_learning_rate # Restore a checkpoint if FLAGS.load_formalVal: saver.restore( sess, FLAGS.log_save_path + FLAGS.dataSet + '/' + FLAGS.LossType + '/' + FLAGS.formerTimer) # Training # epoch_iterator = stream_train.get_epoch_iterator() # collectors J_m_loss = nn_Ops.data_collector(tag='Jm', init=1e+6) J_syn_loss = nn_Ops.data_collector(tag='J_syn', init=1e+6) J_metric_loss = nn_Ops.data_collector(tag='J_metric', init=1e+6) J_soft_loss = nn_Ops.data_collector(tag='J_soft', init=1e+6) J_recon_loss = nn_Ops.data_collector(tag='J_recon', init=1e+6) J_gen_loss = nn_Ops.data_collector(tag='J_gen', init=1e+6) cross_entropy_loss = nn_Ops.data_collector(tag='cross_entropy', init=1e+6) wd_Loss = nn_Ops.data_collector(tag='weight_decay', init=1e+6) # max_nmi = 0 # step = 0 # # bp_epoch = FLAGS.init_batch_per_epoch # evaluation print('only eval!') # nmi_tr, f1_tr, recalls_tr = evaluation.Evaluation( # stream_train_eval, image_mean, sess, x_raw, label_raw, is_Training, embedding_z, 98, neighbours) # nmi_te, f1_te, recalls_te = evaluation.Evaluation( # stream_test, image_mean, sess, x_raw, label_raw, is_Training, embedding_z, FLAGS.num_class_test, neighbours) embeddings, labels = evaluation.Evaluation_icon2( stream_train, image_mean, sess, x_raw, label_raw, is_Training, embedding_z, FLAGS.num_class_test, neighbours) out_dir = os.path.expanduser( FLAGS.log_save_path[0:len(FLAGS.log_save_path) - 1] + '_embeddings') if not os.path.exists(out_dir): os.makedirs(out_dir) for idx, distance in enumerate(embeddings): save_1Darray(distance, os.path.expanduser(os.path.join(out_dir, str(idx)))) print('End')
def main(_): if not FLAGS.LossType == 'Triplet': print("LossType triplet loss is required") return 0 # placeholders x_raw = tf.placeholder(tf.float32, shape=[None, FLAGS.default_image_size, FLAGS.default_image_size, 3]) label_raw = tf.placeholder(tf.int32, shape=[None, 1]) with tf.name_scope('istraining'): is_Training = tf.placeholder(tf.bool) with tf.name_scope('learning_rate'): lr = tf.placeholder(tf.float32) with tf.variable_scope('Classifier'): google_net_model = GoogleNet_Model.GoogleNet_Model(pooling_type=FLAGS.pooling_type) embedding = google_net_model.forward(x_raw) embedding = nn_Ops.bn_block( embedding, normal=FLAGS.normalize, is_Training=is_Training, name='BN1') embedding = nn_Ops.fc_block( embedding, in_d=1024, out_d=FLAGS.embedding_size, name='fc1', is_bn=False, is_relu=False, is_Training=is_Training ) with tf.name_scope('Loss'): # wdLoss = layers.apply_regularization(regularizer, weights_list=None) def exclude_batch_norm(name): return 'batch_normalization' not in name and 'Generator' not in name and 'Loss' not in name wdLoss = FLAGS.Regular_factor * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables() if exclude_batch_norm(v.name)] ) # Get the Label label = tf.reduce_mean(label_raw, axis=1, keep_dims=False) # For some kinds of Losses, the embedding should be l2 normed #embedding_l = embedding_z J_m = Loss_ops.Loss(embedding, label, FLAGS.LossType) # J_m=tf.Print(J_m,[J_m]) J_m = J_m + wdLoss with tf.name_scope('Jgen2'): Jgen2 = tf.placeholder(tf.float32) with tf.name_scope('Pull_dis_mean'): Pull_dis_mean = tf.placeholder(tf.float32) Pull_dis = nn_Ops.data_collector(tag='pulling_linear', init=25) embedding_l, disap = Two_Stage_Model.Pulling_Positivate(FLAGS.LossType, embedding, can=FLAGS.alpha) with tf.variable_scope('Generator1'): embedding_g = Two_Stage_Model.generator_ori(embedding_l, FLAGS.LossType) with tf.name_scope('Loss'): def exclude_batch_norm1(name): return 'batch_normalization' not in name and 'Generator1' in name and 'Loss' not in name wdLoss_g1 = FLAGS.Regular_factor * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables() if exclude_batch_norm1(v.name)] ) # Generator1_S with tf.variable_scope('GeneratorS'): embedding_s = Two_Stage_Model.generator_ori(embedding_g, FLAGS.LossType) with tf.name_scope('Loss'): def exclude_batch_norm2(name): return 'batch_normalization' not in name and 'GeneratorS' in name and 'Loss' not in name wdLoss_g1s = FLAGS.Regular_factor * tf.add_n( [tf.nn.l2_loss(v) for v in tf.trainable_variables() if exclude_batch_norm2(v.name)] ) # Discriminator1 only contains the anchor and positive message with tf.variable_scope('Discriminator1') as scope: embedding_dis1 = Two_Stage_Model.discriminator1(Two_Stage_Model.slice_ap_n(embedding),1) scope.reuse_variables() embedding_g_dis1 = Two_Stage_Model.discriminator1(Two_Stage_Model.slice_ap_n(embedding_g),1) # scope.reuse_variables() # embedding_s_dis1 = Two_Stage_Model.discriminator1(Two_Stage_Model.slice_ap_n(embedding_s)) with tf.variable_scope('DiscriminatorS') as scope: embedding_dis1S = Two_Stage_Model.discriminator1(Two_Stage_Model.slice_ap_n(embedding),1) scope.reuse_variables() embedding_s_dis1S = Two_Stage_Model.discriminator1(Two_Stage_Model.slice_ap_n(embedding_s),1) with tf.variable_scope('Generator2'): embedding_h = Two_Stage_Model.generator_ori(embedding_g) with tf.variable_scope('Discriminator2') as scope: embedding_dis2 = Two_Stage_Model.discriminator2(embedding) scope.reuse_variables() embedding_g_dis2 = Two_Stage_Model.discriminator2(embedding_g) scope.reuse_variables() embedding_h_dis2 = Two_Stage_Model.discriminator2(embedding_h) embedding_h_cli = embedding_h ''' using binary_crossentropy replace sparse_softmax_cross_entropy_with_logits ''' with tf.name_scope('Loss'): J_syn = Loss_ops.Loss(embedding_h_cli, label, _lossType=FLAGS.LossType, hard_ori=FLAGS.HARD_ORI) # J_syn = tf.constant(0.) J_m = J_m para1 = tf.exp(-FLAGS.beta / Jgen2) J_metric = para1 * J_m + (1. - para1) * J_syn real_loss_d1 = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.zeros([FLAGS.batch_size*2/3], dtype=tf.int32), logits=embedding_dis1)) generated1_loss_d1 = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.ones([FLAGS.batch_size*2/3], dtype=tf.int32), logits=embedding_g_dis1)) J_LD1 = FLAGS.Softmax_factor * -(tf.reduce_mean(tf.log(1. - tf.nn.sigmoid(embedding_dis1)) + tf.log(tf.nn.sigmoid(embedding_g_dis1)))) J_LG1 = FLAGS.Softmax_factor * (-tf.reduce_mean(tf.log(1. - tf.nn.sigmoid(embedding_g_dis1)))) + wdLoss_g1 embedding_g_split = tf.split(embedding_g, 3, axis=0) embedding_g_split_anc = embedding_g_split[0] embedding_g_split_pos = embedding_g_split[1] dis_g1 = tf.reduce_mean( Two_Stage_Model.distance(embedding_g_split_anc, embedding_g_split_pos)) dis_g1 = tf.maximum(-dis_g1 + 1000, 0.)*10 J_LG1 = J_LG1 '''add for D1S''' J_LD1S = FLAGS.Softmax_factor * (-tf.reduce_mean(tf.log(1. - tf.nn.sigmoid(embedding_dis1S)) + tf.log(tf.nn.sigmoid(embedding_s_dis1S)))) J_LG1_S_cross = FLAGS.Softmax_factor *(-tf.reduce_mean(tf.log(1. - tf.nn.sigmoid(embedding_s_dis1S)))) recon_ori_s = FLAGS.Recon_factor * tf.reduce_mean(Two_Stage_Model.distance(embedding_s, embedding)) J_LG1_S = J_LG1_S_cross + recon_ori_s + wdLoss_g1s # label_onehot = tf.one_hot(label, FLAGS.num_class+1) real_loss_d2 = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=embedding_dis2)) generated2_loss_d2 = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=embedding_g_dis2)) generated2_h_loss_d2 = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.zeros([FLAGS.batch_size] ,dtype=tf.int32)+FLAGS.num_class, logits=embedding_h_dis2)) J_LD2 = FLAGS.Softmax_factor * (real_loss_d2 + generated2_loss_d2 + generated2_h_loss_d2) / 3 J_LD2 = J_LD2 + J_syn cross_entropy, W_fc, b_fc = THSG.cross_entropy(embedding=embedding, label=label) Logits_q = tf.matmul(embedding_h, W_fc) + b_fc J_LG2_cross_entropy = FLAGS.Softmax_factor * FLAGS._lambda * tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=Logits_q)) J_LG2C_cross_GAN = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=embedding_h_dis2)) J_LG2C_cross = FLAGS.Softmax_factor * (J_LG2C_cross_GAN + J_LG2_cross_entropy) / 2 recon_g_h_ancpos = FLAGS.Recon_factor * tf.reduce_mean(Two_Stage_Model.distance(Two_Stage_Model.slice_ap_n(embedding_g), Two_Stage_Model.slice_ap_n(embedding_h))) J_fan = FLAGS.Softmax_factor * Loss_ops.Loss_fan(embedding_h, label, _lossType=FLAGS.LossType, param=2 - tf.exp(-FLAGS.beta / Jgen2), hard_ori=FLAGS.HARD_ORI) J_LG2 = J_LG2C_cross + recon_g_h_ancpos + J_fan # J_LG2 = tf.Print(J_LG2, [J_LG2]) J_F = J_metric c_train_step = nn_Ops.training(loss=J_F, lr=lr, var_scope='Classifier') d1_train_step = nn_Ops.training(loss=J_LD1, lr=FLAGS.lr_dis, var_scope='Discriminator1') g1_train_step = nn_Ops.training(loss=J_LG1, lr=FLAGS.lr_gen, var_scope='Generator1') d1s_train_step = nn_Ops.training(loss=J_LD1S, lr=FLAGS.lr_dis, var_scope='DiscriminatorS') g1s_train_step = nn_Ops.training(loss=J_LG1_S, lr=FLAGS.lr_gen, var_scope='GeneratorS*Generator1') d2_train_step = nn_Ops.training(loss=J_LD2, lr=FLAGS.lr_dis, var_scope='Discriminator2') g2_train_step = nn_Ops.training(loss=J_LG2, lr=FLAGS.lr_gen, var_scope='Generator2') s_train_step = nn_Ops.training(loss=cross_entropy, lr=FLAGS.s_lr, var_scope='Softmax_classifier') # initialise the session with tf.Session(config=config) as sess: summary_writer = tf.summary.FileWriter(LOGDIR, sess.graph) # Initial all the variables with the sess sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() # learning rate _lr = FLAGS.init_learning_rate # Restore a checkpoint if FLAGS.load_formalVal: saver.restore(sess, FLAGS.log_save_path+FLAGS.dataSet+'/'+FLAGS.LossType+'/'+FLAGS.formerTimer) # Training epoch_iterator = stream_train.get_epoch_iterator() # collectors J_m_loss = nn_Ops.data_collector(tag='J_m', init=1e+6) J_syn_loss = nn_Ops.data_collector(tag='J_syn', init=1e+6) J_metric_loss = nn_Ops.data_collector(tag='J_metric', init=1e+6) real_loss_d1_loss = nn_Ops.data_collector(tag='real_loss_d1', init=1e+6) generated1_loss_d1_loss = nn_Ops.data_collector(tag='generated1_loss_d1', init=1e+6) J_LD1S_loss = nn_Ops.data_collector(tag='L_D1S', init=1e+6) J_LG1_loss = nn_Ops.data_collector(tag='J_LG1', init=1e+6) J_LG1_S_cross_loss = nn_Ops.data_collector(tag='J_LG1_S_cross', init=1e+6) recon_ori_s_loss = nn_Ops.data_collector(tag='recon_ori_s', init=1e+6) real_loss_d2_loss = nn_Ops.data_collector(tag='real_loss_d2', init=1e+6) generated2_loss_d2_loss = nn_Ops.data_collector(tag='generated2_loss_d2', init=1e+6) generated2_h_loss_d2_loss = nn_Ops.data_collector(tag='generated2_h_loss_d2', init=1e+6) J_LG2C_cross_loss = nn_Ops.data_collector(tag='J_LG2C_cross', init=1e+6) recon_g_h_ancpos_loss = nn_Ops.data_collector(tag='recon_g_h_ancpos', init=1e+6) J_LG2_loss = nn_Ops.data_collector(tag='J_LG2', init=1e+6) J_fan_loss = nn_Ops.data_collector(tag='J_fan', init=1e+6) J_LD1_loss = nn_Ops.data_collector(tag='J_LD1', init=1e+6) J_LD2_loss = nn_Ops.data_collector(tag='J_LD2', init=1e+6) J_F_loss = nn_Ops.data_collector(tag='J_F', init=1e+6) cross_entropy_loss = nn_Ops.data_collector(tag='cross_entropy', init=1e+6) dis_g1_loss = nn_Ops.data_collector(tag='dis_g1', init=1e+6) max_nmi = 0 step = 0 bp_epoch = FLAGS.init_batch_per_epoch with tqdm(total=FLAGS.max_steps) as pbar: for batch in copy.copy(epoch_iterator): # get images and labels from batch x_batch_data, Label_raw = nn_Ops.batch_data(batch) pbar.update(1) _, _, _, disap_var = sess.run([d1_train_step, d1s_train_step, d2_train_step,disap], feed_dict={x_raw: x_batch_data, label_raw: Label_raw, is_Training: True, lr: _lr, Pull_dis_mean : Pull_dis.read()}) Pull_dis.update(var=disap_var.mean()*0.8) _, _, _, disap_var = sess.run([d1_train_step, d1s_train_step, d2_train_step,disap], feed_dict={x_raw: x_batch_data, label_raw: Label_raw, is_Training: True, lr: _lr, Pull_dis_mean : Pull_dis.read()}) Pull_dis.update(var=disap_var.mean()*0.8) c_train, s_train, d1_train, g1_train, d1s_train, g1s_train, d2_train, g2_train, real_loss_d2_var, J_metric_var, J_m_var, \ J_syn_var, real_loss_d1_var, generated1_loss_d1_var, J_LD1_var, J_LD2_var, J_LD1S_var, \ J_LG1_var, J_LG1_S_cross_var, recon_ori_s_var, real_loss_d2_var, generated2_loss_d2_var, cross_entropy_var, \ generated2_h_loss_d2_var, J_LG2C_cross_var, recon_g_h_ancpos_var, J_LG2_var, J_fan_var, J_F_var, dis_g1_var,disap_var\ = sess.run( [c_train_step, s_train_step, d1_train_step, g1_train_step, d1s_train_step, g1s_train_step, d2_train_step, g2_train_step, real_loss_d2, J_metric, J_m, J_syn, real_loss_d1, generated1_loss_d1, J_LD1, J_LD2, J_LD1S, J_LG1, J_LG1_S_cross, recon_ori_s, real_loss_d2, generated2_loss_d2, cross_entropy, generated2_h_loss_d2, J_LG2C_cross, recon_g_h_ancpos, J_LG2, J_fan, J_F, dis_g1,disap], feed_dict={x_raw: x_batch_data, label_raw: Label_raw, is_Training: True, lr: _lr, Jgen2: J_LG2_loss.read(), Pull_dis_mean : Pull_dis.read()}) Pull_dis.update(var=disap_var.mean()*0.8) real_loss_d2_loss.update(var=real_loss_d2_var) J_metric_loss.update(var=J_metric_var) J_m_loss.update(var=J_m_var) J_syn_loss.update(var=J_syn_var) real_loss_d1_loss.update(var=real_loss_d1_var) generated1_loss_d1_loss.update(var=generated1_loss_d1_var) J_LD1S_loss.update(var=J_LD1S_var) dis_g1_loss.update(var=dis_g1_var) J_LG1_loss.update(var=J_LG1_var) J_LG1_S_cross_loss.update(var=J_LG1_S_cross_var) recon_ori_s_loss.update(var=recon_ori_s_var) real_loss_d2_loss.update(var=real_loss_d2_var) generated2_loss_d2_loss.update(var=generated2_loss_d2_var) generated2_h_loss_d2_loss.update(var=generated2_h_loss_d2_var) J_LG2C_cross_loss.update(var=J_LG2C_cross_var) recon_g_h_ancpos_loss.update(var=recon_g_h_ancpos_var) J_LG2_loss.update(var=J_LG2_var) J_fan_loss.update(var=J_fan_var) J_LD1_loss.update(var=J_LD1_var) J_LD2_loss.update(var=J_LD2_var) J_F_loss.update(var=J_F_var) cross_entropy_loss.update(var=cross_entropy_var) step += 1 # print('learning rate %f' % _lr) # evaluation if step % bp_epoch == 0: print('only eval eval') # nmi_te_cli, f1_te_cli, recalls_te_cli, map_cli = evaluation.Evaluation( # stream_test, image_mean, sess, x_raw, label_raw, is_Training, embedding, 98, neighbours) recalls_te_cli, map_cli = evaluation.Evaluation( stream_test, image_mean, sess, x_raw, label_raw, is_Training, embedding, 98, neighbours) # Summary eval_summary = tf.Summary() # eval_summary.value.add(tag='test nmi', simple_value=nmi_te_cli) # eval_summary.value.add(tag='test f1', simple_value=f1_te_cli) eval_summary.value.add(tag='test map', simple_value=map_cli) for i in range(0, np.shape(neighbours)[0]): eval_summary.value.add(tag='Recall@%d test' % neighbours[i], simple_value=recalls_te_cli[i]) # Embedding_Visualization.embedding_Visual("./", embedding_var, summary_writer) real_loss_d2_loss.write_to_tfboard(eval_summary) J_metric_loss.write_to_tfboard(eval_summary) J_m_loss.write_to_tfboard(eval_summary) J_syn_loss.write_to_tfboard(eval_summary) real_loss_d1_loss.write_to_tfboard(eval_summary) generated1_loss_d1_loss.write_to_tfboard(eval_summary) J_LD1S_loss.write_to_tfboard(eval_summary) J_LG1_loss.write_to_tfboard(eval_summary) dis_g1_loss.write_to_tfboard(eval_summary) J_LD1_loss.write_to_tfboard(eval_summary) J_LD2_loss.write_to_tfboard(eval_summary) J_F_loss.write_to_tfboard(eval_summary) J_LG1_S_cross_loss.write_to_tfboard(eval_summary) recon_ori_s_loss.write_to_tfboard(eval_summary) real_loss_d2_loss.write_to_tfboard(eval_summary) generated2_loss_d2_loss.write_to_tfboard(eval_summary) generated2_h_loss_d2_loss.write_to_tfboard(eval_summary) J_LG2C_cross_loss.write_to_tfboard(eval_summary) recon_g_h_ancpos_loss.write_to_tfboard(eval_summary) J_LG2_loss.write_to_tfboard(eval_summary) J_fan_loss.write_to_tfboard(eval_summary) cross_entropy_loss.write_to_tfboard(eval_summary) summary_writer.add_summary(eval_summary, step) print('Summary written') if map_cli > max_nmi: max_nmi = map_cli print("Saved") saver.save(sess, os.path.join(LOGDIR, "model.ckpt")) # saver.save(sess, os.path.join(LOGDIR, "model.ckpt")) summary_writer.flush() if step in [5632, 6848]: _lr = _lr * 0.5 if step >= 5000: bp_epoch = FLAGS.batch_per_epoch if step >= FLAGS.max_steps: os._exit(0)
def main(_): if not FLAGS.LossType == 'NpairLoss': print("LossType n-pair-loss is required") return 0 # placeholders x_raw = tf.placeholder( tf.float32, shape=[None, FLAGS.default_image_size, FLAGS.default_image_size, 3]) label_raw = tf.placeholder(tf.int32, shape=[None, 1]) with tf.name_scope('istraining'): is_Training = tf.placeholder(tf.bool) with tf.name_scope('learning_rate'): lr = tf.placeholder(tf.float32) with tf.variable_scope('Classifier'): google_net_model = GoogleNet_Model.GoogleNet_Model() embedding = google_net_model.forward(x_raw) if FLAGS.Apply_HDML: embedding_y_origin = embedding # Batch Normalization layer 1 embedding = nn_Ops.bn_block(embedding, normal=FLAGS.normalize, is_Training=is_Training, name='BN1') # FC layer 1 embedding_z = nn_Ops.fc_block(embedding, in_d=1024, out_d=FLAGS.embedding_size, name='fc1', is_bn=False, is_relu=False, is_Training=is_Training) # Embedding Visualization assignment, embedding_var = Embedding_Visualization.embedding_assign( batch_size=256, embedding=embedding_z, embedding_size=FLAGS.embedding_size, name='Embedding_of_fc1') # conventional Loss function with tf.name_scope('Loss'): # wdLoss = layers.apply_regularization(regularizer, weights_list=None) def exclude_batch_norm(name): return 'batch_normalization' not in name and 'Generator' not in name and 'Loss' not in name wdLoss = FLAGS.Regular_factor * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if exclude_batch_norm(v.name) ]) # Get the Label label = tf.reduce_mean(label_raw, axis=1, keep_dims=False) # For some kinds of Losses, the embedding should be l2 normed #embedding_l = embedding_z J_m = Loss_ops.Loss(embedding_z, label, FLAGS.LossType) + wdLoss # if HNG is applied if FLAGS.Apply_HDML: with tf.name_scope('Javg'): Javg = tf.placeholder(tf.float32) with tf.name_scope('Jgen'): Jgen = tf.placeholder(tf.float32) embedding_z_quta = HDML.Pulling(FLAGS.LossType, embedding_z, Javg) embedding_z_concate = tf.concat([embedding_z, embedding_z_quta], axis=0) # Generator with tf.variable_scope('Generator'): # generator fc3 embedding_y_concate = nn_Ops.fc_block(embedding_z_concate, in_d=FLAGS.embedding_size, out_d=512, name='generator1', is_bn=True, is_relu=True, is_Training=is_Training) # generator fc4 embedding_y_concate = nn_Ops.fc_block(embedding_y_concate, in_d=512, out_d=1024, name='generator2', is_bn=False, is_relu=False, is_Training=is_Training) embedding_yp, embedding_yq = HDML.npairSplit(embedding_y_concate) with tf.variable_scope('Classifier'): embedding_z_quta = nn_Ops.bn_block(embedding_yq, normal=FLAGS.normalize, is_Training=is_Training, name='BN1', reuse=True) embedding_z_quta = nn_Ops.fc_block(embedding_z_quta, in_d=1024, out_d=FLAGS.embedding_size, name='fc1', is_bn=False, is_relu=False, reuse=True, is_Training=is_Training) embedding_zq_anc = tf.slice( input_=embedding_z_quta, begin=[0, 0], size=[int(FLAGS.batch_size / 2), int(FLAGS.embedding_size)]) embedding_zq_negtile = tf.slice( input_=embedding_z_quta, begin=[int(FLAGS.batch_size / 2), 0], size=[ int(np.square(FLAGS.batch_size / 2)), int(FLAGS.embedding_size) ]) with tf.name_scope('Loss'): J_syn = (1. - tf.exp(-FLAGS.beta / Jgen)) * Loss_ops.new_npair_loss( labels=label, embedding_anchor=embedding_zq_anc, embedding_positive=embedding_zq_negtile, equal_shape=False, reg_lambda=FLAGS.loss_l2_reg) J_m = (tf.exp(-FLAGS.beta / Jgen)) * J_m J_metric = J_m + J_syn cross_entropy, W_fc, b_fc = HDML.cross_entropy( embedding=embedding_y_origin, label=label) embedding_yq_anc = tf.slice(input_=embedding_yq, begin=[0, 0], size=[int(FLAGS.batch_size / 2), 1024]) embedding_yq_negtile = tf.slice( input_=embedding_yq, begin=[int(FLAGS.batch_size / 2), 0], size=[int(np.square(FLAGS.batch_size / 2)), 1024]) J_recon = (1 - FLAGS._lambda) * tf.reduce_sum( tf.square(embedding_yp - embedding_y_origin)) J_soft = HDML.genSoftmax(embedding_anc=embedding_yq_anc, embedding_neg=embedding_yq_negtile, W_fc=W_fc, b_fc=b_fc, label=label) J_gen = J_recon + J_soft if FLAGS.Apply_HDML: c_train_step = nn_Ops.training(loss=J_metric, lr=lr, var_scope='Classifier') g_train_step = nn_Ops.training(loss=J_gen, lr=FLAGS.lr_gen, var_scope='Generator') s_train_step = nn_Ops.training(loss=cross_entropy, lr=FLAGS.s_lr, var_scope='Softmax_classifier') else: train_step = nn_Ops.training(loss=J_m, lr=lr) # initialise the session with tf.Session(config=config) as sess: # Initial all the variables with the sess sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() # learning rate _lr = FLAGS.init_learning_rate # Restore a checkpoint if FLAGS.load_formalVal: saver.restore( sess, FLAGS.log_save_path + FLAGS.dataSet + '/' + FLAGS.LossType + '/' + FLAGS.formerTimer) # Training epoch_iterator = stream_train.get_epoch_iterator() # collectors J_m_loss = nn_Ops.data_collector(tag='Jm', init=1e+6) J_syn_loss = nn_Ops.data_collector(tag='J_syn', init=1e+6) J_metric_loss = nn_Ops.data_collector(tag='J_metric', init=1e+6) J_soft_loss = nn_Ops.data_collector(tag='J_soft', init=1e+6) J_recon_loss = nn_Ops.data_collector(tag='J_recon', init=1e+6) J_gen_loss = nn_Ops.data_collector(tag='J_gen', init=1e+6) cross_entropy_loss = nn_Ops.data_collector(tag='cross_entropy', init=1e+6) wd_Loss = nn_Ops.data_collector(tag='weight_decay', init=1e+6) max_nmi = 0 step = 0 bp_epoch = FLAGS.init_batch_per_epoch with tqdm(total=FLAGS.max_steps) as pbar: for batch in copy.copy(epoch_iterator): # get images and labels from batch x_batch_data, Label_raw = nn_Ops.batch_data(batch) pbar.update(1) if not FLAGS.Apply_HDML: train, J_m_var, wd_Loss_var = sess.run( [train_step, J_m, wdLoss], feed_dict={ x_raw: x_batch_data, label_raw: Label_raw, is_Training: True, lr: _lr }) J_m_loss.update(var=J_m_var) wd_Loss.update(var=wd_Loss_var) else: c_train, g_train, s_train, wd_Loss_var, J_metric_var, J_m_var, \ J_syn_var, J_recon_var, J_soft_var, J_gen_var, cross_en_var = sess.run( [c_train_step, g_train_step, s_train_step, wdLoss, J_metric, J_m, J_syn, J_recon, J_soft, J_gen, cross_entropy], feed_dict={x_raw: x_batch_data, label_raw: Label_raw, is_Training: True, lr: _lr, Javg: J_m_loss.read(), Jgen: J_gen_loss.read()}) wd_Loss.update(var=wd_Loss_var) J_metric_loss.update(var=J_metric_var) J_m_loss.update(var=J_m_var) J_syn_loss.update(var=J_syn_var) J_recon_loss.update(var=J_recon_var) J_soft_loss.update(var=J_soft_var) J_gen_loss.update(var=J_gen_var) cross_entropy_loss.update(cross_en_var) step += 1 # print('learning rate %f' % _lr) # evaluation if step % bp_epoch == 0: print('only eval eval') # nmi_tr, f1_tr, recalls_tr = evaluation.Evaluation( # stream_train_eval, image_mean, sess, x_raw, label_raw, is_Training, embedding_z, 98, neighbours) nmi_te, f1_te, recalls_te = evaluation.Evaluation( stream_test, image_mean, sess, x_raw, label_raw, is_Training, embedding_z, FLAGS.num_class_test, neighbours) # Summary eval_summary = tf.Summary() # eval_summary.value.add(tag='train nmi', simple_value=nmi_tr) # eval_summary.value.add(tag='train f1', simple_value=f1_tr) # for i in range(0, np.shape(neighbours)[0]): # eval_summary.value.add(tag='Recall@%d train' % neighbours[i], simple_value=recalls_tr[i]) eval_summary.value.add(tag='test nmi', simple_value=nmi_te) eval_summary.value.add(tag='test f1', simple_value=f1_te) for i in range(0, np.shape(neighbours)[0]): eval_summary.value.add(tag='Recall@%d test' % neighbours[i], simple_value=recalls_te[i]) J_m_loss.write_to_tfboard(eval_summary) wd_Loss.write_to_tfboard(eval_summary) eval_summary.value.add(tag='learning_rate', simple_value=_lr) if FLAGS.Apply_HDML: J_syn_loss.write_to_tfboard(eval_summary) J_metric_loss.write_to_tfboard(eval_summary) J_soft_loss.write_to_tfboard(eval_summary) J_recon_loss.write_to_tfboard(eval_summary) J_gen_loss.write_to_tfboard(eval_summary) cross_entropy_loss.write_to_tfboard(eval_summary) summary_writer.add_summary(eval_summary, step) print('Summary written') if nmi_te > max_nmi: max_nmi = nmi_te print("Saved") saver.save(sess, os.path.join(LOGDIR, "modelBest.ckpt")) saver.save(sess, os.path.join(LOGDIR, "model.ckpt")) summary_writer.flush() # if step in [5632, 6848]: # _lr = _lr * 0.5 if step >= 5000: bp_epoch = FLAGS.batch_per_epoch if step >= FLAGS.max_steps: os._exit()
def main(_): x_raw = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3]) label_raw = tf.placeholder(tf.int32, shape=[None, 1]) with tf.name_scope('istraining'): is_Training = tf.placeholder(tf.bool) with tf.name_scope('isphase'): is_Phase = tf.placeholder(tf.bool) with tf.name_scope('learning_rate'): lr = tf.placeholder(tf.float32) with tf.name_scope('lambdas'): lambda1 = tf.placeholder(tf.float32) lambda2 = tf.placeholder(tf.float32) lambda3 = tf.placeholder(tf.float32) lambda4 = tf.placeholder(tf.float32) with tf.variable_scope('Feature_extractor'): google_net_model = GoogleNet_Model.GoogleNet_Model() embedding = google_net_model.forward(x_raw) embedding_y_origin = embedding # output embedding = nn_Ops.bn_block(embedding, normal=True, is_Training=is_Training, name='FC3') embedding_z = nn_Ops.fc_block(embedding, in_d=1024, out_d=EMBEDDING_SIZE, name='fc1', is_bn=False, is_relu=False, is_Training=is_Training) # predict mu embedding1 = nn_Ops.bn_block(embedding, normal=True, is_Training=is_Training, name='FC1') embedding_mu = nn_Ops.fc_block(embedding1, in_d=1024, out_d=EMBEDDING_SIZE, name='fc2', is_bn=False, is_relu=False, is_Training=is_Training) # predict (log (sigma^2)) embedding2 = nn_Ops.bn_block(embedding, normal=True, is_Training=is_Training, name='FC2') embedding_sigma = nn_Ops.fc_block(embedding2, in_d=1024, out_d=EMBEDDING_SIZE, name='fc3', is_bn=False, is_relu=False, is_Training=is_Training) with tf.name_scope('Loss'): def exclude_batch_norm(name): return 'batch_normalization' not in name and 'Generator' not in name and 'Loss' not in name wdLoss = 5e-3 * tf.add_n([ tf.nn.l2_loss(v) for v in tf.trainable_variables() if exclude_batch_norm(v.name) ]) label = tf.reduce_mean(label_raw, axis=1) J_m = Losses.triplet_semihard_loss(label, embedding_z) + wdLoss zv_emb = mysampling(embedding_mu, embedding_sigma) # zv = tfd.Independent(tfd.Normal(loc=embedding_mu, scale=embedding_sigma), # reinterpreted_batch_ndims=1) # zv_emb=zv.sample([1]) zv_emb1 = tf.reshape(zv_emb, (-1, 128)) # zv_prob=zv.prob(zv_emb) # prior_prob=zv.prob(zv_emb) embedding_z_add = tf.add(embedding_z, zv_emb1, name='Synthesized_features') with tf.variable_scope('Decoder'): embedding_y_add = nn_Ops.fc_block(embedding_z_add, in_d=EMBEDDING_SIZE, out_d=512, name='decoder1', is_bn=True, is_relu=True, is_Training=is_Phase) embedding_y_add = nn_Ops.fc_block(embedding_y_add, in_d=512, out_d=1024, name='decoder2', is_bn=False, is_relu=False, is_Training=is_Phase) print("embedding_sigma", embedding_sigma) print("embedding_mu", embedding_mu) with tf.name_scope('Loss_KL'): kl_loss = 1 + embedding_sigma - K.square(embedding_mu) - K.exp( embedding_sigma) kl_loss = K.sum(kl_loss, axis=-1) kl_loss *= -0.5 # L1 loss J_KL = lambda1 * K.mean(kl_loss) with tf.name_scope('Loss_Recon'): # L2 Loss J_recon = lambda2 * (0.5) * tf.reduce_sum( tf.square(embedding_y_add - embedding_y_origin)) with tf.name_scope('Loss_synthetic'): # L3 Loss J_syn = lambda3 * Losses.triplet_semihard_loss( labels=label, embeddings=embedding_z_add) with tf.name_scope('Loss_metric'): # L4 Loss J_metric = lambda4 * J_m with tf.name_scope('Loss_Softmax_classifier'): cross_entropy, W_fc, b_fc = Losses.cross_entropy( embedding=embedding_y_origin, label=label) c_train_step = nn_Ops.training(loss=J_metric + J_syn + J_KL, lr=lr, var_scope='Feature_extractor') g_train_step = nn_Ops.training(loss=J_recon, lr=LR_gen, var_scope='Decoder') s_train_step = nn_Ops.training(loss=cross_entropy, lr=LR_s, var_scope='Softmax_classifier') def model_summary(): model_vars = tf.trainable_variables() slim.model_analyzer.analyze_vars(model_vars, print_info=True) model_summary() with tf.Session(config=config) as sess: summary_writer = tf.summary.FileWriter(LOGDIR, sess.graph) sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() _lr = LR_init epoch_iterator = stream_train.get_epoch_iterator() J_KL_loss = nn_Ops.data_collector(tag='JKL', init=1e+6) J_m_loss = nn_Ops.data_collector(tag='Jm', init=1e+6) J_syn_loss = nn_Ops.data_collector(tag='J_syn', init=1e+6) J_metric_loss = nn_Ops.data_collector(tag='J_metric', init=1e+6) J_recon_loss = nn_Ops.data_collector(tag='J_recon', init=1e+6) cross_entropy_loss = nn_Ops.data_collector(tag='cross_entropy', init=1e+6) wd_Loss = nn_Ops.data_collector(tag='weight_decay', init=1e+6) max_nmi = 0 step = 0 with tqdm(total=MAX_ITER) as pbar: for batch in copy.copy(epoch_iterator): # get images and labels from batch x_batch_data, Label_raw = nn_Ops.batch_data(batch) pbar.update(1) c_train, g_train, s_train, wd_Loss_var, J_KL_var,J_metric_var, J_m_var, \ J_syn_var, J_recon_var, cross_en_var = sess.run( [c_train_step, g_train_step, s_train_step, wdLoss, J_KL, J_metric, J_m, J_syn, J_recon, cross_entropy], feed_dict={x_raw: x_batch_data, label_raw: Label_raw, is_Training: True,is_Phase:False, lambda1:1, lambda2:0.5, lambda3:0.5, lambda4:1, lr: _lr}) wd_Loss.update(var=wd_Loss_var) J_KL_loss.update(var=J_KL_var) J_metric_loss.update(var=J_metric_var) J_m_loss.update(var=J_m_var) J_syn_loss.update(var=J_syn_var) J_recon_loss.update(var=J_recon_var) cross_entropy_loss.update(cross_en_var) step += 1 # evaluation if step % bp_epoch == 0: print('only eval eval') nmi_te, f1_te, recalls_te = evaluation.Evaluation( stream_test, image_mean, sess, x_raw, label_raw, is_Training, is_Phase, embedding_z, 98, neighbours) # Summary eval_summary = tf.Summary() eval_summary.value.add(tag='test nmi', simple_value=nmi_te) eval_summary.value.add(tag='test f1', simple_value=f1_te) for i in range(0, np.shape(neighbours)[0]): eval_summary.value.add(tag='Recall@%d test' % neighbours[i], simple_value=recalls_te[i]) J_m_loss.write_to_tfboard(eval_summary) wd_Loss.write_to_tfboard(eval_summary) J_KL_loss.write_to_tfboard(eval_summary) eval_summary.value.add(tag='learning_rate', simple_value=_lr) J_syn_loss.write_to_tfboard(eval_summary) J_metric_loss.write_to_tfboard(eval_summary) J_recon_loss.write_to_tfboard(eval_summary) cross_entropy_loss.write_to_tfboard(eval_summary) summary_writer.add_summary(eval_summary, step) print('Summary written') if nmi_te > max_nmi: max_nmi = nmi_te print("Saved") saver.save(sess, os.path.join(LOGDIR, "model.ckpt")) summary_writer.flush() if step >= MAX_ITER: break with tqdm(total=MAX_ITER) as pbar: for batch in copy.copy(epoch_iterator): # get images and labels from batch x_batch_data, Label_raw = nn_Ops.batch_data(batch) pbar.update(1) c_train, g_train, s_train, wd_Loss_var, J_KL_var,J_metric_var, J_m_var, \ J_syn_var, J_recon_var, cross_en_var = sess.run( [c_train_step, g_train_step, s_train_step, wdLoss, J_KL, J_metric, J_m, J_syn, J_recon, cross_entropy], feed_dict={x_raw: x_batch_data, label_raw: Label_raw, is_Training: True,is_Phase:True, lambda1:0.8, lambda2:1, lambda3:0.4, lambda4:0.8, lr: _lr}) wd_Loss.update(var=wd_Loss_var) J_KL_loss.update(var=J_KL_var) J_metric_loss.update(var=J_metric_var) J_m_loss.update(var=J_m_var) J_syn_loss.update(var=J_syn_var) J_recon_loss.update(var=J_recon_var) cross_entropy_loss.update(cross_en_var) step += 1 # evaluation if step % bp_epoch == 0: print('only eval eval') nmi_te, f1_te, recalls_te = evaluation.Evaluation( stream_test, image_mean, sess, x_raw, label_raw, is_Training, is_Phase, embedding_z, 98, neighbours) # Summary eval_summary = tf.Summary() eval_summary.value.add(tag='test nmi', simple_value=nmi_te) eval_summary.value.add(tag='test f1', simple_value=f1_te) for i in range(0, np.shape(neighbours)[0]): eval_summary.value.add(tag='Recall@%d test' % neighbours[i], simple_value=recalls_te[i]) J_m_loss.write_to_tfboard(eval_summary) wd_Loss.write_to_tfboard(eval_summary) J_KL_loss.write_to_tfboard(eval_summary) eval_summary.value.add(tag='learning_rate', simple_value=_lr) J_syn_loss.write_to_tfboard(eval_summary) J_metric_loss.write_to_tfboard(eval_summary) J_recon_loss.write_to_tfboard(eval_summary) cross_entropy_loss.write_to_tfboard(eval_summary) summary_writer.add_summary(eval_summary, step) print('Summary written') if nmi_te > max_nmi: max_nmi = nmi_te print("Saved") saver.save(sess, os.path.join(LOGDIR, "model.ckpt")) summary_writer.flush() if step >= 2 * MAX_ITER: os._exit(os.EX_OK)
def configTfSession(streams, summary_writer, train_steps, losses): stream_train, stream_train_eval, stream_test = streams wdLoss, J_m, Javg, Jgen, J_metric, J_gen, J_syn, cross_entropy, J_recon, J_soft = losses with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() # learning rate _lr = FLAGS.init_learning_rate # Restore a checkpoint if FLAGS.load_formalVal: saver.restore( sess, FLAGS.log_save_path + FLAGS.dataSet + '/' + FLAGS.LossType + '/' + FLAGS.formerTimer) # Training epoch_iterator = stream_train.get_epoch_iterator() # collectors J_m_loss = nn_Ops.data_collector(tag='Jm', init=1e+6) J_syn_loss = nn_Ops.data_collector(tag='J_syn', init=1e+6) J_metric_loss = nn_Ops.data_collector(tag='J_metric', init=1e+6) J_soft_loss = nn_Ops.data_collector(tag='J_soft', init=1e+6) J_recon_loss = nn_Ops.data_collector(tag='J_recon', init=1e+6) J_gen_loss = nn_Ops.data_collector(tag='J_gen', init=1e+6) cross_entropy_loss = nn_Ops.data_collector(tag='cross_entropy', init=1e+6) wd_Loss = nn_Ops.data_collector(tag='weight_decay', init=1e+6) max_nmi = 0 step = 0 bp_epoch = FLAGS.init_batch_per_epoch with tqdm(total=FLAGS.max_steps) as pbar: for batch in copy.copy(epoch_iterator): # get images and labels from batch x_batch_data, Label_raw = nn_Ops.batch_data(batch) pbar.update(1) if not FLAGS.Apply_HDML: train_step = train_steps train, J_m_var, wd_Loss_var = sess.run( [train_step, J_m, wdLoss], feed_dict={ x_raw: x_batch_data, label_raw: Label_raw, is_Training: True, lr: _lr }) J_m_loss.update(var=J_m_var) wd_Loss.update(var=wd_Loss_var) else: c_train_step, g_train_step, s_train_step = train_steps c_train, g_train, s_train, wd_Loss_var, J_metric_var, J_m_var, \ J_syn_var, J_recon_var, J_soft_var, J_gen_var, cross_en_var = sess.run( [c_train_step, g_train_step, s_train_step, wdLoss, J_metric, J_m, J_syn, J_recon, J_soft, J_gen, cross_entropy], feed_dict={x_raw: x_batch_data, label_raw: Label_raw, is_Training: True, lr: _lr, Javg: J_m_loss.read(), Jgen: J_gen_loss.read()}) wd_Loss.update(var=wd_Loss_var) J_metric_loss.update(var=J_metric_var) J_m_loss.update(var=J_m_var) J_syn_loss.update(var=J_syn_var) J_recon_loss.update(var=J_recon_var) J_soft_loss.update(var=J_soft_var) J_gen_loss.update(var=J_gen_var) cross_entropy_loss.update(cross_en_var) step += 1 # print('learning rate %f' % _lr) # evaluation if step % bp_epoch == 0: print('only eval eval') # nmi_tr, f1_tr, recalls_tr = evaluation.Evaluation( # stream_train_eval, image_mean, sess, x_raw, label_raw, is_Training, embedding_z, 98, neighbours) nmi_te, f1_te, recalls_te = evaluation.Evaluation( stream_test, image_mean, sess, x_raw, label_raw, is_Training, embedding_z, 98, neighbours) # Summary eval_summary = tf.Summary() # eval_summary.value.add(tag='train nmi', simple_value=nmi_tr) # eval_summary.value.add(tag='train f1', simple_value=f1_tr) # for i in range(0, np.shape(neighbours)[0]): # eval_summary.value.add(tag='Recall@%d train' % neighbours[i], simple_value=recalls_tr[i]) eval_summary.value.add(tag='test nmi', simple_value=nmi_te) eval_summary.value.add(tag='test f1', simple_value=f1_te) for i in range(0, np.shape(neighbours)[0]): eval_summary.value.add(tag='Recall@%d test' % neighbours[i], simple_value=recalls_te[i]) J_m_loss.write_to_tfboard(eval_summary) wd_Loss.write_to_tfboard(eval_summary) eval_summary.value.add(tag='learning_rate', simple_value=_lr) if FLAGS.Apply_HDML: J_syn_loss.write_to_tfboard(eval_summary) J_metric_loss.write_to_tfboard(eval_summary) J_soft_loss.write_to_tfboard(eval_summary) J_recon_loss.write_to_tfboard(eval_summary) J_gen_loss.write_to_tfboard(eval_summary) cross_entropy_loss.write_to_tfboard(eval_summary) summary_writer.add_summary(eval_summary, step) print('Summary written') if nmi_te > max_nmi: max_nmi = nmi_te print("Saved") saver.save(sess, os.path.join(LOGDIR, "model.ckpt")) summary_writer.flush() if step in [5632, 6848]: _lr = _lr * 0.5 if step >= 5000: bp_epoch = FLAGS.batch_per_epoch if step >= FLAGS.max_steps: os._exit(0)