示例#1
0
 def __init__(self, lhs: EntityList, rhs: EntityList, rel: LongTensorType) -> None:
     if not isinstance(lhs, EntityList) or not isinstance(rhs, EntityList):
         raise TypeError(
             "Expected left- and right-hand side to be entity lists, got "
             "%s and %s instead" % (type(lhs), type(rhs))
         )
     if not isinstance(rel, (torch.LongTensor, torch.cuda.LongTensor)):
         raise TypeError("Expected relation to be a long tensor, got %s" % type(rel))
     if len(lhs) != len(rhs):
         raise ValueError(
             "The left- and right-hand side entity lists have different "
             "lengths: %d != %d" % (len(lhs), len(rhs))
         )
     if rel.dim() > 1:
         raise ValueError(
             "The relation can be either a scalar or a 1-dimensional "
             "tensor, got a %d-dimensional tensor" % rel.dim()
         )
     if rel.dim() == 1 and rel.shape[0] != len(lhs):
         raise ValueError(
             "The relation has a different length than the entity lists: "
             "%d != %d" % (rel.shape[0], len(lhs))
         )
     self.lhs = lhs
     self.rhs = rhs
     self.rel = rel
示例#2
0
文件: entitylist.py 项目: RweBs/PDKE
 def __init__(self, tensor: LongTensorType,
              tensor_list: TensorList) -> None:
     if not isinstance(tensor, torch.LongTensor):
         raise TypeError("Expected long tensor as first argument, got %s" %
                         type(tensor))
     if not isinstance(tensor_list, TensorList):
         raise TypeError("Expected tensor list as second argument, got %s" %
                         type(tensor_list))
     if tensor.dim() != 1:
         raise ValueError(
             "Expected 1-dimensional tensor, got %d-dimensional one" %
             tensor.dim())
     if tensor.shape[0] != len(tensor_list):
         raise ValueError(
             "The tensor and tensor list have different lengths: %d != %d" %
             (tensor.shape[0], len(tensor_list)))
     # TODO We could check that, for all i, we have either tensor[i] < 0 or
     # tensor_list[i] empty, however it's expensive and we're already doing
     # something similar at retrieval inside to_tensor(_list).
     self.tensor: LongTensorType = tensor
     self.tensor_list: TensorList = tensor_list
示例#3
0
文件: entitylist.py 项目: RweBs/PDKE
 def from_tensor(cls, tensor: LongTensorType) -> 'EntityList':
     if tensor.dim() != 1:
         raise ValueError("Expected 1D tensor, got %dD" % tensor.dim())
     tensor_list = TensorList.empty(num_tensors=tensor.shape[0])
     return cls(tensor, tensor_list)