def train_dev_split(sess,
                    tf_records_path,
                    dev_set_size=2000,
                    batch_size=64,
                    train_sample_size=2000):
    ds_, filename = dataset.features_dataset()

    ds = ds_.shuffle(buffer_size=20000)

    train_ds = ds.skip(dev_set_size).repeat()
    train_ds_iter = train_ds.shuffle(buffer_size=20000) \
        .batch(batch_size) \
        .make_initializable_iterator()

    train_sample_ds = ds.skip(dev_set_size)
    train_sample_ds_iter = train_sample_ds.shuffle(buffer_size=20000) \
        .take(train_sample_size) \
        .batch(train_sample_size) \
        .make_initializable_iterator()

    dev_ds_iter = ds.take(dev_set_size).batch(
        dev_set_size).make_initializable_iterator()

    sess.run(train_ds_iter.initializer, feed_dict={filename: tf_records_path})
    sess.run(dev_ds_iter.initializer, feed_dict={filename: tf_records_path})
    sess.run(train_sample_ds_iter.initializer,
             feed_dict={filename: tf_records_path})

    return train_ds_iter.get_next(), dev_ds_iter.get_next(
    ), train_sample_ds_iter.get_next()
Esempio n. 2
0
def infer_train(model_name, output_probs, x):
    BATCH_SIZE = 20000

    with tf.Session().as_default() as sess:
        ds, filename = dataset.features_dataset()
        ds_iter = ds.batch(BATCH_SIZE).make_initializable_iterator()
        sess.run(ds_iter.initializer, feed_dict={filename: paths.TRAIN_TF_RECORDS})

        tf.global_variables_initializer().run()

        saver = tf.train.Saver()
        lines = open(os.path.join(paths.CHECKPOINTS_DIR, model_name + '_latest')).read().split('\n')
        last_checkpoint = [l.split(':')[1].replace('"', '').strip() for l in lines if 'model_checkpoint_path:' in l][0]
        saver.restore(sess, os.path.join(paths.CHECKPOINTS_DIR, last_checkpoint))

        _, one_hot_decoder = dataset.one_hot_label_encoder()

        breeds = one_hot_decoder(np.identity(consts.CLASSES_COUNT))
        agg_test_df = None

        try:
            while True:
                test_batch = sess.run(ds_iter.get_next())

                inception_output = test_batch['inception_output']
                labels = test_batch['label']

                pred_probs = sess.run(output_probs, feed_dict={x: inception_output.T})
                pred_probs_max = pred_probs >= np.max(pred_probs, axis=0)
                pred_breeds = one_hot_decoder(pred_probs_max.T)

                test_df = pd.DataFrame(data={'pred': pred_breeds, 'actual': labels})

                if agg_test_df is None:
                    agg_test_df = test_df
                else:
                    agg_test_df = agg_test_df.append(test_df)

        except tf.errors.OutOfRangeError:
            print('End of the dataset')

        print(agg_test_df.take(range(0, 10)))

        agg_test_df.to_csv(paths.TRAIN_CONFUSION, index_label='id', float_format='%.17f')

        print('predictions saved to %s' % paths.TRAIN_CONFUSION)
def get_data_iter(sess_, tf_records_paths_, buffer_size=20, batch_size=64):
    ds_, file_names_ = dataset.features_dataset()
    ds_iter = ds_.shuffle(buffer_size).repeat().batch(
        batch_size).make_initializable_iterator()
    sess_.run(ds_iter.initializer, feed_dict={file_names_: tf_records_paths_})
    return ds_iter.get_next()