def test_struct_should_return_element_when_indexing_with_string():
    t = TensorStruct.ones({'a': (10, 2), 'b': {'c': (5, 4)}})
    t_a = t['a']
    assert isinstance(t_a, torch.Tensor)
    assert t_a.shape == (10, 2)

    t_b = t['b']
    assert isinstance(t_b, TensorStruct)
    assert t_b['c'].shape == (5, 4)
def test_struct_should_update_dict_if_valid_struct_given():
    t = TensorStruct.ones({'a': {'b': (10, 1), 'c': (10, 2)}})
    new_data = TensorStruct({
        'b': torch.zeros_like(t['a']['b']),
        'c': torch.zeros_like(t['a']['c'])
    })
    t['a'] = new_data
    assert t['a']['b'][0, 0] == 0
    assert t['a']['c'][0, 0] == 0
def test_struct_should_update_values_from_struct_at_given_indices():
    t = TensorStruct.ones({'a': {'b': (10, 1), 'c': (10, 2)}})
    new_data = TensorStruct({
        'b': torch.zeros((5, 1)),
        'c': torch.zeros((5, 2))
    })
    t['a'][2:7] = new_data
    assert t['a']['b'][2, 0] == 0
    assert t['a']['c'][2, 0] == 0
def test_cat_should_cat_nested_tensors():
    ts = [TensorStruct.ones({'a': (10, 2), 'b': (10, 3)}) for _ in range(5)]
    ts_ = cat(ts, dim=0)
    assert ts_['a'].shape == (50, 2)
    assert ts_['b'].shape == (50, 3)
def test_struct_should_raise_if_updating_with_invalid_struct():
    t = TensorStruct.ones({'a': {'b': (10, 1), 'c': (10, 2)}})
    new_data = TensorStruct({'b': torch.zeros((10, 1))})

    with pytest.raises(ValueError):
        t['a'] = new_data
def test_struct_should_update_tensor_if_tensor_given():
    t = TensorStruct.ones({'a': (10, 2), 'b': (10, 1)})
    t['a'] = torch.zeros_like(t['a'])
    assert t['a'][0, 0] == 0
def test_struct_should_raise_when_given_invalid_key():
    t = TensorStruct.ones({'a': (10, 2), 'b': (10, 1)})
    with pytest.raises(KeyError):
        _ = t['c']
def test_struct_should_raise_if_updating_invalid_key():
    t = TensorStruct.ones({'a': (10, 2), 'b': (10, 1)})
    with pytest.raises(KeyError):
        t['c'] = 5
def test_struct_should_raise_when_given_index_other_than_string_int_slice_tensor(
):
    t = TensorStruct.ones({'a': (10, 2), 'b': (10, 1)})
    with pytest.raises(ValueError):
        _ = t[None]
def test_struct_should_return_narrowed_elements_when_indexing_with_tensor():
    t = TensorStruct.ones({'a': (10, 2), 'b': (10, 1)})
    t_29 = t[torch.arange(2, 9, dtype=torch.long)]
    assert isinstance(t_29, TensorStruct)
    assert t_29['a'].shape == (7, 2)
    assert t_29['b'].shape == (7, 1)
def test_struct_should_return_narrowed_elements_when_indexing_with_slice():
    t = TensorStruct.ones({'a': (10, 2), 'b': (10, 1)})
    t_05 = t[:5]
    assert isinstance(t_05, TensorStruct)
    assert t_05['a'].shape == (5, 2)
    assert t_05['b'].shape == (5, 1)
def test_struct_should_return_single_elements_when_indexing_with_int():
    t = TensorStruct.ones({'a': (10, 2), 'b': (10, 1)})
    t_1 = t[1]
    assert isinstance(t_1, TensorStruct)
    assert t_1['a'].shape == (2, )
    assert t_1['b'].shape == (1, )
def test_struct_should_raise_if_using_unsupported_assignment():
    t = TensorStruct.ones({'a': {'b': (10, 1), 'c': (10, 2)}})
    with pytest.raises(ValueError):
        t['a'] = 7