def find_image_mean(image_path, dataset_path = '/raid/mainak/datasets/nabirds/'):

    image_paths = load_image_paths(dataset_path, path_prefix=image_path)
    image_class_labels = load_image_labels(dataset_path)
    
    # Load in the train / test split
    train_images, test_images = load_train_test_split(dataset_path)
    
    # Visualize the images and their annotations
    image_identifiers = image_paths.keys()
    random.shuffle(image_identifiers) 
    
    count = 0

    for image_id in train_images:
        image_path = image_paths[image_id]
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        im_resized = np.asarray(cv2.resize(image, IM_ORIG_SIZE), dtype = np.float64)
        # x = randint(0,28)
        # y = randint(0,28)
        # im_cropped = im_resized[x:x+IM_SIZE[0], y:y+IM_SIZE[1]]
        im_cropped = im_resized
        im_cropped = im_cropped / 255
        if (count == 0):
            im_mean = im_cropped
        else:    
            im_mean = (count * im_mean + im_cropped)/(count+1)
        count += 1
    return im_mean    
def load_image_batch(im_mean, start_idx, batch_size, image_path, im_size, dataset_path = '/raid/mainak/datasets/nabirds/', is_train=1):
    
    image_paths = load_image_paths(dataset_path, path_prefix=image_path)
    image_class_labels = load_image_labels(dataset_path)
    
    # Load in the train / test split
    train_images, test_images = load_train_test_split(dataset_path)
    
    # Visualize the images and their annotations
    image_identifiers = image_paths.keys()
    random.shuffle(image_identifiers) 
    if is_train ==1:
        image_set = train_images
    else:
        image_set = test_images
            
    if start_idx == 0:
        random.shuffle(image_set) 
    batch_images = image_set[start_idx:start_idx+batch_size]
    
    count = 0
    out_images = []
    image_labels = []

    for image_id in batch_images:
	  
        image_path = image_paths[image_id]
        # print image_path
        image = cv2.imread(image_path, cv2.IMREAD_COLOR)
        class_label = image_class_labels[image_id]
        #if (class_label > 21):
        #    print "Error!: ", class_label
        
        count += 1
        processed_image = preprocessImage(image, im_mean, im_size, is_train)
        out_images.append(processed_image)
        image_labels.append(class_label)
        
	#out_images_arr = np.transpose(np.asarray(out_images), (0,3,1,2)) # Make it n_images x n_ch x n_row x n_col
	out_images_arr = np.asarray(out_images) # Make it n_images x n_row x n_col x n_ch
	image_labels_arr = np.asarray(image_labels)
	
	batch_x = out_images_arr
	batch_y = image_labels_arr

    return batch_x, batch_y
    correct = 0.0
    for i in range(predictions.shape[0]):
        if np.argmax(predictions[i]) == np.argmax(labels[i]):
            correct += 1
    return 100.0 * (correct / predictions.shape[0]) 

if __name__ == "__main__":
    args = parse_args()
    datapath = Path(args.dataset_path)
    samplepath = Path(args.sample_path)
    summaries_dir = "summaries/"
    num_patches = 10
    num_epochs = 10
    num_classes = 555
    savepath = Path(args.save_path)
    train_images, test_images = load_train_test_split()
    image_paths = load_image_paths()
    image_labels = load_image_labels()
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4)
    batch_size = 30

    with tf.Session(config=tf.ConfigProto()) as sess:
        vgg = vgg16.Vgg16(555, 80)
        images = tf.placeholder("float", [batch_size, 224, 224, 3])
        with tf.name_scope("content_vgg"):
             vgg.build(images)
        conv_1 = vgg.conv1_1
        conv_6 = vgg.conv3_2
        conv_9 = vgg.conv4_2
        conv_12 = vgg.conv5_2
        batch_cnt = tf.Variable(0, trainable=False)
def get_data_length(dataset_path = '/raid/mainak/datasets/nabirds/'):

    # Load in the train / test split
    train_images, test_images = load_train_test_split(dataset_path)
    return len(train_images), len(test_images)