Пример #1
0
    def data(self):
        if FLAGS.data == 'cifar10':
            import cifar10_input
            cifar10_input.maybe_download_and_extract()

            all_data, all_labels = cifar10_input.read_train_data()
            vali_data, vali_labels = cifar10_input.read_validation_data()

            all_data_mean = np.mean(all_data, axis=(0, 1, 2))
            all_data_std = np.std(all_data, axis=(0, 1, 2))

            all_data -= all_data_mean
            all_data /= all_data_std

            vali_data -= all_data_mean
            vali_data /= all_data_std

            all_data = preprocessing.pad_images(all_data, FLAGS.padding_size)

        elif FLAGS.data == 'cifar100':
            import cifar100_input
            cifar100_input.maybe_download_and_extract()

            all_data, all_labels = cifar100_input.read_train_data()
            vali_data, vali_labels = cifar100_input.read_validation_data()

            all_data_mean = np.mean(all_data, axis=(0, 1, 2))
            all_data_std = np.std(all_data, axis=(0, 1, 2))

            all_data -= all_data_mean
            all_data /= all_data_std

            vali_data -= all_data_mean
            vali_data /= all_data_std

            all_data = preprocessing.pad_images(all_data, FLAGS.padding_size)

        elif FLAGS.data == 'svhn':
            import svhn_input
            svhn_input.maybe_download()

            all_data, all_labels = svhn_input.read_train_data()
            vali_data, vali_labels = svhn_input.read_test_data()

            # label is [1,...,10], so subtract one to make [0,...,9]
            all_labels -= 1
            vali_labels -= 1

            # To save memory, train_data will be divided by 255 when it feeds.
            vali_data = vali_data.astype(np.float32) / 255.
        else:
            assert False, 'unknown data'

        return all_data, all_labels, vali_data, vali_labels
def main():
    # Ensure data_dir has dataset and model_dir is cleared before training.
    cifar10_input.maybe_download_and_extract(data_dir=args.data_dir)
    if tf.gfile.Exists(args.model_dir):
        tf.gfile.DeleteRecursively(args.model_dir)
    tf.gfile.MakeDirs(args.model_dir)

    classifier = tf.estimator.Estimator(model_fn=model_fn,
                                        model_dir=args.model_dir,
                                        params={'batch_size': args.batch_size})
    classifier.train(input_fn=lambda: input_fn(args.data_dir, args.batch_size),
                     steps=10000)
def main(argv=None):  # pylint: disable=unused-argument
    import random
    # Fix random seed to produce exactly the same results.
    random.seed(0)
    tf.set_random_seed(0)
    np.random.seed(0)

    data_dir = os.path.join(os.getcwd(), 'tmp/cifar10_data')
    cifar10_input.maybe_download_and_extract(data_dir)

    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)
    train()
Пример #4
0
def main(argv=None):  # pylint: disable=unused-argument
    cifar10_input.maybe_download_and_extract()
    if tf.gfile.Exists(FLAGS.train_dir):
        tf.gfile.DeleteRecursively(FLAGS.train_dir)
    tf.gfile.MakeDirs(FLAGS.train_dir)
    train()
Пример #5
0
import cifar10_input
cifar10_input.maybe_download_and_extract()

import tensorflow as tf
from tensorflow.python import control_flow_ops
import numpy as np
import time, os

# Architecture
n_hidden_1 = 256
n_hidden_2 = 256

# Parameters
learning_rate = 0.01
training_epochs = 1000
batch_size = 128
display_step = 1

def inputs(eval_data=True):
  data_dir = os.path.join('data/cifar10_data', 'cifar-10-batches-bin')
  return cifar10_input.inputs(eval_data=eval_data, data_dir=data_dir,
                              batch_size=batch_size)

def distorted_inputs():
  data_dir = os.path.join('data/cifar10_data', 'cifar-10-batches-bin')
  return cifar10_input.distorted_inputs(data_dir=data_dir,
                                        batch_size=batch_size)

def conv_batch_norm(x, n_out, phase_train):
    beta_init = tf.constant_initializer(value=0.0, dtype=tf.float32)
    gamma_init = tf.constant_initializer(value=1.0, dtype=tf.float32)
Пример #6
0
def maybe_download_and_extract():
    return cifar10_input.maybe_download_and_extract()
Пример #7
0
def main(_):
    if (FLAGS.name is None):
        print("Usage: train.py --name=NAME")
        exit(1)
    data_input.maybe_download_and_extract()
    run_training()
import os
sys.path.insert(1, os.path.join(sys.path[0], '..'))

from gcnn_lib.grid_graph import grid_graph
from gcnn_lib.coarsening import coarsen
from gcnn_lib.coarsening import lmaxX
from gcnn_lib.coarsening import perm_data
from gcnn_lib.coarsening import lmaxX
from gcnn_lib.coarsening import rescale_L

import cifar10_input

DATA_DIR = "./data"
# mnist = input_data.read_data_sets("data/", one_hot=False)

cifar10_input.maybe_download_and_extract(DATA_DIR)

# train_data = mnist.train.images.astype(np.float32)
train_data, train_labels = cifar10_input.inputs(
    False, os.path.join(DATA_DIR, 'cifar-10-batches-bin'),
    cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN +
    cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_VAL)

test_data, test_labels = cifar10_input.inputs(
    True, os.path.join(DATA_DIR, 'cifar-10-batches-bin'),
    cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL)
####### test_data = tf.contrib.layers.flatten(test_data)
val_data = tf.slice(train_data,
                    [cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN, 0, 0, 0],
                    [-1, -1, -1, -1])
val_labels = tf.slice(train_labels,
Пример #9
0
import os
import shutil
import sys
from timeit import default_timer as timer
import argparse

import tensorflow as tf
import numpy as np
import math

from model import Model
import cifar10_input
import tiny_input
from pgd_attack import LinfPGDAttack

cifar10_input.maybe_download_and_extract('../DATA')

num_classes = 10

parser = argparse.ArgumentParser()
parser.add_argument('--oat', help='apply OAT', action='store_true')
parser.add_argument('--alpha',
                    default=1.0,
                    type=float,
                    help='hyperparameter alpha for OAT')
parser.add_argument('--suffix', help='suffix')
args = parser.parse_args()

with open('config.json') as config_file:
    config = json.load(config_file)
Пример #10
0
def main(_):
  if(FLAGS.name is None):
    print("Usage: train.py --name=NAME")
    exit(1)
  data_input.maybe_download_and_extract()
  run_training()
Пример #11
0
def main(argv=None):  # pylint: disable=unused-argument
  cifar10_input.maybe_download_and_extract(FLAGS.data_dir)
  if tf.gfile.Exists(FLAGS.eval_dir):
    tf.gfile.DeleteRecursively(FLAGS.eval_dir)
  tf.gfile.MakeDirs(FLAGS.eval_dir)
  evaluate()