def predict(model, filename, weights_path):
    """
	Predicts the label for a single file with MFCC features.

	Args:
	model: Keras model 
	weights_path: filepath to the weights of the model
	filename: filepath to the file which should be processed

	Returns:
	The label.
	"""
    if os.path.isfile(weights_path):
        model.load_weights(weights_path, by_name=False)

    sample = data_wrangler.single_data_processing(filename)
    pred = model.predict(sample)
    print(pred)
    print(data_read.label_from_oh(pred))
    return data_read.label_from_oh(pred)
def predict_preprocessed(weights_path, file_paths, output_filename):
	"""
	Predict list processed files and saves it to CSV file.

	Args:
	weights_path:		weights for model
	file_paths:			files to predict a label for
	output_filename: 	name of CSV file

	Returns:
	-
	"""
	if os.path.isfile(weights_path):
		print("Loading weights")
		model.load_weights(weights_path, by_name=False)

	with open(output_filename, "w") as csv_file:
	
		csv_writer = csv.writer(csv_file, delimiter=",", dialect="unix", quoting = csv.QUOTE_NONE)

		iteration = 0
		max_iterations = math.ceil(len(file_paths)/256)

		csv_writer.writerow(["fname","label"])

		while len(file_paths) > 0:

			batch = file_paths[:256]
			batch = [np.reshape(np.loadtxt(filename), (101, 128)) for filename in batch]
			batch = np.asarray(batch)

			pred = model.predict(batch, batch_size=256)
			for filename, prediction in zip(file_paths, pred):
				clean_filename = filename.split("/")[-1].split(".")[0]+".wav"
				csv_writer.writerow([clean_filename, data_read.label_from_oh(prediction)]) 	
			
			file_paths=file_paths[len(batch):]																# TODO
			iteration +=1
			print("%s/%s done." % (iteration, max_iterations))	
def predict_single_unpreprocessed(weights_path, file, model):
	"""
	Predict list of raw '.wav' files and saves it to CSV file.

	Args:
	weights_path:		weights for model
	file_paths:			files to predict a label for
	output_filename: 	name of CSV file

	Returns:
	-
	"""
	if os.path.isfile(weights_path):
		#print("Loading weights")
		model.load_weights(weights_path, by_name=False)

	batch = [data_wrangler.single_data_processing_raw(file)]
	batch = np.reshape(np.array(batch), (len(batch),101,13))


	pred = model.predict(batch, batch_size=256)
	print(data_read.label_from_oh(pred))
def predict_unpreprocessed(weights_path, file_paths, output_filename):
	"""
	Predict list of raw '.wav' files and saves it to CSV file.

	Args:
	weights_path:		weights for model
	file_paths:			files to predict a label for
	output_filename: 	name of CSV file

	Returns:
	-
	"""
	if os.path.isfile(weights_path):
		print("Loading weights")
		model.load_weights(weights_path, by_name=False)

	with open(output_filename, "w") as csv_file:
	
		csv_writer = csv.writer(csv_file, delimiter=",", dialect="unix", quoting = csv.QUOTE_NONE)

		iteration = 0
		max_iterations = math.ceil(len(file_paths)/256)

		while len(file_paths) >0:
			batch_filenames = file_paths[:256]

			batch = [data_wrangler.single_data_processing(filename) for filename in batch_filenames]
			batch = np.reshape(np.array(batch), (len(batch),101,13))

			pred = model.predict(batch, batch_size=256)
			for filename, prediction in zip(batch_filenames, pred):
				clean_filename = filename.split("/")[-1]
				csv_writer.writerow([clean_filename, data_read.label_from_oh(prediction)]) 																	# TODO
			
			file_paths = file_paths[min(len(batch),256):]
			iteration = iteration+1
			sys.stdout.write('\r>> Predicting classes for batch %d of %d.' % (iteration,  max_iterations))
			sys.stdout.flush()
    sample = data_wrangler.single_data_processing_with_mel_spectrogram(
        filename)
    pred = model.predict(sample)
    print(pred)
    print(data_read.label_from_oh(pred))


if __name__ == "__main__":

    model = my_model()
    x_test, y_test, x_train, y_train = load_dataset()

    d = {}
    for l in y_train:
        d[data_read.label_from_oh(l)] = d.get(data_read.label_from_oh(l),
                                              0) + 1
    print(d)

    log_path = "models/v54"
    weights_path = log_path + "/weights.best.hdf5"

    learning_rate = 0.0001
    train_model(model, 1000, x_test, y_test, x_train, y_train, weights_path,
                log_path, learning_rate)

    if os.path.isfile(weights_path):
        model.load_weights(weights_path, by_name=False)
    pred = model.predict(x_test, batch_size=128)
    x = 0
    for p_i, l_i in zip(pred, [data_read.label_from_oh(l) for l in y_test]):