def listen_to(candidate, stimulus_set, reset_column='story', average_sentence=True): """ Pass a `stimulus_set` through a model `candidate`. Operates on a sentence-based `stimulus_set`. """ activations = [] for story in ordered_set(stimulus_set[reset_column].values): story_stimuli = stimulus_set[stimulus_set[reset_column] == story] story_stimuli.name = f"{stimulus_set.name}-{story}" story_activations = candidate(stimuli=story_stimuli, average_sentence=average_sentence) activations.append(story_activations) model_activations = merge_data_arrays(activations) # merging does not maintain stimulus order. the following orders again idx = [ model_activations['stimulus_id'].values.tolist().index(stimulus_id) for stimulus_id in itertools.chain.from_iterable( s['stimulus_id'].values for s in activations) ] assert len( set(idx)) == len(idx), "Found duplicate indices to order activations" model_activations = model_activations[{'presentation': idx}] return model_activations
def read_words(candidate, stimulus_set, reset_column='sentence_id', copy_columns=(), average_sentence=False): """ Pass a `stimulus_set` through a model `candidate`. In contrast to the `listen_to` function, this function operates on a word-based `stimulus_set`. """ # Input: stimulus_set = pandas df, col 1 with sentence ID and 2nd col as word. activations = [] for i, reset_id in enumerate(ordered_set(stimulus_set[reset_column].values)): part_stimuli = stimulus_set[stimulus_set[reset_column] == reset_id] # stimulus_ids = part_stimuli['stimulus_id'] sentence_stimuli = StimulusSet({'sentence': ' '.join(part_stimuli['word']), reset_column: list(set(part_stimuli[reset_column]))}) sentence_stimuli.name = f"{stimulus_set.name}-{reset_id}" sentence_activations = candidate(stimuli=sentence_stimuli, average_sentence=average_sentence) for column in copy_columns: sentence_activations[column] = ('presentation', part_stimuli[column]) activations.append(sentence_activations) model_activations = merge_data_arrays(activations) # merging does not maintain stimulus order. the following orders again idx = [model_activations['stimulus_id'].values.tolist().index(stimulus_id) for stimulus_id in itertools.chain.from_iterable(s['stimulus_id'].values for s in activations)] assert len(set(idx)) == len(idx), "Found duplicate indices to order activations" model_activations = model_activations[{'presentation': idx}] return model_activations
def read_words(candidate, stimulus_set): # This is a new version of the listen_to_stories function # Input: stimulus_set = pandas df, col 1 with sentence ID and 2nd col as word. activations = [] for i, sentence_id in enumerate(ordered_set(stimulus_set['sentence_id'].values)): sentence_stimuli = stimulus_set[stimulus_set['sentence_id'] == sentence_id] sentence_stimuli = StimulusSet({'sentence': ' '.join(sentence_stimuli['word']), 'sentence_id': list(set(sentence_stimuli['sentence_id']))}) sentence_stimuli.name = f"{stimulus_set.name}-{sentence_id}" sentence_activations = candidate(stimuli=sentence_stimuli) sentence_activations['stimulus_id'] = ('presentation', 8 * i + np.arange(0, 8)) sentence_activations['sentence_id'] = ('presentation', [sentence_id] * 8) activations.append(sentence_activations) model_activations = merge_data_arrays(activations) # merging does not maintain stimulus order. the following orders again idx = [model_activations['stimulus_id'].values.tolist().index(stimulus_id) for stimulus_id in itertools.chain.from_iterable(s['stimulus_id'].values for s in activations)] assert len(set(idx)) == len(idx), "Found duplicate indices to order activations" model_activations = model_activations[{'presentation': idx}] return model_activations
def _align_stimuli_recordings(stimulus_set, assembly): aligned_stimulus_set = [] partial_sentences = assembly['stimulus_sentence'].values partial_sentences = [ compare_ignore(sentence) for sentence in partial_sentences ] assembly_stimset = {} stimulus_set_index = 0 stories = ordered_set(assembly['story'].values.tolist()) for story in tqdm(sorted(stories), desc='align stimuli', total=len(stories)): story_partial_sentences = [ (sentence, i) for i, (sentence, sentence_story) in enumerate( zip(partial_sentences, assembly['story'].values)) if sentence_story == story ] story_stimuli = stimulus_set[stimulus_set['story'] == story] stimuli_story = ' '.join(story_stimuli['sentence']) stimuli_story_sentence_starts = [0] + [ len(sentence) + 1 for sentence in story_stimuli['sentence'] ] stimuli_story_sentence_starts = np.cumsum( stimuli_story_sentence_starts) assert ' '.join(s for s, i in story_partial_sentences) == compare_ignore( stimuli_story) stimulus_index = 0 Stimulus = namedtuple( 'Stimulus', ['story', 'sentence', 'sentence_num', 'sentence_part']) sentence_parts = defaultdict(lambda: 0) for partial_sentence, assembly_index in story_partial_sentences: full_partial_sentence = '' partial_sentence_index = 0 while partial_sentence_index < len(partial_sentence) \ or stimulus_index < len(stimuli_story) \ and stimuli_story[stimulus_index] in compare_characters + [' ']: if partial_sentence_index < len(partial_sentence) \ and partial_sentence[partial_sentence_index].lower() \ == stimuli_story[stimulus_index].lower(): full_partial_sentence += stimuli_story[stimulus_index] stimulus_index += 1 partial_sentence_index += 1 elif stimuli_story[stimulus_index] in compare_characters + [ ' ' ]: # this case leads to a potential issue: Beginning quotations ' are always appended to # the current instead of the next sentence. For now, I'm hoping this won't lead to issues. full_partial_sentence += stimuli_story[stimulus_index] stimulus_index += 1 elif stimuli_story[stimulus_index] == '-': full_partial_sentence += '-' stimulus_index += 1 if partial_sentence[partial_sentence_index] == ' ': partial_sentence_index += 1 else: raise NotImplementedError() sentence_num = next( index for index, start in enumerate(stimuli_story_sentence_starts) if start >= stimulus_index) - 1 sentence_part = sentence_parts[sentence_num] sentence_parts[sentence_num] += 1 row = Stimulus(sentence=full_partial_sentence, story=story, sentence_num=sentence_num, sentence_part=sentence_part) aligned_stimulus_set.append(row) assembly_stimset[assembly_index] = stimulus_set_index stimulus_set_index += 1 # check aligned_story = "".join(row.sentence for row in aligned_stimulus_set if row.story == story) assert aligned_story == stimuli_story # build StimulusSet aligned_stimulus_set = StimulusSet(aligned_stimulus_set) aligned_stimulus_set['stimulus_id'] = [ ".".join([str(value) for value in values]) for values in zip(*[ aligned_stimulus_set[coord].values for coord in ['story', 'sentence_num', 'sentence_part'] ]) ] aligned_stimulus_set.name = f"{stimulus_set.name}-aligned" # align assembly alignment = [ stimset_idx for assembly_idx, stimset_idx in sorted(assembly_stimset.items(), key=operator.itemgetter(0)) ] assembly_coords = { **{ coord: (dims, values) for coord, dims, values in walk_coords(assembly) }, **{ 'stimulus_id': ('presentation', aligned_stimulus_set['stimulus_id'].values[alignment]), 'meta_sentence': ('presentation', assembly['stimulus_sentence'].values), 'stimulus_sentence': ('presentation', aligned_stimulus_set['sentence'].values[alignment]) } } assembly = type(assembly)(assembly.values, coords=assembly_coords, dims=assembly.dims) return aligned_stimulus_set, assembly
def _merge_voxel_meta(data, meta, bold_shift_seconds): data_missing = set(meta['story'].values) - set(data['story'].values) if data_missing: warnings.warn(f"Stories missing from the data: {data_missing}") meta_missing = set(data['story'].values) - set(meta['story'].values) if meta_missing: warnings.warn(f"Stories missing from the meta: {meta_missing}") ignored_words = [None, '', '<s>', '</s>', '<s'] annotated_data = [] for story in tqdm(ordered_set(data['story'].values), desc='merge meta'): if story not in meta['story'].values: continue story_meta = meta.sel(story=story) story_meta = story_meta.sortby('time_end') story_data = data.sel(story=story).stack(timepoint=['timepoint_value']) story_data = story_data.sortby('timepoint_value') timepoints = story_data['timepoint_value'].values.tolist() assert is_sorted(timepoints) timepoints = [ timepoint - bold_shift_seconds for timepoint in timepoints ] sentences = [] last_timepoint = -np.inf for timepoint in timepoints: if last_timepoint >= max(story_meta['time_end'].values): break if timepoint <= 0: sentences.append(None) continue # ignore fixation period timebin_meta = [ last_timepoint < end <= timepoint for end in story_meta['time_end'].values ] timebin_meta = story_meta[{'time_bin': timebin_meta}] sentence = ' '.join(word.strip() for word in timebin_meta.values if word not in ignored_words) sentence = sentence.lower().strip() # quick-fixes if story == 'Boar' and sentence == 'interactions the the': # Boar duplicate sentence = 'interactions the' if story == 'KingOfBirds' and sentence == 'the fact that the larger': # missing word in TextGrid sentence = 'earth ' + sentence if story == 'MrSticky' and sentence == 'worry don\'t worry i went extra slowly since it\'s': sentence = 'don\'t worry i went extra slowly since it\'s' sentences.append(sentence) last_timepoint = timebin_meta['time_end'].values[-1] sentence_index = [ i for i, sentence in enumerate(sentences) if sentence ] sentences = np.array(sentences)[sentence_index] if story not in ['Boar', 'KingOfBirds', 'MrSticky']: # ignore quick-fixes annotated_sentence = ' '.join(sentences) meta_sentence = ' '.join(word.strip() for word in story_meta.values if word not in ignored_words) \ .lower().strip() assert annotated_sentence == meta_sentence # re-interpret timepoints as stimuli coords = {} for coord_name, dims, coord_value in walk_coords(story_data): dims = [ dim if not dim.startswith('timepoint') else 'presentation' for dim in dims ] # discard the timepoints for which the stimulus did not change (empty word) coord_value = coord_value if not array_is_element( dims, 'presentation') else coord_value[sentence_index] coords[coord_name] = dims, coord_value coords = { **coords, **{ 'stimulus_sentence': ('presentation', sentences) } } story_data = story_data[{ dim: slice(None) if dim != 'timepoint' else sentence_index for dim in story_data.dims }] dims = [ dim if not dim.startswith('timepoint') else 'presentation' for dim in story_data.dims ] story_data = xr.DataArray(story_data.values, coords=coords, dims=dims) story_data['story'] = 'presentation', [story] * len( story_data['presentation']) gather_indexes(story_data) annotated_data.append(story_data) annotated_data = merge_data_arrays(annotated_data) return annotated_data
def num_features_vs_score(benchmark='Pereira2018-encoding', per_layer=True, include_untrained=True): if include_untrained: all_models = [(model, f"{model}-untrained") for model in models] all_models = [ model for model_tuple in all_models for model in model_tuple ] else: all_models = models scores = collect_scores(benchmark=benchmark, models=all_models) scores = average_adjacent(scores) scores = scores.dropna() if not per_layer: scores = choose_best_scores(scores) # count number of features store_file = Path(__file__).parent / "num_features.csv" if store_file.is_file(): num_features = pd.read_csv(store_file) else: num_features = [] for model in tqdm(ordered_set(scores['model'].values), desc='models'): # mock-run stimuli that are already stored mock_extractor = ActivationsExtractorHelper(get_activations=None, reset=None) features = mock_extractor._from_sentences_stored( layers=model_layers[model.replace('-untrained', '')], sentences=None, identifier=model.replace('-untrained', ''), stimuli_identifier='Pereira2018-243sentences.astronaut') if per_layer: for layer in scores['layer'].values[scores['model'] == model]: num_features.append({ 'model': model, 'layer': layer, 'score': len(features.sel(layer=layer)['neuroid']) }) else: num_features.append({ 'model': model, 'score': len(features['neuroid']) }) num_features = pd.DataFrame(num_features) num_features['error'] = np.nan num_features.to_csv(store_file, index=False) if per_layer: assert (scores['layer'].values == num_features['layer'].values).all() # plot colors = [ model_colors[model.replace('-untrained', '')] for model in scores['model'].values ] fig, ax = _plot_scores1_2(num_features, scores, color=colors, xlabel="number of features", ylabel=benchmark) savefig(fig, savename=Path(__file__).parent / f"num_features-{benchmark}" + ("-layerwise" if per_layer else ""))