示例#1
0
    def test_collate_frames(self):
        vals = [
            torch.tensor([4.5, 2.3, 1.2]).unsqueeze(-1).expand(-1, 10),
            torch.tensor([6.7, 9.8]).unsqueeze(-1).expand(-1, 10),
            torch.tensor([7.7, 5.4, 6.2, 8.0]).unsqueeze(-1).expand(-1, 10),
            torch.tensor([1.5]).unsqueeze(-1).expand(-1, 10),
        ]
        expected_res1 = (torch.tensor([
            [4.5, 2.3, 1.2, 0.0],
            [6.7, 9.8, 0.0, 0.0],
            [7.7, 5.4, 6.2, 8.0],
            [1.5, 0.0, 0.0, 0.0],
        ]).unsqueeze(-1).expand(-1, -1, 10))
        expected_res2 = (torch.tensor([
            [0.0, 4.5, 2.3, 1.2],
            [0.0, 0.0, 6.7, 9.8],
            [7.7, 5.4, 6.2, 8.0],
            [0.0, 0.0, 0.0, 1.5],
        ]).unsqueeze(-1).expand(-1, -1, 10))

        res = utils.collate_frames(vals, pad_value=0.0, left_pad=False)
        self.assertTensorEqual(res, expected_res1)

        res = utils.collate_frames(vals, pad_value=0.0, left_pad=True)
        self.assertTensorEqual(res, expected_res2)
示例#2
0
 def merge(key, left_pad, move_eos_to_beginning=False):
     if key == 'source':
         return speech_utils.collate_frames(
             [s[key] for s in samples], 0.0, left_pad,
         )
     elif key == 'target':
         return data_utils.collate_tokens(
             [s[key] for s in samples],
             pad_idx, eos_idx, left_pad, move_eos_to_beginning,
         )
     else:
         raise ValueError('Invalid key.')
示例#3
0
 def merge(key):
     if key == "source":
         return speech_utils.collate_frames([s[key] for s in samples], 0.0)
     elif key == "target":
         max_num_transitions = max(s["target"].num_transitions for s in samples)
         max_num_states = max(s["target"].num_states for s in samples)
         return ChainGraphBatch(
             [s["target"] for s in samples],
             max_num_transitions=max_num_transitions,
             max_num_states=max_num_states,
         )
     else:
         raise ValueError("Invalid key.")
示例#4
0
 def merge(key):
     if key == "source":
         return speech_utils.collate_frames([s[key] for s in samples], 0.0)
     elif key == "target":
         return data_utils.collate_tokens(
             [s[key] for s in samples],
             pad_idx=pad_idx,
             eos_idx=None,
             left_pad=False,
             move_eos_to_beginning=False,
         )
     else:
         raise ValueError("Invalid key.")
示例#5
0
 def merge(key, left_pad, move_eos_to_beginning=False, pad_to_length=None):
     if key == "source":
         return speech_utils.collate_frames(
             [s[key] for s in samples], 0.0, left_pad,
             pad_to_length=pad_to_length,
             pad_to_multiple=pad_to_multiple,
         )
     elif key == "target" or key == "prev_output_tokens":
         return data_utils.collate_tokens(
             [s[key] for s in samples],
             pad_idx, eos_idx, left_pad, move_eos_to_beginning,
             pad_to_length=pad_to_length,
             pad_to_multiple=pad_to_multiple,
         )
     else:
         raise ValueError("Invalid key.")