Exemple #1
0
def main():
    args = parser.parse_args()

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    ######################
    # directory preparation
    filewriter_path = args.tensorboard_dir
    checkpoint_path = args.checkpoint_dir

    test_mkdir(filewriter_path)
    test_mkdir(checkpoint_path)

    ######################
    # data preparation
    train_file = os.path.join(args.list_dir, "train.txt")
    val_file = os.path.join(args.list_dir, "val.txt")

    train_generator = ImageDataGenerator(train_file, shuffle=True)
    val_generator = ImageDataGenerator(val_file, shuffle=False)

    batch_size = args.batch_size
    train_batches_per_epoch = train_generator.data_size
    val_batches_per_epoch = val_generator.data_size

    ######################
    # model graph preparation
    patch_height = args.patch_size
    patch_width = args.patch_size
    batch_size = args.batch_size

    # TF placeholder for graph input
    leftx = tf.placeholder(tf.float32,
                           shape=[batch_size, patch_height, patch_width, 1])
    rightx_pos = tf.placeholder(
        tf.float32, shape=[batch_size, patch_height, patch_width, 1])
    rightx_neg = tf.placeholder(
        tf.float32, shape=[batch_size, patch_height, patch_width, 1])

    # Initialize model
    left_model = NET(leftx,
                     input_patch_size=patch_height,
                     batch_size=batch_size)
    right_model_pos = NET(rightx_pos,
                          input_patch_size=patch_height,
                          batch_size=batch_size)
    right_model_neg = NET(rightx_neg,
                          input_patch_size=patch_height,
                          batch_size=batch_size)

    featuresl = tf.squeeze(left_model.features, [1, 2])
    featuresr_pos = tf.squeeze(right_model_pos.features, [1, 2])
    featuresr_neg = tf.squeeze(right_model_neg.features, [1, 2])

    # Op for calculating cosine distance/dot product
    with tf.name_scope("correlation"):
        cosine_pos = tf.reduce_sum(tf.multiply(featuresl, featuresr_pos),
                                   axis=-1)
        cosine_neg = tf.reduce_sum(tf.multiply(featuresl, featuresr_neg),
                                   axis=-1)

    # Op for calculating the loss
    with tf.name_scope("hinge_loss"):
        margin = tf.ones(shape=[batch_size], dtype=tf.float32) * args.margin
        loss = tf.maximum(0.0, margin - cosine_pos + cosine_neg)
        loss = tf.reduce_mean(loss)

    # Train op
    with tf.name_scope("train"):
        var_list = tf.trainable_variables()
        for var in var_list:
            print "{}: {}".format(var.name, var.shape)
        # Get gradients of all trainable variables
        gradients = tf.gradients(loss, var_list)
        gradients = list(zip(gradients, var_list))

        # Create optimizer and apply gradient descent with momentum to the trainable variables
        optimizer = tf.train.MomentumOptimizer(args.learning_rate, args.beta)
        train_op = optimizer.apply_gradients(grads_and_vars=gradients)

    # summary Ops for tensorboard visualization
    with tf.name_scope("training_metric"):
        training_summary = []
        # Add loss to summary
        training_summary.append(tf.summary.scalar('hinge_loss', loss))

        # Merge all summaries together
        training_merged_summary = tf.summary.merge(training_summary)

    # validation loss
    with tf.name_scope("val_metric"):
        val_summary = []
        val_loss = tf.placeholder(tf.float32, [])

        # Add val loss to summary
        val_summary.append(tf.summary.scalar('val_hinge_loss', val_loss))
        val_merged_summary = tf.summary.merge(val_summary)

    # Initialize the FileWriter
    writer = tf.summary.FileWriter(filewriter_path)
    # Initialize an saver for store model checkpoints
    saver = tf.train.Saver(max_to_keep=10)

    ######################
    # DO training
    # Start Tensorflow session
    with tf.Session(config=tf.ConfigProto(
                        log_device_placement=False, \
                        allow_soft_placement=True, \
                        gpu_options=tf.GPUOptions(allow_growth=True))) as sess:

        # Initialize all variables
        sess.run(tf.global_variables_initializer())

        # resume from checkpoint or not
        if args.resume is None:
            # Add the model graph to TensorBoard before initial training
            writer.add_graph(sess.graph)
        else:
            saver.restore(sess, args.resume)

        print "training_batches_per_epoch: {}, val_batches_per_epoch: {}.".format(\
                train_batches_per_epoch, val_batches_per_epoch)
        print("{} Start training...".format(datetime.now()))
        print("{} Open Tensorboard at --logdir {}".format(
            datetime.now(), filewriter_path))

        # Loop training
        for epoch in range(args.start_epoch, args.end_epoch):
            print("{} Epoch number: {}".format(datetime.now(), epoch + 1))

            for batch in tqdm(range(train_batches_per_epoch)):
                # Get a batch of data
                batch_left, batch_right_pos, batch_right_neg = train_generator.next_batch(
                    batch_size)

                # And run the training op
                sess.run(train_op,
                         feed_dict={
                             leftx: batch_left,
                             rightx_pos: batch_right_pos,
                             rightx_neg: batch_right_neg
                         })

                # Generate summary with the current batch of data and write to file
                if (batch + 1) % args.print_freq == 0:
                    s = sess.run(training_merged_summary,
                                 feed_dict={
                                     leftx: batch_left,
                                     rightx_pos: batch_right_pos,
                                     rightx_neg: batch_right_neg
                                 })
                    writer.add_summary(s,
                                       epoch * train_batches_per_epoch + batch)

            if (epoch + 1) % args.save_freq == 0:
                print("{} Saving checkpoint of model...".format(
                    datetime.now()))
                # save checkpoint of the model
                checkpoint_name = os.path.join(
                    checkpoint_path, 'model_epoch' + str(epoch + 1) + '.ckpt')
                save_path = saver.save(sess, checkpoint_name)

            if (epoch + 1) % args.val_freq == 0:
                # Validate the model on the entire validation set
                print("{} Start validation".format(datetime.now()))
                val_ls = 0.
                for _ in tqdm(range(val_batches_per_epoch)):
                    batch_left, batch_right_pos, batch_right_neg = val_generator.next_batch(
                        batch_size)
                    result = sess.run(loss,
                                      feed_dict={
                                          leftx: batch_left,
                                          rightx_pos: batch_right_pos,
                                          rightx_neg: batch_right_neg
                                      })
                    val_ls += result

                val_ls = val_ls / (1. * val_batches_per_epoch)

                print 'validation loss: {}'.format(val_ls)
                s = sess.run(val_merged_summary,
                             feed_dict={val_loss: np.float32(val_ls)})
                writer.add_summary(s, train_batches_per_epoch * (epoch + 1))

            # Reset the file pointer of the image data generator
            val_generator.reset_pointer()
            train_generator.reset_pointer()
        print("{} Start validation".format(datetime.now()))
        test_acc = 0.
        test_count = 0
        for _ in range(val_batches_per_epoch):
            batch_tx, batch_ty = val_generator.next_batch(batch_size)
            acc = sess.run(accuracy,
                           feed_dict={
                               x: batch_tx,
                               y: batch_ty,
                               keep_prob: 1.
                           })
            test_acc += acc
            test_count += 1
        test_acc /= test_count
        print("{} Validation Accuracy = {:.4f}".format(datetime.now(),
                                                       test_acc))

        # Reset the file pointer of the image data generator
        val_generator.reset_pointer()
        train_generator.reset_pointer()

        print("{} Saving checkpoint of model...".format(datetime.now()))

        #save checkpoint of the model
        checkpoint_name = os.path.join(
            checkpoint_path, 'model_epoch' + str(epoch + 1) + '.ckpt')
        save_path = saver.save(sess, checkpoint_name)

        print("{} Model checkpoint saved at {}".format(datetime.now(),
                                                       checkpoint_name))
Exemple #3
0
def train(batch_size, learning_rate, conv, fc, dropout_rate, additional):
    x = tf.placeholder(tf.float32, [batch_size, input_size, input_size, 3])
    y = tf.placeholder(tf.float32, [None, num_classes])
    keep_prob = tf.placeholder(tf.float32)

    # If lesser number of convolutions are to be used, the first fully connected
    # layer weights also need to be learned
    if conv < 5:
        train_layers.append('fc6')

    # Load the model with the desired parameters
    model = AlexNet(x, keep_prob, num_classes, train_layers, fc, conv, additional)
    score = model.fc8

    # List of trainable variables of the layers we want to train
    var_list = [v for v in tf.trainable_variables() if v.name.split('/')[0] in
                train_layers]

    # Op for calculating the loss
    with tf.name_scope('cross_entropy'):
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=score,
                              labels=y))

    # Train op
    with tf.name_scope('train'):
        gradients = tf.gradients(loss, var_list)
        gradients = list(zip(gradients, var_list))

        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        train_op = optimizer.apply_gradients(grads_and_vars=gradients)

    # Add gradients to summary
    for gradient, var in gradients:
        tf.summary.histogram(var.name + '/gradient', gradient)

    # Add the variables we train to the summary
    for var in var_list:
        tf.summary.histogram(var.name, var)

    # Add the loss to the summary
    tf.summary.scalar('cross_entropy', loss)

    # Op for the accuracy of the model
    with tf.name_scope('accuracy'):
        correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    tf.summary.scalar('accuracy', accuracy)

    # Merge all summaries together
    merged_summary = tf.summary.merge_all()

    # Initialize the FileWriter
    writer = tf.summary.FileWriter(filewriter_path)

    saver = tf.train.Saver()

    train_generator = ImageDataGenerator(train_file, horizontal_flip=True,
                                         shuffle=True)
    val_generator = ImageDataGenerator(val_file, shuffle=False)

    # Get the number of training / validation steps per epoch
    train_batches_per_epoch = np.floor(train_generator.data_size /
                                       batch_size).astype(np.int16)
    val_batches_per_epoch = np.floor(val_generator.data_size / batch_size).astype(int)

    # sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    # saver = tf.train.import_meta_graph('/tmp/finetune_alexnet/model_epoch0.ckpt.meta')

    # saver.restore(sess, '/tmp/finetune_alexnet/model_epoch0.ckpt')
    
    # Comment this line and uncomment the above ones to reuse the model after a checkpoint
    # for faster training
    model.load_initial_weights(sess)
    print("{} Restored model...".format(datetime.now()))

    test_acc_prev = 0
    test_count = 0
    for _ in range(val_batches_per_epoch):
        batch_tx, batch_ty = val_generator.next_batch(batch_size)
        acc = sess.run(accuracy, feed_dict={x: batch_tx, y: batch_ty, keep_prob: 1.})
        test_acc_prev += acc
        test_count += 1

    test_acc_prev /= test_count
    test_acc = test_acc_prev
    print("{} Initial Validation Accuracy = {:.4f}".format(datetime.now(), test_acc_prev))
    writer.add_graph(sess.graph)

    print('{} Start training...'.format(datetime.now()))
    print('{} Open TensorBoard at --logdir {}'.format(datetime.now(),
          filewriter_path))

    for epoch in range(num_epochs):
        print('{} Epoch number: {}'.format(datetime.now(), epoch + 1))

        step = 1

        while step < train_batches_per_epoch:
            print('{} Step number: {}'.format(datetime.now(), step))
            batch_xs, batch_ys = train_generator.next_batch(batch_size)
            sess.run(train_op, feed_dict={x: batch_xs, y: batch_ys,
                                          keep_prob: dropout_rate})
            if step % display_step == 0:
                s = sess.run(merged_summary, feed_dict={x: batch_xs,
                             y: batch_ys, keep_prob: dropout_rate})
                writer.add_summary(s, epoch * train_batches_per_epoch + step)

            step += 1
    
        # Validate the model on the entire validation set
        print("{} Start validation".format(datetime.now()))
        test_acc = 0.
        test_count = 0
        for ind in range(val_batches_per_epoch):
            print('{} Valid batch number: {}'.format(datetime.now(), ind))
            batch_tx, batch_ty = val_generator.next_batch(batch_size)
            acc = sess.run(accuracy, feed_dict={x: batch_tx,
                                                y: batch_ty,
                                                keep_prob: 1.})
            test_acc += acc
            test_count += 1
        test_acc /= test_count
        print("{} Validation Accuracy = {:.4f}".format(datetime.now(), test_acc))

        if test_acc > test_acc_prev:
            print("{} Saving checkpoint of model...".format(datetime.now()))

            #save checkpoint of the model
            checkpoint_name = os.path.join(checkpoint_path, 'model_epoch'+str(epoch)+'.ckpt')
            save_path = saver.save(sess, checkpoint_name)

            print("{} Model checkpoint saved at {}".format(datetime.now(), checkpoint_name))

            if abs(test_acc - test_acc_prev) < 0.05:
                print("Early stopping.... exiting")
                break
    
            test_acc_prev = test_acc

        # Reset the file pointer of the image data generator
        val_generator.reset_pointer()
        train_generator.reset_pointer()
    return test_acc
Exemple #4
0
                                      y: batch_ty,
                                      keep_var: 1.
                                  })
            #print pre_labels
            test_pre_label += pre_labels.tolist()
            #test_true_label += tf.argmax(batch_ty,1).tolist()
        #pdb.set_trace()
        test_pre_label += rest_test_pre_label
        test_acc = accuracy_score(val_generator.labels, test_pre_label)
        print("{} Iter {}: Testing Accuracy = {:.4f}".format(
            datetime.now(), step, test_acc))
        #all_test_acc.append(test_acc)

        print("F1 score = {:.4f}".format(
            f1_score(val_generator.labels, test_pre_label, average='macro')))

        print("Confusionmatrix =")
        print(" {} ".format(
            confusion_matrix(val_generator.labels, test_pre_label)))

        # Reset the file pointer of the image data generator
        val_generator.reset_pointer()  #2attention
        train_generator.reset_pointer()

        #save model
        temp = '%05d' % step

        model_path_1 = model_path + temp + '.ckpt'
        #save_path = saver.save(sess, model_path_1)
        #print("Model saved in file: {}".format(save_path))
Exemple #5
0
def main(_):
    # Create training directories
    now = datetime.datetime.now()
    train_dir_name = now.strftime('resnet_%Y%m%d_%H%M%S')
    train_dir = os.path.join(FLAGS.tensorboard_root_dir, train_dir_name)
    checkpoint_dir = os.path.join(train_dir, 'checkpoint')
    tensorboard_dir = os.path.join(train_dir, 'tensorboard')
    tensorboard_train_dir = os.path.join(tensorboard_dir, 'train')
    tensorboard_val_dir = os.path.join(tensorboard_dir, 'val')

    if not os.path.isdir(FLAGS.tensorboard_root_dir): os.mkdir(FLAGS.tensorboard_root_dir)
    if not os.path.isdir(train_dir): os.mkdir(train_dir)
    if not os.path.isdir(checkpoint_dir): os.mkdir(checkpoint_dir)
    if not os.path.isdir(tensorboard_dir): os.mkdir(tensorboard_dir)
    if not os.path.isdir(tensorboard_train_dir): os.mkdir(tensorboard_train_dir)
    if not os.path.isdir(tensorboard_val_dir): os.mkdir(tensorboard_val_dir)

    # Write flags to txt
    flags_file_path = os.path.join(train_dir, 'flags.txt')
    flags_file = open(flags_file_path, 'w')
    flags_file.write('learning_rate={}\n'.format(FLAGS.learning_rate))
    flags_file.write('resnet_depth={}\n'.format(FLAGS.resnet_depth))
    flags_file.write('num_epochs={}\n'.format(FLAGS.num_epochs))
    flags_file.write('batch_size={}\n'.format(FLAGS.batch_size))
    flags_file.write('train_layers={}\n'.format(FLAGS.train_layers))
    flags_file.write('multi_scale={}\n'.format(FLAGS.multi_scale))
    flags_file.write('tensorboard_root_dir={}\n'.format(FLAGS.tensorboard_root_dir))
    flags_file.write('log_step={}\n'.format(FLAGS.log_step))
    flags_file.close()

    # Placeholders
    x = tf.placeholder(tf.float32, [FLAGS.batch_size, 224, 224, 3])
    y = tf.placeholder(tf.float32, [None, FLAGS.num_classes])
    is_training = tf.placeholder('bool', [])

    # Model
    train_layers = FLAGS.train_layers.split(',')
    model = ResNetModel(is_training, depth=FLAGS.resnet_depth, num_classes=FLAGS.num_classes)
    loss = model.loss(x, y)
    train_op = model.optimize(FLAGS.learning_rate, train_layers)

    # Training accuracy of the model
    correct_pred = tf.equal(tf.argmax(model.prob, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

    # Summaries
    tf.summary.scalar('train_loss', loss)
    tf.summary.scalar('train_accuracy', accuracy)
    merged_summary = tf.summary.merge_all()

    train_writer = tf.summary.FileWriter(tensorboard_train_dir)
    val_writer = tf.summary.FileWriter(tensorboard_val_dir)
    saver = tf.train.Saver()

    # Batch preprocessors
    multi_scale = FLAGS.multi_scale.split(',')
    if len(multi_scale) == 2:
        multi_scale = [int(multi_scale[0]), int(multi_scale[1])]
    else:
        multi_scale = None
    # Initalize the data generator seperately for the training and validation set
    train_generator = ImageDataGenerator(FLAGS.training_file, 
                                     horizontal_flip = True, shuffle = True)
    val_generator = ImageDataGenerator(FLAGS.val_file, shuffle = False) 

    #train_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.training_file, num_classes=FLAGS.num_classes,
    #                                       output_size=[224, 224], horizontal_flip=True, shuffle=True, multi_scale=multi_scale)
    #val_preprocessor = BatchPreprocessor(dataset_file_path=FLAGS.val_file, num_classes=FLAGS.num_classes, output_size=[224, 224])
    # Get the number of training/validation steps per epoch
    train_batches_per_epoch = np.floor(train_generator.data_size / FLAGS.batch_size).astype(np.int16)
    val_batches_per_epoch = np.floor(val_generator.data_size / FLAGS.batch_size).astype(np.int16)

    # Get the number of training/validation steps per epoch
    #train_batches_per_epoch = np.floor(len(train_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)
    #val_batches_per_epoch = np.floor(len(val_preprocessor.labels) / FLAGS.batch_size).astype(np.int16)
    os.environ["CUDA_VISIBLE_DEVICES"] = "1"
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.9)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        train_writer.add_graph(sess.graph)

        # Load the pretrained weights
        model.load_original_weights(sess, skip_layers=train_layers)

        # Directly restore (your model should be exactly the same with checkpoint)
        # saver.restore(sess, "/Users/dgurkaynak/Projects/marvel-training/alexnet64-fc6/model_epoch10.ckpt")

        print("{} Start training...".format(datetime.datetime.now()))
        print("{} Open Tensorboard at --logdir {}".format(datetime.datetime.now(), tensorboard_dir))

        for epoch in range(FLAGS.num_epochs):
            print("{} Epoch number: {}".format(datetime.datetime.now(), epoch+1))
            step = 1

            # Start training
            while step < train_batches_per_epoch:
                # Get a batch of images and labels
                batch_xs, batch_ys = train_generator.next_batch(FLAGS.batch_size)
                #batch_xs, batch_ys = train_preprocessor.next_batch(FLAGS.batch_size)
                sess.run(train_op, feed_dict={x: batch_xs, y: batch_ys, is_training: True})

                # Logging
                if step % FLAGS.log_step == 0:
                    s = sess.run(merged_summary, feed_dict={x: batch_xs, y: batch_ys, is_training: False})
                    train_writer.add_summary(s, epoch * train_batches_per_epoch + step)

                step += 1

            # Epoch completed, start validation
            print("{} Start validation".format(datetime.datetime.now()))
            test_acc = 0.
            test_count = 0

            for _ in range(val_batches_per_epoch):
                batch_tx, batch_ty = val_generator.next_batch(FLAGS.batch_size)
                #batch_tx, batch_ty = val_preprocessor.next_batch(FLAGS.batch_size)
                acc = sess.run(accuracy, feed_dict={x: batch_tx, y: batch_ty, is_training: False})
                test_acc += acc
                test_count += 1

            test_acc /= test_count
            s = tf.Summary(value=[
                tf.Summary.Value(tag="validation_accuracy", simple_value=test_acc)
            ])
            val_writer.add_summary(s, epoch+1)
            print("{} Validation Accuracy = {:.4f}".format(datetime.datetime.now(), test_acc))

            # Reset the file pointer of the image data generator
            val_generator.reset_pointer()
            train_generator.reset_pointer()

            print("{} Saving checkpoint of model...".format(datetime.datetime.now()))

            #save checkpoint of the model
            checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch'+str(epoch+1)+'.ckpt')
            save_path = saver.save(sess, checkpoint_path)

            print("{} Model checkpoint saved at {}".format(datetime.datetime.now(), checkpoint_path))
Exemple #6
0
def main(unused_argv):
    if FLAGS.job_name is None or FLAGS.job_name == '':
        raise ValueError('Must specify an expilict job_name')
    else:
        print('job_name: %s' % FLAGS.job_name)

    if FLAGS.task_index is None or FLAGS.task_index == '':
        raise ValueError('Must specify an explicit task_index')
    else:
        print('task_index:%s' % FLAGS.task_index)

    ps_spec = FLAGS.ps_hosts.split(',')
    worker_spec = FLAGS.worker_hosts.split(',')
    num_worker = len(worker_spec)

    cluster = tf.train.ClusterSpec({'ps': ps_spec, 'worker': worker_spec})

    kill_ps_queue = create_done_queue(num_worker)

    server = tf.train.Server(cluster,
                             job_name=FLAGS.job_name,
                             task_index=FLAGS.task_index)

    if FLAGS.job_name == 'ps':
        # server.join()
        with tf.Session(server.target) as sess:
            for i in range(num_worker):
                sess.run(kill_ps_queue.dequeue())
        return

    is_chief = (FLAGS.task_index == 0)

    if FLAGS.use_gpu:
        worker_device = '/job:worker/task:%d/gpu:%d' % (FLAGS.task_index,
                                                        FLAGS.gpu_id)
    else:
        worker_device = '/job:worker/task:%d/cpu:0' % FLAGS.task_index

    with tf.device(
            tf.train.replica_device_setter(worker_device=worker_device,
                                           ps_device='/job:ps/cpu:0',
                                           cluster=cluster)):

        global_step = tf.Variable(0, name='global_step', trainable=False)

        x = tf.placeholder(tf.float32, [None, 227, 227, 3], name='x')
        y = tf.placeholder(tf.float32, [None, FLAGS.n_classes], name='y')

        keep_prob = tf.placeholder(tf.float32, name='kp')

        model = AlexNet(x, keep_prob, FLAGS.n_classes)

        score = model.fc3

        cross_entropy = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=score))

        tf.summary.scalar('loss', cross_entropy)

        opt = get_optimizer('Adam', FLAGS.learning_rate)

        if FLAGS.sync_replicas:
            replicas_to_aggregate = num_worker
            opt = tf.train.SyncReplicasOptimizer(
                opt,
                replicas_to_aggregate=replicas_to_aggregate,
                total_num_replicas=num_worker,
                use_locking=False,
                name='sync_replicas')

        train_op = opt.minimize(cross_entropy, global_step=global_step)

        correct_prediction = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
        tf.summary.scalar('accuary', accuracy)

        if FLAGS.sync_replicas:
            local_init_op = opt.local_step_init_op
            if is_chief:
                local_init_op = opt.chief_init_op
            ready_for_local_init_op = opt.ready_for_local_init_op

            chief_queue_runner = opt.get_chief_queue_runner()
            init_token_op = opt.get_init_tokens_op()

        init_op = tf.global_variables_initializer()
        kill_ps_enqueue_op = kill_ps_queue.enqueue(1)

        summary_op = tf.summary.merge_all()
        writer = tf.summary.FileWriter(FLAGS.logdir)
        saver = tf.train.Saver()

        # train_dir = tempfile.mkdtemp()

        if FLAGS.sync_replicas:
            sv = tf.train.Supervisor(
                is_chief=is_chief,
                logdir=FLAGS.checkpoint,
                init_op=init_op,
                local_init_op=local_init_op,
                ready_for_local_init_op=ready_for_local_init_op,
                summary_op=summary_op,
                saver=saver,
                summary_writer=writer,
                recovery_wait_secs=1,
                global_step=global_step)
        else:
            sv = tf.train.Supervisor(is_chief=is_chief,
                                     logdir=FLAGS.checkpoint,
                                     init_op=init_op,
                                     recovery_wait_secs=1,
                                     summary_op=summary_op,
                                     saver=saver,
                                     summary_writer=writer,
                                     global_step=global_step)

        sess_config = tf.ConfigProto(allow_soft_placement=True,
                                     log_device_placement=False,
                                     device_filters=[
                                         '/job:ps',
                                         '/job:worker/task:%d' %
                                         FLAGS.task_index
                                     ])

        if is_chief:
            print('Worker %d: Initailizing session...' % FLAGS.task_index)
        else:
            print('Worker %d: Waiting for session to be initaialized...' %
                  FLAGS.task_index)

        sess = sv.prepare_or_wait_for_session(server.target,
                                              config=sess_config)
        print('Worker %d: Session initialization complete.' % FLAGS.task_index)

        if FLAGS.sync_replicas and is_chief:
            sess.run(init_token_op)
            sv.start_queue_runners(sess, [chief_queue_runner])

        train_generator = ImageDataGenerator(FLAGS.train_file,
                                             horizontal_flip=True,
                                             shuffle=True)
        val_generator = ImageDataGenerator(FLAGS.val_file, shuffle=False)

        # Get the number of training/validation steps per epoch
        train_batches_per_epoch = np.floor(train_generator.data_size /
                                           FLAGS.batch_size).astype(np.int16)
        val_batches_per_epoch = np.floor(val_generator.data_size /
                                         FLAGS.batch_size).astype(np.int16)

        print("{} Start training...".format(datetime.now()))
        print("{} Open Tensorboard at --logdir {}".format(
            datetime.now(), FLAGS.logdir))

        for epoch in range(FLAGS.num_epoches):

            print("{} Epoch number: {}".format(datetime.now(), epoch + 1))

            step = 1

            while step < train_batches_per_epoch:

                start_time = time.time()
                # Get a batch of images and labels

                batch_xs, batch_ys = train_generator.next_batch(
                    FLAGS.batch_size)

                # And run the training op
                _, loss, gstep = sess.run(
                    [train_op, cross_entropy, global_step],
                    feed_dict={
                        x: batch_xs,
                        y: batch_ys,
                        keep_prob: FLAGS.dropout
                    })

                print('total step: %d, loss: %f' % (gstep, loss))
                duration = time.time() - start_time

                # Generate summary with the current batch of data and write to file
                if step % FLAGS.display_step == 0:
                    s = sess.run(sv.summary_op,
                                 feed_dict={
                                     x: batch_xs,
                                     y: batch_ys,
                                     keep_prob: 1.
                                 })
                    writer.add_summary(s,
                                       epoch * train_batches_per_epoch + step)
                # print

                if step % 10 == 0:
                    print("[INFO] {} pics has trained. time using {}".format(
                        step * FLAGS.batch_size, duration))

                step += 1

            # Validate the model on the entire validation set
            print("{} Start validation".format(datetime.now()))
            test_acc = 0.
            test_count = 0
            for _ in range(val_batches_per_epoch):
                batch_tx, batch_ty = val_generator.next_batch(FLAGS.batch_size)
                acc = sess.run(accuracy,
                               feed_dict={
                                   x: batch_tx,
                                   y: batch_ty,
                                   keep_prob: 1.
                               })
                test_acc += acc
                test_count += 1
            test_acc /= test_count
            print("Validation Accuracy = {} {}".format(datetime.now(),
                                                       test_acc))

            # Reset the file pointer of t
            # he image data generator
            val_generator.reset_pointer()
            train_generator.reset_pointer()

            print("{} Saving checkpoint of model...".format(datetime.now()))

            # save checkpoint of the model
            checkpoint_name = os.path.join(
                FLAGS.checkpoint, 'model_epoch' + str(epoch + 1) + '.ckpt')
            save_path = sv.saver.save(sess, checkpoint_name)

            print("{} Model checkpoint saved at {}".format(
                datetime.now(), checkpoint_name))
    def fine_tuning(self, train_list, test_list, mean, snapshot,
                    filewriter_path):
        # Learning params
        learning_rate = 0.0005
        num_epochs = 151
        batch_size = 64

        # Network params
        in_img_size = (227, 227)  #(height, width)
        dropout_rate = 1
        num_classes = 2
        train_layers = ['fc7', 'fc8']

        # How often we want to write the tf.summary data to disk
        display_step = 30

        x = tf.placeholder(tf.float32,
                           [batch_size, in_img_size[0], in_img_size[1], 3])
        y = tf.placeholder(tf.float32, [None, num_classes])
        keep_prob = tf.placeholder(tf.float32)

        # Initialize model
        model = alexnet(x,
                        keep_prob,
                        num_classes,
                        train_layers,
                        in_size=in_img_size)
        #link variable to model output
        score = model.fc8
        # List of trainable variables of the layers we want to train
        var_list = [
            v for v in tf.trainable_variables()
            if v.name.split('/')[0] in train_layers
        ]
        # Op for calculating the loss
        with tf.name_scope("cross_ent"):
            loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits(logits=score,
                                                        labels=y))
            # Train op

            # Get gradients of all trainable variables
            gradients = tf.gradients(loss, var_list)
            gradients = list(zip(gradients, var_list))
            '''
            # Create optimizer and apply gradient descent to the trainable variables
            learning_rate = tf.train.exponential_decay(learning_rate,
                                           global_step=tf.Variable(0, trainable=False),
                                           decay_steps=10,decay_rate=0.9)
            '''
            optimizer = tf.train.MomentumOptimizer(learning_rate, 0.9)
            train_op = optimizer.minimize(loss)

        # Add gradients to summary
        for gradient, var in gradients:
            tf.summary.histogram(var.name + '/gradient', gradient)
        # Add the variables we train to the summary
        for var in var_list:
            tf.summary.histogram(var.name, var)
        # Add the loss to summary
        tf.summary.scalar('cross_entropy', loss)

        # Evaluation op: Accuracy of the model
        with tf.name_scope("accuracy"):
            correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1))
            accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
        # Add the accuracy to the summary
        tf.summary.scalar('accuracy', accuracy)

        # Merge all summaries together
        merged_summary = tf.summary.merge_all()
        # Initialize the FileWriter
        writer = tf.summary.FileWriter(filewriter_path)
        # Initialize an saver for store model checkpoints
        saver = tf.train.Saver()
        # Initalize the data generator seperately for the training and validation set
        train_generator = ImageDataGenerator(train_list,
                                             horizontal_flip=True,
                                             shuffle=False,
                                             mean=mean,
                                             scale_size=in_img_size,
                                             nb_classes=num_classes)
        val_generator = ImageDataGenerator(test_list,
                                           shuffle=False,
                                           mean=mean,
                                           scale_size=in_img_size,
                                           nb_classes=num_classes)
        # Get the number of training/validation steps per epoch
        train_batches_per_epoch = np.floor(train_generator.data_size /
                                           batch_size).astype(np.int16)
        val_batches_per_epoch = np.floor(val_generator.data_size /
                                         batch_size).astype(np.int16)

        # Start Tensorflow session
        with tf.Session() as sess:
            # Initialize all variables
            sess.run(tf.global_variables_initializer())
            # Add the model graph to TensorBoard
            writer.add_graph(sess.graph)
            # Load the pretrained weights into the non-trainable layer
            model.load_initial_weights(sess)
            print("{} Start training...".format(datetime.now()))
            print("{} Open Tensorboard at --logdir {}".format(
                datetime.now(), filewriter_path))
            # Loop over number of epochs
            for epoch in range(num_epochs):
                print("{} Epoch number: {}/{}".format(datetime.now(),
                                                      epoch + 1, num_epochs))
                step = 1
                while step < train_batches_per_epoch:
                    # Get a batch of images and labels
                    batch_xs, batch_ys = train_generator.next_batch(batch_size)
                    # And run the training op
                    sess.run(train_op,
                             feed_dict={
                                 x: batch_xs,
                                 y: batch_ys,
                                 keep_prob: dropout_rate
                             })
                    # Generate summary with the current batch of data and write to file
                    if step % display_step == 0:
                        s = sess.run(merged_summary,
                                     feed_dict={
                                         x: batch_xs,
                                         y: batch_ys,
                                         keep_prob: 1.
                                     })
                        writer.add_summary(
                            s, epoch * train_batches_per_epoch + step)
                    step += 1

                # Validate the model on the entire validation set
                print("{} Start validation".format(datetime.now()))
                test_acc = 0.
                test_count = 0
                for _ in range(val_batches_per_epoch):
                    batch_tx, batch_ty = val_generator.next_batch(batch_size)
                    acc = sess.run(accuracy,
                                   feed_dict={
                                       x: batch_tx,
                                       y: batch_ty,
                                       keep_prob: 1.
                                   })
                    test_acc += acc
                    test_count += 1
                test_acc /= test_count
                print("{} Validation Accuracy = {:.4f}".format(
                    datetime.now(), test_acc))

                # Reset the file pointer of the image data generator
                val_generator.reset_pointer()
                train_generator.reset_pointer()
                print("{} Saving checkpoint of model...".format(
                    datetime.now()))

                #save checkpoint of the model
                if epoch % display_step == 0:
                    checkpoint_name = os.path.join(
                        snapshot, 'model_epoch' + str(epoch) + '.ckpt')
                    save_path = saver.save(sess, checkpoint_name)
                    print("{} Model checkpoint saved at {}".format(
                        datetime.now(), checkpoint_name))