def preprocess_unlabeled_dataset(source_path,
                                 nbMFCCs=39,
                                 verbose=False,
                                 logger=None):  # TODO
    wav_files = transform.loadWavs(source_path)
    logger.debug("Found %d WAV files" % len(wav_files))
    assert len(wav_files) != 0

    X = []
    for i in tqdm(range(len(wav_files))):
        wav_name = str(wav_files[i])
        X_val, total_frames = create_mfcc('DUMMY', wav_name, nbMFCCs)
        X.append(X_val)

        if verbose:
            logger.debug('type(X_val): \t\t %s', type(X_val))
            logger.debug('X_val.shape: \t\t %s', X_val.shape)
            logger.debug('type(X_val[0][0]):\t %s', type(X_val[0][0]))
    return X
    def evaluateModel(self,
                      BIDIRECTIONAL,
                      N_HIDDEN_LIST,
                      batch_size,
                      dataName,
                      wavDir,
                      data_store_dir,
                      meanStd_path,
                      model_load,
                      nbMFCCs,
                      store_dir,
                      force_overwrite=False):
        logger_evaluate.info("\n\n\n")

        ####### THE DATA you want to evaluate ##########
        data_store_path = data_store_dir + dataName.replace(
            '/', '_') + "_nbMFCC" + str(nbMFCCs)
        if not os.path.exists(data_store_dir): os.makedirs(data_store_dir)
        predictions_path = store_dir + os.sep + dataName.replace(
            '/', '_') + "_predictions.pkl"

        # log file
        logFile = store_dir + os.sep + "Evaluation" + dataName.replace(
            '/', '_') + '.log'
        if os.path.exists(logFile) and not force_overwrite:
            from general_tools import query_yes_no
            if query_yes_no(
                    "Log file already exists at %s\n Do you want to evaluate again and overwrite?",
                    "y"):
                pass
            else:
                logger_evaluate.info(
                    "Log file already exists, not re-evaluating.... ")
                return 0
        fh = logging.FileHandler(logFile, 'w')  # create new logFile
        fh.setLevel(logging.INFO)
        fh.setFormatter(formatter)
        logger_evaluate.addHandler(fh)
        logger_evaluate.info("\n  MODEL:    %s", model_load)
        logger_evaluate.info("\n  WAV_DIR:  %s", wavDir)
        logger_evaluate.info("\n  PREDICTS: %s", predictions_path)
        logger_evaluate.info("\n  LOG:      %s", logFile)
        logger_evaluate.info("\n")

        # GATHERING DATA
        logger_evaluate.info("* Gathering Data ...")
        if os.path.exists(data_store_path + ".pkl"):
            [inputs, targets,
             valid_frames] = unpickle(data_store_path + ".pkl")
            calculateAccuracy = True
            logger_evaluate.info(
                "Successfully loaded preprocessed data, with targets")

        elif os.path.exists(
                data_store_path + "_noTargets.pkl"
        ):  # TODO: make it work for unlabeled datasets. see RNN_tools_lstm.py, eg iterate_minibatch_noTargets.
            [inputs] = unpickle(data_store_path + "_noTargets.pkl")
            calculateAccuracy = False  # we can't as we don't know the correct labels
            logger_evaluate.info(
                "Successfully loaded preprocessed data, no targets")

        else:
            logger_evaluate.info("Data not found, preprocessing...")

            # From WAVS, generate X, y and valid_frames; also store under data_store_dir
            def preprocessLabeledWavs(wavDir, store_dir, name):
                # fixWavs -> suppose this is done
                # convert to pkl
                X, y, valid_frames = preprocessWavs.preprocess_dataset(
                    source_path=wavDir,
                    nbMFCCs=nbMFCCs,
                    logger=logger_evaluate)

                X_data_type = 'float32'
                X = preprocessWavs.set_type(X, X_data_type)
                y_data_type = 'int32'
                y = preprocessWavs.set_type(y, y_data_type)
                valid_frames_data_type = 'int32'
                valid_frames = preprocessWavs.set_type(valid_frames,
                                                       valid_frames_data_type)

                return X, y, valid_frames

            def preprocessUnlabeledWavs(wavDir, store_dir, name):  #TODO
                # fixWavs -> suppose this is done
                # convert to pkl
                X = preprocessWavs.preprocess_unlabeled_dataset(
                    source_path=wavDir,
                    nbMFCCs=nbMFCCs,
                    logger=logger_evaluate)

                X_data_type = 'float32'
                X = preprocessWavs.set_type(X, X_data_type)

                return X

            # load wavs and labels
            wav_files = transform.loadWavs(wavDir)
            wav_filenames = [
                str(
                    os.path.basename(
                        os.path.dirname(
                            os.path.dirname(os.path.dirname(wav_file)))) +
                    os.sep + os.path.basename(
                        os.path.dirname(os.path.dirname(wav_file))) + os.sep +
                    os.path.basename(os.path.dirname(wav_file)) + os.sep +
                    os.path.basename(wav_file)) for wav_file in wav_files
            ]
            logger_evaluate.info("Found %s files to evaluate \n Example: %s",
                                 len(wav_filenames), wav_filenames[0])
            label_files = transform.loadPhns(wavDir)

            # if source dir doesn't contain labels, we can't calculate accuracy
            calculateAccuracy = True
            if not (len(wav_files) == len(label_files)):
                calculateAccuracy = False
                inputs = preprocessUnlabeledWavs(wavDir=wavDir,
                                                 store_dir=store_dir,
                                                 name=dataName)
            else:
                inputs, targets, valid_frames = preprocessLabeledWavs(
                    wavDir=wavDir, store_dir=store_dir, name=dataName)

            # normalize inputs using dataset Mean and Std_dev;  convert to float32 for GPU evaluation
            with open(meanStd_path, 'rb') as cPickle_file:
                [mean_val, std_val] = cPickle.load(cPickle_file)
            inputs = preprocessWavs.normalize(inputs, mean_val, std_val)

            # just to be sure
            X_data_type = 'float32'
            inputs = preprocessWavs.set_type(inputs, X_data_type)

            # Print some information
            logger_evaluate.debug("* Data information")
            logger_evaluate.debug('  inputs')
            logger_evaluate.debug('%s %s', type(inputs), len(inputs))
            logger_evaluate.debug('%s %s', type(inputs[0]), inputs[0].shape)
            logger_evaluate.debug('%s %s', type(inputs[0][0]),
                                  inputs[0][0].shape)
            logger_evaluate.debug('%s', type(inputs[0][0][0]))
            logger_evaluate.debug('y train')
            logger_evaluate.debug('  %s %s', type(targets), len(targets))
            logger_evaluate.debug('  %s %s', type(targets[0]),
                                  targets[0].shape)
            logger_evaluate.debug('  %s %s', type(targets[0][0]),
                                  targets[0][0].shape)

            # slice to have a number of inputs that is a multiple of batch size
            logger_evaluate.info(
                "Not evaluating %s last files (batch size mismatch)",
                len(inputs) % batch_size)
            inputs = inputs[:-(len(inputs) % batch_size) or None]
            if calculateAccuracy:
                targets = targets[:-(len(targets) % batch_size) or None]
                valid_frames = valid_frames[:-(len(valid_frames) %
                                               batch_size) or None]

            # pad the inputs to process batches easily
            inputs = pad_sequences_X(inputs)
            if calculateAccuracy: targets = pad_sequences_y(targets)

            # save the preprocessed data
            logger_evaluate.info("storing preprocessed data to: %s",
                                 data_store_path)
            if calculateAccuracy:
                general_tools.saveToPkl(data_store_path + '.pkl',
                                        [inputs, targets, valid_frames])
            else:
                general_tools.saveToPkl(data_store_path + '_noTargets.pkl',
                                        [inputs])

        # Gather filenames; for debugging
        wav_files = transform.loadWavs(wavDir)
        wav_filenames = [
            str(
                os.path.basename(
                    os.path.dirname(os.path.dirname(os.path.dirname(
                        wav_file)))) + os.sep +
                os.path.basename(os.path.dirname(os.path.dirname(wav_file))) +
                os.sep + os.path.basename(os.path.dirname(wav_file)) + os.sep +
                os.path.basename(wav_file)) for wav_file in wav_files
        ]
        logger_evaluate.debug(" # inputs: %s, # wav files: %s", len(inputs),
                              len(wav_files))

        # make copy of data because we might need to use is again for calculating accurasy, and the iterator will remove elements from the array
        inputs_bak = copy.deepcopy(inputs)
        if calculateAccuracy:
            targets_bak = copy.deepcopy(targets)
            valid_frames_bak = copy.deepcopy(valid_frames)

        logger_evaluate.info("* Evaluating: pass over Evaluation Set")

        if calculateAccuracy:  # if .phn files are provided, we can check our predictions
            logger_evaluate.info(
                "Getting predictions and calculating accuracy...")
            avg_error, avg_acc, predictions = self.RNN_network.run_epoch(X=inputs, y=targets, valid_frames=valid_frames, \
                                                                         get_predictions=True, batch_size=batch_size)

            logger_evaluate.info("All batches, avg Accuracy: %s", avg_acc)
            inputs = inputs_bak
            targets = targets_bak
            valid_frames = valid_frames_bak

            #uncomment if you want to save everything in one place (takes quite a lot of storage space)
            #general_tools.saveToPkl(predictions_path, [inputs, predictions, targets, valid_frames, avg_Acc])

        else:
            # TODO fix this
            for inputs, masks, seq_lengths in tqdm(
                    iterate_minibatches_noTargets(inputs,
                                                  batch_size=batch_size,
                                                  shuffle=False),
                    total=len(inputs)):
                # get predictions
                nb_inputs = len(
                    inputs)  # usually batch size, but could be lower
                seq_len = len(inputs[0])
                prediction = self.RNN_network.predictions_fn(inputs, masks)
                prediction = np.reshape(prediction, (nb_inputs, -1))
                prediction = list(prediction)
                predictions = predictions + prediction

            inputs = inputs_bak
            #general_tools.saveToPkl(predictions_path, [inputs, predictions])

        # Print information about the predictions
        logger_evaluate.info("* Done")
        end_evaluation_time = time.time()
        eval_duration = end_evaluation_time - program_start_time
        logger_evaluate.info('Total time: {:.3f}'.format(eval_duration))
        # Print the results
        try:
            printEvaluation(wav_filenames,
                            inputs,
                            predictions,
                            targets,
                            valid_frames,
                            avg_acc,
                            range(len(inputs)),
                            logger=logger_evaluate,
                            only_final_accuracy=True)
        except:
            pdb.set_trace()
        logger_evaluate.info(
            'Evaluation duration: {:.3f}'.format(eval_duration))
        logger_evaluate.info(
            'Printing duration: {:.3f}'.format(time.time() -
                                               end_evaluation_time))

        # close the log handler
        fh.close()
        logger_evaluate.removeHandler(fh)
def preprocess_dataset(source_path,
                       nbMFCCs=39,
                       logger=None,
                       debug=None,
                       verbose=False):
    """Preprocess data, ignoring compressed files and files starting with 'SA'"""
    X = []
    y = []
    valid_frames = []

    print(nbMFCCs)

    # source_path is the root dir of all the wav/phn files
    wav_files = transform.loadWavs(source_path)
    label_files = transform.loadPhns(source_path)

    logger.debug("Found %d WAV files" % len(wav_files))
    logger.debug("Found %d PHN files" % len(label_files))
    assert len(wav_files) == len(label_files)
    assert len(wav_files) != 0

    processed = 0
    for i in tqdm(range(len(wav_files))):
        phn_name = str(label_files[i])
        wav_name = str(wav_files[i])

        if (
                wav_name.startswith("SA")
        ):  #specific for TIMIT: these files contain strong dialects; don't use them
            continue

        # Get MFCC of the WAV
        X_val, total_frames = create_mfcc(
            'DUMMY', wav_name, nbMFCCs
        )  # get 3 levels: 0th, 1st and 2nd derivative (=> 3*13 = 39 coefficients)
        total_frames = int(total_frames)

        X.append(X_val)

        # Get phonemes and valid frame numbers out of .phn files
        total_duration = get_total_duration(phn_name)
        fr = open(phn_name)

        # some .PHN files don't start at 0. Set default phoneme to silence (expected at the end of phoneme_set_list)
        y_vals = np.zeros(total_frames) + phoneme_classes[phoneme_set_list[-1]]
        valid_frames_vals = []

        for line in fr:
            [start_time, end_time, phoneme] = line.rstrip('\n').split()
            start_time = int(start_time)
            end_time = int(end_time)
            start_ind = int(
                np.round(start_time / (total_duration / total_frames)))
            end_ind = int(np.round(end_time / (total_duration / total_frames)))

            valid_ind = int((start_ind + end_ind) / 2)
            valid_frames_vals.append(valid_ind)

            phoneme_num = phoneme_classes[phoneme]
            # check that phoneme is found in dict
            if (phoneme_num == -1):
                logger.error("In file: %s, phoneme not found: %s", phn_name,
                             phoneme)
                pdb.set_trace()
            y_vals[start_ind:end_ind] = phoneme_num

            if verbose:
                logger.debug('%s', (total_frames / float(total_duration)))
                logger.debug('TIME  start: %s end: %s, phoneme: %s, class: %s',
                             start_time, end_time, phoneme, phoneme_num)
                logger.debug('FRAME start: %s end: %s, phoneme: %s, class: %s',
                             start_ind, end_ind, phoneme, phoneme_num)
        fr.close()

        # append the target array to our y
        y.append(y_vals.astype('int32'))

        # append the valid_frames array to our valid_frames
        valid_frames_vals = np.array(valid_frames_vals)
        valid_frames.append(valid_frames_vals.astype('int32'))

        if verbose:
            logger.debug('(%s) create_target_vector: %s', i, phn_name[:-4])
            logger.debug('type(X_val): \t\t %s', type(X_val))
            logger.debug('X_val.shape: \t\t %s', X_val.shape)
            logger.debug('type(X_val[0][0]):\t %s', type(X_val[0][0]))

            logger.debug('type(y_val): \t\t %s', type(y_vals))
            logger.debug('y_val.shape: \t\t %s', y_vals.shape)
            logger.debug('type(y_val[0]):\t %s', type(y_vals[0]))
            logger.debug('y_val: \t\t %s', (y_vals))

        processed += 1
        if debug != None and processed >= debug:
            break

    return X, y, valid_frames