Exemple #1
0
    def testForModel(self):
        # path_to_tfrecords_file = '/notebooks/dataVolume/workspace/data'
        path_to_tfrecords_file = '/home/amax/Documents/wit/data/train.tfrecords'
        input_ops = Inputs(path_to_tfrecords_file=path_to_tfrecords_file,
                           batch_size=32,
                           shuffle=True,
                           min_queue_examples=5000,
                           num_preprocess_threads=4,
                           num_reader_threads=1)

        images, input_seqs, target_seqs, mask = input_ops.build_batch()

        mymodel = Model(vocab_size=11,
                        mode='train',
                        embedding_size=512,
                        num_lstm_units=128,
                        lstm_dropout_keep_prob=0.7,
                        cnn_drop_rate=0.2,
                        initializer_scale=0.08)

        logits = mymodel.inference(images, input_seqs, mask, state=None)

        total_loss = mymodel.loss(logits, target_seqs)

        Tensors_output = [images, input_seqs, target_seqs, mask, logits]
        Tensors_name = [
            'images', 'input_seqs', 'target_seqs', 'mask', 'logits'
        ]

        expected_shapes = [
            (32, 54, 54, 3),  # [batch_size, image_height, image_width, 3]
            (32, ),  # [batch_size, sequence_length]
            (32, ),  # [batch_size, sequence_length]
            (32, ),  # [batch_size, sequence_length]
            ()
        ]

        self._checkOutputs(Tensors_output, Tensors_name, expected_shapes)
Exemple #2
0
    def testForBatch_not_shuffle(self):
        # path_to_tfrecords_file = '/notebooks/dataVolume/workspace/data'
        path_to_tfrecords_file = '/home/amax/Documents/wit/data/val.tfrecords'
        input_ops = Inputs(path_to_tfrecords_file=path_to_tfrecords_file,
                           batch_size=128,
                           shuffle=False,
                           min_queue_examples=5000,
                           num_preprocess_threads=4,
                           num_reader_threads=1)

        images, input_seqs, target_seqs, mask = input_ops.build_batch()

        Tensors_output = [images, input_seqs, target_seqs, mask]
        Tensors_name = ['images', 'input_seqs', 'target_seqs', 'mask']

        expected_shapes = [
            (128, 54, 54, 3),  # [batch_size, image_height, image_width, 3]
            (128, ),  # [batch_size, sequence_length]
            (128, ),  # [batch_size, sequence_length]
            (128, )
        ]  # [batch_size, sequence_length]

        self._checkOutputs(Tensors_output, Tensors_name, expected_shapes)
Exemple #3
0
    def evaluate(self, path_to_checkpoint, path_to_tfrecords_file,
                 num_examples, global_step):
        batch_size = 128
        num_batches = int(num_examples / batch_size)

        with tf.Graph().as_default():
            input_ops = Inputs(path_to_tfrecords_file=path_to_tfrecords_file,
                               batch_size=batch_size,
                               shuffle=False,
                               min_queue_examples=5000,
                               num_preprocess_threads=4,
                               num_reader_threads=1)
            images, input_seqs, target_seqs, mask = input_ops.build_batch()

            mymodel = Model(vocab_size=12,
                            mode='evaluate',
                            embedding_size=512,
                            num_lstm_units=64,
                            lstm_dropout_keep_prob=0.7,
                            cnn_drop_rate=0.2,
                            initializer_scale=0.08)

            logits = mymodel.inference(images, input_seqs, mask)
            digit_predictions = tf.argmax(logits, axis=1)

            labels = tf.reshape(target_seqs, [-1])
            weights = tf.to_float(tf.reshape(mask, [-1]))
            predictions = tf.reshape(digit_predictions, [-1])

            accuracy, update_accuracy = tf.metrics.accuracy(
                labels=labels, predictions=predictions, weights=weights)

            tf.summary.image('image', images)
            tf.summary.scalar('accuracy', accuracy)
            tf.summary.histogram(
                'variables',
                tf.concat([
                    tf.reshape(var, [-1]) for var in tf.trainable_variables()
                ],
                          axis=0))
            summary = tf.summary.merge_all()

            with tf.Session() as sess:
                sess.run([
                    tf.global_variables_initializer(),
                    tf.local_variables_initializer()
                ])
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(sess=sess, coord=coord)

                restorer = tf.train.Saver()
                restorer.restore(sess, path_to_checkpoint)

                for _ in range(num_batches):
                    sess.run(update_accuracy)

                accuracy_val, summary_val = sess.run([accuracy, summary])
                self.summary_writer.add_summary(summary_val,
                                                global_step=global_step)

                coord.request_stop()
                coord.join(threads)

        return accuracy_val
Exemple #4
0
def _train(path_to_train_tfrecords_file, num_train_examples,
           path_to_val_tfrecords_file, num_val_examples,
           path_to_train_log_dir, path_to_restore_checkpoint_file,
           training_options):

    batch_size = training_options['batch_size']  # 32
    initial_patience = training_options['patience']  # 100
    # output information setting
    num_steps_to_show_loss = 100
    num_steps_to_check = 1000

    with tf.Graph().as_default():
        input_ops = Inputs(path_to_tfrecords_file=path_to_train_tfrecords_file,
                           batch_size=32,
                           shuffle=True,
                           # int(0.4 * num_train_examples)
                           min_queue_examples=5000,
                           num_preprocess_threads=4,
                           num_reader_threads=1)
        images, input_seqs, target_seqs, mask = input_ops.build_batch()

        mymodel = Model(vocab_size=12,
                        mode='train',
                        embedding_size=512,
                        num_lstm_units=64,
                        lstm_dropout_keep_prob=0.7,
                        cnn_drop_rate=0.2,
                        initializer_scale=0.08)

        logits = mymodel.inference(images, input_seqs, mask)

        total_loss = mymodel.loss(logits, target_seqs)

        global_step = tf.Variable(initial_value=0,
                                  name="global_step",
                                  trainable=False,
                                  collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.GLOBAL_VARIABLES])

        initial_learning_rate = training_options['learning_rate']  # 1e-2
        decay_steps = training_options['decay_steps']  # 10000
        decay_rate = training_options['decay_rate']  # 0.9

        learning_rate = tf.train.exponential_decay(initial_learning_rate,
                                                   global_step=global_step,
                                                   decay_steps=decay_steps,
                                                   decay_rate=decay_rate,
                                                   staircase=True)
        optimizer = tf.train.GradientDescentOptimizer(learning_rate)
        train_op = optimizer.minimize(total_loss, global_step=global_step)

        tf.summary.image('image', images)
        tf.summary.scalar('loss', total_loss)
        tf.summary.scalar('learning_rate', learning_rate)
        summary = tf.summary.merge_all()

        with tf.Session() as sess:
            summary_writer = tf.summary.FileWriter(
                path_to_train_log_dir, sess.graph)
            evaluator = Evaluator(os.path.join(
                path_to_train_log_dir, 'eval/val'))

            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(sess=sess, coord=coord)

            saver = tf.train.Saver()
            if path_to_restore_checkpoint_file is not None:
                assert tf.train.checkpoint_exists(path_to_restore_checkpoint_file), \
                    '%s not found' % path_to_restore_checkpoint_file
                saver.restore(sess, path_to_restore_checkpoint_file)
                print('Model restored from file: %s' %
                      path_to_restore_checkpoint_file)

            print('Start training')
            patience = initial_patience
            best_accuracy = 0.0
            duration = 0.0

            while True:
                start_time = time.time()
                _, loss_val, summary_val, global_step_val, learning_rate_val = sess.run(
                    [train_op, total_loss, summary, global_step, learning_rate])
                duration += time.time() - start_time

                if global_step_val % num_steps_to_show_loss == 0:
                    examples_per_sec = batch_size * num_steps_to_show_loss / duration
                    duration = 0.0
                    print('=> %s: step %d, loss = %f (%.1f examples/sec)' %
                          (datetime.now(), global_step_val, loss_val, examples_per_sec))

                if global_step_val % num_steps_to_check != 0:
                    continue

                summary_writer.add_summary(
                    summary_val, global_step=global_step_val)

                print('=> Evaluating on validation dataset...')
                path_to_latest_checkpoint_file = saver.save(
                    sess, os.path.join(path_to_train_log_dir, 'latest.ckpt'))
                accuracy = evaluator.evaluate(path_to_latest_checkpoint_file, path_to_val_tfrecords_file,
                                              num_val_examples,  # 23508
                                              global_step_val)
                print('==> accuracy = %f, best accuracy %f' %
                      (accuracy, best_accuracy))

                if accuracy > best_accuracy:
                    path_to_checkpoint_file = saver.save(sess, os.path.join(path_to_train_log_dir, 'model.ckpt'),
                                                         global_step=global_step_val)
                    print('=> Model saved to file: %s' %
                          path_to_checkpoint_file)
                    patience = initial_patience
                    best_accuracy = accuracy
                else:
                    patience -= 1

                print('=> patience = %d' % patience)
                if patience == 0:
                    break

            coord.request_stop()
            coord.join(threads)
            print('Finished')