param_file = '/home/sainbar/fastnet-confussion-layer/config/cifar-10-18pct-confussion11x22.cfg'
learning_rate = 1
image_color = 3
image_size = 32
image_shape = (image_color, image_size, image_size, batch_size)
init_model = parser.parse_config_file(param_file)
net = fastnet.net.FastNet(learning_rate, image_shape, init_model)

# prepare data
train_data, train_labels, test_data, test_labels = data_loader.load_cifar10()
data_mean = train_data.mean(axis=1, keepdims=True)
train_data = train_data - data_mean
test_data = test_data - data_mean

# noisy data
noisy_data, noisy_labels = data_loader.load_noisy_labeled()
noisy_data = noisy_data - data_mean
noisy_labels += 11

# background noise
back_data = data_loader.load_noise()
back_data = back_data - data_mean
back_labels = np.ones(back_data.shape[1]) * 10

train_data = np.concatenate(
    (train_data[:, 0:pure_sz], noisy_data[:, 0:noise_sz],
     back_data[:, 0:back_sz]),
    axis=1)
train_labels = np.concatenate(
    (train_labels[0:pure_sz], noisy_labels[0:noise_sz],
     back_labels[0:back_sz]))
# prepare data
clean_data, clean_labels, test_data, test_labels = data_loader.load_cifar10()
data_mean = clean_data.mean(axis=1,keepdims=True)
clean_data = clean_data - data_mean
test_data = test_data - data_mean

# background noise
back_data = data_loader.load_noise()
back_data = back_data - data_mean
back_labels = np.ones(back_data.shape[1])
for i in range(back_sz):
	back_labels[i] = i % 10 # easy to reproduce

# noisy data
noisy_data, noisy_labels = data_loader.load_noisy_labeled()
noisy_data = noisy_data - data_mean

# mix data
train_data = np.concatenate((clean_data[:,0:pure_sz], noisy_data[:,0:noisy_sz], back_data[:,0:back_sz]), axis=1)
train_labels = np.concatenate((clean_labels[0:pure_sz], noisy_labels[0:noisy_sz], back_labels[0:back_sz]))

val_sz = 0
pure_sz2 = pure_sz + int(1. * val_sz * pure_sz/(pure_sz + noisy_sz + back_sz))
noisy_sz2 = noisy_sz + int(1. * val_sz * noisy_sz/(pure_sz + noisy_sz + back_sz))
back_sz2 = back_sz + int(1. * val_sz * back_sz/(pure_sz + noisy_sz + back_sz))
assert pure_sz2 <= clean_data.shape[1]
assert noisy_sz2 <= noisy_data.shape[1]
assert back_sz2 <= back_data.shape[1]
val_data = np.concatenate((clean_data[:,pure_sz:pure_sz2], noisy_data[:,noisy_sz:noisy_sz2], back_data[:,back_sz:back_sz2]), axis=1)
val_labels = np.concatenate((clean_labels[pure_sz:pure_sz2], noisy_labels[noisy_sz:noisy_sz2], back_labels[back_sz:back_sz2]))