def test_augment_audio_with_invalid_spectrogram_height(test_raw_file, segments, model):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model)
            with pytest.raises(ValueError):
                spectrogram = np.ones((10, 10, 1))
                generator.augment_audio(spectrogram)
def test_augment_audio_with_invalid_spectrogram_length(test_raw_file, segments, model):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model)
            target_shape = model.audio_input_shape
            with pytest.raises(ValueError):
                spectrogram = np.ones((target_shape[0], 1000, 1))
                generator.augment_audio(spectrogram)
def test_augment_audio_should_match_model_audio_input_shape_with_augmentor(test_raw_file, segments, model, augmentor):
    with temp_dir(segments[0].dir):
        with temp_copy(test_raw_file, segments[0].dir):
            generator = SegmentsGenerator(segments, model, vision_augmentor=augmentor)
            target_shape = model.audio_input_shape
            spectrogram = np.ones((target_shape[0], 10, 1))
            assert all(map(lambda x: x.shape == target_shape, generator.augment_audio(spectrogram)))