Esempio n. 1
0
    def train(self):
        read = readData(self.inAdd, self.train2id, self.headRelation2Tail, self.tailRelation2Head,
                      self.entity2id, self.id2entity, self.relation2id, self.id2relation, self.nums)
        self.trainTriple = read.out()
        self.numOfTriple = self.nums[0]
        self.numOfEntity = self.nums[1]
        self.numOfRelation = self.nums[2]

        self.readValidateTriples()
        self.readTestTriples()

        transE = TransE(self.numOfEntity, self.numOfRelation, self.entityDimension, self.relationDimension, self.margin,
                        self.norm)

        if self.preOrNot:
            self.preRead(transE)

        transE.to(self.device)

        self.bestAvFiMR = self.validate(transE)
        self.entityEmbedding = transE.entity_embeddings.weight.data.clone()
        self.relationEmbedding = transE.relation_embeddings.weight.data.clone()

        criterion = nn.MarginRankingLoss(self.margin, False).to(self.device)
        optimizer = optim.SGD(transE.parameters(), lr=self.learningRate, weight_decay=self.weight_decay)

        dataSet = dataset(self.numOfTriple)
        batchSize = float(self.numOfTriple / self.numOfBatches)
        dataLoader = DataLoader(dataSet, batchSize, True)

        patienceCount = 0

        for epoch in range(self.numOfEpochs):
            epochLoss = 0
            for batch in dataLoader:
                self.positiveBatch = {}
                self.corruptedBatch = {}
                generateBatches(batch, self.train2id, self.positiveBatch, self.corruptedBatch, self.numOfEntity,
                                self.headRelation2Tail, self.tailRelation2Head)
                optimizer.zero_grad()
                positiveBatchHead = self.positiveBatch["h"].to(self.device)
                positiveBatchRelation = self.positiveBatch["r"].to(self.device)
                positiveBatchTail = self.positiveBatch["t"].to(self.device)
                corruptedBatchHead = self.corruptedBatch["h"].to(self.device)
                corruptedBatchRelation = self.corruptedBatch["r"].to(self.device)
                corruptedBatchTail = self.corruptedBatch["t"].to(self.device)
                output = transE(positiveBatchHead, positiveBatchRelation, positiveBatchTail, corruptedBatchHead,
                                   corruptedBatchRelation, corruptedBatchTail)
                positiveLoss, negativeLoss = output.view(2, -1)
                tmpTensor = torch.tensor([-1], dtype=torch.float).to(self.device)
                batchLoss = criterion(positiveLoss, negativeLoss, tmpTensor)
                batchLoss.backward()
                optimizer.step()
                epochLoss += batchLoss

            print("epoch: {}, loss: {}".format(epoch, epochLoss))

            tmpAvFiMR = self.validate(transE)

            if tmpAvFiMR < self.bestAvFiMR:
                print("best averaged raw mean rank: {} -> {}".format(self.bestAvFiMR, tmpAvFiMR))
                patienceCount = 0
                self.bestAvFiMR = tmpAvFiMR
                self.entityEmbedding = transE.entity_embeddings.weight.data.clone()
                self.relationEmbedding = transE.relation_embeddings.weight.data.clone()
            else:
                patienceCount += 1
                print("early stop patience: {}, patience count: {}, current rank: {}, best rank: {}".format(self.earlyStopPatience, patienceCount, tmpAvFiMR, self.bestAvFiMR))
                if patienceCount == self.patience:
                    if self.earlyStopPatience == 1:
                        break
                    print("learning rate: {} -> {}".format(self.learningRate, self.learningRate/2))
                    print("weight decay: {} -> {}".format(self.weight_decay, self.weight_decay*2))
                    self.learningRate /= 2
                    self.weight_decay *= 2
                    transE.entity_embeddings.weight.data = self.entityEmbedding.clone()
                    transE.relation_embeddings.weight.data = self.relationEmbedding.clone()
                    optimizer = optim.SGD(transE.parameters(), lr=self.learningRate, weight_decay=self.weight_decay)
                    patienceCount = 0
                    self.earlyStopPatience -= 1

            if (epoch+1)%self.outputFreq == 0 or (epoch+1) == self.numOfEpochs:
                self.write()
            print()

        transE.entity_embeddings.weight.data = self.entityEmbedding.clone()
        transE.relation_embeddings.weight.data = self.relationEmbedding.clone()
        self.test(transE)
Esempio n. 2
0
File: train.py Progetto: zjs123/MPKE
    def train(self):
        path = "./dataset/" + self.dataset
        data = readData(path, self.train2id, self.year2id, self.step_list,
                        self.headRelation2Tail, self.tailRelation2Head,
                        self.headTail2Relation, self.nums)  #read training data

        self.Triples = data.out()

        self.numOfTrainTriple = self.nums[0]
        self.numOfEntity = self.nums[1]
        self.numOfRelation = self.nums[2]
        self.numOfTime = self.nums[3]
        self.numOfMaxLen = self.nums[4]

        self.readValidateTriples(path)
        self.readTestTriples(path)

        self.model = MPKE(self.numOfEntity, self.numOfRelation, self.numOfTime,
                          self.numOfMaxLen, self.entityDimension,
                          self.relationDimension, self.norm, self.norm_m,
                          self.hyper_m)  #init the model

        #self.preRead()

        self.model.to(self.device)

        #self.test()

        Margin_Loss_D = Loss.double_marginLoss()
        #Margin_Loss_H = Loss.marginLoss()
        #Margin_Loss_S = Loss.sigmoidLoss()

        optimizer = optim.Adam(self.model.parameters(), lr=self.learningRate)

        Dataset = dataset(self.numOfTrainTriple)
        batchsize = int(self.numOfTrainTriple / self.numOfBatches)
        dataLoader = DataLoader(Dataset, batchsize, True)

        Log_path = "./dataset/" + self.dataset + "/" + str(
            self.learningRate) + "_MP_" + "log.txt"
        Log = open(Log_path, "w")
        for epoch in range(self.numOfEpochs):
            epochLoss = 0
            for batch in dataLoader:
                self.positiveBatch = {}
                self.corruptedBatch = {}
                generateBatches(batch, self.train2id, self.step_list,
                                self.positiveBatch, self.corruptedBatch,
                                self.numOfEntity, self.numOfRelation,
                                self.headRelation2Tail, self.tailRelation2Head,
                                self.headTail2Relation, self.ns)
                optimizer.zero_grad()
                positiveBatchHead = self.positiveBatch["h"].to(self.device)
                positiveBatchRelation = self.positiveBatch["r"].to(self.device)
                positiveBatchTail = self.positiveBatch["t"].to(self.device)
                positiveBatchTime = self.positiveBatch["time"].to(self.device)
                positiveBatchStep = self.positiveBatch["step"].to(self.device)
                corruptedBatchHead = self.corruptedBatch["h"].to(self.device)
                corruptedBatchRelation = self.corruptedBatch["r"].to(
                    self.device)
                corruptedBatchTail = self.corruptedBatch["t"].to(self.device)
                corruptedBatchTime = self.corruptedBatch["time"].to(
                    self.device)
                corruptedBatchStep = self.corruptedBatch["step"].to(
                    self.device)

                positiveScore, negativeScore = self.model(
                    positiveBatchHead, positiveBatchRelation,
                    positiveBatchTail, positiveBatchTime, positiveBatchStep,
                    corruptedBatchHead, corruptedBatchRelation,
                    corruptedBatchTail, corruptedBatchTime, corruptedBatchStep)

                ent_embeddings = self.model.entity_embeddings(
                    torch.cat([
                        positiveBatchHead, positiveBatchTail,
                        corruptedBatchHead, corruptedBatchTail
                    ]))
                rel_embeddings = self.model.relation_embeddings(
                    torch.cat([positiveBatchRelation, corruptedBatchRelation]))

                loss = Margin_Loss_D(positiveScore, negativeScore,
                                     self.margin_triple)

                time_embeddings = self.model.time_embeddings(positiveBatchTime)
                step_embeddings = self.model.step_embeddings(positiveBatchStep)

                batchLoss = loss + Loss.normLoss(
                    time_embeddings) + Loss.normLoss(step_embeddings)
                batchLoss.backward()
                optimizer.step()
                epochLoss += batchLoss

            print("epoch " + str(epoch) + ": , loss: " + str(epochLoss))

            if epoch % 20 == 0 and epoch != 0:
                Log.write("epoch " + str(epoch) + ": , loss: " +
                          str(epochLoss))
                meanRank_H, Hits10_H = self.model.Validate_entity_H(
                    validateHead=self.testHead.to(self.device),
                    validateRelation=self.testRelation.to(self.device),
                    validateTail=self.testTail.to(self.device),
                    validateTime=self.testTime.to(self.device),
                    validateStepH=self.testStepH.to(self.device),
                    trainTriple=self.Triples.to(self.device),
                    numOfvalidateTriple=self.numOfTestTriple)
                print("mean rank H_2_2_.1_nonorm: " + str(meanRank_H))
                meanRank_T, Hits10_T = self.model.Validate_entity_T(
                    validateHead=self.testHead.to(self.device),
                    validateRelation=self.testRelation.to(self.device),
                    validateTail=self.testTail.to(self.device),
                    validateTime=self.testTime.to(self.device),
                    validateStepT=self.testStepT.to(self.device),
                    trainTriple=self.Triples.to(self.device),
                    numOfvalidateTriple=self.numOfTestTriple)
                print("mean rank T_2_2_.1_nonorm: " + str(meanRank_T))
                Log.write("valid H MR: " + str(meanRank_H) + "\n")
                Log.write("valid T MR: " + str(meanRank_T) + "\n")
                Log.write("valid entity MR: " +
                          str((meanRank_H + meanRank_T) / 2) + "\n")
                print("valid entity MR: " + str((meanRank_H + meanRank_T) / 2))
                Log.write("valid entity H10: " +
                          str((Hits10_H + Hits10_T) / 2) + "\n")
                print("valid entity H10: " + str((Hits10_H + Hits10_T) / 2))
                ValidMR_relation = self.model.fastValidate_relation(
                    validateHead=self.testHead.to(self.device),
                    validateRelation=self.testRelation.to(self.device),
                    validateTail=self.testTail.to(self.device),
                    validateTime=self.testTime.to(self.device),
                    validateStepH=self.testStepH.to(self.device),
                    numOfvalidateTriple=self.numOfTestTriple)
                Log.write("valid relation MR: " + str(ValidMR_relation) + "\n")
                Log.write("\n")
                print("valid relation MR: " + str(ValidMR_relation))
        Log.close()
Esempio n. 3
0
    def train(self):
        path = "./dataset/" + self.dataset
        data = readData(path, self.train2id, self.seq_withTime,
                        self.seq_relation, self.headRelation2Tail,
                        self.tailRelation2Head, self.headTail2Relation,
                        self.nums)

        self.Triples = data.out()

        self.numOfTrainTriple = self.nums[0]
        self.numOfEntity = self.nums[1]
        self.numOfRelation = self.nums[2]

        self.readValidateTriples(path)
        self.readTestTriples(path)

        #self.Transmit = Transmit(self.numOfEntity, self.numOfRelation, self.entityDimension, self.relationDimension,
        #              self.norm)

        self.Transmit = TKGFrame(self.numOfEntity, self.numOfRelation,
                                 self.entityDimension, self.relationDimension,
                                 self.norm)

        self.Transmit.to(self.device)

        #self.perRead(self.Transmit)

        Margin_Loss_H = Loss.marginLoss()
        Margin_Loss_D = Loss.double_marginLoss()

        optimizer = optim.SGD(self.Transmit.parameters(), lr=self.learningRate)

        Dataset = dataset(self.numOfTrainTriple)
        batchsize = int(self.numOfTrainTriple / self.numOfBatches)
        dataLoader = DataLoader(Dataset, batchsize, True)

        #self.write()
        for epoch in range(self.numOfEpochs):
            epochLoss = 0
            for batch in dataLoader:
                self.positiveBatch = {}
                self.corruptedBatch = {}
                self.relation_pair_batch = {}
                generateBatches(batch, self.train2id, self.seq_relation,
                                self.positiveBatch, self.corruptedBatch,
                                self.relation_pair_batch, self.numOfEntity,
                                self.numOfRelation, self.headRelation2Tail,
                                self.tailRelation2Head, self.headTail2Relation)
                optimizer.zero_grad()
                positiveBatchHead = self.positiveBatch["h"].to(self.device)
                positiveBatchRelation = self.positiveBatch["r"].to(self.device)
                positiveBatchTail = self.positiveBatch["t"].to(self.device)
                corruptedBatchHead = self.corruptedBatch["h"].to(self.device)
                corruptedBatchRelation = self.corruptedBatch["r"].to(
                    self.device)
                corruptedBatchTail = self.corruptedBatch["t"].to(self.device)
                relation_pair_h = self.relation_pair_batch["h"].to(self.device)
                relation_pair_t = self.relation_pair_batch["t"].to(self.device)
                relation_pair_step = self.relation_pair_batch["step"].to(
                    self.device)
                positiveLoss, negativeLoss, positive_relation_pair_Loss, negative_relation_pair_Loss = self.Transmit(
                    positiveBatchHead, positiveBatchRelation,
                    positiveBatchTail, corruptedBatchHead,
                    corruptedBatchRelation, corruptedBatchTail,
                    relation_pair_h, relation_pair_t, relation_pair_step)
                transLoss = Margin_Loss_H(positiveLoss, negativeLoss,
                                          self.margin_triple)

                relationLoss = Margin_Loss_D(positive_relation_pair_Loss,
                                             negative_relation_pair_Loss,
                                             self.margin_relation)

                ent_embeddings = self.Transmit.entity_embeddings(
                    torch.cat([
                        positiveBatchHead, positiveBatchTail,
                        corruptedBatchHead, corruptedBatchTail
                    ]))
                rel_embeddings = self.Transmit.relation_embeddings(
                    torch.cat([
                        positiveBatchRelation, relation_pair_h, relation_pair_t
                    ]))

                normloss = Loss.normLoss(ent_embeddings) + Loss.normLoss(
                    rel_embeddings
                )  #+Loss.F_norm(self.Transmit.relation_trans)#+0.1*Loss.orthogonalLoss(rel_embeddings,self.Transmit.relation_trans)
                batchLoss = transLoss + normloss + self.trade_off * relationLoss
                batchLoss.backward()
                optimizer.step()
                epochLoss += batchLoss

            print("epoch " + str(epoch) + ": , loss: " + str(epochLoss))

            #ValidMR_entity = self.validate_entity()

            ValidMR_relation = self.validate_relation()

            #print("valid entity MR: "+str(ValidMR_entity))
            print("valid relation MR: " + str(ValidMR_relation))