예제 #1
0
def get_bs_histdep_on_full(analysis_file, spike_times_full_file_name,
                           embedding, number_of_bootstraps,
                           target_recording_length_in_min):

    h5_data_dir_name = "{}/{}/{}".format(spike_times_full_file_name,
                                         target_recording_length_in_min,
                                         str(embedding))

    if not h5_data_dir_name in analysis_file:
        analysis_file.create_group(h5_data_dir_name)
    h5_data_dir = analysis_file[h5_data_dir_name]

    if "bs_history_dependence" in h5_data_dir:
        stored_bs_history_dependence = h5_data_dir["bs_history_dependence"][()]
    else:
        stored_bs_history_dependence = []

    if len(stored_bs_history_dependence) >= number_of_bootstraps:
        return stored_bs_history_dependence[:number_of_bootstraps]

    else:
        spike_times_full \
            = utl.get_spike_times_from_file("{}/{}".format(data_dir,
                                                           spike_times_full_file_name))[0]

        bs_Rs = np.zeros(number_of_bootstraps -
                         len(stored_bs_history_dependence))
        for rep in range(len(bs_Rs)):
            bs_spike_times = get_random_chunk_of_data(
                spike_times_full, 60 * target_recording_length_in_min)

            bs_recording_length = bs_spike_times[-1] - bs_spike_times[0]

            bs_R = hapi.get_history_dependence_for_single_embedding(
                bs_spike_times,
                bs_recording_length,
                estimation_method,
                embedding,
                embedding_step_size,
                bbc_tolerance=np.inf)
            bs_Rs[rep] = bs_R

        bs_history_dependence \
            = np.hstack((stored_bs_history_dependence,
                         bs_Rs))

        if len(stored_bs_history_dependence) > 0:
            del h5_data_dir["bs_history_dependence"]
        h5_data_dir.create_dataset("bs_history_dependence",
                                   data=bs_history_dependence)

        return bs_history_dependence
예제 #2
0
def compare_CIs(analysis_file, ax):
    spike_times_full \
        = utl.get_spike_times_from_file("{}/{}".format(data_dir,
                                                       spike_times_full_file_name))[0]
    recording_length_full = spike_times_full[-1] - spike_times_full[0]

    coverage = []

    for target_recording_length_in_min in target_recording_lengths_in_min:
        print(target_recording_length_in_min)

        embedding = representative_embeddings[target_recording_length_in_min]

        # first compute from long recording..
        # get the "true" R for the given embedding
        history_dependence_full \
            = hapi.get_history_dependence_for_single_embedding(spike_times_full,
                                                               recording_length_full,
                                                               estimation_method,
                                                               embedding,
                                                               embedding_step_size,
                                                               bbc_tolerance=np.inf)

        # now compute 95% CIs many times and
        # see whether "true" R was included in the CI 95% of the times
        bs_CIs = get_bs_data(analysis_file,
                             spike_times_full,
                             number_of_bootstraps,
                             embedding,
                             target_recording_length_in_min,
                             number_of_CIs,
                             target_data="CI")

        coverage_this_rec_len = 0
        for CI_lo, CI_hi in bs_CIs:
            if CI_lo <= history_dependence_full <= CI_hi:
                coverage_this_rec_len += 1
        coverage_this_rec_len = 100 * (coverage_this_rec_len / number_of_CIs)
        coverage += [coverage_this_rec_len]

    ax.plot(target_recording_lengths_in_min, coverage, color='g', ls='-')

    ax.set_xscale('log')
    ax.xaxis.set_major_formatter(ticker.FuncFormatter(vsl.format_x_label))
예제 #3
0
def compare_R_variances_recording_length(analysis_file, ax):
    spike_times_full \
        = utl.get_spike_times_from_file("{}/{}".format(data_dir,
                                                       spike_times_full_file_name))[0]

    recording_length_full = spike_times_full[-1] - spike_times_full[0]

    bs_std_means = []
    bs_std_stds = []
    full_stds = []

    for target_recording_length_in_min in target_recording_lengths_in_min:
        print(target_recording_length_in_min)

        embedding = representative_embeddings[target_recording_length_in_min]

        bs_stds = get_bs_data(analysis_file,
                              spike_times_full,
                              number_of_bootstraps,
                              embedding,
                              target_recording_length_in_min,
                              number_of_CIs,
                              target_data="std")

        bs_std_means += [np.average(bs_stds)]
        bs_std_stds += [np.std(bs_stds)]

        # now compute from long recording..
        bs_history_dependence_full = get_bs_histdep_on_full(
            analysis_file, spike_times_full_file_name, embedding,
            number_of_bootstraps, target_recording_length_in_min)

        full_stds += [np.std(bs_history_dependence_full)]

    ax.errorbar(target_recording_lengths_in_min,
                bs_std_means,
                yerr=bs_std_stds,
                color='b',
                ecolor='b',
                ls='-')
    ax.plot(target_recording_lengths_in_min, full_stds, color='k', ls='-')

    ax.set_xscale('log')
    ax.xaxis.set_major_formatter(ticker.FuncFormatter(vsl.format_x_label))
예제 #4
0
def parse_arguments(defined_tasks, defined_estimation_methods):
    """
    Parse the arguments passed to the script via the command line.

    Import settings from file, do some sanity checks to avoid faulty runs.
    """

    # parse arguments
    parser = argparse.ArgumentParser(
        description="""
    History dependence estimator, v. {}

    Estimate the history dependence and temporal depth of a single
    neuron, based on information-theoretical measures for spike time
    data, as presented in (Rudelt et al, in prep.) [1].  Parameters
    can be passed via the command line or through files, where command
    line options are prioritised over those passed by file.  (If none
    are supplied, settings are read from the 'default.yaml' file.)  A
    user new to this tool is encouraged to run

      python3 {} sample_data/spike_times.dat -o sample_output.pdf \\
        -s settings/test.yaml

    to test the functionality of this tool.  A more detailed
    description can be found in the guide provided with the tool [2].

    [1]: L. Rudelt, D. G. Marx, M. Wibral, V. Priesemann: Embedding
        optimization reveals long-lasting history dependence in
        neural spiking activity (in prep.)

    [2]: https://github.com/Priesemann-Group/hdestimator
        """.format(__version__, argv[0]),
        formatter_class=argparse.RawDescriptionHelpFormatter)
    optional_arguments = parser._action_groups.pop()

    required_arguments = parser.add_argument_group("required arguments")
    required_arguments.add_argument(
        'spike_times_file',
        action="store",
        help=
        "Define file from which to read spike times and on which to perform the analysis.  The file should contain one spike time per line.",
        nargs='+')

    optional_arguments.add_argument(
        "-t",
        "--task",
        metavar="TASK",
        action="store",
        help=
        "Define task to be performed.  One of {}.  Per default, the full analysis is performed."
        .format(defined_tasks),
        default="full-analysis")
    optional_arguments.add_argument(
        "-e",
        "--estimation-method",
        metavar="EST_METHOD",
        action="store",
        help="Specify estimation method for the analysis, one of {}.".format(
            defined_estimation_methods))
    optional_arguments.add_argument(
        "-h5",
        "--hdf5-dataset",
        action="store",
        help="Load data stored in a dataset in a hdf5 file.",
        nargs='+')
    optional_arguments.add_argument("-o",
                                    "--output",
                                    metavar="IMAGE_FILE",
                                    action="store",
                                    help="Save the output image to file.")
    optional_arguments.add_argument(
        "-p",
        "--persistent",
        action="store_true",
        help=
        "Save the analysis to file.  If an existing analysis is found, read it from file."
    )
    optional_arguments.add_argument(
        "-s",
        "--settings-file",
        metavar="SETTINGS_FILE",
        action="store",
        help="Specify yaml file from which to load custom settings.")
    optional_arguments.add_argument(
        "-l",
        "--label",
        metavar="LABEL",
        action="store",
        help="Include a label in the output to classify the analysis.")
    # optional_arguments.add_argument("-v", "--verbose", action="store_true", help="Print more info at run time.")
    optional_arguments.add_argument(
        '--version',
        action='version',
        version='hdestimator v. {}'.format(__version__),
        help="Show version of the tool and exit.")
    parser._action_groups.append(optional_arguments)
    args = parser.parse_args()

    # check that parsed arguments are valid

    task = args.task.lower()
    spike_times_file_names = args.spike_times_file

    task_found = False
    task_full_name = ""
    for defined_task in defined_tasks:
        if defined_task.startswith(task):
            if not task_found:
                task_found = True
                task_full_name = defined_task
            else:
                print(
                    "Task could not be uniquely determined.  Task must be one of {}.  Aborting."
                    .format(defined_tasks),
                    file=stderr,
                    flush=True)
                exit(EXIT_FAILURE)

    task = task_full_name

    if not task in defined_tasks:
        print("Task must be one of {}.  Aborting.".format(defined_tasks),
              file=stderr,
              flush=True)
        exit(EXIT_FAILURE)

    for spike_times_file_name in spike_times_file_names:
        if not exists(spike_times_file_name):
            print("Spike times file {} not found.  Aborting.".format(
                spike_times_file_name),
                  file=stderr,
                  flush=True)
            exit(EXIT_FAILURE)

    spike_times = utl.get_spike_times_from_file(spike_times_file_names,
                                                args.hdf5_dataset)

    if not isinstance(spike_times, np.ndarray):
        print("Error loading spike times. Aborting.", file=stderr, flush=True)
        exit(EXIT_FAILURE)
    elif not len(spike_times) > 0:
        print("Spike times are empty. Aborting.", file=stderr, flush=True)
        exit(EXIT_FAILURE)

    #
    # PARSE SETTINGS
    #

    # create default settings file if it does not exist:
    if not isfile('{}/settings/default.yaml'.format(ESTIMATOR_DIR)):
        utl.create_default_settings_file(ESTIMATOR_DIR)

    # load default settings
    with open('{}/settings/default.yaml'.format(ESTIMATOR_DIR),
              'r') as default_settings_file:
        settings = yaml.load(default_settings_file, Loader=yaml.BaseLoader)

    # overwrite default settings with custom ones
    if not args.settings_file is None:
        if not isfile(args.settings_file):
            print("Error: Settings file {} not found. Aborting.".format(
                args.settings_file),
                  file=stderr,
                  flush=True)
            exit(EXIT_FAILURE)
        with open(args.settings_file, 'r') as custom_settings_file:
            custom_settings = yaml.load(custom_settings_file,
                                        Loader=yaml.BaseLoader)
        for setting_key in settings:
            if setting_key in custom_settings:
                settings[setting_key] = custom_settings[setting_key]

    if args.persistent:
        settings['persistent_analysis'] = "True"
    # if args.verbose:
    #     settings['verbose_output'] = "True"
    # else:
    settings['verbose_output'] = "False"

    if not args.estimation_method is None:
        settings['estimation_method'] = args.estimation_method

    if not 'block_length_l' in settings:
        settings['block_length_l'] = "None"

    # check that required settings are defined
    required_parameters = [
        'embedding_past_range_set',
        'embedding_number_of_bins_set',
        'embedding_scaling_exponent_set',
        'embedding_step_size',
        'bbc_tolerance',
        'number_of_bootstraps_R_max',
        'number_of_bootstraps_R_tot',
        'number_of_bootstraps_nonessential',
        'block_length_l',
        'bootstrap_CI_percentile_lo',
        'bootstrap_CI_percentile_hi',
        # 'number_of_permutations',
        'auto_MI_bin_size_set',
        'auto_MI_max_delay'
    ]

    required_settings = [
        'estimation_method', 'plot_AIS', 'ANALYSIS_DIR', 'persistent_analysis',
        'cross_validated_optimization', 'return_averaged_R',
        'bootstrap_CI_use_sd', 'verbose_output', 'plot_settings', 'plot_color'
    ] + required_parameters

    for required_setting in required_settings:
        if not required_setting in settings:
            print(
                "Error in settings file: {} is not defined. Aborting.".format(
                    required_setting),
                file=stderr,
                flush=True)
            exit(EXIT_FAILURE)

    # sanity check for the settings
    if not settings['estimation_method'] in defined_estimation_methods:
        print("Error: estimation_method must be one of {}. Aborting.".format(
            defined_estimation_methods),
              file=stderr,
              flush=True)
        exit(EXIT_FAILURE)

    # evaluate settings (turn strings into booleans etc if applicable)
    for setting_key in [
            'persistent_analysis', 'verbose_output',
            'cross_validated_optimization', 'return_averaged_R',
            'bootstrap_CI_use_sd', 'plot_AIS'
    ]:
        settings[setting_key] = ast.literal_eval(settings[setting_key])
    for plot_setting in settings['plot_settings']:
        try:
            settings['plot_settings'][plot_setting] \
                = ast.literal_eval(settings['plot_settings'][plot_setting])
        except:
            continue

    for parameter_key in required_parameters:
        if isinstance(settings[parameter_key], list):
            settings[parameter_key] = [
                ast.literal_eval(element)
                for element in settings[parameter_key]
            ]
        elif parameter_key == 'embedding_scaling_exponent_set' \
             and isinstance(settings['embedding_scaling_exponent_set'], dict):
            # embedding_scaling_exponent_set can be passed either as a
            # list, in which case it is evaluated as such or it can be
            # passed by specifying three parameters that determine how
            # many scaling exponents should be used.  In the latter case, the
            # uniform embedding as well as the embedding for which
            # the first bin has a length of min_first_bin_size (in
            # seconds) are used, as well as linearly spaced scaling
            # factors in between, such that in total
            # number_of_scalings scalings are used

            for key in settings['embedding_scaling_exponent_set']:
                settings['embedding_scaling_exponent_set'][key] \
                    = ast.literal_eval(settings['embedding_scaling_exponent_set'][key])
        else:
            settings[parameter_key] = ast.literal_eval(settings[parameter_key])

    # If R_tot is computed as an average over Rs, no confidence interval can be estimated
    if settings['return_averaged_R']:
        settings['number_of_bootstraps_R_tot'] = 0

    # if the user specifies a file in which to store output image:
    # store this in settings
    if not args.output is None:
        settings['output_image'] = args.output

    # if the user wants to store the data, do so in a dedicated directory below the
    # ANALYSIS_DIR passed via settings (here it is also checked whether there is an
    # existing analysis, for which the hash sum of the content of the spike times
    # file must match).
    #
    # If the user does not want to store the data, a temporary file is created and
    # then deleted after the program finishes
    #
    # For most tasks an existing analysis file is expected

    if settings['persistent_analysis']:
        if not isdir(settings['ANALYSIS_DIR']):
            print("Error: {} not found. Aborting.".format(
                settings['ANALYSIS_DIR']),
                  file=stderr,
                  flush=True)
            exit(EXIT_FAILURE)

        analysis_dir, analysis_num, existing_analysis_found \
            = utl.get_or_create_analysis_dir(spike_times,
                                             spike_times_file_names,
                                             settings['ANALYSIS_DIR'])

        settings['ANALYSIS_DIR'] = analysis_dir
    else:
        analysis_num = "temp"

    analysis_file = utl.get_analysis_file(settings['persistent_analysis'],
                                          settings['ANALYSIS_DIR'])

    # sanity check for tasks

    if not task == "full-analysis" and not settings['persistent_analysis']:
        print(
            "Error.  Setting 'persistent_analysis' is set to 'False' and task is not 'full-analysis'.  This would produce no output.  Aborting.",
            file=stderr,
            flush=True)
        exit(EXIT_FAILURE)

    if task in [
            "confidence-intervals",
            # "permutation-test",
            "csv-files"
    ]:
        if settings['cross_validated_optimization']:
            required_dir = "h2_embeddings"
        else:
            required_dir = "embeddings"
        if not required_dir in analysis_file.keys():
            print(
                "Error.  No existing analysis found.  Please run the 'history-dependence' task first.  Aborting.",
                file=stderr,
                flush=True)
            exit(EXIT_FAILURE)

    csv_stats_file, csv_histdep_data_file, csv_auto_MI_data_file \
        = utl.get_CSV_files(task,
                            settings['persistent_analysis'],
                            settings['ANALYSIS_DIR'])

    if task == "plots":
        for csv_file in [
                csv_stats_file, csv_histdep_data_file, csv_auto_MI_data_file
        ]:
            if csv_file == None:
                print(
                    "Error.  CSV files not found and needed to produce plots.  Please run the 'csv-files' task first.  Aborting.",
                    file=stderr,
                    flush=True)
                exit(EXIT_FAILURE)

    # label for the output
    if not args.label is None:
        settings['label'] = args.label
    else:
        if not 'label' in settings:
            settings['label'] = ""
    if "," in settings['label']:
        new_label = ""
        for char in settings['label']:
            if not char == ",":
                new_label += char
            else:
                new_label += ";"
        settings['label'] = new_label
        print(
            "Warning: Invalid label '{}'. It may not contain any commas, as this conflicts with the CSV file format.  The commas have been replaced by semicolons."
            .format(settings['label']),
            file=stderr,
            flush=True)

    # for cross-validation
    # split up data in two halves
    spike_times_optimization = []
    spike_times_validation = []
    if settings['cross_validated_optimization']:
        for spt in spike_times:
            spt_half_time = (spt[-1] - spt[0]) / 2
            spt_optimization = spt[spt < spt_half_time]
            spt_validation = spt[spt >= spt_half_time] \
                - spt_half_time
            spike_times_optimization += [spt_optimization]
            spike_times_validation += [spt_validation]
    else:
        for spt in spike_times:
            spike_times_optimization += [spt]
            spike_times_validation += [spt]

    spike_times_optimization = np.array(spike_times_optimization)
    spike_times_validation = np.array(spike_times_validation)

    return task, spike_times, spike_times_optimization, spike_times_validation, \
        analysis_file, csv_stats_file, csv_histdep_data_file, csv_auto_MI_data_file, analysis_num, \
        settings
예제 #5
0
def test_get_spike_times_from_file():
    spike_times = utl.get_spike_times_from_file('sample_data/spike_times.dat')
    assert len(spike_times) > 0
예제 #6
0
import hde_utils as utl

spike_times = utl.get_spike_times_from_file('sample_data/spike_times.dat')


def test_get_spike_times_from_file():
    spike_times = utl.get_spike_times_from_file('sample_data/spike_times.dat')
    assert len(spike_times) > 0
예제 #7
0
def test_get_spike_times_from_file():
    estimator_env.spike_times = utl.get_spike_times_from_file(
        spike_times_file_name)
    assert len(estimator_env.spike_times) > 0