예제 #1
0
def train(model, train_splits_batches, valid_splits_batches, test_splits_batches, normalizer,
		  model_params, parameters, config_folder, start_time_train, logger_train):
	
	DEBUG = parameters["debug"]
	SUMMARIZE=DEBUG["summarize"]
	
	# Build DEBUG dict
	if DEBUG["save_measures"]:
		DEBUG["save_measures"] = config_folder+"/save_measures"

	# Time information used
	time_limit = parameters['walltime'] * 3600 - 30*60  # walltime - 30 minutes in seconds

	# Reset graph before starting training
	tf.reset_default_graph()
		
	###### PETIT TEST VALIDATION
	# Use same validation and train set
	# piano_valid, orch_valid, valid_index = piano_train, orch_train, train_index	

	which_trainer = model.trainer()

	# Save it for generation. SO UGLY
	with open(os.path.join(config_folder, 'which_trainer'), 'w') as ff:
		ff.write(which_trainer)
	trainer = import_trainer(which_trainer, model_params, parameters)
	# Flag to know if the model has to be trained or not
	model_optimize = model.optimize()

	############################################################
	# Display informations about the models
	num_parameters = model_statistics.count_parameters(tf.get_default_graph())
	logger_train.info('** Num trainable parameters :  {}'.format(num_parameters))
	with open(os.path.join(config_folder, 'num_parameters.txt'), 'w') as ff:
		ff.write("{:d}".format(num_parameters))

	############################################################
	# Training
	logger_train.info("#" * 60)
	logger_train.info("#### Training")
	epoch = 0
	OVERFITTING = False
	TIME_LIMIT = False

	# Train error
	loss_tab = np.zeros(max(1, parameters['max_iter']))

	# Select criteria
	overfitting_measure = parameters["overfitting_measure"]
	save_measures = parameters['save_measures']

	# Short-term validation error
	valid_tabs = {
		'loss': np.zeros(max(1, parameters['max_iter'])),
		'accuracy': np.zeros(max(1, parameters['max_iter'])),
		'precision': np.zeros(max(1, parameters['max_iter'])),
		'recall': np.zeros(max(1, parameters['max_iter'])),
		'true_accuracy': np.zeros(max(1, parameters['max_iter'])),
		'f_score': np.zeros(max(1, parameters['max_iter'])),
		'Xent': np.zeros(max(1, parameters['max_iter']))
		}
	# Best epoch for each measure
	best_epoch = {
		'loss': 0, 
		'accuracy': 0, 
		'precision': 0, 
		'recall': 0, 
		'true_accuracy': 0, 
		'f_score': 0, 
		'Xent': 0
	}
	
	# Sampled preds measures
	valid_tabs_sampled = {
		'loss': np.zeros(max(1, parameters['max_iter'])),
		'accuracy': np.zeros(max(1, parameters['max_iter'])),
		'precision': np.zeros(max(1, parameters['max_iter'])),
		'recall': np.zeros(max(1, parameters['max_iter'])),
		'true_accuracy': np.zeros(max(1, parameters['max_iter'])),
		'f_score': np.zeros(max(1, parameters['max_iter'])),
		'Xent': np.zeros(max(1, parameters['max_iter']))
		}

	# Long-term validation error
	valid_tabs_LR = {
		'loss': np.zeros(max(1, parameters['max_iter'])), 
		'accuracy': np.zeros(max(1, parameters['max_iter'])), 
		'precision': np.zeros(max(1, parameters['max_iter'])), 
		'recall': np.zeros(max(1, parameters['max_iter'])), 
		'true_accuracy': np.zeros(max(1, parameters['max_iter'])), 
		'f_score': np.zeros(max(1, parameters['max_iter'])), 
		'Xent': np.zeros(max(1, parameters['max_iter']))
		}
	# Best epoch for each measure
	best_epoch_LR = {
		'loss': 0, 
		'accuracy': 0, 
		'precision': 0, 
		'recall': 0, 
		'true_accuracy': 0, 
		'f_score': 0, 
		'Xent': 0
	}

	### Timing file
	# open('timing', 'w').close()

	if parameters['memory_gpu']:
		# This apparently does not work
		configSession = tf.ConfigProto()
		configSession.gpu_options.per_process_gpu_memory_fraction = parameters['memory_gpu']
	else:
		configSession = None

	with tf.Session(config=configSession) as sess:

		# Only for models with shared weights
		model.init_weights()

		##############################
		# Create PH and nodes
		if parameters['pretrained_model'] is None:
			logger_train.info((u'#### Graph'))
			start_time_building_graph = time.time()
			trainer.build_variables_nodes(model, parameters)
			trainer.build_preds_nodes(model)
			trainer.build_loss_nodes(model, parameters)
			trainer.build_train_step_node(model, config.optimizer())
			trainer.save_nodes(model)
			time_building_graph = time.time() - start_time_building_graph
			logger_train.info("TTT : Building the graph took {0:.2f}s".format(time_building_graph))
		else:
			logger_train.info((u'#### Graph'))
			start_time_building_graph = time.time() 
			trainer.load_pretrained_model(parameters['pretrained_model'])
			time_building_graph = time.time() - start_time_building_graph
			logger_train.info("TTT : Loading pretrained model took {0:.2f}s".format(time_building_graph))
		
		if SUMMARIZE:
			tf.summary.scalar('loss', trainer.loss)
		##############################

		summarize_dict = {}
		if SUMMARIZE:
			merged_node = tf.summary.merge_all()
			train_writer = tf.summary.FileWriter(config_folder + '/summary', sess.graph)
			train_writer.add_graph(tf.get_default_graph())
		else:
			merged_node = None
		summarize_dict['bool'] = SUMMARIZE
		summarize_dict['merged_node'] = merged_node

		if model.is_keras():
			K.set_session(sess)
		
		# Initialize weights
		if parameters['pretrained_model']: 
			trainer.saver.restore(sess, parameters['pretrained_model'] + '/model')
		else:
			sess.run(tf.global_variables_initializer())


		# if DEBUG:
		# 	sess = tf_debug.LocalCLIDebugWrapperSession(sess)
		# 	sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)
			
		N_matrix_files = len(train_splits_batches)
		
		#######################################
		# Load first matrix
		#######################################
		load_data_start = time.time()
		pool = ThreadPool(processes=1)
		async_train = pool.apply_async(async_load_mat, (normalizer, train_splits_batches[0]['chunks_folders'], parameters))
		matrices_from_thread = async_train.get()
		load_data_time = time.time() - load_data_start
		logger_train.info("Load the first matrix time : " + str(load_data_time))

		# For dumb baseline models like random or repeat which don't need training step optimization
		if model_optimize == False:
			# WARNING : first validation matrix is not necessarily the same as the first train matrix
			async_test = pool.apply_async(async_load_mat, (normalizer, test_splits_batches[0]['chunks_folders'], parameters))
			init_matrices_test = async_test.get()
			test_results, test_results_sampled, test_long_range_results, _, _ = validate(trainer, sess, 
					init_matrices_test, test_splits_batches, 
					normalizer, parameters,
					logger_train)
			training_utils.mean_and_store_results(test_results, valid_tabs, 0)
			training_utils.mean_and_store_results(test_results_sampled, valid_tabs_sampled, 0)
			training_utils.mean_and_store_results(test_long_range_results, valid_tabs_LR, 0)
			return training_utils.remove_tail_training_curves(valid_tabs, 1), {}, best_epoch, \
				training_utils.remove_tail_training_curves(valid_tabs_sampled, 1), {}, best_epoch_sampled, \
				training_utils.remove_tail_training_curves(valid_tabs_LR, 1), {}, best_epoch_LR

		# Training iteration
		while (not OVERFITTING and not TIME_LIMIT
			   and epoch != parameters['max_iter']):
		
			start_time_epoch = time.time()

			trainer.DEBUG['epoch'] = epoch

			train_cost_epoch = []
			sparse_loss_epoch = []
			
			train_time = time.time()
			for file_ind_CURRENT in range(N_matrix_files):

				#######################################
				# Get indices and matrices to load
				#######################################
				# ttt=time.time()
				# We train on the current matrix
				train_index = train_splits_batches[file_ind_CURRENT]['batches']
				# But load the one next one
				file_ind_NEXT = (file_ind_CURRENT+1) % N_matrix_files
				next_chunks = train_splits_batches[file_ind_NEXT]['chunks_folders']
				# ttt=time.time()-ttt
				# with open('timing', 'a') as ff:
				# 	ff.write("Get next chunks infos: {:.4f}\n".format(ttt))

				#######################################
				# Load matrix thread
				#######################################
				# ttt=time.time()
				async_train = pool.apply_async(async_load_mat, (normalizer, next_chunks, parameters))
				# ttt=time.time()-ttt
				# with open('timing', 'a') as ff:
				# 	ff.write("Launch thread data: {:.4f}\n".format(ttt))
				
				# ttt=time.time()
				piano_input, orch_transformed, duration_piano, mask_orch = matrices_from_thread
				# ttt=time.time()-ttt
				# with open('timing', 'a') as ff:
				# 	ff.write("Unpack matrices: {:.4f}\n".format(ttt))
				
				#######################################
				# Train
				#######################################
				# ttt=time.time()
				for batch_index in train_index:
					# ttt=time.time()
					loss_batch, _, debug_outputs, summary = trainer.training_step(sess, batch_index, piano_input, orch_transformed, duration_piano, mask_orch, summarize_dict)
					# ttt=time.time()-ttt
					# with open('timing', 'a') as ff:
					# 	ff.write("Sess Run: {:.4f}\n".format(ttt))

					# Keep track of cost
					train_cost_epoch.append(loss_batch)
					sparse_loss_batch = debug_outputs["sparse_loss_batch"]
					sparse_loss_epoch.append(sparse_loss_batch)
				# ttt=time.time()-ttt
				# with open('timing', 'a') as ff:
				# 	ff.write("Train over a batch: {:.4f}\n".format(ttt))

				#######################################
				# New matrices from thread
				#######################################
				# ttt=time.time()
				del(matrices_from_thread)
				# ttt=time.time()-ttt
				# with open('timing', 'a') as ff:
				# 	ff.write("Delete matrices: {:.4f}\n".format(ttt))
				# ttt=time.time()
				matrices_from_thread = async_train.get()
				# ttt=time.time()-ttt
				# with open('timing', 'a') as ff:
				# 	ff.write("Get next chunk from thread: {:.4f}\n##############\n".format(ttt))
				
			train_time = time.time() - train_time
			logger_train.info("Training time : {}".format(train_time))

			### 
			# DEBUG
			if trainer.DEBUG["plot_weights"]:
				# weight_folder=config_folder+"/weights/"+str(epoch)
				weight_folder=config_folder+"/weights"
				plot_weights.plot_weights(sess, weight_folder)

			#
			###

			# WARNING : first validation matrix is not necessarily the same as the first train matrix
			# So now that it's here, parallelization is absolutely useless....
			async_valid = pool.apply_async(async_load_mat, (normalizer, valid_splits_batches[0]['chunks_folders'], parameters))

			if SUMMARIZE:
				if (epoch<5) or (epoch%10==0):
					# Note that summarize here only look at the variables after the last batch of the epoch
					# If you want to look at all the batches, include it in 
					train_writer.add_summary(summary, epoch)
	 
			mean_loss = np.mean(train_cost_epoch)
			loss_tab[epoch] = mean_loss

			#######################################
			# Validate
			#######################################
			valid_time = time.time()
			init_matrices_validation = async_valid.get()

			# Create DEBUG folders
			if trainer.DEBUG["plot_nade_ordering_preds"]:
				AAA = config_folder+"/DEBUG/preds_nade/"+str(epoch)
				trainer.DEBUG["plot_nade_ordering_preds"] = AAA
				os.makedirs(AAA)
			if trainer.DEBUG["save_accuracy_along_sampling"]:
				AAA = config_folder+"/DEBUG/accuracy_along_sampling/"+str(epoch)
				trainer.DEBUG["save_accuracy_along_sampling"] = AAA
				os.makedirs(AAA)
			if trainer.DEBUG["salience_embedding"]:
				AAA = config_folder+"/DEBUG/salience_embeddings/"+str(epoch)
				trainer.DEBUG["salience_embedding"] = AAA
				os.makedirs(AAA)

			valid_results, valid_results_sampled, valid_long_range_results, preds_val, truth_val = \
				validate(trainer, sess, 
					init_matrices_validation, valid_splits_batches,
					normalizer, parameters,
					logger_train)
			valid_time = time.time() - valid_time
			logger_train.info("Validation time : {}".format(valid_time))

			training_utils.mean_and_store_results(valid_results, valid_tabs, epoch)
			training_utils.mean_and_store_results(valid_results_sampled, valid_tabs_sampled, epoch)
			training_utils.mean_and_store_results(valid_long_range_results, valid_tabs_LR, epoch)
			end_time_epoch = time.time()
			
			#######################################
			# Overfitting ? 
			if epoch >= parameters['min_number_iteration']:
				# Choose short/long range and the measure
				OVERFITTING = early_stopping.up_criterion(valid_tabs[overfitting_measure], epoch, parameters["number_strips"], parameters["validation_order"])
				if not OVERFITTING:
					# Also check for NaN
					OVERFITTING = early_stopping.check_for_nan(valid_tabs, save_measures, max_nan=3)
			#######################################

			#######################################
			# Monitor time (guillimin walltime)
			if (time.time() - start_time_train) > time_limit:
				TIME_LIMIT = True
			#######################################

			#######################################
			# Log training
			#######################################
			logger_train.info("############################################################")
			logger_train.info('Epoch : {} , Training loss : {} , Validation loss : {} \n \
		Validation accuracy : {:.3f} ; {:.3f} \n \
		Precision : {:.3f} ; {:.3f} \n \
		Recall : {:.3f} ; {:.3f} \n \
		Xent : {:.6f} ; {:.3f} \n \
		Sparse_loss : {:.3f}'
							  .format(epoch, mean_loss,
								valid_tabs['loss'][epoch],
								valid_tabs['accuracy'][epoch], valid_tabs_sampled['accuracy'][epoch],
								valid_tabs['precision'][epoch], valid_tabs_sampled['precision'][epoch],
								valid_tabs['recall'][epoch],  valid_tabs_sampled['recall'][epoch],
								valid_tabs['Xent'][epoch], valid_tabs_sampled['Xent'][epoch],
								np.mean(sparse_loss_epoch)))

			logger_train.info('Time : {}'
							  .format(end_time_epoch - start_time_epoch))

			#######################################
			# Best model ?
			# Xent criterion
			start_time_saving = time.time()
			for measure_name, measure_curve in valid_tabs.items():
				best_measure_so_far = measure_curve[best_epoch[measure_name]]
				measure_for_this_epoch = measure_curve[epoch]
				if (measure_for_this_epoch <= best_measure_so_far) or (epoch==0):
					if measure_name in save_measures:
						trainer.saver.save(sess, config_folder + "/model_" + measure_name + "/model")
					best_epoch[measure_name] = epoch

				#######################################
				# DEBUG
				# Save numpy arrays of measures values
				if trainer.DEBUG["save_measures"]:
					if os.path.isdir(trainer.DEBUG["save_measures"]):
						shutil.rmtree(trainer.DEBUG["save_measures"])
					os.makedirs(trainer.DEBUG["save_measures"])
					for measure_name, measure_tab in valid_results.items():
						np.save(os.path.join(trainer.DEBUG["save_measures"], measure_name + '.npy'), measure_tab[:2000])
					np.save(os.path.join(trainer.DEBUG["save_measures"], 'preds.npy'), np.asarray(preds_val[:2000]))
					np.save(os.path.join(trainer.DEBUG["save_measures"], 'truth.npy'), np.asarray(truth_val[:2000]))
				#######################################
	   
			end_time_saving = time.time()
			logger_train.info('Saving time : {:.3f}'.format(end_time_saving-start_time_saving))
			#######################################

			if OVERFITTING:
				logger_train.info('OVERFITTING !!')

			if TIME_LIMIT:
				logger_train.info('TIME OUT !!')

			#######################################
			# Epoch +1
			#######################################
			epoch += 1

		#######################################
		# Test
		#######################################
		test_time = time.time()
		async_test = pool.apply_async(async_load_mat, (normalizer, test_splits_batches[0]['chunks_folders'], parameters))
		init_matrices_test = async_test.get()
		test_results, test_results_sampled, test_long_range_results, preds_test, truth_test = \
			validate(trainer, sess, 
				init_matrices_test, test_splits_batches,
				normalizer, parameters,
				logger_train)
		test_time = time.time() - test_time
		logger_train.info("Test time : {}".format(test_time))
		
		test_tab={}
		test_tab_sampled={}
		test_tab_LR={}
		training_utils.mean_and_store_results(test_results, test_tab, None)
		training_utils.mean_and_store_results(test_results_sampled, test_tab_sampled, None)
		training_utils.mean_and_store_results(test_long_range_results, test_tab_LR, None)

		logger_train.info("############################################################")
		logger_train.info("""## Test Scores
Loss : {}
Validation accuracy : {:.3f} %, precision : {:.3f} %, recall : {:.3f} %
True_accuracy : {:.3f} %, f_score : {:.3f} %, Xent : {:.6f}"""
		.format(
			test_tab['loss'], test_tab['accuracy'], test_tab['precision'],
			test_tab['recall'], test_tab['true_accuracy'], test_tab['f_score'], test_tab['Xent']))
		logger_train.info('Time : {}'
						  .format(test_time))


		#######################################
		# Close workers' pool
		#######################################
		pool.close()
		pool.join()
	
	return training_utils.remove_tail_training_curves(valid_tabs, epoch), test_tab, best_epoch, \
		training_utils.remove_tail_training_curves(valid_tabs_sampled, epoch), test_tab_sampled, best_epoch,\
		training_utils.remove_tail_training_curves(valid_tabs_LR, epoch), test_tab_LR, best_epoch_LR

# bias=[v.eval() for v in tf.global_variables() if v.name == "top_layer_prediction/orch_pred/bias:0"][0]
# kernel=[v.eval() for v in tf.global_variables() if v.name == "top_layer_prediction/orch_pred/kernel:0"][0]
예제 #2
0
def generate_midi(config_folder_fd, config_folder_bd, config_folder_corr,
                  score_source, save_folder, initialization_type,
                  number_of_version, duration_gen, num_pass_correct,
                  logger_generate):
    """This function generate the orchestration of a midi piano score
    
    Parameters
    ----------
    config_folder : str
        Absolute path to the configuration folder, i.e. the folder containing the saved model and the results
    score_source : str
        Either a path to a folder containing two midi files (piano and orchestration) or the path toa piano midi files
    number_of_version : int
        Number of version generated in a batch manner. Since the generation process involves sampling it might be interesting to generate several versions
    duration_gen : int
        Length of the generated score (in number of events). Useful for generating only the beginning of the piece.
    logger_generate : logger
        Instanciation of logging. Can be None
    """

    logger_generate.info("#############################################")
    logger_generate.info("Orchestrating : " + score_source)

    # Load parameters
    parameters = pkl.load(
        open(config_folder_fd + '/script_parameters.pkl', 'rb'))
    model_parameters_fd = pkl.load(
        open(config_folder_fd + '/model_params.pkl', 'rb'))
    #
    parameters_bd = pkl.load(
        open(config_folder_bd + '/script_parameters.pkl', 'rb'))
    model_parameters_bd = pkl.load(
        open(config_folder_bd + '/model_params.pkl', 'rb'))
    #
    parameters_corr = pkl.load(
        open(config_folder_corr + '/script_parameters.pkl', 'rb'))
    model_parameters_corr = pkl.load(
        open(config_folder_corr + '/model_params.pkl', 'rb'))

    assert (model_parameters_fd["temporal_order"]
            == model_parameters_bd["temporal_order"]) and (
                model_parameters_fd["temporal_order"]
                == model_parameters_corr["temporal_order"]
            ), "The two model have different seed_size"
    assert (parameters["quantization"] == parameters_bd["quantization"]) and (
        parameters["quantization"] == parameters_corr["quantization"]
    ), "The two model have different quantization"
    assert (parameters["temporal_granularity"]
            == parameters_bd["temporal_granularity"]) and (
                parameters["temporal_granularity"]
                == parameters_corr["temporal_granularity"]
            ), "The two model have different temporal_granularity"
    assert (parameters["instru_mapping"] == parameters_bd["instru_mapping"]
            ) and (parameters["instru_mapping"]
                   == parameters_corr["instru_mapping"]
                   ), "The two model have different instru_mapping"
    assert (parameters["normalizer"] == parameters_bd["normalizer"]) and (
        parameters["normalizer"] == parameters_corr["normalizer"]
    ), "The two model have different normalizer"

    seed_size = max(model_parameters_fd['temporal_order'], 10) - 1

    #######################
    # Load data
    if re.search(r'mid$', score_source):
        pr_piano, event_piano, duration_piano, name_piano, pr_orch, instru_orch, duration = generation_utils.load_solo(
            score_source, parameters["quantization"],
            parameters["binarize_piano"], parameters["temporal_granularity"])
    else:
        if initialization_type == "seed":
            pr_piano, event_piano, duration_piano, name_piano, pr_orch, instru_orch, duration = generation_utils.load_from_pair(
                score_source,
                parameters["quantization"],
                parameters["binarize_piano"],
                parameters["binarize_orch"],
                parameters["temporal_granularity"],
                align_bool=True)
        else:
            pr_piano, event_piano, duration_piano, name_piano, pr_orch, instru_orch, duration = generation_utils.load_from_pair(
                score_source,
                parameters["quantization"],
                parameters["binarize_piano"],
                parameters["binarize_orch"],
                parameters["temporal_granularity"],
                align_bool=False)

    if (duration is None) or (duration < duration_gen):
        logger_generate.info("Track too short to be used")
        return
    ########################

    ########################
    # Shorten
    # Keep only the beginning of the pieces (let's say a 100 events)
    pr_piano = pianoroll_processing.extract_pianoroll_part(
        pr_piano, 0, duration_gen)
    if parameters["duration_piano"]:
        duration_piano = np.asarray(duration_piano[:duration_gen])
    else:
        duration_piano = None
    if parameters["temporal_granularity"] == "event_level":
        event_piano = event_piano[:duration_gen]
    pr_orch = pianoroll_processing.extract_pianoroll_part(
        pr_orch, 0, duration_gen)
    ########################

    ########################
    # Instanciate piano pianoroll
    N_piano = parameters["instru_mapping"]['Piano']['index_max']
    pr_piano_gen = np.zeros((duration_gen, N_piano), dtype=np.float32)
    pr_piano_gen = build_data_aux.cast_small_pr_into_big_pr(
        pr_piano, {}, 0, duration_gen, parameters["instru_mapping"],
        pr_piano_gen)
    pr_piano_gen_flat = pr_piano_gen.sum(axis=1)
    silence_piano = [
        e for e in range(duration_gen) if pr_piano_gen_flat[e] == 0
    ]
    ########################

    ########################
    # Initialize orchestra pianoroll with orchestra seed (choose one)
    N_orchestra = parameters['N_orchestra']
    pr_orchestra_truth = np.zeros((duration_gen, N_orchestra),
                                  dtype=np.float32)
    pr_orchestra_truth = build_data_aux.cast_small_pr_into_big_pr(
        pr_orch, instru_orch, 0, duration_gen, parameters["instru_mapping"],
        pr_orchestra_truth)
    if initialization_type == "seed":
        pr_orchestra_seed = generation_utils.init_with_seed(
            pr_orch, number_of_version, seed_size, N_orchestra, instru_orch,
            parameters["instru_mapping"])
    elif initialization_type == "zeros":
        pr_orchestra_seed = generation_utils.init_with_zeros(
            number_of_version, seed_size, N_orchestra)
    elif initialization_type == "constant":
        const_value = 0.1
        pr_orchestra_seed = generation_utils.init_with_constant(
            number_of_version, seed_size, N_orchestra, const_value)
    elif initialization_type == "random":
        proba_activation = 0.01
        pr_orchestra_seed = generation_utils.init_with_random(
            number_of_version, seed_size, N_orchestra, proba_activation)
    ########################

    #######################################
    # Embed piano
    time_embedding = time.time()
    if parameters['embedded_piano']:
        # Load model
        embedding_path = parameters["embedding_path"]
        embedding_model = torch.load(embedding_path, map_location="cpu")

        # Build embedding (no need to batch here, len(pr_piano_gen) is sufficiently small)
        # Plus no CUDA here because : afradi of mix with TF  +  possibly very long piano chunks
        piano_resize_emb = np.zeros(
            (len(pr_piano_gen), 1, 128))  # Embeddings accetp size 128 samples
        piano_resize_emb[:, 0, parameters["instru_mapping"]['Piano']
                         ['pitch_min']:parameters["instru_mapping"]['Piano']
                         ['pitch_max']] = pr_piano_gen
        piano_resize_emb_TT = torch.tensor(piano_resize_emb)
        piano_embedded_TT = embedding_model(piano_resize_emb_TT.float(), 0)
        pr_piano_gen_embedded = piano_embedded_TT.numpy()
    else:
        pr_piano_gen_embedded = pr_piano_gen
    time_embedding = time.time() - time_embedding
    #######################################

    ########################
    # Inputs' normalization
    normalizer = pkl.load(
        open(os.path.join(config_folder_fd, 'normalizer.pkl'), 'rb'))
    if parameters["embedded_piano"]:  # When using embedding, no normalization
        pr_piano_gen_norm = pr_piano_gen_embedded
    else:
        pr_piano_gen_norm = normalizer.transform(pr_piano_gen_embedded)
    ########################

    ########################
    # Store folder
    string = re.split(r'/', name_piano)[-1]
    name_track = re.sub('piano_solo.mid', '', string)
    generated_folder = save_folder + '/fd_bd_corr_' + initialization_type + '_init/' + name_track
    if not os.path.isdir(generated_folder):
        os.makedirs(generated_folder)
    ########################

    ########################
    # Get trainer
    with open(os.path.join(config_folder_fd, 'which_trainer'), 'r') as ff:
        which_trainer_fd = ff.read()
    # Trainer
    trainer_fd = import_trainer(which_trainer_fd, model_parameters_fd,
                                parameters)
    #
    with open(os.path.join(config_folder_bd, 'which_trainer'), 'r') as ff:
        which_trainer_bd = ff.read()
    # Trainer
    trainer_bd = import_trainer(which_trainer_bd, model_parameters_bd,
                                parameters)
    #
    with open(os.path.join(config_folder_corr, 'which_trainer'), 'r') as ff:
        which_trainer_corr = ff.read()
    # Trainer
    trainer_corr = import_trainer(which_trainer_corr, model_parameters_corr,
                                  parameters)
    ########################

    ############################################################
    # Generate
    ############################################################
    time_generate_0 = time.time()
    model_path = 'model_accuracy'
    # Forward
    pr_orchestra_gen = generate(trainer_fd,
                                pr_piano_gen_norm,
                                silence_piano,
                                duration_piano,
                                config_folder_fd,
                                model_path,
                                pr_orchestra_seed,
                                batch_size=number_of_version)
    prefix_name = 'fd_'
    generation_utils.reconstruct_generation(pr_orchestra_gen, event_piano,
                                            generated_folder, prefix_name,
                                            parameters, seed_size)
    # Backward
    pr_orchestra_seed = pr_orchestra_gen[:, -seed_size:]
    pr_orchestra_gen = generate_backward(trainer_bd,
                                         pr_piano_gen_norm,
                                         silence_piano,
                                         duration_piano,
                                         config_folder_bd,
                                         model_path,
                                         pr_orchestra_seed,
                                         batch_size=number_of_version)
    prefix_name = 'bd_'
    generation_utils.reconstruct_generation(pr_orchestra_gen, event_piano,
                                            generated_folder, prefix_name,
                                            parameters, seed_size)
    # Correction
    for pass_index in range(num_pass_correct):
        pr_orchestra_gen = correct(trainer_corr,
                                   pr_piano_gen_norm,
                                   silence_piano,
                                   duration_piano,
                                   config_folder_corr,
                                   model_path,
                                   pr_orchestra_gen,
                                   batch_size=number_of_version)
        generation_utils.reconstruct_generation(pr_orchestra_gen, event_piano,
                                                generated_folder, prefix_name,
                                                parameters, seed_size)
        prefix_name = 'corr_' + str(pass_index) + '_'
    time_generate_1 = time.time()
    logger_generate.info(
        'TTT : Generating data took {} seconds'.format(time_generate_1 -
                                                       time_generate_0))

    ############################################################
    # Reconstruct and write
    ############################################################
    prefix_name = 'final_'
    generation_utils.reconstruct_generation(pr_orchestra_gen, event_piano,
                                            generated_folder, prefix_name,
                                            parameters, seed_size)
    generation_utils.reconstruct_original(pr_piano_gen, pr_orchestra_truth,
                                          event_piano, generated_folder,
                                          parameters)
    return
def generate_midi(config_folder, score_source, number_of_version, duration_gen,
                  logger_generate):
    """This function generate the orchestration of a midi piano score
    
    Parameters
    ----------
    config_folder : str
        Absolute path to the configuration folder, i.e. the folder containing the saved model and the results
    score_source : str
        Either a path to a folder containing two midi files (piano and orchestration) or the path toa piano midi files
    number_of_version : int
        Number of version generated in a batch manner. Since the generation process involves sampling it might be interesting to generate several versions
    duration_gen : int
        Length of the generated score (in number of events). Useful for generating only the beginning of the piece.
    logger_generate : logger
        Instanciation of logging. Can be None
    """

    logger_generate.info("#############################################")
    logger_generate.info("Orchestrating : " + score_source)
    ############################################################
    # Load model, config and data
    ############################################################

    ########################
    # Load config and model
    parameters = pkl.load(open(config_folder + '/script_parameters.pkl', 'rb'))
    model_parameters = pkl.load(open(config_folder + '/model_params.pkl',
                                     'rb'))
    # Set a minimum seed size, because for very short models you don't event see the beginning
    seed_size = max(model_parameters['temporal_order'], 10) - 1
    quantization = parameters['quantization']
    temporal_granularity = parameters['temporal_granularity']
    instru_mapping = parameters['instru_mapping']
    ########################

    #######################
    # Load data
    if re.search(r'mid$', score_source):
        pr_piano, event_piano, duration_piano, name_piano, pr_orch, instru_orch, duration = load_solo(
            score_source, quantization, parameters["binarize_piano"],
            temporal_granularity)
    else:
        pr_piano, event_piano, duration_piano, name_piano, pr_orch, instru_orch, duration = load_from_pair(
            score_source, quantization, parameters["binarize_piano"],
            parameters["binarize_orch"], temporal_granularity)

    if (duration is None) or (duration < duration_gen):
        logger_generate.info("Track too short to be used")
        return
    ########################

    ########################
    # Shorten
    # Keep only the beginning of the pieces (let's say a 100 events)
    pr_piano = extract_pianoroll_part(pr_piano, 0, duration_gen)
    if parameters["duration_piano"]:
        duration_piano = np.asarray(duration_piano[:duration_gen])
    else:
        duration_piano = None
    if parameters["temporal_granularity"] == "event_level":
        event_piano = event_piano[:duration_gen]
    pr_orch = extract_pianoroll_part(pr_orch, 0, duration_gen)
    ########################

    ########################
    # Instanciate piano pianoroll
    N_piano = instru_mapping['Piano']['index_max']
    pr_piano_gen = np.zeros((duration_gen, N_piano), dtype=np.float32)
    pr_piano_gen = build_data_aux.cast_small_pr_into_big_pr(
        pr_piano, {}, 0, duration_gen, instru_mapping, pr_piano_gen)
    pr_piano_gen_flat = pr_piano_gen.sum(axis=1)
    silence_piano = [
        e for e in range(duration_gen) if pr_piano_gen_flat[e] == 0
    ]
    ########################

    ########################
    # Instanciate orchestra pianoroll with orchestra seed
    N_orchestra = parameters['N_orchestra']
    if pr_orch:
        pr_orchestra_gen = np.zeros((seed_size, N_orchestra), dtype=np.float32)
        orch_seed_beginning = {k: v[:seed_size] for k, v in pr_orch.items()}
        pr_orchestra_gen = build_data_aux.cast_small_pr_into_big_pr(
            orch_seed_beginning, instru_orch, 0, seed_size, instru_mapping,
            pr_orchestra_gen)
        pr_orchestra_truth = np.zeros((duration_gen, N_orchestra),
                                      dtype=np.float32)
        pr_orchestra_truth = build_data_aux.cast_small_pr_into_big_pr(
            pr_orch, instru_orch, 0, duration_gen, instru_mapping,
            pr_orchestra_truth)
    else:
        pr_orchestra_gen = None
        pr_orchestra_truth = None
    ########################

    #######################################
    # Embed piano
    time_embedding = time.time()
    if parameters['embedded_piano']:
        # Load model
        embedding_path = parameters["embedding_path"]
        embedding_model = torch.load(embedding_path, map_location="cpu")

        # Build embedding (no need to batch here, len(pr_piano_gen) is sufficiently small)
        # Plus no CUDA here because : afradi of mix with TF  +  possibly very long piano chunks
        piano_resize_emb = np.zeros(
            (len(pr_piano_gen), 1, 128))  # Embeddings accetp size 128 samples
        piano_resize_emb[:, 0, instru_mapping['Piano']['pitch_min']:
                         instru_mapping['Piano']['pitch_max']] = pr_piano_gen
        piano_resize_emb_TT = torch.tensor(piano_resize_emb)
        piano_embedded_TT = embedding_model(piano_resize_emb_TT.float(), 0)
        pr_piano_gen_embedded = piano_embedded_TT.numpy()
    else:
        pr_piano_gen_embedded = pr_piano_gen
    time_embedding = time.time() - time_embedding
    #######################################

    ########################
    # Inputs' normalization
    normalizer = pkl.load(
        open(os.path.join(config_folder, 'normalizer.pkl'), 'rb'))
    if parameters["embedded_piano"]:  # When using embedding, no normalization
        pr_piano_gen_norm = pr_piano_gen_embedded
    else:
        pr_piano_gen_norm = normalizer.transform(pr_piano_gen_embedded)
    ########################

    ########################
    # Store folder
    string = re.split(r'/', name_piano)[-1]
    name_track = re.sub('piano_solo.mid', '', string)
    generated_folder = config_folder + '/generation_reference_example/' + name_track
    if not os.path.isdir(generated_folder):
        os.makedirs(generated_folder)
    ########################

    ########################
    # Get trainer
    with open(os.path.join(config_folder, 'which_trainer'), 'r') as ff:
        which_trainer = ff.read()
    # Trainer
    trainer = import_trainer(which_trainer, model_parameters, parameters)

    ########################

    ############################################################
    # Generate
    ############################################################
    time_generate_0 = time.time()
    generated_sequences = {}
    for measure_name in parameters['save_measures']:
        model_path = 'model_' + measure_name
        generated_sequences[measure_name] = generate(
            trainer,
            pr_piano_gen_norm,
            silence_piano,
            duration_piano,
            config_folder,
            model_path,
            pr_orchestra_gen,
            batch_size=number_of_version)

    time_generate_1 = time.time()
    logger_generate.info(
        'TTT : Generating data took {} seconds'.format(time_generate_1 -
                                                       time_generate_0))

    ############################################################
    # Reconstruct and write
    ############################################################
    def reconstruct_write_aux(generated_sequences, prefix):
        for write_counter in range(generated_sequences.shape[0]):
            # To distinguish when seed stop, insert a sustained note
            this_seq = generated_sequences[write_counter] * 127
            this_seq[:seed_size, 0] = 20
            # Reconstruct
            if parameters['temporal_granularity'] == 'event_level':
                pr_orchestra_rhythm = from_event_to_frame(
                    this_seq, event_piano)
                pr_orchestra_rhythm_I = instrument_reconstruction(
                    pr_orchestra_rhythm, instru_mapping)
                write_path = generated_folder + '/' + prefix + '_' + str(
                    write_counter) + '_generated_rhythm.mid'
                write_midi(pr_orchestra_rhythm_I,
                           quantization,
                           write_path,
                           tempo=80)
            pr_orchestra_event = this_seq
            pr_orchestra_event_I = instrument_reconstruction(
                pr_orchestra_event, instru_mapping)
            write_path = generated_folder + '/' + prefix + '_' + str(
                write_counter) + '_generated.mid'
            write_midi(pr_orchestra_event_I, 1, write_path, tempo=80)
        return

    for measure_name in parameters["save_measures"]:
        reconstruct_write_aux(generated_sequences[measure_name], measure_name)

    ############################################################
    ############################################################
    if parameters["temporal_granularity"] == 'event_level':
        # Write original orchestration and piano scores, but reconstructed version, just to check
        A_rhythm = from_event_to_frame(pr_piano_gen, event_piano)
        B_rhythm = A_rhythm * 127
        piano_reconstructed_rhythm = instrument_reconstruction_piano(
            B_rhythm, instru_mapping)
        write_path = generated_folder + '/piano_reconstructed_rhythm.mid'
        write_midi(piano_reconstructed_rhythm,
                   quantization,
                   write_path,
                   tempo=80)
        # Truth
        A_rhythm = from_event_to_frame(pr_orchestra_truth, event_piano)
        B_rhythm = A_rhythm * 127
        orchestra_reconstructed_rhythm = instrument_reconstruction(
            B_rhythm, instru_mapping)
        write_path = generated_folder + '/orchestra_reconstructed_rhythm.mid'
        write_midi(orchestra_reconstructed_rhythm,
                   quantization,
                   write_path,
                   tempo=80)
        #
        A = pr_piano_gen
        B = A * 127
        piano_reconstructed = instrument_reconstruction_piano(
            B, instru_mapping)
        write_path = generated_folder + '/piano_reconstructed.mid'
        write_midi(piano_reconstructed, 1, write_path, tempo=80)
        #
        A = pr_orchestra_truth
        B = A * 127
        orchestra_reconstructed = instrument_reconstruction(B, instru_mapping)
        write_path = generated_folder + '/orchestra_reconstructed.mid'
        write_midi(orchestra_reconstructed, 1, write_path, tempo=80)
    else:
        A = pr_piano_gen
        B = A * 127
        piano_reconstructed = instrument_reconstruction_piano(
            B, instru_mapping)
        write_path = generated_folder + '/piano_reconstructed.mid'
        write_midi(piano_reconstructed, quantization, write_path, tempo=80)
        #
        A = pr_orchestra_truth
        B = A * 127
        orchestra_reconstructed = instrument_reconstruction(B, instru_mapping)
        write_path = generated_folder + '/orchestra_reconstructed.mid'
        write_midi(orchestra_reconstructed, quantization, write_path, tempo=80)
    ############################################################
    ############################################################
    return