def test_collate_frames(self): vals = [ torch.tensor([4.5, 2.3, 1.2]).unsqueeze(-1).expand(-1, 10), torch.tensor([6.7, 9.8]).unsqueeze(-1).expand(-1, 10), torch.tensor([7.7, 5.4, 6.2, 8.0]).unsqueeze(-1).expand(-1, 10), torch.tensor([1.5]).unsqueeze(-1).expand(-1, 10), ] expected_res1 = (torch.tensor([ [4.5, 2.3, 1.2, 0.0], [6.7, 9.8, 0.0, 0.0], [7.7, 5.4, 6.2, 8.0], [1.5, 0.0, 0.0, 0.0], ]).unsqueeze(-1).expand(-1, -1, 10)) expected_res2 = (torch.tensor([ [0.0, 4.5, 2.3, 1.2], [0.0, 0.0, 6.7, 9.8], [7.7, 5.4, 6.2, 8.0], [0.0, 0.0, 0.0, 1.5], ]).unsqueeze(-1).expand(-1, -1, 10)) res = utils.collate_frames(vals, pad_value=0.0, left_pad=False) self.assertTensorEqual(res, expected_res1) res = utils.collate_frames(vals, pad_value=0.0, left_pad=True) self.assertTensorEqual(res, expected_res2)
def merge(key, left_pad, move_eos_to_beginning=False): if key == 'source': return speech_utils.collate_frames( [s[key] for s in samples], 0.0, left_pad, ) elif key == 'target': return data_utils.collate_tokens( [s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning, ) else: raise ValueError('Invalid key.')
def merge(key): if key == "source": return speech_utils.collate_frames([s[key] for s in samples], 0.0) elif key == "target": max_num_transitions = max(s["target"].num_transitions for s in samples) max_num_states = max(s["target"].num_states for s in samples) return ChainGraphBatch( [s["target"] for s in samples], max_num_transitions=max_num_transitions, max_num_states=max_num_states, ) else: raise ValueError("Invalid key.")
def merge(key): if key == "source": return speech_utils.collate_frames([s[key] for s in samples], 0.0) elif key == "target": return data_utils.collate_tokens( [s[key] for s in samples], pad_idx=pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False, ) else: raise ValueError("Invalid key.")
def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None): if key == "source": return speech_utils.collate_frames( [s[key] for s in samples], 0.0, left_pad, pad_to_length=pad_to_length, pad_to_multiple=pad_to_multiple, ) elif key == "target" or key == "prev_output_tokens": return data_utils.collate_tokens( [s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning, pad_to_length=pad_to_length, pad_to_multiple=pad_to_multiple, ) else: raise ValueError("Invalid key.")