예제 #1
0
def inputs(eval_data):
    """Construct input for ShapeOverlap evaluation using the Reader ops.
    Args:
      eval_data: bool, indicating if one should use the train or eval data set.
    Returns:
      images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 6] size.
      labels: Labels. 1D tensor of [batch_size] size.
    Raises:
      ValueError: If no data_dir
    """
    maybe_download_and_extract(FLAGS.data_dir, FLAGS.DATA_URL)

    with tf.variable_scope('READ'):
        if not FLAGS.data_dir:
            raise ValueError('Please supply a data_dir')
        locks, keys, labels = overlap_input.inputs(eval_data=eval_data,
                                                   data_dir=FLAGS.data_dir,
                                                   batch_size=FLAGS.batch_size)

        if FLAGS.use_fp16:
            locks = tf.cast(locks, tf.float16)
            keys = tf.cast(keys, tf.float16)
            labels = tf.cast(labels, tf.float16)

        return locks, keys, labels
예제 #2
0
파일: test.py 프로젝트: ag8/magic
    if reshape:
        # Reshape the images to a long vector, if necessary
        images = tf.reshape(images,
                            shape=[
                                FLAGS.BATCH_SIZE, FLAGS.IMAGE_SIZE *
                                FLAGS.IMAGE_SIZE * FLAGS.NUM_LAYERS
                            ])

    # Return the images and the labels.
    return images, labels


# Get input data
images_batch, labels_batch = overlap_input.inputs(normalize=True,
                                                  reshape=False,
                                                  rotation=FLAGS.ROTATE)
n_samples = FLAGS.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN

with tf.Session() as sess:
    # Start populating the filename queue.
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord, sess=sess)

    for i in xrange(0, 2):
        print("Getting batch of images and labels.")

        current_image_batch = images_batch.eval()
        current_labels_batch = labels_batch.eval()

        print("labels batch: ")
예제 #3
0
파일: main.py 프로젝트: ag8/magic
from constants import FLAGS
from utils import *
from vae import VariationalAutoencoder

# Load MNIST data in a format suited for tensorflow.
# The script input_data is available under this URL:
# https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/mnist/input_data.py
# import input_data
# mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# n_samples = mnist.train.num_examples

print("Running experiment " + os.path.basename(os.path.dirname(os.path.abspath(__file__))) + "!")


# Get input data
images_batch, labels_batch = overlap_input.inputs(reshape=True)
n_samples = FLAGS.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN


def train(network_architecture, sess, learning_rate=0.001,
          batch_size=FLAGS.BATCH_SIZE, training_epochs=10,
          step_display_step=5, epoch_display_step=5):

    vae = VariationalAutoencoder(network_architecture, sess=sess,
                                 transfer_fct=tf.nn.tanh,  # FIXME: Fix numerical issues instead of just using tanh
                                 learning_rate=learning_rate,
                                 batch_size=batch_size)
    try:
        loop_start = datetime.datetime.now()

        # Training cycle
예제 #4
0
파일: main.py 프로젝트: ag8/magic
from constants import FLAGS
from vae import VariationalAutoencoder


sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../"))
from affinity.models.magic_autoencoder.utils.utils import *

# Load MNIST data in a format suited for tensorflow.
# The script input_data is available under this URL:
# https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/mnist/input_data.py
# import input_data
# mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# n_samples = mnist.train.num_examples

# Get input data
images_batch, labels_batch = overlap_input.inputs(normalize=True, reshape=False)
n_samples = FLAGS.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN


def train(sess):
    logits = net.net(images_batch, num_fully_connected_layers=1)

    # Initialize all variables
    sess.run(tf.global_variables_initializer())

    labels = tf.cast(labels_batch, dtype=tf.float32)
    cost = tf.reduce_mean(tf.squared_difference(logits, labels))


    # Initialize all variables
    sess.run(tf.global_variables_initializer())
예제 #5
0
import matplotlib.pyplot as plt

import overlap_input
from constants import FLAGS
from vae import VariationalAutoencoder

# Load MNIST data in a format suited for tensorflow.
# The script input_data is available under this URL:
# https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/mnist/input_data.py
# import input_data
# mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# n_samples = mnist.train.num_examples

# Get input data
images_batch, labels_batch = overlap_input.inputs(normalize=True, reshape=True, rotation=True)
n_samples = FLAGS.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN


class TrainingException(Exception):
    pass


def train(network_architecture, sess, learning_rate=0.001,
          batch_size=FLAGS.BATCH_SIZE, training_epochs=10, display_step=5):
    vae = VariationalAutoencoder(network_architecture, sess=sess,
                                 transfer_fct=tf.nn.softplus,  # FIXME: Fix numerical issues instead of just using tanh
                                 learning_rate=learning_rate,
                                 batch_size=batch_size)

    try:
예제 #6
0
파일: main.py 프로젝트: ag8/magic
import pickle
from random import randint

import numpy as np
import scipy.misc

import tensorflow as tf

import overlap_input

from constants import FLAGS

# Get examples
images_batch, labels_batch = overlap_input.inputs()


def brilliant_neural_network(images, labels):

    # logits = tf.random_uniform(shape=[], minval=0, maxval=10000, dtype=tf.int32)

    logits = tf.constant(0)

    return logits


# Beautiful neural network
logits = brilliant_neural_network(images_batch, labels_batch)


# Loss function
def get_loss(logits, labels):
예제 #7
0
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import scipy.misc

import tensorflow as tf

import tensorflow.contrib.slim as slim

import overlap_input

images, labels = overlap_input.inputs(reshape=True)

print("Woo!")