def validate(self, model):
     mean_rank = 0
     valid_dataset = MyDataset(self.num_of_validate_triples)
     valid_dataloader = DataLoader(valid_dataset, self.validation_batch_size, False)
     for valid_batch in valid_dataloader:
         mean_rank += model.validate(torch.tensor([self.id_validate_triples["id_heads"][index.item()] for index in valid_batch]).to(self.device),
                                     torch.tensor([self.id_validate_triples["id_relations"][index.item()] for index in valid_batch]).to(self.device),
                                     torch.tensor([self.id_validate_triples["id_tails"][index.item()] for index in valid_batch]).to(self.device))
     return mean_rank/self.num_of_validate_triples
コード例 #2
0
 def test(self, model):
     test_entity_set = MyDataset(self.num_of_test_entities)
     test_entity_loader = DataLoader(test_entity_set, self.test_batch_size,
                                     False)
     for entity_id_batch in test_entity_loader:
         entity_batch = [
             self.test_entities[entity_id.item()]
             for entity_id in entity_id_batch
         ]
         head_batch, tail_batch, both_batch = self.offline_batch_retrieve.batch_classification(
             "test", entity_batch)
         return float(
             self.loss_compute("test", model, head_batch, tail_batch,
                               both_batch))
コード例 #3
0
 def validate(self, model):
     validate_entity_set = MyDataset(self.num_of_validate_entities)
     validate_entity_loader = DataLoader(validate_entity_set,
                                         self.validation_batch_size, False)
     for entity_id_batch in validate_entity_loader:
         entity_batch = [
             self.validate_entities[entity_id.item()]
             for entity_id in entity_id_batch
         ]
         head_batch, tail_batch, both_batch = self.offline_batch_retrieve.batch_classification(
             "validate", entity_batch)
         return float(
             self.loss_compute("validate", model, head_batch, tail_batch,
                               both_batch))
 def test(self, model):
     train_triple_tensor = load_data(
         self.output_path + "train_triple_tensor.pickle", self.log_path,
         "train_triple_tensor").to(self.device)
     test_dataset = MyDataset(self.num_of_test_triples)
     test_dataloader = DataLoader(test_dataset, self.test_batch_size, False)
     test_result = torch.zeros(4).to(
         self.device
     )  # [mean_rank, hit_n, filtered_mean_rank, filtered_hit_n]
     log_text(self.log_path,
              "number of test triples: %d" % self.num_of_test_triples)
     count = 0
     for test_batch in test_dataloader:
         if count % 1000 == 0:
             print "%d test triples processed" % count
         count += self.test_batch_size
         model.test_calc(
             self.n_of_hit, test_result, train_triple_tensor,
             torch.tensor([
                 self.id_test_triples["id_heads"][index]
                 for index in test_batch
             ]).to(self.device),
             torch.tensor([
                 self.id_test_triples["id_relations"][index]
                 for index in test_batch
             ]).to(self.device),
             torch.tensor([
                 self.id_test_triples["id_tails"][index]
                 for index in test_batch
             ]).to(self.device))
     log_text(
         self.log_path, "raw mean rank: %f" %
         (test_result[0].item() / float(self.num_of_test_triples)))
     log_text(
         self.log_path,
         "raw hit@%d: %f%%" % (self.n_of_hit, 100. * test_result[1].item() /
                               float(2. * self.num_of_test_triples)))
     log_text(
         self.log_path, "filtered mean rank: %f" %
         (test_result[2].item() / float(self.num_of_test_triples)))
     log_text(
         self.log_path, "filtered hit@%d: %f%%" %
         (self.n_of_hit, 100. * test_result[3].item() /
          float(2. * self.num_of_test_triples)))
    def train(self):
        model = Model(self.result_path, self.log_path, self.entity_dimension,
                      self.relation_dimension, self.num_of_entities,
                      self.num_of_relations, self.norm, self.device)
        if self.continue_learning:
            model.input()
        model.to(self.device)
        optimizer = torch.optim.Adam(model.parameters(), self.learning_rate)
        PrintGPUStatus.print_gpu_status("after the initialization of model")

        self.offline_batch_retrieve = OfflineBatchRetrieve(
            self.names, self.dataset)

        entity_set = MyDataset(self.num_of_train_entities)
        entity_loader = DataLoader(entity_set, self.batch_size, True)

        for epoch in range(self.num_of_epochs):
            epoch_loss = 0.
            if epoch != 0 and epoch % self.re_sampling_freq == 0:
                self.context_and_negatives.re_sampling()
                self.offline_batch_retrieve.re_read_context_and_negatives()
            for entity_id_batch in entity_loader:
                model.normalize()
                optimizer.zero_grad()
                entity_batch = [
                    self.train_entities[entity_id.item()]
                    for entity_id in entity_id_batch
                ]
                head_loss, tail_loss, both_loss, batch_loss = 0., 0., 0., 0.
                head_batch, tail_batch, both_batch = self.offline_batch_retrieve.batch_classification(
                    "train", entity_batch)
                if len(head_batch) > 0:
                    head_head, head_relation = self.offline_batch_retrieve.head_context_retrieve(
                        "train", head_batch)
                    negative_head_batch = self.offline_batch_retrieve.negative_retrieves(
                        "train", head_batch)
                    head_batch = torch.LongTensor(head_batch)
                    head_loss = -1. * model(
                        head_batch.to(self.device), head_head.to(self.device),
                        head_relation.to(self.device), None, None,
                        negative_head_batch.to(self.device))
                if len(tail_batch) > 0:
                    tail_relation, tail_tail = self.offline_batch_retrieve.tail_context_retrieve(
                        "train", tail_batch)
                    negative_tail_batch = self.offline_batch_retrieve.negative_retrieves(
                        "train", tail_batch)
                    tail_batch = torch.LongTensor(tail_batch)
                    tail_loss = -1. * model(
                        tail_batch.to(self.device), None, None,
                        tail_relation.to(self.device), tail_tail.to(
                            self.device), negative_tail_batch.to(self.device))
                if len(both_batch) > 0:
                    both_head, both_head_relation = self.offline_batch_retrieve.head_context_retrieve(
                        "train", both_batch)
                    both_tail_relation, both_tail = self.offline_batch_retrieve.tail_context_retrieve(
                        "train", both_batch)
                    negative_both_batch = self.offline_batch_retrieve.negative_retrieves(
                        "train", both_batch)
                    both_batch = torch.LongTensor(both_batch)
                    both_loss = -1. * model(
                        both_batch.to(self.device), both_head.to(self.device),
                        both_head_relation.to(self.device),
                        both_tail_relation.to(self.device),
                        both_tail.to(self.device),
                        negative_both_batch.to(self.device))
                batch_loss += head_loss + tail_loss + both_loss
                batch_loss.backward()
                optimizer.step()
                epoch_loss += batch_loss
            log_text(
                self.log_path,
                "\r\nepoch " + str(epoch) + ": , loss: " + str(epoch_loss))
            if (epoch + 1) % self.output_freq == 0:
                model.output()
コード例 #6
0
    def train(self):
        model = Model(self.result_path, self.log_path, self.entity_dimension,
                      self.relation_dimension, self.num_of_entities,
                      self.num_of_relations, self.norm, self.device)
        if self.continue_learning:
            model.input()
        model.to(self.device)
        optimizer = torch.optim.Adam(model.parameters(), self.learning_rate)
        PrintGPUStatus.print_gpu_status("after the initialization of model")

        self.offline_batch_retrieve = OfflineBatchRetrieve(
            self.names, self.dataset)

        current_validate_loss = self.validate(model)
        log_text(self.log_path,
                 "initial loss (validation): %f" % current_validate_loss)
        optimal_validate_loss = current_validate_loss
        self.optimal_entity_embeddings = model.entity_embeddings.weight.data.clone(
        )
        self.optimal_relation_embeddings = model.relation_embeddings.weight.data.clone(
        )

        entity_set = MyDataset(self.num_of_train_entities)
        entity_loader = DataLoader(entity_set, self.batch_size, True)
        patience_count = 0
        for epoch in range(self.num_of_epochs):
            epoch_loss = 0.
            if epoch != 0 and epoch % self.re_sampling_freq == 0:
                self.context_and_negatives.re_sampling()
                self.offline_batch_retrieve.re_read_context_and_negatives()
            for entity_id_batch in entity_loader:
                model.normalize()
                optimizer.zero_grad()
                entity_batch = [
                    self.train_entities[entity_id.item()]
                    for entity_id in entity_id_batch
                ]
                head_batch, tail_batch, both_batch = self.offline_batch_retrieve.batch_classification(
                    "train", entity_batch)
                batch_loss = self.loss_compute("train", model, head_batch,
                                               tail_batch, both_batch)
                batch_loss.backward()
                optimizer.step()
                epoch_loss += batch_loss
            log_text(
                self.log_path,
                "\r\nepoch " + str(epoch) + ": , loss: " + str(epoch_loss))
            if epoch % self.validation_freq == 0:
                current_validate_loss = self.validate(model)
                if current_validate_loss < optimal_validate_loss:
                    log_text(
                        self.log_path, "optimal validate loss: " +
                        str(optimal_validate_loss) + " -> " +
                        str(current_validate_loss))
                    patience_count = 0
                    optimal_validate_loss = current_validate_loss
                    self.optimal_entity_embeddings = model.entity_embeddings.weight.data.clone(
                    )
                    self.optimal_relation_embeddings = model.relation_embeddings.weight.data.clone(
                    )
                else:
                    patience_count += 1
                    log_text(
                        self.log_path, "early stop patience: " +
                        str(self.early_stop_patience) + ", patience count: " +
                        str(patience_count) + ", current validate loss: " +
                        str(current_validate_loss) +
                        ", optimal validate loss: " +
                        str(optimal_validate_loss))
                    if patience_count == self.patience:
                        if self.early_stop_patience == 1:
                            dump_data(
                                self.optimal_entity_embeddings.to("cpu"),
                                self.result_path +
                                "optimal_entity_embedding.pickle",
                                self.log_path,
                                "self.optimal_entity_embeddings")
                            dump_data(
                                self.optimal_relation_embeddings.to("cpu"),
                                self.result_path +
                                "optimal_relation_embedding.pickle",
                                self.log_path,
                                "self.optimal_relation_embeddings")
                            break
                        log_text(
                            self.log_path,
                            "learning rate: " + str(self.learning_rate) +
                            " -> " + str(self.learning_rate / 2))
                        self.learning_rate = self.learning_rate / 2
                        model.entity_embeddings.weight.data = self.optimal_entity_embeddings.clone(
                        )
                        model.relation_embeddings.weight.data = self.optimal_relation_embeddings.clone(
                        )
                        optimizer = torch.optim.Adam(model.parameters(),
                                                     lr=self.learning_rate)
                        patience_count = 0
                        self.early_stop_patience -= 1
            if (epoch + 1) % self.output_freq == 0:
                model.output()
                dump_data(self.optimal_entity_embeddings.to("cpu"),
                          self.result_path + "optimal_entity_embedding.pickle",
                          self.log_path, "self.optimal_entity_embeddings")
                dump_data(
                    self.optimal_relation_embeddings.to("cpu"),
                    self.result_path + "optimal_relation_embedding.pickle",
                    self.log_path, "self.optimal_relation_embeddings")
        print "test loss: %f" % self.test(model)
    def train(self):
        model = Model(self.result_path, self.log_path, self.entity_dimension, self.relation_dimension, self.num_of_entities, self.num_of_relations, self.norm, self.device)
        if self.continue_learning:
            model.input()
        model.to(self.device)
        optimizer = torch.optim.Adam(model.parameters(), self.learning_rate)
        PrintGPUStatus.print_gpu_status("after the initialization of model")

        self.offline_batch_retrieve = OfflineBatchRetrieve(self.names, self.dataset)

        current_mean_rank = self.validate(model)
        log_text(self.log_path, "initial mean rank (validation): %f" % current_mean_rank)
        optimal_mean_rank = current_mean_rank
        self.optimal_entity_embeddings = model.entity_embeddings.weight.data.clone()
        self.optimal_relation_embeddings = model.relation_embeddings.weight.data.clone()

        entity_set = MyDataset(self.num_of_train_entities)
        entity_loader = DataLoader(entity_set, self.batch_size, True)
        patience_count = 0
        for epoch in range(self.num_of_epochs):
            epoch_loss = 0.
            if epoch != 0 and epoch % self.re_sampling_freq == 0:
                self.context_and_negatives.re_sampling()
                self.offline_batch_retrieve.re_read_context_and_negatives()
            for entity_id_batch in entity_loader:
                model.normalize()
                optimizer.zero_grad()
                entity_batch = [self.train_entities[entity_id.item()] for entity_id in entity_id_batch]
                head_loss, tail_loss, both_loss, batch_loss = 0., 0., 0., 0.
                head_batch, tail_batch, both_batch = self.offline_batch_retrieve.batch_classification("train", entity_batch)
                if len(head_batch) > 0:
                    head_head, head_relation = self.offline_batch_retrieve.head_context_retrieve("train", head_batch)
                    negative_head_batch = self.offline_batch_retrieve.negative_retrieves("train", head_batch)
                    head_batch = torch.LongTensor(head_batch)
                    head_loss = -1. * model(head_batch.to(self.device),
                                            head_head.to(self.device), head_relation.to(self.device),
                                            None, None,
                                            negative_head_batch.to(self.device))
                if len(tail_batch) > 0:
                    tail_relation, tail_tail = self.offline_batch_retrieve.tail_context_retrieve("train", tail_batch)
                    negative_tail_batch = self.offline_batch_retrieve.negative_retrieves("train", tail_batch)
                    tail_batch = torch.LongTensor(tail_batch)
                    tail_loss = -1. * model(tail_batch.to(self.device),
                                            None, None,
                                            tail_relation.to(self.device), tail_tail.to(self.device),
                                            negative_tail_batch.to(self.device))
                if len(both_batch) > 0:
                    both_head, both_head_relation = self.offline_batch_retrieve.head_context_retrieve("train", both_batch)
                    both_tail_relation, both_tail = self.offline_batch_retrieve.tail_context_retrieve("train", both_batch)
                    negative_both_batch = self.offline_batch_retrieve.negative_retrieves("train", both_batch)
                    both_batch = torch.LongTensor(both_batch)
                    both_loss = -1. * model(both_batch.to(self.device),
                                            both_head.to(self.device), both_head_relation.to(self.device),
                                            both_tail_relation.to(self.device), both_tail.to(self.device),
                                            negative_both_batch.to(self.device))
                batch_loss += head_loss + tail_loss + both_loss
                batch_loss.backward()
                optimizer.step()
                epoch_loss += batch_loss
            log_text(self.log_path, "\r\nepoch " + str(epoch) + ": , loss: " + str(epoch_loss))
            if epoch % self.validation_freq == 0:
                current_mean_rank = self.validate(model)
                if current_mean_rank < optimal_mean_rank:
                    log_text(self.log_path, "optimal average raw mean rank: " + str(optimal_mean_rank) + " -> " + str(current_mean_rank))
                    patience_count = 0
                    optimal_mean_rank = current_mean_rank
                    self.optimal_entity_embeddings = model.entity_embeddings.weight.data.clone()
                    self.optimal_relation_embeddings = model.relation_embeddings.weight.data.clone()
                else:
                    patience_count += 1
                    log_text(self.log_path, "early stop patience: " + str(self.early_stop_patience) + ", patience count: " + str(patience_count) + ", current rank: " + str(current_mean_rank) + ", best rank: " + str(optimal_mean_rank))
                    if patience_count == self.patience:
                        if self.early_stop_patience == 1:
                            dump_data(self.optimal_entity_embeddings.to("cpu"),
                                      self.result_path + "optimal_entity_embedding.pickle", self.log_path,
                                      "self.optimal_entity_embeddings")
                            dump_data(self.optimal_relation_embeddings.to("cpu"),
                                      self.result_path + "optimal_relation_embedding.pickle", self.log_path,
                                      "self.optimal_relation_embeddings")
                            break
                        log_text(self.log_path, "learning rate: " + str(self.learning_rate) + " -> " + str(self.learning_rate / 2))
                        self.learning_rate = self.learning_rate / 2
                        model.entity_embeddings.weight.data = self.optimal_entity_embeddings.clone()
                        model.relation_embeddings.weight.data = self.optimal_relation_embeddings.clone()
                        optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
                        patience_count = 0
                        self.early_stop_patience -= 1
            if (epoch + 1) % self.output_freq == 0:
                model.output()
                dump_data(self.optimal_entity_embeddings.to("cpu"), self.result_path + "optimal_entity_embedding.pickle", self.log_path, "self.optimal_entity_embeddings")
                dump_data(self.optimal_relation_embeddings.to("cpu"), self.result_path + "optimal_relation_embedding.pickle", self.log_path, "self.optimal_relation_embeddings")
        self.test(model)
コード例 #8
0
    def test_calc(self, n_of_hit, test_result, train_triple_tensor, test_heads,
                  test_relations, test_tails):
        test_head_embeddings = self.entity_embeddings(
            test_heads)  # (num_of_test_triples, entity_dim)
        test_relation_embeddings = self.relation_embeddings(
            test_relations)  # (num_of_test_triples, relation_dim)
        test_tail_embeddings = self.entity_embeddings(
            test_tails)  # (num_of_test_triples, entity_dim)

        target_loss = torch.norm(
            test_head_embeddings + test_relation_embeddings -
            test_tail_embeddings, self.norm, 1)  # (num_of_test_triples,)
        tmp_head_loss = torch.norm(
            torch.unsqueeze(self.entity_embeddings.weight.data, 1) +
            test_relation_embeddings - test_tail_embeddings, self.norm,
            2)  # (num_of_entities, num_of_test_triples)
        tmp_tail_loss = torch.norm(
            test_head_embeddings + test_relation_embeddings -
            torch.unsqueeze(self.entity_embeddings.weight.data, 1), self.norm,
            2)  # (num_of_entities, num_of_test_triples)

        better_heads = torch.nonzero(
            nn.functional.relu(target_loss -
                               tmp_head_loss))  # (number of better heads, 2)
        better_tails = torch.nonzero(
            nn.functional.relu(target_loss -
                               tmp_tail_loss))  # (number of better tails, 2)

        rank_h = better_heads.size()[0]
        rank_t = better_tails.size()[0]

        test_result[0] += (rank_h + rank_t + 2) / 2
        if rank_h + 1 <= n_of_hit * test_heads.size()[0]:
            test_result[1] += test_heads.size()[0]
        if rank_t + 1 <= n_of_hit * test_heads.size()[0]:
            test_result[1] += test_heads.size()[0]

        existing_heads = 0
        existing_tails = 0
        batch_num = 200
        dataset_h = MyDataset(rank_h)
        data_loader_h = DataLoader(dataset_h, batch_num, False)
        for batch in data_loader_h:
            existing_heads += torch.nonzero(
                torch.relu(-1 * torch.sum(
                    torch.abs(
                        torch.cat(
                            (torch.unsqueeze(better_heads[batch, 0], 1),
                             torch.unsqueeze(
                                 test_relations[better_heads[batch, 1]], 1),
                             torch.unsqueeze(
                                 test_tails[better_heads[batch, 1]], 1)), 1) -
                        torch.unsqueeze(train_triple_tensor, 1)), 2) +
                           0.5)).size()[0]
        dataset_t = MyDataset(rank_t)
        data_loader_t = DataLoader(dataset_t, batch_num, False)
        for batch in data_loader_t:
            existing_tails += torch.nonzero(
                torch.relu(-1 * torch.sum(
                    torch.abs(
                        torch.cat(
                            (torch.unsqueeze(
                                test_heads[better_tails[batch, 1]], 1),
                             torch.unsqueeze(
                                 test_relations[better_tails[batch, 1]], 1),
                             torch.unsqueeze(better_tails[batch, 0], 1)), 1) -
                        torch.unsqueeze(train_triple_tensor, 1)), 2) +
                           0.5)).size()[0]

        filtered_rank_h = rank_h - existing_heads
        filtered_rank_t = rank_t - existing_tails

        test_result[2] += (filtered_rank_h + filtered_rank_t + 2) / 2
        if filtered_rank_h + 1 <= n_of_hit * test_heads.size()[0]:
            test_result[3] += test_heads.size()[0]
        if filtered_rank_t + 1 <= n_of_hit * test_heads.size()[0]:
            test_result[3] += test_heads.size()[0]
    def train(self):
        entity_set = MyDataset(self.num_of_train_entities)
        entity_loader = DataLoader(entity_set, self.batch_size, True)
        batch_process = BatchProcess(
            self.train_entities, self.train_head_entities,
            self.train_tail_entities, self.train_both_entities,
            self.head_context_head, self.head_context_relation,
            self.head_context_statistics, self.tail_context_relation,
            self.tail_context_tail, self.tail_context_statistics,
            self.head_context_size, self.tail_context_size,
            self.num_of_train_entities, self.negative_batch_size, self.device)
        model = Model(self.result_path, self.log_path, self.entity_dimension,
                      self.relation_dimension, self.num_of_entities,
                      self.num_of_relations, self.norm, self.device)
        if self.continue_learning:
            model.input()
        model.to(self.device)
        optimizer = torch.optim.Adam(model.parameters(), self.learning_rate)
        current_mean_rank = self.validate(model)
        log_text(self.log_path,
                 "initial mean rank (validation): %f" % current_mean_rank)
        optimal_mean_rank = current_mean_rank
        self.optimal_entity_embeddings = model.entity_embeddings.weight.data.clone(
        )
        self.optimal_relation_embeddings = model.relation_embeddings.weight.data.clone(
        )

        patience_count = 0
        for epoch in range(self.num_of_epochs):
            epoch_loss = 0.
            count = 0
            for entity_id_batch in entity_loader:
                if count % 200 == 0:
                    print "%d batches processed " % count + time.strftime(
                        '%m-%d-%Y %H:%M:%S', time.localtime(time.time()))
                count += 1
                model.normalize()
                optimizer.zero_grad()
                entity_id_batch = entity_id_batch.tolist()
                entity_batch = [
                    self.train_entities[entity_id]
                    for entity_id in entity_id_batch
                ]
                head_loss, tail_loss, both_loss, batch_loss = 0., 0., 0., 0.
                head_batch, tail_batch, both_batch = batch_process.batch_classification(
                    entity_batch)
                if len(head_batch) > 0:
                    head_head, head_relation = batch_process.head_context_process(
                        head_batch)
                    negative_head_batch = batch_process.negative_batch_generation(
                        head_batch)
                    head_batch = torch.LongTensor(head_batch)
                    head_loss = -1. * model(
                        head_batch.to(self.device), head_head.to(self.device),
                        head_relation.to(self.device), None, None,
                        negative_head_batch.to(self.device))
                if len(tail_batch) > 0:
                    tail_relation, tail_tail = batch_process.tail_context_process(
                        tail_batch)
                    negative_tail_batch = batch_process.negative_batch_generation(
                        tail_batch)
                    tail_batch = torch.LongTensor(tail_batch)
                    tail_loss = -1. * model(
                        tail_batch.to(self.device), None, None,
                        tail_relation.to(self.device), tail_tail.to(
                            self.device), negative_tail_batch.to(self.device))
                if len(both_batch) > 0:
                    both_head, both_head_relation = batch_process.head_context_process(
                        both_batch)
                    both_tail_relation, both_tail = batch_process.tail_context_process(
                        both_batch)
                    negative_both_batch = batch_process.negative_batch_generation(
                        both_batch)
                    both_batch = torch.LongTensor(both_batch)
                    both_loss = -1. * model(
                        both_batch.to(self.device), both_head.to(self.device),
                        both_head_relation.to(self.device),
                        both_tail_relation.to(self.device),
                        both_tail.to(self.device),
                        negative_both_batch.to(self.device))

                batch_loss += head_loss + tail_loss + both_loss
                batch_loss.backward()
                optimizer.step()
                epoch_loss += batch_loss
            log_text(
                self.log_path,
                "\r\nepoch " + str(epoch) + ": , loss: " + str(epoch_loss))
            current_mean_rank = self.validate(model)
            if current_mean_rank < optimal_mean_rank:
                log_text(
                    self.log_path, "optimal average raw mean rank: " +
                    str(optimal_mean_rank) + " -> " + str(current_mean_rank))
                patience_count = 0
                optimal_mean_rank = current_mean_rank
                self.optimal_entity_embeddings = model.entity_embeddings.weight.data.clone(
                )
                self.optimal_relation_embeddings = model.relation_embeddings.weight.data.clone(
                )
            else:
                patience_count += 1
                log_text(
                    self.log_path,
                    "early stop patience: " + str(self.early_stop_patience) +
                    ", patience count: " + str(patience_count) +
                    ", current rank: " + str(current_mean_rank) +
                    ", best rank: " + str(optimal_mean_rank))
                if patience_count == self.patience:
                    if self.early_stop_patience == 1:
                        dump_data(
                            self.optimal_entity_embeddings.to("cpu"),
                            self.result_path +
                            "optimal_entity_embedding.pickle", self.log_path,
                            "self.optimal_entity_embeddings")
                        dump_data(
                            self.optimal_relation_embeddings.to("cpu"),
                            self.result_path +
                            "optimal_relation_embedding.pickle", self.log_path,
                            "self.optimal_relation_embeddings")
                        break
                    log_text(
                        self.log_path,
                        "learning rate: " + str(self.learning_rate) + " -> " +
                        str(self.learning_rate / 2))
                    self.learning_rate = self.learning_rate / 2
                    model.entity_embeddings.weight.data = self.optimal_entity_embeddings.clone(
                    )
                    model.relation_embeddings.weight.data = self.optimal_relation_embeddings.clone(
                    )
                    optimizer = torch.optim.Adam(model.parameters(),
                                                 lr=self.learning_rate)
                    patience_count = 0
                    self.early_stop_patience -= 1
            if epoch % self.output_freq == 0:
                model.output()
                dump_data(self.optimal_entity_embeddings.to("cpu"),
                          self.result_path + "optimal_entity_embedding.pickle",
                          self.log_path, "self.optimal_entity_embeddings")
                dump_data(
                    self.optimal_relation_embeddings.to("cpu"),
                    self.result_path + "optimal_relation_embedding.pickle",
                    self.log_path, "self.optimal_relation_embeddings")
        self.test(model)