示例#1
0
def filter_units_by_accuracy(units, stimulus_list, threshold=0.8):
    ntrial = len(list(filter(
        lambda x: 'metadata' in x['stimulus'] and "label" in x['stimulus']['metadata'] and \
            x['stimulus']['metadata']['label']=='integrity',
        stimulus_list)))/3
    ntrain = int(np.ceil(ntrial/2))
    ntest = int(np.floor(ntrial/2))
    tvt = glia.TVT(ntrain,ntest,0)

    get_solid = glia.compose(
        glia.f_create_experiments(stimulus_list),
        glia.filter_integrity,
        partial(glia.group_by,
            key=lambda x: x["stimulus"]["metadata"]["group"]),
        glia.group_dict_to_list,
        glia.f_split_list(tvt)
    )

    classification_data = glia.apply_pipeline(get_solid,units, progress=True)
    units_accuracy = glia.pmap(unit_classification_accuracy,classification_data)
    filter_threshold = glia.f_filter(lambda k,x: x['on']>threshold or x['off']>threshold)
    return set(filter_threshold(units_accuracy).keys())
示例#2
0
def analyze(ctx,
            filename,
            trigger,
            threshold,
            eyecandy,
            ignore_extra=False,
            fix_missing=False,
            output=None,
            notebook=None,
            configuration=None,
            verbose=False,
            debug=False,
            processes=None,
            by_channel=False,
            integrity_filter=0.0,
            analog_idx=1,
            default_channel_map=False,
            dev=False):
    """Analyze data recorded with eyecandy.
    
    This command/function preprocesses the data & aligns stimuli to ephys
    recording.
    """
    print("version 0.5.1")
    init_logging(filename, processes, verbose, debug)
    #### FILEPATHS
    logger.debug(str(filename) + "   " + str(os.path.curdir))
    if not os.path.isfile(filename):
        try:
            filename = glia.match_filename(filename, "txt")
        except:
            filename = glia.match_filename(filename, "bxr")

    data_directory, data_name = os.path.split(filename)
    name, extension = os.path.splitext(data_name)
    analog_file = os.path.join(data_directory, name + '.analog')
    if not os.path.isfile(analog_file):
        # use 3brain analog file
        analog_file = os.path.join(data_directory, name + '.analog.brw')

    stimulus_file = os.path.join(data_directory, name + ".stim")
    ctx.obj = {"filename": os.path.join(data_directory, name)}
    print(f"Analyzing {name}")

    if configuration != None:
        with open(configuration, 'r') as f:
            user_config = yaml.safe_load(f)
        config.user_config = user_config
        if "analog_calibration" in user_config:
            config.analog_calibration = user_config["analog_calibration"]
        if "notebook" in user_config:
            notebook = user_config["notebook"]
        if "eyecandy" in user_config:
            eyecandy = user_config["eyecandy"]
        if "processes" in user_config:
            processes = user_config["processes"]
        if "integrity_filter" in user_config:
            integrity_filter = user_config["integrity_filter"]
        if "by_channel" in user_config:
            by_channel = user_config["by_channel"]

    if not notebook:
        notebook = glia.find_notebook(data_directory)

    lab_notebook = glia.open_lab_notebook(notebook)
    logger.info(name)
    experiment_protocol = glia.get_experiment_protocol(lab_notebook, name)
    flicker_version = experiment_protocol["flickerVersion"]

    #### LOAD STIMULUS
    try:
        metadata, stimulus_list, method = glia.read_stimulus(stimulus_file)
        ctx.obj["stimulus_list"] = stimulus_list
        ctx.obj["metadata"] = metadata
        # assert method=='analog-flicker'
    except:
        print(
            "No .stim file found. Creating from .analog file.".format(trigger))
        if flicker_version == 0.3:
            metadata, stimulus_list = glia.create_stimuli(
                analog_file, stimulus_file, notebook, name, eyecandy,
                analog_idx, ignore_extra, config.analog_calibration, threshold)
            ctx.obj["stimulus_list"] = stimulus_list
            ctx.obj["metadata"] = metadata
            print('finished creating .stim file')
        elif trigger == "ttl":
            raise ValueError('not implemented')
        else:
            raise ValueError("invalid trigger: {}".format(trigger))

    # look for .frames file
    try:
        lab_notebook_notype = glia.open_lab_notebook(notebook,
                                                     convert_types=False)
        protocol_notype = glia.get_experiment_protocol(lab_notebook_notype,
                                                       name)
        date_prefix = os.path.join(data_directory,
                                   protocol_notype['date'].replace(':', '_'))
        frames_file = date_prefix + "_eyecandy_frames.log"
        video_file = date_prefix + "_eyecandy.mkv"
        frame_log = pd.read_csv(frames_file)
        frame_log = frame_log[:-1]  # last frame is not encoded for some reason
        ctx.obj["frame_log"] = frame_log
        ctx.obj["video_file"] = video_file
    except Exception as e:
        extype, value, tb = sys.exc_info()
        traceback.print_exc()
        print(e)
        ctx.obj["frame_log"] = None
        ctx.obj["video_file"] = None
        print("Attempting to continue without frame log...")

    #### LOAD SPIKES
    spyking_regex = re.compile('.*\.result.hdf5$')
    eye = experiment_protocol['eye']
    experiment_n = experiment_protocol['experimentNumber']

    date = experiment_protocol['date'].date().strftime("%y%m%d")

    retina_id = date + '_R' + eye + '_E' + experiment_n
    if extension == ".txt":
        ctx.obj["units"] = glia.read_plexon_txt_file(filename, retina_id,
                                                     channel_map)
    elif extension == ".bxr":
        if default_channel_map:
            channel_map_3brain = config.channel_map_3brain
        else:
            channel_map_3brain = None
        ctx.obj["units"] = glia.read_3brain_spikes(filename,
                                                   retina_id,
                                                   channel_map_3brain,
                                                   truncate=dev)
    elif re.match(spyking_regex, filename):
        ctx.obj["units"] = glia.read_spyking_results(filename)
    else:
        raise ValueError(
            'could not read {}. Is it a plexon or spyking circus file?')

    #### DATA MUNGING OPTIONS
    if integrity_filter > 0.0:
        good_units = solid.filter_units_by_accuracy(ctx.obj["units"],
                                                    ctx.obj['stimulus_list'],
                                                    integrity_filter)
        filter_good_units = glia.f_filter(lambda u, v: u in good_units)
        ctx.obj["units"] = filter_good_units(ctx.obj["units"])

    if by_channel:
        ctx.obj["units"] = glia.combine_units_by_channel(ctx.obj["units"])

    # prepare_output
    plot_directory = os.path.join(data_directory, name + "-plots")
    config.plot_directory = plot_directory

    os.makedirs(plot_directory, exist_ok=True)
    os.chmod(plot_directory, 0o777)

    if output == "pdf":
        logger.debug("Outputting pdf")
        ctx.obj["retina_pdf"] = PdfPages(
            glia.plot_pdf_path(plot_directory, "retina"))
        ctx.obj["unit_pdfs"] = glia.open_pdfs(plot_directory,
                                              list(ctx.obj["units"].keys()),
                                              Unit.name_lookup())
        # c connotes 'continuation' for continuation passing style
        ctx.obj["c_unit_fig"] = partial(glia.add_to_unit_pdfs,
                                        unit_pdfs=ctx.obj["unit_pdfs"])
        ctx.obj["c_retina_fig"] = lambda x: ctx.obj["retina_pdf"].savefig(x)

    elif output == "png":
        logger.debug("Outputting png")
        ctx.obj["c_unit_fig"] = glia.save_unit_fig
        ctx.obj["c_retina_fig"] = glia.save_retina_fig
        os.makedirs(os.path.join(plot_directory, "00-all"), exist_ok=True)

        for unit_id in ctx.obj["units"].keys():
            name = unit_id
            os.makedirs(os.path.join(plot_directory, name), exist_ok=True)
示例#3
0
def save_grating_npz(units,
                     stimulus_list,
                     name,
                     append,
                     group_by,
                     sinusoid=False):
    "Psychophysics discrimination grating 0.2.0"
    print("Saving grating NPZ file.")
    if sinusoid:
        stimulus_type = "SINUSOIDAL_GRATING"
    else:
        stimulus_type = 'GRATING'
    get_gratings = glia.compose(
        partial(glia.create_experiments,
                stimulus_list=stimulus_list,
                append_lifespan=append),
        glia.f_filter(lambda x: x['stimulusType'] == stimulus_type),
        partial(glia.group_by, key=group_by),
        glia.f_map(partial(glia.group_by, key=lambda x: x["width"])),
        glia.f_map(
            glia.f_map(
                partial(glia.group_by,
                        key=lambda x: x["metadata"]["cohort"]))))
    gratings = get_gratings(units)

    max_duration = 0.0
    for condition, sizes in gratings.items():
        for size, cohorts in sizes.items():
            for cohort, experiments in cohorts.items():
                max_duration = max(max_duration, experiments[0]['lifespan'])
    max_duration += append

    conditions = sorted(list(gratings.keys()))
    print("Conditions:", name, conditions)
    nconditions = len(conditions)
    example_condition = glia.get_value(gratings)
    sizes = sorted(list(example_condition.keys()))
    print("Sizes:", sizes)
    nsizes = len(sizes)

    example_size = glia.get_value(example_condition)
    ncohorts = len(example_size)
    # print(list(gratings.values()))
    d = int(np.ceil(max_duration * 1000))  # 1ms bins
    tvt = glia.tvt_by_percentage(ncohorts, 60, 40, 0)
    # 2 per cohort
    training_data = np.full((nconditions, nsizes, tvt.training * 2, d,
                             Unit.nrow, Unit.ncol, Unit.nunit),
                            0,
                            dtype='int8')
    training_target = np.full((nconditions, nsizes, tvt.training * 2),
                              0,
                              dtype='int8')
    validation_data = np.full((nconditions, nsizes, tvt.validation * 2, d,
                               Unit.nrow, Unit.ncol, Unit.nunit),
                              0,
                              dtype='int8')
    validation_target = np.full((nconditions, nsizes, tvt.validation * 2),
                                0,
                                dtype='int8')

    condition_map = {c: i for i, c in enumerate(conditions)}
    size_map = {s: i for i, s in enumerate(sizes)}
    for condition, sizes in gratings.items():
        for size, cohorts in sizes.items():
            X = glia.f_split_dict(tvt)(cohorts)

            td, tt = glia.experiments_to_ndarrays(glia.training_cohorts(X),
                                                  get_grating_class_from_stim,
                                                  append)
            missing_duration = d - td.shape[1]
            pad_td = np.pad(td, ((0, 0), (0, missing_duration), (0, 0), (0, 0),
                                 (0, 0)),
                            mode='constant')
            condition_index = condition_map[condition]
            size_index = size_map[size]
            training_data[condition_index, size_index] = pad_td
            training_target[condition_index, size_index] = tt

            td, tt = glia.experiments_to_ndarrays(glia.validation_cohorts(X),
                                                  get_grating_class_from_stim,
                                                  append)
            pad_td = np.pad(td, ((0, 0), (0, missing_duration), (0, 0), (0, 0),
                                 (0, 0)),
                            mode='constant')
            validation_data[condition_index, size_index] = pad_td
            validation_target[condition_index, size_index] = tt

    print('saving to ', name)
    np.savez(name,
             training_data=training_data,
             training_target=training_target,
             validation_data=validation_data,
             validation_target=validation_target)
示例#4
0
def save_checkerboard_flicker_npz(units,
                                  stimulus_list,
                                  name,
                                  append,
                                  group_by,
                                  quad=False):
    "Psychophysics discrimination checkerboard 0.2.0"
    print("Saving checkerboard NPZ file.")

    get_checkers = glia.compose(
        partial(
            glia.create_experiments,
            progress=True,
            append_lifespan=append,
            # stimulus_list=stimulus_list,append_lifespan=0.5),
            stimulus_list=stimulus_list),
        partial(glia.group_by, key=lambda x: x["metadata"]["group"]),
        glia.group_dict_to_list,
        glia.f_filter(group_contains_checkerboard),
        glia.f_map(
            glia.f_filter(lambda x: x['stimulusType'] == 'CHECKERBOARD')),
        glia.f_map(glia.merge_experiments),
        partial(glia.group_by, key=group_by),
        glia.f_map(partial(glia.group_by, key=lambda x: x["size"])),
        glia.f_map(
            glia.f_map(
                partial(glia.group_by,
                        key=lambda x: x["metadata"]["cohort"]))))
    checkers = get_checkers(units)

    max_duration = 0.0
    for condition, sizes in checkers.items():
        for size, cohorts in sizes.items():
            for cohort, experiments in cohorts.items():
                max_duration = max(max_duration, experiments[0]['lifespan'])
    max_duration += append
    print(f"max_duration: {max_duration}")

    conditions = sorted(list(checkers.keys()))
    print("Conditions:", name, conditions)
    nconditions = len(conditions)
    example_condition = glia.get_value(checkers)
    sizes = sorted(list(example_condition.keys()))
    nsizes = len(sizes)
    # TODO remove
    if max_duration < 9:
        print(example_condition)

    example_size = glia.get_value(example_condition)
    ncohorts = len(example_size)
    # print(list(checkers.values()))
    d = int(np.ceil(max_duration * 1000))  # 1ms bins

    tvt = glia.tvt_by_percentage(ncohorts, 60, 40, 0)
    logger.info(f"{tvt}, {ncohorts}")
    # (TODO?) 2 dims for first checkerboard and second checkerboard
    # 4 per cohort
    if quad:
        ntraining = tvt.training * 4
        nvalid = tvt.validation * 4
    else:
        ntraining = tvt.training * 2
        nvalid = tvt.validation * 2

    training_data = np.full(
        (nconditions, nsizes, ntraining, d, Unit.nrow, Unit.ncol, Unit.nunit),
        0,
        dtype='int8')
    training_target = np.full((nconditions, nsizes, ntraining),
                              0,
                              dtype='int8')
    validation_data = np.full(
        (nconditions, nsizes, nvalid, d, Unit.nrow, Unit.ncol, Unit.nunit),
        0,
        dtype='int8')
    validation_target = np.full((nconditions, nsizes, nvalid), 0, dtype='int8')
    # test_data = np.full((nsizes,tvt.test,d,nunits),0,dtype='int8')
    # test_target = np.full((nsizes,tvt.test),0,dtype='int8')

    if quad:
        get_class = get_checker_quad_discrimination_class
    else:
        get_class = get_checker_discrimination_class
    condition_map = {c: i for i, c in enumerate(conditions)}
    size_map = {s: i for i, s in enumerate(sizes)}
    for condition, sizes in checkers.items():
        for size, cohorts in sizes.items():
            X = glia.f_split_dict(tvt)(cohorts)

            td, tt = glia.experiments_to_ndarrays(glia.training_cohorts(X),
                                                  get_class, append)
            logger.info(td.shape)
            missing_duration = d - td.shape[1]
            pad_td = np.pad(td, ((0, 0), (0, missing_duration), (0, 0), (0, 0),
                                 (0, 0)),
                            mode='constant')
            condition_index = condition_map[condition]
            size_index = size_map[size]
            training_data[condition_index, size_index] = pad_td
            training_target[condition_index, size_index] = tt

            td, tt = glia.experiments_to_ndarrays(glia.validation_cohorts(X),
                                                  get_class, append)
            pad_td = np.pad(td, ((0, 0), (0, missing_duration), (0, 0), (0, 0),
                                 (0, 0)),
                            mode='constant')
            validation_data[condition_index, size_index] = pad_td
            validation_target[condition_index, size_index] = tt

    print('saving to ', name)
    np.savez(name,
             training_data=training_data,
             training_target=training_target,
             validation_data=validation_data,
             validation_target=validation_target)
示例#5
0
def save_images_h5(units, stimulus_list, name, frame_log, video_file, append):
    """Assumes each group is three stimuli with image in second position.
    
    Concatenate second stimuli with first 0.5s of third stimuli"""
    # open first so if there's a problem we don't waste time
    compression_level = 3
    dset_filter = tables.filters.Filters(complevel=compression_level,
                                         complib='blosc:zstd')
    with tables.open_file(name + ".h5", 'w') as h5:
        class_resolver = get_classes_from_stimulus_list(stimulus_list)
        nclasses = len(class_resolver)
        frames, image_classes = glia.get_images_from_vid(
            stimulus_list, frame_log, video_file)

        image_class_num = list(
            map(lambda x: class_resolver[str(x)], image_classes))
        idx_sorted_order = np.argsort(image_class_num)

        # save mapping of class_num target to class metadata
        # this way h5.root.image_classes[n] will give the class metadata string
        logger.info("create class_resolver with max string of 256")
        resolver = h5.create_carray(h5.root, "image_classes",
                                    tables.StringAtom(itemsize=256),
                                    (nclasses, ))
        img_class_array = np.array(image_classes,
                                   dtype="S256")[idx_sorted_order]
        for i, image_class in enumerate(img_class_array):
            resolver[i] = image_class

        atom = tables.Atom.from_dtype(frames[0].dtype)
        images = h5.create_carray(h5.root,
                                  "images",
                                  atom, (nclasses, *frames[0].shape),
                                  filters=dset_filter)

        frames = np.array(frames)
        nFrames = len(frames)
        for i, idx in enumerate(idx_sorted_order):
            if idx >= nFrames:
                logger.warn(
                    f"skipping class {image_classes[idx]} as no accompanying frame. This should only occur if experiment stopped early."
                )
                continue
            images[i] = frames[idx]

        print("finished saving images")
        get_image_responses = glia.compose(
            # returns a list
            partial(glia.create_experiments,
                    stimulus_list=stimulus_list,
                    progress=True,
                    append_lifespan=append),
            partial(glia.group_by, key=lambda x: x["metadata"]["group"]),
            glia.group_dict_to_list,
            glia.f_filter(partial(glia.group_contains, "IMAGE")),
            # truncate to 0.5s
            glia.f_map(lambda x: [x[1], truncate(x[2], 0.5)]),
            glia.f_map(glia.merge_experiments),
            partial(glia.group_by, key=lambda x: x["metadata"]["cohort"]),
            # glia.f_map(f_flatten)
        )

        image_responses = get_image_responses(units)
        ncohorts = len(image_responses)
        ex_cohort = glia.get_value(image_responses)
        images_per_cohort = len(ex_cohort)
        print("images_per_cohort", images_per_cohort)
        duration = ex_cohort[0]["lifespan"]

        d = int(np.ceil(duration * 1000))  # 1ms bins
        logger.info(f"ncohorts: {ncohorts}")
        # import pdb; pdb.set_trace()

        logger.info(f"nclasses: {nclasses}")
        if nclasses < 256:
            class_dtype = np.dtype('uint8')
        else:
            class_dtype = np.dtype('uint16')

        class_resolver_func = lambda c: class_resolver[str(c)]

        # determine shape
        experiments = glia.flatten_group_dict(image_responses)
        nE = len(experiments)
        d = int(np.ceil(duration * 1000))  # 1ms bins
        data_shape = (nE, d, Unit.nrow, Unit.ncol, Unit.nunit)

        print(f"writing to {name}.h5 with zstd compression...")
        data = h5.create_carray("/",
                                "data",
                                tables.Atom.from_dtype(np.dtype('uint8')),
                                shape=data_shape,
                                filters=dset_filter)
        target = h5.create_carray("/",
                                  "target",
                                  tables.Atom.from_dtype(class_dtype),
                                  shape=(nE, ),
                                  filters=dset_filter)

        glia.experiments_to_h5(experiments,
                               data,
                               target,
                               partial(get_image_class_from_stim,
                                       class_resolver=class_resolver_func),
                               append,
                               class_dtype=class_dtype)
示例#6
0
def save_acuity_image_npz(units, stimulus_list, name, append):
    "Assumes metadata includes a parameter to group by, as well as a blank image"

    get_letters = glia.compose(
        partial(glia.create_experiments,
                stimulus_list=stimulus_list,
                progress=True,
                append_lifespan=append),
        partial(glia.group_by, key=lambda x: x["metadata"]["group"]),
        glia.group_dict_to_list,
        glia.f_filter(partial(glia.group_contains, "IMAGE")),
        glia.f_map(lambda x: x[0:2]),
        partial(glia.group_by, key=lambda x: x[1]["metadata"]["parameter"]),
        glia.f_map(
            partial(glia.group_by, key=lambda x: x[1]["metadata"]["cohort"])),
        glia.f_map(glia.f_map(f_flatten)),
        glia.f_map(glia.f_map(partial(balance_blanks, key='image'))))
    letters = get_letters(units)
    sizes = sorted(list(letters.keys()))
    nsizes = len(sizes)
    ncohorts = len(list(letters.values())[0])
    ex_letters = glia.get_value(list(letters.values())[0])
    nletters = len(ex_letters)
    print("nletters", nletters)
    duration = ex_letters[0]["lifespan"]

    # small hack to fix bug in letters 0.2.0
    letter_duration = ex_letters[1]['lifespan']
    if duration != letter_duration:
        new_letters = {}
        for size, cohorts in letters.items():
            new_letters[size] = {}
            for cohort, stimuli in cohorts.items():
                new_letters[size][cohort] = list(
                    map(lambda s: truncate(s, letter_duration), stimuli))
        letters = new_letters

    d = int(np.ceil(duration * 1000))  # 1ms bins
    nunits = len(units.keys())
    tvt = glia.tvt_by_percentage(ncohorts, 60, 40, 0)
    logger.info(f"{tvt}, ncohorts: {ncohorts}")

    experiments_per_cohort = 11
    training_data = np.full((nsizes, tvt.training * experiments_per_cohort, d,
                             Unit.nrow, Unit.ncol, Unit.nunit),
                            0,
                            dtype='int8')
    training_target = np.full((nsizes, tvt.training * experiments_per_cohort),
                              0,
                              dtype='int8')
    validation_data = np.full((nsizes, tvt.validation * experiments_per_cohort,
                               d, Unit.nrow, Unit.ncol, Unit.nunit),
                              0,
                              dtype='int8')
    validation_target = np.full(
        (nsizes, tvt.validation * experiments_per_cohort), 0, dtype='int8')

    size_map = {s: i for i, s in enumerate(sizes)}
    for size, cohorts in letters.items():
        X = glia.f_split_dict(tvt)(cohorts)
        logger.info(f"ncohorts: {len(cohorts)}")
        td, tt = glia.experiments_to_ndarrays(glia.training_cohorts(X),
                                              acuity_image_class, append)
        logger.info(td.shape)
        missing_duration = d - td.shape[1]
        pad_td = np.pad(td, ((0, 0), (0, missing_duration), (0, 0), (0, 0),
                             (0, 0)),
                        mode='constant')
        size_index = size_map[size]
        training_data[size_index] = pad_td
        training_target[size_index] = tt

        td, tt = glia.experiments_to_ndarrays(glia.validation_cohorts(X),
                                              acuity_image_class, append)
        pad_td = np.pad(td, ((0, 0), (0, missing_duration), (0, 0), (0, 0),
                             (0, 0)),
                        mode='constant')
        validation_data[size_index] = pad_td
        validation_target[size_index] = tt

    np.savez(name,
             training_data=training_data,
             training_target=training_target,
             validation_data=validation_data,
             validation_target=validation_target)
示例#7
0
def save_letter_npz(units, stimulus_list, name, append):
    print(
        "Saving letter NPZ file. Warning: not including Off response--performance can be improved!"
    )
    # TODO use merge_experiment
    # TODO add TEST!!!
    get_letters = glia.compose(
        partial(glia.create_experiments,
                stimulus_list=stimulus_list,
                progress=True,
                append_lifespan=append),
        partial(glia.group_by, key=lambda x: x["metadata"]["group"]),
        glia.group_dict_to_list, glia.f_filter(group_contains_letter),
        glia.f_map(lambda x: x[0:2]),
        partial(glia.group_by, key=lambda x: x[1]["size"]),
        glia.f_map(
            partial(glia.group_by, key=lambda x: x[1]["metadata"]["cohort"])),
        glia.f_map(glia.f_map(f_flatten)),
        glia.f_map(glia.f_map(balance_blanks)))
    letters = get_letters(units)
    sizes = sorted(list(letters.keys()))
    nsizes = len(sizes)
    ncohorts = len(list(letters.values())[0])
    ex_letters = glia.get_value(list(letters.values())[0])
    nletters = len(ex_letters)
    print("nletters", nletters)
    duration = ex_letters[0]["lifespan"]

    d = int(np.ceil(duration * 1000))  # 1ms bins
    nunits = len(units.keys())
    tvt = glia.tvt_by_percentage(ncohorts, 60, 40, 0)
    logger.info(f"{tvt}, ncohorts: {ncohorts}")

    experiments_per_cohort = 11
    training_data = np.full((nsizes, tvt.training * experiments_per_cohort, d,
                             Unit.nrow, Unit.ncol, Unit.nunit),
                            0,
                            dtype='int8')
    training_target = np.full((nsizes, tvt.training * experiments_per_cohort),
                              0,
                              dtype='int8')
    validation_data = np.full((nsizes, tvt.validation * experiments_per_cohort,
                               d, Unit.nrow, Unit.ncol, Unit.nunit),
                              0,
                              dtype='int8')
    validation_target = np.full(
        (nsizes, tvt.validation * experiments_per_cohort), 0, dtype='int8')

    size_map = {s: i for i, s in enumerate(sizes)}
    for size, cohorts in letters.items():
        X = glia.f_split_dict(tvt)(cohorts)
        logger.info(f"ncohorts: {len(cohorts)}")
        td, tt = glia.experiments_to_ndarrays(glia.training_cohorts(X),
                                              letter_class, append)
        logger.info(td.shape)
        missing_duration = d - td.shape[1]
        pad_td = np.pad(td, ((0, 0), (0, missing_duration), (0, 0), (0, 0),
                             (0, 0)),
                        mode='constant')
        size_index = size_map[size]
        training_data[size_index] = pad_td
        training_target[size_index] = tt

        td, tt = glia.experiments_to_ndarrays(glia.validation_cohorts(X),
                                              letter_class, append)
        pad_td = np.pad(td, ((0, 0), (0, missing_duration), (0, 0), (0, 0),
                             (0, 0)),
                        mode='constant')
        validation_data[size_index] = pad_td
        validation_target[size_index] = tt

    np.savez(name,
             training_data=training_data,
             training_target=training_target,
             validation_data=validation_data,
             validation_target=validation_target)
示例#8
0
def save_eyechart_npz(units, stimulus_list, name, append=0.5):
    print("Saving eyechart NPZ file.")

    # TODO add blanks
    get_letters = glia.compose(
        partial(glia.create_experiments,
                stimulus_list=stimulus_list,
                append_lifespan=append),
        partial(glia.group_by, key=lambda x: x["metadata"]["group"]),
        glia.group_dict_to_list,
        glia.f_filter(group_contains_letter),
        glia.f_map(lambda x: adjust_lifespan(x[1])),
        partial(glia.group_by, key=lambda x: x["size"]),
        glia.f_map(
            partial(glia.group_by,
                    key=lambda x: x[1]["stimulus"]["metadata"]["cohort"])),
        glia.f_map(glia.f_map(f_flatten)),
    )
    letters = get_letters(units)
    sizes = sorted(list(letters.keys()))
    nsizes = len(sizes)
    ncohorts = len(list(letters.values())[0])
    ex_letters = glia.get_a_value(list(letters.values())[0])
    nletters = len(ex_letters)
    print("nletters", nletters)
    duration = ex_letters[0]["lifespan"]
    d = int(np.ceil(duration * 1000))  # 1ms bins
    nunits = len(units.keys())
    tvt = glia.tvt_by_percentage(ncohorts, 60, 40, 0)
    logger.info(f"{tvt}, {ncohorts}")
    training_data = np.full((nsizes, tvt.training, d, nunits), 0, dtype='int8')
    training_target = np.full((nsizes, tvt.training), 0, dtype='int8')
    validation_data = np.full((nsizes, tvt.validation, d, nunits),
                              0,
                              dtype='int8')
    validation_target = np.full((nsizes, tvt.validation), 0, dtype='int8')
    test_data = np.full((nsizes, tvt.test, d, nunits), 0, dtype='int8')
    test_target = np.full((nsizes, tvt.test), 0, dtype='int8')

    size_map = {s: i for i, s in enumerate(sizes)}
    for size, experiments in letters.items():
        split = glia.f_split_dict(tvt)
        flatten_cohort = glia.compose(glia.group_dict_to_list, f_flatten)
        X = glia.tvt_map(split(experiments), flatten_cohort)

        td, tt = glia.experiments_to_ndarrays(X.training, letter_class, append)
        size_index = size_map[size]
        training_data[size_index] = td
        training_target[size_index] = tt

        td, tt = glia.experiments_to_ndarrays(X.validation, letter_class,
                                              append)
        validation_data[size_index] = td
        validation_target[size_index] = tt

        td, tt = glia.experiments_to_ndarrays(X.test, letter_class, append)
        test_data[size_index] = td
        test_target[size_index] = tt

    np.savez(name,
             training_data=training_data,
             training_target=training_target,
             validation_data=validation_data,
             validation_target=validation_target)