def __init__(self, lhs: EntityList, rhs: EntityList, rel: LongTensorType) -> None: if not isinstance(lhs, EntityList) or not isinstance(rhs, EntityList): raise TypeError( "Expected left- and right-hand side to be entity lists, got " "%s and %s instead" % (type(lhs), type(rhs)) ) if not isinstance(rel, (torch.LongTensor, torch.cuda.LongTensor)): raise TypeError("Expected relation to be a long tensor, got %s" % type(rel)) if len(lhs) != len(rhs): raise ValueError( "The left- and right-hand side entity lists have different " "lengths: %d != %d" % (len(lhs), len(rhs)) ) if rel.dim() > 1: raise ValueError( "The relation can be either a scalar or a 1-dimensional " "tensor, got a %d-dimensional tensor" % rel.dim() ) if rel.dim() == 1 and rel.shape[0] != len(lhs): raise ValueError( "The relation has a different length than the entity lists: " "%d != %d" % (rel.shape[0], len(lhs)) ) self.lhs = lhs self.rhs = rhs self.rel = rel
def __init__(self, tensor: LongTensorType, tensor_list: TensorList) -> None: if not isinstance(tensor, torch.LongTensor): raise TypeError("Expected long tensor as first argument, got %s" % type(tensor)) if not isinstance(tensor_list, TensorList): raise TypeError("Expected tensor list as second argument, got %s" % type(tensor_list)) if tensor.dim() != 1: raise ValueError( "Expected 1-dimensional tensor, got %d-dimensional one" % tensor.dim()) if tensor.shape[0] != len(tensor_list): raise ValueError( "The tensor and tensor list have different lengths: %d != %d" % (tensor.shape[0], len(tensor_list))) # TODO We could check that, for all i, we have either tensor[i] < 0 or # tensor_list[i] empty, however it's expensive and we're already doing # something similar at retrieval inside to_tensor(_list). self.tensor: LongTensorType = tensor self.tensor_list: TensorList = tensor_list
def from_tensor(cls, tensor: LongTensorType) -> 'EntityList': if tensor.dim() != 1: raise ValueError("Expected 1D tensor, got %dD" % tensor.dim()) tensor_list = TensorList.empty(num_tensors=tensor.shape[0]) return cls(tensor, tensor_list)