Exemplo n.º 1
0
def train_pro_gan():
    """
    Train the generative adverserial network.
    """
    # Graph definition: inputs, model, loss, optimizer, initializer
    print('Graph building...')
    # Inputs
    image_inputs        = tf.compat.v1.placeholder(tf.compat.v1.float32, shape=[None]+config.image_shape, name='image_inputs')
    noise_inputs        = tf.compat.v1.placeholder(tf.compat.v1.float32, shape=[None,config.latent_size], name='noise_inputs')
    res_training        = tf.compat.v1.placeholder(tf.compat.v1.float32, shape=[], name='res_training')
    minibatch_size      = tf.compat.v1.placeholder(tf.compat.v1.int64, shape=[], name='minibatch_size')
    # Network
    gen_loss, disc_loss, output_images = getattr(network, config.network)(
        image_inputs,
        noise_inputs,
        latent_size=config.latent_size,
        minibatch_size=minibatch_size,
        res_building=int(np.log2(config.image_size)),
        res_training=res_training)
    var_list = tf.compat.v1.global_variables() # Store the list of variables
    restore_list = []
    for v in var_list:
        if not '8' in v.name and not 'Output_6' in v.name:
            restore_list += [v]
            print(v)
    # Optimizer
    lr                  = tf.compat.v1.placeholder(tf.compat.v1.float32, name='lr')
    optimizer_disc      = tf.compat.v1.train.AdamOptimizer(learning_rate=lr, **config.optimzer_kwargs)
    optimizer_gen       = tf.compat.v1.train.AdamOptimizer(learning_rate=lr, **config.optimzer_kwargs)
    disc_vars           = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator')
    gen_vars            = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, scope='generator')
    training_op_disc    = optimizer_disc.minimize(disc_loss, var_list=disc_vars, name='training_op_disc')
    training_op_gen     = optimizer_gen.minimize(gen_loss, var_list=gen_vars, name='training_op_gen')
    # Initializer
    init                = tf.compat.v1.global_variables_initializer()
    print('Done: graph built.')

    # Init dataset
    inputs = getattr(dataset, config.data_initilize)(**config.data_initilize_kwargs)

    # Define the minibatch selector
    select_minibatch = partial(getattr(dataset, config.data_selector), inputs)

    # Saver
    saver = train.Saver(var_list,config.logs_path, {'gen_loss':gen_loss, 'disc_loss':disc_loss})

    # Save minibatch test
    minibatch = select_minibatch(crt_img=0, res=config.image_size, minibatch_size=config.nbof_test_sample)/2 +0.5
    saver.save_images(minibatch, 'test')

    # Time measurements
    graph_time = time()

    # Training --> use the configuration file
    print('Training...')
    with tf.compat.v1.Session() as sess:
        # Initialize
        init.run()

        # Restore former parameters
        if config.restore:
            print('Restoring weight stored in {}'.format(config.restore))
            tf.compat.v1.train.Saver(var_list=restore_list).restore(sess, config.restore)

        # Training parameters
        noise_test  = 2*np.random.random((config.nbof_test_sample, config.latent_size))-1

        # Training ...
        cur_img = config.start_img
        while cur_img < config.end_img:
            # Change input dataset
            cur_res = training_schedule(cur_img)
            cur_minibatch_size = config.minibatch_size[int(np.ceil(cur_res))]
            minibatch = select_minibatch(crt_img=cur_img, res=2**int(np.ceil(cur_res)), minibatch_size=cur_minibatch_size)
            noises = 2*np.random.random((cur_minibatch_size, config.latent_size))-1

            feed_dict = {}
            feed_dict['image_inputs:0'] = minibatch
            feed_dict['noise_inputs:0'] = noises
            feed_dict['lr:0'] = config.learning_rate
            feed_dict['res_training:0'] = cur_res
            feed_dict['minibatch_size:0'] = cur_minibatch_size

            sess.run([training_op_disc], feed_dict=feed_dict)
            sess.run([training_op_gen], feed_dict=feed_dict)

            # Display time information
            if cur_img % config.img_per_images == 0:
                graph_time = time() - graph_time
                minutes = int(graph_time // 60)
                secondes = graph_time - minutes * 60
                print('{} kimgs: {:4d} minutes {:2f} secondes, {:2f} resolution'.format(cur_img//1000, minutes, secondes, cur_res))
                graph_time = time()
            # Save logs
            if cur_img % config.img_per_summary == 0:
                saver.save_summary(sess, feed_dict, cur_img)
            # Save images
            if cur_img % config.img_per_images == 0:
                feed_dict = {noise_inputs:noise_test, res_training:cur_res}
                outputs = output_images.eval(feed_dict=feed_dict)
                outputs = outputs / 2 + 0.5 # Re scale output images
                outputs = np.clip(outputs, 0, 1)
                saver.save_images(outputs, cur_img//1000)
            # Save model
            if cur_img % config.img_per_save == 0:
                saver.save_model(sess, cur_img//1000)
            # Update current image
            cur_img += cur_minibatch_size
        # Final Saving
        saver.save_model(sess, 'final')
        saver.close_summary()
    print('Done: training')
Exemplo n.º 2
0
def train_recent_from_npy():
    """
    Train the classification or the recognition network
    mode is in ['classification', 'recognition']
    """
    # Prepare the dataset
    training_filenames, training_labels_bytes = np.load(config.data_path)
    training_labels = np.array([
        int(tlb.decode('ascii')) for tlb in training_labels_bytes
    ])  # Convert labels type: from bytes to int
    validation_filenames, validation_labels_bytes = np.load(
        config.data_test_path)
    validation_labels = np.array([
        int(vlb.decode('ascii')) for vlb in validation_labels_bytes
    ])  # Convert labels type: from bytes to int

    # Graph definition: inputs, model, loss, optimizer, initializer
    print('Graph building...')
    # Small graph for decoding jpegs
    with tf.compat.v1.variable_scope('DecodeJPEG'):
        image_jpeg = tf.compat.v1.placeholder(tf.string,
                                              shape=None,
                                              name='images_raw')
        image_decoded = tf.io.decode_jpeg(image_jpeg)
    # Inputs
    images = tf.compat.v1.placeholder(
        tf.float32,
        shape=[None, config.image_size, config.image_size, 3],
        name='image_inputs')
    label_inputs = tf.compat.v1.placeholder(tf.int64,
                                            shape=[
                                                None,
                                            ],
                                            name='label_inputs')
    training = tf.placeholder(tf.bool, shape=[], name='training')
    lr = tf.compat.v1.placeholder(tf.float32, shape=[], name='learning_rate')
    # Augmentation
    image_inputs = images / 127.5 - 1
    aug = lambda: dataset.augment_image(image_inputs,
                                        config.minibatch_size,
                                        use_horizontal_flip=True,
                                        rotation_rate=0.3,
                                        translation_rate=0.2,
                                        cutout_size=25,
                                        crop_pixels=10)
    image_inputs = tf.cond(training, aug, lambda: image_inputs)

    with tf.variable_scope('Network'):
        # output = deep_cnn_v1(image_inputs_ph, training, config.emb_size, regularizer_rate)
        output = wideresnet(image_inputs,
                            training,
                            config.emb_size,
                            regularizer_rate=config.regularizer_rate,
                            fmaps=[80, 160, 320, 640],
                            nbof_unit=[4, 4, 4, 4],
                            strides=[2, 2, 2, 2],
                            dropouts=[0., 0., 0., 0.])
        # output = deep_cnn_v1(image_inputs, training, config.emb_size, regularizer_rate=config.regularizer_rate, fmaps=[16,32,64,128])
        # output = insightface_resnet(
        #     inputs,
        #     training,
        #     emb_size,
        #     regularizer_rate=regularizer_rate,
        #     dropout_rate=0.,
        #     fmaps       = [16,16,32,64,128],
        #     nbof_unit   = [1,1,1,1,1],
        #     strides     = [2,2,2,2,2])
        emb = tf.nn.l2_normalize(output, axis=1)

    # Loss network
    with tf.variable_scope('LossLayers'):
        # net_logit_layer = layers.dense(emb, config.nbof_labels)
        # net_logit_layer = losses.cosineface_losses(emb, label_inputs, config.nbof_labels, config.regularizer_rate)
        net_logit_layer = losses.cosineface_losses(emb, label_inputs,
                                                   config.minibatch_size,
                                                   config.regularizer_rate)
        net_loss_layer = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=net_logit_layer, labels=label_inputs))

    input_emb = tf.compat.v1.placeholder(tf.float32,
                                         shape=[None, config.emb_size],
                                         name='input_emb')
    with tf.variable_scope('LossLayers', reuse=True):
        # logit_layer = layers.dense(input_emb, config.nbof_labels)
        # logit_layer = losses.cosineface_losses(input_emb, label_inputs, config.nbof_labels, config.regularizer_rate)
        logit_layer = losses.cosineface_losses(input_emb, label_inputs,
                                               config.minibatch_size,
                                               config.regularizer_rate)
        loss_layer = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=logit_layer, labels=label_inputs))
    # Optimizer
    optimizer_net = tf.compat.v1.train.AdamOptimizer(learning_rate=lr,
                                                     **config.optimzer_kwargs)
    optimizer_loss = tf.compat.v1.train.GradientDescentOptimizer(
        learning_rate=0.01)
    # optimizer_loss      = tf.compat.v1.train.AdamOptimizer(learning_rate=lr*10, **config.optimzer_kwargs)
    var_list_net = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope='Network')
    var_list_loss = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope='LossLayers')
    assert len(var_list_net) > 0
    # Training operations
    loss_net = tf.compat.v1.placeholder(tf.float32, shape=[], name='loss')
    training_op_net = optimizer_net.minimize(net_loss_layer,
                                             var_list=var_list_net,
                                             name='training_op_net')
    training_op_loss = optimizer_loss.minimize(loss_layer,
                                               var_list=var_list_loss,
                                               name='training_op_loss')
    extra_training_op = tf.compat.v1.get_collection(
        tf.GraphKeys.UPDATE_OPS)  # For batch normalization
    # Initializer
    init = tf.compat.v1.global_variables_initializer()
    print('Done: graph built.')

    # Saver
    with tf.compat.v1.variable_scope('Saver'):
        train_acc = tf.reduce_mean(
            tf.cast(
                tf.equal(tf.arg_max(tf.nn.softmax(net_logit_layer), 1),
                         label_inputs), tf.float32))
        training_accuracy = tf.compat.v1.placeholder(
            tf.float32, shape=[], name='training_accuracy_placeholder')
        validation_accuracy = tf.compat.v1.placeholder(
            tf.float32, shape=[], name='validation_accuracy_placeholder')
        saver = train.Saver(var_list_net,
                            config.logs_path,
                            summary_dict={
                                'loss': loss_net,
                                'training_accuracy': training_accuracy,
                                'validation_accuracy': validation_accuracy,
                                'learning_rate': lr
                            },
                            restore=None)

    # Time measurements
    init_time = time()
    # Training --> use the configuration file
    print('Training...')
    with tf.compat.v1.Session() as sess:
        # Initialize
        init.run()
        # Restore former parameters
        if config.restore:
            print('Restoring weight stored in {}'.format(config.restore))
            saver.restore(sess, config.restore)

        # Training ...
        cur_img = config.start_img
        max_valid_acc = 0
        validation_accuracy_ = 0
        while cur_img < config.end_img:

            # Inputs
            img_, lab_ = define_batch(training_filenames, training_labels,
                                      config.minibatch_size)
            images_ = [
                sess.run(image_decoded, feed_dict={image_jpeg: im})
                for im in img_
            ]
            # Change label range:
            _, counts = np.unique(lab_, return_counts=True)
            lab_ = np.concatenate([[i] * j for i, j in enumerate(counts)])
            # Embeddings
            emb_ = sess.run(emb,
                            feed_dict={
                                images: images_,
                                label_inputs: lab_,
                                training: True
                            })
            # Training operation
            for _ in range(100):  # Overtrain the last layer
                sess.run(training_op_loss,
                         feed_dict={
                             input_emb: emb_,
                             label_inputs: lab_,
                             lr: training_schedule(cur_img)
                         })
            # Train the network
            loss_net_, training_accuracy_, _, _ = sess.run(
                [
                    net_loss_layer, train_acc, training_op_net,
                    extra_training_op
                ],
                feed_dict={
                    images: images_,
                    label_inputs: lab_,
                    training: True,
                    lr: training_schedule(cur_img)
                })

            # Validation
            if cur_img % config.img_per_val == 0:
                pred_validation = np.empty((0, config.emb_size))
                labels_test = np.empty((0))
                for i in range(0, len(validation_labels),
                               config.minibatch_size):
                    # images_ = load_images(validation_filenames[i:i+config.minibatch_size])
                    images_ = [
                        sess.run(image_decoded, feed_dict={image_jpeg: im})
                        for im in validation_filenames[i:i +
                                                       config.minibatch_size]
                    ]
                    feed_dict = {images: images_, training: False}
                    pred_validation = np.append(pred_validation,
                                                sess.run(emb, feed_dict),
                                                axis=0)
                    labels_test = np.append(
                        labels_test,
                        validation_labels[i:i + config.minibatch_size])
                _, _, acc, val, _, far = evaluate.evaluate(
                    pred_validation, labels_test[::2])
                validation_accuracy_ = np.mean(acc)
                # Display information
                graph_time = time() - init_time
                hours = int(graph_time // 3600)
                minutes = int(graph_time // 60) - hours * 60
                secondes = graph_time - minutes * 60 - hours * 3600
                print(
                    '{:4d} kimgs, {}h {:2d}m {:2.1f}s, lr:{:0.6f}, training: {:2.2f}%, validation: {:2.2f}%'
                    .format(cur_img // 1000, hours, minutes, secondes,
                            training_schedule(cur_img),
                            training_accuracy_ * 100,
                            validation_accuracy_ * 100))
            # Save logs
            if cur_img % config.img_per_summary == 0:
                feed_dict_saver = {
                    images: images_,
                    label_inputs: lab_,
                    validation_accuracy: validation_accuracy_,
                    training_accuracy: training_accuracy_,
                    training: False,
                    loss_net: loss_net_,
                    lr: training_schedule(cur_img)
                }
                saver.save_summary(sess, feed_dict_saver, cur_img)
            # Save model
            # if cur_img % config.img_per_save == 0:
            if max_valid_acc < validation_accuracy_:
                saver.save_model(sess, cur_img // 1000)
                max_valid_acc = validation_accuracy_
            # Update current image
            cur_img += config.minibatch_size
        # Final Saving
        saver.save_model(sess, 'final')
        saver.close_summary()
    print('Done: training')
Exemplo n.º 3
0
def train_triplet_from_images():
    """
    Train the classification or the recognition network
    mode is in ['classification', 'recognition']
    """
    # Graph definition: inputs, model, loss, optimizer, initializer
    print('Graph building...')
    # Inputs
    images = tf.compat.v1.placeholder(
        tf.float32,
        shape=[None, config.image_size, config.image_size, 3],
        name='image_inputs')
    # label_inputs = tf.compat.v1.placeholder(tf.int64, shape=[None,], name='label_inputs')
    training = tf.placeholder(tf.bool, shape=[], name='training')
    lr = tf.compat.v1.placeholder(tf.float32, shape=[], name='learning_rate')
    # Augmentation
    image_inputs = images / 127.5 - 1
    aug = lambda: dataset.augment_image(image_inputs,
                                        config.minibatch_size,
                                        use_horizontal_flip=True,
                                        rotation_rate=0.3,
                                        translation_rate=0.2,
                                        cutout_size=25,
                                        crop_pixels=10)
    image_inputs = tf.cond(training, aug, lambda: image_inputs)
    # Network
    emb, loss, reg_loss = network.recognition(image_inputs, config.emb_size,
                                              training,
                                              config.regularizer_rate)
    # Optimizer
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=lr,
                                                 **config.optimzer_kwargs)
    var_list = tf.compat.v1.get_collection(
        tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope=config.network)
    assert len(var_list) > 0
    # Training operations
    training_op = optimizer.minimize(loss + reg_loss,
                                     var_list=var_list,
                                     name='training_op')
    extra_training_op = tf.compat.v1.get_collection(
        tf.GraphKeys.UPDATE_OPS)  # For batch normalization
    # Initializer
    init = tf.compat.v1.global_variables_initializer()
    print('Done: graph built.')

    # Prepare the dataset
    training_filenames, training_labels = prepare_data.prepare_data(
        config.data_path)
    validation_filenames, validation_labels = prepare_data.prepare_dogfacenet(
        config.data_test_path)

    # Saver
    with tf.compat.v1.variable_scope('Saver'):
        training_accuracy = losses.triplet_accuracy(emb)
        # Saver placeholders
        loss_saver = tf.compat.v1.placeholder(tf.float32,
                                              shape=[],
                                              name='loss_saver')
        training_accuracy_saver = tf.compat.v1.placeholder(
            tf.float32, shape=[], name='training_accuracy_saver')
        validation_accuracy_saver = tf.compat.v1.placeholder(
            tf.float32, shape=[], name='validation_accuracy_saver')
        saver = train.Saver(var_list,
                            config.logs_path,
                            summary_dict={
                                'loss': loss_saver,
                                'training_accuracy': training_accuracy_saver,
                                'validation_accuracy':
                                validation_accuracy_saver,
                                'learning_rate': lr
                            },
                            restore=None)

    # Time measurements
    init_time = time()
    # Training --> use the configuration file
    print('Training...')
    with tf.compat.v1.Session() as sess:
        # Initialize
        init.run()
        # Restore former parameters
        if config.restore:
            print('Restoring weight stored in {}'.format(config.restore))
            saver.restore(sess, config.restore)

        # Training ...
        cur_img = config.start_img
        while cur_img < config.end_img:
            filenames_, _ = define_triplets_batch(
                training_filenames,
                training_labels,
                nbof_triplet=config.minibatch_size)

            images_ = load_images(filenames_)

            feed_dict = {
                images: images_,
                training: True,
                lr: training_schedule(cur_img)
            }

            loss_, training_accuracy_, _, _ = sess.run(
                [loss, training_accuracy, training_op, extra_training_op],
                feed_dict=feed_dict)

            # Validation
            if cur_img % config.img_per_val == 0:
                pred_validation = np.empty((0, config.emb_size))
                labels_test = np.empty((0))
                for i in range(0, len(validation_labels),
                               config.minibatch_size):
                    images_ = load_images(
                        validation_filenames[i:i + config.minibatch_size])
                    feed_dict = {images: images_, training: False}
                    pred_validation = np.append(pred_validation,
                                                sess.run(emb, feed_dict),
                                                axis=0)
                    labels_test = np.append(
                        labels_test,
                        validation_labels[i:i + config.minibatch_size])
                _, _, acc, val, _, far = evaluate.evaluate(
                    pred_validation, labels_test[::2])
                validation_accuracy_ = np.mean(acc)
                feed_dict_saver = {
                    loss_saver: loss_,
                    training_accuracy_saver: training_accuracy_,
                    validation_accuracy_saver: validation_accuracy_,
                    lr: training_schedule(cur_img)
                }
                saver.save_summary(sess, feed_dict_saver, cur_img)
                # Display information
                graph_time = time() - init_time
                hours = int(graph_time // 3600)
                minutes = int(graph_time // 60) - hours * 60
                secondes = graph_time - minutes * 60 - hours * 3600
                print(
                    '{:4d} kimgs, {}h {:2d}m {:2.1f}s, lr:{:0.6f}, training: {:2.2f}%, validation: {:2.2f}%'
                    .format(cur_img // 1000, hours, minutes, secondes,
                            training_schedule(cur_img),
                            training_accuracy_ * 100,
                            validation_accuracy_ * 100))
            # Save logs
            if cur_img % config.img_per_summary == 0:
                feed_dict_saver = {
                    loss_saver: loss_,
                    training_accuracy_saver: training_accuracy_,
                    validation_accuracy_saver: validation_accuracy_,
                    lr: training_schedule(cur_img)
                }
                saver.save_summary(sess, feed_dict_saver, cur_img)
            # Save model
            if cur_img % config.img_per_save == 0:
                saver.save_model(sess, cur_img // 1000)
            # Update current image
            cur_img += config.minibatch_size
        # Final Saving
        saver.save_model(sess, 'final')
        saver.close_summary()
    print('Done: training')
Exemplo n.º 4
0
def train_classification(mode='classification', use_adaptive_loss=False):
    """
    Train the classification or the recognition network
    mode is in ['classification', 'recognition']
    """
    assert mode in ['classification', 'recognition']
    # Graph definition: inputs, model, loss, optimizer, initializer
    print('Graph building...')
    with tf.compat.v1.variable_scope('Dataset'):
        # Init dataset
        training_dataset = getattr(dataset, config.data_initilize)(minibatch_size=config.minibatch_size, shuffle=True, **config.data_initilize_kwargs)
        validation_dataset = getattr(dataset, config.data_test_initilize)(minibatch_size=config.minibatch_size, shuffle=False, repeat=False, **config.data_test_initilize_kwargs)
    # Inputs
    handle = tf.placeholder(tf.string, shape=[], name='handle_input')
    iterator = tf.data.Iterator.from_string_handle(handle, training_dataset.output_types, training_dataset.output_shapes)
    # image_ins, label_ins = iterator.get_next()
    image_inputs, label_inputs = iterator.get_next()
    image_inputs.set_shape([None,config.image_size,config.image_size,3])
    label_inputs.set_shape([None,1])
    label_inputs = tf.reshape(label_inputs, [-1])
    label_inputs = tf.cast(label_inputs, tf.int64)
    image_inputs = tf.cast(image_inputs, dtype=tf.float32)/127.5 - 1
    # Network parameters
    nbof_labels = config.nbof_labels
    training = tf.placeholder(tf.bool, shape=[], name='training')
    regularizer_rate=config.regularizer_rate
    # Augmentation
    # if mode=='classification':
    # aug = lambda: dataset.augment_image(
    #     image_inputs,
    #     config.minibatch_size,
    #     use_horizontal_flip=(mode=='classification'),
    #     rotation_rate   =0.3,
    #     translation_rate=0.1,
    #     cutout_size     =25,
    #     crop_pixels     =10)
    # image_inputs = tf.cond(training, aug, lambda: image_inputs)
    # Network
    if mode=='classification':
        logit,loss,reg_loss = network.classification(image_inputs, label_inputs, nbof_labels, training, regularizer_rate)
    else:
        emb_size = config.emb_size
        emb,logit,loss,reg_loss = network.recognition(image_inputs, label_inputs, emb_size, nbof_labels, training, use_adaptive_loss=False, regularizer_rate=regularizer_rate)
        # emb,loss,reg_loss = network.recognition(image_inputs, label_inputs, emb_size, nbof_labels, training, use_adaptive_loss=use_adaptive_loss, regularizer_rate=regularizer_rate)
    # Optimizer
    # optimizer           = tf.compat.v1.train.MomentumOptimizer(learning_rate=config.learning_rate, momentum=0.9, use_nesterov=True)
    lr                  = tf.compat.v1.placeholder(tf.float32, shape=[], name='lr')
    optimizer           = tf.compat.v1.train.AdamOptimizer(learning_rate=lr, **config.optimzer_kwargs)
    var_list            = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope=config.network)
    assert len(var_list)>0
    # Training operations
    training_op         = optimizer.minimize(loss+reg_loss, var_list=var_list, name='training_op')
    extra_training_op   = tf.compat.v1.get_collection(tf.GraphKeys.UPDATE_OPS) # For batch normalization
    # Initializer
    init                = tf.compat.v1.global_variables_initializer()
    print('Done: graph built.')

    # Define the minibatch selector
    with tf.compat.v1.variable_scope('Dataset'):
        training_iterator = training_dataset.make_one_shot_iterator()
        validation_iterator = validation_dataset.make_initializable_iterator()

    # Saver
    with tf.compat.v1.variable_scope('Saver'):
        if mode == 'classification':
            pred = tf.nn.softmax(logit)
            training_output = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred, axis=1), label_inputs), dtype=tf.float32))
        else:
            pred = tf.nn.softmax(logit)
            training_output = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred, axis=1), label_inputs), dtype=tf.float32))
            # training_output = losses.general_triplet_accuracy(emb, label_inputs, nbof_labels)
        training_accuracy = tf.compat.v1.placeholder(tf.float32, shape=[], name='training_accuracy_placeholder')
        validation_accuracy = tf.compat.v1.placeholder(tf.float32, shape=[], name='validation_accuracy_placeholder')
        saver = train.Saver(var_list,config.logs_path, summary_dict={'loss':loss, 'reg_loss':reg_loss, 'training_accuracy':training_accuracy, 'validation_accuracy':validation_accuracy, 'lr':lr}, restore=None)
    
    # Time measurements
    init_time = time()
    train_acc = 0
    img_per_train_acc = 0
    saved_train_acc = 0
    valid_acc = 0
    # Training --> use the configuration file
    print('Training...')
    with tf.compat.v1.Session() as sess:
        # Initialize
        init.run()
        training_handle = sess.run(training_iterator.string_handle())
        validation_handle = sess.run(validation_iterator.string_handle())
        # Restore former parameters
        if config.restore:
            print('Restoring weight stored in {}'.format(config.restore))
            saver.restore(sess, config.restore)

        # Training ...
        cur_img = config.start_img
        while cur_img < config.end_img:
            # Training operation
            feed_dict = {handle: training_handle, lr:training_schedule(cur_img), training:True}

            # feed_dict = {handle: validation_handle, lr:training_schedule(cur_img), training:False}
            # import matplotlib.pyplot as plt
            # sess.run(validation_iterator.initializer)
            # for i in range(1):
            #     imgs, labs = sess.run([image_inputs,label_inputs],feed_dict=feed_dict)
            # print(labs)
            # fig = plt.figure()
            # fig.set_size_inches(1, 1, forward=False)
            # ax = plt.Axes(fig, [0., 0., 1., 1.])
            # ax.set_axis_off()
            # fig.add_axes(ax)
            # ax.imshow(imgs[1]/2+0.5)
            # n = int(np.sqrt(len(imgs)))
            # for i in range(n):
            #     for j in range(n):
            #         plt.subplot(n,n,i*n+j+1)
            #         plt.imshow(imgs[i*n+j]/2 + 0.5)
            # plt.show()
            # return 0

            _,_, train_out = sess.run([training_op, extra_training_op, training_output], feed_dict=feed_dict)
            
            # Save accuracy
            train_acc += train_out
            img_per_train_acc += 1

            # Validation
            if cur_img % config.img_per_val == 0:
                pred_validation = np.empty((0,nbof_labels) if mode=='classification' else (0,emb_size))
                labels_test = np.empty((0))
                sess.run(validation_iterator.initializer)
                while True:
                    try:
                        pred_, labs = sess.run([pred, label_inputs] if mode=='classification' else [emb, label_inputs],
                            feed_dict={handle: validation_handle, training:False})
                        pred_validation = np.append(pred_validation, pred_, axis=0)
                        labels_test = np.append(labels_test,labs)
                    except tf.errors.OutOfRangeError:
                        break
                if mode=='classification':
                    valid_acc = np.mean(np.equal(np.argmax(pred_validation, axis=1), labels_test))
                else:
                    _, _, acc, val, _, far=evaluate.evaluate(pred_validation, labels_test[::2])
                    valid_acc = np.mean(acc)
                feed_dict_saver = {handle: training_handle, validation_accuracy:valid_acc, training_accuracy:saved_train_acc, training:False, lr:training_schedule(cur_img)}
                saver.save_summary(sess, feed_dict_saver, cur_img)
                # Display information
                graph_time  = time() - init_time
                hours       = int(graph_time // 3600)
                minutes     = int(graph_time // 60) - hours*60
                secondes    = graph_time - minutes * 60 - hours * 3600
                print('{:4d} kimgs, {}h {:2d}m {:2.1f}s, lr:{:0.6f}, training: {:2.2f}%, validation: {:2.2f}%'.format(cur_img//1000, hours, minutes, secondes, training_schedule(cur_img), saved_train_acc*100, valid_acc*100))
            # Save logs
            if cur_img % config.img_per_summary == 0:
                assert img_per_train_acc!=0
                train_acc /= img_per_train_acc
                feed_dict_saver = {handle: training_handle, validation_accuracy:valid_acc, training_accuracy:train_acc, training:False, lr:training_schedule(cur_img)}
                saver.save_summary(sess, feed_dict_saver, cur_img)
                saved_train_acc = train_acc
                train_acc = 0
                img_per_train_acc = 0
            # Save model
            if cur_img % config.img_per_save == 0:
                saver.save_model(sess, cur_img//1000)
            # Update current image
            cur_img += config.minibatch_size
        # Final Saving
        saver.save_model(sess, 'final')
        saver.close_summary()
        sess.close()
    print('Done: training')