def test_file_wise_generator(PARAMS, file_name_sp, file_name_mu, target_dB):
    n_fft = PARAMS['n_fft'][PARAMS['Model']]
    n_mels = PARAMS['n_mels'][PARAMS['Model']]
    featName = PARAMS['featName'][PARAMS['Model']]
    
    if file_name_mu=='':
        fv = preproc.get_featuregram(PARAMS, 'speech', PARAMS['feature_opDir'], file_name_sp, '', None, n_fft, n_mels, featName, save_feat=True)
    elif file_name_sp=='':
        fv = preproc.get_featuregram(PARAMS, 'music', PARAMS['feature_opDir'], '', file_name_mu, None, n_fft, n_mels, featName, save_feat=True)
    else:
        fv = preproc.get_featuregram(PARAMS, 'speech_music', PARAMS['feature_opDir'], file_name_sp, file_name_mu, target_dB, n_fft, n_mels, featName, save_feat=True)
        
    if PARAMS['frame_level_scaling']:
        fv = cscale_data(fv, PARAMS['mean_fold'+str(PARAMS['fold'])], PARAMS['stdev_fold'+str(PARAMS['fold'])])

    batchData = preproc.get_feature_patches(PARAMS, fv, PARAMS['W'], PARAMS['W_shift'], featName)
    if PARAMS['skewness_vector']:
        if PARAMS['skewness_vector']=='Harm':
            batchData = cget_data_statistics(batchData, axis=1)
            batchData = np.expand_dims(batchData, axis=2)
        elif PARAMS['skewness_vector']=='Perc':
            batchData = cget_data_statistics(batchData, axis=0)
            batchData = np.expand_dims(batchData, axis=1)

    if 'Lemaire_et_al' in PARAMS['Model']:
        batchData = np.transpose(batchData, axes=(0,2,1)) # TCN input shape=(batch_size, timesteps, ndim)

    numLab = np.shape(batchData)[0]
    
    if file_name_mu=='': # speech
        batchLabel = np.array([1]*numLab)
    elif file_name_sp=='': # music
        batchLabel = np.array([0]*numLab)
    else: # speech+music
        batchLabel = np.array([2]*numLab)
    OHE_batchLabel = to_categorical(batchLabel, num_classes=len(PARAMS['classes']))
    
    return batchData, OHE_batchLabel
def generator(PARAMS, folder, file_list, batchSize):
    batch_count = 0
    np.random.shuffle(file_list['speech'])
    np.random.shuffle(file_list['music'])

    file_list_sp_temp = file_list['speech'].copy()
    file_list_mu_temp = file_list['music'].copy()

    batchData_sp = np.empty([], dtype=float)
    batchData_mu = np.empty([], dtype=float)

    balance_sp = 0
    balance_mu = 0

    if not os.path.exists(PARAMS['feature_opDir'] + '/speech/'):
        os.makedirs(PARAMS['feature_opDir'] + '/speech/')
    if not os.path.exists(PARAMS['feature_opDir'] + '/music/'):
        os.makedirs(PARAMS['feature_opDir'] + '/music/')

    if len(PARAMS['classes']) == 3:
        np.random.shuffle(file_list['speech+music'])
        file_list_spmu_temp = file_list['speech+music'].copy()
        batchData_spmu = np.empty([], dtype=float)
        balance_spmu = 0
        if not os.path.exists(PARAMS['feature_opDir'] + '/speech_music/'):
            os.makedirs(PARAMS['feature_opDir'] + '/speech_music/')

    n_fft = PARAMS['n_fft'][PARAMS['Model']]
    n_mels = PARAMS['n_mels'][PARAMS['Model']]
    featName = PARAMS['featName'][PARAMS['Model']]

    while 1:
        batchData = np.empty([], dtype=float)

        while balance_sp < batchSize:
            if not file_list_sp_temp:
                file_list_sp_temp = file_list['speech'].copy()
            sp_fName = file_list_sp_temp.pop()
            sp_fName_path = folder + '/speech/' + sp_fName
            if not os.path.exists(sp_fName_path):
                # print(sp_fName_path, os.path.exists(sp_fName_path))
                continue
            fv_sp = preproc.get_featuregram(PARAMS, 'speech',
                                            PARAMS['feature_opDir'],
                                            sp_fName_path, '', None, n_fft,
                                            n_mels, featName)
            if PARAMS['frame_level_scaling']:
                fv_sp = cscale_data(fv_sp,
                                    PARAMS['mean_fold' + str(PARAMS['fold'])],
                                    PARAMS['stdev_fold' + str(PARAMS['fold'])])
            fv_sp_patches = preproc.get_feature_patches(
                PARAMS, fv_sp, PARAMS['W'], PARAMS['W_shift'], featName)
            if balance_sp == 0:
                batchData_sp = fv_sp_patches
            else:
                batchData_sp = np.append(batchData_sp, fv_sp_patches, axis=0)
            balance_sp += np.shape(fv_sp_patches)[0]
            # print('Speech: ', batchSize, balance_sp, np.shape(batchData_sp))

        while balance_mu < batchSize:
            if not file_list_mu_temp:
                file_list_mu_temp = file_list['music'].copy()
            mu_fName = file_list_mu_temp.pop()
            mu_fName_path = folder + '/music/' + mu_fName
            if not os.path.exists(mu_fName_path):
                # print(mu_fName_path, os.path.exists(mu_fName_path))
                continue
            fv_mu = preproc.get_featuregram(PARAMS, 'music',
                                            PARAMS['feature_opDir'], '',
                                            mu_fName_path, None, n_fft, n_mels,
                                            featName)
            if PARAMS['frame_level_scaling']:
                fv_mu = cscale_data(fv_mu,
                                    PARAMS['mean_fold' + str(PARAMS['fold'])],
                                    PARAMS['stdev_fold' + str(PARAMS['fold'])])
            fv_mu_patches = preproc.get_feature_patches(
                PARAMS, fv_mu, PARAMS['W'], PARAMS['W_shift'], featName)
            if balance_mu == 0:
                batchData_mu = fv_mu_patches
            else:
                batchData_mu = np.append(batchData_mu, fv_mu_patches, axis=0)
            balance_mu += np.shape(fv_mu_patches)[0]
            # print('Music: ', batchSize, balance_mu, np.shape(batchData_mu))

        batchData = batchData_mu[:batchSize, :]  # music label=0
        batchData = np.append(batchData, batchData_sp[:batchSize, :],
                              axis=0)  # speech label=1

        balance_mu -= batchSize
        balance_sp -= batchSize

        batchData_mu = batchData_mu[batchSize:, :]
        batchData_sp = batchData_sp[batchSize:, :]

        batchLabel = [0] * batchSize  # music
        batchLabel.extend([1] * batchSize)  # speech

        if len(PARAMS['classes']) == 3:
            while balance_spmu < batchSize:
                if not file_list_spmu_temp:
                    file_list_spmu_temp = file_list['speech+music'].copy()
                np.random.shuffle(file_list_spmu_temp)
                spmu_info = file_list_spmu_temp.pop()
                sp_fName = spmu_info['speech']
                sp_fName_path = folder + '/speech/' + sp_fName
                mu_fName = spmu_info['music']
                mu_fName_path = folder + '/music/' + mu_fName
                target_dB = spmu_info['SMR']
                if (not os.path.exists(mu_fName_path)) or (
                        not os.path.exists(sp_fName_path)):
                    continue
                fv_spmu = preproc.get_featuregram(PARAMS, 'speech_music',
                                                  PARAMS['feature_opDir'],
                                                  sp_fName_path, mu_fName_path,
                                                  target_dB, n_fft, n_mels,
                                                  featName)
                if PARAMS['frame_level_scaling']:
                    fv_spmu = cscale_data(
                        fv_spmu, PARAMS['mean_fold' + str(PARAMS['fold'])],
                        PARAMS['stdev_fold' + str(PARAMS['fold'])])
                fv_spmu_patches = preproc.get_feature_patches(
                    PARAMS, fv_spmu, PARAMS['W'], PARAMS['W_shift'], featName)
                if balance_spmu == 0:
                    batchData_spmu = fv_spmu_patches
                else:
                    batchData_spmu = np.append(batchData_spmu,
                                               fv_spmu_patches,
                                               axis=0)
                balance_spmu += np.shape(fv_spmu_patches)[0]
                # print('SpeechMusic: ', batchSize, balance_spmu, np.shape(batchData_spmu))

            # speech_music label=2
            batchData = np.append(batchData,
                                  batchData_spmu[:batchSize, :],
                                  axis=0)
            balance_spmu -= batchSize
            batchData_spmu = batchData_spmu[batchSize:, :]
            batchLabel.extend([2] * batchSize)  # speech+music

        if PARAMS['Model'] == 'Lemaire_et_al':
            batchData = np.transpose(
                batchData,
                axes=(0, 2,
                      1))  # TCN input shape=(batch_size, timesteps, ndim)
        ''' Adding Normal (Gaussian) noise for data augmentation '''
        if PARAMS['data_augmentation_with_noise']:
            scale = np.random.choice([5e-3, 1e-3, 5e-4, 1e-4])
            noise = np.random.normal(loc=0.0,
                                     scale=scale,
                                     size=np.shape(batchData))
            batchData = np.add(batchData, noise)

        OHE_batchLabel = to_categorical(batchLabel,
                                        num_classes=len(PARAMS['classes']))

        batch_count += 1
        # print('Batch ', batch_count, ' shape=', np.shape(batchData), np.shape(OHE_batchLabel))
        yield batchData, OHE_batchLabel
def generator(PARAMS, folder, file_list, batchSize):
    batch_count = 0
    np.random.shuffle(file_list['speech'])
    np.random.shuffle(file_list['music'])

    file_list_sp_temp = file_list['speech'].copy()
    file_list_mu_temp = file_list['music'].copy()

    batchData_sp = np.empty([], dtype=float)
    batchData_mu = np.empty([], dtype=float)

    balance_sp = 0
    balance_mu = 0

    if not os.path.exists(PARAMS['feature_opDir']+'/speech/'):
        os.makedirs(PARAMS['feature_opDir']+'/speech/')
    if not os.path.exists(PARAMS['feature_opDir']+'/music/'):
        os.makedirs(PARAMS['feature_opDir']+'/music/')

    if len(PARAMS['classes'])==3:
        np.random.shuffle(file_list['speech+music'])
        file_list_spmu_temp = file_list['speech+music'].copy()
        batchData_spmu = np.empty([], dtype=float)
        batchData_spmu_target_dB = np.empty([], dtype=float)
        balance_spmu = 0
        if not os.path.exists(PARAMS['feature_opDir']+'/speech_music/'):
            os.makedirs(PARAMS['feature_opDir']+'/speech_music/')
        
    n_fft = PARAMS['n_fft'][PARAMS['Model']]
    n_mels = PARAMS['n_mels'][PARAMS['Model']]
    featName = PARAMS['featName'][PARAMS['Model']]
        
    while 1:
        batchData = np.empty([], dtype=float)
        
        while balance_sp<batchSize:
            if not file_list_sp_temp:
                file_list_sp_temp = file_list['speech'].copy()
            sp_fName = file_list_sp_temp.pop()
            sp_fName_path = folder + '/speech/' + sp_fName
            if not os.path.exists(sp_fName_path):
                # print(sp_fName_path, os.path.exists(sp_fName_path))
                continue         
            fv_sp = preproc.get_featuregram(PARAMS, 'speech', PARAMS['feature_opDir'], sp_fName_path, '', None, n_fft, n_mels, featName)
            if PARAMS['frame_level_scaling']:
                fv_sp = cscale_data(fv_sp, PARAMS['mean_fold'+str(PARAMS['fold'])], PARAMS['stdev_fold'+str(PARAMS['fold'])])
            fv_sp_patches = preproc.get_feature_patches(PARAMS, fv_sp, PARAMS['W'], PARAMS['W_shift'], featName)

            if PARAMS['skewness_vector']:
                if PARAMS['skewness_vector']=='Harm':
                    fv_sp_patches = cget_data_statistics(fv_sp_patches, axis=1)
                    fv_sp_patches = np.expand_dims(fv_sp_patches, axis=2)
                elif PARAMS['skewness_vector']=='Perc':
                    fv_sp_patches = cget_data_statistics(fv_sp_patches, axis=0)
                    fv_sp_patches = np.expand_dims(fv_sp_patches, axis=1)

            if balance_sp==0:
                batchData_sp = fv_sp_patches
            else:
                batchData_sp = np.append(batchData_sp, fv_sp_patches, axis=0)
            balance_sp += np.shape(fv_sp_patches)[0]
            # print('Speech: ', batchSize, balance_sp, np.shape(batchData_sp))
            

        while balance_mu<batchSize:
            if not file_list_mu_temp:
                file_list_mu_temp = file_list['music'].copy()
            mu_fName = file_list_mu_temp.pop()
            mu_fName_path = folder + '/music/' + mu_fName
            if not os.path.exists(mu_fName_path):
                # print(mu_fName_path, os.path.exists(mu_fName_path))
                continue
            fv_mu = preproc.get_featuregram(PARAMS, 'music', PARAMS['feature_opDir'], '', mu_fName_path, None, n_fft, n_mels, featName)
            if PARAMS['frame_level_scaling']:
                fv_mu = cscale_data(fv_mu, PARAMS['mean_fold'+str(PARAMS['fold'])], PARAMS['stdev_fold'+str(PARAMS['fold'])])
            fv_mu_patches = preproc.get_feature_patches(PARAMS, fv_mu, PARAMS['W'], PARAMS['W_shift'], featName)

            if PARAMS['skewness_vector']:
                if PARAMS['skewness_vector']=='Harm':
                    fv_mu_patches = cget_data_statistics(fv_mu_patches, axis=1)
                    fv_mu_patches = np.expand_dims(fv_mu_patches, axis=2)
                elif PARAMS['skewness_vector']=='Perc':
                    fv_mu_patches = cget_data_statistics(fv_mu_patches, axis=0)
                    fv_mu_patches = np.expand_dims(fv_mu_patches, axis=1)

            if balance_mu==0:
                batchData_mu = fv_mu_patches
            else:
                batchData_mu = np.append(batchData_mu, fv_mu_patches, axis=0)
            balance_mu += np.shape(fv_mu_patches)[0]
            # print('Music: ', batchSize, balance_mu, np.shape(batchData_mu))

        batchData = batchData_mu[:batchSize, :] # music label=0
        batchData = np.append(batchData, batchData_sp[:batchSize, :], axis=0) # speech label=1

        balance_mu -= batchSize
        balance_sp -= batchSize

        batchData_mu = batchData_mu[batchSize:, :]            
        batchData_sp = batchData_sp[batchSize:, :]

        batchLabel = [0]*batchSize # music
        batchLabel.extend([1]*batchSize) # speech
        batchLabel_smr = np.ones((3*batchSize,2))
        batchLabel_smr[:batchSize, :] = np.repeat(np.array([1, 0], ndmin=2), batchSize, axis=0) # music
        batchLabel_smr[batchSize:2*batchSize] = np.repeat(np.array([0, 1], ndmin=2), batchSize, axis=0) # speech

        if len(PARAMS['classes'])==3:
            while balance_spmu<batchSize:
                if not file_list_spmu_temp:
                    file_list_spmu_temp = file_list['speech+music'].copy()
                np.random.shuffle(file_list_spmu_temp)
                spmu_info = file_list_spmu_temp.pop()
                sp_fName = spmu_info['speech']
                sp_fName_path = folder + '/speech/' + sp_fName
                mu_fName = spmu_info['music']
                mu_fName_path = folder + '/music/' + mu_fName
                target_dB = spmu_info['SMR']
                if (not os.path.exists(mu_fName_path)) or (not os.path.exists(sp_fName_path)):
                    continue
                fv_spmu = preproc.get_featuregram(PARAMS, 'speech_music', PARAMS['feature_opDir'], sp_fName_path, mu_fName_path, target_dB, n_fft, n_mels, featName)
                if PARAMS['frame_level_scaling']:
                    fv_spmu = cscale_data(fv_spmu, PARAMS['mean_fold'+str(PARAMS['fold'])],  PARAMS['stdev_fold'+str(PARAMS['fold'])])
                fv_spmu_patches = preproc.get_feature_patches(PARAMS, fv_spmu, PARAMS['W'], PARAMS['W_shift'], featName)

                if PARAMS['skewness_vector']:
                    if PARAMS['skewness_vector']=='Harm':
                        fv_spmu_patches = cget_data_statistics(fv_spmu_patches, axis=1)
                        fv_spmu_patches = np.expand_dims(fv_spmu_patches, axis=2)
                    elif PARAMS['skewness_vector']=='Perc':
                        fv_spmu_patches = cget_data_statistics(fv_spmu_patches, axis=0)
                        fv_spmu_patches = np.expand_dims(fv_spmu_patches, axis=1)

                if balance_spmu==0:
                    batchData_spmu = fv_spmu_patches
                    batchData_spmu_target_dB = np.array([target_dB]*np.shape(fv_spmu_patches)[0])
                else:
                    batchData_spmu = np.append(batchData_spmu, fv_spmu_patches, axis=0)
                    batchData_spmu_target_dB = np.append(batchData_spmu_target_dB, np.array([target_dB]*np.shape(fv_spmu_patches)[0]))
                balance_spmu += np.shape(fv_spmu_patches)[0]
                # print('SpeechMusic: ', batchSize, balance_spmu, np.shape(batchData_spmu))
                
            # speech_music label=2
            batchData = np.append(batchData, batchData_spmu[:batchSize, :], axis=0)  
            balance_spmu -= batchSize
            batchData_spmu = batchData_spmu[batchSize:, :]            
            batchLabel.extend([2]*batchSize) # speech+music
            label_idx = 2*batchSize
            for i in range(batchSize):
                if batchData_spmu_target_dB[i]>=0:
                    batchLabel_smr[label_idx] = np.array([1/np.power(10,(batchData_spmu_target_dB[i]/10)), 1], ndmin=2)
                else:
                    batchLabel_smr[label_idx] = np.array([1, np.power(10,(batchData_spmu_target_dB[i]/10))], ndmin=2)
                label_idx += 1
            batchData_spmu_target_dB = batchData_spmu_target_dB[batchSize:]
        
        if 'Lemaire_et_al' in PARAMS['Model']:
            batchData = np.transpose(batchData, axes=(0,2,1)) # TCN input shape=(batch_size, timesteps, ndim)
        
        ''' Adding Normal (Gaussian) noise for data augmentation '''
        if PARAMS['data_augmentation_with_noise']:
            scale = np.random.choice([5e-3, 1e-3, 5e-4, 1e-4])
            noise = np.random.normal(loc=0.0, scale=scale, size=np.shape(batchData))
            batchData = np.add(batchData, noise)
                            
        OHE_batchLabel = to_categorical(batchLabel, num_classes=len(PARAMS['classes']))

        '''
        Speech Nonspeech
        '''
        batchLabel_sp_nsp = np.copy(batchLabel)
        batchLabel_sp_nsp[:batchSize] = 0
        batchLabel_sp_nsp[batchSize:2*batchSize] = 1
        batchLabel_sp_nsp[2*batchSize:] = 0

        '''
        Music Nonmusic
        '''
        batchLabel_mu_nmu = np.copy(batchLabel)
        batchLabel_mu_nmu[:batchSize] = 1
        batchLabel_mu_nmu[batchSize:2*batchSize] = 0
        batchLabel_mu_nmu[2*batchSize:] = 0

        batchLabel_MTL = {'R': batchLabel_smr, 'S': batchLabel_sp_nsp, 'M': batchLabel_mu_nmu, '3C': OHE_batchLabel}
    
        batch_count += 1
        # print('Batch ', batch_count, ' shape=', np.shape(batchData), np.shape(OHE_batchLabel))
        
        if ('MTL' in PARAMS['Model']) or ('Cascaded_MTL' in PARAMS['Model']):
            yield batchData, batchLabel_MTL
        else:
            yield batchData, OHE_batchLabel