def context_sampling(self):
     for index in range(len(self.names)):
         num_of_entity = self.num_of_entities[index]
         entity_dict = self.entity_dicts[index]
         head_context_statistic = self.head_context_statistics[index]
         tail_context_statistic = self.tail_context_statistics[index]
         head_context_head = self.head_context_heads[index]
         head_context_relation = self.head_context_relations[index]
         tail_context_relation = self.tail_context_relations[index]
         tail_context_tail = self.tail_context_tails[index]
         entity_head = self.entity_heads[index]
         entity_head_relation = self.entity_head_relations[index]
         entity_tail_relation = self.entity_tail_relations[index]
         entity_tail = self.entity_tails[index]
         for entity_id in range(num_of_entity):
             entity = entity_dict[entity_id]
             num_of_head_context = head_context_statistic[entity]
             num_of_tail_context = tail_context_statistic[entity]
             if num_of_head_context > 0:
                 heads = head_context_head[entity]
                 relations = head_context_relation[entity]
                 sampled_ids = sampled_id_generation(
                     0, num_of_head_context, self.head_context_size)
                 entity_head[entity] = torch.unsqueeze(
                     torch.LongTensor([heads[_] for _ in sampled_ids]), 0)
                 entity_head_relation[entity] = torch.unsqueeze(
                     torch.LongTensor([relations[_] for _ in sampled_ids]),
                     0)
             if num_of_tail_context > 0:
                 relations = tail_context_relation[entity]
                 tails = tail_context_tail[entity]
                 sampled_ids = sampled_id_generation(
                     0, num_of_tail_context, self.tail_context_size)
                 entity_tail_relation[entity] = torch.unsqueeze(
                     torch.LongTensor([relations[_] for _ in sampled_ids]),
                     0)
                 entity_tail[entity] = torch.unsqueeze(
                     torch.LongTensor([tails[_] for _ in sampled_ids]), 0)
         name = self.names[index]
         dump_data(entity_head,
                   self.output_path + name + "_context_head.pickle",
                   self.log_path, "")
         dump_data(
             entity_head_relation,
             self.output_path + name + "_context_head_relation.pickle",
             self.log_path, "")
         dump_data(
             entity_tail_relation,
             self.output_path + name + "_context_tail_relation.pickle",
             self.log_path, "")
         dump_data(entity_tail,
                   self.output_path + name + "_context_tail.pickle",
                   self.log_path, "")
 def tail_context_process(self, tail_batch):
     tail_relation = torch.LongTensor(len(tail_batch), self.tail_context_size)
     tail_tail = torch.LongTensor(len(tail_batch), self.tail_context_size)
     for index in range(len(tail_batch)):
         entity = tail_batch[index]
         relations = self.tail_context_relation[entity]
         tails = self.tail_context_tail[entity]
         num_of_tail_context = self.tail_context_statistics[entity]
         sampled_ids = sampled_id_generation(0, num_of_tail_context, self.tail_context_size)
         tail_relation[index] = torch.LongTensor([relations[_] for _ in sampled_ids])
         tail_tail[index] = torch.LongTensor([tails[_] for _ in sampled_ids])
     return tail_relation, tail_tail
 def head_context_process(self, head_batch):
     head_head = torch.LongTensor(len(head_batch), self.head_context_size)
     head_relation = torch.LongTensor(len(head_batch), self.head_context_size)
     for index in range(len(head_batch)):
         entity = head_batch[index]
         heads = self.head_context_head[entity]
         relations = self.head_context_relation[entity]
         num_of_head_context = self.head_context_statistics[entity]
         sampled_ids = sampled_id_generation(0, num_of_head_context, self.head_context_size)
         head_head[index] = torch.LongTensor([heads[_] for _ in sampled_ids])
         head_relation[index] = torch.LongTensor([relations[_] for _ in sampled_ids])
     return head_head, head_relation
 def negative_sampling(self):
     for index in range(len(self.names)):
         name = self.names[index]
         num_of_entity = self.num_of_entities[index]
         entity_dict = self.entity_dicts[index]
         negative = self.negatives[index]
         for entity_id in range(num_of_entity):
             entity = entity_dict[entity_id]
             negative_entities = []
             sampled_entities = {}
             sampled_entity_count = 0
             while len(
                     negative_entities
             ) < self.negative_batch_size and sampled_entity_count < num_of_entity:
                 sampled_entity = entity_dict[sampled_id_generation(
                     0, num_of_entity, 1)[0]]
                 while sampled_entity in sampled_entities:
                     sampled_entity = entity_dict[sampled_id_generation(
                         0, num_of_entity, 1)[0]]
                 sampled_entities[sampled_entity] = None
                 sampled_entity_count += 1
                 if self.negative_or_not(entity, sampled_entity):
                     negative_entities.append(sampled_entity)
             if len(negative_entities) == 0:
                 sampled_ids = sampled_id_generation(
                     0, num_of_entity, self.negative_batch_size)
                 for sampled_id in sampled_ids:
                     negative_entities.append(entity_dict[sampled_id])
             if len(negative_entities) < self.negative_batch_size:
                 sampled_ids = sampled_id_generation(
                     0, len(negative_entities),
                     self.negative_batch_size - len(negative_entities))
                 for sampled_id in sampled_ids:
                     negative_entities.append(negative_entities[sampled_id])
             negative[entity] = torch.unsqueeze(
                 torch.LongTensor(negative_entities), 0)
         dump_data(negative,
                   self.output_path + "%s_negatives.pickle" % name,
                   self.log_path, "")
 def negative_batch_generation(self, positive_batch):
     negative_batch = torch.LongTensor(len(positive_batch), self.negative_batch_size)
     for index in range(negative_batch.size()[0]):
         entity = positive_batch[index]
         negative_entities = []
         sampled_entities = {}
         sampled_entity_count = 0
         while len(negative_entities) < self.negative_batch_size and sampled_entity_count < self.num_of_train_entities:
             sampled_entity_id = sampled_id_generation(0, self.num_of_train_entities, 1)[0]
             while sampled_entity_id in sampled_entities:
                 sampled_entity_id = sampled_id_generation(0, self.num_of_train_entities, 1)[0]
             sampled_entities[sampled_entity_id] = None
             sampled_entity_count += 1
             if self.negative_or_not(entity, self.train_entities[sampled_entity_id]):
                 negative_entities.append(self.train_entities[sampled_entity_id])
         if len(negative_entities) == 0:
             negative_entities = [self.train_entities[tmp_id] for tmp_id in sampled_id_generation(0, self.num_of_train_entities, self.negative_batch_size)]
         if len(negative_entities) < self.negative_batch_size:
             sampled_indices = sampled_id_generation(0, len(negative_entities), self.negative_batch_size - len(negative_entities))
             for sampled_index in sampled_indices:
                 negative_entities.append(negative_entities[sampled_index])
         negative_batch[index] = torch.FloatTensor(negative_entities)
     return negative_batch
import pickle, torch
from tools.uniform_sampling import sampled_id_generation

ids = sampled_id_generation(0, 3, 1)
print ids[0]