예제 #1
0
def mnist(layers,  # pylint: disable=invalid-name
          activation="sigmoid",
          batch_size=128,
          mode="train"):
  """Mnist classification with a multi-layer perceptron."""

  if activation == "sigmoid":
    activation_op = tf.sigmoid
  elif activation == "relu":
    activation_op = tf.nn.relu
  else:
    raise ValueError("{} activation not supported".format(activation))

  # Data.
  data = mnist_dataset.load_mnist()
  data = getattr(data, mode)
  images = tf.constant(data.images, dtype=tf.float32, name="MNIST_images")
  images = tf.reshape(images, [-1, 28, 28, 1])
  labels = tf.constant(data.labels, dtype=tf.int64, name="MNIST_labels")

  # Network.
  mlp = nn.MLP(list(layers) + [10],
               activation=activation_op,
               initializers=_nn_initializers)
  network = nn.Sequential([nn.BatchFlatten(), mlp])

  def build():
    indices = tf.random_uniform([batch_size], 0, data.num_examples, tf.int64)
    batch_images = tf.gather(images, indices)
    batch_labels = tf.gather(labels, indices)
    output = network(batch_images)
    return _xent_loss(output, batch_labels)

  return build
예제 #2
0
def cifar10(
        path,  # pylint: disable=invalid-name
        conv_channels=None,
        linear_layers=None,
        batch_norm=True,
        batch_size=128,
        num_threads=4,
        min_queue_examples=1000,
        mode="train"):
    """Cifar10 classification with a convolutional network."""

    # Data.
    _maybe_download_cifar10(path)

    # Read images and labels from disk.
    if mode == "train":
        filenames = [
            os.path.join(path, CIFAR10_FOLDER, "data_batch_{}.bin".format(i))
            for i in xrange(1, 6)
        ]
    elif mode == "test":
        filenames = [os.path.join(path, "test_batch.bin")]
    else:
        raise ValueError("Mode {} not recognised".format(mode))

    depth = 3
    height = 32
    width = 32
    label_bytes = 1
    image_bytes = depth * height * width
    record_bytes = label_bytes + image_bytes
    reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
    _, record = reader.read(tf.train.string_input_producer(filenames))
    record_bytes = tf.decode_raw(record, tf.uint8)

    label = tf.cast(tf.slice(record_bytes, [0], [label_bytes]), tf.int32)
    raw_image = tf.slice(record_bytes, [label_bytes], [image_bytes])
    image = tf.cast(tf.reshape(raw_image, [depth, height, width]), tf.float32)
    # height x width x depth.
    image = tf.transpose(image, [1, 2, 0])
    image = tf.div(image, 255)

    queue = tf.RandomShuffleQueue(
        capacity=min_queue_examples + 3 * batch_size,
        min_after_dequeue=min_queue_examples,
        dtypes=[tf.float32, tf.int32],
        shapes=[image.get_shape(), label.get_shape()])
    enqueue_ops = [queue.enqueue([image, label]) for _ in xrange(num_threads)]
    tf.train.add_queue_runner(tf.train.QueueRunner(queue, enqueue_ops))

    # Network.
    def _conv_activation(x):  # pylint: disable=invalid-name
        return tf.nn.max_pool(tf.nn.relu(x),
                              ksize=[1, 2, 2, 1],
                              strides=[1, 2, 2, 1],
                              padding="SAME")

    conv = nn.ConvNet2D(output_channels=conv_channels,
                        kernel_shapes=[5],
                        strides=[1],
                        paddings=[nn.SAME],
                        activation=_conv_activation,
                        activate_final=True,
                        initializers=_nn_initializers,
                        use_batch_norm=batch_norm)

    if batch_norm:
        linear_activation = lambda x: tf.nn.relu(nn.BatchNorm()(x))
    else:
        linear_activation = tf.nn.relu

    mlp = nn.MLP(list(linear_layers) + [10],
                 activation=linear_activation,
                 initializers=_nn_initializers)
    network = nn.Sequential([conv, nn.BatchFlatten(), mlp])

    def build():
        image_batch, label_batch = queue.dequeue_many(batch_size)
        label_batch = tf.reshape(label_batch, [batch_size])

        output = network(image_batch)
        return _xent_loss(output, label_batch)

    return build