Ejemplo n.º 1
0
def test_merge_extractor_results_by_features():
    np.random.seed(100)
    image_dir = join(get_test_data_path(), 'image')
    stim = ImageStim(join(image_dir, 'apple.jpg'))

    # Merge results for static Stims (no onsets)
    extractors = [BrightnessExtractor(), VibranceExtractor()]
    results = [e.extract(stim) for e in extractors]
    df = ExtractorResult.merge_features(results)

    de = DummyExtractor()
    de_names = ['Extractor1', 'Extractor2', 'Extractor3']
    results = [de.extract(stim, name) for name in de_names]
    df = ExtractorResult.merge_features(results)
    assert df.shape == (177, 10)
    assert df.columns.levels[1].unique().tolist() == ['duration', 0, 1, 2, '']
    assert df.columns.levels[0].unique().tolist() == de_names + ['onset', 'stim']
Ejemplo n.º 2
0
def test_merge_extractor_results_by_stims():
    image_dir = join(get_test_data_path(), 'image')
    stim1 = ImageStim(join(image_dir, 'apple.jpg'))
    stim2 = ImageStim(join(image_dir, 'obama.jpg'))
    de = DummyExtractor()
    results = [de.extract(stim1), de.extract(stim2)]
    df = ExtractorResult.merge_stims(results)
    assert df.shape == (200, 5)
    assert df.columns.tolist() == ['onset', 'duration', 0, 1, 2]
    assert set(df.index.levels[1].unique()) == set(['obama.jpg', 'apple.jpg'])
Ejemplo n.º 3
0
def test_convert_to_long():
    audio_dir = join(get_test_data_path(), 'audio')
    stim = AudioStim(join(audio_dir, 'barber.wav'))
    ext = STFTAudioExtractor(frame_size=1., spectrogram=False,
                        bins=[(100, 300), (300, 3000), (3000, 20000)])
    timeline = ext.extract(stim)
    long_timeline = to_long_format(timeline)
    assert long_timeline.shape == (timeline.to_df().shape[0] * 3, 4)
    assert 'feature' in long_timeline.columns
    assert 'value' in long_timeline.columns
    assert '100_300' not in long_timeline.columns
    timeline = ExtractorResult.merge_features([timeline])
    long_timeline = to_long_format(timeline)
    assert 'feature' in long_timeline.columns
    assert 'extractor' in long_timeline.columns
    assert '100_300' not in long_timeline.columns