def start(p = 12,num_clusters =100,max_iter=50, batch_size = 500,init = 'random' ):
    
    data_loader = DataLoader()
    cifar_data = data_loader.load_cifar_data()     
    images = cifar_data['data'].reshape((-1,3,32,32)).astype('float32')
#     img_test = images[2,:,:,:]
#     img_test = np.rollaxis(img_test,0,3)
#     img_test = img_test[:,:,::-1]
#     plt.imshow(img_test)
#     plt.show()

    images = np.rollaxis(images,1,4)
    images = images[:,:,:,::-1]
    
    num_patches = images.shape[0]
    patch_size = [p,p]
#   
    kmeans = MiniBatchKMeans(num_clusters,max_iter,batch_size,init)  
    
    patches_img = kmeans.generate_patches(images, patch_size)
    
#     plt.imshow(patches_img[0,:,:,:])
#     plt.show();
    # Convert to matrix form rows X cols
    patches=patches_img.reshape(patches_img.shape[0],-1)
#     i=display(patches[0,:], patch_size)
#     plt.imshow(i)
#     plt.show()
    
    
    #pre-processing
    
    centers,counts = kmeans.fit(patches)
    
    fig = plt.figure()
    disp_row_size = np.ceil(np.sqrt(kmeans.num_clusters))
    
    for i in xrange(kmeans.num_clusters):
        subplot = fig.add_subplot(disp_row_size, disp_row_size, i+1)    
        subplot.get_xaxis().set_visible(False)
        subplot.get_yaxis().set_visible(False)
        img = display(centers[:,i], patch_size)
        subplot.imshow(img, interpolation='none')
    
    #plt.show()
    
    directory=check_create_observations_dir()
    
    plt.savefig(directory+'/repFields.png')
#     patch_test=patches[0,:,:,:]
#     plt.imshow(patch_test)
#     plt.show()
    
    display_bar(counts,directory+'/clusterCount.png')
    
    
    print "THE END" 
def fit_cifar():
    print "Computing for CIFAR 10"
    data_loader = DataLoader()
    cifar_data=data_loader.load_cifar_data()        
    
    train_set_x=cifar_data['data']
    train_set_y=cifar_data['labels']  
    
    plt=get_pairwise_plot(train_set_x, train_set_y)
    
    obs_dir=check_create_observations_dir("PCA")
    target_path = os.path.join(obs_dir,"scatterplotCIFAR.png")
    plt.savefig(target_path)
    print "THE END" 

def display(data_row, patch_size):

    data_row = data_row - data_row.min()
    data_row = data_row / data_row.max()
    img = data_row.reshape(3, patch_size[0], patch_size[1]).astype("float32")
    img = np.rollaxis(img, 0, 3)
    return img
    # plt.imshow(img)
    # plt.show()


if __name__ == "__main__":
    data_loader = DataLoader()
    cifar_data = data_loader.load_cifar_data()
    images = cifar_data["data"].reshape((-1, 3, 32, 32)).astype("float32")
    #     img_test = images[2,:,:,:]
    #     img_test = np.rollaxis(img_test,0,3)
    #     img_test = img_test[:,:,::-1]
    #     plt.imshow(img_test)
    #     plt.show()

    images = np.rollaxis(images, 1, 4)
    images = images[:, :, :, ::-1]

    num_patches = images.shape[0]
    patch_size = [12, 12]
    #
    kmeans = KMeans()