def calc_data_error(tdoa_sec, m, sound_speed_mps, hydrophones_config,
                    hydrophone_pairs):
    """ Calculates tdoa measurement errors. Eq. (9) in Mouy et al. 2018"""
    tdoa_m = predict_tdoa(m, sound_speed_mps, hydrophones_config,
                          hydrophone_pairs)
    Q = len(m)
    M = m.size
    N = len(tdoa_sec)
    error_std = np.sqrt((1 / (Q * (N - M))) * (sum((tdoa_sec - tdoa_m)**2)))
    return error_std
def calc_data_error(tdoa_sec, m, sound_speed_mps,hydrophones_config, hydrophone_pairs):
    """ Calculates tdoa measurement errors. Eq. (9) in Mouy et al. 2018"""
    tdoa_m = predict_tdoa(m, sound_speed_mps, hydrophones_config, hydrophone_pairs)
    Q = len(m)
    M = m.size # number of dimensions of the model (here: X, Y, and Z)
    N = len(tdoa_sec) # number of measurements
    if N > M:
        error_std = np.sqrt((1/(Q*(N-M))) * (sum((tdoa_sec-tdoa_m)**2)))
    else:
        error_std = np.sqrt((sum((tdoa_sec-tdoa_m)**2)))
    return error_std
    #     localization_config['GRIDSEARCH']['radius_m'],
    #     origin=localization_config['GRIDSEARCH']['origin'])
    #sources = defineCubeVolumeGrid(0.2, 2, origin=[0, 0, 0])
    try:
        npzfile = np.load(localization_config['GRIDSEARCH']['stored_tdoas'])
        sources_tdoa = npzfile['sources_tdoa']
        sources_array = npzfile['sources']
        sources['x'] = sources_array[:,0]
        sources['y'] = sources_array[:,1]
        sources['z'] = sources_array[:,2]
        print('Succesully loaded precomputed grid TDOAs from file.')
    except:
        print("Couln't read precomputed TDOAs from file, computing grid TDOAs...")    
        sources_tdoa = np.zeros(shape=(len(hydrophone_pairs),len(sources)))
        for source_idx, source in sources.iterrows():
            sources_tdoa[:,source_idx] = predict_tdoa(source, sound_speed_mps, hydrophones_config, hydrophone_pairs).T
        np.savez(os.path.join(outdir,'tdoa_grid'),sources_tdoa=sources_tdoa,sources=sources)    
    # # Azimuth:
    # theta = np.arctan2(sources['y'].to_numpy(),sources['x'].to_numpy())*(180/np.pi)
    # theta = (((theta+90) % 360)-180) *(-1)        
    # # Elevation:
    # phi = np.arctan2(sources['y'].to_numpy()**2+sources['x'].to_numpy()**2,sources['z'].to_numpy())*(180/np.pi)
    # phi = phi - 90
    # sources['theta'] = theta
    # sources['phi'] = phi

# Define Measurement object for the localization results
# if localization_config['METHOD']['linearized_inversion']:
#     localizations = Measurement()
#     localizations.metadata['measurer_name'] = localization_method_name
#     localizations.metadata['measurer_version'] = '0.1'
def run_localization(infile, deployment_info_file, detection_config,
                     hydrophones_config, localization_config):
    t1 = 0
    t2 = 70
    # Look up data files for all channels
    audio_files = find_audio_files(infile, hydrophones_config)

    # run detector on selected channel
    print('DETECTION')
    detections = run_detector(
        audio_files['path'][detection_config['AUDIO']['channel']],
        audio_files['channel'][detection_config['AUDIO']['channel']],
        detection_config,
        chunk=[t1, t2],
        deployment_file=deployment_info_file)
    #detections.insert_values(frequency_min=20)

    print(str(len(detections)) + ' detections')

    # # plot spectrogram/waveforms of all channels and detections
    # plot_data(audio_files,
    #           detection_config['SPECTROGRAM']['frame_sec'],
    #           detection_config['SPECTROGRAM']['window_type'],
    #           detection_config['SPECTROGRAM']['nfft_sec'],
    #           detection_config['SPECTROGRAM']['step_sec'],
    #           detection_config['SPECTROGRAM']['fmin_hz'],
    #           detection_config['SPECTROGRAM']['fmax_hz'],
    #           chunk = [t1, t2],
    #           detections=detections,
    #           detections_channel=detection_config['AUDIO']['channel'])

    # localization
    sound_speed_mps = localization_config['ENVIRONMENT']['sound_speed_mps']
    ref_channel = localization_config['TDOA']['ref_channel']

    # define search window based on hydrophone separation and sound speed
    hydrophones_dist_matrix = calc_hydrophones_distances(hydrophones_config)
    TDOA_max_sec = np.max(hydrophones_dist_matrix) / sound_speed_mps

    # define hydrophone pairs
    hydrophone_pairs = defineReceiverPairs(len(hydrophones_config),
                                           ref_receiver=ref_channel)

    # pre-compute grid search if needed
    if localization_config['METHOD']['grid_search']:
        sources = defineSphereVolumeGrid(
            localization_config['GRIDSEARCH']['spacing_m'],
            localization_config['GRIDSEARCH']['radius_m'],
            origin=localization_config['GRIDSEARCH']['origin'])
        #sources = defineCubeVolumeGrid(0.2, 2, origin=[0, 0, 0])
        sources_tdoa = np.zeros(shape=(len(hydrophone_pairs), len(sources)))
        for source_idx, source in sources.iterrows():
            sources_tdoa[:, source_idx] = predict_tdoa(source, sound_speed_mps,
                                                       hydrophones_config,
                                                       hydrophone_pairs).T
        theta = np.arctan2(sources['y'].to_numpy(),
                           sources['x'].to_numpy()) * (180 / np.pi)  # azimuth
        phi = np.arctan2(
            sources['y'].to_numpy()**2 + sources['x'].to_numpy()**2,
            sources['z'].to_numpy()) * (180 / np.pi)
        sources['theta'] = theta
        sources['phi'] = phi

    # Define Measurement object for the localization results
    if localization_config['METHOD']['linearized_inversion']:
        localizations = Measurement()
        localizations.metadata['measurer_name'] = localization_method_name
        localizations.metadata['measurer_version'] = '0.1'
        localizations.metadata['measurements_name'] = [[
            'x', 'y', 'z', 'x_std', 'y_std', 'z_std', 'tdoa_errors_std'
        ]]
    # need to define what output is for grid search

    # pick single detection (will use loop after)
    print('LOCALIZATION')
    for detec_idx, detec in detections.data.iterrows():

        if 'detec_idx_forced' in locals():
            print('Warning: forced to only process detection #',
                  str(detec_idx_forced))
            detec = detections.data.iloc[detec_idx_forced]

        print(str(detec_idx + 1) + '/' + str(len(detections)))

        # load data from all channels for that detection
        waveform_stack = stack_waveforms(audio_files, detec, TDOA_max_sec)

        # readjust signal boundaries to only focus on section with most energy
        percentage_max_energy = 90
        chunk = ecosound.core.tools.tighten_signal_limits_peak(
            waveform_stack[detection_config['AUDIO']['channel']],
            percentage_max_energy)
        waveform_stack = [x[chunk[0]:chunk[1]] for x in waveform_stack]

        # calculate TDOAs
        tdoa_sec, corr_val = calc_tdoa(
            waveform_stack,
            hydrophone_pairs,
            detec['audio_sampling_frequency'],
            TDOA_max_sec=TDOA_max_sec,
            upsample_res_sec=localization_config['TDOA']['upsample_res_sec'],
            normalize=localization_config['TDOA']['normalize'],
            doplot=False,
        )

        if localization_config['METHOD']['grid_search']:
            delta_tdoa = sources_tdoa - tdoa_sec
            delta_tdoa_norm = np.linalg.norm(delta_tdoa, axis=0)
            sources['delta_tdoa'] = delta_tdoa_norm

            fig = plt.figure()
            ax = fig.add_subplot(111, projection='3d')
            colors = matplotlib.cm.tab10(hydrophones_config.index.values)
            #alphas = delta_tdoa_norm - min(delta_tdoa_norm)
            #alphas = alphas/max(alphas)
            #alphas = alphas - 1
            #alphas = abs(alphas)
            #alphas = np.array(alphas)
            alphas = 0.5
            for index, hp in hydrophones_config.iterrows():
                point = ax.scatter(
                    hp['x'],
                    hp['y'],
                    hp['z'],
                    s=40,
                    color=colors[index],
                    label=hp['name'],
                )
            ax.scatter(
                sources['x'],
                sources['y'],
                sources['z'],
                c=sources['delta_tdoa'],
                s=2,
                alpha=alphas,
            )
            # Axes labels
            ax.set_xlabel('X (m)', labelpad=10)
            ax.set_ylabel('Y (m)', labelpad=10)
            ax.set_zlabel('Z (m)', labelpad=10)
            # legend
            ax.legend(bbox_to_anchor=(1.07, 0.7, 0.3, 0.2), loc='upper left')
            plt.tight_layout()
            plt.show()

            plt.figure()
            sources.plot.hexbin(x="theta",
                                y="phi",
                                C="delta_tdoa",
                                reduce_C_function=np.mean,
                                gridsize=40,
                                cmap="viridis")

        # Lineralized inversion
        if localization_config['METHOD']['linearized_inversion']:
            [m, iterations_logs
             ] = linearized_inversion(tdoa_sec,
                                      hydrophones_config,
                                      hydrophone_pairs,
                                      localization_config['INVERSION'],
                                      sound_speed_mps,
                                      doplot=False)

            # Estimate uncertainty
            tdoa_errors_std = calc_data_error(tdoa_sec, m, sound_speed_mps,
                                              hydrophones_config,
                                              hydrophone_pairs)
            loc_errors_std = calc_loc_errors(tdoa_errors_std, m,
                                             sound_speed_mps,
                                             hydrophones_config,
                                             hydrophone_pairs)

        # Bring all detection and localization informations together
        detec.loc['x'] = m['x'].values[0]
        detec.loc['y'] = m['y'].values[0]
        detec.loc['z'] = m['z'].values[0]
        detec.loc['x_std'] = loc_errors_std['x_std'].values[0]
        detec.loc['y_std'] = loc_errors_std['y_std'].values[0]
        detec.loc['z_std'] = loc_errors_std['z_std'].values[0]
        detec.loc['tdoa_errors_std'] = tdoa_errors_std[0]

        # stack to results into localization object
        localizations.data = localizations.data.append(detec,
                                                       ignore_index=True)

    return localizations