예제 #1
0
def convert_triple_tuple_to_torch(batch, config, enable_cuda_override=None):
    deprecation("Not tested anymore", since="0.3.0")
    if enable_cuda_override is not None:
        converter = _BatchElementConverter(enable_cuda_override)
    else:
        converter = _BatchElementConverter(config.enable_cuda)
    return tuple(map(converter, batch))
예제 #2
0
def expand_triple_to_sets(triple, num_expands, arange_target):
    """Tiles triple into a large sets for testing. One node will be initialized with arange.
    Returns (h, r, t), each with a shape of (num_expands,)
    """
    deprecation("Not tested anymore", since="0.3.0")

    if not constants.TripleElement.has_value(arange_target):
        raise RuntimeError(
            "arange_target is set to wrong value. It has to be one of TripleElement but it's "
            + str(arange_target))
    h, r, t = triple

    if arange_target == constants.TripleElement.HEAD:
        h = np.arange(num_expands, dtype=np.int64)
        r = np.tile(np.array([r], dtype=np.int64), num_expands)
        t = np.tile(np.array([t], dtype=np.int64), num_expands)
    elif arange_target == constants.TripleElement.RELATION:
        h = np.tile(np.array([h], dtype=np.int64), num_expands)
        r = np.arange(num_expands, dtype=np.int64)
        t = np.tile(np.array([t], dtype=np.int64), num_expands)
    elif arange_target == constants.TripleElement.TAIL:
        h = np.tile(np.array([h], dtype=np.int64), num_expands)
        r = np.tile(np.array([r], dtype=np.int64), num_expands)
        t = np.arange(num_expands, dtype=np.int64)
    else:
        raise RuntimeError(
            "Miracle happened. arange_target passed the validation and reached impossible branch."
        )

    return (h, r, t)
예제 #3
0
def label_prediction_collate(sample):
    """Add all positive labels for sample.
    """
    deprecation("Not tested anymore", since="0.3.0")

    tiled, batch, splits = sample

    labels_shape = (tiled.shape[0])
    labels = np.full(labels_shape, 1, dtype=np.int64)

    return tiled, batch, splits, labels
예제 #4
0
 def __init__(self,
              source,
              negative_sampler,
              literals,
              transforms,
              sample_negative_for_non_triples=False):
     deprecation("WIP", since="0.3.0")
     self.source = source
     self.negative_sampler = negative_sampler
     self.literals = literals
     self.sample_negative_for_non_triples = sample_negative_for_non_triples
     self.transforms = transforms
예제 #5
0
def get_triples_from_batch(batch):
    """Returns h, r, t and possible label from batch."""
    deprecation("Not tested anymore", since="0.3.0")

    multiple_samples = (batch.ndim == 3)
    if multiple_samples:
        batch_size, num_samples, num_element = batch.shape
        elements = np.split(batch, num_element, axis=2)
        return (e.reshape(batch_size, num_samples) for e in elements)
    else:
        batch_size, num_element = batch.shape
        elements = np.split(batch, num_element, axis=1)
        return (e.reshape(batch_size) for e in elements)
예제 #6
0
def validation_resource_manager(config,
                                triple_source,
                                required_modes=['train_validate', 'test']):
  """prepare resources if validation is needed."""
  utils.deprecation("process manager not in use any more", since="0.5.0")
  enabled = config.mode in required_modes
  if enabled:
    ctx = mp.get_context('spawn')
    pool = ParallelEvaluator(config, triple_source, ctx)
    try:
      pool.start()
      yield pool
    finally:
      pool.stop()
  else:
    yield None
예제 #7
0
def _evaluate_predict_element(model, config, triple_index, num_expands,
                              element_type, rank_fn, ranks_list,
                              filtered_ranks_list):
  """Evaluation a single triple with expanded sets."""
  utils.deprecation("multiprocess validation is not in use any more", since="0.5.0")
  batch = data.expand_triple_to_sets(
      kgekit.data.unpack(triple_index), num_expands, element_type)
  batch = data.convert_triple_tuple_to_torch(batch, config)
  logging.debug(element_type)
  logging.debug("Batch len: " + str(len(batch)) + "; batch sample: " +
                str(batch[0]))
  predicted_batch = model.forward(batch).cpu()
  logging.debug("Predicted batch len" + str(len(predicted_batch)) +
                "; batch sample: " + str(predicted_batch[0]))
  rank, filtered_rank = rank_fn(predicted_batch.data.numpy(), triple_index)
  logging.debug("Rank :" + str(rank) + "; Filtered rank length :" +
                str(filtered_rank))
  ranks_list.append(rank)
  filtered_ranks_list.append(filtered_rank)
예제 #8
0
def sieve_and_expand_triple(triple_source, entities, relations, head, relation,
                            tail):
    """Tile on a unknown element. returns a tuple of size 3 with h, r, t."""
    deprecation("Not tested anymore", since="0.3.0")

    batch_size, num_samples, num_element = batch.shape
    elements = np.split(batch, num_element, axis=2)
    # return (e.reshape(batch_size) for e in elements)

    if head == '?':
        r = relations[relation]
        t = entities[tail]
        triple_index = kgedata.TripleIndex(-1, r, t)

        h = np.arange(triple_source.num_entity, dtype=np.int64)
        r = np.tile(np.array([r], dtype=np.int64), triple_source.num_entity)
        t = np.tile(np.array([t], dtype=np.int64), triple_source.num_entity)
        prediction_type = constants.HEAD_KEY
    elif relation == '?':
        h = entities[head]
        t = entities[tail]
        triple_index = kgedata.TripleIndex(h, -1, t)

        h = np.tile(np.array([h], dtype=np.int64), triple_source.num_relation)
        r = np.arange(triple_source.num_relation, dtype=np.int64)
        t = np.tile(np.array([t], dtype=np.int64), triple_source.num_relation)
        prediction_type = constants.RELATION_KEY
    elif tail == '?':
        r = relations[relation]
        h = entities[head]
        triple_index = kgedata.TripleIndex(h, r, -1)

        h = np.tile(np.array([h], dtype=np.int64), triple_source.num_entity)
        r = np.tile(np.array([r], dtype=np.int64), triple_source.num_entity)
        t = np.arange(triple_source.num_entity, dtype=np.int64)
        prediction_type = constants.TAIL_KEY
    else:
        raise RuntimeError("head, relation, tail are known.")

    return (h, r, t), prediction_type, triple_index
예제 #9
0
def evaulate_prediction_np_collate(model, triple_source, config, ranker,
                                   data_loader):
  """use with NumpyCollate."""
  utils.deprecation("multiprocess validation is not in use any more", "0.5.0")
  model.eval()

  head_ranks = []
  filtered_head_ranks = []
  tail_ranks = []
  filtered_tail_ranks = []
  relation_ranks = []
  filtered_relation_ranks = []

  for i_batch, sample_batched in enumerate(data_loader):
    # sample_batched is a list of triple. triple has shape (1, 3). We need to tile it for the test.
    for triple in sample_batched:
      triple_index = kgedata.TripleIndex(*triple[0, :])

      if (config.report_dimension & stats.StatisticsDimension.SEPERATE_ENTITY
         ) or (config.report_dimension &
               stats.StatisticsDimension.COMBINED_ENTITY):
        _evaluate_predict_element(model, config, triple_index,
                                  triple_source.num_entity,
                                  data.TripleElement.HEAD, ranker.rank_head,
                                  head_ranks, filtered_head_ranks)
        _evaluate_predict_element(model, config, triple_index,
                                  triple_source.num_entity,
                                  data.TripleElement.TAIL, ranker.rank_tail,
                                  tail_ranks, filtered_tail_ranks)
      if config.report_dimension & stats.StatisticsDimension.RELATION:
        _evaluate_predict_element(
            model, config, triple_index, triple_source.num_relation,
            data.TripleElement.RELATION, ranker.rank_relation, relation_ranks,
            filtered_relation_ranks)

  return (head_ranks,
          filtered_head_ranks), (tail_ranks,
                                 filtered_tail_ranks), (relation_ranks,
                                                        filtered_relation_ranks)
예제 #10
0
 def __init__(self, triple_order):
     deprecation("Input is changed to numpy array.", since="0.3.0")
     kgekit.utils.assert_triple_order(triple_order)
     self.triple_order = triple_order
예제 #11
0
 def __init__(self, config):
     deprecation("Not tested anymore", since="0.3.0")
     self.config = config
예제 #12
0
 def __init__(self, transform=None):
     deprecation("Not tested anymore", since="0.3.0")
     self.transform = transform
예제 #13
0
 def __init__(self, cuda_enabled=False):
     deprecation("Not tested anymore", since="0.3.0")
     self.cuda_enabled = cuda_enabled