def test_tensor_dict_constructor(): dict_tree = dict(a=torch.randn(2, 2), b=dict(c=dict(d=np.random.randn(3, 3)))) tensor_dict = TensorDict.from_tree(dict_tree) assert torch.is_tensor(tensor_dict["a"]) assert isinstance(tensor_dict["b"], TensorDict) assert isinstance(tensor_dict["b"]["c"], TensorDict) assert torch.is_tensor(tensor_dict["b"]["c"]["d"])
def test_tensor_dict_map(): dict_tree = dict(a=dict(b=[0])) tensor_dict = TensorDict.from_tree(dict_tree) res = tensor_dict.map(lambda x: x + 1) assert (res["a"]["b"] == 1).all() tensor_dict.map_in_place(lambda x: x + 1) assert res == tensor_dict
def test_tensor_dict_str_index(): dict_tree = dict(a=torch.randn(2, 2), b=dict(c=dict(d=torch.randn(3, 3)))) tensor_dict = TensorDict.from_tree(dict_tree) x = torch.randn(5, 5) tensor_dict["a"] = x assert (tensor_dict["a"] == x).all() with pytest.raises(KeyError): _ = tensor_dict["c"]
def test_tensor_dict_index(): dict_tree = dict(a=torch.randn(2, 2), b=dict(c=dict(d=torch.randn(3, 3)))) tensor_dict = TensorDict.from_tree(dict_tree) with pytest.raises(KeyError): tensor_dict["b"][0] = dict(q=torch.randn(3)) tmp = dict(c=dict(d=torch.randn(3))) tensor_dict["b"][0] = tmp assert torch.allclose(tensor_dict["b"]["c"]["d"][0], tmp["c"]["d"]) assert not torch.allclose(tensor_dict["b"]["c"]["d"][1], tmp["c"]["d"]) tensor_dict["b"]["c"]["x"] = torch.randn(5, 5) with pytest.raises(KeyError): tensor_dict["b"][1] = tmp tensor_dict["b"].set(1, tmp, strict=False) assert torch.allclose(tensor_dict["b"]["c"]["d"][1], tmp["c"]["d"]) tmp = dict(c=dict(d=torch.randn(1, 3))) del tensor_dict["b"]["c"]["x"] tensor_dict["b"][2:3] = tmp assert torch.allclose(tensor_dict["b"]["c"]["d"][2:3], tmp["c"]["d"])
def test_tensor_dict_to_tree(): dict_tree = dict(a=torch.randn(2, 2), b=dict(c=dict(d=torch.randn(3, 3)))) assert dict_tree == TensorDict.from_tree(dict_tree).to_tree()