コード例 #1
0
def train_one_epoch(classifier, data_files, hyper_pars):
    classifier.train(input_fn=lambda: make_fashion_iterators(
        data_files['train'],
        hyper_pars['batch_size'],
        shuffle=True,
        tfrecord=hyper_pars['tfrecord']),
                     steps=hyper_pars['train_steps'])
コード例 #2
0
def train(
        train_file, batch_size, num_epochs, train_steps, model_dir,
        learning_rate
):
    tf.reset_default_graph()
    chkpt_dir = model_dir + '/checkpoints'
    run_dest_dir = model_dir + '/%d' % time.time()
    n_steps = train_steps or 1000000000

    with tf.Graph().as_default() as g:
        with tf.Session(graph=g) as sess:

            features, _ = make_fashion_iterators(
                train_file, batch_size, num_epochs, shuffle=True
            )
            model = FashionAutoencoder(learning_rate=learning_rate)
            model.build_network(features)

            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            writer = tf.summary.FileWriter(
                logdir=run_dest_dir, graph=sess.graph
            )
            saver = tf.train.Saver(save_relative_paths=True)

            ckpt = tf.train.get_checkpoint_state(os.path.dirname(chkpt_dir))
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                LOGGER.info('Restored session from {}'.format(chkpt_dir))

            writer.add_graph(sess.graph)
            initial_step = model.global_step.eval()
            LOGGER.info('initial step = {}'.format(initial_step))

            try:
                for b_num in range(initial_step, initial_step + n_steps):
                    _, loss_batch, encoded, summary_t = sess.run(
                        [model.optimizer,
                         model.loss,
                         model.encoded,
                         model.train_summary_op]
                    )
                    if (b_num + 1) % 50 == 0:
                        LOGGER.info(
                            ' Loss @step {}: {:5.1f}'.format(b_num, loss_batch)
                        )
                        LOGGER.debug(str(encoded))
                        saver.save(sess, chkpt_dir, b_num)
                        writer.add_summary(summary_t, global_step=b_num)

            except tf.errors.OutOfRangeError:
                LOGGER.info('Training stopped - queue is empty.')

            saver.save(sess, chkpt_dir, b_num)
            writer.add_summary(summary_t, global_step=b_num)

        writer.close()
コード例 #3
0
def test(test_file, model_dir):
    tf.reset_default_graph()
    chkpt_dir = model_dir + '/checkpoints'

    with tf.Graph().as_default() as g:
        with tf.Session(graph=g) as sess:

            features, labels = make_fashion_iterators(
                test_file, batch_size=1, num_epochs=1, shuffle=False
            )
            model = FashionAutoencoder()
            model.build_network(features)

            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            saver = tf.train.Saver()

            ckpt = tf.train.get_checkpoint_state(os.path.dirname(chkpt_dir))
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                LOGGER.info('Restored session from {}'.format(chkpt_dir))

            initial_step = model.global_step.eval()
            LOGGER.info('initial step = {}'.format(initial_step))

            try:
                for i in range(20):
                    loss_batch, encoded_batch, labels_batch, input, recon = \
                        sess.run(
                            [model.loss,
                             model.encoded,
                             labels,
                             model.X,
                             model.Y]
                        )
                    print(loss_batch, encoded_batch.shape, recon.shape)
                    n_encoded = encoded_batch.shape[1]
                    n_labels = labels_batch.shape[1]

                    fig = plt.figure()
                    gs = plt.GridSpec(1, 3)
                    ax1 = plt.subplot(gs[0])
                    ax1.imshow(recon[0].reshape(28, 28))
                    ax2 = plt.subplot(gs[1])
                    ax2.imshow(input[0].reshape(28, 28))
                    plt.title(np.argmax(labels_batch[0]))
                    ax3 = plt.subplot(gs[2])
                    ax3.imshow(encoded_batch[0].reshape(8, 8))
                    figname = 'image_{:04d}.pdf'.format(i)
                    plt.savefig(figname, bbox_inches='tight')
                    plt.close()

            except tf.errors.OutOfRangeError:
                LOGGER.info('Testing stopped - queue is empty.')

    return n_encoded, n_labels
コード例 #4
0
def evaluate(classifier, data_files, hyper_pars):
    eval_result = classifier.evaluate(
        input_fn=lambda: make_fashion_iterators(data_files['test'],
                                                hyper_pars['batch_size'],
                                                tfrecord=hyper_pars['tfrecord']
                                                ),
        steps=100,
    )
    print('\nEval:')
    print('acc: {accuracy:0.3f}, loss: {loss:0.3f}, MPCA {mpca:0.3f}'.format(
        **eval_result))
コード例 #5
0
def test(test_file, batch_size, model_dir):
    tf.reset_default_graph()
    chkpt_dir = model_dir + '/checkpoints'

    with tf.Graph().as_default() as g:
        with tf.Session(graph=g) as sess:

            features, labels = make_fashion_iterators(test_file,
                                                      batch_size,
                                                      num_epochs=1,
                                                      shuffle=False)
            features = tf.reshape(features, [-1, 784])

            model = FashionMNISTLogReg()
            model.build_network(features, labels)

            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            saver = tf.train.Saver()

            ckpt = tf.train.get_checkpoint_state(os.path.dirname(chkpt_dir))
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                LOGGER.info('Restored session from {}'.format(chkpt_dir))

            average_loss = 0.0
            total_correct_preds = 0
            n_processed = 0

            initial_step = model.global_step.eval()
            LOGGER.info('initial step = {}'.format(initial_step))

            try:
                for i in range(1000000000):
                    loss_batch, logits_batch, Y_batch = sess.run(
                        [model.loss, model.logits, labels])
                    n_processed += batch_size
                    average_loss += loss_batch
                    preds = tf.nn.softmax(logits_batch)
                    correct_preds = tf.equal(tf.argmax(preds, 1),
                                             tf.argmax(Y_batch, 1))
                    accuracy = tf.reduce_sum(tf.cast(correct_preds,
                                                     tf.float32))
                    total_correct_preds += sess.run(accuracy)
                    LOGGER.info('  batch {} loss = {} for nproc {}'.format(
                        i, loss_batch, n_processed))
                    LOGGER.info("  total_corr_preds / nproc = {} / {}".format(
                        total_correct_preds, n_processed))
                    LOGGER.info("  Cumul. Accuracy {0}".format(
                        total_correct_preds / n_processed))
            except tf.errors.OutOfRangeError:
                LOGGER.info('Testing stopped - queue is empty.')
コード例 #6
0
def predict(classifier, data_files, hyper_pars):
    # predictions is a generator - evaluation is lazy
    predictions = classifier.predict(input_fn=lambda: make_fashion_iterators(
        data_files['test'],
        hyper_pars['batch_size'],
        tfrecord=hyper_pars['tfrecord']), )
    counter = 0
    for p in predictions:
        # TODO? - add persistency mechanism for predictions
        print(p)
        counter += 1
        if counter > 10:
            break
コード例 #7
0
def train(train_file, batch_size, num_epochs, train_steps, model_dir,
          learning_rate):
    tf.reset_default_graph()
    chkpt_dir = model_dir + '/checkpoints'
    run_dest_dir = model_dir + '/%d' % time.time()

    with tf.Graph().as_default() as g:
        with tf.Session(graph=g) as sess:

            features, labels = make_fashion_iterators(train_file,
                                                      batch_size,
                                                      num_epochs,
                                                      shuffle=True)
            features = tf.reshape(features, [-1, 784])

            model = FashionMNISTLogReg()
            model.build_network(features, labels)

            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            writer = tf.summary.FileWriter(logdir=run_dest_dir,
                                           graph=sess.graph)
            saver = tf.train.Saver(save_relative_paths=True)

            ckpt = tf.train.get_checkpoint_state(os.path.dirname(chkpt_dir))
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                LOGGER.info('Restored session from {}'.format(chkpt_dir))

            writer.add_graph(sess.graph)
            initial_step = model.global_step.eval()
            LOGGER.info('initial step = {}'.format(initial_step))

            for b_num in range(initial_step, initial_step + train_steps):
                _, loss_batch, summary_t = sess.run(
                    [model.optimizer, model.loss, model.train_summary_op])
                LOGGER.info(' Loss @step {}: {:5.1f}'.format(
                    b_num, loss_batch))
                writer.add_summary(summary_t, global_step=b_num)
                if (b_num + 1) % 5 == 0:
                    saver.save(sess, chkpt_dir, b_num)
                    writer.add_summary(summary_t, global_step=b_num)

        writer.close()
コード例 #8
0
def test_graph_one_shot_iterator_read(hdf5_file=TFILE,
                                      batch_size=25,
                                      num_epochs=1):
    feats, labs = make_fashion_iterators(hdf5_file, batch_size, num_epochs)
    with tf.Session() as sess:
        total_batches = 0
        total_examples = 0
        try:
            while True:
                fs, ls = sess.run([feats, labs])
                logger.info('{}, {}, {}, {}'.format(fs.shape, fs.dtype,
                                                    ls.shape, ls.dtype))
                total_batches += 1
                total_examples += ls.shape[0]
        except tf.errors.OutOfRangeError:
            logger.info(
                'end of dataset at total_batches={}'.format(total_batches))
        except Exception as e:
            logger.error(e)
    logger.info('saw {} total examples'.format(total_examples))
コード例 #9
0
def encode(data_file, model_dir, encoded_file_name, n_encoded, n_labels):
    tf.reset_default_graph()
    chkpt_dir = model_dir + '/checkpoints'

    with tf.Graph().as_default() as g:
        with tf.Session(graph=g) as sess:

            features, labels = make_fashion_iterators(
                data_file, batch_size=50, num_epochs=1, shuffle=False
            )
            model = FashionAutoencoder()
            model.build_network(features)

            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())

            saver = tf.train.Saver()

            ckpt = tf.train.get_checkpoint_state(os.path.dirname(chkpt_dir))
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                LOGGER.info('Restored session from {}'.format(chkpt_dir))

            initial_step = model.global_step.eval()
            LOGGER.info('initial step = {}'.format(initial_step))

            f = setup_hdf5(encoded_file_name, n_encoded, n_labels)

            try:
                for i in range(1000000000):
                    encoded_batch, labels_batch = sess.run([
                        model.encoded, labels
                    ])
                    add_batch_to_hdf5(f, encoded_batch, labels_batch)
            except tf.errors.OutOfRangeError:
                LOGGER.info('Testing stopped - queue is empty.')

            f.close()