예제 #1
0
def train(model, train_splits_batches, valid_splits_batches,
          test_splits_batches, normalizer, model_params, parameters,
          config_folder, start_time_train, logger_train):

    DEBUG = parameters["debug"]
    SUMMARIZE = DEBUG["summarize"]

    # Build DEBUG dict
    if DEBUG["save_measures"]:
        DEBUG["save_measures"] = config_folder + "/save_measures"

    # Time information used
    time_limit = parameters[
        'walltime'] * 3600 - 30 * 60  # walltime - 30 minutes in seconds

    # Reset graph before starting training
    tf.reset_default_graph()

    ###### PETIT TEST VALIDATION
    # Use same validation and train set
    # piano_valid, orch_valid, valid_index = piano_train, orch_train, train_index

    which_trainer = model.trainer()

    # Save it for generation. SO UGLY
    with open(os.path.join(config_folder, 'which_trainer'), 'w') as ff:
        ff.write(which_trainer)
    trainer = import_trainer(which_trainer, model_params, parameters)
    # Flag to know if the model has to be trained or not
    model_optimize = model.optimize()

    ############################################################
    # Display informations about the models
    num_parameters = model_statistics.count_parameters(tf.get_default_graph())
    logger_train.info(
        '** Num trainable parameters :  {}'.format(num_parameters))
    with open(os.path.join(config_folder, 'num_parameters.txt'), 'w') as ff:
        ff.write("{:d}".format(num_parameters))

    ############################################################
    # Training
    logger_train.info("#" * 60)
    logger_train.info("#### Training")
    epoch = 0
    OVERFITTING = False
    TIME_LIMIT = False

    # Train error
    loss_tab = np.zeros(max(1, parameters['max_iter']))

    # Select criteria
    overfitting_measure = parameters["overfitting_measure"]
    save_measures = parameters['save_measures']

    # Short-term validation error
    valid_tabs = {
        'loss': np.zeros(max(1, parameters['max_iter'])),
        'accuracy': np.zeros(max(1, parameters['max_iter'])),
        'precision': np.zeros(max(1, parameters['max_iter'])),
        'recall': np.zeros(max(1, parameters['max_iter'])),
        'true_accuracy': np.zeros(max(1, parameters['max_iter'])),
        'f_score': np.zeros(max(1, parameters['max_iter'])),
        'Xent': np.zeros(max(1, parameters['max_iter']))
    }
    # Best epoch for each measure
    best_epoch = {
        'loss': 0,
        'accuracy': 0,
        'precision': 0,
        'recall': 0,
        'true_accuracy': 0,
        'f_score': 0,
        'Xent': 0
    }

    # Sampled preds measures
    valid_tabs_sampled = {
        'loss': np.zeros(max(1, parameters['max_iter'])),
        'accuracy': np.zeros(max(1, parameters['max_iter'])),
        'precision': np.zeros(max(1, parameters['max_iter'])),
        'recall': np.zeros(max(1, parameters['max_iter'])),
        'true_accuracy': np.zeros(max(1, parameters['max_iter'])),
        'f_score': np.zeros(max(1, parameters['max_iter'])),
        'Xent': np.zeros(max(1, parameters['max_iter']))
    }

    # Long-term validation error
    valid_tabs_LR = {
        'loss': np.zeros(max(1, parameters['max_iter'])),
        'accuracy': np.zeros(max(1, parameters['max_iter'])),
        'precision': np.zeros(max(1, parameters['max_iter'])),
        'recall': np.zeros(max(1, parameters['max_iter'])),
        'true_accuracy': np.zeros(max(1, parameters['max_iter'])),
        'f_score': np.zeros(max(1, parameters['max_iter'])),
        'Xent': np.zeros(max(1, parameters['max_iter']))
    }
    # Best epoch for each measure
    best_epoch_LR = {
        'loss': 0,
        'accuracy': 0,
        'precision': 0,
        'recall': 0,
        'true_accuracy': 0,
        'f_score': 0,
        'Xent': 0
    }

    ### Timing file
    # open('timing', 'w').close()

    if parameters['memory_gpu']:
        # This apparently does not work
        configSession = tf.ConfigProto()
        configSession.gpu_options.per_process_gpu_memory_fraction = parameters[
            'memory_gpu']
    else:
        configSession = None

    with tf.Session(config=configSession) as sess:

        # Only for models with shared weights
        model.init_weights()

        ##############################
        # Create PH and nodes
        if parameters['pretrained_model'] is None:
            logger_train.info((u'#### Graph'))
            start_time_building_graph = time.time()
            trainer.build_variables_nodes(model, parameters)
            trainer.build_preds_nodes(model)
            trainer.build_loss_nodes(model, parameters)
            trainer.build_train_step_node(model, config.optimizer())
            trainer.save_nodes(model)
            time_building_graph = time.time() - start_time_building_graph
            logger_train.info("TTT : Building the graph took {0:.2f}s".format(
                time_building_graph))
        else:
            logger_train.info((u'#### Graph'))
            start_time_building_graph = time.time()
            trainer.load_pretrained_model(parameters['pretrained_model'])
            time_building_graph = time.time() - start_time_building_graph
            logger_train.info(
                "TTT : Loading pretrained model took {0:.2f}s".format(
                    time_building_graph))

        if SUMMARIZE:
            tf.summary.scalar('loss', trainer.loss)
        ##############################

        summarize_dict = {}

        if SUMMARIZE:
            merged_node = tf.summary.merge_all()
            train_writer = tf.summary.FileWriter(config_folder + '/summary',
                                                 sess.graph)
            train_writer.add_graph(tf.get_default_graph())
        else:
            merged_node = None
        summarize_dict['bool'] = SUMMARIZE
        summarize_dict['merged_node'] = merged_node

        if model.is_keras():
            K.set_session(sess)

        # Initialize weights
        if parameters['pretrained_model']:
            trainer.saver.restore(sess,
                                  parameters['pretrained_model'] + '/model')
        else:
            sess.run(tf.global_variables_initializer())

        # if DEBUG:
        # 	sess = tf_debug.LocalCLIDebugWrapperSession(sess)
        # 	sess.add_tensor_filter("has_inf_or_nan", tf_debug.has_inf_or_nan)

        N_matrix_files = len(train_splits_batches)

        #######################################
        # Load first matrix
        #######################################
        load_data_start = time.time()
        pool = ThreadPool(processes=1)
        async_train = pool.apply_async(
            async_load_mat,
            (normalizer, train_splits_batches[0]['chunks_folders'], parameters,
             trainer.np_type))
        matrices_from_thread = async_train.get()
        load_data_time = time.time() - load_data_start
        logger_train.info("Load the first matrix time : " +
                          str(load_data_time))

        # For dumb baseline models like random or repeat which don't need training step optimization
        if model_optimize == False:
            # WARNING : first validation matrix is not necessarily the same as the first train matrix
            async_test = pool.apply_async(
                async_load_mat,
                (normalizer, test_splits_batches[0]['chunks_folders'],
                 parameters, trainer.np_type))
            init_matrices_test = async_test.get()
            test_results, test_results_sampled, test_long_range_results, _, _ = validate(
                trainer, sess, init_matrices_test, test_splits_batches,
                normalizer, parameters, logger_train)
            training_utils.mean_and_store_results(test_results, valid_tabs, 0)
            training_utils.mean_and_store_results(test_results_sampled,
                                                  valid_tabs_sampled, 0)
            training_utils.mean_and_store_results(test_long_range_results,
                                                  valid_tabs_LR, 0)

            test_tab = {}
            test_tab_sampled = {}
            test_tab_LR = {}
            training_utils.mean_and_store_results(test_results, test_tab, None)
            training_utils.mean_and_store_results(test_results_sampled,
                                                  test_tab_sampled, None)
            training_utils.mean_and_store_results(test_long_range_results,
                                                  test_tab_LR, None)

            accuracy_training = np.zeros((100))

            return accuracy_training, training_utils.remove_tail_training_curves(valid_tabs, 1), test_tab, best_epoch, \
             training_utils.remove_tail_training_curves(valid_tabs_sampled, 1), test_tab_sampled, best_epoch, \
             training_utils.remove_tail_training_curves(valid_tabs_LR, 1), test_tab_LR, best_epoch_LR

        accuracy_training = []

        # Training iteration
        while (not OVERFITTING and not TIME_LIMIT
               and epoch != parameters['max_iter']):
            start_time_epoch = time.time()

            trainer.DEBUG['epoch'] = epoch

            train_cost_epoch = []
            sparse_loss_epoch = []

            train_time = time.time()

            this_batch_accuracy = []

            for file_ind_CURRENT in range(N_matrix_files):

                #######################################
                # Get indices and matrices to load
                #######################################
                # We train on the current matrix
                train_index = train_splits_batches[file_ind_CURRENT]['batches']
                # But load the one next one
                file_ind_NEXT = (file_ind_CURRENT + 1) % N_matrix_files
                next_chunks = train_splits_batches[file_ind_NEXT][
                    'chunks_folders']

                #######################################
                # Load matrix thread
                #######################################
                async_train = pool.apply_async(
                    async_load_mat,
                    (normalizer, next_chunks, parameters, trainer.np_type))
                piano_input, orch_transformed, duration_piano, mask_orch = matrices_from_thread

                #######################################
                # Train
                #######################################
                for batch_index in train_index:

                    loss_batch, preds_batch, debug_outputs, summary = trainer.training_step(
                        sess, batch_index, piano_input, orch_transformed,
                        duration_piano, mask_orch, summarize_dict)

                    orch_t = orch_transformed[batch_index]
                    accuracy_batch = measure.accuracy_measure(
                        orch_t, preds_batch)
                    this_batch_accuracy.extend(accuracy_batch)

                    # Keep track of cost
                    train_cost_epoch.append(loss_batch)
                    sparse_loss_batch = debug_outputs["sparse_loss_batch"]
                    sparse_loss_epoch.append(sparse_loss_batch)

                #######################################
                # New matrices from thread
                #######################################
                del (matrices_from_thread)
                matrices_from_thread = async_train.get()

            this_batch_accuracy_mean = -100 * np.mean(this_batch_accuracy)
            accuracy_training.append(this_batch_accuracy_mean)

            train_time = time.time() - train_time
            logger_train.info("Training time : {}".format(train_time))

            ###
            # DEBUG
            if trainer.DEBUG["plot_weights"]:
                # weight_folder=config_folder+"/weights/"+str(epoch)
                weight_folder = config_folder + "/weights"
                plot_weights.plot_weights(sess, weight_folder)

            #
            ###

            # WARNING : first validation matrix is not necessarily the same as the first train matrix
            # So now that it's here, parallelization is absolutely useless....
            async_valid = pool.apply_async(
                async_load_mat,
                (normalizer, valid_splits_batches[0]['chunks_folders'],
                 parameters, trainer.np_type))

            if SUMMARIZE:
                if (epoch < 5) or (epoch % 10 == 0):
                    # Note that summarize here only look at the variables after the last batch of the epoch
                    # If you want to look at all the batches, include it in
                    train_writer.add_summary(summary, epoch)

            mean_loss = np.mean(train_cost_epoch)
            loss_tab[epoch] = mean_loss

            #######################################
            # Validate
            #######################################
            valid_time = time.time()
            init_matrices_validation = async_valid.get()

            # Create DEBUG folders
            if trainer.DEBUG["plot_nade_ordering_preds"]:
                AAA = config_folder + "/DEBUG/preds_nade/" + str(epoch)
                trainer.DEBUG["plot_nade_ordering_preds"] = AAA
                os.makedirs(AAA)
            if trainer.DEBUG["save_accuracy_along_sampling"]:
                AAA = config_folder + "/DEBUG/accuracy_along_sampling/" + str(
                    epoch)
                trainer.DEBUG["save_accuracy_along_sampling"] = AAA
                os.makedirs(AAA)
            if trainer.DEBUG["salience_embedding"]:
                AAA = config_folder + "/DEBUG/salience_embeddings/" + str(
                    epoch)
                trainer.DEBUG["salience_embedding"] = AAA
                os.makedirs(AAA)

            valid_results, valid_results_sampled, valid_long_range_results, preds_val, truth_val = \
             validate(trainer, sess,
              init_matrices_validation, valid_splits_batches,
              normalizer, parameters,
              logger_train)
            valid_time = time.time() - valid_time
            logger_train.info("Validation time : {}".format(valid_time))

            training_utils.mean_and_store_results(valid_results, valid_tabs,
                                                  epoch)
            training_utils.mean_and_store_results(valid_results_sampled,
                                                  valid_tabs_sampled, epoch)
            training_utils.mean_and_store_results(valid_long_range_results,
                                                  valid_tabs_LR, epoch)
            end_time_epoch = time.time()

            #######################################
            # Overfitting ?
            if epoch >= parameters['min_number_iteration']:
                # Choose short/long range and the measure
                OVERFITTING = early_stopping.up_criterion(
                    valid_tabs[overfitting_measure], epoch,
                    parameters["number_strips"],
                    parameters["validation_order"])
                if not OVERFITTING:
                    # Also check for NaN
                    OVERFITTING = early_stopping.check_for_nan(valid_tabs,
                                                               save_measures,
                                                               max_nan=3)
            #######################################

            #######################################
            # Monitor time (guillimin walltime)
            if (time.time() - start_time_train) > time_limit:
                TIME_LIMIT = True
            #######################################

            #######################################
            # Log training
            #######################################
            logger_train.info(
                "############################################################")
            logger_train.info(
                'Epoch : {} , Training loss : {} , Validation loss : {} \n \
		Validation accuracy : {:.3f} ; {:.3f} \n \
		Precision : {:.3f} ; {:.3f} \n \
		Recall : {:.3f} ; {:.3f} \n \
		Xent : {:.6f} ; {:.3f} \n \
		Sparse_loss : {:.3f}'.format(epoch, mean_loss, valid_tabs['loss'][epoch],
                               valid_tabs['accuracy'][epoch],
                               valid_tabs_sampled['accuracy'][epoch],
                               valid_tabs['precision'][epoch],
                               valid_tabs_sampled['precision'][epoch],
                               valid_tabs['recall'][epoch],
                               valid_tabs_sampled['recall'][epoch],
                               valid_tabs['Xent'][epoch],
                               valid_tabs_sampled['Xent'][epoch],
                               np.mean(sparse_loss_epoch)))

            logger_train.info('Time : {}'.format(end_time_epoch -
                                                 start_time_epoch))

            #######################################
            # Best model ?
            # Xent criterion
            start_time_saving = time.time()
            for measure_name, measure_curve in valid_tabs.items():
                best_measure_so_far = measure_curve[best_epoch[measure_name]]
                measure_for_this_epoch = measure_curve[epoch]
                if (measure_for_this_epoch <= best_measure_so_far) or (epoch
                                                                       == 0):
                    if measure_name in save_measures:
                        trainer.saver.save(
                            sess, config_folder + "/model_" + measure_name +
                            "/model")
                    best_epoch[measure_name] = epoch

                #######################################
                # DEBUG
                # Save numpy arrays of measures values
                if trainer.DEBUG["save_measures"]:
                    if os.path.isdir(trainer.DEBUG["save_measures"]):
                        shutil.rmtree(trainer.DEBUG["save_measures"])
                    os.makedirs(trainer.DEBUG["save_measures"])
                    for measure_name, measure_tab in valid_results.items():
                        np.save(
                            os.path.join(trainer.DEBUG["save_measures"],
                                         measure_name + '.npy'),
                            measure_tab[:2000])
                    np.save(
                        os.path.join(trainer.DEBUG["save_measures"],
                                     'preds.npy'),
                        np.asarray(preds_val[:2000]))
                    np.save(
                        os.path.join(trainer.DEBUG["save_measures"],
                                     'truth.npy'),
                        np.asarray(truth_val[:2000]))
                #######################################

            end_time_saving = time.time()
            logger_train.info('Saving time : {:.3f}'.format(end_time_saving -
                                                            start_time_saving))
            #######################################

            if OVERFITTING:
                logger_train.info('OVERFITTING !!')

            if TIME_LIMIT:
                logger_train.info('TIME OUT !!')

            #######################################
            # Epoch +1
            #######################################
            epoch += 1

        #######################################
        # Test
        #######################################
        test_time = time.time()
        async_test = pool.apply_async(
            async_load_mat,
            (normalizer, test_splits_batches[0]['chunks_folders'], parameters,
             trainer.np_type))
        init_matrices_test = async_test.get()
        test_results, test_results_sampled, test_long_range_results, preds_test, truth_test = \
         validate(trainer, sess,
          init_matrices_test, test_splits_batches,
          normalizer, parameters,
          logger_train)
        test_time = time.time() - test_time
        logger_train.info("Test time : {}".format(test_time))

        test_tab = {}
        test_tab_sampled = {}
        test_tab_LR = {}
        training_utils.mean_and_store_results(test_results, test_tab, None)
        training_utils.mean_and_store_results(test_results_sampled,
                                              test_tab_sampled, None)
        training_utils.mean_and_store_results(test_long_range_results,
                                              test_tab_LR, None)

        logger_train.info(
            "############################################################")
        logger_train.info("""## Test Scores
Loss : {}
Validation accuracy : {:.3f} %, precision : {:.3f} %, recall : {:.3f} %
True_accuracy : {:.3f} %, f_score : {:.3f} %, Xent : {:.6f}""".format(
            test_tab['loss'], test_tab['accuracy'], test_tab['precision'],
            test_tab['recall'], test_tab['true_accuracy'], test_tab['f_score'],
            test_tab['Xent']))
        logger_train.info('Time : {}'.format(test_time))

        #######################################
        # Close workers' pool
        #######################################
        pool.close()
        pool.join()

    return accuracy_training, training_utils.remove_tail_training_curves(valid_tabs, epoch), test_tab, best_epoch, \
     training_utils.remove_tail_training_curves(valid_tabs_sampled, epoch), test_tab_sampled, best_epoch,\
     training_utils.remove_tail_training_curves(valid_tabs_LR, epoch), test_tab_LR, best_epoch_LR


# bias=[v.eval() for v in tf.global_variables() if v.name == "top_layer_prediction/orch_pred/bias:0"][0]
# kernel=[v.eval() for v in tf.global_variables() if v.name == "top_layer_prediction/orch_pred/kernel:0"][0]
    def orderless_NADE_generation(self, sess, feed_dict, orch_t):
        # Pre-compute the context embedding
        # which will be the same for all the orderings
        context_embedding = sess.run(self.context_embedding_out, feed_dict)

        # Generate the orderings in parallel -> duplicate the embedding and orch_t matrices along batch dim
        batch_size, orch_dim = orch_t.shape
        # Start with an orchestra prediction and mask equal to zero
        orch_t_reshape = np.concatenate(
            [orch_t for _ in range(self.num_ordering)], axis=0)
        orch_pred = np.zeros_like(orch_t_reshape)
        mask = np.zeros_like(orch_t_reshape)
        context_embedding_reshape = np.concatenate(
            [context_embedding for _ in range(self.num_ordering)], axis=0)

        # Three nodes to feed now: orch_pred, context_embedding, and mask
        feed_dict_known_context = {}
        feed_dict_known_context[self.orch_t_ph] = orch_t_reshape
        feed_dict_known_context[
            self.context_embedding_in] = context_embedding_reshape

        # Build the orderings (use the same ordering for all elems in batch)
        orderings = []
        for ordering_ind in range(self.num_ordering):
            # This ordering
            ordering = list(range(orch_dim))
            random.shuffle(ordering)
            orderings.append(ordering)

        if self.DEBUG["save_accuracy_along_sampling"]:
            accuracy_along_sampling = []

        # Loop over the length of the orderings
        for d in range(orch_dim):
            # Generate step
            feed_dict_known_context[self.orch_pred] = orch_pred
            feed_dict_known_context[self.mask_input] = mask

            loss_batch, preds_batch = sess.run([self.loss_val, self.preds_gen],
                                               feed_dict_known_context)

            ##############################
            ##############################
            # DEBUG
            # Observe the evolution of the accuracy along the sampling process
            if self.DEBUG["save_accuracy_along_sampling"]:
                accuracy_batch = np.mean(
                    accuracy_measure(orch_t_reshape, preds_batch))
                accuracy_along_sampling.append(accuracy_batch)
            # Plot the predictions
            if self.DEBUG["plot_nade_ordering_preds"] and (
                    self.DEBUG["batch_counter"]
                    == (self.DEBUG["num_batch"] - 1)):
                for ordering_ind in range(self.num_ordering):
                    batch_begin = batch_size * ordering_ind
                    batch_end = batch_size * (ordering_ind + 1)
                    np.save(
                        self.DEBUG["plot_nade_ordering_preds"] + '/' + str(d) +
                        '_' + str(ordering_ind) + '.npy',
                        preds_batch[batch_begin:batch_end, :])
                mean_pred_batch = self.mean_parallel_prediction(
                    batch_size, preds_batch)
                np.save(
                    self.DEBUG["plot_nade_ordering_preds"] + '/' + str(d) +
                    '_mean.npy', mean_pred_batch)
            ##############################
            ##############################

            # Update matrices
            for ordering_ind in range(self.num_ordering):
                batch_begin = batch_size * ordering_ind
                batch_end = batch_size * (ordering_ind + 1)
                mask[batch_begin:batch_end, orderings[ordering_ind][d]] = 1
                ##################################################
                # Mean-field or sampling ? Sampling because need binary values to move along the Gibbs/NADE process
                orch_pred[batch_begin:batch_end,
                          orderings[ordering_ind][d]] = np.random.binomial(
                              1, preds_batch[batch_begin:batch_end,
                                             orderings[ordering_ind][d]])
                ##################################################

        # Now continue Gibbs sampling
        while (d < orch_dim * self.gibbs_sampling_factor):

            pitch_resampled = random.randint(0, orch_dim - 1)
            d += 1

            ##############################
            # Randomly set one value in the mask to zero
            mask[:, pitch_resampled] = 0
            ##############################

            ##############################
            feed_dict_known_context[self.orch_pred] = orch_pred
            feed_dict_known_context[self.mask_input] = mask
            ##############################

            ##############################
            loss_batch, preds_batch = sess.run([self.loss_val, self.preds_gen],
                                               feed_dict_known_context)
            ##############################

            ##############################
            ##############################
            # DEBUG
            # Observe the evolution of the accuracy along the sampling process
            if self.DEBUG["save_accuracy_along_sampling"]:
                accuracy_batch = np.mean(
                    accuracy_measure(orch_t_reshape, preds_batch))
                accuracy_along_sampling.append(accuracy_batch)
            # Plot the predictions
            if self.DEBUG["plot_nade_ordering_preds"] and (
                    self.DEBUG["batch_counter"]
                    == (self.DEBUG["num_batch"] - 1)):
                for ordering_ind in range(self.num_ordering):
                    batch_begin = batch_size * ordering_ind
                    batch_end = batch_size * (ordering_ind + 1)
                    np.save(
                        self.DEBUG["plot_nade_ordering_preds"] + '/' + str(d) +
                        '_' + str(ordering_ind) + '.npy',
                        preds_batch[batch_begin:batch_end, :])
                mean_pred_batch = self.mean_parallel_prediction(
                    batch_size, preds_batch)
                np.save(
                    self.DEBUG["plot_nade_ordering_preds"] + '/' + str(d) +
                    '_mean.npy', mean_pred_batch)
            ##############################
            ##############################

            ##############################
            # Write back the mask to 1
            mask[:, pitch_resampled] = 1
            # Resample
            orch_pred[:, pitch_resampled] = np.random.binomial(
                1, preds_batch[:, pitch_resampled])
            ##############################

        if self.DEBUG["plot_nade_ordering_preds"] and (
                self.DEBUG["batch_counter"] == (self.DEBUG["num_batch"] - 1)):
            np.save(self.DEBUG["plot_nade_ordering_preds"] + '/truth.npy',
                    orch_t)

        # Save accuracy_along_sampling
        if self.DEBUG["save_accuracy_along_sampling"]:
            save_file_path = self.DEBUG[
                "save_accuracy_along_sampling"] + '/' + str(
                    self.DEBUG["batch_counter"]) + '.txt'
            with open(save_file_path, 'w') as thefile:
                for item in accuracy_along_sampling:
                    thefile.write("{:.4f}\n".format(100 * item))

        return orch_pred, loss_batch
예제 #3
0
def validate(trainer, sess, init_matrices_validation, valid_splits_batches,
             normalizer, parameters, logger, DEBUG):

    temporal_order = trainer.temporal_order

    accuracy = []
    precision = []
    recall = []
    val_loss = []
    true_accuracy = []
    f_score = []
    Xent = []

    if DEBUG["save_measures"]:
        preds = []
        truth = []
    else:
        preds = None
        truth = None

    accuracy_long_range = []
    precision_long_range = []
    recall_long_range = []
    val_loss_long_range = []
    true_accuracy_long_range = []
    f_score_long_range = []
    Xent_long_range = []

    N_matrix_files = len(valid_splits_batches)
    pool = ThreadPool(processes=1)
    matrices_from_thread = init_matrices_validation

    for file_ind_CURRENT in range(N_matrix_files):
        #######################################
        # Get indices and matrices to load
        #######################################
        # We train on the current matrix
        valid_index = valid_splits_batches[file_ind_CURRENT]['batches']
        valid_long_range_index = valid_splits_batches[file_ind_CURRENT][
            'batches_lr']
        # But load the one next one
        file_ind_NEXT = (file_ind_CURRENT + 1) % N_matrix_files
        next_chunks = valid_splits_batches[file_ind_NEXT]['chunks_folders']

        #######################################
        # Load matrix thread
        #######################################
        async_valid = pool.apply_async(async_load_mat,
                                       (normalizer, next_chunks, parameters))

        piano_input, orch_input, duration_piano, mask_orch = matrices_from_thread

        #######################################
        # Loop for short-term validation
        #######################################
        for batch_counter, batch_index in enumerate(valid_index):
            if DEBUG["plot_nade_ordering_preds"] and (batch_counter
                                                      == len(valid_index) - 1):
                if os.path.isdir(DEBUG["plot_nade_ordering_preds"]):
                    shutil.rmtree(DEBUG["plot_nade_ordering_preds"])
                os.makedirs(DEBUG["plot_nade_ordering_preds"])
                loss_batch, preds_batch, orch_t = trainer.valid_step(
                    sess,
                    batch_index,
                    piano_input,
                    orch_input,
                    duration_piano,
                    mask_orch,
                    PLOTING_FOLDER=DEBUG["plot_nade_ordering_preds"])
            else:
                loss_batch, preds_batch, orch_t = trainer.valid_step(
                    sess,
                    batch_index,
                    piano_input,
                    orch_input,
                    duration_piano,
                    mask_orch,
                    PLOTING_FOLDER=None)

            Xent_batch = binary_cross_entropy(orch_t, preds_batch)
            accuracy_batch = accuracy_measure(orch_t, preds_batch)
            precision_batch = precision_measure(orch_t, preds_batch)
            recall_batch = recall_measure(orch_t, preds_batch)
            true_accuracy_batch = true_accuracy_measure(orch_t, preds_batch)
            f_score_batch = f_measure(orch_t, preds_batch)

            val_loss.extend(loss_batch)
            accuracy.extend(accuracy_batch)
            precision.extend(precision_batch)
            recall.extend(recall_batch)
            true_accuracy.extend(true_accuracy_batch)
            f_score.extend(f_score_batch)
            Xent.extend(Xent_batch)

            if DEBUG["save_measures"]:
                # No need to store all the training points
                preds.extend(preds_batch)
                truth.extend(orch_t)

        #######################################
        # Loop for long-term validation
        # The task is filling a gap of size parameters["long_range"]
        # So the algo is given :
        #       orch[0:temporal_order]
        #       orch[temporal_order+parameters["long_range"]:
        #            (2*temporal_order)+parameters["long_range"]]
        # And must fill :
        #       orch[temporal_order:
        #            temporal_order+parameters["long_range"]]
        #######################################
        for batch_index in valid_long_range_index:
            # Init
            # Extract from piano and orchestra the matrices required for the task
            seq_len = (temporal_order - 1) * 2 + parameters["long_range"]
            piano_dim = piano_input.shape[1]
            orch_dim = orch_input.shape[1]
            piano_extracted = np.zeros((len(batch_index), seq_len, piano_dim))
            orch_extracted = np.zeros((len(batch_index), seq_len, orch_dim))
            if parameters["duration_piano"]:
                duration_piano_extracted = np.zeros(
                    (len(batch_index), seq_len))
            else:
                duration_piano_extracted = None
            orch_gen = np.zeros((len(batch_index), seq_len, orch_dim))

            for ind_b, this_batch_ind in enumerate(batch_index):
                start_ind = this_batch_ind - temporal_order + 1
                end_ind = start_ind + seq_len
                piano_extracted[ind_b] = piano_input[start_ind:end_ind, :]
                orch_extracted[ind_b] = orch_input[start_ind:end_ind, :]
                if parameters["duration_piano"]:
                    duration_piano_extracted[ind_b] = duration_piano[
                        start_ind:end_ind]

            # We know the past orchestration at the beginning...
            orch_gen[:, :temporal_order -
                     1, :] = orch_extracted[:, :temporal_order - 1, :]
            # and the future orchestration at the end
            orch_gen[:, -temporal_order +
                     1:, :] = orch_extracted[:, -temporal_order + 1:, :]
            # check we didn't gave the correct information
            assert orch_gen[:, temporal_order - 1:(
                temporal_order - 1) + parameters["long_range"], :].sum(
                ) == 0, "The gap to fill in orch_gen contains values !"

            for t in range(temporal_order - 1,
                           temporal_order - 1 + parameters["long_range"]):

                loss_batch, preds_batch, orch_t = trainer.valid_long_range_step(
                    sess, t, piano_extracted, orch_extracted, orch_gen,
                    duration_piano_extracted)

                prediction_sampled = np.random.binomial(1, preds_batch)
                orch_gen[:, t, :] = prediction_sampled

                # Compute performances measures
                Xent_batch = binary_cross_entropy(orch_t, preds_batch)
                accuracy_batch = accuracy_measure(orch_t, preds_batch)
                precision_batch = precision_measure(orch_t, preds_batch)
                recall_batch = recall_measure(orch_t, preds_batch)
                true_accuracy_batch = true_accuracy_measure(
                    orch_t, preds_batch)
                f_score_batch = f_measure(orch_t, preds_batch)

                val_loss_long_range.extend(loss_batch)
                accuracy_long_range.extend(accuracy_batch)
                precision_long_range.extend(precision_batch)
                recall_long_range.extend(recall_batch)
                true_accuracy_long_range.extend(true_accuracy_batch)
                f_score_long_range.extend(f_score_batch)
                Xent_long_range.extend(Xent_batch)

        del (matrices_from_thread)
        matrices_from_thread = async_valid.get()

    pool.close()
    pool.join()

    valid_results = {
        'accuracy': np.asarray(accuracy),
        'precision': np.asarray(precision),
        'recall': np.asarray(recall),
        'loss': np.asarray(val_loss),
        'true_accuracy': np.asarray(true_accuracy),
        'f_score': np.asarray(f_score),
        'Xent': np.asarray(Xent)
    }
    valid_long_range_results = {
        'accuracy': np.asarray(accuracy_long_range),
        'precision': np.asarray(precision_long_range),
        'recall': np.asarray(recall_long_range),
        'loss': np.asarray(val_loss_long_range),
        'true_accuracy': np.asarray(true_accuracy_long_range),
        'f_score': np.asarray(f_score_long_range),
        'Xent': np.asarray(Xent_long_range)
    }
    return valid_results, valid_long_range_results, preds, truth
예제 #4
0
def accuracy_and_binary_Xent(context, batches, plot_folder, N_example):
    """Used to compare what is considered good or bad for the accuracy and the neg-ll. Both values are computed over the whole test set.
    After normalization (mean and std), we plot the most representative examples of all the possible combinations of bad, average and good for neg-ll and acc
    
    Parameters
    ----------
    sess
        tensor flow session
    temporal_order
        temporal_order of the model, i.e. how far in the past/future is the model allowed to look
    batches
        a list containing lists of indices. Each list of indices is a mini-batch
    piano
        matrix containing the piano score
    orch
        matrix containing the orch score
    inputs_ph
        lists of placeholders which are inputs of the graph
    orch_t_ph
        placeholder for the true orchestral frame
    preds
        placeholder for the prediction of the network
    keras_learning_phase
        placeholder for a flag used by Keras

    Returns
    -------
    None
    
    """

    # Clean plot directory
    if os.path.isdir(plot_folder):
        shutil.rmtree(plot_folder)
    os.makedirs(plot_folder)

    # Unpack context variables
    sess = context['sess']
    temporal_order = context['temporal_order']
    piano = context['piano']
    orch = context['orch']
    duration_piano = context['duration_piano']
    mask_orch = context["mask_orch"]
    inputs_ph = context['inputs_ph']
    orch_t_ph = context['orch_t_ph']
    preds = context['preds']
    keras_learning_phase = context['keras_learning_phase']

    piano_t_ph, piano_past_ph, piano_future_ph, orch_past_ph, orch_future_ph = inputs_ph
    all_preds = []
    all_truth = []
    # Get all the predictions for the whole set
    for batch_index in batches:
        # Build batch
        piano_t, piano_past, piano_future, orch_past, orch_future, orch_t = build_batch(
            batch_index, piano, orch, duration_piano, mask_orch,
            len(batch_index), temporal_order)
        # Input nodes
        feed_dict = {
            piano_t_ph: piano_t,
            piano_past_ph: piano_past,
            piano_future_ph: piano_future,
            orch_past_ph: orch_past,
            orch_future_ph: orch_future,
            orch_t_ph: orch_t,
            keras_learning_phase: 0
        }
        # Compute validation loss
        preds_batch = sess.run(preds, feed_dict)
        all_preds.extend(preds_batch)
        all_truth.extend(orch_t)

    # Cast as np array
    preds_mat = np.asarray(all_preds)
    truth_mat = np.asarray(all_truth)

    # Compute neg-ll (use keras function for coherency with validation step)
    truth_mat_ph = tf.placeholder(tf.float32,
                                  shape=(truth_mat.shape),
                                  name="truth_mat")
    preds_mat_ph = tf.placeholder(tf.float32,
                                  shape=(preds_mat.shape),
                                  name="preds_mat")
    neg_ll_node = keras.losses.binary_crossentropy(truth_mat_ph, preds_mat_ph)
    neg_ll = sess.run(neg_ll_node, {
        preds_mat_ph: preds_mat,
        truth_mat_ph: truth_mat
    })

    # Compute acc
    acc = 100 * accuracy_measure(truth_mat, preds_mat)

    # Normalize
    def normalize(matrix):
        mean = np.mean(matrix)
        std = np.std(matrix)
        norm_mat = (matrix - mean) / std
        return norm_mat, mean, std

    neg_ll_norm, neg_ll_mean, neg_ll_std = normalize(neg_ll)
    acc_norm, acc_mean, acc_std = normalize(acc)

    # Rank perf
    arg_neg_ll = np.argsort(neg_ll_norm)
    arg_acc = np.argsort(
        -acc_norm)  # minus to have the best at the first index
    num_index = len(arg_neg_ll)

    # Plots
    def plot(ind, this_folder):
        temp_mat = np.stack((truth_mat[ind], preds_mat[ind]))
        visualize_mat_proba(
            temp_mat, this_folder,
            "acc_" + str(acc[ind]) + "_nll_" + str(neg_ll[ind]))

    # Reste juste à parser les arg_sorted pour avoir les bons et mauvais exemples
    for index in range(N_example):
        good_index = index
        bad_index = -index - 1
        average_index = (num_index / 2) + (1 - 2 * (index % 2)) * index
        # Bad Xent
        bad_xent_folder = os.path.join(plot_folder, "bad_Xent_" + str(index))
        plot(arg_neg_ll[bad_index], bad_xent_folder)
        # Average Xent
        plot(arg_neg_ll[average_index],
             os.path.join(plot_folder, "average_Xent_" + str(index)))
        # Good Xent
        plot(arg_neg_ll[good_index],
             os.path.join(plot_folder, "good_Xent_" + str(index)))
        # Bad acc
        plot(arg_acc[bad_index],
             os.path.join(plot_folder, "bad_acc_" + str(index)))
        # Average acc
        plot(arg_acc[average_index],
             os.path.join(plot_folder, "average_acc_" + str(index)))
        # Good acc
        plot(arg_acc[good_index],
             os.path.join(plot_folder, "good_acc_" + str(index)))

    # Write the statistics
    with open(os.path.join(plot_folder, "statistics.txt"), "wb") as f:
        f.write("Accuracy mean : " + str(acc_mean) + "\n")
        f.write("Accuracy std : " + str(acc_std) + "\n")
        f.write("Xent mean : " + str(neg_ll_mean) + "\n")
        f.write("Xent std : " + str(neg_ll_std) + "\n")
    return
def compare_Xent_acc_corresponding_preds(context, batches, save_folder):
    """Save te matrices for the accuracy score and Xent score and the corresponding ground-truth and predictions 

    """
    
    # Clean plot directory
    if os.path.isdir(save_folder):
        shutil.rmtree(save_folder)
    os.makedirs(save_folder)
        
    # Unpack context variables    
    sess = context['sess']
    temporal_order = context['temporal_order']
    piano = context['piano']
    orch = context['orch']
    duration_piano = context["duration_piano"]
    mask_orch = context["mask_orch"]
    inputs_ph = context['inputs_ph']
    orch_t_ph = context['orch_t_ph']
    preds = context['preds']
    keras_learning_phase = context['keras_learning_phase']
    
    piano_t_ph, piano_past_ph, piano_future_ph, orch_past_ph, orch_future_ph = inputs_ph
    all_preds = []
    all_truth = []
    # Get all the predictions for the whole set
    for batch_index in batches:
        # Build batch
        piano_t, piano_past, piano_future, orch_past, orch_future, orch_t = build_batch(batch_index, piano, orch, duration_piano, mask_orch, len(batch_index), temporal_order)
        # Input nodes
        feed_dict = {piano_t_ph: piano_t,
                    piano_past_ph: piano_past,
                    piano_future_ph: piano_future,
                    orch_past_ph: orch_past,
                    orch_future_ph: orch_future,
                    orch_t_ph: orch_t,
                    keras_learning_phase: 0}
        # Compute validation loss
        preds_batch = sess.run(preds, feed_dict)
        all_preds.extend(preds_batch)
        all_truth.extend(orch_t)
        
    # Cast as np array
    preds_mat = np.asarray(all_preds)
    truth_mat = np.asarray(all_truth)
    
    # Compute neg-ll (use keras function for coherency with validation step)
    truth_mat_ph = tf.placeholder(tf.float32, shape=(truth_mat.shape), name="truth_mat")
    preds_mat_ph = tf.placeholder(tf.float32, shape=(preds_mat.shape), name="preds_mat")
    neg_ll_node = keras.losses.binary_crossentropy(truth_mat_ph, preds_mat_ph)
    neg_ll = sess.run(neg_ll_node, {preds_mat_ph: preds_mat, truth_mat_ph: truth_mat})
    
    # Compute acc
    acc = 100 * accuracy_measure(truth_mat, preds_mat)
    
    np.save(os.path.join(save_folder, 'accuracy.npy'), acc)
    np.save(os.path.join(save_folder, 'Xent.npy'), 100 * neg_ll)
    np.save(os.path.join(save_folder, 'preds.npy'), preds_mat)
    np.save(os.path.join(save_folder, 'truth.npy'), truth_mat)

    return
	def orderless_NADE_generation(self, sess, feed_dict, orch_t, instrumentation=None):
		# instrumentation takes the form of a binary mask
		# Pre-compute the context embedding 
		# which will be the same for all the orderings
		context_embedding = sess.run(self.context_embedding_out, feed_dict)

		# Generate the orderings in parallel -> duplicate the embedding and orch_t matrices along batch dim
		_, orch_dim = orch_t.shape
		# Start with an orchestra prediction and mask equal to zero
		orch_t_reshape = np.concatenate([orch_t for _ in range(self.num_ordering)], axis=0)
		orch_pred = np.zeros_like(orch_t_reshape)
		mask = np.zeros_like(orch_t_reshape)
		context_embedding_reshape = np.concatenate([context_embedding for _ in range(self.num_ordering)], axis=0)
		
		# Three nodes to feed now: orch_pred, context_embedding, and mask
		feed_dict_known_context = {}
		feed_dict_known_context[self.orch_t_ph] = orch_t_reshape
		feed_dict_known_context[self.context_embedding_in] = context_embedding_reshape

		# Indices to sample given piano
		piano_t = feed_dict[self.piano_t_ph]
		ind_played_piano = np.where(piano_t > 0)[1]
		pitches_played_piano = set(self.pitch_piano[ind_played_piano])
		indices_to_sample = []
		for pitch_played_piano in pitches_played_piano:
			temp = np.where(self.pitch_orch==pitch_played_piano)[0]
			indices_to_sample.extend(temp.tolist())

		# Indices to sample given instrumentation
		indices_to_sample = set(indices_to_sample)
		if instrumentation is not None:
			indices_to_sample_instru = set(np.where(instrumentation>0)[0].tolist())
			indices_to_sample = indices_to_sample.intersection(indices_to_sample_instru) 
		
		indices_NOT_sample = set(range(orch_dim)) - indices_to_sample
		indices_to_sample = list(indices_to_sample)
		indices_NOT_sample = list(indices_NOT_sample)

		# Build the orderings (use the same ordering for all elems in batch)
		orderings = []
		for ordering_ind in range(self.num_ordering):
			# This ordering
			random.shuffle(indices_to_sample)
			orderings.append(indices_NOT_sample + indices_to_sample)
		
		# set not_sample value to 1 in mask
		for ind_order in indices_NOT_sample:
			mask[:, ind_order] = 1

		if self.DEBUG["save_accuracy_along_sampling"]:
			accuracy_along_sampling = []

		# Loop over the length of the orderings
		for d in range(len(indices_NOT_sample), orch_dim):
			# Generate step
			feed_dict_known_context[self.orch_pred] = orch_pred
			feed_dict_known_context[self.mask_input] = mask
			
			loss_batch, preds_batch = sess.run([self.loss_val, self.preds_gen], feed_dict_known_context)
			
			# Update matrices
			for ordering_ind in range(self.num_ordering):
				mask[ordering_ind, orderings[ordering_ind][d]] = 1
				##################################################
				# Do we sample or not ??????
				orch_pred[ordering_ind, orderings[ordering_ind][d]] = np.random.binomial(1, preds_batch[ordering_ind, orderings[ordering_ind][d]])
				##################################################

			##############################
			##############################
			# DEBUG
			# Observe the evolution of the accuracy along the sampling process
			if self.DEBUG["save_accuracy_along_sampling"]:
				accuracy_batch = np.mean(accuracy_measure(orch_t_reshape, orch_pred))
				accuracy_along_sampling.append(accuracy_batch)
			# Plot the predictions
			if self.DEBUG["plot_nade_ordering_preds"] and (self.DEBUG["batch_counter"]==(self.DEBUG["num_batch"]-1)):
				for ordering_ind in range(self.num_ordering):
					np.save(self.DEBUG["plot_nade_ordering_preds"] + '/' + str(d) + '_' + str(ordering_ind) + '.npy', orch_pred[ordering_ind,:])
				mean_pred_batch = self.mean_parallel_prediction(orch_pred)
				np.save(self.DEBUG["plot_nade_ordering_preds"] + '/' + str(d) + '_mean.npy', mean_pred_batch)
			##############################
			##############################

		if self.DEBUG["plot_nade_ordering_preds"] and (self.DEBUG["batch_counter"]==(self.DEBUG["num_batch"]-1)):
			np.save(self.DEBUG["plot_nade_ordering_preds"] + '/truth.npy', orch_t)

		# Save accuracy_along_sampling
		if self.DEBUG["save_accuracy_along_sampling"]:
			save_file_path = self.DEBUG["save_accuracy_along_sampling"] + '/' + str(self.DEBUG["batch_counter"]) + '.txt'
			with open(save_file_path, 'w') as thefile:
				for item in accuracy_along_sampling:
	  				thefile.write("{:.4f}\n".format(100*item))

		return orch_pred, loss_batch