def test_batch_hard_triplet_loss():
    """Test the triplet loss with batch hard triplet mining"""
    num_data = 50
    feat_dim = 6
    margin = 0.2
    num_classes = 5

    embeddings = np.random.rand(num_data, feat_dim).astype(np.float32)
    labels = np.random.randint(0, num_classes,
                               size=(num_data)).astype(np.float32)

    for squared in [True, False]:
        pdist_matrix = pairwise_distance_np(embeddings, squared=squared)

        loss_np = 0.0
        for i in range(num_data):
            # Select the hardest positive
            max_pos_dist = np.max(pdist_matrix[i][labels == labels[i]])

            # Select the hardest negative
            min_neg_dist = np.min(pdist_matrix[i][labels != labels[i]])

            loss = np.maximum(0.0, max_pos_dist - min_neg_dist + margin)
            loss_np += loss

        loss_np /= num_data

        # Compute the loss in TF.
        loss_tf_val = batch_hard_triplet_loss(torch.from_numpy(labels),
                                              torch.from_numpy(embeddings),
                                              margin,
                                              squared=squared)
        assert np.allclose(loss_np, loss_tf_val)
Exemple #2
0
def loss_function(labels, embeddings, alpha):
    """
    Loss function as described in http://cs230.stanford.edu/projects_fall_2019/reports/26251543.pdf.
    Essentially triplet loss with global orthogonal regulrization.
    """

    # get dot products
    dot_product = tf.matmul(embeddings,
                            embeddings,
                            transpose_a=False,
                            transpose_b=True)
    dot_product_squared = tf.math.pow(dot_product, 2)

    # get negative pairs mask
    neg_mask = _get_anchor_negative_triplet_mask(labels)
    neg_mask = tf.cast(neg_mask, tf.float16)
    num_pairs = tf.reduce_sum(neg_mask)

    # get regularization terms
    m1 = tf.reduce_sum(dot_product * neg_mask) / num_pairs
    m2 = tf.reduce_sum(dot_product_squared * neg_mask) / num_pairs

    # get embeddings dimension
    dim_term = tf.cast(1 / tf.shape(embeddings)[1], tf.float16)

    # compute global orthogonal regularization term
    l_gor = m1**2 + tf.maximum(tf.constant(0, dtype=tf.float16), m2 - dim_term)

    # get triplet_loss
    l_triplet = batch_hard_triplet_loss(labels, embeddings, MARGIN)

    return l_triplet + alpha * l_gor
def train_model_with_tripletloss(model,
                                 train_data,
                                 train_labels,
                                 test_data,
                                 test_labels,
                                 num_classes,
                                 len_encoding,
                                 num_epochs=20,
                                 batch_size=128,
                                 learning_rate=0.001,
                                 margin=0.5,
                                 triplet_loss_strategy="batch_hard"):

    # Generate tf.data.Dataset
    train_dataset = tf.data.Dataset.from_tensor_slices(
        (train_data, train_labels))
    train_dataset = train_dataset.batch(batch_size)

    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

    # Placeholder for total loss over an epoch
    overall_triplet_loss = tf.Variable(0, dtype=tf.float32)

    # Train network
    for epoch in range(num_epochs):
        # Iterate over minibatches
        overall_triplet_loss.assign(0.0)
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            triplet_loss_value = train_one_step_tripletloss(
                model, x_batch_train, y_batch_train, triplet_loss_strategy,
                margin, optimizer)

            overall_triplet_loss.assign_add(triplet_loss_value)
            #if(epoch==0):
            #    print('Step: {} Loss: {}'.format(step, triplet_loss_value))
        print('Epoch: {}: Average Train Loss: {} Last Batch Loss {}'.format(
            epoch,
            overall_triplet_loss.numpy() / batch_size, triplet_loss_value))

        # Model evaluation on test set for this epoch
        test_dataset = tf.data.Dataset.from_tensor_slices(
            (test_data, test_labels))
        test_dataset = test_dataset.batch(batch_size)
        overall_triplet_loss.assign(0.0)
        for x_batch_test, y_batch_test in test_dataset:
            embeddings = model(x_batch_test)
            if (triplet_loss_strategy == "batch_hard"):
                triplet_loss_value = triplet_loss.batch_hard_triplet_loss(
                    y_batch_test, embeddings, margin)
            elif (triplet_loss_strategy == "batch_all"):
                triplet_loss_value = triplet_loss.batch_all_triplet_loss(
                    y_batch_test, embeddings, margin)
            overall_triplet_loss.assign_add(triplet_loss_value)
        print('Epoch: {}:  Average Test Loss: {} Last Batch Loss {}'.format(
            epoch,
            overall_triplet_loss.numpy() / batch_size, triplet_loss_value))
    def batch_hard_triplet_loss(self,
                                embeddings,
                                labels,
                                margin=0.5,
                                squared=False):

        return triplet_loss.batch_hard_triplet_loss(labels,
                                                    embeddings,
                                                    margin=margin,
                                                    squared=squared)
Exemple #5
0
 def hard_triplet_loss(self, y_true, y_pred):
     print(y_true)
     print(y_true.shape)
     y_true = y_true[:, 0]
     print(y_true)
     print(y_true.shape)
     return triplet_loss.batch_hard_triplet_loss(y_true,
                                                 y_pred,
                                                 margin=0.6,
                                                 squared=False)
def train_one_step_tripletloss(model, x_batch_train, y_batch_train,
                               triplet_loss_strategy, margin, optimizer):
    with tf.GradientTape(persistent=False) as tape:
        embeddings = model(x_batch_train)
        if (triplet_loss_strategy == "batch_hard"):
            triplet_loss_value = triplet_loss.batch_hard_triplet_loss(
                y_batch_train, embeddings, margin)
        elif (triplet_loss_strategy == "batch_all"):
            triplet_loss_value = triplet_loss.batch_all_triplet_loss(
                y_batch_train, embeddings, margin)

    grads = tape.gradient(triplet_loss_value, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    return triplet_loss_value
Exemple #7
0
def cnn_model(features, labels, mode):
    images = features['images']
    filenames = features['filenames']
    onehot_labels = labels
    axillary_labels = features['axillary_labels']

    if FLAGS.network == 'alexnet':
        # Format data
        if FLAGS.data_format == 'NCHW':
            print(colored("Converting data format to channels first (NCHW)", \
                    'blue'))
            images = tf.transpose(images, [0, 3, 1, 2])

        # Setup batch normalization
        if mode == tf.estimator.ModeKeys.TRAIN:
            norm_params={'is_training':True, 
                    'data_format': FLAGS.data_format}
        else:
            norm_params={'is_training':False,
                    'data_format': FLAGS.data_format,
                    'updates_collections': None}

        # Create the network
        logits = alexnet(images, norm_params, mode) 

    elif FLAGS.network == 'resnet':
        logits, end_points = resnet_v2.resnet_v2_50(inputs=images, 
                num_classes=ts._NUM_CLASSES, 
                is_training=(mode==tf.estimator.ModeKeys.TRAIN))

    # Inference
    predicted_classes = tf.argmax(logits, 1)
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
                mode,
                predictions={
                    'pred_class': predicted_classes,
                    'gt_class': axillary_labels,
                    'embedding': logits,
                    # 'prob': tf.nn.softmax(logits),
                })

    # Training 
    groundtruth_classes = tf.argmax(onehot_labels, 1)
    if FLAGS.mode == "triplet_training":
        if FLAGS.triplet_mining_method == "batchall":
            loss, fraction_positive_triplets, num_valid_triplets = \
                    triplet_loss.batch_all_triplet_loss(
                    axillary_labels, logits, FLAGS.triplet_margin)
        elif FLAGS.triplet_mining_method == "batchhard":
            loss = triplet_loss.batch_hard_triplet_loss(
                    axillary_labels, logits, FLAGS.triplet_margin)
        else:
            "ERROR: Wrong Triplet loss mining method, using softmax"
            loss = tf.losses.softmax_cross_entropy(
                    onehot_labels=onehot_labels, logits=logits)
        if FLAGS.loss_mode == "mix":
            loss += tf.losses.softmax_cross_entropy(
                    onehot_labels=onehot_labels, logits=logits)

    else:
        loss = tf.losses.softmax_cross_entropy(
                onehot_labels=onehot_labels, logits=logits)

    if mode == tf.estimator.ModeKeys.TRAIN:
        if FLAGS.optimizer == 'GD':
            decay_factor = 0.96
            learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                    tf.train.get_global_step(),
                    int(math.ceil(float(ts._SPLITS_TO_SIZES['train'] / 
                        FLAGS.batch_size))),
                    decay_factor)
            optimizer = tf.train.GradientDescentOptimizer(
                    learning_rate=learning_rate)
        elif FLAGS.optimizer == 'Momentum':
            decay_factor = 0.96
            learning_rate = tf.train.exponential_decay(FLAGS.learning_rate,
                    tf.train.get_global_step(),
                    int(math.ceil(float(ts._SPLITS_TO_SIZES['train'] / 
                        FLAGS.batch_size))),
                    decay_factor)
            optimizer = tf.train.MomentumOptimizer(
                    learning_rate=learning_rate, momentum=0.9)
        else:
            optimizer = tf.train.AdamOptimizer(
                    learning_rate=FLAGS.learning_rate)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.minimize(loss, 
                    global_step=tf.train.get_global_step())
            return tf.estimator.EstimatorSpec(
                    mode, loss=loss, train_op=train_op)

    # Testing
    # top_5 = tf.metrics.precision_at_top_k(
            # labels=groundtruth_classes, 
            # predictions=predicted_classes,
            # k = 5)
    # top_10 = tf.metrics.precision_at_top_k(
            # labels=groundtruth_classes, 
            # predictions=predicted_classes,
            # k = 10)
    eval_metric_ops = {
            'eval/accuracy': tf.metrics.accuracy(
                labels=groundtruth_classes, 
                predictions=predicted_classes),
            # 'eval/accuracy_top5': top_5,
            # 'eval/accuracy_top10': top_10,
            }
    return tf.estimator.EstimatorSpec(
            mode, loss=loss, eval_metric_ops=eval_metric_ops)
Exemple #8
0
def my_model(features, labels, mode, params):
    '''
       model_fn指定函数,构建模型,训练等
       ---------------------------------
       Args:
          features: 输入,shape = (batch_size, 784)
          labels:   输出,shape = (batch_size, )
          mode:     str, 阶段
          params:   dict, 超参数
    '''
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    images = features
    images = tf.reshape(
        images, shape=[-1, params['image_size'], params['image_size'],
                       1])  # reshape (batch_size, img_size, img_size, 1)
    with tf.variable_scope("model"):
        embeddings = build_model(is_training, images, params)  # 简历模型

    if mode == tf.estimator.ModeKeys.PREDICT:  # 如果是预测阶段,直接返回得到embeddings
        predictions = {'embeddings': embeddings}
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
    '''调用对应的triplet loss'''
    labels = tf.cast(labels, tf.int64)
    if params['triplet_strategy'] == 'batch_all':
        loss, fraction = batch_all_triplet_loss(labels,
                                                embeddings,
                                                margin=params['margin'],
                                                squared=params['squared'])
    elif params['triplet_strategy'] == 'batch_hard':
        loss = batch_hard_triplet_loss(labels,
                                       embeddings,
                                       margin=params['margin'],
                                       squared=params['squared'])
    else:
        raise ValueError("triplet_strategy 配置不正确: {}".format(
            params['triplet_strategy']))

    embedding_mean_norm = tf.reduce_mean(tf.norm(
        embeddings, axis=1))  # 这里计算了embeddings的二范数的均值
    tf.summary.scalar("embedding_mean_norm", embedding_mean_norm)
    with tf.variable_scope("metrics"):
        eval_metric_ops = {
            'embedding_mean_norm': tf.metrics.mean(embedding_mean_norm)
        }
        if params['triplet_strategy'] == 'batch_all':
            eval_metric_ops['fraction_positive_triplets'] = tf.metrics.mean(
                fraction)
    if mode == tf.estimator.ModeKeys.EVAL:
        return tf.estimator.EstimatorSpec(mode,
                                          loss=loss,
                                          eval_metric_ops=eval_metric_ops)

    tf.summary.scalar('loss', loss)
    if params['triplet_strategy'] == "batch_all":
        tf.summary.scalar('fraction_positive_triplets', fraction)
    tf.summary.image('train_image', images, max_outputs=1)  # 1代表1个channel

    optimizer = tf.train.AdamOptimizer(learning_rate=params['learning_rate'])
    global_step = tf.train.get_global_step()
    if params['use_batch_norm']:
        '''如果使用BN,需要估计batch上的均值和方差,tf.get_collection(tf.GraphKeys.UPDATE_OPS)就可以得到
        tf.control_dependencies计算完之后再进行里面的操作
        '''
        with tf.control_dependencies(tf.get_collection(
                tf.GraphKeys.UPDATE_OPS)):
            train_op = optimizer.minimize(loss, global_step=global_step)
    else:
        train_op = optimizer.minimize(loss, global_step=global_step)
    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

def is_first_batch_in_epoch(step):
    return (step * batch_size) % num_train_samples < batch_size


# Initialize the model for training set and validation sets
with tf.variable_scope("ResNet") as scope:
    pred, feat, _, classification_weights = ResNet(train_data, True)
    scope.reuse_variables()
    valpred, _, saliency, _ = ResNet(val_data, False)

# Forming the triplet loss by hard-triplet sampler
with tf.name_scope('Triplet_loss'):
    sialoss = tri.batch_hard_triplet_loss(train_labels,
                                          feat,
                                          margin,
                                          squared=False)


def compute_regularization(classification_weights):
    regularization_loss = 0
    for weights in classification_weights:
        regularization_loss += tf.nn.l2_loss(weights)

    return regularization_loss


# Forming the cross-entropy loss and accuracy for classifier learning
with tf.name_scope('Loss_and_Accuracy'):
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
Exemple #10
0
                                        batch_size=params['bs_test'],
                                        shuffle_buffer=params['sb_test'])

    handle = tf.placeholder(tf.string, shape=[])
    x, song_id, time = tf.data.Iterator.from_string_handle(
        handle, trn_itr.output_types, trn_itr.output_shapes).get_next()

    input_layer = tf.reshape(x, params['x.shape'])
    y_ = encode_labels(song_id, one_hot=False)
""" Calculations """
embeddings = model(input_layer, params)

with tf.name_scope("training") as scope:
    # loss, positive_triplets = batch_all_triplet_loss(labels=y_, embeddings=embeddings, margin=params['loss_margin'])
    loss = batch_hard_triplet_loss(labels=y_,
                                   embeddings=embeddings,
                                   margin=params['loss_margin'])
    update_ops = tf.get_collection(
        tf.GraphKeys.UPDATE_OPS)  # needed for batch normalizations
    with tf.control_dependencies(update_ops):
        train_step = tf.train.AdamOptimizer(params['lr']).minimize(
            loss, global_step=tf.train.create_global_step())
    tf.summary.scalar('loss', loss)
    # tf.summary.scalar('positive_triplets', positive_triplets)
""" Creation of the distance matrix for the tree reconstruction """
with tf.name_scope('summaries') as scope:
    distance_matrix = pairwise_distances(embeddings)
    tf.summary.tensor_summary("distance_matrix", distance_matrix)
""" Classification """
with tf.name_scope("classify") as scope:
    predictions = classify_logistic(embeddings, params)
Exemple #11
0
def main(argv=None):
  
    if FLAGS.mode == 'test':
        FLAGS.batch_size = 1
    
    if FLAGS.mode == 'cmc':
        FLAGS.batch_size = 1

    learning_rate = tf.placeholder(tf.float32, name='learning_rate')
    images = tf.placeholder(tf.float32, [3, FLAGS.batch_size, IMAGE_HEIGHT, IMAGE_WIDTH, 3], name='images')
    
    images_total = tf.placeholder(tf.float32, [FLAGS.batch_size, IMAGE_HEIGHT, IMAGE_WIDTH, 3], name='images_total')
    
    labels = tf.placeholder(tf.float32, [FLAGS.batch_size], name='labels')
  
    
    
    

    
    
    
    is_train = tf.placeholder(tf.bool, name='is_train')
    global_step = tf.Variable(0, name='global_step', trainable=False)
    weight_decay = 0.0005
    tarin_num_id = 0
    val_num_id = 0

    if FLAGS.mode == 'train':
        tarin_num_id = cuhk03_dataset_label2.get_num_id(FLAGS.data_dir, 'train')
        print(tarin_num_id, '               11111111111111111111               1111111111111111')
    elif FLAGS.mode == 'val':
        val_num_id = cuhk03_dataset_label2.get_num_id(FLAGS.data_dir, 'val')
  
    
    
    
    
    
    
    # Create the model and an embedding head.
    model = import_module('nets.' + 'resnet_v1_50')
    head = import_module('heads.' + 'fc1024')
    
    
    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    endpoints, body_prefix = model.endpoints(images_total, is_training=True)

    with tf.name_scope('head'):
        endpoints = head.head(endpoints, FLAGS.embedding_dim, is_training=True)
    
    
    '''
    print endpoints['model_output'] # (bt,2048)
    print endpoints['global_pool'] # (bt,2048)
    print endpoints['resnet_v1_50/block4']# (bt,7,7,2048)
    
    print ' 1\n'
    '''

    
    
    
    
    train_mode = tf.placeholder(tf.bool)


    print('Build network')
    
    feat = endpoints['resnet_v1_50/block4']# (bt,7,7,2048)
    
    

    #feat = tf.convert_to_tensor(feat, dtype=tf.float32)

    feat_1x1 = tf.layers.conv2d(feat, 2048, [1, 1],padding='valid',
            kernel_regularizer=tf.contrib.layers.l2_regularizer(weight_decay), reuse=None, name='conv1x1')     
    
    feature = part_attend(feat_1x1,weight_decay)
   
    #loss_triplet,PP,NN = triplet_hard_loss(feature,FLAGS.ID_num,FLAGS.IMG_PER_ID)
    loss_triplet ,PP,NN = batch_hard_triplet_loss(labels,feature,0.3)

    
    
   
    
    loss = loss_triplet*FLAGS.global_rate
    

    
    
    
    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(
    tf.GraphKeys.GLOBAL_VARIABLES, body_prefix)
    
    
    
    
    #optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)

    #optimizer = tf.train.AdadeltaOptimizer(learning_rate)
    #train = optimizer.minimize(loss, global_step=global_step)
    
    
    
    
    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        #train_op = optimizer.minimize(loss_mean, global_step=global_step)

    
        optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
        train = optimizer.minimize(loss, global_step=global_step)
    

    lr = FLAGS.learning_rate

    #config=tf.ConfigProto(log_device_placement=True)
    #config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)) 
    # GPU
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.9
    
    with tf.Session(config=config) as sess:
        
        

        
        
        print '\n'
        #print model_variables
        print '\n'
        #sess.run(tf.global_variables_initializer())
        #saver = tf.train.Saver()
        
        #checkpoint_saver = tf.train.Saver(max_to_keep=0)
        checkpoint_saver = tf.train.Saver()


        ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print('Restore model')
            print ckpt.model_checkpoint_path
            #saver.restore(sess, ckpt.model_checkpoint_path)
            checkpoint_saver.restore(sess, ckpt.model_checkpoint_path)
                    
        #for first , training load imagenet
        else:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(model_variables)
            print FLAGS.initial_checkpoint
            saver.restore(sess, FLAGS.initial_checkpoint)
            
         
            

            
            
        if FLAGS.mode == 'train':
            step = sess.run(global_step)
            for i in xrange(step, FLAGS.max_steps + 1):

                batch_images, batch_labels, batch_images_total = cuhk03_dataset_label2.read_data(FLAGS.data_dir, 'train', tarin_num_id,
                    IMAGE_WIDTH, IMAGE_HEIGHT, FLAGS.batch_size,FLAGS.ID_num,FLAGS.IMG_PER_ID)
                
             
              
                feed_dict = {learning_rate: lr,  is_train: True , train_mode: True, images_total: batch_images_total, labels: batch_labels}
                
                
                
                
                _,train_loss = sess.run([train,loss], feed_dict=feed_dict) 
                    
                print('Step: %d, Learning rate: %f, Train loss: %f ' % (i, lr, train_loss))
                
                gtoloss,gp,gn = sess.run([loss_triplet,PP,NN], feed_dict=feed_dict)   
                print 'global hard: ',gtoloss
                print 'global P: ',gp
                print 'global N: ',gn
                
                
            
                
                
                #lr = FLAGS.learning_rate / ((2) ** (i/160000)) * 0.1
                lr = FLAGS.learning_rate * ((0.0001 * i + 1) ** -0.75)
                if i % 100 == 0:
                    #saver.save(sess, FLAGS.logs_dir + 'model.ckpt', i)
                    # test save
                    #vgg.save_npy(sess, './big.npy')
                    
                    checkpoint_saver.save(sess,FLAGS.logs_dir + 'model.ckpt', i)
                
                
                
                
        

        elif FLAGS.mode == 'val':
            total = 0.
            for _ in xrange(10):
                batch_images, batch_labels = cuhk03_dataset_label2.read_data(FLAGS.data_dir, 'val', val_num_id,
                    IMAGE_WIDTH, IMAGE_HEIGHT, FLAGS.batch_size)
                feed_dict = {images: batch_images, labels: batch_labels, is_train: False}
                prediction = sess.run(inference, feed_dict=feed_dict)
                prediction = np.argmax(prediction, axis=1)
                label = np.argmax(batch_labels, axis=1)

                for i in xrange(len(prediction)):
                    if prediction[i] == label[i]:
                        total += 1
            print('Accuracy: %f' % (total / (FLAGS.batch_size * 10)))

            '''
            for i in xrange(len(prediction)):
                print('Prediction: %s, Label: %s' % (prediction[i] == 0, labels[i] == 0))
                image1 = cv2.cvtColor(batch_images[0][i], cv2.COLOR_RGB2BGR)
                image2 = cv2.cvtColor(batch_images[1][i], cv2.COLOR_RGB2BGR)
                image = np.concatenate((image1, image2), axis=1)
                cv2.imshow('image', image)
                key = cv2.waitKey(0)
                if key == 1048603:  # ESC key
                    break
            '''

        
        elif FLAGS.mode == 'test':
            image1 = cv2.imread(FLAGS.image1)
            image1 = cv2.resize(image1, (IMAGE_WIDTH, IMAGE_HEIGHT))
            image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
            image1 = np.reshape(image1, (1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)).astype(float)
            image2 = cv2.imread(FLAGS.image2)
            image2 = cv2.resize(image2, (IMAGE_WIDTH, IMAGE_HEIGHT))
            image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
            image2 = np.reshape(image2, (1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)).astype(float)
            test_images = np.array([image1, image2,image2])

            feed_dict = {images: test_images, is_train: False, droup_is_training: False}
            #prediction, prediction2 = sess.run([DD,DD2], feed_dict=feed_dict)
            prediction = sess.run([inference], feed_dict=feed_dict)
            prediction = np.array(prediction)
            print prediction.shape
            print( np.argmax(prediction[0])+1)
    logit = net.outputs
    logit_norm = tf.norm(logit, axis=1, keep_dims=True)
    logit = tf.div(logit, logit_norm, name='norm_logit')

    # test net  because of batch normal layer
    tl.layers.set_name_reuse(True)
    test_net = get_resnet(images,
                          args.net_depth,
                          type='ir',
                          w_init=w_init_method,
                          trainable=False,
                          reuse=True,
                          keep_rate=dropout_rate)
    embedding_tensor = test_net.outputs
    # 3.3 define the cross entropy
    t_loss = batch_hard_triplet_loss(
        labels, logit, margin=args.triplet_margin) * args.triplet_weight

    wd_loss = 0
    for weights in tl.layers.get_variables_with_name('W_conv2d', True, True):
        wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(weights)
    for W in tl.layers.get_variables_with_name('resnet_v1_50/E_DenseLayer/W',
                                               True, True):
        wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(W)
    for weights in tl.layers.get_variables_with_name('embedding_weights', True,
                                                     True):
        wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(weights)
    for gamma in tl.layers.get_variables_with_name('gamma', True, True):
        wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(gamma)
    # for beta in tl.layers.get_variables_with_name('beta', True, True):
    #     wd_loss += tf.contrib.layers.l2_regularizer(args.weight_deacy)(beta)
    for alphas in tl.layers.get_variables_with_name('alphas', True, True):
Exemple #13
0
def main(argv=None):

    if FLAGS.mode == 'test':
        FLAGS.batch_size = 1
    
    if FLAGS.mode == 'cmc':
        FLAGS.batch_size = 1

    learning_rate = tf.placeholder(tf.float32, name='learning_rate')
    #images = tf.placeholder(tf.float32, [2, FLAGS.batch_size, IMAGE_HEIGHT, IMAGE_WIDTH, 3], name='images')
    images = tf.placeholder(tf.float32, [3, FLAGS.batch_size, IMAGE_HEIGHT, IMAGE_WIDTH, 3], name='images')
    
    images_total = tf.placeholder(tf.float32, [FLAGS.batch_size, IMAGE_HEIGHT, IMAGE_WIDTH, 3], name='images_total')
    
    labels = tf.placeholder(tf.float32, [FLAGS.batch_size], name='labels')
    #labels_neg = tf.placeholder(tf.float32, [FLAGS.batch_size, 743], name='labels')
    
    #total
    #labels = tf.placeholder(tf.float32, [FLAGS.batch_size, 847], name='labels')
    #labels_neg = tf.placeholder(tf.float32, [FLAGS.batch_size, 847], name='labels')
    
    #eye
    #labels = tf.placeholder(tf.float32, [FLAGS.batch_size, 104], name='labels')
    #labels_neg = tf.placeholder(tf.float32, [FLAGS.batch_size, 104], name='labels')

    
    
    is_train = tf.placeholder(tf.bool, name='is_train')
    global_step = tf.Variable(0, name='global_step', trainable=False)
    weight_decay = 0.0005
    tarin_num_id = 0
    val_num_id = 0

    if FLAGS.mode == 'train':
        tarin_num_id = cuhk03_dataset_label2.get_num_id(FLAGS.data_dir, 'train')
        print(tarin_num_id, '               11111111111111111111               1111111111111111')
    elif FLAGS.mode == 'val':
        val_num_id = cuhk03_dataset_label2.get_num_id(FLAGS.data_dir, 'val')

    
    # Create the model and an embedding head.
    model = import_module('nets.' + 'resnet_v1_50')
    head = import_module('heads.' + 'fc1024')
    
    
    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    endpoints, body_prefix = model.endpoints(images_total, is_training=True)

    with tf.name_scope('head'):
        endpoints = head.head(endpoints, FLAGS.embedding_dim, is_training=True)
    
    
    '''
    print endpoints['model_output'] # (bt,2048)
    print endpoints['global_pool'] # (bt,2048)
    print endpoints['resnet_v1_50/block4']# (bt,7,7,2048)
    '''

    # Create the model and an embedding head.
    model2 = import_module('nets.' + 'resnet_v1_101')
    endpoints2, body_prefix2 = model2.endpoints(images_total, is_training=True)
       
    train_mode = tf.placeholder(tf.bool)

    print('Build network')
    
    feat = endpoints['resnet_v1_50/block4']# (bt,7,7,2048)
    
    feat2 = endpoints2['resnet_v1_101/block4']# (bt,7,7,2048)

    #feat = tf.convert_to_tensor(feat, dtype=tf.float32)
    # global
    feature,feature2 = global_pooling(feat,feat2,weight_decay)
    loss_triplet ,PP,NN = batch_hard_triplet_loss(labels,feature,0.3)
    
    
    _,dis_matrix1 = triplet_hard_loss(feature,FLAGS.ID_num,FLAGS.IMG_PER_ID)
    _,dis_matrix2 = triplet_hard_loss(feature2,FLAGS.ID_num,FLAGS.IMG_PER_ID)
    mul_loss = multual_loss(dis_matrix1,dis_matrix2)

    
    
    local_anchor_feature, local_anchor_feature2 = local_pooling(feat,feat2,weight_decay)
    local_loss_triplet ,local_pos_loss, local_neg_loss = local_triplet_hard_loss(local_anchor_feature,FLAGS.ID_num,FLAGS.IMG_PER_ID)
    
    
    
    loss_triplet2 ,PP2,NN2 = batch_hard_triplet_loss(labels,feature2,0.3)
    local_loss_triplet2 ,local_pos_loss2, local_neg_loss2 = local_triplet_hard_loss(local_anchor_feature2,FLAGS.ID_num,FLAGS.IMG_PER_ID)
    

    
    s1 = fully_connected_class(feature,feature_dim=2048,num_classes=1000)#tarin_num_id
    cross_entropy_var = slim.losses.sparse_softmax_cross_entropy(s1, tf.cast(labels, tf.int64))
    loss_softmax = cross_entropy_var
    #loss_softmax = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels_softmax, logits=s1))
    inference = tf.nn.softmax(s1)
    
    s2 = fully_connected_class2(feature2,feature_dim=2048,num_classes=1000)
    cross_entropy_var2 = slim.losses.sparse_softmax_cross_entropy(s2, tf.cast(labels, tf.int64))
    loss_softmax2 = cross_entropy_var2
    
    #loss_softmax2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels_softmax, logits=s2))
    inference2 = tf.nn.softmax(s2)
    
    multual_softmax1 = kl_loss_compute(s1, s2)
    multual_softmax2 = kl_loss_compute(s2, s1)
    
    
    
    
    P1= tf.reduce_mean(PP)
    P2= tf.reduce_mean(PP2)
    N1= tf.reduce_mean(NN)
    N2= tf.reduce_mean(NN2)
    
    LP1= tf.reduce_mean(local_pos_loss)
    LN1= tf.reduce_mean(local_neg_loss)
    
    
    
    '''
    
    # global
    feature2 = global_pooling(feat2,weight_decay)
    #loss_triplet,PP,NN = triplet_hard_loss(feature,FLAGS.ID_num,FLAGS.IMG_PER_ID)
    loss_triplet2 ,PP2,NN2 = batch_hard_triplet_loss(labels,feature2,0.3)

    
    #local
    local_anchor_feature2 = local_pooling(feat2,weight_decay)
    local_loss_triplet2 ,local_pos_loss2, local_neg_loss2 = local_triplet_hard_loss(local_anchor_feature2,FLAGS.ID_num,FLAGS.IMG_PER_ID)
    '''
    
    

    loss = local_loss_triplet*FLAGS.local_rate + loss_triplet*FLAGS.global_rate + mul_loss + loss_softmax + multual_softmax1
   
    #DD = compute_euclidean_distance(anchor_feature,positive_feature)
    loss2 = local_loss_triplet2*FLAGS.local_rate + loss_triplet2*FLAGS.global_rate + mul_loss + loss_softmax2 + multual_softmax2
    

    
    
    if FLAGS.mode == 'val' or FLAGS.mode == 'cmc' or FLAGS.mode == 'test':
       loss ,pos_loss, neg_loss = triplet_loss(anchor_feature, positive_feature, negative_feature, 0.3)
       print ' ERROR                 ERROR '
       None
    

    
    
    
    
    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(
    tf.GraphKeys.GLOBAL_VARIABLES, body_prefix)
    
    model_variables2 = tf.get_collection(
    tf.GraphKeys.GLOBAL_VARIABLES, body_prefix2)
    
      
    
    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        

    
        optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
        train = optimizer.minimize(loss, global_step=global_step)
        
        optimizer2 = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
        train2 = optimizer2.minimize(loss2, global_step=global_step)
    

    tf.summary.scalar("total_loss 1", loss)
    tf.summary.scalar("total_loss 2", loss2)
    tf.summary.scalar("learning_rate", learning_rate)

    regularization_var = tf.reduce_sum(tf.losses.get_regularization_loss())
    tf.summary.scalar("weight_loss", regularization_var)
    


    lr = FLAGS.learning_rate

    #config=tf.ConfigProto(log_device_placement=True)
    #config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)) 
    # GPU
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.95

    with tf.Session(config=config) as sess:
        
        merged = tf.summary.merge_all()
        writer = tf.summary.FileWriter("TensorBoard_1x1_a_1x7/", graph = sess.graph)

        #sess.run(tf.global_variables_initializer())
        #saver = tf.train.Saver()
        
        #checkpoint_saver = tf.train.Saver(max_to_keep=0)
        checkpoint_saver = tf.train.Saver()


        ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print('Restore model')
            print ckpt.model_checkpoint_path
            #saver.restore(sess, ckpt.model_checkpoint_path)
            checkpoint_saver.restore(sess, ckpt.model_checkpoint_path)
                    
        #for first , training load imagenet
        else:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(model_variables)
            print FLAGS.initial_checkpoint
            saver.restore(sess, FLAGS.initial_checkpoint)
            
            
            saver2 = tf.train.Saver(model_variables2)
            print FLAGS.initial_checkpoint2
            saver2.restore(sess, FLAGS.initial_checkpoint2)
   
            
            
        if FLAGS.mode == 'train':
            step = sess.run(global_step)
            for i in xrange(step, FLAGS.max_steps + 1):

                batch_images, batch_labels, batch_images_total = cuhk03_dataset_label2.read_data(FLAGS.data_dir, 'train', tarin_num_id,
                    IMAGE_WIDTH, IMAGE_HEIGHT, FLAGS.batch_size,FLAGS.ID_num,FLAGS.IMG_PER_ID)
                
                #feed_dict = {learning_rate: lr,  is_train: True , labels: batch_labels, droup_is_training: False, train_mode: True, images_total: batch_images_total} #no label   images: batch_images,
              
                feed_dict = {learning_rate: lr,  is_train: True , train_mode: True, images_total: batch_images_total, labels: batch_labels}

                                              
                start = time.time()
                                
                _,_,train_loss,train_loss2 = sess.run([train,train2,loss,loss2 ], feed_dict=feed_dict) 
                    
                print('Step: %d, Learning rate: %f, Train loss: %f , Train loss2: %f' % (i, lr, train_loss,train_loss2))
                
                gtoloss,gp,gn = sess.run([loss_triplet,P1,N1], feed_dict=feed_dict)   
                print 'global hard: ',gtoloss
                print 'global P1: ',gp
                print 'global N1: ',gn
                             
                toloss,p,n = sess.run([local_loss_triplet,LP1,LN1], feed_dict=feed_dict)   
                print 'local hard: ',toloss
                print 'local P: ',p
                print 'local N: ',n
                                
                mul,p2,n2 = sess.run([mul_loss,loss_triplet2,local_loss_triplet2], feed_dict=feed_dict)   
                print 'mul loss: ',mul
                print 'loss_triplet2: ',p2
                print 'local_loss_triplet2: ',n2                               
                
                end = time.time()
                elapsed = end - start
                print "Time taken: ", elapsed, "seconds."
                                
               
                #lr = FLAGS.learning_rate / ((2) ** (i/160000)) * 0.1
                lr = FLAGS.learning_rate * ((0.0001 * i + 1) ** -0.75)
                if i % 100 == 0:
               
                    checkpoint_saver.save(sess,FLAGS.logs_dir + 'model.ckpt', i)
                
                if i % 20 == 0:
                    result = sess.run(merged, feed_dict=feed_dict)
                    writer.add_summary(result, i)
                
                
        

        elif FLAGS.mode == 'val':
            total = 0.
            for _ in xrange(10):
                batch_images, batch_labels = cuhk03_dataset_label2.read_data(FLAGS.data_dir, 'val', val_num_id,
                    IMAGE_WIDTH, IMAGE_HEIGHT, FLAGS.batch_size)
                feed_dict = {images: batch_images, labels: batch_labels, is_train: False}
                prediction = sess.run(inference, feed_dict=feed_dict)
                prediction = np.argmax(prediction, axis=1)
                label = np.argmax(batch_labels, axis=1)

                for i in xrange(len(prediction)):
                    if prediction[i] == label[i]:
                        total += 1
            print('Accuracy: %f' % (total / (FLAGS.batch_size * 10)))

            '''
            for i in xrange(len(prediction)):
                print('Prediction: %s, Label: %s' % (prediction[i] == 0, labels[i] == 0))
                image1 = cv2.cvtColor(batch_images[0][i], cv2.COLOR_RGB2BGR)
                image2 = cv2.cvtColor(batch_images[1][i], cv2.COLOR_RGB2BGR)
                image = np.concatenate((image1, image2), axis=1)
                cv2.imshow('image', image)
                key = cv2.waitKey(0)
                if key == 1048603:  # ESC key
                    break
            '''

        
        elif FLAGS.mode == 'cmc':    
          do_times = 1
          cmc_sum=np.zeros((100, 100), dtype='f')
          for times in xrange(do_times):  
              path = 'data' 
              set = 'val'
              
              cmc_array=np.ones((100, 100), dtype='f')
              
              batch_images = []
              batch_labels = []
              index_gallery_array=np.ones((1, 100), dtype='f')
              gallery_bool = True
              probe_bool = True
              for j in xrange(100):
                      id_probe = j
                      for i in xrange(100):
                              batch_images = []
                              batch_labels = []
                              filepath = ''
                              
                              #filepath_gallery = '%s/labeled/%s/%04d_%02d.jpg' % (path, set, i, index_gallery)
                              #filepath_probe = '%s/labeled/%s/%04d_%02d.jpg' % (path, set, id_probe, index_probe)                          
                              
                              if gallery_bool == True:
                                    while True:
                                          index_gallery = int(random.random() * 10)
                                          index_gallery_array[0,i] = index_gallery
  
                                          filepath_gallery = '%s/labeled/%s/%04d_%02d.jpg' % (path, set, i, index_gallery)
                                          if not os.path.exists(filepath_gallery):
                                              continue
                                          break
                              if i ==99:
                                  gallery_bool = False
                              if gallery_bool == False:
                                          index_gallery = index_gallery_array[0,i]
                                          filepath_gallery = '%s/labeled/%s/%04d_%02d.jpg' % (path, set, i, index_gallery)
                              
                              
                              
                              if probe_bool == True:
                                    while True:
                                          index_probe = int(random.random() * 10)
                                          filepath_probe = '%s/labeled/%s/%04d_%02d.jpg' % (path, set, id_probe, index_probe)
                                          if not os.path.exists(filepath_probe):
                                              continue
                                          if index_gallery_array[0,id_probe] == index_probe:
                                              continue
                                          probe_bool = False
                                          break
                              if i ==99:
                                  probe_bool = True
                              
                              
                              '''
                              while True:
                                    index_probe = int(random.random() * 10)
                                    filepath_probe = '%s/labeled/%s/%04d_%02d.jpg' % (path, set, id_probe, index_probe)
                                    if not os.path.exists(filepath_gallery):
                                        continue
                                    if index_gallery_array[1,id_probe] == index_probe:
                                        continue
                                    break
                              '''
                              
                              #filepath_gallery = 'data/labeled/val/0000_01.jpg'
                              #filepath_probe   = 'data/labeled/val/0000_02.jpg'
                                                                          
                              image1 = cv2.imread(filepath_gallery)
                              image1 = cv2.resize(image1, (IMAGE_WIDTH, IMAGE_HEIGHT))
                              image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
                              image1 = np.reshape(image1, (1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)).astype(float)
                              
                              image2 = cv2.imread(filepath_probe)
                              image2 = cv2.resize(image2, (IMAGE_WIDTH, IMAGE_HEIGHT))
                              image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
                              image2 = np.reshape(image2, (1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)).astype(float)
                              
                              test_images = np.array([image1, image2, image2])
                              
                              #print (filepath_gallery)
                              #print (filepath_probe)
                              #print ('1111111111111111111111')
          
                              if i == j:
                                  batch_labels = [1., 0.]
                              if i != j:    
                                  batch_labels = [0., 1.]
                              batch_labels = np.array(batch_labels)
                              print('test  img :',test_images.shape)
                              
                              feed_dict = {images: test_images, is_train: False}
                              prediction = sess.run(DD, feed_dict=feed_dict)
                              #print (prediction, prediction[0][1])
                              
                              print (filepath_gallery,filepath_probe)
                              
                              #print(bool(not np.argmax(prediction[0])))
                              print (prediction)
                              
                              cmc_array[j,i] = prediction
                              
                              #print(i,j)
                             
                              
                              #prediction = sess.run(inference, feed_dict=feed_dict)
                              #prediction = np.argmax(prediction, axis=1)
                              #label = np.argmax(batch_labels, axis=1)
                              
  
              
              cmc_score = cmc.cmc(cmc_array)
              cmc_sum = cmc_score + cmc_sum
              print(cmc_score)
          cmc_sum = cmc_sum/do_times
          print(cmc_sum)
          print('final cmc') 
        
        
        
        elif FLAGS.mode == 'test':
            image1 = cv2.imread(FLAGS.image1)
            image1 = cv2.resize(image1, (IMAGE_WIDTH, IMAGE_HEIGHT))
            image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
            image1 = np.reshape(image1, (1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)).astype(float)
            image2 = cv2.imread(FLAGS.image2)
            image2 = cv2.resize(image2, (IMAGE_WIDTH, IMAGE_HEIGHT))
            image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
            image2 = np.reshape(image2, (1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)).astype(float)
            test_images = np.array([image1, image2,image2])

            feed_dict = {images: test_images, is_train: False, droup_is_training: False}
            #prediction, prediction2 = sess.run([DD,DD2], feed_dict=feed_dict)
            prediction = sess.run([inference], feed_dict=feed_dict)
            prediction = np.array(prediction)
            print prediction.shape
            print( np.argmax(prediction[0])+1)
Exemple #14
0
def triplet_loss(y_true, y_pred):

    label = K.flatten(y_true[:, 0, 0])

    loss = batch_hard_triplet_loss(label, y_pred[:, 1], y_pred[:, 0], 0.2)
    return loss
Exemple #15
0
def main(argv=None):

    if FLAGS.mode == 'test':
        FLAGS.batch_size = 1

    if FLAGS.mode == 'cmc':
        FLAGS.batch_size = 1

    learning_rate = tf.placeholder(tf.float32, name='learning_rate')

    images = tf.placeholder(
        tf.float32, [3, FLAGS.batch_size, IMAGE_HEIGHT, IMAGE_WIDTH, 3],
        name='images')

    images_total = tf.placeholder(
        tf.float32, [FLAGS.batch_size, IMAGE_HEIGHT, IMAGE_WIDTH, 3],
        name='images_total')

    labels = tf.placeholder(tf.float32, [FLAGS.batch_size], name='labels')

    is_train = tf.placeholder(tf.bool, name='is_train')
    global_step = tf.Variable(0, name='global_step', trainable=False)
    weight_decay = 0.0005
    tarin_num_id = 0
    val_num_id = 0

    if FLAGS.mode == 'train':
        tarin_num_id = cuhk03_dataset_label2.get_num_id(
            FLAGS.data_dir, 'train')
        print(
            tarin_num_id,
            '               11111111111111111111               1111111111111111'
        )
    elif FLAGS.mode == 'val':
        val_num_id = cuhk03_dataset_label2.get_num_id(FLAGS.data_dir, 'val')
    #images1, images2,images3 = preprocess(images, is_train)

    # Create the model and an embedding head.
    model = import_module('nets.' + 'mobilenet_v1_1_224')
    head = import_module('heads.' + 'fc1024')

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    endpoints, body_prefix = model.endpoints(images_total, is_training=True)

    with tf.name_scope('head'):
        endpoints = head.head(endpoints, FLAGS.embedding_dim, is_training=True)

    print 'model_output : ', endpoints['model_output']  # (bt,2048)
    print 'global_pool : ', endpoints['global_pool']  # (bt,2048)
    print 'resnet_v1_50/block4 : ', endpoints[
        'Conv2d_12_pointwise']  # (bt,7,7,2048)
    #  see   net.resnet_V1   line 258
    print ' 1\n'

    train_mode = tf.placeholder(tf.bool)

    print('Build network')

    feat = endpoints['Conv2d_12_pointwise']  # (bt,7,7,2048)

    #feat = tf.convert_to_tensor(feat, dtype=tf.float32)
    # global
    feature = global_pooling(feat, weight_decay)
    #loss_triplet,PP,NN = triplet_hard_loss(feature,FLAGS.ID_num,FLAGS.IMG_PER_ID)
    loss_triplet, PP, NN = batch_hard_triplet_loss(labels, feature, 0.3)

    loss = loss_triplet * FLAGS.global_rate

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    #optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)

    #optimizer = tf.train.AdadeltaOptimizer(learning_rate)
    #train = optimizer.minimize(loss, global_step=global_step)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
        #train_op = optimizer.minimize(loss_mean, global_step=global_step)

        optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
        train = optimizer.minimize(loss, global_step=global_step)

    lr = FLAGS.learning_rate

    #config=tf.ConfigProto(log_device_placement=True)
    #config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
    # GPU
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.9

    with tf.Session(config=config) as sess:

        print '\n'
        #print model_variables
        print '\n'
        #sess.run(tf.global_variables_initializer())
        #saver = tf.train.Saver()

        #checkpoint_saver = tf.train.Saver(max_to_keep=0)
        checkpoint_saver = tf.train.Saver()

        ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print('Restore model')
            print ckpt.model_checkpoint_path
            #saver.restore(sess, ckpt.model_checkpoint_path)
            checkpoint_saver.restore(sess, ckpt.model_checkpoint_path)

        #for first , training load imagenet
        else:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(model_variables)
            print FLAGS.initial_checkpoint
            saver.restore(sess, FLAGS.initial_checkpoint)

        if FLAGS.mode == 'train':
            step = sess.run(global_step)
            for i in xrange(step, FLAGS.max_steps + 1):

                batch_images, batch_labels, batch_images_total = cuhk03_dataset_label2.read_data(
                    FLAGS.data_dir, 'train', tarin_num_id, IMAGE_WIDTH,
                    IMAGE_HEIGHT, FLAGS.batch_size, FLAGS.ID_num,
                    FLAGS.IMG_PER_ID)

                feed_dict = {
                    learning_rate: lr,
                    is_train: True,
                    train_mode: True,
                    images_total: batch_images_total,
                    labels: batch_labels
                }

                _, train_loss = sess.run([train, loss], feed_dict=feed_dict)

                print('Step: %d, Learning rate: %f, Train loss: %f ' %
                      (i, lr, train_loss))

                gtoloss, gp, gn = sess.run([loss_triplet, PP, NN],
                                           feed_dict=feed_dict)
                print 'global hard: ', gtoloss
                print 'global P: ', gp
                print 'global N: ', gn

                #lr = FLAGS.learning_rate / ((2) ** (i/160000)) * 0.1
                lr = FLAGS.learning_rate * ((0.0001 * i + 1)**-0.75)
                if i % 100 == 0:
                    #saver.save(sess, FLAGS.logs_dir + 'model.ckpt', i)

                    checkpoint_saver.save(sess, FLAGS.logs_dir + 'model.ckpt',
                                          i)

        elif FLAGS.mode == 'val':
            total = 0.
            for _ in xrange(10):
                batch_images, batch_labels = cuhk03_dataset_label2.read_data(
                    FLAGS.data_dir, 'val', val_num_id, IMAGE_WIDTH,
                    IMAGE_HEIGHT, FLAGS.batch_size)
                feed_dict = {
                    images: batch_images,
                    labels: batch_labels,
                    is_train: False
                }
                prediction = sess.run(inference, feed_dict=feed_dict)
                prediction = np.argmax(prediction, axis=1)
                label = np.argmax(batch_labels, axis=1)

                for i in xrange(len(prediction)):
                    if prediction[i] == label[i]:
                        total += 1
            print('Accuracy: %f' % (total / (FLAGS.batch_size * 10)))
            '''
            for i in xrange(len(prediction)):
                print('Prediction: %s, Label: %s' % (prediction[i] == 0, labels[i] == 0))
                image1 = cv2.cvtColor(batch_images[0][i], cv2.COLOR_RGB2BGR)
                image2 = cv2.cvtColor(batch_images[1][i], cv2.COLOR_RGB2BGR)
                image = np.concatenate((image1, image2), axis=1)
                cv2.imshow('image', image)
                key = cv2.waitKey(0)
                if key == 1048603:  # ESC key
                    break
            '''

        elif FLAGS.mode == 'cmc':
            do_times = 1
            cmc_sum = np.zeros((100, 100), dtype='f')
            for times in xrange(do_times):
                path = 'data'
                set = 'train'

                cmc_array = np.ones((100, 100), dtype='f')

                batch_images = []
                batch_labels = []
                index_gallery_array = np.ones((1, 100), dtype='f')
                gallery_bool = True
                probe_bool = True
                for j in xrange(100):
                    id_probe = j
                    for i in xrange(100):
                        batch_images = []
                        batch_labels = []
                        filepath = ''

                        #filepath_gallery = '%s/labeled/%s/%04d_%02d.jpg' % (path, set, i, index_gallery)
                        #filepath_probe = '%s/labeled/%s/%04d_%02d.jpg' % (path, set, id_probe, index_probe)

                        if gallery_bool == True:
                            while True:
                                index_gallery = int(random.random() * 10)
                                index_gallery_array[0, i] = index_gallery

                                filepath_gallery = '%s/labeled/%s/%04d_%02d.jpg' % (
                                    path, set, i, index_gallery)
                                if not os.path.exists(filepath_gallery):
                                    continue
                                break
                        if i == 99:
                            gallery_bool = False
                        if gallery_bool == False:
                            index_gallery = index_gallery_array[0, i]
                            filepath_gallery = '%s/labeled/%s/%04d_%02d.jpg' % (
                                path, set, i, index_gallery)

                        if probe_bool == True:
                            while True:
                                index_probe = int(random.random() * 10)
                                filepath_probe = '%s/labeled/%s/%04d_%02d.jpg' % (
                                    path, set, id_probe, index_probe)
                                if not os.path.exists(filepath_probe):
                                    continue
                                if index_gallery_array[
                                        0, id_probe] == index_probe:
                                    continue
                                probe_bool = False
                                break
                        if i == 99:
                            probe_bool = True
                        '''
                              while True:
                                    index_probe = int(random.random() * 10)
                                    filepath_probe = '%s/labeled/%s/%04d_%02d.jpg' % (path, set, id_probe, index_probe)
                                    if not os.path.exists(filepath_gallery):
                                        continue
                                    if index_gallery_array[1,id_probe] == index_probe:
                                        continue
                                    break
                              '''

                        #filepath_gallery = 'data/labeled/val/0000_01.jpg'
                        #filepath_probe   = 'data/labeled/val/0000_02.jpg'

                        image1 = cv2.imread(filepath_gallery)
                        image1 = cv2.resize(image1,
                                            (IMAGE_WIDTH, IMAGE_HEIGHT))
                        image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
                        image1 = np.reshape(
                            image1,
                            (1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)).astype(float)

                        image2 = cv2.imread(filepath_probe)
                        image2 = cv2.resize(image2,
                                            (IMAGE_WIDTH, IMAGE_HEIGHT))
                        image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
                        image2 = np.reshape(
                            image2,
                            (1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)).astype(float)

                        test_images = np.array([image1, image2, image2])

                        #print (filepath_gallery)
                        #print (filepath_probe)
                        #print ('1111111111111111111111')

                        if i == j:
                            batch_labels = [1., 0.]
                        if i != j:
                            batch_labels = [0., 1.]
                        batch_labels = np.array(batch_labels)
                        print('test  img :', test_images.shape)

                        feed_dict = {images: test_images, is_train: False}
                        prediction = sess.run(DD, feed_dict=feed_dict)
                        #print (prediction, prediction[0][1])

                        print(filepath_gallery, filepath_probe)

                        #print(bool(not np.argmax(prediction[0])))
                        print(prediction)

                        cmc_array[j, i] = prediction

                        #print(i,j)

                        #prediction = sess.run(inference, feed_dict=feed_dict)
                        #prediction = np.argmax(prediction, axis=1)
                        #label = np.argmax(batch_labels, axis=1)

                cmc_score = cmc.cmc(cmc_array)
                cmc_sum = cmc_score + cmc_sum
                print(cmc_score)
            cmc_sum = cmc_sum / do_times
            print(cmc_sum)
            print('final cmc')

        elif FLAGS.mode == 'test':
            image1 = cv2.imread(FLAGS.image1)
            image1 = cv2.resize(image1, (IMAGE_WIDTH, IMAGE_HEIGHT))
            image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
            image1 = np.reshape(
                image1, (1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)).astype(float)
            image2 = cv2.imread(FLAGS.image2)
            image2 = cv2.resize(image2, (IMAGE_WIDTH, IMAGE_HEIGHT))
            image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
            image2 = np.reshape(
                image2, (1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)).astype(float)
            test_images = np.array([image1, image2, image2])

            feed_dict = {
                images: test_images,
                is_train: False,
                droup_is_training: False
            }
            #prediction, prediction2 = sess.run([DD,DD2], feed_dict=feed_dict)
            prediction = sess.run([inference], feed_dict=feed_dict)
            prediction = np.array(prediction)
            print prediction.shape
            print(np.argmax(prediction[0]) + 1)
    handle = tf.placeholder(tf.string, shape=[])
    x, comp_id, song_id = tf.data.Iterator.from_string_handle(
        handle, trn_itr.output_types, trn_itr.output_shapes).get_next()

    input_layer = tf.reshape(x, PARAMS['x.shape'])  # shape [-1, 233, 1_323, 1]
    y_ = tf.one_hot(comp_id, PARAMS['n_composers'])
    y_ = tf.squeeze(
        y_, 1
    )  # squeeze because tf.graph doesn't know that there is only one comp_id per data point
""" Calculations """
_, embeddings = model(input_layer, PARAMS)
distance_matrix = pairwise_distances(embeddings)

with tf.name_scope("training") as scope:
    loss = batch_hard_triplet_loss(tf.squeeze(comp_id, 1), embeddings,
                                   PARAMS['triplet_loss_margin'])
    update_ops = tf.get_collection(
        tf.GraphKeys.UPDATE_OPS)  # needed for batch normalizations
    with tf.control_dependencies(update_ops):
        train_step = tf.train.AdamOptimizer(PARAMS['lr']).minimize(
            loss, global_step=tf.train.create_global_step())
tf.summary.scalar('loss', loss)
""" Session run """
with tf.Session() as sess:
    for var in tf.trainable_variables():
        tf.summary.histogram(var.name.replace(":", "_"), var)
        if "kernel" in var.name and "conv1d" in var.name:
            tf.summary.image(
                var.name.replace(":", "_") + "_image", tf.expand_dims(var, -1))
    merged = tf.summary.merge_all(
    )  # compute all the summaries (why name merge?)
Exemple #17
0
def model_fn(features, labels, mode, params):
    """Model function for tf.estimator

    Args:
        features: (tf.Tensor) input batch of images
        labels: (tf.Tensor) labels of the images
        mode: (int) An instance of tf.estimator.ModeKeys (TRAIN, EVAL, PREDICT)
        params: (dict) experiment parameters

    Returns:
        model_spec: (tf.estimator.EstimatorSpec)
    """

    images = features
    tf.summary.image('train_image', images, max_outputs=3)

    # Create model
    embeddings = build_model(images, mode, params)

    # if predicting for new data, just compute and return the embeddings
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {'embeddings': embeddings}
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Define loss
    if params.loss == "triplet_batch_hard":
        loss = batch_hard_triplet_loss(labels,
                                       embeddings,
                                       margin=params.margin)
    elif params.loss == "triplet_batch_all":
        loss = batch_all_triplet_loss(labels,
                                      embeddings,
                                      margin=params.margin,
                                      metrics=metrics)
    elif params.loss == "cross_entropy":
        one_hot_labels = tf.one_hot(labels, params.num_classes)
        loss = tf.losses.softmax_cross_entropy(one_hot_labels, embeddings)
    tf.summary.scalar('loss', loss)

    # Metrics
    metrics = {}
    embedding_mean_norm = tf.metrics.mean(tf.norm(embeddings, axis=1))
    metrics["metrics/embedding_mean_norm"] = embedding_mean_norm
    with tf.name_scope("metrics/"):
        tf.summary.scalar('embedding_mean_norm', embedding_mean_norm[1])

    if params.loss == "cross_entropy":
        predictions = tf.argmax(embeddings, 1)
        accuracy = tf.metrics.accuracy(labels=labels, predictions=predictions)
        metrics["metrics/accuracy"] = accuracy
        with tf.name_scope(
                "metrics/"
        ):  #slash / prevents scope from being unique and reenters the name scope
            tf.summary.scalar(
                'accuracy', accuracy[1]
            )  # The tf.summary.scalar will make accuracy available to TensorBoard
            # in both TRAIN and EVAL modes.

    if mode == tf.estimator.ModeKeys.EVAL:
        # loss and metrics run on validation set, https://www.tensorflow.org/guide/custom_estimators
        return tf.estimator.EstimatorSpec(
            mode, loss=loss, eval_metric_ops=metrics
        )  #Values of eval_metric_ops must be (metric_value, update_op) tuples

    # Optimizer
    optimizer = tf.train.AdamOptimizer(params.learning_rate)
    global_step = tf.train.get_global_step()

    train_op = optimizer.minimize(loss, global_step=global_step)

    return tf.estimator.EstimatorSpec(mode,
                                      loss=loss,
                                      train_op=train_op,
                                      training_hooks=[])
Exemple #18
0
def main(argv=None):

    if FLAGS.mode == 'test':
        FLAGS.batch_size = 1

    if FLAGS.mode == 'cmc':
        FLAGS.batch_size = 1

    learning_rate = tf.placeholder(tf.float32, name='learning_rate')
    #images = tf.placeholder(tf.float32, [2, FLAGS.batch_size, IMAGE_HEIGHT, IMAGE_WIDTH, 3], name='images')
    images = tf.placeholder(
        tf.float32, [3, FLAGS.batch_size, IMAGE_HEIGHT, IMAGE_WIDTH, 3],
        name='images')

    images_total = tf.placeholder(
        tf.float32, [FLAGS.batch_size, IMAGE_HEIGHT, IMAGE_WIDTH, 3],
        name='images_total')

    labels = tf.placeholder(tf.float32, [FLAGS.batch_size], name='labels')
    #labels_neg = tf.placeholder(tf.float32, [FLAGS.batch_size, 743], name='labels')

    #total
    #labels = tf.placeholder(tf.float32, [FLAGS.batch_size, 847], name='labels')
    #labels_neg = tf.placeholder(tf.float32, [FLAGS.batch_size, 847], name='labels')

    #eye
    #labels = tf.placeholder(tf.float32, [FLAGS.batch_size, 104], name='labels')
    #labels_neg = tf.placeholder(tf.float32, [FLAGS.batch_size, 104], name='labels')

    is_train = tf.placeholder(tf.bool, name='is_train')
    global_step = tf.Variable(0, name='global_step', trainable=False)
    weight_decay = 0.0005
    tarin_num_id = 0
    val_num_id = 0

    if FLAGS.mode == 'train':
        tarin_num_id = cuhk03_dataset_label2.get_num_id(
            FLAGS.data_dir, 'train')
        print(
            tarin_num_id,
            '               11111111111111111111               1111111111111111'
        )
    elif FLAGS.mode == 'val':
        val_num_id = cuhk03_dataset_label2.get_num_id(FLAGS.data_dir, 'val')

    # Create the model and an embedding head.
    model = import_module('nets.' + 'resnet_v1_50')
    head = import_module('heads.' + 'fc1024')

    # Feed the image through the model. The returned `body_prefix` will be used
    # further down to load the pre-trained weights for all variables with this
    # prefix.
    endpoints, body_prefix = model.endpoints(images_total, is_training=True)

    with tf.name_scope('head'):
        endpoints = head.head(endpoints, FLAGS.embedding_dim, is_training=True)
    '''
    print endpoints['model_output'] # (bt,2048)
    print endpoints['global_pool'] # (bt,2048)
    print endpoints['resnet_v1_50/block4']# (bt,7,7,2048)
    '''

    # Create the model and an embedding head.
    model2 = import_module('nets.' + 'resnet_v1_101')
    endpoints2, body_prefix2 = model2.endpoints(images_total, is_training=True)

    train_mode = tf.placeholder(tf.bool)

    print('Build network')

    feat = endpoints['resnet_v1_50/block4']  # (bt,7,7,2048)

    feat2 = endpoints2['resnet_v1_101/block4']  # (bt,7,7,2048)

    #feat = tf.convert_to_tensor(feat, dtype=tf.float32)
    # global
    feature, feature2 = global_pooling(feat, feat2, weight_decay)
    loss_triplet, PP, NN = batch_hard_triplet_loss(labels, feature, 0.3)

    _, dis_matrix1 = triplet_hard_loss(feature, FLAGS.ID_num, FLAGS.IMG_PER_ID)
    _, dis_matrix2 = triplet_hard_loss(feature2, FLAGS.ID_num,
                                       FLAGS.IMG_PER_ID)
    mul_loss = multual_loss(dis_matrix1, dis_matrix2)

    local_anchor_feature, local_anchor_feature2 = local_pooling(
        feat, feat2, weight_decay)
    local_loss_triplet, local_pos_loss, local_neg_loss = local_triplet_hard_loss(
        local_anchor_feature, FLAGS.ID_num, FLAGS.IMG_PER_ID)

    loss_triplet2, PP2, NN2 = batch_hard_triplet_loss(labels, feature2, 0.3)
    local_loss_triplet2, local_pos_loss2, local_neg_loss2 = local_triplet_hard_loss(
        local_anchor_feature2, FLAGS.ID_num, FLAGS.IMG_PER_ID)

    s1 = fully_connected_class(feature, feature_dim=2048,
                               num_classes=1000)  #tarin_num_id
    cross_entropy_var = slim.losses.sparse_softmax_cross_entropy(
        s1, tf.cast(labels, tf.int64))
    loss_softmax = cross_entropy_var
    #loss_softmax = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels_softmax, logits=s1))
    inference = tf.nn.softmax(s1)

    s2 = fully_connected_class2(feature2, feature_dim=2048, num_classes=1000)
    cross_entropy_var2 = slim.losses.sparse_softmax_cross_entropy(
        s2, tf.cast(labels, tf.int64))
    loss_softmax2 = cross_entropy_var2

    #loss_softmax2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels_softmax, logits=s2))
    inference2 = tf.nn.softmax(s2)

    multual_softmax1 = kl_loss_compute(s1, s2)
    multual_softmax2 = kl_loss_compute(s2, s1)

    P1 = tf.reduce_mean(PP)
    P2 = tf.reduce_mean(PP2)
    N1 = tf.reduce_mean(NN)
    N2 = tf.reduce_mean(NN2)

    LP1 = tf.reduce_mean(local_pos_loss)
    LN1 = tf.reduce_mean(local_neg_loss)
    '''
    
    # global
    feature2 = global_pooling(feat2,weight_decay)
    #loss_triplet,PP,NN = triplet_hard_loss(feature,FLAGS.ID_num,FLAGS.IMG_PER_ID)
    loss_triplet2 ,PP2,NN2 = batch_hard_triplet_loss(labels,feature2,0.3)

    
    #local
    local_anchor_feature2 = local_pooling(feat2,weight_decay)
    local_loss_triplet2 ,local_pos_loss2, local_neg_loss2 = local_triplet_hard_loss(local_anchor_feature2,FLAGS.ID_num,FLAGS.IMG_PER_ID)
    '''

    loss = local_loss_triplet * FLAGS.local_rate + loss_triplet * FLAGS.global_rate + mul_loss + loss_softmax + multual_softmax1

    #DD = compute_euclidean_distance(anchor_feature,positive_feature)
    loss2 = local_loss_triplet2 * FLAGS.local_rate + loss_triplet2 * FLAGS.global_rate + mul_loss + loss_softmax2 + multual_softmax2

    if FLAGS.mode == 'val' or FLAGS.mode == 'cmc' or FLAGS.mode == 'test':
        loss, pos_loss, neg_loss = triplet_loss(anchor_feature,
                                                positive_feature,
                                                negative_feature, 0.3)
        print ' ERROR                 ERROR '
        None

    # These are collected here before we add the optimizer, because depending
    # on the optimizer, it might add extra slots, which are also global
    # variables, with the exact same prefix.
    model_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                        body_prefix)

    model_variables2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                         body_prefix2)

    # Update_ops are used to update batchnorm stats.
    with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):

        optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
        train = optimizer.minimize(loss, global_step=global_step)

        optimizer2 = tf.train.MomentumOptimizer(learning_rate, momentum=0.9)
        train2 = optimizer2.minimize(loss2, global_step=global_step)

    tf.summary.scalar("total_loss 1", loss)
    tf.summary.scalar("total_loss 2", loss2)
    tf.summary.scalar("learning_rate", learning_rate)

    regularization_var = tf.reduce_sum(tf.losses.get_regularization_loss())
    tf.summary.scalar("weight_loss", regularization_var)

    lr = FLAGS.learning_rate

    #config=tf.ConfigProto(log_device_placement=True)
    #config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
    # GPU
    config = tf.ConfigProto()
    config.gpu_options.per_process_gpu_memory_fraction = 0.95

    with tf.Session(config=config) as sess:

        merged = tf.summary.merge_all()
        writer = tf.summary.FileWriter("TensorBoard_1x1_a_1x7/",
                                       graph=sess.graph)

        #sess.run(tf.global_variables_initializer())
        #saver = tf.train.Saver()

        #checkpoint_saver = tf.train.Saver(max_to_keep=0)
        checkpoint_saver = tf.train.Saver()

        ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print('Restore model')
            print ckpt.model_checkpoint_path
            #saver.restore(sess, ckpt.model_checkpoint_path)
            checkpoint_saver.restore(sess, ckpt.model_checkpoint_path)

        #for first , training load imagenet
        else:
            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver(model_variables)
            print FLAGS.initial_checkpoint
            saver.restore(sess, FLAGS.initial_checkpoint)

            saver2 = tf.train.Saver(model_variables2)
            print FLAGS.initial_checkpoint2
            saver2.restore(sess, FLAGS.initial_checkpoint2)

        if FLAGS.mode == 'train':
            step = sess.run(global_step)
            for i in xrange(step, FLAGS.max_steps + 1):

                batch_images, batch_labels, batch_images_total = cuhk03_dataset_label2.read_data(
                    FLAGS.data_dir, 'train', tarin_num_id, IMAGE_WIDTH,
                    IMAGE_HEIGHT, FLAGS.batch_size, FLAGS.ID_num,
                    FLAGS.IMG_PER_ID)

                #feed_dict = {learning_rate: lr,  is_train: True , labels: batch_labels, droup_is_training: False, train_mode: True, images_total: batch_images_total} #no label   images: batch_images,

                feed_dict = {
                    learning_rate: lr,
                    is_train: True,
                    train_mode: True,
                    images_total: batch_images_total,
                    labels: batch_labels
                }

                start = time.time()

                _, _, train_loss, train_loss2 = sess.run(
                    [train, train2, loss, loss2], feed_dict=feed_dict)

                print(
                    'Step: %d, Learning rate: %f, Train loss: %f , Train loss2: %f'
                    % (i, lr, train_loss, train_loss2))

                gtoloss, gp, gn = sess.run([loss_triplet, P1, N1],
                                           feed_dict=feed_dict)
                print 'global hard: ', gtoloss
                print 'global P1: ', gp
                print 'global N1: ', gn

                toloss, p, n = sess.run([local_loss_triplet, LP1, LN1],
                                        feed_dict=feed_dict)
                print 'local hard: ', toloss
                print 'local P: ', p
                print 'local N: ', n

                mul, p2, n2 = sess.run(
                    [mul_loss, loss_triplet2, local_loss_triplet2],
                    feed_dict=feed_dict)
                print 'mul loss: ', mul
                print 'loss_triplet2: ', p2
                print 'local_loss_triplet2: ', n2

                end = time.time()
                elapsed = end - start
                print "Time taken: ", elapsed, "seconds."

                #lr = FLAGS.learning_rate / ((2) ** (i/160000)) * 0.1
                lr = FLAGS.learning_rate * ((0.0001 * i + 1)**-0.75)
                if i % 100 == 0:

                    checkpoint_saver.save(sess, FLAGS.logs_dir + 'model.ckpt',
                                          i)

                if i % 20 == 0:
                    result = sess.run(merged, feed_dict=feed_dict)
                    writer.add_summary(result, i)