def read_dynamic( hf: h5py.File, key: str, begin: int, end: int, ) -> TensorList: try: offsets_ds = hf[f"{key}_offsets"] data_ds = hf[f"{key}_data"] except LookupError: # Empty tensor_list representation return TensorList(offsets=torch.zeros( (), dtype=torch.long).expand(end - begin + 1), data=torch.empty((0, ), dtype=torch.long)) offsets = torch.empty((end - begin + 1, ), dtype=torch.long) offsets_ds.read_direct(offsets.numpy(), source_sel=np.s_[begin:end + 1]) data_begin = offsets[0].item() data_end = offsets[-1].item() data = torch.empty((data_end - data_begin, ), dtype=torch.long) # Needed because https://github.com/h5py/h5py/issues/870. if data_end - data_begin > 0: data_ds.read_direct(data.numpy(), source_sel=np.s_[data_begin:data_end]) offsets -= int(offsets[0]) return TensorList(offsets, data)
def read_dynamic(hf: h5py.File, key: str, begin: int, end: int, *, shared: bool = False) -> TensorList: try: offsets_ds = hf[f"{key}_offsets"] data_ds = hf[f"{key}_data"] except LookupError: return TensorList.empty(num_tensors=end - begin) allocator = allocate_shared_tensor if shared else torch.empty offsets = allocator((end - begin + 1, ), dtype=torch.long) offsets_ds.read_direct(offsets.numpy(), source_sel=np.s_[begin:end + 1]) data_begin = offsets[0].item() data_end = offsets[-1].item() data = allocator((data_end - data_begin, ), dtype=torch.long) # Needed because https://github.com/h5py/h5py/issues/870. if data_end - data_begin > 0: data_ds.read_direct(data.numpy(), source_sel=np.s_[data_begin:data_end]) offsets -= int(offsets[0]) return TensorList(offsets, data)
def test_tensor_list(self): with tempfile.NamedTemporaryFile() as bf: with h5py.File(bf.name, "w") as hf, FileEdgeAppender(hf) as buffered_hf: buffered_hf.append_tensor_list( "foo", TensorList( torch.tensor([0, 3, 5], dtype=torch.long), torch.tensor([1, 2, 3, 4, 5], dtype=torch.long), ), ) buffered_hf.append_tensor_list( "bar", TensorList( torch.tensor([0, 1_000_000], dtype=torch.long), torch.arange(1_000_000, dtype=torch.long), ), ) buffered_hf.append_tensor_list( "foo", TensorList( torch.tensor([0, 1, 1, 3], dtype=torch.long), torch.tensor([6, 7, 8], dtype=torch.long), ), )
def get(self, input_: TensorList) -> FloatTensorType: if input_.size(0) == 0: return torch.empty((0, self.weight.size(1))) return F.embedding_bag( input_.data.long(), self.weight, input_.offsets[:-1], max_norm=self.max_norm, sparse=True, )
def test_forward(self): embeddings = torch.tensor( [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], requires_grad=True ) module = FeaturizedEmbedding(weight=embeddings) result = module( EntityList.from_tensor_list( TensorList( torch.tensor([0, 1, 3, 6, 6]), torch.tensor([0, 2, 1, 0, 1, 0]) ) ) ) self.assertTensorEqual( result, torch.tensor( [ [1.0000, 1.0000, 1.0000], [2.5000, 2.5000, 2.5000], [1.3333, 1.3333, 1.3333], [0.0000, 0.0000, 0.0000], ] ), ) result.sum().backward() self.assertTrue((embeddings.grad.to_dense() != 0).any())
def test_from_tensor(self): self.assertEqual( EntityList.from_tensor(torch.tensor([3, 4], dtype=torch.long)), EntityList( torch.tensor([3, 4], dtype=torch.long), TensorList.empty(num_tensors=2) ), )
def test_empty(self): embeddings = torch.empty((0, 3)) module = FeaturizedEmbedding(weight=embeddings) self.assertTensorEqual( module( EntityList.from_tensor_list( TensorList(torch.zeros((1, ), dtype=torch.long), torch.empty((0, ), dtype=torch.long)))), torch.empty((0, 3)))
def test_max_norm(self): embeddings = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]) module = FeaturizedEmbedding(weight=embeddings, max_norm=2) self.assertTensorEqual( module( EntityList.from_tensor_list( TensorList(torch.tensor([0, 1, 3, 6, 6]), torch.tensor([0, 2, 1, 0, 1, 0])))), torch.tensor([ [1.0000, 1.0000, 1.0000], [1.1547, 1.1547, 1.1547], [1.0516, 1.0516, 1.0516], [0.0000, 0.0000, 0.0000], ]), )
def test_empty(self): self.assertEqual( EntityList.empty(), EntityList(torch.empty((0, ), dtype=torch.long), TensorList.empty()), )
def tensor_list_from_lists(lists: Sequence[Sequence[int]]) -> TensorList: offsets = torch.tensor([0] + [len(l) for l in lists], dtype=torch.long).cumsum(dim=0) data = torch.cat([torch.tensor(l, dtype=torch.long) for l in lists], dim=0) return TensorList(offsets, data)
def cat(cls, entity_lists: Sequence['EntityList']) -> 'EntityList': return cls(torch.cat([el.tensor for el in entity_lists]), TensorList.cat(el.tensor_list for el in entity_lists))
def from_tensor(cls, tensor: LongTensorType) -> 'EntityList': if tensor.dim() != 1: raise ValueError("Expected 1D tensor, got %dD" % tensor.dim()) tensor_list = TensorList.empty(num_tensors=tensor.shape[0]) return cls(tensor, tensor_list)
def empty(cls) -> 'EntityList': return cls(torch.empty((0, ), dtype=torch.long), TensorList.empty())
def cat(cls, entity_lists: Sequence["EntityList"]) -> "EntityList": return cls( torch.cat([el.tensor for el in entity_lists]), TensorList.cat(el.tensor_list for el in entity_lists), )