def output(self): dump_data(self.entity_embeddings.weight.data.to("cpu"), self.result_path + "entity_embeddings.pickle", self.log_path, "self.entity_embeddings.weight.data") dump_data(self.relation_embeddings.weight.data.to("cpu"), self.result_path + "relation_embeddings.pickle", self.log_path, "self.relation_embeddings.weight.data")
def entity_classification(self): counts = [0, 0, 0] dataset_classifications = [ self.train_entities, self.validate_entities, self.test_entities ] context_head_classifications = [ self.train_head_entities, self.validate_head_entities, self.test_head_entities ] context_tail_classifications = [ self.train_tail_entities, self.validate_tail_entities, self.test_tail_entities ] context_both_classifications = [ self.train_both_entities, self.validate_both_entities, self.test_both_entities ] for entity in range(self.statistics["num_of_entities"]): for index in range(len(self.names)): if entity in self.head_relation_to_tails[ index] or entity in self.tail_relation_to_heads[index]: dataset_classifications[index][counts[index]] = entity counts[index] += 1 if self.head_context_statistics[index][ entity] > 0 and self.tail_context_statistics[ index][entity] == 0: context_head_classifications[index][entity] = None if self.head_context_statistics[index][ entity] == 0 and self.tail_context_statistics[ index][entity] > 0: context_tail_classifications[index][entity] = None if self.head_context_statistics[index][ entity] > 0 and self.tail_context_statistics[ index][entity] > 0: context_both_classifications[index][entity] = None self.statistics["num_of_train_entities"] = counts[0] self.statistics["num_of_validate_entities"] = counts[1] self.statistics["num_of_test_entities"] = counts[2] for index in range(len(self.names)): dump_data( dataset_classifications[index], self.output_path + "%s_entities.pickle" % self.names[index], self.log_path, "") dump_data( context_head_classifications[index], self.output_path + "%s_head_entities.pickle" % self.names[index], self.log_path, "") dump_data( context_tail_classifications[index], self.output_path + "%s_tail_entities.pickle" % self.names[index], self.log_path, "") dump_data( context_both_classifications[index], self.output_path + "%s_both_entities.pickle" % self.names[index], self.log_path, "") dump_data(self.statistics, self.output_path + "statistics.pickle", self.log_path, "")
def train_triple_tensor_generation(self): train_triple_tensor = torch.zeros(self.num_of_train_triples, 3) for index in range(self.num_of_train_triples): train_triple_tensor[index][0] = self.id_train_triples["id_heads"][ index] train_triple_tensor[index][1] = self.id_train_triples[ "id_relations"][index] train_triple_tensor[index][2] = self.id_train_triples["id_tails"][ index] dump_data(train_triple_tensor, self.output_path + "train_triple_tensor.pickle", self.log_path, "train_triple_tensor")
def head_relation_to_tail_and_reverse(self): names = ["train", "valid", "test"] num_of_triples = [ self.num_of_train_triples, self.num_of_validate_triples, self.num_of_test_triples ] id_triples = [ self.id_train_triples, self.id_validate_triples, self.id_test_triples ] head_relation_to_tails = [ self.train_head_relation_to_tail, self.validate_head_relation_to_tail, self.test_head_relation_to_tail ] tail_relation_to_heads = [ self.train_tail_relation_to_head, self.validate_tail_relation_to_head, self.test_tail_relation_to_head ] for index in range(3): name = names[index] num_of_triple = num_of_triples[index] id_triple = id_triples[index] head_relation_to_tail = head_relation_to_tails[index] tail_relation_to_head = tail_relation_to_heads[index] for triple_id in range(num_of_triple): tmp_head = id_triple["id_heads"][triple_id] tmp_relation = id_triple["id_relations"][triple_id] tmp_tail = id_triple["id_tails"][triple_id] if tmp_head not in head_relation_to_tail: head_relation_to_tail[tmp_head] = {tmp_relation: []} else: if tmp_relation not in head_relation_to_tail[tmp_head]: head_relation_to_tail[tmp_head][tmp_relation] = [] head_relation_to_tail[tmp_head][tmp_relation].append(tmp_tail) if tmp_tail not in tail_relation_to_head: tail_relation_to_head[tmp_tail] = {tmp_relation: []} else: if tmp_relation not in tail_relation_to_head[tmp_tail]: tail_relation_to_head[tmp_tail][tmp_relation] = [] tail_relation_to_head[tmp_tail][tmp_relation].append(tmp_head) dump_data( head_relation_to_tail, self.output_path + "%s_head_relation_to_tail.pickle" % name, self.log_path, "head_relation_to_tail") dump_data( tail_relation_to_head, self.output_path + "%s_tail_relation_to_head.pickle" % name, self.log_path, "tail_relation_to_head")
def context_sampling(self): for index in range(len(self.names)): num_of_entity = self.num_of_entities[index] entity_dict = self.entity_dicts[index] head_context_statistic = self.head_context_statistics[index] tail_context_statistic = self.tail_context_statistics[index] head_context_head = self.head_context_heads[index] head_context_relation = self.head_context_relations[index] tail_context_relation = self.tail_context_relations[index] tail_context_tail = self.tail_context_tails[index] entity_head = self.entity_heads[index] entity_head_relation = self.entity_head_relations[index] entity_tail_relation = self.entity_tail_relations[index] entity_tail = self.entity_tails[index] for entity_id in range(num_of_entity): entity = entity_dict[entity_id] num_of_head_context = head_context_statistic[entity] num_of_tail_context = tail_context_statistic[entity] if num_of_head_context > 0: heads = head_context_head[entity] relations = head_context_relation[entity] sampled_ids = sampled_id_generation( 0, num_of_head_context, self.head_context_size) entity_head[entity] = torch.unsqueeze( torch.LongTensor([heads[_] for _ in sampled_ids]), 0) entity_head_relation[entity] = torch.unsqueeze( torch.LongTensor([relations[_] for _ in sampled_ids]), 0) if num_of_tail_context > 0: relations = tail_context_relation[entity] tails = tail_context_tail[entity] sampled_ids = sampled_id_generation( 0, num_of_tail_context, self.tail_context_size) entity_tail_relation[entity] = torch.unsqueeze( torch.LongTensor([relations[_] for _ in sampled_ids]), 0) entity_tail[entity] = torch.unsqueeze( torch.LongTensor([tails[_] for _ in sampled_ids]), 0) name = self.names[index] dump_data(entity_head, self.output_path + name + "_context_head.pickle", self.log_path, "") dump_data( entity_head_relation, self.output_path + name + "_context_head_relation.pickle", self.log_path, "") dump_data( entity_tail_relation, self.output_path + name + "_context_tail_relation.pickle", self.log_path, "") dump_data(entity_tail, self.output_path + name + "_context_tail.pickle", self.log_path, "")
def negative_sampling(self): for index in range(len(self.names)): name = self.names[index] num_of_entity = self.num_of_entities[index] entity_dict = self.entity_dicts[index] negative = self.negatives[index] for entity_id in range(num_of_entity): entity = entity_dict[entity_id] negative_entities = [] sampled_entities = {} sampled_entity_count = 0 while len( negative_entities ) < self.negative_batch_size and sampled_entity_count < num_of_entity: sampled_entity = entity_dict[sampled_id_generation( 0, num_of_entity, 1)[0]] while sampled_entity in sampled_entities: sampled_entity = entity_dict[sampled_id_generation( 0, num_of_entity, 1)[0]] sampled_entities[sampled_entity] = None sampled_entity_count += 1 if self.negative_or_not(entity, sampled_entity): negative_entities.append(sampled_entity) if len(negative_entities) == 0: sampled_ids = sampled_id_generation( 0, num_of_entity, self.negative_batch_size) for sampled_id in sampled_ids: negative_entities.append(entity_dict[sampled_id]) if len(negative_entities) < self.negative_batch_size: sampled_ids = sampled_id_generation( 0, len(negative_entities), self.negative_batch_size - len(negative_entities)) for sampled_id in sampled_ids: negative_entities.append(negative_entities[sampled_id]) negative[entity] = torch.unsqueeze( torch.LongTensor(negative_entities), 0) dump_data(negative, self.output_path + "%s_negatives.pickle" % name, self.log_path, "")
def read_dataset(self): names = ["train", "valid", "test"] string_triples = [ self.string_train_triples, self.string_validate_triples, self.string_test_triples ] id_triples = [ self.id_train_triples, self.id_validate_triples, self.id_test_triples ] num_of_triples = [0, 0, 0] for index in range(3): name = names[index] string_triple = string_triples[index] id_triple = id_triples[index] log_text(self.log_path, "reading file %s" % self.input_path + name + ".txt") with open(self.input_path + name + ".txt") as data_reader: tmp_line = data_reader.readline() while tmp_line and tmp_line not in ["\n", "\r\n", "\r"]: tmp_head = tmp_line.split()[0] tmp_relation = tmp_line.split()[1] tmp_tail = tmp_line.split()[2] string_triple["heads"].append(tmp_head) string_triple["relations"].append(tmp_relation) string_triple["tails"].append(tmp_tail) id_triple["id_heads"].append( self.entity_id_generation(tmp_head)) id_triple["id_relations"].append( self.relation_id_generation(tmp_relation)) id_triple["id_tails"].append( self.entity_id_generation(tmp_tail)) num_of_triples[index] += 1 tmp_line = data_reader.readline() dump_data(string_triple, self.output_path + "string_%s_triples.pickle" % name, self.log_path, "string_%s_triples" % name) dump_data(id_triple, self.output_path + "id_%s_triples.pickle" % name, self.log_path, "id_%s_triples" % name) dump_data(self.entity2id, self.output_path + "entity2id.pickle", self.log_path, "self.entity2id") dump_data(self.relation2id, self.output_path + "relation2id.pickle", self.log_path, "self.relation2id") self.num_of_train_triples = num_of_triples[0] self.num_of_validate_triples = num_of_triples[1] self.num_of_test_triples = num_of_triples[2]
def statistics(self): log_text(self.log_path, "number of train triples: %d" % self.num_of_train_triples) log_text( self.log_path, "number of validate triples: %d" % self.num_of_validate_triples) log_text(self.log_path, "number of test triples: %d" % self.num_of_test_triples) log_text(self.log_path, "number of entities: %d" % self.num_of_entities) log_text(self.log_path, "number of relations: %d" % self.num_of_relations) statistics = { "num_of_train_triples": self.num_of_train_triples, "num_of_validate_triples": self.num_of_validate_triples, "num_of_test_triples": self.num_of_test_triples, "num_of_entities": self.num_of_entities, "num_of_relations": self.num_of_relations, "num_of_train_entities": None, "num_of_validate_entities": None, "num_of_test_entities": None } dump_data(statistics, self.output_path + "statistics.pickle", self.log_path, "statistics")
def train(self): model = Model(self.result_path, self.log_path, self.entity_dimension, self.relation_dimension, self.num_of_entities, self.num_of_relations, self.norm, self.device) if self.continue_learning: model.input() model.to(self.device) optimizer = torch.optim.Adam(model.parameters(), self.learning_rate) PrintGPUStatus.print_gpu_status("after the initialization of model") self.offline_batch_retrieve = OfflineBatchRetrieve( self.names, self.dataset) current_validate_loss = self.validate(model) log_text(self.log_path, "initial loss (validation): %f" % current_validate_loss) optimal_validate_loss = current_validate_loss self.optimal_entity_embeddings = model.entity_embeddings.weight.data.clone( ) self.optimal_relation_embeddings = model.relation_embeddings.weight.data.clone( ) entity_set = MyDataset(self.num_of_train_entities) entity_loader = DataLoader(entity_set, self.batch_size, True) patience_count = 0 for epoch in range(self.num_of_epochs): epoch_loss = 0. if epoch != 0 and epoch % self.re_sampling_freq == 0: self.context_and_negatives.re_sampling() self.offline_batch_retrieve.re_read_context_and_negatives() for entity_id_batch in entity_loader: model.normalize() optimizer.zero_grad() entity_batch = [ self.train_entities[entity_id.item()] for entity_id in entity_id_batch ] head_batch, tail_batch, both_batch = self.offline_batch_retrieve.batch_classification( "train", entity_batch) batch_loss = self.loss_compute("train", model, head_batch, tail_batch, both_batch) batch_loss.backward() optimizer.step() epoch_loss += batch_loss log_text( self.log_path, "\r\nepoch " + str(epoch) + ": , loss: " + str(epoch_loss)) if epoch % self.validation_freq == 0: current_validate_loss = self.validate(model) if current_validate_loss < optimal_validate_loss: log_text( self.log_path, "optimal validate loss: " + str(optimal_validate_loss) + " -> " + str(current_validate_loss)) patience_count = 0 optimal_validate_loss = current_validate_loss self.optimal_entity_embeddings = model.entity_embeddings.weight.data.clone( ) self.optimal_relation_embeddings = model.relation_embeddings.weight.data.clone( ) else: patience_count += 1 log_text( self.log_path, "early stop patience: " + str(self.early_stop_patience) + ", patience count: " + str(patience_count) + ", current validate loss: " + str(current_validate_loss) + ", optimal validate loss: " + str(optimal_validate_loss)) if patience_count == self.patience: if self.early_stop_patience == 1: dump_data( self.optimal_entity_embeddings.to("cpu"), self.result_path + "optimal_entity_embedding.pickle", self.log_path, "self.optimal_entity_embeddings") dump_data( self.optimal_relation_embeddings.to("cpu"), self.result_path + "optimal_relation_embedding.pickle", self.log_path, "self.optimal_relation_embeddings") break log_text( self.log_path, "learning rate: " + str(self.learning_rate) + " -> " + str(self.learning_rate / 2)) self.learning_rate = self.learning_rate / 2 model.entity_embeddings.weight.data = self.optimal_entity_embeddings.clone( ) model.relation_embeddings.weight.data = self.optimal_relation_embeddings.clone( ) optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate) patience_count = 0 self.early_stop_patience -= 1 if (epoch + 1) % self.output_freq == 0: model.output() dump_data(self.optimal_entity_embeddings.to("cpu"), self.result_path + "optimal_entity_embedding.pickle", self.log_path, "self.optimal_entity_embeddings") dump_data( self.optimal_relation_embeddings.to("cpu"), self.result_path + "optimal_relation_embedding.pickle", self.log_path, "self.optimal_relation_embeddings") print "test loss: %f" % self.test(model)
def context_process(self): names = ["train", "valid", "test"] head_relation_to_tails = [ self.train_head_relation_to_tail, self.validate_head_relation_to_tail, self.test_head_relation_to_tail ] tail_relation_to_heads = [ self.train_tail_relation_to_head, self.validate_tail_relation_to_head, self.test_tail_relation_to_head ] head_context_heads = [ self.train_head_context_head, self.validate_head_context_head, self.test_head_context_head ] head_context_relations = [ self.train_head_context_relation, self.validate_head_context_relation, self.test_head_context_relation ] head_context_statistics_es = [ self.train_head_context_statistics, self.validate_head_context_statistics, self.test_head_context_statistics ] tail_context_relations = [ self.train_tail_context_relation, self.validate_tail_context_relation, self.test_tail_context_relation ] tail_context_tails = [ self.train_tail_context_tail, self.validate_tail_context_tail, self.test_tail_context_tail ] tail_context_statistics_es = [ self.train_tail_context_statistics, self.validate_tail_context_statistics, self.test_tail_context_statistics ] for index in range(3): name = names[index] head_relation_to_tail = head_relation_to_tails[index] tail_relation_to_head = tail_relation_to_heads[index] head_context_head = head_context_heads[index] head_context_relation = head_context_relations[index] head_context_statistics = head_context_statistics_es[index] tail_context_relation = tail_context_relations[index] tail_context_tail = tail_context_tails[index] tail_context_statistics = tail_context_statistics_es[index] for entity in range(self.num_of_entities): num_of_head_context = 0 head_context_head[entity] = {} head_context_relation[entity] = {} if entity in tail_relation_to_head: for relation in tail_relation_to_head[entity]: for head in tail_relation_to_head[entity][relation]: head_context_head[entity][ num_of_head_context] = head head_context_relation[entity][ num_of_head_context] = relation num_of_head_context += 1 head_context_statistics[entity] = num_of_head_context num_of_tail_context = 0 tail_context_relation[entity] = {} tail_context_tail[entity] = {} if entity in head_relation_to_tail: for relation in head_relation_to_tail[entity]: for tail in head_relation_to_tail[entity][relation]: tail_context_relation[entity][ num_of_tail_context] = relation tail_context_tail[entity][ num_of_tail_context] = tail num_of_tail_context += 1 tail_context_statistics[entity] = num_of_tail_context dump_data(head_context_head, self.output_path + "%s_head_context_head.pickle" % name, self.log_path, "head_context_head") dump_data( head_context_relation, self.output_path + "%s_head_context_relation.pickle" % name, self.log_path, "head_context_relation") dump_data( head_context_statistics, self.output_path + "%s_head_context_statistics.pickle" % name, self.log_path, "head_context_statistics") dump_data( tail_context_relation, self.output_path + "%s_tail_context_relation.pickle" % name, self.log_path, "tail_context_relation") dump_data(tail_context_tail, self.output_path + "%s_tail_context_tail.pickle" % name, self.log_path, "tail_context_tail") dump_data( tail_context_statistics, self.output_path + "%s_tail_context_statistics.pickle" % name, self.log_path, "tail_context_statistics")
def train(self): model = Model(self.result_path, self.log_path, self.entity_dimension, self.relation_dimension, self.num_of_entities, self.num_of_relations, self.norm, self.device) if self.continue_learning: model.input() model.to(self.device) optimizer = torch.optim.Adam(model.parameters(), self.learning_rate) PrintGPUStatus.print_gpu_status("after the initialization of model") self.offline_batch_retrieve = OfflineBatchRetrieve(self.names, self.dataset) current_mean_rank = self.validate(model) log_text(self.log_path, "initial mean rank (validation): %f" % current_mean_rank) optimal_mean_rank = current_mean_rank self.optimal_entity_embeddings = model.entity_embeddings.weight.data.clone() self.optimal_relation_embeddings = model.relation_embeddings.weight.data.clone() entity_set = MyDataset(self.num_of_train_entities) entity_loader = DataLoader(entity_set, self.batch_size, True) patience_count = 0 for epoch in range(self.num_of_epochs): epoch_loss = 0. if epoch != 0 and epoch % self.re_sampling_freq == 0: self.context_and_negatives.re_sampling() self.offline_batch_retrieve.re_read_context_and_negatives() for entity_id_batch in entity_loader: model.normalize() optimizer.zero_grad() entity_batch = [self.train_entities[entity_id.item()] for entity_id in entity_id_batch] head_loss, tail_loss, both_loss, batch_loss = 0., 0., 0., 0. head_batch, tail_batch, both_batch = self.offline_batch_retrieve.batch_classification("train", entity_batch) if len(head_batch) > 0: head_head, head_relation = self.offline_batch_retrieve.head_context_retrieve("train", head_batch) negative_head_batch = self.offline_batch_retrieve.negative_retrieves("train", head_batch) head_batch = torch.LongTensor(head_batch) head_loss = -1. * model(head_batch.to(self.device), head_head.to(self.device), head_relation.to(self.device), None, None, negative_head_batch.to(self.device)) if len(tail_batch) > 0: tail_relation, tail_tail = self.offline_batch_retrieve.tail_context_retrieve("train", tail_batch) negative_tail_batch = self.offline_batch_retrieve.negative_retrieves("train", tail_batch) tail_batch = torch.LongTensor(tail_batch) tail_loss = -1. * model(tail_batch.to(self.device), None, None, tail_relation.to(self.device), tail_tail.to(self.device), negative_tail_batch.to(self.device)) if len(both_batch) > 0: both_head, both_head_relation = self.offline_batch_retrieve.head_context_retrieve("train", both_batch) both_tail_relation, both_tail = self.offline_batch_retrieve.tail_context_retrieve("train", both_batch) negative_both_batch = self.offline_batch_retrieve.negative_retrieves("train", both_batch) both_batch = torch.LongTensor(both_batch) both_loss = -1. * model(both_batch.to(self.device), both_head.to(self.device), both_head_relation.to(self.device), both_tail_relation.to(self.device), both_tail.to(self.device), negative_both_batch.to(self.device)) batch_loss += head_loss + tail_loss + both_loss batch_loss.backward() optimizer.step() epoch_loss += batch_loss log_text(self.log_path, "\r\nepoch " + str(epoch) + ": , loss: " + str(epoch_loss)) if epoch % self.validation_freq == 0: current_mean_rank = self.validate(model) if current_mean_rank < optimal_mean_rank: log_text(self.log_path, "optimal average raw mean rank: " + str(optimal_mean_rank) + " -> " + str(current_mean_rank)) patience_count = 0 optimal_mean_rank = current_mean_rank self.optimal_entity_embeddings = model.entity_embeddings.weight.data.clone() self.optimal_relation_embeddings = model.relation_embeddings.weight.data.clone() else: patience_count += 1 log_text(self.log_path, "early stop patience: " + str(self.early_stop_patience) + ", patience count: " + str(patience_count) + ", current rank: " + str(current_mean_rank) + ", best rank: " + str(optimal_mean_rank)) if patience_count == self.patience: if self.early_stop_patience == 1: dump_data(self.optimal_entity_embeddings.to("cpu"), self.result_path + "optimal_entity_embedding.pickle", self.log_path, "self.optimal_entity_embeddings") dump_data(self.optimal_relation_embeddings.to("cpu"), self.result_path + "optimal_relation_embedding.pickle", self.log_path, "self.optimal_relation_embeddings") break log_text(self.log_path, "learning rate: " + str(self.learning_rate) + " -> " + str(self.learning_rate / 2)) self.learning_rate = self.learning_rate / 2 model.entity_embeddings.weight.data = self.optimal_entity_embeddings.clone() model.relation_embeddings.weight.data = self.optimal_relation_embeddings.clone() optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate) patience_count = 0 self.early_stop_patience -= 1 if (epoch + 1) % self.output_freq == 0: model.output() dump_data(self.optimal_entity_embeddings.to("cpu"), self.result_path + "optimal_entity_embedding.pickle", self.log_path, "self.optimal_entity_embeddings") dump_data(self.optimal_relation_embeddings.to("cpu"), self.result_path + "optimal_relation_embedding.pickle", self.log_path, "self.optimal_relation_embeddings") self.test(model)
def train(self): entity_set = MyDataset(self.num_of_train_entities) entity_loader = DataLoader(entity_set, self.batch_size, True) batch_process = BatchProcess( self.train_entities, self.train_head_entities, self.train_tail_entities, self.train_both_entities, self.head_context_head, self.head_context_relation, self.head_context_statistics, self.tail_context_relation, self.tail_context_tail, self.tail_context_statistics, self.head_context_size, self.tail_context_size, self.num_of_train_entities, self.negative_batch_size, self.device) model = Model(self.result_path, self.log_path, self.entity_dimension, self.relation_dimension, self.num_of_entities, self.num_of_relations, self.norm, self.device) if self.continue_learning: model.input() model.to(self.device) optimizer = torch.optim.Adam(model.parameters(), self.learning_rate) current_mean_rank = self.validate(model) log_text(self.log_path, "initial mean rank (validation): %f" % current_mean_rank) optimal_mean_rank = current_mean_rank self.optimal_entity_embeddings = model.entity_embeddings.weight.data.clone( ) self.optimal_relation_embeddings = model.relation_embeddings.weight.data.clone( ) patience_count = 0 for epoch in range(self.num_of_epochs): epoch_loss = 0. count = 0 for entity_id_batch in entity_loader: if count % 200 == 0: print "%d batches processed " % count + time.strftime( '%m-%d-%Y %H:%M:%S', time.localtime(time.time())) count += 1 model.normalize() optimizer.zero_grad() entity_id_batch = entity_id_batch.tolist() entity_batch = [ self.train_entities[entity_id] for entity_id in entity_id_batch ] head_loss, tail_loss, both_loss, batch_loss = 0., 0., 0., 0. head_batch, tail_batch, both_batch = batch_process.batch_classification( entity_batch) if len(head_batch) > 0: head_head, head_relation = batch_process.head_context_process( head_batch) negative_head_batch = batch_process.negative_batch_generation( head_batch) head_batch = torch.LongTensor(head_batch) head_loss = -1. * model( head_batch.to(self.device), head_head.to(self.device), head_relation.to(self.device), None, None, negative_head_batch.to(self.device)) if len(tail_batch) > 0: tail_relation, tail_tail = batch_process.tail_context_process( tail_batch) negative_tail_batch = batch_process.negative_batch_generation( tail_batch) tail_batch = torch.LongTensor(tail_batch) tail_loss = -1. * model( tail_batch.to(self.device), None, None, tail_relation.to(self.device), tail_tail.to( self.device), negative_tail_batch.to(self.device)) if len(both_batch) > 0: both_head, both_head_relation = batch_process.head_context_process( both_batch) both_tail_relation, both_tail = batch_process.tail_context_process( both_batch) negative_both_batch = batch_process.negative_batch_generation( both_batch) both_batch = torch.LongTensor(both_batch) both_loss = -1. * model( both_batch.to(self.device), both_head.to(self.device), both_head_relation.to(self.device), both_tail_relation.to(self.device), both_tail.to(self.device), negative_both_batch.to(self.device)) batch_loss += head_loss + tail_loss + both_loss batch_loss.backward() optimizer.step() epoch_loss += batch_loss log_text( self.log_path, "\r\nepoch " + str(epoch) + ": , loss: " + str(epoch_loss)) current_mean_rank = self.validate(model) if current_mean_rank < optimal_mean_rank: log_text( self.log_path, "optimal average raw mean rank: " + str(optimal_mean_rank) + " -> " + str(current_mean_rank)) patience_count = 0 optimal_mean_rank = current_mean_rank self.optimal_entity_embeddings = model.entity_embeddings.weight.data.clone( ) self.optimal_relation_embeddings = model.relation_embeddings.weight.data.clone( ) else: patience_count += 1 log_text( self.log_path, "early stop patience: " + str(self.early_stop_patience) + ", patience count: " + str(patience_count) + ", current rank: " + str(current_mean_rank) + ", best rank: " + str(optimal_mean_rank)) if patience_count == self.patience: if self.early_stop_patience == 1: dump_data( self.optimal_entity_embeddings.to("cpu"), self.result_path + "optimal_entity_embedding.pickle", self.log_path, "self.optimal_entity_embeddings") dump_data( self.optimal_relation_embeddings.to("cpu"), self.result_path + "optimal_relation_embedding.pickle", self.log_path, "self.optimal_relation_embeddings") break log_text( self.log_path, "learning rate: " + str(self.learning_rate) + " -> " + str(self.learning_rate / 2)) self.learning_rate = self.learning_rate / 2 model.entity_embeddings.weight.data = self.optimal_entity_embeddings.clone( ) model.relation_embeddings.weight.data = self.optimal_relation_embeddings.clone( ) optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate) patience_count = 0 self.early_stop_patience -= 1 if epoch % self.output_freq == 0: model.output() dump_data(self.optimal_entity_embeddings.to("cpu"), self.result_path + "optimal_entity_embedding.pickle", self.log_path, "self.optimal_entity_embeddings") dump_data( self.optimal_relation_embeddings.to("cpu"), self.result_path + "optimal_relation_embedding.pickle", self.log_path, "self.optimal_relation_embeddings") self.test(model)