def sample_multi_frame_dataset(tmpdir): inputs_path = os.path.join(tmpdir.strpath, 'inputs.hdf5') targets_path = os.path.join(tmpdir.strpath, 'targets.hdf5') corpus = resources.create_dataset() container_inputs = containers.Container(inputs_path) container_targets = containers.Container(targets_path) container_inputs.open() container_targets.open() container_inputs.set('utt-1', np.arange(60).reshape(15, 4)) container_inputs.set('utt-2', np.arange(80).reshape(20, 4)) container_inputs.set('utt-3', np.arange(44).reshape(11, 4)) container_inputs.set('utt-4', np.arange(12).reshape(3, 4)) container_inputs.set('utt-5', np.arange(16).reshape(4, 4)) container_targets.set('utt-1', np.arange(30).reshape(15, 2)) container_targets.set('utt-2', np.arange(40).reshape(20, 2)) container_targets.set('utt-3', np.arange(22).reshape(11, 2)) container_targets.set('utt-4', np.arange(6).reshape(3, 2)) container_targets.set('utt-5', np.arange(8).reshape(4, 2)) return feeding.MultiFrameDataset(corpus, [container_inputs, container_targets], 4)
def test_return_correct_length_for_chunk_at_end_of_utterance( self, sample_multi_frame_dataset): ds_length_enabled = feeding.MultiFrameDataset( sample_multi_frame_dataset.utt_ids, sample_multi_frame_dataset.containers, 4, return_length=True) assert len(ds_length_enabled[11]) == 3 assert ds_length_enabled[11][2] == 3
def test_return_correct_length_for_chunk_with_full_size( self, sample_multi_frame_dataset): ds_length_enabled = feeding.MultiFrameDataset( sample_multi_frame_dataset.utt_ids, sample_multi_frame_dataset.containers, 4, return_length=True) assert len(ds_length_enabled[9]) == 3 assert ds_length_enabled[9][2] == 4
def test_pads_shorter_chunks_with_zeros(self, sample_multi_frame_dataset): ds_pad_enabled = feeding.MultiFrameDataset( sample_multi_frame_dataset.utt_ids, sample_multi_frame_dataset.containers, 4, pad=True) assert len(ds_pad_enabled[11]) == 3 exp = np.pad(np.arange(12).reshape(3, 4) + 32, ((0, 1), (0, 0)), mode='constant', constant_values=0) assert np.array_equal(ds_pad_enabled[11][0], exp) exp = np.pad(np.arange(6).reshape(3, 2) + 16, ((0, 1), (0, 0)), mode='constant', constant_values=0) assert np.array_equal(ds_pad_enabled[11][1], exp) assert ds_pad_enabled[11][2] == 3
def test_raises_error_if_frames_per_chunk_is_smaller_than_one(self): with pytest.raises(ValueError): feeding.MultiFrameDataset(['utt-1'], [], 0)