net.checkpoint_name += '_wdX' + str(args.wdecayX)
	for l in net.layers:
		if hasattr(l,'wc'):
			l.wc *= args.wdecayX	
net.output_dir = '~/data/outside-noise-results/results_BU_robust/' + net.checkpoint_name + '/'
if os.path.exists(net.output_dir) == False:
	os.mkdir(net.output_dir)

# prepare data
clean_data, clean_labels, test_data, test_labels = data_loader.load_cifar10()
data_mean = clean_data.mean(axis=1,keepdims=True)
clean_data = clean_data - data_mean
test_data = test_data - data_mean

# background noise
back_data = data_loader.load_noise()
back_data = back_data - data_mean
back_labels = np.ones(back_data.shape[1])
for i in range(back_sz):
	back_labels[i] = i % 10 # easy to reproduce

# noisy data
noisy_data, noisy_labels = data_loader.load_noisy_labeled()
noisy_data = noisy_data - data_mean

# mix data
train_data = np.concatenate((clean_data[:,0:pure_sz], noisy_data[:,0:noisy_sz], back_data[:,0:back_sz]), axis=1)
train_labels = np.concatenate((clean_labels[0:pure_sz], noisy_labels[0:noisy_sz], back_labels[0:back_sz]))

val_sz = 0
pure_sz2 = pure_sz + int(1. * val_sz * pure_sz/(pure_sz + noisy_sz + back_sz))
init_model = parser.parse_config_file(param_file)
net = fastnet.net.FastNet(learning_rate, image_shape, init_model)

# prepare data
train_data, train_labels, test_data, test_labels = data_loader.load_cifar10()
data_mean = train_data.mean(axis=1, keepdims=True)
train_data = train_data - data_mean
test_data = test_data - data_mean

# noisy data
noisy_data, noisy_labels = data_loader.load_noisy_labeled()
noisy_data = noisy_data - data_mean
noisy_labels += 11

# background noise
back_data = data_loader.load_noise()
back_data = back_data - data_mean
back_labels = np.ones(back_data.shape[1]) * 10

train_data = np.concatenate(
    (train_data[:, 0:pure_sz], noisy_data[:, 0:noise_sz],
     back_data[:, 0:back_sz]),
    axis=1)
train_labels = np.concatenate(
    (train_labels[0:pure_sz], noisy_labels[0:noise_sz],
     back_labels[0:back_sz]))

# shuffle data
order = range(pure_sz + back_sz + noise_sz)
np.random.shuffle(order)
train_data = train_data[:, order]
l = net.layers[-2]
w = np.eye(11)
w[:,10] = 0.05
w[10,10] = 0.5
#w = np.concatenate((w, 0.1*np.ones((10,1))), axis=1)
l.weight = data_loader.copy_to_gpu(w)


# prepare data
data, labels, test_data, test_labels = data_loader.load_cifar10()
data_mean = data.mean(axis=1,keepdims=True)
data = data - data_mean
test_data = test_data - data_mean

# noisy data
noisy_data = data_loader.load_noise()
noisy_data = noisy_data - data_mean
noisy_labels = np.ones(noisy_data.shape[1]) * 10

# mix data
pure_sz = int(sys.argv[1])
noise_sz = int(sys.argv[2])

for i in range(noise_sz/2):
	noisy_labels[i] = np.random.randint(10)

train_data = np.concatenate((data[:,0:pure_sz], noisy_data[:,0:noise_sz]), axis=1)
train_labels = np.concatenate((labels[0:pure_sz], noisy_labels[0:noise_sz]))

order = range(train_data.shape[1])
np.random.shuffle(order)