예제 #1
0
def main(_):

    # placeholders

    # -> raw input data
    x_raw = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3])
    label_raw = tf.placeholder(tf.int32, shape=[None, 1])

    # -> parameters involved
    with tf.name_scope('isTraining'):
        is_Training = tf.placeholder(tf.bool)
    with tf.name_scope('isPhase'):
        is_Phase = tf.placeholder(tf.bool)
    with tf.name_scope('learningRate'):
        lr = tf.placeholder(tf.float32)
    with tf.name_scope('lambdas'):
        lambda1 = tf.placeholder(tf.float32)
        lambda2 = tf.placeholder(tf.float32)
        lambda3 = tf.placeholder(tf.float32)
        lambda4 = tf.placeholder(tf.float32)

    # -> feature extractor layers
    with tf.variable_scope('FeatureExtractor'):

        # modified googLeNet layer pre-trained on ILSVRC2012
        google_net_model = GoogleNet_Model.GoogleNet_Model()
        embedding_gn = google_net_model.forward(x_raw)

        # batch normalization of average pooling layer of googLeNet
        embedding = nn_Ops.bn_block(embedding_gn,
                                    normal=True,
                                    is_Training=is_Training,
                                    name='BN')

        # 3 fully connected layers of Size EMBEDDING_SIZE
        # mean of cluster
        embedding_mu = nn_Ops.fc_block(embedding,
                                       in_d=1024,
                                       out_d=EMBEDDING_SIZE,
                                       name='fc1',
                                       is_bn=False,
                                       is_relu=False,
                                       is_Training=is_Training)
        # log(sigma^2) of cluster
        embedding_sigma = nn_Ops.fc_block(embedding,
                                          in_d=1024,
                                          out_d=EMBEDDING_SIZE,
                                          name='fc2',
                                          is_bn=False,
                                          is_relu=False,
                                          is_Training=is_Training)
        # invariant feature of cluster
        embedding_zi = nn_Ops.fc_block(embedding,
                                       in_d=1024,
                                       out_d=EMBEDDING_SIZE,
                                       name='fc3',
                                       is_bn=False,
                                       is_relu=False,
                                       is_Training=is_Training)

        with tf.name_scope('Loss'):

            def exclude_batch_norm(name):
                return 'batch_normalization' not in name and 'Generator' not in name and 'Loss' not in name

            wdLoss = 5e-3 * tf.add_n([
                tf.nn.l2_loss(v)
                for v in tf.trainable_variables() if exclude_batch_norm(v.name)
            ])
            label = tf.reduce_mean(label_raw, axis=1)
            J_m = Losses.triplet_semihard_loss(label, embedding_zi) + wdLoss

    # Generator
    with tf.variable_scope('Generator'):
        embedding_re = samplingGaussian(embedding_mu, embedding_sigma)
        embedding_zv = tf.reshape(embedding_re, (-1, EMBEDDING_SIZE))
        # Z = Zi + Zv
        embedding_z = tf.add(embedding_zi,
                             embedding_zv,
                             name='Synthesized_features')

    # Decoder
    with tf.variable_scope('Decoder'):
        embedding_y1 = nn_Ops.fc_block(embedding_z,
                                       in_d=EMBEDDING_SIZE,
                                       out_d=512,
                                       name='decoder1',
                                       is_bn=True,
                                       is_relu=True,
                                       is_Training=is_Phase)
        embedding_y2 = nn_Ops.fc_block(embedding_y1,
                                       in_d=512,
                                       out_d=1024,
                                       name='decoder2',
                                       is_bn=False,
                                       is_relu=False,
                                       is_Training=is_Phase)

    print("embedding_sigma", embedding_sigma)
    print("embedding_mu", embedding_mu)

    # Defining the 4 Losses

    # L1 loss
    # Definition: L1 = (1/2) x sum( 1 + log(sigma^2) - mu^2 - sigma^2)
    with tf.name_scope('L1_KLDivergence'):
        kl_loss = 1 + embedding_sigma - K.square(embedding_mu) - K.exp(
            embedding_sigma)
        kl_loss = K.sum(kl_loss, axis=-1)
        kl_loss *= -(0.5 / BATCH_SIZE)
        L1 = lambda1 * K.mean(kl_loss)

    # L2 Loss
    # Definition: L2 = sum( L2Norm( target - outputOf(GoogleNet) ))
    with tf.name_scope('L2_Reconstruction'):
        L2 = lambda2 * (1 / (20 * BATCH_SIZE)) * tf.reduce_sum(
            tf.square(embedding_y2 - embedding_gn))

    # L3 Loss
    # Definition: L3 = Lm( Z )
    with tf.name_scope('L3_Synthesized'):
        L3 = lambda3 * Losses.triplet_semihard_loss(labels=label,
                                                    embeddings=embedding_z)

    # L4 Loss
    # Definition: L4 = Lm( Zi )
    with tf.name_scope('L4_Metric'):
        L4 = lambda4 * J_m

    # Classifier Loss
    with tf.name_scope('Softmax_Loss'):
        cross_entropy, W_fc, b_fc = Losses.cross_entropy(
            embedding=embedding_gn, label=label)

    c_train_step = nn_Ops.training(loss=L4 + L3 + L1,
                                   lr=lr,
                                   var_scope='FeatureExtractor')
    g_train_step = nn_Ops.training(loss=L2, lr=LR_gen, var_scope='Decoder')
    s_train_step = nn_Ops.training(loss=cross_entropy,
                                   lr=LR_s,
                                   var_scope='Softmax_classifier')

    def model_summary():
        model_vars = tf.trainable_variables()
        slim.model_analyzer.analyze_vars(model_vars, print_info=True)

    model_summary()

    with tf.Session(config=config) as sess:

        summary_writer = tf.summary.FileWriter(LOGDIR, sess.graph)
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        _lr = LR_init

        # To Record the losses to TfBoard
        Jm_loss = nn_Ops.data_collector(tag='Jm', init=1e+6)
        L1_loss = nn_Ops.data_collector(tag='KLDivergence', init=1e+6)
        L2_loss = nn_Ops.data_collector(tag='Reconstruction', init=1e+6)
        L3_loss = nn_Ops.data_collector(tag='Synthesized', init=1e+6)
        L4_loss = nn_Ops.data_collector(tag='Metric', init=1e+6)

        cross_entropy_loss = nn_Ops.data_collector(tag='cross_entropy',
                                                   init=1e+6)
        wd_Loss = nn_Ops.data_collector(tag='weight_decay', init=1e+6)
        max_nmi = 0

        print("Phase 1")
        step = 0
        epoch_iterator = stream_train.get_epoch_iterator()
        for epoch in tqdm(range(NUM_EPOCHS_PHASE1)):
            print("Epoch: ", epoch)
            for batch in tqdm(copy.copy(epoch_iterator), total=MAX_ITER):
                step += 1
                # get images and labels from batch
                x_batch_data, Label_raw = nn_Ops.batch_data(batch)

                # training step
                c_train, g_train, s_train, wd_Loss_var, L1_var,L4_var, J_m_var, \
                    L3_var, L2_var, cross_en_var = sess.run(
                        [c_train_step, g_train_step, s_train_step, wdLoss, L1,
                         L4, J_m, L3, L2, cross_entropy],
                        feed_dict={x_raw: x_batch_data,
                                   label_raw: Label_raw,
                                   is_Training: True,is_Phase:False,
                                   lambda1:1, lambda2:1, lambda3:0.1, lambda4:1,  lr: _lr})

                Jm_loss.update(var=J_m_var)
                L1_loss.update(var=L1_var)
                L2_loss.update(var=L2_var)
                L3_loss.update(var=L3_var)
                L4_loss.update(var=L4_var)
                cross_entropy_loss.update(cross_en_var)
                wd_Loss.update(var=wd_Loss_var)

                # evaluation
                if step % RECORD_STEP == 0:
                    nmi_te, f1_te, recalls_te = evaluation.Evaluation(
                        stream_test, image_mean, sess, x_raw, label_raw,
                        is_Training, is_Phase, embedding_zi, 98, neighbours)

                    # Summary
                    eval_summary = tf.Summary()
                    eval_summary.value.add(tag='test nmi', simple_value=nmi_te)
                    eval_summary.value.add(tag='test f1', simple_value=f1_te)
                    for i in range(0, np.shape(neighbours)[0]):
                        eval_summary.value.add(tag='Recall@%d test' %
                                               neighbours[i],
                                               simple_value=recalls_te[i])
                    Jm_loss.write_to_tfboard(eval_summary)
                    eval_summary.value.add(tag='learning_rate',
                                           simple_value=_lr)
                    L1_loss.write_to_tfboard(eval_summary)
                    L2_loss.write_to_tfboard(eval_summary)
                    L3_loss.write_to_tfboard(eval_summary)
                    L4_loss.write_to_tfboard(eval_summary)
                    wd_Loss.write_to_tfboard(eval_summary)
                    cross_entropy_loss.write_to_tfboard(eval_summary)
                    summary_writer.add_summary(eval_summary, step)
                    print('Summary Recorder')
                    if nmi_te > max_nmi:
                        max_nmi = nmi_te
                        saver.save(sess, os.path.join(LOGDIR, "model.ckpt"))
                        print("Model Saved")
                    summary_writer.flush()

        print("Phase 2")
        epoch_iterator = stream_train.get_epoch_iterator()
        for epoch in tqdm(range(NUM_EPOCHS_PHASE2)):
            print("Epoch: ", epoch)
            for batch in tqdm(copy.copy(epoch_iterator), total=MAX_ITER):
                step += 1
                # get images and labels from batch
                x_batch_data, Label_raw = nn_Ops.batch_data(batch)

                # training step
                c_train, g_train, s_train, wd_Loss_var, L1_var,L4_var, J_m_var, \
                    L3_var, L2_var, cross_en_var = sess.run(
                        [c_train_step, g_train_step, s_train_step, wdLoss, L1,
                         L4, J_m, L3, L2, cross_entropy],
                        feed_dict={x_raw: x_batch_data,
                                   label_raw: Label_raw,
                                   is_Training: True,is_Phase:True,
                                   lambda1:0.8, lambda2:1, lambda3:0.2, lambda4:0.8, lr: _lr})

                Jm_loss.update(var=J_m_var)
                L1_loss.update(var=L1_var)
                L2_loss.update(var=L2_var)
                L3_loss.update(var=L3_var)
                L4_loss.update(var=L4_var)
                wd_Loss.update(var=wd_Loss_var)
                cross_entropy_loss.update(cross_en_var)

                # evaluation
                if step % RECORD_STEP == 0:
                    nmi_te, f1_te, recalls_te = evaluation.Evaluation(
                        stream_test, image_mean, sess, x_raw, label_raw,
                        is_Training, is_Phase, embedding_zi, 98, neighbours)

                    # Summary
                    eval_summary = tf.Summary()
                    eval_summary.value.add(tag='test nmi', simple_value=nmi_te)
                    eval_summary.value.add(tag='test f1', simple_value=f1_te)
                    for i in range(0, np.shape(neighbours)[0]):
                        eval_summary.value.add(tag='Recall@%d test' %
                                               neighbours[i],
                                               simple_value=recalls_te[i])
                    Jm_loss.write_to_tfboard(eval_summary)
                    eval_summary.value.add(tag='learning_rate',
                                           simple_value=_lr)
                    L1_loss.write_to_tfboard(eval_summary)
                    L2_loss.write_to_tfboard(eval_summary)
                    L3_loss.write_to_tfboard(eval_summary)
                    L4_loss.write_to_tfboard(eval_summary)
                    wd_Loss.write_to_tfboard(eval_summary)
                    cross_entropy_loss.write_to_tfboard(eval_summary)
                    summary_writer.add_summary(eval_summary, step)
                    print('Summary Recorder')
                    if nmi_te > max_nmi:
                        max_nmi = nmi_te
                        saver.save(sess, os.path.join(LOGDIR, "model.ckpt"))
                        print("Model Saved")
                    summary_writer.flush()
예제 #2
0
def main(_):
    if not FLAGS.LossType == 'NpairLoss':
        print("LossType n-pair-loss is required")
        return 0

    # placeholders
    x_raw = tf.placeholder(
        tf.float32,
        shape=[None, FLAGS.default_image_size, FLAGS.default_image_size, 3])
    label_raw = tf.placeholder(tf.int32, shape=[None, 1])
    with tf.name_scope('istraining'):
        is_Training = tf.placeholder(tf.bool)
    with tf.name_scope('learning_rate'):
        lr = tf.placeholder(tf.float32)

    with tf.variable_scope('Classifier'):
        google_net_model = GoogleNet_Model.GoogleNet_Model()
        embedding = google_net_model.forward(x_raw)
        if FLAGS.Apply_HDML:
            embedding_y_origin = embedding

        # Batch Normalization layer 1
        embedding = nn_Ops.bn_block(embedding,
                                    normal=FLAGS.normalize,
                                    is_Training=is_Training,
                                    name='BN1')

        # FC layer 1
        embedding_z = nn_Ops.fc_block(embedding,
                                      in_d=1024,
                                      out_d=FLAGS.embedding_size,
                                      name='fc1',
                                      is_bn=False,
                                      is_relu=False,
                                      is_Training=is_Training)

        # Embedding Visualization
        assignment, embedding_var = Embedding_Visualization.embedding_assign(
            batch_size=256,
            embedding=embedding_z,
            embedding_size=FLAGS.embedding_size,
            name='Embedding_of_fc1')

        # conventional Loss function
        with tf.name_scope('Loss'):
            # wdLoss = layers.apply_regularization(regularizer, weights_list=None)
            def exclude_batch_norm(name):
                return 'batch_normalization' not in name and 'Generator' not in name and 'Loss' not in name

            wdLoss = FLAGS.Regular_factor * tf.add_n([
                tf.nn.l2_loss(v)
                for v in tf.trainable_variables() if exclude_batch_norm(v.name)
            ])
            # Get the Label
            label = tf.reduce_mean(label_raw, axis=1, keep_dims=False)
            # For some kinds of Losses, the embedding should be l2 normed
            #embedding_l = embedding_z
            J_m = Loss_ops.Loss(embedding_z, label, FLAGS.LossType) + wdLoss

    # if HNG is applied
    if FLAGS.Apply_HDML:
        with tf.name_scope('Javg'):
            Javg = tf.placeholder(tf.float32)
        with tf.name_scope('Jgen'):
            Jgen = tf.placeholder(tf.float32)

        embedding_z_quta = HDML.Pulling(FLAGS.LossType, embedding_z, Javg)

        embedding_z_concate = tf.concat([embedding_z, embedding_z_quta],
                                        axis=0)

        # Generator
        with tf.variable_scope('Generator'):

            # generator fc3
            embedding_y_concate = nn_Ops.fc_block(embedding_z_concate,
                                                  in_d=FLAGS.embedding_size,
                                                  out_d=512,
                                                  name='generator1',
                                                  is_bn=True,
                                                  is_relu=True,
                                                  is_Training=is_Training)

            # generator fc4
            embedding_y_concate = nn_Ops.fc_block(embedding_y_concate,
                                                  in_d=512,
                                                  out_d=1024,
                                                  name='generator2',
                                                  is_bn=False,
                                                  is_relu=False,
                                                  is_Training=is_Training)

            embedding_yp, embedding_yq = HDML.npairSplit(embedding_y_concate)

        with tf.variable_scope('Classifier'):
            embedding_z_quta = nn_Ops.bn_block(embedding_yq,
                                               normal=FLAGS.normalize,
                                               is_Training=is_Training,
                                               name='BN1',
                                               reuse=True)

            embedding_z_quta = nn_Ops.fc_block(embedding_z_quta,
                                               in_d=1024,
                                               out_d=FLAGS.embedding_size,
                                               name='fc1',
                                               is_bn=False,
                                               is_relu=False,
                                               reuse=True,
                                               is_Training=is_Training)

            embedding_zq_anc = tf.slice(
                input_=embedding_z_quta,
                begin=[0, 0],
                size=[int(FLAGS.batch_size / 2),
                      int(FLAGS.embedding_size)])
            embedding_zq_negtile = tf.slice(
                input_=embedding_z_quta,
                begin=[int(FLAGS.batch_size / 2), 0],
                size=[
                    int(np.square(FLAGS.batch_size / 2)),
                    int(FLAGS.embedding_size)
                ])

        with tf.name_scope('Loss'):
            J_syn = (1. -
                     tf.exp(-FLAGS.beta / Jgen)) * Loss_ops.new_npair_loss(
                         labels=label,
                         embedding_anchor=embedding_zq_anc,
                         embedding_positive=embedding_zq_negtile,
                         equal_shape=False,
                         reg_lambda=FLAGS.loss_l2_reg)
            J_m = (tf.exp(-FLAGS.beta / Jgen)) * J_m
            J_metric = J_m + J_syn

            cross_entropy, W_fc, b_fc = HDML.cross_entropy(
                embedding=embedding_y_origin, label=label)

            embedding_yq_anc = tf.slice(input_=embedding_yq,
                                        begin=[0, 0],
                                        size=[int(FLAGS.batch_size / 2), 1024])
            embedding_yq_negtile = tf.slice(
                input_=embedding_yq,
                begin=[int(FLAGS.batch_size / 2), 0],
                size=[int(np.square(FLAGS.batch_size / 2)), 1024])
            J_recon = (1 - FLAGS._lambda) * tf.reduce_sum(
                tf.square(embedding_yp - embedding_y_origin))
            J_soft = HDML.genSoftmax(embedding_anc=embedding_yq_anc,
                                     embedding_neg=embedding_yq_negtile,
                                     W_fc=W_fc,
                                     b_fc=b_fc,
                                     label=label)
            J_gen = J_recon + J_soft

    if FLAGS.Apply_HDML:
        c_train_step = nn_Ops.training(loss=J_metric,
                                       lr=lr,
                                       var_scope='Classifier')
        g_train_step = nn_Ops.training(loss=J_gen,
                                       lr=FLAGS.lr_gen,
                                       var_scope='Generator')
        s_train_step = nn_Ops.training(loss=cross_entropy,
                                       lr=FLAGS.s_lr,
                                       var_scope='Softmax_classifier')
    else:
        train_step = nn_Ops.training(loss=J_m, lr=lr)

    # initialise the session
    with tf.Session(config=config) as sess:
        # Initial all the variables with the sess
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()

        # learning rate
        _lr = FLAGS.init_learning_rate

        # Restore a checkpoint
        if FLAGS.load_formalVal:
            saver.restore(
                sess, FLAGS.log_save_path + FLAGS.dataSet + '/' +
                FLAGS.LossType + '/' + FLAGS.formerTimer)

        # Training
        # epoch_iterator = stream_train.get_epoch_iterator()

        # collectors
        J_m_loss = nn_Ops.data_collector(tag='Jm', init=1e+6)
        J_syn_loss = nn_Ops.data_collector(tag='J_syn', init=1e+6)
        J_metric_loss = nn_Ops.data_collector(tag='J_metric', init=1e+6)
        J_soft_loss = nn_Ops.data_collector(tag='J_soft', init=1e+6)
        J_recon_loss = nn_Ops.data_collector(tag='J_recon', init=1e+6)
        J_gen_loss = nn_Ops.data_collector(tag='J_gen', init=1e+6)
        cross_entropy_loss = nn_Ops.data_collector(tag='cross_entropy',
                                                   init=1e+6)
        wd_Loss = nn_Ops.data_collector(tag='weight_decay', init=1e+6)
        # max_nmi = 0
        # step = 0
        #
        # bp_epoch = FLAGS.init_batch_per_epoch
        # evaluation
        print('only eval!')
        # nmi_tr, f1_tr, recalls_tr = evaluation.Evaluation(
        #     stream_train_eval, image_mean, sess, x_raw, label_raw, is_Training, embedding_z, 98, neighbours)
        # nmi_te, f1_te, recalls_te = evaluation.Evaluation(
        #     stream_test, image_mean, sess, x_raw, label_raw, is_Training, embedding_z, FLAGS.num_class_test, neighbours)
        embeddings, labels = evaluation.Evaluation_icon2(
            stream_train, image_mean, sess, x_raw, label_raw, is_Training,
            embedding_z, FLAGS.num_class_test, neighbours)
        out_dir = os.path.expanduser(
            FLAGS.log_save_path[0:len(FLAGS.log_save_path) - 1] +
            '_embeddings')
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        for idx, distance in enumerate(embeddings):
            save_1Darray(distance,
                         os.path.expanduser(os.path.join(out_dir, str(idx))))
        print('End')
예제 #3
0
def main(_):
    if not FLAGS.LossType == 'Triplet':
        print("LossType triplet loss is required")
        return 0

    # placeholders
    x_raw = tf.placeholder(tf.float32, shape=[None, FLAGS.default_image_size, FLAGS.default_image_size, 3])
    label_raw = tf.placeholder(tf.int32, shape=[None, 1])
    with tf.name_scope('istraining'):
        is_Training = tf.placeholder(tf.bool)
    with tf.name_scope('learning_rate'):
        lr = tf.placeholder(tf.float32)

    with tf.variable_scope('Classifier'):
        google_net_model = GoogleNet_Model.GoogleNet_Model(pooling_type=FLAGS.pooling_type)
        embedding = google_net_model.forward(x_raw)
        embedding = nn_Ops.bn_block(
            embedding, normal=FLAGS.normalize, is_Training=is_Training, name='BN1')

        embedding = nn_Ops.fc_block(
            embedding, in_d=1024, out_d=FLAGS.embedding_size,
            name='fc1', is_bn=False, is_relu=False, is_Training=is_Training
        )
        with tf.name_scope('Loss'):
            # wdLoss = layers.apply_regularization(regularizer, weights_list=None)
            def exclude_batch_norm(name):
                return 'batch_normalization' not in name and 'Generator' not in name and 'Loss' not in name

            wdLoss = FLAGS.Regular_factor * tf.add_n(
                [tf.nn.l2_loss(v) for v in tf.trainable_variables() if exclude_batch_norm(v.name)]
            )
            # Get the Label
            label = tf.reduce_mean(label_raw, axis=1, keep_dims=False)
            # For some kinds of Losses, the embedding should be l2 normed
            #embedding_l = embedding_z
            J_m = Loss_ops.Loss(embedding, label, FLAGS.LossType)
            # J_m=tf.Print(J_m,[J_m])
            J_m = J_m + wdLoss
    with tf.name_scope('Jgen2'):
        Jgen2 = tf.placeholder(tf.float32)
    with tf.name_scope('Pull_dis_mean'):
        Pull_dis_mean = tf.placeholder(tf.float32)
    Pull_dis = nn_Ops.data_collector(tag='pulling_linear', init=25)

    embedding_l, disap = Two_Stage_Model.Pulling_Positivate(FLAGS.LossType, embedding, can=FLAGS.alpha)


    with tf.variable_scope('Generator1'):
        embedding_g = Two_Stage_Model.generator_ori(embedding_l, FLAGS.LossType)
        with tf.name_scope('Loss'):
            def exclude_batch_norm1(name):
                return 'batch_normalization' not in name and 'Generator1' in name and 'Loss' not in name

            wdLoss_g1 = FLAGS.Regular_factor * tf.add_n(
                [tf.nn.l2_loss(v) for v in tf.trainable_variables() if exclude_batch_norm1(v.name)]
            )
    # Generator1_S
    with tf.variable_scope('GeneratorS'):
        embedding_s = Two_Stage_Model.generator_ori(embedding_g, FLAGS.LossType)
        with tf.name_scope('Loss'):
            def exclude_batch_norm2(name):
                return 'batch_normalization' not in name and 'GeneratorS' in name and 'Loss' not in name

            wdLoss_g1s = FLAGS.Regular_factor * tf.add_n(
                [tf.nn.l2_loss(v) for v in tf.trainable_variables() if exclude_batch_norm2(v.name)]
            )

    # Discriminator1 only contains the anchor and positive message
    with tf.variable_scope('Discriminator1') as scope:
        embedding_dis1 = Two_Stage_Model.discriminator1(Two_Stage_Model.slice_ap_n(embedding),1)
        scope.reuse_variables()
        embedding_g_dis1 = Two_Stage_Model.discriminator1(Two_Stage_Model.slice_ap_n(embedding_g),1)
        # scope.reuse_variables()
        # embedding_s_dis1 = Two_Stage_Model.discriminator1(Two_Stage_Model.slice_ap_n(embedding_s))


    with tf.variable_scope('DiscriminatorS') as scope:
        embedding_dis1S = Two_Stage_Model.discriminator1(Two_Stage_Model.slice_ap_n(embedding),1)
        scope.reuse_variables()
        embedding_s_dis1S = Two_Stage_Model.discriminator1(Two_Stage_Model.slice_ap_n(embedding_s),1)

    with tf.variable_scope('Generator2'):
        embedding_h = Two_Stage_Model.generator_ori(embedding_g)
    with tf.variable_scope('Discriminator2') as scope:
        embedding_dis2 = Two_Stage_Model.discriminator2(embedding)
        scope.reuse_variables()
        embedding_g_dis2 = Two_Stage_Model.discriminator2(embedding_g)
        scope.reuse_variables()
        embedding_h_dis2 = Two_Stage_Model.discriminator2(embedding_h)

    embedding_h_cli = embedding_h

    '''
        using binary_crossentropy replace sparse_softmax_cross_entropy_with_logits
    '''
    with tf.name_scope('Loss'):
        J_syn = Loss_ops.Loss(embedding_h_cli, label, _lossType=FLAGS.LossType, hard_ori=FLAGS.HARD_ORI)
        # J_syn = tf.constant(0.)
        J_m = J_m
        para1 = tf.exp(-FLAGS.beta / Jgen2)
        J_metric = para1 * J_m + (1. - para1) * J_syn

        real_loss_d1 = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.zeros([FLAGS.batch_size*2/3], dtype=tf.int32), logits=embedding_dis1))
        generated1_loss_d1 = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.ones([FLAGS.batch_size*2/3], dtype=tf.int32), logits=embedding_g_dis1))
        J_LD1 = FLAGS.Softmax_factor * -(tf.reduce_mean(tf.log(1. - tf.nn.sigmoid(embedding_dis1)) + tf.log(tf.nn.sigmoid(embedding_g_dis1))))

        J_LG1 = FLAGS.Softmax_factor * (-tf.reduce_mean(tf.log(1. - tf.nn.sigmoid(embedding_g_dis1)))) + wdLoss_g1
        embedding_g_split = tf.split(embedding_g, 3, axis=0)
        embedding_g_split_anc = embedding_g_split[0]
        embedding_g_split_pos = embedding_g_split[1]
        dis_g1 = tf.reduce_mean(
            Two_Stage_Model.distance(embedding_g_split_anc, embedding_g_split_pos))
        dis_g1 = tf.maximum(-dis_g1 + 1000, 0.)*10
        J_LG1 = J_LG1

        '''add for D1S'''
        J_LD1S = FLAGS.Softmax_factor * (-tf.reduce_mean(tf.log(1. - tf.nn.sigmoid(embedding_dis1S)) + tf.log(tf.nn.sigmoid(embedding_s_dis1S))))

        J_LG1_S_cross = FLAGS.Softmax_factor *(-tf.reduce_mean(tf.log(1. - tf.nn.sigmoid(embedding_s_dis1S))))
        recon_ori_s = FLAGS.Recon_factor * tf.reduce_mean(Two_Stage_Model.distance(embedding_s, embedding))
        J_LG1_S = J_LG1_S_cross + recon_ori_s + wdLoss_g1s

        # label_onehot = tf.one_hot(label, FLAGS.num_class+1)
        real_loss_d2 = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=embedding_dis2))
        generated2_loss_d2 = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=embedding_g_dis2))
        generated2_h_loss_d2 = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
                         labels=tf.zeros([FLAGS.batch_size] ,dtype=tf.int32)+FLAGS.num_class, logits=embedding_h_dis2))
        J_LD2 = FLAGS.Softmax_factor * (real_loss_d2 + generated2_loss_d2 + generated2_h_loss_d2) / 3
        J_LD2 = J_LD2 + J_syn

        cross_entropy, W_fc, b_fc = THSG.cross_entropy(embedding=embedding, label=label)
        Logits_q = tf.matmul(embedding_h, W_fc) + b_fc
        J_LG2_cross_entropy = FLAGS.Softmax_factor * FLAGS._lambda * tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=Logits_q))

        J_LG2C_cross_GAN = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=embedding_h_dis2))
        J_LG2C_cross = FLAGS.Softmax_factor * (J_LG2C_cross_GAN + J_LG2_cross_entropy) / 2
        recon_g_h_ancpos = FLAGS.Recon_factor * tf.reduce_mean(Two_Stage_Model.distance(Two_Stage_Model.slice_ap_n(embedding_g),
                                                                   Two_Stage_Model.slice_ap_n(embedding_h)))

        J_fan = FLAGS.Softmax_factor * Loss_ops.Loss_fan(embedding_h, label, _lossType=FLAGS.LossType,
                                                          param=2 - tf.exp(-FLAGS.beta / Jgen2),
                                                          hard_ori=FLAGS.HARD_ORI)
        J_LG2 = J_LG2C_cross + recon_g_h_ancpos + J_fan
        # J_LG2 = tf.Print(J_LG2, [J_LG2])

        J_F = J_metric
    c_train_step = nn_Ops.training(loss=J_F, lr=lr, var_scope='Classifier')
    d1_train_step = nn_Ops.training(loss=J_LD1, lr=FLAGS.lr_dis, var_scope='Discriminator1')
    g1_train_step = nn_Ops.training(loss=J_LG1, lr=FLAGS.lr_gen, var_scope='Generator1')
    d1s_train_step = nn_Ops.training(loss=J_LD1S, lr=FLAGS.lr_dis, var_scope='DiscriminatorS')
    g1s_train_step = nn_Ops.training(loss=J_LG1_S, lr=FLAGS.lr_gen, var_scope='GeneratorS*Generator1')

    d2_train_step = nn_Ops.training(loss=J_LD2, lr=FLAGS.lr_dis, var_scope='Discriminator2')
    g2_train_step = nn_Ops.training(loss=J_LG2, lr=FLAGS.lr_gen, var_scope='Generator2')

    s_train_step = nn_Ops.training(loss=cross_entropy, lr=FLAGS.s_lr, var_scope='Softmax_classifier')

    # initialise the session
    with tf.Session(config=config) as sess:
        summary_writer = tf.summary.FileWriter(LOGDIR, sess.graph)
        # Initial all the variables with the sess
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()

        # learning rate
        _lr = FLAGS.init_learning_rate

        # Restore a checkpoint
        if FLAGS.load_formalVal:
            saver.restore(sess, FLAGS.log_save_path+FLAGS.dataSet+'/'+FLAGS.LossType+'/'+FLAGS.formerTimer)

        # Training
        epoch_iterator = stream_train.get_epoch_iterator()

        # collectors
        J_m_loss = nn_Ops.data_collector(tag='J_m', init=1e+6)
        J_syn_loss = nn_Ops.data_collector(tag='J_syn', init=1e+6)
        J_metric_loss = nn_Ops.data_collector(tag='J_metric', init=1e+6)

        real_loss_d1_loss = nn_Ops.data_collector(tag='real_loss_d1', init=1e+6)
        generated1_loss_d1_loss = nn_Ops.data_collector(tag='generated1_loss_d1', init=1e+6)
        J_LD1S_loss = nn_Ops.data_collector(tag='L_D1S', init=1e+6)
        J_LG1_loss = nn_Ops.data_collector(tag='J_LG1', init=1e+6)

        J_LG1_S_cross_loss = nn_Ops.data_collector(tag='J_LG1_S_cross', init=1e+6)
        recon_ori_s_loss = nn_Ops.data_collector(tag='recon_ori_s', init=1e+6)

        real_loss_d2_loss = nn_Ops.data_collector(tag='real_loss_d2', init=1e+6)
        generated2_loss_d2_loss = nn_Ops.data_collector(tag='generated2_loss_d2', init=1e+6)
        generated2_h_loss_d2_loss = nn_Ops.data_collector(tag='generated2_h_loss_d2', init=1e+6)

        J_LG2C_cross_loss = nn_Ops.data_collector(tag='J_LG2C_cross', init=1e+6)
        recon_g_h_ancpos_loss = nn_Ops.data_collector(tag='recon_g_h_ancpos', init=1e+6)
        J_LG2_loss = nn_Ops.data_collector(tag='J_LG2', init=1e+6)
        J_fan_loss = nn_Ops.data_collector(tag='J_fan', init=1e+6)

        J_LD1_loss = nn_Ops.data_collector(tag='J_LD1', init=1e+6)
        J_LD2_loss = nn_Ops.data_collector(tag='J_LD2', init=1e+6)
        J_F_loss = nn_Ops.data_collector(tag='J_F', init=1e+6)
        cross_entropy_loss = nn_Ops.data_collector(tag='cross_entropy', init=1e+6)
        dis_g1_loss = nn_Ops.data_collector(tag='dis_g1', init=1e+6)

        max_nmi = 0
        step = 0

        bp_epoch = FLAGS.init_batch_per_epoch
        with tqdm(total=FLAGS.max_steps) as pbar:
            for batch in copy.copy(epoch_iterator):
                # get images and labels from batch
                x_batch_data, Label_raw = nn_Ops.batch_data(batch)
                pbar.update(1)
                _, _, _, disap_var = sess.run([d1_train_step, d1s_train_step, d2_train_step,disap], feed_dict={x_raw: x_batch_data,
                                                                   label_raw: Label_raw,
                                                                   is_Training: True, lr: _lr,
                                                                   Pull_dis_mean : Pull_dis.read()})
                Pull_dis.update(var=disap_var.mean()*0.8)
                _, _, _, disap_var = sess.run([d1_train_step, d1s_train_step, d2_train_step,disap], feed_dict={x_raw: x_batch_data,
                                                                    label_raw: Label_raw,
                                                                    is_Training: True, lr: _lr,
                                                                    Pull_dis_mean : Pull_dis.read()})
                Pull_dis.update(var=disap_var.mean()*0.8)
                c_train, s_train, d1_train, g1_train, d1s_train, g1s_train, d2_train, g2_train, real_loss_d2_var, J_metric_var, J_m_var, \
                    J_syn_var, real_loss_d1_var, generated1_loss_d1_var, J_LD1_var, J_LD2_var, J_LD1S_var, \
                    J_LG1_var, J_LG1_S_cross_var, recon_ori_s_var,  real_loss_d2_var, generated2_loss_d2_var, cross_entropy_var, \
                    generated2_h_loss_d2_var, J_LG2C_cross_var, recon_g_h_ancpos_var, J_LG2_var, J_fan_var, J_F_var, dis_g1_var,disap_var\
                    = sess.run(
                        [c_train_step, s_train_step, d1_train_step, g1_train_step, d1s_train_step, g1s_train_step, d2_train_step, g2_train_step,
                         real_loss_d2, J_metric, J_m, J_syn, real_loss_d1, generated1_loss_d1, J_LD1, J_LD2,
                         J_LD1S, J_LG1, J_LG1_S_cross, recon_ori_s, real_loss_d2,
                         generated2_loss_d2, cross_entropy, generated2_h_loss_d2, J_LG2C_cross, recon_g_h_ancpos, J_LG2, J_fan, J_F, dis_g1,disap],
                        feed_dict={x_raw: x_batch_data,
                                   label_raw: Label_raw,
                                   is_Training: True, lr: _lr, Jgen2: J_LG2_loss.read(), Pull_dis_mean : Pull_dis.read()})
                Pull_dis.update(var=disap_var.mean()*0.8)
                real_loss_d2_loss.update(var=real_loss_d2_var)
                J_metric_loss.update(var=J_metric_var)
                J_m_loss.update(var=J_m_var)
                J_syn_loss.update(var=J_syn_var)
                real_loss_d1_loss.update(var=real_loss_d1_var)
                generated1_loss_d1_loss.update(var=generated1_loss_d1_var)
                J_LD1S_loss.update(var=J_LD1S_var)
                dis_g1_loss.update(var=dis_g1_var)
                J_LG1_loss.update(var=J_LG1_var)
                J_LG1_S_cross_loss.update(var=J_LG1_S_cross_var)
                recon_ori_s_loss.update(var=recon_ori_s_var)
                real_loss_d2_loss.update(var=real_loss_d2_var)
                generated2_loss_d2_loss.update(var=generated2_loss_d2_var)
                generated2_h_loss_d2_loss.update(var=generated2_h_loss_d2_var)
                J_LG2C_cross_loss.update(var=J_LG2C_cross_var)
                recon_g_h_ancpos_loss.update(var=recon_g_h_ancpos_var)
                J_LG2_loss.update(var=J_LG2_var)
                J_fan_loss.update(var=J_fan_var)
                J_LD1_loss.update(var=J_LD1_var)
                J_LD2_loss.update(var=J_LD2_var)
                J_F_loss.update(var=J_F_var)
                cross_entropy_loss.update(var=cross_entropy_var)
                step += 1
                # print('learning rate %f' % _lr)

                # evaluation
                if step % bp_epoch == 0:
                    print('only eval eval')
                    # nmi_te_cli, f1_te_cli, recalls_te_cli, map_cli = evaluation.Evaluation(
                    #     stream_test, image_mean, sess, x_raw, label_raw, is_Training, embedding, 98, neighbours)
                    recalls_te_cli, map_cli = evaluation.Evaluation(
                            stream_test, image_mean, sess, x_raw, label_raw, is_Training, embedding, 98, neighbours)
                    # Summary
                    eval_summary = tf.Summary()
                    # eval_summary.value.add(tag='test nmi', simple_value=nmi_te_cli)
                    # eval_summary.value.add(tag='test f1', simple_value=f1_te_cli)
                    eval_summary.value.add(tag='test map', simple_value=map_cli)
                    for i in range(0, np.shape(neighbours)[0]):
                        eval_summary.value.add(tag='Recall@%d test' % neighbours[i], simple_value=recalls_te_cli[i])

                    # Embedding_Visualization.embedding_Visual("./", embedding_var, summary_writer)

                    real_loss_d2_loss.write_to_tfboard(eval_summary)
                    J_metric_loss.write_to_tfboard(eval_summary)
                    J_m_loss.write_to_tfboard(eval_summary)
                    J_syn_loss.write_to_tfboard(eval_summary)
                    real_loss_d1_loss.write_to_tfboard(eval_summary)
                    generated1_loss_d1_loss.write_to_tfboard(eval_summary)
                    J_LD1S_loss.write_to_tfboard(eval_summary)
                    J_LG1_loss.write_to_tfboard(eval_summary)
                    dis_g1_loss.write_to_tfboard(eval_summary)
                    J_LD1_loss.write_to_tfboard(eval_summary)
                    J_LD2_loss.write_to_tfboard(eval_summary)
                    J_F_loss.write_to_tfboard(eval_summary)
                    J_LG1_S_cross_loss.write_to_tfboard(eval_summary)
                    recon_ori_s_loss.write_to_tfboard(eval_summary)
                    real_loss_d2_loss.write_to_tfboard(eval_summary)
                    generated2_loss_d2_loss.write_to_tfboard(eval_summary)
                    generated2_h_loss_d2_loss.write_to_tfboard(eval_summary)
                    J_LG2C_cross_loss.write_to_tfboard(eval_summary)
                    recon_g_h_ancpos_loss.write_to_tfboard(eval_summary)
                    J_LG2_loss.write_to_tfboard(eval_summary)
                    J_fan_loss.write_to_tfboard(eval_summary)
                    cross_entropy_loss.write_to_tfboard(eval_summary)
                    summary_writer.add_summary(eval_summary, step)
                    print('Summary written')
                    if map_cli > max_nmi:
                        max_nmi = map_cli
                        print("Saved")
                        saver.save(sess, os.path.join(LOGDIR, "model.ckpt"))
                    # saver.save(sess, os.path.join(LOGDIR, "model.ckpt"))
                    summary_writer.flush()
                    if step in [5632, 6848]:
                        _lr = _lr * 0.5

                    if step >= 5000:
                        bp_epoch = FLAGS.batch_per_epoch
                    if step >= FLAGS.max_steps:
                        os._exit(0)
예제 #4
0
def main(_):
    if not FLAGS.LossType == 'NpairLoss':
        print("LossType n-pair-loss is required")
        return 0

    # placeholders
    x_raw = tf.placeholder(
        tf.float32,
        shape=[None, FLAGS.default_image_size, FLAGS.default_image_size, 3])
    label_raw = tf.placeholder(tf.int32, shape=[None, 1])
    with tf.name_scope('istraining'):
        is_Training = tf.placeholder(tf.bool)
    with tf.name_scope('learning_rate'):
        lr = tf.placeholder(tf.float32)

    with tf.variable_scope('Classifier'):
        google_net_model = GoogleNet_Model.GoogleNet_Model()
        embedding = google_net_model.forward(x_raw)
        if FLAGS.Apply_HDML:
            embedding_y_origin = embedding

        # Batch Normalization layer 1
        embedding = nn_Ops.bn_block(embedding,
                                    normal=FLAGS.normalize,
                                    is_Training=is_Training,
                                    name='BN1')

        # FC layer 1
        embedding_z = nn_Ops.fc_block(embedding,
                                      in_d=1024,
                                      out_d=FLAGS.embedding_size,
                                      name='fc1',
                                      is_bn=False,
                                      is_relu=False,
                                      is_Training=is_Training)

        # Embedding Visualization
        assignment, embedding_var = Embedding_Visualization.embedding_assign(
            batch_size=256,
            embedding=embedding_z,
            embedding_size=FLAGS.embedding_size,
            name='Embedding_of_fc1')

        # conventional Loss function
        with tf.name_scope('Loss'):
            # wdLoss = layers.apply_regularization(regularizer, weights_list=None)
            def exclude_batch_norm(name):
                return 'batch_normalization' not in name and 'Generator' not in name and 'Loss' not in name

            wdLoss = FLAGS.Regular_factor * tf.add_n([
                tf.nn.l2_loss(v)
                for v in tf.trainable_variables() if exclude_batch_norm(v.name)
            ])
            # Get the Label
            label = tf.reduce_mean(label_raw, axis=1, keep_dims=False)
            # For some kinds of Losses, the embedding should be l2 normed
            #embedding_l = embedding_z
            J_m = Loss_ops.Loss(embedding_z, label, FLAGS.LossType) + wdLoss

    # if HNG is applied
    if FLAGS.Apply_HDML:
        with tf.name_scope('Javg'):
            Javg = tf.placeholder(tf.float32)
        with tf.name_scope('Jgen'):
            Jgen = tf.placeholder(tf.float32)

        embedding_z_quta = HDML.Pulling(FLAGS.LossType, embedding_z, Javg)

        embedding_z_concate = tf.concat([embedding_z, embedding_z_quta],
                                        axis=0)

        # Generator
        with tf.variable_scope('Generator'):

            # generator fc3
            embedding_y_concate = nn_Ops.fc_block(embedding_z_concate,
                                                  in_d=FLAGS.embedding_size,
                                                  out_d=512,
                                                  name='generator1',
                                                  is_bn=True,
                                                  is_relu=True,
                                                  is_Training=is_Training)

            # generator fc4
            embedding_y_concate = nn_Ops.fc_block(embedding_y_concate,
                                                  in_d=512,
                                                  out_d=1024,
                                                  name='generator2',
                                                  is_bn=False,
                                                  is_relu=False,
                                                  is_Training=is_Training)

            embedding_yp, embedding_yq = HDML.npairSplit(embedding_y_concate)

        with tf.variable_scope('Classifier'):
            embedding_z_quta = nn_Ops.bn_block(embedding_yq,
                                               normal=FLAGS.normalize,
                                               is_Training=is_Training,
                                               name='BN1',
                                               reuse=True)

            embedding_z_quta = nn_Ops.fc_block(embedding_z_quta,
                                               in_d=1024,
                                               out_d=FLAGS.embedding_size,
                                               name='fc1',
                                               is_bn=False,
                                               is_relu=False,
                                               reuse=True,
                                               is_Training=is_Training)

            embedding_zq_anc = tf.slice(
                input_=embedding_z_quta,
                begin=[0, 0],
                size=[int(FLAGS.batch_size / 2),
                      int(FLAGS.embedding_size)])
            embedding_zq_negtile = tf.slice(
                input_=embedding_z_quta,
                begin=[int(FLAGS.batch_size / 2), 0],
                size=[
                    int(np.square(FLAGS.batch_size / 2)),
                    int(FLAGS.embedding_size)
                ])

        with tf.name_scope('Loss'):
            J_syn = (1. -
                     tf.exp(-FLAGS.beta / Jgen)) * Loss_ops.new_npair_loss(
                         labels=label,
                         embedding_anchor=embedding_zq_anc,
                         embedding_positive=embedding_zq_negtile,
                         equal_shape=False,
                         reg_lambda=FLAGS.loss_l2_reg)
            J_m = (tf.exp(-FLAGS.beta / Jgen)) * J_m
            J_metric = J_m + J_syn

            cross_entropy, W_fc, b_fc = HDML.cross_entropy(
                embedding=embedding_y_origin, label=label)

            embedding_yq_anc = tf.slice(input_=embedding_yq,
                                        begin=[0, 0],
                                        size=[int(FLAGS.batch_size / 2), 1024])
            embedding_yq_negtile = tf.slice(
                input_=embedding_yq,
                begin=[int(FLAGS.batch_size / 2), 0],
                size=[int(np.square(FLAGS.batch_size / 2)), 1024])
            J_recon = (1 - FLAGS._lambda) * tf.reduce_sum(
                tf.square(embedding_yp - embedding_y_origin))
            J_soft = HDML.genSoftmax(embedding_anc=embedding_yq_anc,
                                     embedding_neg=embedding_yq_negtile,
                                     W_fc=W_fc,
                                     b_fc=b_fc,
                                     label=label)
            J_gen = J_recon + J_soft

    if FLAGS.Apply_HDML:
        c_train_step = nn_Ops.training(loss=J_metric,
                                       lr=lr,
                                       var_scope='Classifier')
        g_train_step = nn_Ops.training(loss=J_gen,
                                       lr=FLAGS.lr_gen,
                                       var_scope='Generator')
        s_train_step = nn_Ops.training(loss=cross_entropy,
                                       lr=FLAGS.s_lr,
                                       var_scope='Softmax_classifier')
    else:
        train_step = nn_Ops.training(loss=J_m, lr=lr)

    # initialise the session
    with tf.Session(config=config) as sess:
        # Initial all the variables with the sess
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()

        # learning rate
        _lr = FLAGS.init_learning_rate

        # Restore a checkpoint
        if FLAGS.load_formalVal:
            saver.restore(
                sess, FLAGS.log_save_path + FLAGS.dataSet + '/' +
                FLAGS.LossType + '/' + FLAGS.formerTimer)

        # Training
        epoch_iterator = stream_train.get_epoch_iterator()

        # collectors
        J_m_loss = nn_Ops.data_collector(tag='Jm', init=1e+6)
        J_syn_loss = nn_Ops.data_collector(tag='J_syn', init=1e+6)
        J_metric_loss = nn_Ops.data_collector(tag='J_metric', init=1e+6)
        J_soft_loss = nn_Ops.data_collector(tag='J_soft', init=1e+6)
        J_recon_loss = nn_Ops.data_collector(tag='J_recon', init=1e+6)
        J_gen_loss = nn_Ops.data_collector(tag='J_gen', init=1e+6)
        cross_entropy_loss = nn_Ops.data_collector(tag='cross_entropy',
                                                   init=1e+6)
        wd_Loss = nn_Ops.data_collector(tag='weight_decay', init=1e+6)
        max_nmi = 0
        step = 0

        bp_epoch = FLAGS.init_batch_per_epoch
        with tqdm(total=FLAGS.max_steps) as pbar:
            for batch in copy.copy(epoch_iterator):
                # get images and labels from batch
                x_batch_data, Label_raw = nn_Ops.batch_data(batch)
                pbar.update(1)
                if not FLAGS.Apply_HDML:
                    train, J_m_var, wd_Loss_var = sess.run(
                        [train_step, J_m, wdLoss],
                        feed_dict={
                            x_raw: x_batch_data,
                            label_raw: Label_raw,
                            is_Training: True,
                            lr: _lr
                        })
                    J_m_loss.update(var=J_m_var)
                    wd_Loss.update(var=wd_Loss_var)

                else:
                    c_train, g_train, s_train, wd_Loss_var, J_metric_var, J_m_var, \
                        J_syn_var, J_recon_var,  J_soft_var, J_gen_var, cross_en_var = sess.run(
                            [c_train_step, g_train_step, s_train_step, wdLoss,
                             J_metric, J_m, J_syn, J_recon, J_soft, J_gen, cross_entropy],
                            feed_dict={x_raw: x_batch_data,
                                       label_raw: Label_raw,
                                       is_Training: True, lr: _lr, Javg: J_m_loss.read(), Jgen: J_gen_loss.read()})
                    wd_Loss.update(var=wd_Loss_var)
                    J_metric_loss.update(var=J_metric_var)
                    J_m_loss.update(var=J_m_var)
                    J_syn_loss.update(var=J_syn_var)
                    J_recon_loss.update(var=J_recon_var)
                    J_soft_loss.update(var=J_soft_var)
                    J_gen_loss.update(var=J_gen_var)
                    cross_entropy_loss.update(cross_en_var)
                step += 1
                # print('learning rate %f' % _lr)

                # evaluation
                if step % bp_epoch == 0:
                    print('only eval eval')
                    # nmi_tr, f1_tr, recalls_tr = evaluation.Evaluation(
                    #     stream_train_eval, image_mean, sess, x_raw, label_raw, is_Training, embedding_z, 98, neighbours)
                    nmi_te, f1_te, recalls_te = evaluation.Evaluation(
                        stream_test, image_mean, sess, x_raw, label_raw,
                        is_Training, embedding_z, FLAGS.num_class_test,
                        neighbours)

                    # Summary
                    eval_summary = tf.Summary()
                    # eval_summary.value.add(tag='train nmi', simple_value=nmi_tr)
                    # eval_summary.value.add(tag='train f1', simple_value=f1_tr)
                    # for i in range(0, np.shape(neighbours)[0]):
                    #     eval_summary.value.add(tag='Recall@%d train' % neighbours[i], simple_value=recalls_tr[i])
                    eval_summary.value.add(tag='test nmi', simple_value=nmi_te)
                    eval_summary.value.add(tag='test f1', simple_value=f1_te)
                    for i in range(0, np.shape(neighbours)[0]):
                        eval_summary.value.add(tag='Recall@%d test' %
                                               neighbours[i],
                                               simple_value=recalls_te[i])
                    J_m_loss.write_to_tfboard(eval_summary)
                    wd_Loss.write_to_tfboard(eval_summary)
                    eval_summary.value.add(tag='learning_rate',
                                           simple_value=_lr)
                    if FLAGS.Apply_HDML:
                        J_syn_loss.write_to_tfboard(eval_summary)
                        J_metric_loss.write_to_tfboard(eval_summary)
                        J_soft_loss.write_to_tfboard(eval_summary)
                        J_recon_loss.write_to_tfboard(eval_summary)
                        J_gen_loss.write_to_tfboard(eval_summary)
                        cross_entropy_loss.write_to_tfboard(eval_summary)
                    summary_writer.add_summary(eval_summary, step)
                    print('Summary written')
                    if nmi_te > max_nmi:
                        max_nmi = nmi_te
                        print("Saved")
                        saver.save(sess, os.path.join(LOGDIR,
                                                      "modelBest.ckpt"))
                    saver.save(sess, os.path.join(LOGDIR, "model.ckpt"))
                    summary_writer.flush()
                    # if step in [5632, 6848]:
                    #     _lr = _lr * 0.5

                    if step >= 5000:
                        bp_epoch = FLAGS.batch_per_epoch
                    if step >= FLAGS.max_steps:
                        os._exit()
def main(_):
    x_raw = tf.placeholder(tf.float32, shape=[None, IMAGE_SIZE, IMAGE_SIZE, 3])
    label_raw = tf.placeholder(tf.int32, shape=[None, 1])
    with tf.name_scope('istraining'):
        is_Training = tf.placeholder(tf.bool)
    with tf.name_scope('isphase'):
        is_Phase = tf.placeholder(tf.bool)
    with tf.name_scope('learning_rate'):
        lr = tf.placeholder(tf.float32)
    with tf.name_scope('lambdas'):
        lambda1 = tf.placeholder(tf.float32)
        lambda2 = tf.placeholder(tf.float32)
        lambda3 = tf.placeholder(tf.float32)
        lambda4 = tf.placeholder(tf.float32)

    with tf.variable_scope('Feature_extractor'):
        google_net_model = GoogleNet_Model.GoogleNet_Model()
        embedding = google_net_model.forward(x_raw)
        embedding_y_origin = embedding
        # output
        embedding = nn_Ops.bn_block(embedding,
                                    normal=True,
                                    is_Training=is_Training,
                                    name='FC3')
        embedding_z = nn_Ops.fc_block(embedding,
                                      in_d=1024,
                                      out_d=EMBEDDING_SIZE,
                                      name='fc1',
                                      is_bn=False,
                                      is_relu=False,
                                      is_Training=is_Training)
        # predict mu
        embedding1 = nn_Ops.bn_block(embedding,
                                     normal=True,
                                     is_Training=is_Training,
                                     name='FC1')
        embedding_mu = nn_Ops.fc_block(embedding1,
                                       in_d=1024,
                                       out_d=EMBEDDING_SIZE,
                                       name='fc2',
                                       is_bn=False,
                                       is_relu=False,
                                       is_Training=is_Training)
        # predict (log (sigma^2))
        embedding2 = nn_Ops.bn_block(embedding,
                                     normal=True,
                                     is_Training=is_Training,
                                     name='FC2')
        embedding_sigma = nn_Ops.fc_block(embedding2,
                                          in_d=1024,
                                          out_d=EMBEDDING_SIZE,
                                          name='fc3',
                                          is_bn=False,
                                          is_relu=False,
                                          is_Training=is_Training)

        with tf.name_scope('Loss'):

            def exclude_batch_norm(name):
                return 'batch_normalization' not in name and 'Generator' not in name and 'Loss' not in name

            wdLoss = 5e-3 * tf.add_n([
                tf.nn.l2_loss(v)
                for v in tf.trainable_variables() if exclude_batch_norm(v.name)
            ])
            label = tf.reduce_mean(label_raw, axis=1)
            J_m = Losses.triplet_semihard_loss(label, embedding_z) + wdLoss

    zv_emb = mysampling(embedding_mu, embedding_sigma)
    # zv = tfd.Independent(tfd.Normal(loc=embedding_mu, scale=embedding_sigma),
    # reinterpreted_batch_ndims=1)
    # zv_emb=zv.sample([1])
    zv_emb1 = tf.reshape(zv_emb, (-1, 128))
    # zv_prob=zv.prob(zv_emb)
    # prior_prob=zv.prob(zv_emb)

    embedding_z_add = tf.add(embedding_z, zv_emb1, name='Synthesized_features')
    with tf.variable_scope('Decoder'):
        embedding_y_add = nn_Ops.fc_block(embedding_z_add,
                                          in_d=EMBEDDING_SIZE,
                                          out_d=512,
                                          name='decoder1',
                                          is_bn=True,
                                          is_relu=True,
                                          is_Training=is_Phase)
        embedding_y_add = nn_Ops.fc_block(embedding_y_add,
                                          in_d=512,
                                          out_d=1024,
                                          name='decoder2',
                                          is_bn=False,
                                          is_relu=False,
                                          is_Training=is_Phase)
    print("embedding_sigma", embedding_sigma)
    print("embedding_mu", embedding_mu)
    with tf.name_scope('Loss_KL'):
        kl_loss = 1 + embedding_sigma - K.square(embedding_mu) - K.exp(
            embedding_sigma)
        kl_loss = K.sum(kl_loss, axis=-1)
        kl_loss *= -0.5
        # L1 loss
        J_KL = lambda1 * K.mean(kl_loss)

    with tf.name_scope('Loss_Recon'):

        # L2 Loss
        J_recon = lambda2 * (0.5) * tf.reduce_sum(
            tf.square(embedding_y_add - embedding_y_origin))

    with tf.name_scope('Loss_synthetic'):

        # L3 Loss
        J_syn = lambda3 * Losses.triplet_semihard_loss(
            labels=label, embeddings=embedding_z_add)

    with tf.name_scope('Loss_metric'):

        # L4 Loss
        J_metric = lambda4 * J_m

    with tf.name_scope('Loss_Softmax_classifier'):
        cross_entropy, W_fc, b_fc = Losses.cross_entropy(
            embedding=embedding_y_origin, label=label)

    c_train_step = nn_Ops.training(loss=J_metric + J_syn + J_KL,
                                   lr=lr,
                                   var_scope='Feature_extractor')
    g_train_step = nn_Ops.training(loss=J_recon,
                                   lr=LR_gen,
                                   var_scope='Decoder')
    s_train_step = nn_Ops.training(loss=cross_entropy,
                                   lr=LR_s,
                                   var_scope='Softmax_classifier')

    def model_summary():
        model_vars = tf.trainable_variables()
        slim.model_analyzer.analyze_vars(model_vars, print_info=True)

    model_summary()
    with tf.Session(config=config) as sess:
        summary_writer = tf.summary.FileWriter(LOGDIR, sess.graph)
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        _lr = LR_init

        epoch_iterator = stream_train.get_epoch_iterator()
        J_KL_loss = nn_Ops.data_collector(tag='JKL', init=1e+6)
        J_m_loss = nn_Ops.data_collector(tag='Jm', init=1e+6)
        J_syn_loss = nn_Ops.data_collector(tag='J_syn', init=1e+6)
        J_metric_loss = nn_Ops.data_collector(tag='J_metric', init=1e+6)
        J_recon_loss = nn_Ops.data_collector(tag='J_recon', init=1e+6)
        cross_entropy_loss = nn_Ops.data_collector(tag='cross_entropy',
                                                   init=1e+6)
        wd_Loss = nn_Ops.data_collector(tag='weight_decay', init=1e+6)
        max_nmi = 0
        step = 0

        with tqdm(total=MAX_ITER) as pbar:
            for batch in copy.copy(epoch_iterator):
                # get images and labels from batch
                x_batch_data, Label_raw = nn_Ops.batch_data(batch)
                pbar.update(1)
                c_train, g_train, s_train, wd_Loss_var, J_KL_var,J_metric_var, J_m_var, \
                    J_syn_var, J_recon_var, cross_en_var = sess.run(
                        [c_train_step, g_train_step, s_train_step, wdLoss, J_KL,
                         J_metric, J_m, J_syn, J_recon, cross_entropy],
                        feed_dict={x_raw: x_batch_data,
                                   label_raw: Label_raw,
                                   is_Training: True,is_Phase:False,
                                   lambda1:1, lambda2:0.5, lambda3:0.5, lambda4:1,  lr: _lr})

                wd_Loss.update(var=wd_Loss_var)
                J_KL_loss.update(var=J_KL_var)
                J_metric_loss.update(var=J_metric_var)
                J_m_loss.update(var=J_m_var)
                J_syn_loss.update(var=J_syn_var)
                J_recon_loss.update(var=J_recon_var)
                cross_entropy_loss.update(cross_en_var)
                step += 1

                # evaluation
                if step % bp_epoch == 0:
                    print('only eval eval')
                    nmi_te, f1_te, recalls_te = evaluation.Evaluation(
                        stream_test, image_mean, sess, x_raw, label_raw,
                        is_Training, is_Phase, embedding_z, 98, neighbours)

                    # Summary
                    eval_summary = tf.Summary()
                    eval_summary.value.add(tag='test nmi', simple_value=nmi_te)
                    eval_summary.value.add(tag='test f1', simple_value=f1_te)
                    for i in range(0, np.shape(neighbours)[0]):
                        eval_summary.value.add(tag='Recall@%d test' %
                                               neighbours[i],
                                               simple_value=recalls_te[i])
                    J_m_loss.write_to_tfboard(eval_summary)
                    wd_Loss.write_to_tfboard(eval_summary)
                    J_KL_loss.write_to_tfboard(eval_summary)
                    eval_summary.value.add(tag='learning_rate',
                                           simple_value=_lr)
                    J_syn_loss.write_to_tfboard(eval_summary)
                    J_metric_loss.write_to_tfboard(eval_summary)
                    J_recon_loss.write_to_tfboard(eval_summary)
                    cross_entropy_loss.write_to_tfboard(eval_summary)
                    summary_writer.add_summary(eval_summary, step)
                    print('Summary written')
                    if nmi_te > max_nmi:
                        max_nmi = nmi_te
                        print("Saved")
                        saver.save(sess, os.path.join(LOGDIR, "model.ckpt"))
                    summary_writer.flush()
                    if step >= MAX_ITER:
                        break

        with tqdm(total=MAX_ITER) as pbar:
            for batch in copy.copy(epoch_iterator):
                # get images and labels from batch
                x_batch_data, Label_raw = nn_Ops.batch_data(batch)
                pbar.update(1)
                c_train, g_train, s_train, wd_Loss_var, J_KL_var,J_metric_var, J_m_var, \
                    J_syn_var, J_recon_var, cross_en_var = sess.run(
                        [c_train_step, g_train_step, s_train_step, wdLoss, J_KL,
                         J_metric, J_m, J_syn, J_recon, cross_entropy],
                        feed_dict={x_raw: x_batch_data,
                                   label_raw: Label_raw,
                                   is_Training: True,is_Phase:True,
                                   lambda1:0.8, lambda2:1, lambda3:0.4, lambda4:0.8, lr: _lr})

                wd_Loss.update(var=wd_Loss_var)
                J_KL_loss.update(var=J_KL_var)
                J_metric_loss.update(var=J_metric_var)
                J_m_loss.update(var=J_m_var)
                J_syn_loss.update(var=J_syn_var)
                J_recon_loss.update(var=J_recon_var)
                cross_entropy_loss.update(cross_en_var)
                step += 1

                # evaluation
                if step % bp_epoch == 0:
                    print('only eval eval')
                    nmi_te, f1_te, recalls_te = evaluation.Evaluation(
                        stream_test, image_mean, sess, x_raw, label_raw,
                        is_Training, is_Phase, embedding_z, 98, neighbours)

                    # Summary
                    eval_summary = tf.Summary()
                    eval_summary.value.add(tag='test nmi', simple_value=nmi_te)
                    eval_summary.value.add(tag='test f1', simple_value=f1_te)
                    for i in range(0, np.shape(neighbours)[0]):
                        eval_summary.value.add(tag='Recall@%d test' %
                                               neighbours[i],
                                               simple_value=recalls_te[i])
                    J_m_loss.write_to_tfboard(eval_summary)
                    wd_Loss.write_to_tfboard(eval_summary)
                    J_KL_loss.write_to_tfboard(eval_summary)
                    eval_summary.value.add(tag='learning_rate',
                                           simple_value=_lr)
                    J_syn_loss.write_to_tfboard(eval_summary)
                    J_metric_loss.write_to_tfboard(eval_summary)
                    J_recon_loss.write_to_tfboard(eval_summary)
                    cross_entropy_loss.write_to_tfboard(eval_summary)
                    summary_writer.add_summary(eval_summary, step)
                    print('Summary written')
                    if nmi_te > max_nmi:
                        max_nmi = nmi_te
                        print("Saved")
                        saver.save(sess, os.path.join(LOGDIR, "model.ckpt"))
                    summary_writer.flush()
                    if step >= 2 * MAX_ITER:
                        os._exit(os.EX_OK)
예제 #6
0
def configTfSession(streams, summary_writer, train_steps, losses):
    stream_train, stream_train_eval, stream_test = streams
    wdLoss, J_m, Javg, Jgen, J_metric, J_gen, J_syn, cross_entropy, J_recon, J_soft = losses

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()

        # learning rate
        _lr = FLAGS.init_learning_rate

        # Restore a checkpoint
        if FLAGS.load_formalVal:
            saver.restore(
                sess, FLAGS.log_save_path + FLAGS.dataSet + '/' +
                FLAGS.LossType + '/' + FLAGS.formerTimer)

        # Training
        epoch_iterator = stream_train.get_epoch_iterator()

        # collectors
        J_m_loss = nn_Ops.data_collector(tag='Jm', init=1e+6)
        J_syn_loss = nn_Ops.data_collector(tag='J_syn', init=1e+6)
        J_metric_loss = nn_Ops.data_collector(tag='J_metric', init=1e+6)
        J_soft_loss = nn_Ops.data_collector(tag='J_soft', init=1e+6)
        J_recon_loss = nn_Ops.data_collector(tag='J_recon', init=1e+6)
        J_gen_loss = nn_Ops.data_collector(tag='J_gen', init=1e+6)
        cross_entropy_loss = nn_Ops.data_collector(tag='cross_entropy',
                                                   init=1e+6)
        wd_Loss = nn_Ops.data_collector(tag='weight_decay', init=1e+6)
        max_nmi = 0
        step = 0

        bp_epoch = FLAGS.init_batch_per_epoch
        with tqdm(total=FLAGS.max_steps) as pbar:
            for batch in copy.copy(epoch_iterator):
                # get images and labels from batch
                x_batch_data, Label_raw = nn_Ops.batch_data(batch)
                pbar.update(1)
                if not FLAGS.Apply_HDML:
                    train_step = train_steps
                    train, J_m_var, wd_Loss_var = sess.run(
                        [train_step, J_m, wdLoss],
                        feed_dict={
                            x_raw: x_batch_data,
                            label_raw: Label_raw,
                            is_Training: True,
                            lr: _lr
                        })
                    J_m_loss.update(var=J_m_var)
                    wd_Loss.update(var=wd_Loss_var)

                else:
                    c_train_step, g_train_step, s_train_step = train_steps
                    c_train, g_train, s_train, wd_Loss_var, J_metric_var, J_m_var, \
                    J_syn_var, J_recon_var, J_soft_var, J_gen_var, cross_en_var = sess.run(
                        [c_train_step, g_train_step, s_train_step, wdLoss,
                         J_metric, J_m, J_syn, J_recon, J_soft, J_gen, cross_entropy],
                        feed_dict={x_raw: x_batch_data,
                                   label_raw: Label_raw,
                                   is_Training: True, lr: _lr, Javg: J_m_loss.read(), Jgen: J_gen_loss.read()})
                    wd_Loss.update(var=wd_Loss_var)
                    J_metric_loss.update(var=J_metric_var)
                    J_m_loss.update(var=J_m_var)
                    J_syn_loss.update(var=J_syn_var)
                    J_recon_loss.update(var=J_recon_var)
                    J_soft_loss.update(var=J_soft_var)
                    J_gen_loss.update(var=J_gen_var)
                    cross_entropy_loss.update(cross_en_var)
                step += 1
                # print('learning rate %f' % _lr)

                # evaluation
                if step % bp_epoch == 0:
                    print('only eval eval')
                    # nmi_tr, f1_tr, recalls_tr = evaluation.Evaluation(
                    #     stream_train_eval, image_mean, sess, x_raw, label_raw, is_Training, embedding_z, 98, neighbours)
                    nmi_te, f1_te, recalls_te = evaluation.Evaluation(
                        stream_test, image_mean, sess, x_raw, label_raw,
                        is_Training, embedding_z, 98, neighbours)

                    # Summary
                    eval_summary = tf.Summary()
                    # eval_summary.value.add(tag='train nmi', simple_value=nmi_tr)
                    # eval_summary.value.add(tag='train f1', simple_value=f1_tr)
                    # for i in range(0, np.shape(neighbours)[0]):
                    #     eval_summary.value.add(tag='Recall@%d train' % neighbours[i], simple_value=recalls_tr[i])
                    eval_summary.value.add(tag='test nmi', simple_value=nmi_te)
                    eval_summary.value.add(tag='test f1', simple_value=f1_te)
                    for i in range(0, np.shape(neighbours)[0]):
                        eval_summary.value.add(tag='Recall@%d test' %
                                               neighbours[i],
                                               simple_value=recalls_te[i])
                    J_m_loss.write_to_tfboard(eval_summary)
                    wd_Loss.write_to_tfboard(eval_summary)
                    eval_summary.value.add(tag='learning_rate',
                                           simple_value=_lr)
                    if FLAGS.Apply_HDML:
                        J_syn_loss.write_to_tfboard(eval_summary)
                        J_metric_loss.write_to_tfboard(eval_summary)
                        J_soft_loss.write_to_tfboard(eval_summary)
                        J_recon_loss.write_to_tfboard(eval_summary)
                        J_gen_loss.write_to_tfboard(eval_summary)
                        cross_entropy_loss.write_to_tfboard(eval_summary)
                    summary_writer.add_summary(eval_summary, step)
                    print('Summary written')
                    if nmi_te > max_nmi:
                        max_nmi = nmi_te
                        print("Saved")
                        saver.save(sess, os.path.join(LOGDIR, "model.ckpt"))
                    summary_writer.flush()
                    if step in [5632, 6848]:
                        _lr = _lr * 0.5

                    if step >= 5000:
                        bp_epoch = FLAGS.batch_per_epoch
                    if step >= FLAGS.max_steps:
                        os._exit(0)