예제 #1
0
def _train(path_to_train_tfrecords_file, num_train_examples,
           path_to_val_tfrecords_file, num_val_examples, path_to_train_log_dir,
           path_to_restore_checkpoint_file, training_options):
    batch_size = training_options['batch_size']
    initial_patience = training_options['patience']
    num_steps_to_show_loss = 100
    num_steps_to_check = 1000

    with tf.Graph().as_default():
        image_batch, pieces_batch = Donkey.build_batch(
            path_to_train_tfrecords_file,
            num_examples=num_train_examples,
            batch_size=batch_size,
            shuffled=True)
        print pieces_batch
        pieces_logits = Model.inference(image_batch, drop_rate=0.2)
        loss = Model.loss(pieces_logits, pieces_batch)

        global_step = tf.Variable(0, name='global_step', trainable=False)
        learning_rate = tf.train.exponential_decay(
            training_options['learning_rate'],
            global_step=global_step,
            decay_steps=training_options['decay_steps'],
            decay_rate=training_options['decay_rate'],
            staircase=True)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        train_op = optimizer.minimize(loss, global_step=global_step)

        tf.summary.image('image', image_batch)
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('learning_rate', learning_rate)
        summary = tf.summary.merge_all()
        """
예제 #2
0
    def evaluate(self, path_to_checkpoint, path_to_tfrecords_file,
                 num_examples, global_step):
        batch_size = 128
        num_batches = num_examples / batch_size

        with tf.Graph().as_default():
            image_batch, digits_batch, letters_batch = Donkey.build_batch(
                path_to_tfrecords_file,
                num_examples=num_examples,
                batch_size=batch_size,
                shuffled=False)
            digits_logits, letters_logits = Model.inference(image_batch,
                                                            drop_rate=0.0)
            digits_predictions = tf.argmax(digits_logits, axis=2)
            letters_predictions = tf.argmax(letters_logits, axis=2)

            labels = tf.concat([digits_batch, letters_batch], axis=1)
            predictions = tf.concat([digits_predictions, letters_predictions],
                                    axis=1)

            labels_string = tf.reduce_join(tf.as_string(labels), axis=1)
            predictions_string = tf.reduce_join(tf.as_string(predictions),
                                                axis=1)

            accuracy, update_accuracy = tf.metrics.accuracy(
                labels=labels_string, predictions=predictions_string)

            tf.summary.image('image', image_batch)
            tf.summary.scalar('accuracy', accuracy)
            tf.summary.histogram(
                'variables',
                tf.concat([
                    tf.reshape(var, [-1]) for var in tf.trainable_variables()
                ],
                          axis=0))
            summary = tf.summary.merge_all()

            with tf.Session() as sess:
                sess.run([
                    tf.global_variables_initializer(),
                    tf.local_variables_initializer()
                ])
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)

                restorer = tf.train.Saver()
                restorer.restore(sess, path_to_checkpoint)

                for _ in xrange(num_batches):
                    sess.run(update_accuracy)

                accuracy_val, summary_val = sess.run([accuracy, summary])
                self.summary_writer.add_summary(summary_val,
                                                global_step=global_step)

                coord.request_stop()
                coord.join(threads)

        return accuracy_val
예제 #3
0
파일: evaluator.py 프로젝트: NUCS-LEi/scene
    def evaluate(self, path_to_checkpoint, path_to_tfrecords_file, num_examples, global_step):
        batch_size = 128
        num_batches = num_examples // batch_size
        needs_include_length = False

        with tf.Graph().as_default():
            image_batch, lable_batch = Donkey.build_batch(path_to_tfrecords_file,
                                                                 num_examples=num_examples,
                                                                 batch_size=batch_size,
                                                                 shuffled=False)
            label_logits = Model.inference(image_batch, drop_rate=0.0)
            label_predictions1 = tf.argmax(label_logits[:,0], axis=1)
            label_predictions2 = tf.argmax(label_logits[:,1], axis=1)
            label_predictions3 = tf.argmax(label_logits[:,2], axis=1)
            accuracy1, update_accuracy1 = tf.metrics.accuracy(
                labels=lable_batch,
                predictions=label_predictions1
            )
            accuracy2, update_accuracy2 = tf.metrics.accuracy(
                labels=lable_batch,
                predictions=label_predictions2
            )
            accuracy3, update_accuracy3 = tf.metrics.accuracy(
                labels=lable_batch,
                predictions=label_predictions3
            )
            accuracy = accuracy1 + accuracy2 + accuracy3
            update_accuracy = update_accuracy1 + update_accuracy2 + update_accuracy3
        

            tf.summary.image('image', image_batch)
            tf.summary.scalar('accuracy', accuracy)
            tf.summary.histogram('variables',
                                 tf.concat([tf.reshape(var, [-1]) for var in tf.trainable_variables()], axis=0))
            summary = tf.summary.merge_all()

            with tf.Session() as sess:
                sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)

                restorer = tf.train.Saver()
                restorer.restore(sess, path_to_checkpoint)

                for _ in range(num_batches):
                    sess.run(update_accuracy)

                accuracy_val, summary_val = sess.run([accuracy, summary])
                self.summary_writer.add_summary(summary_val, global_step=global_step)

                coord.request_stop()
                coord.join(threads)

        return accuracy_val
예제 #4
0
    def evaluate(self, path_to_checkpoint, path_to_tfrecords_file,
                 num_examples, global_step):
        batch_size = 128
        num_batches = num_examples / batch_size

        with tf.Graph().as_default():
            images, labels = Donkey.build_batch(path_to_tfrecords_file,
                                                batch_size=batch_size,
                                                one_hot=False,
                                                shuffled=False)
            logits = Model.inference(images, keep_prob=1)
            predictions = tf.argmax(logits, axis=1)

            accuracy, update_accuracy = tf.metrics.accuracy(
                labels=labels, predictions=predictions)

            tf.summary.image('image', images)
            tf.summary.scalar('accuracy', accuracy)
            summary = tf.summary.merge_all()

            with tf.Session() as sess:
                sess.run([
                    tf.global_variables_initializer(),
                    tf.local_variables_initializer()
                ])
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)

                restorer = tf.train.Saver()
                restorer.restore(sess, path_to_checkpoint)

                for _ in xrange(num_batches):
                    sess.run(update_accuracy)

                accuracy_val, summary_val = sess.run([accuracy, summary])
                self.summary_writer.add_summary(summary_val,
                                                global_step=global_step)

                coord.request_stop()
                coord.join(threads)

        return accuracy_val
예제 #5
0
def _train(path_to_train_tfrecords_file, num_train_examples,
           path_to_val_tfrecords_file, num_val_examples, path_to_train_log_dir,
           path_to_restore_checkpoint_file, training_options):
    batch_size = training_options['batch_size']
    initial_patience = training_options['patience']
    num_steps_to_show_loss = 100
    num_steps_to_check = 1000

    with tf.Graph().as_default():
        image_batch, length_batch, digits_batch = Donkey.build_batch(
            path_to_train_tfrecords_file,
            num_examples=num_train_examples,
            batch_size=batch_size,
            shuffled=True)
        length_logtis, digits_logits = Model.inference(image_batch,
                                                       drop_rate=0.2)
        loss = Model.loss(length_logtis, digits_logits, length_batch,
                          digits_batch)

        global_step = tf.Variable(0, name='global_step', trainable=False)
        learning_rate = tf.train.exponential_decay(
            training_options['learning_rate'],
            global_step=global_step,
            decay_steps=training_options['decay_steps'],
            decay_rate=training_options['decay_rate'],
            staircase=True)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        train_op = optimizer.minimize(loss, global_step=global_step)

        tf.summary.image('image', image_batch)
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('learning_rate', learning_rate)
        summary = tf.summary.merge_all()

        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(path_to_train_log_dir,
                                                   sess.graph)
            evaluator = Evaluator(
                os.path.join(path_to_train_log_dir, 'eval/val'))

            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            saver = tf.train.Saver()
            if path_to_restore_checkpoint_file is not None:
                assert tf.train.checkpoint_exists(path_to_restore_checkpoint_file), \
                    '%s not found' % path_to_restore_checkpoint_file
                saver.restore(sess, path_to_restore_checkpoint_file)
                print 'Model restored from file: %s' % path_to_restore_checkpoint_file

            print 'Start training'
            patience = initial_patience
            best_accuracy = 0.0
            duration = 0.0

            while True:
                start_time = time.time()
                _, loss_val, summary_val, global_step_val, learning_rate_val = sess.run(
                    [train_op, loss, summary, global_step, learning_rate])
                duration += time.time() - start_time

                if global_step_val % num_steps_to_show_loss == 0:
                    examples_per_sec = batch_size * num_steps_to_show_loss / duration
                    duration = 0.0
                    print '=> %s: step %d, loss = %f (%.1f examples/sec)' % (
                        datetime.now(), global_step_val, loss_val,
                        examples_per_sec)

                if global_step_val % num_steps_to_check != 0:
                    continue

                summary_writer.add_summary(summary_val,
                                           global_step=global_step_val)

                print '=> Evaluating on validation dataset...'
                path_to_latest_checkpoint_file = saver.save(
                    sess, os.path.join(path_to_train_log_dir, 'latest.ckpt'))
                accuracy = evaluator.evaluate(path_to_latest_checkpoint_file,
                                              path_to_val_tfrecords_file,
                                              num_val_examples,
                                              global_step_val)
                print '==> accuracy = %f, best accuracy %f' % (accuracy,
                                                               best_accuracy)

                if accuracy > best_accuracy:
                    path_to_checkpoint_file = saver.save(
                        sess,
                        os.path.join(path_to_train_log_dir, 'model.ckpt'),
                        global_step=global_step_val)
                    print '=> Model saved to file: %s' % path_to_checkpoint_file
                    patience = initial_patience
                    best_accuracy = accuracy
                else:
                    patience -= 1

                print '=> patience = %d' % patience
                if patience == 0:
                    break

            coord.request_stop()
            coord.join(threads)
            print 'Finished'
예제 #6
0
import tensorflow as tf
import sys
sys.path.append('.')
from donkey import Donkey
import matplotlib.pyplot as plt
import numpy as np

image_batch, length_batch, digits_batch = Donkey.build_batch(r'E:\DataSets\SVHN\train.tfrecords',
                                                             num_example=100,
                                                             batch_size=36,
                                                             shuffled=False)

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    image_batch_val, length_batch_val, digits_batch_val = sess.run([image_batch, length_batch, digits_batch])
    image_batch_val = (image_batch_val / 2.0) + 0.5
    print(np.array(image_batch_val[0]).shape)
    fig, axes = plt.subplots(6, 6, figsize=(20, 20))
    for i, ax in enumerate(axes.flat):
        title = 'length: %d\ndigits= %d, %d, %d, %d, %d' % (length_batch_val[i],
                                                            digits_batch_val[i][0],
                                                            digits_batch_val[i][1],
                                                            digits_batch_val[i][2],
                                                            digits_batch_val[i][3],
                                                            digits_batch_val[i][4])
        ax.imshow(image_batch_val[i])
        ax.set_title(title)
        ax.set_xticks([])
        ax.set_yticks([])
예제 #7
0
def _train(path_to_train_tfrecords_file, num_train_examples,
           path_to_val_tfrecords_file, num_val_examples, path_to_train_log_dir,
           path_to_restore_checkpoint_file, training_options):
    batch_size = training_options['batch_size']
    initial_patience = training_options['patience']
    num_steps_to_show_loss = 1
    num_steps_to_check = 1
    #    num_steps_to_show_loss = 10
    #    num_steps_to_check = 100
    #    cwd = './img_confirm/'

    with tf.Graph().as_default():
        X = tf.placeholder("float", [None, 28, 28, 1])
        Y = tf.placeholder("float", [None, 11])

        w = init_weights([3, 3, 1, 32])  # 3x3x1 conv, 32 outputs
        w2 = init_weights([3, 3, 32, 64])  # 3x3x32 conv, 64 outputs
        w3 = init_weights([3, 3, 64, 128])  # 3x3x32 conv, 128 outputs
        w4 = init_weights([128 * 4 * 4,
                           625])  # FC 128 * 4 * 4 inputs, 625 outputs
        w_o = init_weights([625, 11])  # FC 625 inputs, 11 outputs (labels)

        global_step = tf.Variable(0, name='global_step', trainable=False)

        p_keep_conv = tf.placeholder("float")
        p_keep_hidden = tf.placeholder("float")
        py_x = model(X, w, w2, w3, w4, w_o, p_keep_conv, p_keep_hidden)

        loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
        train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(
            loss, global_step=global_step)
        predict_op = tf.argmax(py_x, 1)

        print num_train_examples
        trX, trY = Donkey.build_batch(path_to_train_tfrecords_file,
                                      num_examples=num_train_examples,
                                      batch_size=batch_size,
                                      shuffled=True,
                                      num_epoch_=training_options['num_epoch'])
        teX, teY = Donkey.build_batch(path_to_val_tfrecords_file,
                                      num_examples=num_val_examples,
                                      batch_size=64,
                                      shuffled=True,
                                      num_epoch_=training_options['num_epoch'])
        #        teX,teY = Donkey.build_batch_from_file(path_to_val_file,batch_size=10,
        #                                               shuffled=True,num_epoch_=training_options['num_epoch'])

        indices = tf.placeholder("uint8", [batch_size, 1])
        var = tf.one_hot(indices, depth=11, axis=1)

        tf.summary.image('image', trX)
        tf.summary.scalar('loss', loss)
        summary = tf.summary.merge_all()
        #
        #
        #
        #        tr_x_split = tf.unstack(trX,32,axis=0)
        #

        #        image_batch,length_batch,digits_batch = Donkey.build_batch(path_to_train_tfrecords_file1,path_to_train_tfrecords_file2,path_to_train_tfrecords_file3,path_to_train_tfrecords_file4,
        #                                                                     num_examples=num_train_examples,
        #                                                                     batch_size=batch_size,
        #                                                                    shuffled=True,num_epoch_=training_options['num_epoch'])
        #

        #        trX = trX.reshape(-1, 28, 28, 1)  # 28x28x1 input img
        #        teX = teX.reshape(-1, 28, 28, 1)  # 28x28x1 input img

        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(path_to_train_log_dir,
                                                   sess.graph)

            # you need to initialize all variables
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            saver = tf.train.Saver()
            if path_to_restore_checkpoint_file is not None:
                assert tf.train.checkpoint_exists(path_to_restore_checkpoint_file), \
                    '%s not found' % path_to_restore_checkpoint_file
                saver.restore(sess, path_to_restore_checkpoint_file)
                print 'Model restored from file: %s' % path_to_restore_checkpoint_file

            print 'Start training'
            #            print trX.shape

            for i in range(20000):
                #                plt.imshow(trX.eval())
                #                plt.show()
                #                single,l = sess.run([trX,trY])
                #                single.resize(28,28)
                #                img=Image.fromarray(single, 'L')
                #                img.save(cwd+str(i)+'_''Label_'+str(l)+'.jpg')
                #print(single,l)

                #                tr_x,tr_y,tr_x_split = sess.run([trX,trY,tr_x_split])
                #                print tr_x.shape
                #                tr_x = tr_x.reshape(-1, 28, 28, 1)  # 28x28x1 input img
                ##                print tr_x.shape,tr_x_split
                #
                #
                #                for j,ele in enumerate(tr_x_split):
                ##                    print ele.shape
                ##                    ele.reshape(28,28)
                ##                    print ele.shape
                #                    ele.resize(28,28)
                #                    img = Image.fromarray(ele,'L')
                #                    img.save(cwd+str(j)+'_''Label_'+str(tr_y[j])+'.jpg')
                ##                    print ele,ele.shape,tr_y[j]
                #
                tr_x, tr_y = sess.run([trX, trY])
                tr_y.resize(batch_size, 1)
                tr_y = sess.run(var, feed_dict={indices: tr_y})
                tr_y.resize(batch_size, 11)

                #                tr_y = tr_y.reshape(-1,11)
                _, loss_val, summary_val, global_step_val = sess.run(
                    [train_op, loss, summary, global_step],
                    feed_dict={
                        X: tr_x,
                        Y: tr_y,
                        p_keep_conv: 0.8,
                        p_keep_hidden: 0.5
                    })

                te_x, te_y = sess.run([teX, teY])
                te_x = te_x.reshape(-1, 28, 28, 1)  # 28x28x1 input img
                accuracy = np.mean(te_y == sess.run(predict_op,
                                                    feed_dict={
                                                        X: te_x,
                                                        p_keep_conv: 1.0,
                                                        p_keep_hidden: 1.0
                                                    }))
                print(i, accuracy)

                summary_writer.add_summary(summary_val,
                                           global_step=global_step_val)
                path_to_latest_checkpoint_file = saver.save(
                    sess, os.path.join(path_to_train_log_dir, 'latest.ckpt'))

            coord.request_stop()
            coord.join(threads)
            print 'Finished'
    def evaluate(self, path_to_checkpoint, image_eval, length_eval,
                 digits_eval, global_step):
        batch_size = 128
        needs_include_length = False
        accuracy_val = 0.0
        with tf.Graph().as_default():
            image_batch = tf.placeholder(tf.float32,
                                         shape=[
                                             None, self.image_size,
                                             self.image_size, self.num_channels
                                         ])
            length_batch = tf.placeholder(tf.int32, shape=[None])
            digits_batch = tf.placeholder(tf.int32,
                                          shape=[None, self.digits_nums])
            num_examples = image_eval.shape[0]
            num_batches = num_examples / batch_size
            # length_logits, digits_logits = Model.inference(image_batch, drop_rate=0.0)
            length_logits, digits_logits = Model.forward(image_batch, 1.0)
            length_predictions = tf.argmax(length_logits, axis=1)
            digits_predictions = tf.argmax(digits_logits, axis=2)

            if needs_include_length:
                labels = tf.concat(
                    [tf.reshape(length_batch, [-1, 1]), digits_batch], axis=1)
                predictions = tf.concat([
                    tf.reshape(length_predictions, [-1, 1]), digits_predictions
                ],
                                        axis=1)
            else:
                labels = digits_batch
                predictions = digits_predictions
            #correct_pre = tf.equal(tf.argmax(labels,axis=1),tf.argmax(predictions,axis=1))
            #accuracy = tf.reduce_mean(tf.cast(correct_pre, tf.float32))
            labels_string = tf.reduce_join(tf.as_string(labels), axis=1)
            predictions_string = tf.reduce_join(tf.as_string(predictions),
                                                axis=1)
            correct_pre = tf.equal(labels_string, predictions_string)
            accuracy = tf.reduce_mean(tf.cast(correct_pre, tf.float32))

            tf.summary.image('image', image_batch)
            tf.summary.scalar('accuracy', accuracy)
            tf.summary.histogram(
                'variables',
                tf.concat([
                    tf.reshape(var, [-1]) for var in tf.trainable_variables()
                ],
                          axis=0))
            summary = tf.summary.merge_all()

            with tf.Session() as sess:
                sess.run([
                    tf.global_variables_initializer(),
                    tf.local_variables_initializer()
                ])

                restorer = tf.train.Saver()
                restorer.restore(sess, path_to_checkpoint)
                for _ in range(math.floor(num_examples / batch_size)):
                    image_batch_input, length_batch_input, digits_batch_input = Donkey.build_batch(
                        image_eval,
                        length_eval,
                        digits_eval,
                        batch_size=batch_size)
                    feed_dict = {
                        image_batch: image_batch_input,
                        length_batch: length_batch_input,
                        digits_batch: digits_batch_input
                    }
                    accuracy_step, summary_val = sess.run([accuracy, summary],
                                                          feed_dict=feed_dict)
                    accuracy_val += accuracy_step
                self.summary_writer.add_summary(summary_val,
                                                global_step=global_step)
                accuracy_val = accuracy_val / math.floor(
                    num_examples / batch_size)
        return accuracy_val * 100
    def evaluate(self, path_to_restore_model_checkpoint_file,
                 path_to_restore_defender_checkpoint_file,
                 path_to_tfrecords_file, num_examples, global_step,
                 defend_layer, attacker_type):
        batch_size = 32
        num_batches = num_examples // batch_size
        needs_include_length = False

        with tf.Graph().as_default():
            image_batch, length_batch, digits_batch = Donkey.build_batch(
                path_to_tfrecords_file,
                num_examples=num_examples,
                batch_size=batch_size,
                shuffled=False)
            with tf.variable_scope('model'):
                length_logits, digits_logits, hidden_out = Model.inference(
                    image_batch,
                    drop_rate=0.0,
                    is_training=False,
                    defend_layer=defend_layer)
            with tf.variable_scope('defender'):
                recovered = Attacker.recover_hidden(attacker_type,
                                                    hidden_out,
                                                    is_training=False,
                                                    defend_layer=defend_layer)
            ssim = tf.reduce_mean(
                tf.abs(tf.image.ssim(image_batch, recovered, max_val=2)))
            length_predictions = tf.argmax(length_logits, axis=1)
            digits_predictions = tf.argmax(digits_logits, axis=2)

            if needs_include_length:
                labels = tf.concat(
                    [tf.reshape(length_batch, [-1, 1]), digits_batch], axis=1)
                predictions = tf.concat([
                    tf.reshape(length_predictions, [-1, 1]), digits_predictions
                ],
                                        axis=1)
            else:
                labels = digits_batch
                predictions = digits_predictions

            labels_string = tf.reduce_join(tf.as_string(labels), axis=1)
            predictions_string = tf.reduce_join(tf.as_string(predictions),
                                                axis=1)

            accuracy, update_accuracy = tf.metrics.accuracy(
                labels=labels_string, predictions=predictions_string)

            tf.summary.image('image', image_batch, max_outputs=20)
            tf.summary.image('recovered', recovered, max_outputs=20)
            tf.summary.scalar('ssim', ssim)
            tf.summary.scalar('accuracy', accuracy)
            tf.summary.histogram(
                'variables',
                tf.concat([
                    tf.reshape(var, [-1]) for var in tf.trainable_variables()
                ],
                          axis=0))
            summary = tf.summary.merge_all()

            with tf.Session() as sess:
                sess.run([
                    tf.global_variables_initializer(),
                    tf.local_variables_initializer()
                ])
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)

                model_saver = tf.train.Saver(var_list=tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope='model'))
                defender_saver = tf.train.Saver(var_list=tf.get_collection(
                    tf.GraphKeys.TRAINABLE_VARIABLES, scope='defender'))
                model_saver.restore(sess,
                                    path_to_restore_model_checkpoint_file)
                print("Evaluation model restored from {}".format(
                    path_to_restore_model_checkpoint_file))
                defender_saver.restore(
                    sess, path_to_restore_defender_checkpoint_file)

                for _ in range(num_batches):
                    sess.run(update_accuracy)

                accuracy_val, summary_val = sess.run([accuracy, summary])
                self.summary_writer.add_summary(summary_val,
                                                global_step=global_step)

                coord.request_stop()
                coord.join(threads)

        return accuracy_val
예제 #10
0
def _train(path_to_train_tfrecords_file, num_train_examples,
           path_to_val_tfrecords_file, num_val_examples, path_to_train_log_dir,
           path_to_restore_model_checkpoint_file, training_options):
    batch_size = training_options['batch_size']
    initial_patience = training_options['patience']
    num_steps_to_show_loss = 100
    num_steps_to_check = 1000

    with tf.Graph().as_default():
        image_batch, length_batch, digits_batch = Donkey.build_batch(
            path_to_train_tfrecords_file,
            num_examples=num_train_examples,
            batch_size=batch_size,
            shuffled=True)
        with tf.variable_scope('model'):
            length_logtis, digits_logits, hidden_out = Model.inference(
                image_batch,
                drop_rate=0.0,
                is_training=True,
                defend_layer=FLAGS.defend_layer)
        with tf.variable_scope('defender'):
            recovered = Attacker.recover_hidden(FLAGS.attacker_type,
                                                hidden_out, True,
                                                FLAGS.defend_layer)
        ssim = tf.reduce_mean(
            tf.abs(tf.image.ssim(image_batch, recovered, max_val=2)))
        model_loss = Model.loss(length_logtis, digits_logits, length_batch,
                                digits_batch)
        # loss = model_loss + FLAGS.ssim_weight * ssim
        defender_loss = -ssim

        global_step = tf.Variable(0, name='global_step', trainable=False)
        learning_rate = tf.train.exponential_decay(
            training_options['learning_rate'],
            global_step=global_step,
            decay_steps=training_options['decay_steps'],
            decay_rate=training_options['decay_rate'],
            staircase=True)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)

        # model_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='model')
        # with tf.control_dependencies(model_update_ops):
        #     train_op = optimizer.minimize(loss, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,scope='model'), global_step=global_step)

        defender_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS,
                                                scope='defender')
        with tf.control_dependencies(defender_update_ops):
            defender_op = optimizer.minimize(
                defender_loss,
                var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                           scope='defender'),
                global_step=global_step)

        tf.summary.image('image', image_batch, max_outputs=20)
        tf.summary.image('recovered', recovered, max_outputs=20)
        tf.summary.scalar('model_loss', model_loss)
        tf.summary.scalar('ssim', ssim)
        tf.summary.scalar('learning_rate', learning_rate)
        summary = tf.summary.merge_all()

        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(path_to_train_log_dir,
                                                   sess.graph)
            evaluator = Evaluator(
                os.path.join(path_to_train_log_dir, 'eval/val'))

            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            model_saver = tf.train.Saver(var_list=tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope='model'))
            defender_saver = tf.train.Saver(var_list=tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope='defender'))
            model_saver.restore(sess, path_to_restore_model_checkpoint_file)
            print('Model restored from file: %s' %
                  path_to_restore_model_checkpoint_file)
            # if path_to_restore_checkpoint_file is not None:
            #     assert tf.train.checkpoint_exists(path_to_restore_checkpoint_file), \
            #         '%s not found' % path_to_restore_checkpoint_file
            #     saver.restore(sess, path_to_restore_checkpoint_file)
            #     print('Model restored from file: %s' % path_to_restore_checkpoint_file)

            print('Start training')
            patience = initial_patience
            best_accuracy = 0.0
            duration = 0.0

            while True:
                start_time = time.time()
                _, defender_loss_val, summary_val, global_step_val, learning_rate_val = sess.run(
                    [
                        defender_op, defender_loss, summary, global_step,
                        learning_rate
                    ])
                duration += time.time() - start_time

                # print("image: {} - {}".format(image_batch_val.min(), image_batch_val.max()))

                if global_step_val % num_steps_to_show_loss == 0:
                    examples_per_sec = batch_size * num_steps_to_show_loss / duration
                    duration = 0.0
                    print('=> %s: step %d, loss = %f (%.1f examples/sec)' %
                          (datetime.now(), global_step_val, defender_loss_val,
                           examples_per_sec))

                if global_step_val % num_steps_to_check != 0:
                    continue

                summary_writer.add_summary(summary_val,
                                           global_step=global_step_val)

                print('=> Evaluating on validation dataset...')
                path_to_latest_attacker_checkpoint_file = defender_saver.save(
                    sess, os.path.join(path_to_train_log_dir, 'attacker.ckpt'))
                accuracy = evaluator.evaluate(
                    path_to_restore_model_checkpoint_file,
                    path_to_latest_attacker_checkpoint_file,
                    path_to_val_tfrecords_file, num_val_examples,
                    global_step_val, FLAGS.defend_layer, FLAGS.attacker_type)
                print('==> accuracy = %f, best accuracy %f' %
                      (accuracy, best_accuracy))

                if accuracy > best_accuracy:
                    defender_saver.save(
                        sess,
                        os.path.join(path_to_train_log_dir,
                                     'attacker_best.ckpt'))
                    patience = initial_patience
                    best_accuracy = accuracy
                else:
                    patience -= 1

                print('=> patience = %d' % patience)
                # if patience == 0:
                #     break

                if global_step_val > FLAGS.max_steps:
                    break

            coord.request_stop()
            coord.join(threads)
            print('Finished')
예제 #11
0
def _train(path_to_train_tfrecords_file, num_train_examples, path_to_val_tfrecords_file, num_val_examples,
           path_to_train_log_dir, path_to_restore_checkpoint_file, training_options):
    batch_size = training_options['batch_size']
    initial_patience = training_options['patience']
    num_steps_to_show_loss = 100
    num_steps_to_check = 1000

    with tf.Graph().as_default():
        image_batch, length_batch, digits_batch = Donkey.build_batch(path_to_train_tfrecords_file,
                                                                     num_examples=num_train_examples,
                                                                     batch_size=batch_size,
                                                                     shuffled=True)
        length_logtis, digits_logits = Model.inference(image_batch, drop_rate=0.2)
        loss = Model.loss(length_logtis, digits_logits, length_batch, digits_batch)

        global_step = tf.Variable(0, name='global_step', trainable=False)
        learning_rate = tf.train.exponential_decay(training_options['learning_rate'], global_step=global_step,
                                                   decay_steps=training_options['decay_steps'], decay_rate=training_options['decay_rate'], staircase=True)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        train_op = optimizer.minimize(loss, global_step=global_step)

        tf.summary.image('image', image_batch)
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('learning_rate', learning_rate)
        summary = tf.summary.merge_all()

        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(path_to_train_log_dir, sess.graph)
            evaluator = Evaluator(os.path.join(path_to_train_log_dir, 'eval/val'))

            sess.run(tf.global_variables_initializer())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            saver = tf.train.Saver()
            if path_to_restore_checkpoint_file is not None:
                assert tf.train.checkpoint_exists(path_to_restore_checkpoint_file), \
                    '%s not found' % path_to_restore_checkpoint_file
                saver.restore(sess, path_to_restore_checkpoint_file)
                print 'Model restored from file: %s' % path_to_restore_checkpoint_file

            print 'Start training'
            patience = initial_patience
            best_accuracy = 0.0
            duration = 0.0

            while True:
                start_time = time.time()
                _, loss_val, summary_val, global_step_val, learning_rate_val = sess.run([train_op, loss, summary, global_step, learning_rate])
                duration += time.time() - start_time

                if global_step_val % num_steps_to_show_loss == 0:
                    examples_per_sec = batch_size * num_steps_to_show_loss / duration
                    duration = 0.0
                    print '=> %s: step %d, loss = %f (%.1f examples/sec)' % (
                        datetime.now(), global_step_val, loss_val, examples_per_sec)

                if global_step_val % num_steps_to_check != 0:
                    continue

                summary_writer.add_summary(summary_val, global_step=global_step_val)

                print '=> Evaluating on validation dataset...'
                path_to_latest_checkpoint_file = saver.save(sess, os.path.join(path_to_train_log_dir, 'latest.ckpt'))
                accuracy = evaluator.evaluate(path_to_latest_checkpoint_file, path_to_val_tfrecords_file,
                                              num_val_examples,
                                              global_step_val)
                print '==> accuracy = %f, best accuracy %f' % (accuracy, best_accuracy)

                if accuracy > best_accuracy:
                    path_to_checkpoint_file = saver.save(sess, os.path.join(path_to_train_log_dir, 'model.ckpt'),
                                                         global_step=global_step_val)
                    print '=> Model saved to file: %s' % path_to_checkpoint_file
                    patience = initial_patience
                    best_accuracy = accuracy
                else:
                    patience -= 1

                print '=> patience = %d' % patience
                if patience == 0:
                    break

            coord.request_stop()
            coord.join(threads)
            print 'Finished'
예제 #12
0
    def evaluate(self, path_to_checkpoint, path_to_tfrecords_file,
                 num_examples, global_step, save_file):
        batch_size = 128

        num_batches = num_examples // batch_size
        needs_include_length = False

        with tf.Graph().as_default():
            id_batch, image_batch, length_batch, digits_batch = Donkey.build_batch(
                path_to_tfrecords_file,
                num_example=num_examples,
                batch_size=batch_size,
                shuffled=False)

            length_logits, digits_logits = Model.inference(image_batch,
                                                           drop_rate=0.0)
            length_predictions = tf.argmax(length_logits, axis=1)
            digits_predictions = tf.argmax(digits_logits, axis=2)

            if needs_include_length:
                labels = tf.concat(
                    [tf.reshape(length_batch, [-1, 1]), digits_batch], axis=1)
                predictions = tf.concat([
                    tf.reshape(length_predictions, [-1, 1]), digits_predictions
                ],
                                        axis=1)
            else:
                labels = digits_batch
                predictions = digits_predictions

            labels_string = tf.reduce_join(tf.as_string(labels), axis=1)
            predictions_string = tf.reduce_join(tf.as_string(predictions),
                                                axis=1)

            accuracy, update_accuracy = tf.metrics.accuracy(
                labels=labels_string, predictions=predictions_string)

            tf.summary.image('image', image_batch)
            tf.summary.scalar('accuracy', accuracy)
            tf.summary.histogram(
                'variables',
                tf.concat([
                    tf.reshape(var, [-1]) for var in tf.trainable_variables()
                ],
                          axis=0))
            summary = tf.summary.merge_all()

            ids_total = []
            labels_total = []
            predictions_total = []
            with tf.Session() as sess:
                sess.run([
                    tf.global_variables_initializer(),
                    tf.local_variables_initializer()
                ])
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)

                restorer = tf.train.Saver()
                restorer.restore(sess, path_to_checkpoint)

                for k in range(num_batches):
                    sess.run(update_accuracy)

                    ids, labels, predictions = sess.run(
                        [id_batch, labels_string, predictions_string])
                    accuracy_val, summary_val = sess.run([accuracy, summary])
                    self.summary_writer.add_summary(summary_val,
                                                    global_step=global_step)

                    ids_total.extend(ids)
                    labels_total.extend(labels)
                    predictions_total.extend(predictions)

                coord.request_stop()
                coord.join(threads)

            ids_total = np.array(ids_total)
            labels_total = np.array(labels_total)
            predictions_total = np.array(predictions_total)
            print(ids_total.shape, labels_total.shape, predictions_total.shape)
            with open(save_file, 'w') as f:
                f.write('ID\tprediction\tlabel\n')
                for i in range(len(labels_total)):
                    f.write(
                        str(ids_total[i]) + '\t' + str(predictions_total[i]) +
                        '\t' + str(labels_total[i]) + '\n')

        return accuracy_val
def train(path_to_train_h5py_file, path_to_val_h5py_file,
           path_to_train_log_dir, path_to_restore_checkpoint_file, training_options):
    image_train,length_train,digits_train=read_h5py_file(path_to_train_h5py_file,Flag=1)
    image_val,length_val,digits_val=read_h5py_file(path_to_val_h5py_file,Flag=2)
    batch_size = training_options['batch_size']
    initial_patience = training_options['patience']
    num_steps_to_show_loss = 100
    num_steps_to_check = 1000
    image_size=54
    num_channels=3
    digits_nums=5
    with tf.Graph().as_default():
        image_batch=tf.placeholder(
        tf.float32, shape=[None, image_size, image_size, num_channels])
        length_batch=tf.placeholder(tf.int32,shape=[None])
        digits_batch=tf.placeholder(tf.int32,shape=[None,digits_nums])
        # length_logtis, digits_logits = Model.inference(image_batch, drop_rate=0.2)
        length_logtis, digits_logits = Model.forward(image_batch, 0.8)

        loss = Model.loss(length_logtis, digits_logits, length_batch, digits_batch)

        global_step = tf.Variable(0, name='global_step', trainable=False)
        learning_rate = tf.train.exponential_decay(training_options['learning_rate'], global_step=global_step,
                                                   decay_steps=training_options['decay_steps'], decay_rate=training_options['decay_rate'], staircase=True)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        train_op = optimizer.minimize(loss, global_step=global_step)

        tf.summary.image('image', image_batch)
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('learning_rate', learning_rate)
        summary = tf.summary.merge_all()

        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(path_to_train_log_dir, sess.graph)
            evaluator = Evaluator(os.path.join(path_to_train_log_dir, 'eval/val'))

            sess.run(tf.global_variables_initializer())
            saver = tf.train.Saver()
            if path_to_restore_checkpoint_file is not None:
                assert tf.train.checkpoint_exists(path_to_restore_checkpoint_file), \
                    '%s not found' % path_to_restore_checkpoint_file
                saver.restore(sess, path_to_restore_checkpoint_file)
                print('Model restored from file: %s' % path_to_restore_checkpoint_file)

            print('Start training')
            patience = initial_patience
            best_accuracy = 0.0
            duration = 0.0

            while True:
                start_time = time.time()
                image_batch_input, length_batch_input, digits_batch_input= Donkey.build_batch(image_train,length_train,digits_train,
                                                                     batch_size=batch_size)
                feed_dict={image_batch:image_batch_input,length_batch:length_batch_input,digits_batch:digits_batch_input}
                _, loss_val, summary_val, global_step_val, learning_rate_val = sess.run([train_op, loss, summary, global_step, learning_rate],feed_dict=feed_dict)
                duration += time.time() - start_time

                if global_step_val % num_steps_to_show_loss == 0:
                    examples_per_sec = batch_size * num_steps_to_show_loss / duration
                    duration = 0.0
                    print ('=> %s: step %d, loss = %f (%.1f examples/sec)' % (
                        datetime.now(), global_step_val, loss_val, examples_per_sec))

                if global_step_val % num_steps_to_check != 0:
                    continue

                summary_writer.add_summary(summary_val, global_step=global_step_val)

                print ('=> Evaluating on validation dataset...')
                path_to_latest_checkpoint_file = saver.save(sess, os.path.join(path_to_train_log_dir, 'latest.ckpt'))
                accuracy = evaluator.evaluate(path_to_latest_checkpoint_file, image_val,length_val,digits_val,
                                              global_step_val)
                print ('==> accuracy = %f, best accuracy %f' % (accuracy, best_accuracy))

                if accuracy > best_accuracy:
                    path_to_checkpoint_file = saver.save(sess, os.path.join(path_to_train_log_dir, 'model.ckpt'),
                                                         global_step=global_step_val)
                    print ('=> Model saved to file: %s' % path_to_checkpoint_file)
                    patience = initial_patience
                    best_accuracy = accuracy
                else:
                    patience -= 1

                print ('=> patience = %d' % patience)
                if patience == 0:
                    break
            print ('Finished')