import baxter_writer as bw

import dataset
import vae_assoc

import utils

np.random.seed(0)
tf.set_random_seed(0)

print 'Loading image data...'
img_data = utils.extract_images(fname='bin/img_data_extend.pkl', only_digits=False)
# img_data = utils.extract_images(fname='bin/img_data.pkl', only_digits=False)
# img_data_sets = dataset.construct_datasets(img_data)
print 'Loading joint motion data...'
fa_data, fa_mean, fa_std = utils.extract_jnt_fa_parms(fname='bin/jnt_ik_fa_data_extend.pkl', only_digits=False)
# fa_data, fa_mean, fa_std = utils.extract_jnt_fa_parms(fname='bin/jnt_fa_data.pkl', only_digits=False)
#normalize data
fa_data_normed = (fa_data - fa_mean) / fa_std

# fa_data_sets = dataset.construct_datasets(fa_data_normed)
print 'Constructing dataset...'
#put them together
aug_data = np.concatenate((img_data, fa_data_normed), axis=1)

data_sets = dataset.construct_datasets(aug_data, validation_ratio=.1, test_ratio=.1)
print 'Start training...'
batch_sizes = [64]
#n_z_array = [3, 5, 10, 20]
n_z_array = [4]
# assoc_lambda_array = [1, 3, 5, 10]