def run_vgg_training(model_name, data_directory, path_to_train_file,
                     path_to_val_file, path_to_labels_file, bsize, num_steps,
                     train_log_dir, optimizer, initial_checkpoint):

    graph = tf.Graph()
    sess = tf.InteractiveSession(graph=graph)

    with graph.as_default():
        name_dict, nclass = gen_dict(data_directory, path_to_labels_file)

        label, image = getImage(path_to_train_file, nclass)
        vlabel, vimage = getImage(path_to_val_file, nclass)

        imageBatch, labelBatch = tf.train.shuffle_batch([image, label],
                                                        batch_size=bsize,
                                                        capacity=2000,
                                                        min_after_dequeue=1000)

        vimageBatch, vlabelBatch = tf.train.shuffle_batch(
            [vimage, vlabel],
            batch_size=bsize,
            capacity=2000,
            min_after_dequeue=1000)

        with sess.as_default():
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            sess.run(tf.global_variables_initializer())
            batch_xs, batch_ys = sess.run([imageBatch, labelBatch])
            vbatch_xs, vbatch_ys = sess.run([vimageBatch, vlabelBatch])

            summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

            if model_name == 'vgg_16':
                with tf.variable_scope("vgg_16") as scope:
                    logits, end_points = vgg.vgg_16(batch_xs,
                                                    num_classes=2,
                                                    is_training=True)
                    scope.reuse_variables()
                    vlogits, vend_points = vgg.vgg_16(vbatch_xs,
                                                      num_classes=2,
                                                      is_training=True)
            elif model_name == 'vgg_19':
                with tf.variable_scope("vgg_19") as scope:
                    logits, end_points = vgg.vgg_19(batch_xs,
                                                    num_classes=2,
                                                    is_training=True)
                    scope.reuse_variables()

                    vlogits, vend_points = vgg.vgg_19(vbatch_xs,
                                                      num_classes=2,
                                                      is_training=True)

            # predictions = end_points['Predictions']
            # predictions_validation = vend_points['Predictions']  # -- for inception model use Predictions

            predictions = tf.nn.softmax(logits)
            predictions_validation = tf.nn.softmax(vlogits)

            correct_prediction = tf.equal(tf.argmax(predictions, 1),
                                          tf.argmax(batch_ys, 1))
            vcorrect_prediction = tf.equal(
                tf.argmax(predictions_validation, 1), tf.argmax(vbatch_ys, 1))

            # get mean of all entries in correct prediction, the higher the better
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
            accuracy_validation = tf.reduce_mean(
                tf.cast(vcorrect_prediction, tf.float32))

            logits = tf.reshape(logits, [bsize, 2])

            tf.losses.softmax_cross_entropy(batch_ys, logits)

            total_loss = tf.losses.get_total_loss()

            train_tensor = slim.learning.create_train_op(total_loss, optimizer)

        def train_step_fn(sess, *args, **kwargs):
            total_loss, should_stop = train_step(sess, *args, **kwargs)
            accuracy = sess.run([train_step_fn.accuracy])
            if train_step_fn.step % 50 == 0:
                # sess.run(assignment)
                accuracy_validation = sess.run(
                    [train_step_fn.accuracy_validation])
                # print('Step %s - Loss: %.2f Validation Accuracy: %.2f%%' %
                #       (str(train_step_fn.step).rjust(6, '0'), total_loss, accuracy * 100))
                # saver.save(sess, os.path.join(train_log_dir, "model.ckpt"), train_step_fn.step)

            train_step_fn.step += 1
            return [total_loss, should_stop]

        train_step_fn.step = 0
        train_step_fn.accuracy = accuracy
        train_step_fn.accuracy_validation = accuracy_validation

        summaries.add(tf.summary.scalar('accuracy', accuracy))
        # tf.summary.scalar('accuracy', accuracy)
        summaries.add(tf.summary.scalar('loss', total_loss))
        # tf.summary.scalar('loss', total_loss)
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        slim.learning.train(train_tensor,
                            train_log_dir,
                            number_of_steps=num_steps,
                            summary_op=summary_op,
                            train_step_fn=train_step_fn,
                            save_summaries_secs=20)
        print('completed training')

        coord.request_stop()
        coord.join(threads)
def run_inception_training(model_name, data_directory, path_to_train_file,
                           path_to_val_file, path_to_labels_file, bsize,
                           num_steps, train_log_dir, optimizer,
                           initial_checkpoint):

    graph = tf.Graph()
    sess = tf.InteractiveSession(graph=graph)

    with graph.as_default():
        name_dict, nclass = gen_dict(data_directory, path_to_labels_file)

        label, image = getImage(path_to_train_file, nclass)
        vlabel, vimage = getImage(path_to_val_file, nclass)

        imageBatch, labelBatch = tf.train.shuffle_batch([image, label],
                                                        batch_size=bsize,
                                                        capacity=2000,
                                                        min_after_dequeue=1000)

        vimageBatch, vlabelBatch = tf.train.shuffle_batch(
            [vimage, vlabel],
            batch_size=bsize,
            capacity=2000,
            min_after_dequeue=1000)

        with sess.as_default():
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            sess.run(tf.global_variables_initializer())
            batch_xs, batch_ys = sess.run([imageBatch, labelBatch])
            vbatch_xs, vbatch_ys = sess.run([vimageBatch, vlabelBatch])

            summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

            if model_name == 'inception_v1':
                with tf.variable_scope("InceptionV1") as scope:
                    logits, end_points = inception.inception_v1(
                        batch_xs, num_classes=2, is_training=True)
                    scope.reuse_variables()
                    vlogits, vend_points = inception.inception_v1(
                        vbatch_xs, num_classes=2, is_training=True)
            elif model_name == 'inception_v2':
                with tf.variable_scope("InceptionV2") as scope:
                    logits, end_points = inception.inception_v2(
                        batch_xs, num_classes=2, is_training=True)
                    scope.reuse_variables()
                    vlogits, vend_points = inception.inception_v2(
                        vbatch_xs, num_classes=2, is_training=True)
            elif model_name == 'inception_v3':
                with tf.variable_scope("InceptionV3") as scope:
                    logits, end_points = inception.inception_v3(
                        batch_xs, num_classes=2, is_training=True)
                    scope.reuse_variables()
                    vlogits, vend_points = inception.inception_v3(
                        vbatch_xs, num_classes=2, is_training=True)

            predictions = end_points['Predictions']
            predictions_validation = vend_points[
                'Predictions']  # -- for inception model use Predictions

            correct_prediction = tf.equal(tf.argmax(predictions, 1),
                                          tf.argmax(batch_ys, 1))
            vcorrect_prediction = tf.equal(
                tf.argmax(predictions_validation, 1), tf.argmax(vbatch_ys, 1))

            # get mean of all entries in correct prediction, the higher the better
            accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
            accuracy_validation = tf.reduce_mean(
                tf.cast(vcorrect_prediction, tf.float32))

            logits = tf.reshape(logits, [bsize, 2])

            tf.losses.softmax_cross_entropy(batch_ys, logits)

            total_loss = tf.losses.get_total_loss()

            train_tensor = slim.learning.create_train_op(total_loss, optimizer)

            embedding_size = 1024
            embedding_input = end_points
            # print(tf.shape(embedding_input))
            embedding = tf.Variable(tf.zeros([bsize, embedding_size]),
                                    name="Embedding_Tensor")
            assignment = embedding.assign()
            writer = tf.summary.FileWriter(train_log_dir + '/', sess.graph)
            config = tf.contrib.tensorboard.plugins.projector.ProjectorConfig()
            embedding_config = config.embeddings.add()
            embedding_config.tensor_name = embedding.name
            embedding_config.sprite.image_path = train_log_dir + '/sprite.png'
            embedding_config.metadata_path = train_log_dir + '/metadata.tsv'
            print(embedding_config.sprite.image_path)
            print(embedding_config.metadata_path)
            # Specify the width and height of a single thumbnail.
            embedding_config.sprite.single_image_dim.extend([224, 224])
            tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
                writer, config)

        def train_step_fn(sess, *args, **kwargs):
            total_loss, should_stop = train_step(sess, *args, **kwargs)
            accuracy = sess.run([train_step_fn.accuracy])
            if train_step_fn.step % 50 == 0:
                # sess.run(assignment)
                accuracy_validation = sess.run(
                    [train_step_fn.accuracy_validation])
                # print('Step %s - Loss: %.2f Validation Accuracy: %.2f%%' %
                #       (str(train_step_fn.step).rjust(6, '0'), total_loss, accuracy * 100))
                # saver.save(sess, os.path.join(train_log_dir, "model.ckpt"), train_step_fn.step)

            train_step_fn.step += 1
            return [total_loss, should_stop]

        train_step_fn.step = 0
        train_step_fn.accuracy = accuracy
        train_step_fn.accuracy_validation = accuracy_validation

        tf.summary.scalar('accuracy', accuracy)
        tf.summary.scalar('loss', total_loss)
        for end_point in end_points:
            x = end_points[end_point]
            summaries.add(tf.summary.histogram('activations/' + end_point, x))
        for variable in slim.get_model_variables():
            summaries.add(tf.summary.histogram(variable.op.name, variable))
        summary_op = tf.summary.merge(list(summaries), name='summary_op')

        slim.learning.train(train_tensor,
                            train_log_dir,
                            number_of_steps=num_steps,
                            summary_op=summary_op,
                            train_step_fn=train_step_fn,
                            save_summaries_secs=20)
        print('completed training')

        coord.request_stop()
        coord.join(threads)