def read_dynamic( hf: h5py.File, key: str, begin: int, end: int, ) -> TensorList: try: offsets_ds = hf[f"{key}_offsets"] data_ds = hf[f"{key}_data"] except LookupError: # Empty tensor_list representation return TensorList(offsets=torch.zeros( (), dtype=torch.long).expand(end - begin + 1), data=torch.empty((0, ), dtype=torch.long)) offsets = torch.empty((end - begin + 1, ), dtype=torch.long) offsets_ds.read_direct(offsets.numpy(), source_sel=np.s_[begin:end + 1]) data_begin = offsets[0].item() data_end = offsets[-1].item() data = torch.empty((data_end - data_begin, ), dtype=torch.long) # Needed because https://github.com/h5py/h5py/issues/870. if data_end - data_begin > 0: data_ds.read_direct(data.numpy(), source_sel=np.s_[data_begin:data_end]) offsets -= int(offsets[0]) return TensorList(offsets, data)
def test_tensor_list(self): with tempfile.NamedTemporaryFile() as bf: with h5py.File(bf.name, "w") as hf, FileEdgeAppender(hf) as buffered_hf: buffered_hf.append_tensor_list( "foo", TensorList( torch.tensor([0, 3, 5], dtype=torch.long), torch.tensor([1, 2, 3, 4, 5], dtype=torch.long), ), ) buffered_hf.append_tensor_list( "bar", TensorList( torch.tensor([0, 1_000_000], dtype=torch.long), torch.arange(1_000_000, dtype=torch.long), ), ) buffered_hf.append_tensor_list( "foo", TensorList( torch.tensor([0, 1, 1, 3], dtype=torch.long), torch.tensor([6, 7, 8], dtype=torch.long), ), )
def get(self, input_: TensorList) -> FloatTensorType: if input_.size(0) == 0: return torch.empty((0, self.weight.size(1))) return F.embedding_bag( input_.data.long(), self.weight, input_.offsets[:-1], max_norm=self.max_norm, sparse=True, )
def test_empty(self): embeddings = torch.empty((0, 3)) module = FeaturizedEmbedding(weight=embeddings) self.assertTensorEqual( module(EntityList.from_tensor_list(TensorList( torch.zeros((1,), dtype=torch.long), torch.empty((0,), dtype=torch.long) ))), torch.empty((0, 3)))
def new_with_tensor( cls: Type[EntityListType], tensor: FloatTensorType, ) -> EntityListType: # sanity check assert tensor.squeeze().ndimension() == 1 tensor = tensor.squeeze() tensor_list = TensorList( torch.zeros((), dtype=torch.long).expand(tensor.nelement() + 1), torch.empty((0, ), dtype=torch.long), ) return cls(tensor, tensor_list)
def test_max_norm(self): embeddings = torch.tensor([ [1., 1., 1.], [2., 2., 2.], [3., 3., 3.], ]) module = FeaturizedEmbedding(weight=embeddings, max_norm=2) self.assertTensorEqual( module(EntityList.from_tensor_list(TensorList( torch.tensor([0, 1, 3, 6, 6]), torch.tensor([0, 2, 1, 0, 1, 0]), ))), torch.tensor([ [1.0000, 1.0000, 1.0000], [1.1547, 1.1547, 1.1547], [1.0516, 1.0516, 1.0516], [0.0000, 0.0000, 0.0000], ]))
def test_forward(self): embeddings = torch.tensor([ [1., 1., 1.], [2., 2., 2.], [3., 3., 3.], ], requires_grad=True) module = FeaturizedEmbedding(weight=embeddings) result = module(EntityList.from_tensor_list(TensorList( torch.tensor([0, 1, 3, 6, 6]), torch.tensor([0, 2, 1, 0, 1, 0]), ))) self.assertTensorEqual( result, torch.tensor([ [1.0000, 1.0000, 1.0000], [2.5000, 2.5000, 2.5000], [1.3333, 1.3333, 1.3333], [0.0000, 0.0000, 0.0000], ])) result.sum().backward() self.assertTrue((embeddings.grad.to_dense() != 0).any())
def test_from_tensor(self): self.assertEqual( EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)), EntityList(torch.tensor([3, 4], dtype=torch.long), TensorList.empty(num_tensors=2)), )
def test_empty(self): self.assertEqual( EntityList.empty(), EntityList(torch.empty((0, ), dtype=torch.long), TensorList.empty()), )
def tensor_list_from_lists(lists: Sequence[Sequence[int]]) -> TensorList: offsets = torch.tensor([0] + [len(l) for l in lists], dtype=torch.long).cumsum(dim=0) data = torch.cat([torch.tensor(l, dtype=torch.long) for l in lists], dim=0) return TensorList(offsets, data)
def cat(cls, entity_lists: List['EntityList']) -> 'EntityList': return cls(torch.cat([el.tensor for el in entity_lists]), TensorList.cat(el.tensor_list for el in entity_lists))
def from_tensor_list(cls, tensor_list: TensorList) -> 'EntityList': tensor = torch.full((tensor_list.nelement(), ), -1, dtype=torch.long) return cls(tensor, tensor_list)
def from_tensor(cls, tensor: LongTensorType) -> 'EntityList': if tensor.ndimension() != 1: raise ValueError("Expected 1D tensor, got %dD" % tensor.ndimension()) tensor_list = TensorList.empty(num_tensors=tensor.nelement()) return cls(tensor, tensor_list)
def empty(cls) -> 'EntityList': return cls(torch.empty((0, ), dtype=torch.long), TensorList.empty())