예제 #1
0
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import time
from datetime import timedelta

import cifar

cifar.download()

print(cifar.load_class_names())

train_img, train_cls, train_labels = cifar.load_training_data()
test_img, test_cls, test_labels = cifar.load_test_data()

print('Training set:', len(train_img), 'Testing set:', len(test_img))
x = tf.placeholder(tf.float32, [None, 32, 32, 3])
y_true = tf.placeholder(tf.float32, [None, 10])


def conv_layer(input, size_in, size_out, use_pooling=True):
    w = tf.Variable(tf.truncated_normal([3, 3, size_in, size_out], stddev=0.1))
    b = tf.Variable(tf.constant(0.1, shape=[size_out]))
    conv = tf.nn.conv2d(input, w, strides=[1, 1, 1, 1], padding='SAME')
    y = tf.nn.relu(conv + b)

    if use_pooling:
        y = tf.nn.max_pool(y,
                           ksize=[1, 2, 2, 1],
                           strides=[1, 2, 2, 1],
                           padding='SAME')
예제 #2
0
 def prepare_data(self):
     train_images, train_cls_res, train_cls_vec = cifar.load_training_data()
     test_images, test_cls_res, test_cls_vec = cifar.load_test_data()
     return train_images, train_cls_vec, test_images, test_cls_vec
예제 #3
0
# Oświadczam że kod napisałem samodzielnie - Konrad Kalita

import numpy as np
import tensorflow as tf
from tqdm import tqdm
import cifar
import math

train_data, train_labels = cifar.load_training_data()
test_data, test_labels = cifar.load_test_data()
train_data = train_data[:45000]
train_labels = train_labels[:45000]
print(train_data.shape, train_labels.shape)
print(test_data.shape, test_labels.shape)


def permute(data, labels):
        perm = np.random.permutation(data.shape[0])
        return (data[perm], labels[perm])


def get_batch(data, labels, size):
    for k in range(0, data.shape[0],  size):
        yield k/size, (data[k:k + size], labels[k:k + size])


def crop(imgs, train=False):
    imgs = imgs.reshape(imgs.shape[0], 32, 32, 3)
    if train:
        for i in range(8):
            side = np.random.randint(2)