def test_1d(self): for device in self.devices: row_splits1 = torch.tensor([0, 3, 5, 6, 6, 9], dtype=torch.int32, device=device) # we don't need to call shape2.to(device) here as shape2 # will be on the same device as row_splits shape2 = k2.ragged.create_ragged_shape2(row_splits1, None, 9) values = torch.tensor([1, 0, 4, 2, 3, 0, 4, 5, 2], dtype=torch.int32, device=device) ragged2 = k2.RaggedTensor(shape2, values) # contiguous src = torch.tensor([0, 2, 0, 10, 0, -1], dtype=torch.int32, device=device) ans = k2.simple_ragged_index_select(src, ragged2) self.assertEqual(ans.dtype, src.dtype) self.assertEqual(ans.numel(), shape2.dim0) expected = torch.tensor([2, 10, 0, 0, -1], dtype=torch.int32, device=device) self.assertTrue(torch.allclose(ans, expected)) # non-contiguous src = src.expand(3, -1).t().flatten()[::3] self.assertFalse(src.is_contiguous()) self.assertEqual(src.stride(0), 3) ans = k2.simple_ragged_index_select(src, ragged2) self.assertEqual(ans.dtype, src.dtype) self.assertEqual(ans.numel(), shape2.dim0) self.assertTrue(ans.is_contiguous()) self.assertEqual(ans.stride(0), 1) expected = torch.tensor([2, 10, 0, 0, -1], dtype=torch.int32, device=device) self.assertTrue(torch.allclose(ans, expected))
def test_with_negative_1(self): for device in self.devices: src = torch.tensor([0, 1, 2, 3], dtype=torch.float32, requires_grad=True, device=device) indexes = k2.RaggedTensor([[1, 2, -1], [0, 3], [-1], [0, 2, 3, 1, 3], []]).to(device) ans = k2.ragged.index_and_sum(src, indexes) expected = torch.tensor([1 + 2, 0 + 3, 0, 0 + 2 + 3 + 1 + 3, 0]).to(src) assert torch.allclose(ans, expected) # now for autograd scale = torch.tensor([10, 20, 30, 40, 50]).to(device) (ans * scale).sum().backward() expected_grad = torch.empty_like(src.grad) expected_grad[0] = scale[1] + scale[3] expected_grad[1] = scale[0] + scale[3] expected_grad[2] = scale[0] + scale[3] expected_grad[3] = scale[1] + scale[3] * 2 assert torch.allclose(src.grad, expected_grad)
def test_aux_as_ragged(self): s = ''' 0 1 1 0 0 1 0 0 0 3 2 0 1 2 3 0 1 3 4 0 2 1 5 0 2 5 -1 0 3 1 6 0 4 5 -1 0 5 ''' fsa = k2.Fsa.from_str(s) assert fsa.device.type == 'cpu' aux_row_splits = torch.tensor([0, 2, 3, 3, 6, 6, 7, 8, 10, 11], dtype=torch.int32) aux_shape = k2.ragged.create_ragged_shape2(aux_row_splits, None, 11) aux_values = torch.tensor([1, 2, 3, 5, 6, 7, 8, -1, 9, 10, -1], dtype=torch.int32) fsa.aux_labels = k2.RaggedTensor(aux_shape, aux_values) dest = k2.invert(fsa) print(dest) # will print aux_labels as well
def ctc_graph(symbols: Union[List[List[int]], k2.RaggedTensor], modified: bool = False, device: Optional[Union[torch.device, str]] = "cpu") -> Fsa: '''Construct ctc graphs from symbols. Note: The scores of arcs in the returned FSA are all 0. Args: symbols: It can be one of the following types: - A list of list-of-integers, e..g, `[ [1, 2], [1, 2, 3] ]` - An instance of :class:`k2.RaggedTensor`. Must have `num_axes == 2`. standard: Option to specify the type of CTC topology: "standard" or "simplified", where the "standard" one makes the blank mandatory between a pair of identical symbols. Default True. device: Optional. It can be either a string (e.g., 'cpu', 'cuda:0') or a torch.device. By default, the returned FSA is on CPU. If `symbols` is an instance of :class:`k2.RaggedTensor`, the returned FSA will on the same device as `k2.RaggedTensor`. Returns: An FsaVec containing the returned ctc graphs, with "Dim0()" the same as "len(symbols)"(List[List[int]]) or "dim0"(k2.RaggedTensor) ''' if not isinstance(symbols, k2.RaggedTensor): symbols = k2.RaggedTensor(symbols, device=device) ragged_arc, aux_labels = _k2.ctc_graph(symbols, modified) fsa = Fsa(ragged_arc, aux_labels=aux_labels) return fsa
def test(self): for device in self.devices: fsa1 = k2.ctc_topo(5, device=device) fsa1.attr1 = torch.tensor([1] * fsa1.num_arcs, device=device) stream1 = k2.RnntDecodingStream(fsa1) fsa2 = k2.trivial_graph(3, device=device) fsa2.attr1 = torch.tensor([2] * fsa2.num_arcs, device=device) fsa2.attr2 = torch.tensor([22] * fsa2.num_arcs, device=device) stream2 = k2.RnntDecodingStream(fsa2) fsa3 = k2.ctc_topo(3, modified=True, device=device) fsa3.attr3 = k2.RaggedTensor( torch.ones((fsa3.num_arcs, 2), dtype=torch.int32, device=device) * 3 ) stream3 = k2.RnntDecodingStream(fsa3) config = k2.RnntDecodingConfig(10, 2, 3.0, 3, 3) streams = k2.RnntDecodingStreams( [stream1, stream2, stream3], config ) for i in range(5): shape, context = streams.get_contexts() logprobs = torch.randn( (context.shape[0], 10), dtype=torch.float32, device=device ) streams.advance(logprobs) streams.terminate_and_flush_to_streams() ofsa = streams.format_output([3, 4, 5]) print(ofsa)
def test_sum_per_sublist(self): s = ''' 0 1 1 0. 0 1 2 0. 0 1 3 0. 1 2 4 0. 1 2 5 0. 2 3 -1 0. 3 ''' for device in self.devices: fsa = k2.Fsa.from_str(s).to(device) scores = torch.randn_like(fsa.scores) fsa.set_scores_stochastic_(scores) ragged = k2.RaggedTensor(fsa.arcs.shape(), fsa.scores.exp()) normalized_scores = ragged.sum() assert normalized_scores.numel() == fsa.arcs.dim0() assert torch.allclose( normalized_scores[:-1], torch.ones(normalized_scores.numel() - 1, device=device)) # the final state has no leaving arcs assert normalized_scores[-1].item() == 0
def test(self): for device in self.devices: s = ''' 0 1 2 10 0 1 1 20 1 2 -1 30 2 ''' src = k2.Fsa.from_str(s).to(device).requires_grad_(True) src.float_attr = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, requires_grad=True, device=device) src.int_attr = torch.tensor([1, 2, 3], dtype=torch.int32, device=device) src.ragged_attr = k2.RaggedTensor([[1, 2, 3], [5, 6], []]).to(device) src.attr1 = 'src' src.attr2 = 'fsa' ragged_arc, arc_map = _k2.arc_sort(src.arcs, need_arc_map=True) dest = k2.utils.fsa_from_unary_function_tensor( src, ragged_arc, arc_map) assert torch.allclose( dest.float_attr, torch.tensor([0.2, 0.1, 0.3], dtype=torch.float32, device=device)) assert torch.all( torch.eq( dest.scores, torch.tensor([20, 10, 30], dtype=torch.float32, device=device))) assert torch.all( torch.eq( dest.int_attr, torch.tensor([2, 1, 3], dtype=torch.int32, device=device))) expected_ragged_attr = k2.RaggedTensor([[5, 6], [1, 2, 3], []]).to(device) assert dest.ragged_attr == expected_ragged_attr assert dest.attr1 == src.attr1 assert dest.attr2 == src.attr2 # now for autograd scale = torch.tensor([10, 20, 30], device=device) (dest.float_attr * scale).sum().backward() (dest.scores * scale).sum().backward() expected_grad = torch.tensor([20, 10, 30], dtype=torch.float32, device=device) assert torch.all(torch.eq(src.float_attr.grad, expected_grad)) assert torch.all(torch.eq(src.scores.grad, expected_grad))
def test_single_fsa(self): for device in self.devices: # See https://git.io/JY7r4 s = ''' 0 1 0 0.1 0 2 0 0.2 0 0 0 0.3 1 1 0 0.4 1 2 0 0.5 2 3 -1 0.6 3 ''' src = k2.Fsa.from_str(s).to(device).requires_grad_(True) scores_copy = src.scores.detach().clone().requires_grad_(True) src.attr1 = "hello" src.attr2 = "k2" float_attr = torch.tensor([0.1, 0.2, 0.3, 4, 5, 6], dtype=torch.float32, requires_grad=True, device=device) src.float_attr = float_attr.detach().clone().requires_grad_(True) src.int_attr = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.int32, device=device) src.ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50], [60, 70], [80], [], [0]]).to(device) dest = k2.remove_epsilon_self_loops(src) # arc map is [0, 1, 4, 5] # See https://git.io/JY7oC expected_fsa = k2.Fsa.from_str(''' 0 1 0 0.1 0 2 0 0.2 1 2 0 0.5 2 3 -1 0.6 3 ''') assert k2.to_str_simple(dest) == k2.to_str_simple( expected_fsa), f'{str(dest)}\n{str(expected_fsa)}' assert dest.attr1 == src.attr1 assert dest.attr2 == src.attr2 expected_int_attr = torch.tensor([1, 2, 5, 6], dtype=torch.int32, device=device) assert torch.all(torch.eq(dest.int_attr, expected_int_attr)) expected_ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50], [], [0]]).to(device) assert dest.ragged_attr == expected_ragged_attr expected_float_attr = torch.empty_like(dest.float_attr) expected_float_attr[0] = float_attr[0] expected_float_attr[1] = float_attr[1] expected_float_attr[2] = float_attr[4] expected_float_attr[3] = float_attr[5] assert torch.all(torch.eq(dest.float_attr, expected_float_attr)) expected_scores = torch.empty_like(dest.scores) expected_scores[0] = scores_copy[0] expected_scores[1] = scores_copy[1] expected_scores[2] = scores_copy[4] expected_scores[3] = scores_copy[5] assert torch.all(torch.eq(dest.scores, expected_scores)) scale = torch.tensor([10, 20, 30, 40]).to(float_attr) (dest.float_attr * scale).sum().backward() (expected_float_attr * scale).sum().backward() assert torch.all(torch.eq(src.float_attr.grad, float_attr.grad)) (dest.scores * scale).sum().backward() (expected_scores * scale).sum().backward() assert torch.all(torch.eq(src.scores.grad, scores_copy.grad))
def test_fsa_vec(self): for device in self.devices: # See https://git.io/JY7r4 s = ''' 0 1 0 0.1 0 2 0 0.2 0 0 0 0.3 1 1 0 0.4 1 2 0 0.5 2 3 -1 0.6 3 ''' fsa1 = k2.Fsa.from_str(s).to(device).requires_grad_(True) scores_copy1 = fsa1.scores.detach().clone().requires_grad_(True) fsa1.attr1 = "hello" float_attr1 = torch.tensor([0.1, 0.2, 0.3, 4, 5, 6], dtype=torch.float32, requires_grad=True, device=device) fsa1.float_attr = float_attr1 fsa1.int_attr = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.int32, device=device) fsa1.ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50], [60, 70], [80], [], [0]]).to(device) fsa2 = k2.Fsa.from_str(s).to(device).requires_grad_(True) scores_copy2 = fsa2.scores.detach().clone().requires_grad_(True) fsa2.attr2 = "k2" float_attr2 = torch.tensor([1, 2, 3, 40, 50, 60], dtype=torch.float32, requires_grad=True, device=device) fsa2.float_attr = float_attr2 fsa2.int_attr = torch.tensor([10, 20, 30, 4, 5, 6], dtype=torch.int32, device=device) fsa2.ragged_attr = k2.RaggedTensor([[100, 200], [300, 400, 500], [600, 700], [800], [22], [33, 55]]).to(device) src = k2.create_fsa_vec([fsa1, fsa2]) dest = k2.remove_epsilon_self_loops(src) # arc map is[0, 1, 4, 5, 6, 7, 10, 11] # See https://git.io/JY7oC expected_fsa = k2.Fsa.from_str(''' 0 1 0 0.1 0 2 0 0.2 1 2 0 0.5 2 3 -1 0.6 3 ''') assert k2.to_str_simple(dest[0]) == k2.to_str_simple(expected_fsa) assert k2.to_str_simple(dest[1]) == k2.to_str_simple(expected_fsa) assert dest.attr1 == fsa1.attr1 assert dest.attr2 == fsa2.attr2 expected_int_attr = torch.tensor([1, 2, 5, 6, 10, 20, 5, 6], dtype=torch.int32, device=device) assert torch.all(torch.eq(dest.int_attr, expected_int_attr)) expected_ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50], [], [0], [100, 200], [300, 400, 500], [22], [33, 55]]).to(device) assert dest.ragged_attr == expected_ragged_attr expected_float_attr = torch.empty_like(dest.float_attr) expected_float_attr[0] = float_attr1[0] expected_float_attr[1] = float_attr1[1] expected_float_attr[2] = float_attr1[4] expected_float_attr[3] = float_attr1[5] expected_float_attr[4] = float_attr2[0] expected_float_attr[5] = float_attr2[1] expected_float_attr[6] = float_attr2[4] expected_float_attr[7] = float_attr2[5] assert torch.all(torch.eq(dest.float_attr, expected_float_attr)) expected_scores = torch.empty_like(dest.scores) expected_scores[0] = scores_copy1[0] expected_scores[1] = scores_copy1[1] expected_scores[2] = scores_copy1[4] expected_scores[3] = scores_copy1[5] expected_scores[4] = scores_copy2[0] expected_scores[5] = scores_copy2[1] expected_scores[6] = scores_copy2[4] expected_scores[7] = scores_copy2[5] assert torch.all(torch.eq(dest.scores, expected_scores)) scale = torch.tensor([10, 20, 30, 40, 50, 60, 70, 80]).to(dest.float_attr) (dest.float_attr * scale).sum().backward() (expected_float_attr * scale).sum().backward() assert torch.all(torch.eq(fsa1.float_attr.grad, float_attr1.grad)) assert torch.all(torch.eq(fsa2.float_attr.grad, float_attr2.grad)) (dest.scores * scale).sum().backward() (expected_scores * scale).sum().backward() assert torch.all(torch.eq(fsa1.scores.grad, scores_copy1.grad)) assert torch.all(torch.eq(fsa2.scores.grad, scores_copy2.grad))
import k2 s = ''' 0 1 2 0.1 1 2 -1 0.2 2 ''' fsa = k2.Fsa.from_str(s) fsa.aux_labels = k2.RaggedTensor('[ [10 20] [-1] ]') inverted_fsa = k2.invert(fsa) fsa.draw('before_invert_aux.svg', title='before invert with ragged tensors as aux_labels') inverted_fsa.draw('after_invert_aux.svg', title='after invert')
def test(self): s0 = ''' 0 1 1 0.1 0 2 2 0.2 1 2 3 0.3 1 3 -1 0.4 2 3 -1 0.5 2 1 5 0.55 3 ''' s1 = ''' 0 1 -1 0.6 1 ''' s2 = ''' 0 1 6 0.7 1 0 7 0.8 1 0 8 0.9 1 2 -1 1.0 2 ''' for device in self.devices: fsa0 = k2.Fsa.from_str(s0) fsa1 = k2.Fsa.from_str(s1) fsa2 = k2.Fsa.from_str(s2) fsa0.tensor_attr = torch.tensor([1, 2, 3, 4, 5, 6], dtype=torch.int32, device=device) fsa0.ragged_tensor_attr = k2.RaggedTensor( fsa0.tensor_attr.unsqueeze(-1)) fsa1.tensor_attr = torch.tensor([7], dtype=torch.int32, device=device) fsa1.ragged_tensor_attr = k2.RaggedTensor( fsa1.tensor_attr.unsqueeze(-1)) fsa2.tensor_attr = torch.tensor([8, 9, 10, 11], dtype=torch.int32, device=device) fsa2.ragged_tensor_attr = k2.RaggedTensor( fsa2.tensor_attr.unsqueeze(-1)) fsa_vec = k2.create_fsa_vec([fsa0, fsa1, fsa2]).to(device) fsa = k2.union(fsa_vec) expected_tensor_attr = torch.tensor( [0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]).to(fsa.tensor_attr) assert torch.all(torch.eq(fsa.tensor_attr, expected_tensor_attr)) expected_ragged_tensor_attr = k2.RaggedTensor( expected_tensor_attr.unsqueeze(-1)).remove_values_eq(0) assert str(expected_ragged_tensor_attr) == str( fsa.ragged_tensor_attr) assert torch.allclose( fsa.arcs.values()[:, :3], torch.tensor([ [0, 1, 0], # fsa 0 [0, 4, 0], # fsa 1 [0, 5, 0], # fsa 2 # now for fsa0 [1, 2, 1], [1, 3, 2], [2, 3, 3], [2, 7, -1], [3, 7, -1], [3, 2, 5], # fsa1 [4, 7, -1], # fsa2 [5, 6, 6], [6, 5, 7], [6, 5, 8], [6, 7, -1] ]).to(torch.int32).to(device)) assert torch.allclose( fsa.scores, torch.tensor([ 0., 0., 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.55, 0.6, 0.7, 0.8, 0.9, 1.0 ]).to(device))
def test_autograd_remove_epsilon_and_add_self_loops(self): if not torch.cuda.is_available(): return if not k2.with_cuda: return devices = [torch.device('cuda', 0)] if torch.cuda.device_count() > 1: torch.cuda.set_device(1) devices.append(torch.device('cuda', 1)) s = ''' 0 1 0 0.1 0 1 1 0.2 1 2 -1 0.3 2 ''' for device in devices: src = k2.Fsa.from_str(s).to(device).requires_grad_(True) scores_copy = src.scores.detach().clone().requires_grad_(True) src.attr1 = "hello" src.attr2 = "k2" float_attr = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, requires_grad=True, device=device) src.float_attr = float_attr.detach().clone().requires_grad_(True) src.int_attr = torch.tensor([1, 2, 3], dtype=torch.int32, device=device) src.ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50], [60, 70]]).to(device) dest = k2.remove_epsilon_and_add_self_loops(src) # without add_self_loops, the arc map is [[1] [0 2] [2]] # with add_self_loops, the arc map is [[] [1] [0 2] [] [2]] assert dest.attr1 == src.attr1 assert dest.attr2 == src.attr2 expected_int_attr = k2.RaggedTensor([[], [2], [1, 3], [], [3]]).to(device) assert dest.int_attr == expected_int_attr expected_ragged_attr = k2.RaggedTensor([[], [30, 40, 50], [10, 20, 60, 70], [], [60, 70]]).to(device) assert dest.ragged_attr == expected_ragged_attr expected_float_attr = torch.empty_like(dest.float_attr) expected_float_attr[0] = 0 expected_float_attr[1] = float_attr[1] expected_float_attr[2] = float_attr[0] + float_attr[2] expected_float_attr[3] = 0 expected_float_attr[4] = float_attr[2] assert torch.all(torch.eq(dest.float_attr, expected_float_attr)) expected_scores = torch.empty_like(dest.scores) expected_scores[0] = 0 expected_scores[1] = scores_copy[1] expected_scores[2] = scores_copy[0] + scores_copy[2] expected_scores[3] = 0 expected_scores[4] = scores_copy[2] assert torch.all(torch.eq(dest.scores, expected_scores)) scale = torch.tensor([10, 20, 30, 40, 50]).to(float_attr) (dest.float_attr * scale).sum().backward() (expected_float_attr * scale).sum().backward() assert torch.all(torch.eq(src.float_attr.grad, float_attr.grad)) (dest.scores * scale).sum().backward() (expected_scores * scale).sum().backward() assert torch.all(torch.eq(src.scores.grad, scores_copy.grad))
def test_without_empty_list(self): for device in self.devices: s = ''' 0 1 0 0 0 1 1 0 1 2 -1 0 2 ''' scores = torch.tensor([1, 2, 3], dtype=torch.float32, device=device, requires_grad=True) scores_copy = scores.detach().clone().requires_grad_(True) src = k2.Fsa.from_str(s).to(device) src.scores = scores # see https://git.io/Jufpl src.attr1 = "hello" src.attr2 = "k2" float_attr = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, requires_grad=True, device=device) src.float_attr = float_attr.detach().clone().requires_grad_(True) src.int_attr = torch.tensor([1, 2, 3], dtype=torch.int32, device=device) src.ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50], [60, 70]]).to(device) ragged_arc, arc_map = _k2.remove_epsilon(src.arcs, src.properties) # see https://git.io/Jufpe dest = k2.utils.fsa_from_unary_function_ragged( src, ragged_arc, arc_map) assert dest.attr1 == src.attr1 assert dest.attr2 == src.attr2 expected_arc_map = k2.RaggedTensor([[1], [0, 2], [2]]).to(device) assert arc_map == expected_arc_map expected_int_attr = k2.RaggedTensor([[2], [1, 3], [3]]).to(device) assert dest.int_attr == expected_int_attr expected_ragged_attr = k2.RaggedTensor([[30, 40, 50], [10, 20, 60, 70], [60, 70]]).to(device) assert dest.ragged_attr == expected_ragged_attr expected_float_attr = torch.empty_like(dest.float_attr) expected_float_attr[0] = float_attr[1] expected_float_attr[1] = float_attr[0] + float_attr[2] expected_float_attr[2] = float_attr[2] assert torch.all(torch.eq(dest.float_attr, expected_float_attr)) expected_scores = torch.empty_like(dest.scores) expected_scores[0] = scores_copy[1] expected_scores[1] = scores_copy[0] + scores_copy[2] expected_scores[2] = scores_copy[2] assert torch.all(torch.eq(dest.scores, expected_scores)) scale = torch.tensor([10, 20, 30]).to(float_attr) (dest.float_attr * scale).sum().backward() (expected_float_attr * scale).sum().backward() assert torch.all(torch.eq(src.float_attr.grad, float_attr.grad)) (dest.scores * scale).sum().backward() (expected_scores * scale).sum().backward() assert torch.all(torch.eq(scores.grad, scores_copy.grad))
def test(self): s = ''' 0 1 2 0.1 0 1 1 0.2 1 2 -1 0.3 2 ''' for device in self.devices: src = k2.Fsa.from_str(s).to(device) src.requires_grad_(True) scores_copy = src.scores.detach().clone().requires_grad_(True) src.attr1 = "hello" src.attr2 = "k2" float_attr = torch.tensor([0.1, 0.2, 0.3], dtype=torch.float32, requires_grad=True, device=device) src.float_attr = float_attr.detach().clone().requires_grad_(True) src.int_attr = torch.tensor([1, 2, 3], dtype=torch.int32, device=device) src.ragged_attr = k2.RaggedTensor([[10, 20], [30, 40, 50], [60, 70]]).to(device) dest, arc_map = k2.arc_sort(src, ret_arc_map=True) assert dest.attr1 == src.attr1 assert dest.attr2 == src.attr2 expected_arc_map = torch.tensor([1, 0, 2], dtype=torch.int32, device=device) assert torch.all(torch.eq(arc_map, expected_arc_map)) actual_str = k2.to_str_simple(dest) expected_str = '\n'.join( ['0 1 1 0.2', '0 1 2 0.1', '1 2 -1 0.3', '2']) assert actual_str.strip() == expected_str expected_int_attr = torch.tensor([2, 1, 3], dtype=torch.int32, device=device) assert torch.all(torch.eq(dest.int_attr, expected_int_attr)) expected_ragged_attr = k2.RaggedTensor([[30, 40, 50], [10, 20], [60, 70]]).to(device) assert dest.ragged_attr == expected_ragged_attr expected_float_attr = torch.empty_like(dest.float_attr) expected_float_attr[0] = float_attr[1] expected_float_attr[1] = float_attr[0] expected_float_attr[2] = float_attr[2] assert torch.all(torch.eq(dest.float_attr, expected_float_attr)) expected_scores = torch.empty_like(dest.scores) expected_scores[0] = scores_copy[1] expected_scores[1] = scores_copy[0] expected_scores[2] = scores_copy[2] assert torch.all(torch.eq(dest.scores, expected_scores)) scale = torch.tensor([10, 20, 30]).to(float_attr) (dest.float_attr * scale).sum().backward() (expected_float_attr * scale).sum().backward() assert torch.all(torch.eq(src.float_attr.grad, float_attr.grad)) (dest.scores * scale).sum().backward() (expected_scores * scale).sum().backward() assert torch.all(torch.eq(src.scores.grad, scores_copy.grad))