예제 #1
0
파일: fede.py 프로젝트: HELL-TO-HEAVEN/FedE
    def __init__(self, args, client_id, data, train_dataloader,
                 valid_dataloader, test_dataloader, rel_embed):
        self.args = args
        self.data = data
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.test_dataloader = test_dataloader
        self.rel_embed = rel_embed
        self.client_id = client_id

        self.score_local = []
        self.score_global = []

        self.kge_model = KGEModel(args, args.model)
        self.ent_embed = None
예제 #2
0
def train(args):
    # ----------------
    # Load Data
    # ----------------
    logging.info("loading data..")

    teacher = KGDataset(dict_path=args.teacher_data_path,
                        model_path=args.teacher_model_path)
    student = KGDataset(data_path=args.student_data_path,
                        dict_path=args.student_data_path)
    shared_entity = load_shared_entity(args.shared_entity_path)
    shared_entity = shared_entity[
        (shared_entity["student"] < student.get_entity_count())
        & (shared_entity["teacher"] < teacher.get_entity_count())]
    student2teacher = dict(
        zip(shared_entity["student"], shared_entity["teacher"]))

    logging.info("train: valid: test = %d: %d: %d" % (len(
        student.train_set), len(student.valid_set), len(student.test_set)))
    num_entities = student.get_entity_count()
    num_relations = student.get_relation_count()

    logging.info("number of entities: %d" % (num_entities))
    logging.info("number of relations: %d" % (num_relations))

    # training data
    train_loader_head = DataLoader(TrainDataset(student.train_set,
                                                num_entities, num_relations,
                                                args.num_neg_samples,
                                                'head-batch'),
                                   batch_size=args.kge_batch,
                                   shuffle=True,
                                   num_workers=max(0, args.num_workers // 2),
                                   collate_fn=TrainDataset.collate_fn)
    train_loader_tail = DataLoader(TrainDataset(student.train_set,
                                                num_entities, num_relations,
                                                args.num_neg_samples,
                                                'tail-batch'),
                                   batch_size=args.kge_batch,
                                   shuffle=True,
                                   num_workers=max(0, args.num_workers // 2),
                                   collate_fn=TrainDataset.collate_fn)
    train_iterator = BidirectionalOneShotIterator(train_loader_head,
                                                  train_loader_tail)

    # validation data and test data
    all_student_data = student.train_set + student.valid_set + student.test_set

    valid_loader_head = DataLoader(TestDataset(student.valid_set,
                                               all_student_data, num_entities,
                                               num_relations, 'head-batch'),
                                   batch_size=args.test_batch,
                                   num_workers=max(0, args.num_workers // 2),
                                   collate_fn=TestDataset.collate_fn)
    valid_loader_tail = DataLoader(TestDataset(student.valid_set,
                                               all_student_data, num_entities,
                                               num_relations, 'tail-batch'),
                                   batch_size=args.test_batch,
                                   num_workers=max(0, args.num_workers // 2),
                                   collate_fn=TestDataset.collate_fn)
    valid_dataloaders = [valid_loader_head, valid_loader_tail]

    test_loader_head = DataLoader(TestDataset(student.test_set,
                                              all_student_data, num_entities,
                                              num_relations, 'head-batch'),
                                  batch_size=args.test_batch,
                                  num_workers=max(0, args.num_workers // 2),
                                  collate_fn=TestDataset.collate_fn)
    test_loader_tail = DataLoader(TestDataset(student.test_set,
                                              all_student_data, num_entities,
                                              num_relations, 'tail-batch'),
                                  batch_size=args.test_batch,
                                  num_workers=max(0, args.num_workers // 2),
                                  collate_fn=TestDataset.collate_fn)
    test_dataloaders = [test_loader_head, test_loader_tail]

    # ----------------
    # Prepare Data
    # ----------------

    logging.info("preparing data..")

    # writer = SummaryWriter(args.save_path)
    writer = None

    if args.gpu_id == -1:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda:%d" % (args.gpu_id))
    teacher.entity_embedding = teacher.entity_embedding.to(
        device).requires_grad_(False)
    teacher.relation_embedding = teacher.relation_embedding.to(
        device).requires_grad_(False)

    learner = KGEModel(model_name=args.kge_model,
                       num_entities=num_entities,
                       num_relations=num_relations,
                       hidden_dim=args.emb_dim,
                       gamma=args.margin,
                       double_entity_embedding=args.kge_model
                       in ["RotatE", "ComplEx"],
                       double_relation_embedding=args.kge_model
                       in ["ComplEx"]).to(device)
    transnet = TransNetwork(teacher.get_embedding_dim(),
                            learner.entity_dim).to(device)
    generator = Generator(learner.entity_dim).to(device)
    discriminator = Discriminator(learner.entity_dim,
                                  activation=nn.Sigmoid()).to(device)

    optimizer_g = optim.Adam(generator.parameters(),
                             lr=args.gan_lr,
                             betas=(0.5, 0.9))
    optimizer_d = optim.Adam(list(transnet.parameters()) +
                             list(discriminator.parameters()),
                             lr=args.gan_lr,
                             betas=(0.5, 0.9))
    optimizer_l = optim.Adam(list(transnet.parameters()) +
                             list(learner.parameters()),
                             lr=args.kge_lr)
    scheduler_g = get_cosine_with_hard_restarts_schedule_with_warmup(
        optimizer_g,
        args.steps // 100,
        args.steps,
        num_cycles=4,
        min_percent=args.gan_lr * 0.001)
    scheduler_d = get_cosine_with_hard_restarts_schedule_with_warmup(
        optimizer_d,
        args.steps // 100,
        args.steps,
        num_cycles=4,
        min_percent=args.gan_lr * 0.001)
    scheduler_l = get_cosine_with_hard_restarts_schedule_with_warmup(
        optimizer_l,
        args.steps // 100,
        args.steps,
        num_cycles=4,
        min_percent=args.kge_lr * 0.001)
    #    scheduler_g = get_constant_schedule_with_warmup(optimizer_g, args.steps//100)
    #    scheduler_d = get_constant_schedule_with_warmup(optimizer_d, args.steps//100)
    #    scheduler_l = get_constant_schedule_with_warmup(optimizer_l, args.steps//100)
    optimizer_g.zero_grad()
    optimizer_g.step()
    scheduler_g.step(0)
    optimizer_d.zero_grad()
    optimizer_d.step()
    scheduler_d.step(0)
    optimizer_l.zero_grad()
    optimizer_l.step()
    scheduler_l.step(0)

    bce_loss = nn.BCELoss(reduction='none')
    mse_loss = nn.MSELoss(reduction='none')
    cos_loss = nn.CosineEmbeddingLoss(reduction='none')

    logging.info("begin training..")
    transnet.train()
    generator.train()
    discriminator.train()
    learner.train()
    shared_iterator = batch_iter(shared_entity, args.gan_batch)

    training_logs = []
    valid_best_metrics = {}
    test_best_metrics = {}

    for step in range(1, args.steps + 1):
        log = {}

        loss_g_selfs = []
        loss_g_dists = []
        loss_gs = []
        loss_d_reals = []
        loss_d_transes = []
        loss_d_fakes = []
        loss_ds = []
        teacher_entities = []
        student_entities = []
        for gan_step in range(args.gan_steps):
            try:
                shared_batch = next(shared_iterator)
            except StopIteration:
                shared_iterator = batch_iter(shared_entity, args.gan_batch)
                shared_batch = next(shared_iterator)
            teacher_entity = torch.LongTensor(
                shared_batch['teacher'].values).to(device)
            student_entity = torch.LongTensor(
                shared_batch['student'].values).to(device)
            teacher_embed = torch.index_select(teacher.entity_embedding,
                                               dim=0,
                                               index=teacher_entity)
            student_embed = torch.index_select(learner.entity_embedding,
                                               dim=0,
                                               index=student_entity)
            gan_batch_size = teacher_entity.size(0)
            gan_ones = torch.ones((gan_batch_size, 1),
                                  device=device,
                                  requires_grad=False)
            teacher_entities.append(teacher_entity)
            student_entities.append(student_entity)

            # --------------------
            # Train Discriminator
            # --------------------

            random_z = (torch.rand_like(student_embed) * 2 - 1)
            fake_emb = generator(student_embed.detach(), random_z)
            real_output = discriminator(student_embed.detach(),
                                        student_embed.detach())
            trans_output = discriminator(student_embed.detach(),
                                         transnet(teacher_embed))
            fake_output = discriminator(student_embed.detach(),
                                        fake_emb.detach())

            loss_d_real = -torch.mean(real_output)
            loss_d_trans = -torch.mean(trans_output)
            loss_d_fake = torch.mean(fake_output)

            # beta = anneal_fn("cosine", step, args.steps, args.kge_beta, 0)
            beta = 0
            loss_d = loss_d_real + loss_d_fake + beta * loss_d_trans

            optimizer_d.zero_grad()
            loss_d.backward()
            optimizer_d.step()
            scheduler_d.step(step)

            loss_d_reals.append(loss_d_real.item())
            loss_d_transes.append(loss_d_trans.item())
            loss_d_fakes.append(loss_d_fake.item())
            loss_ds.append(loss_d.item())

            # clip weights of discriminator
            for param in discriminator.parameters():
                param.data.clamp_(-args.gan_clip_value, args.gan_clip_value)

            if args.gan_n_critic > 0:
                if gan_step % args.gan_n_critic == 0:

                    # ----------------
                    # Train Generator
                    # ----------------

                    # random_z = (torch.rand_like(student_embed) * 2 - 1)
                    # fake_emb = generator(student_embed.detach(), random_z)
                    loss_g_self = -torch.mean(
                        discriminator(student_embed.detach(), fake_emb))
                    loss_g_dist = cos_loss(student_embed.detach(), fake_emb,
                                           gan_ones).mean()
                    alpha = anneal_fn("cosine", step, args.steps,
                                      args.kge_alpha, 0)
                    loss_g = loss_g_self + alpha * loss_g_dist

                    optimizer_g.zero_grad()
                    loss_g.backward()
                    optimizer_g.step()
                    scheduler_g.step(step)
                    loss_g_selfs.append(loss_g_self.item())
                    loss_g_dists.append(loss_g_dist.item())
                    loss_gs.append(loss_g.item())
            else:
                for gene_step in range(-args.gan_n_critic):

                    # ----------------
                    # Train Generator
                    # ----------------

                    random_z = (torch.rand_like(student_embed) * 2 - 1)
                    fake_emb = generator(student_embed.detach(), random_z)
                    loss_g_self = -torch.mean(
                        discriminator(student_embed.detach(), fake_emb))
                    loss_g_dist = cos_loss(student_embed.detach(), fake_emb,
                                           gan_ones).mean()
                    alpha = anneal_fn("cosine", step, args.steps,
                                      args.kge_alpha, 0)
                    loss_g = loss_g_self + alpha * loss_g_dist

                    optimizer_g.zero_grad()
                    loss_g.backward()
                    optimizer_g.step()
                    scheduler_g.step(step)
                    loss_g_selfs.append(loss_g_self.item())
                    loss_g_dists.append(loss_g_dist.item())
                    loss_gs.append(loss_g.item())

        log["loss_g_self"] = np.mean(loss_g_selfs) if len(loss_gs) > 0 else 0.0
        log["loss_g_dist"] = np.mean(loss_g_dists) if len(loss_gs) > 0 else 0.0
        log["loss_g"] = np.mean(loss_gs) if len(loss_gs) > 0 else 0.0
        log["loss_d_real"] = np.mean(loss_d_reals) if len(loss_ds) > 0 else 0.0
        log["loss_d_trans"] = np.mean(
            loss_d_transes) if len(loss_ds) > 0 else 0.0
        log["loss_d_fake"] = np.mean(loss_d_fakes) if len(loss_ds) > 0 else 0.0
        log["loss_d"] = np.mean(loss_ds) if len(loss_ds) > 0 else 0.0

        # --------------------
        # Train Learner
        # --------------------

        positive_sample, negative_sample, subsampling_weight, mode = next(
            train_iterator)
        teacher_entity = []
        student_entity = []
        transfered_sample = []
        transfered_weight = []
        transfered_index = []
        for i, s in enumerate(positive_sample[:, 0].numpy()):
            t = student2teacher.get(s, -1)
            if t != -1:
                teacher_entity.append(t)
                student_entity.append(s)
                transfered_sample.append(positive_sample[i])
                transfered_weight.append(subsampling_weight[i].item())
                transfered_index.append(i)
        num_transfered_head = len(transfered_sample)
        for i, s in enumerate(positive_sample[:, 2].numpy()):
            t = student2teacher.get(s, -1)
            if t != -1:
                teacher_entity.append(t)
                student_entity.append(s)
                transfered_sample.append(positive_sample[i])
                transfered_weight.append(subsampling_weight[i].item())
                transfered_index.append(i)
        teacher_entity = torch.LongTensor(teacher_entity).to(device)
        student_entity = torch.LongTensor(student_entity).to(device)
        positive_sample = positive_sample.to(device)
        negative_sample = negative_sample.to(device)
        transfered_sample = torch.cat(transfered_sample,
                                      dim=0).view(-1, 3).to(device)
        subsampling_weight = subsampling_weight.to(device)
        transfered_weight = torch.tensor(transfered_weight).to(device)
        transfered_index = torch.LongTensor(transfered_index).to(device)
        teacher_entities.append(teacher_entity)
        student_entities.append(student_entity)

        positive_score = learner(positive_sample)
        negative_score = learner((positive_sample, negative_sample), mode=mode)

        if transfered_sample.size(0) > 0:
            transfered_negative_score = torch.index_select(
                negative_score, dim=0, index=transfered_index)
            teacher_embed = torch.index_select(teacher.entity_embedding,
                                               dim=0,
                                               index=teacher_entity)
            student_embed = torch.index_select(learner.entity_embedding,
                                               dim=0,
                                               index=student_entity)

            transfered_head = torch.cat([
                transnet(teacher_embed[:num_transfered_head]),
                torch.index_select(
                    learner.entity_embedding,
                    dim=0,
                    index=transfered_sample[num_transfered_head:, 0])
            ],
                                        dim=0).unsqueeze(1)
            transfered_relation = torch.index_select(
                learner.relation_embedding,
                dim=0,
                index=transfered_sample[:, 1]).unsqueeze(1)
            transfered_tail = torch.cat([
                torch.index_select(
                    learner.entity_embedding,
                    dim=0,
                    index=transfered_sample[:num_transfered_head, 2]),
                transnet(teacher_embed[num_transfered_head:])
            ],
                                        dim=0).unsqueeze(1)

            transfered_score = learner.score(transfered_head,
                                             transfered_relation,
                                             transfered_tail)
            # gan_validity = torch.ones((transfered_weight.size(0), 1), dtype=transfered_weight.dtype, device=transfered_weight.device) # w/o AAM
            gan_validity = discriminator(student_embed,
                                         transnet(teacher_embed))
            transfered_weight = gan_validity * transfered_weight
        else:
            transfered_negative_score = torch.tensor([[0.0]]).to(device)
            transfered_score = torch.tensor([[0.0]]).to(device)
            gan_validity = 1
            transfered_weight = torch.tensor([0.0]).to(device)

        if learner.model_name == 'ConvE':
            positive_sample_loss = bce_loss(
                F.logsigmoid(positive_score),
                torch.ones_like(positive_score)).mean()
            negative_sample_loss = bce_loss(
                F.logsigmoid(-negative_score),
                torch.zeros_like(negative_score)).mean()
            transfered_sample_loss = (
                gan_validity *
                bce_loss(F.logsigmoid(transfered_score),
                         torch.ones_like(transfered_score))).mean()
            loss_l_self = positive_sample_loss + negative_sample_loss
            loss_l_trans = transfered_sample_loss
        elif learner.model_name == 'ComplEx' or learner.model_name == 'RotatE':
            positive_sample_loss = -F.logsigmoid(positive_score).mean()
            negative_sample_loss = -F.logsigmoid(-negative_score).mean()
            transfered_sample_loss = -(gan_validity *
                                       F.logsigmoid(transfered_score)).mean()
            loss_l_self = positive_sample_loss + negative_sample_loss
            loss_l_trans = transfered_sample_loss
        elif learner.model_name == 'TransE' or learner.model_name == 'DistMult':
            positive_sample_loss = -positive_score.mean()
            negative_sample_loss = negative_score.mean()
            transfered_sample_loss = -transfered_score.mean()
            loss_l_self = F.relu(-positive_score + negative_score +
                                 args.margin).mean()
            loss_l_trans = (F.relu(-transfered_score +
                                   transfered_negative_score + args.margin) *
                            gan_validity.unsqueeze(1)).mean()

        if len(teacher_entities) > 0:
            teacher_entities = torch.cat(teacher_entities, dim=0)
            student_entities = torch.cat(student_entities, dim=0)
            teacher_embeds = torch.index_select(teacher.entity_embedding,
                                                dim=0,
                                                index=teacher_entities)
            student_embeds = torch.index_select(learner.entity_embedding,
                                                dim=0,
                                                index=student_entities)
            trans_embeds = transnet(teacher_embeds)
            # gan_validity = torch.ones((trans_embeds.size(0), 1), dtype=trans_embeds.dtype, device=trans_embeds.device) # w/o AAM
            gan_validity = discriminator(student_embeds, trans_embeds)
            loss_l_dist = (gan_validity * cos_loss(
                student_embeds, trans_embeds,
                torch.ones(trans_embeds.size(0), device=device))).mean()
        else:
            loss_l_dist = torch.tensor([0.0]).to(device)

        if args.reg != 0.0:
            # Use L3 regularization for ComplEx and DistMult
            reg = args.reg * (learner.entity_embedding.norm(
                p=3)**3 + learner.relation_embedding.norm(p=3).norm(p=3)**3)
            loss_l_self = loss_l_self + reg
            reg_log = {'reg': reg.item()}
            log.update(reg_log)
        else:
            reg_log = {}

        alpha = anneal_fn("cyclical_cosine", step, args.steps // 2,
                          args.kge_alpha, 0)
        beta = anneal_fn("cyclical_cosine", step, args.steps // 2,
                         args.kge_beta, 0)
        # alpha = args.kge_alpha
        # beta = args.kge_beta
        loss_l = loss_l_self + alpha * loss_l_dist + beta * loss_l_trans

        optimizer_l.zero_grad()
        loss_l.backward()
        optimizer_l.step()
        scheduler_l.step(step)

        log["alpha"] = alpha
        log["beta"] = beta
        log["pos_loss"] = positive_sample_loss.item()
        log["neg_loss"] = negative_sample_loss.item()
        log["loss_l_self"] = loss_l_self.item()
        log["loss_l_dist"] = loss_l_dist.item()
        log["loss_l_transfer"] = transfered_sample_loss.item()
        log["loss_l"] = loss_l.item()

        # ----------------
        # Save logs
        # ----------------
        training_logs.append(log)

        # save summary
        # for k, v in log.items():
        #     writer.add_scalar('loss/%s' % (k), v, step)

        # training average
        if len(training_logs) > 0 and (step % args.print_steps == 0
                                       or step == args.steps):
            logging.info("--------------------------------------")
            metrics = {}
            for metric in training_logs[0].keys():
                metrics[metric] = sum([log[metric] for log in training_logs
                                       ]) / len(training_logs)
            log_metrics('training average', step, metrics)

            # for k, v in metrics.items():
            #     writer.add_scalar('train-metric/%s' % (k.replace("@", "_")), v, step)

        # evaluate model
        if len(training_logs) > 0 and (step % args.eval_steps == 0
                                       or step == args.steps):
            logging.info("--------------------------------------")
            logging.info('evaluating on valid dataset...')
            valid_metrics = link_prediction(args, learner, valid_dataloaders)
            log_metrics('valid', step, valid_metrics)
            if len(valid_best_metrics) == 0 or \
                valid_best_metrics["MRR"] + valid_best_metrics["HITS@3"] < valid_metrics["MRR"] + valid_metrics["HITS@3"]:
                valid_best_metrics = valid_metrics.copy()
                valid_best_metrics["step"] = step
                save_model(
                    {
                        "model": learner,
                        "transnet": transnet,
                        "generator": generator,
                        "discriminator": discriminator,
                        "optimizer_l": optimizer_l,
                        "optimizer_g": optimizer_g,
                        "optimizer_d": optimizer_d,
                        "scheduler_l": scheduler_l,
                        "scheduler_g": scheduler_g,
                        "scheduler_d": scheduler_d,
                        "step": step,
                        "steps": args.steps
                    }, os.path.join(args.save_path, "checkpoint_valid.pt"))

            # for k, v in valid_metrics.items():
            #     writer.add_scalar('valid-metric/%s' % (k.replace("@", "_")), v, step)

            logging.info("--------------------------------------")
            logging.info('evaluating on test dataset...')
            test_metrics = link_prediction(args, learner, test_dataloaders)
            log_metrics('test', step, test_metrics)
            if len(test_best_metrics) == 0 or \
                test_best_metrics["MRR"] + test_best_metrics["HITS@3"] < test_metrics["MRR"] + test_metrics["HITS@3"]:
                test_best_metrics = test_metrics.copy()
                test_best_metrics["step"] = step
                save_model(
                    {
                        "model": learner,
                        "transnet": transnet,
                        "generator": generator,
                        "discriminator": discriminator,
                        "optimizer_l": optimizer_l,
                        "optimizer_g": optimizer_g,
                        "optimizer_d": optimizer_d,
                        "scheduler_l": scheduler_l,
                        "scheduler_g": scheduler_g,
                        "scheduler_d": scheduler_d,
                        "step": step,
                        "steps": args.steps
                    }, os.path.join(args.save_path, "checkpoint_test.pt"))

            # for k, v in test_metrics.items():
            #     writer.add_scalar('test-metric/%s' % (k.replace("@", "_")), v, step)

            learner.train()

        # save model
        if len(training_logs) > 0 and (step % args.save_steps == 0
                                       or step == args.steps):
            save_model(
                {
                    "model": learner,
                    "transnet": transnet,
                    "generator": generator,
                    "discriminator": discriminator,
                    "optimizer_l": optimizer_l,
                    "optimizer_g": optimizer_g,
                    "optimizer_d": optimizer_d,
                    "scheduler_l": scheduler_l,
                    "scheduler_g": scheduler_g,
                    "scheduler_d": scheduler_d,
                    "step": step,
                    "steps": args.steps
                }, os.path.join(args.save_path, "checkpoint_%d.pt" % (step)))

        if len(training_logs) > 0 and (step % args.print_steps == 0
                                       or step == args.steps):
            training_logs.clear()

    logging.info("--------------------------------------")
    log_metrics('valid-best', valid_best_metrics["step"], valid_best_metrics)
    log_metrics('test-best', test_best_metrics["step"], test_best_metrics)
예제 #3
0
def train(args):
    # ----------------
    # Load Data
    # ----------------
    logging.info("loading data..")

    data = KGDataset(data_path=args.data_path, dict_path=args.data_path)

    logging.info(
        "train: valid: test = %d: %d: %d" %
        (len(data.train_set), len(data.valid_set), len(data.test_set)))
    num_entities = data.get_entity_count()
    num_relations = data.get_relation_count()

    logging.info("number of entities: %d" % (num_entities))
    logging.info("number of relations: %d" % (num_relations))

    # training data
    train_loader_head = DataLoader(TrainDataset(data.train_set,
                                                data.get_entity_count(),
                                                data.get_relation_count(),
                                                args.num_neg_samples,
                                                'head-batch'),
                                   batch_size=args.kge_batch,
                                   shuffle=True,
                                   num_workers=max(0, args.num_workers // 2),
                                   collate_fn=TrainDataset.collate_fn)
    train_loader_tail = DataLoader(TrainDataset(data.train_set,
                                                data.get_entity_count(),
                                                data.get_relation_count(),
                                                args.num_neg_samples,
                                                'tail-batch'),
                                   batch_size=args.kge_batch,
                                   shuffle=True,
                                   num_workers=max(0, args.num_workers // 2),
                                   collate_fn=TrainDataset.collate_fn)
    train_iterator = BidirectionalOneShotIterator(train_loader_head,
                                                  train_loader_tail)

    # validation data and test data
    all_data = data.train_set + data.valid_set + data.test_set

    valid_loader_head = DataLoader(TestDataset(data.valid_set, all_data,
                                               num_entities, num_relations,
                                               'head-batch'),
                                   batch_size=args.test_batch,
                                   num_workers=max(0, args.num_workers // 2),
                                   collate_fn=TestDataset.collate_fn)
    valid_loader_tail = DataLoader(TestDataset(data.valid_set, all_data,
                                               num_entities, num_relations,
                                               'tail-batch'),
                                   batch_size=args.test_batch,
                                   num_workers=max(0, args.num_workers // 2),
                                   collate_fn=TestDataset.collate_fn)
    valid_dataloaders = [valid_loader_head, valid_loader_tail]

    test_loader_head = DataLoader(TestDataset(data.test_set, all_data,
                                              num_entities, num_relations,
                                              'head-batch'),
                                  batch_size=args.test_batch,
                                  num_workers=max(0, args.num_workers // 2),
                                  collate_fn=TestDataset.collate_fn)
    test_loader_tail = DataLoader(TestDataset(data.test_set, all_data,
                                              num_entities, num_relations,
                                              'tail-batch'),
                                  batch_size=args.test_batch,
                                  num_workers=max(0, args.num_workers // 2),
                                  collate_fn=TestDataset.collate_fn)
    test_dataloaders = [test_loader_head, test_loader_tail]

    # ----------------
    # Prepare Data
    # ----------------

    logging.info("preparing data..")

    writer = SummaryWriter(args.save_path)

    if args.gpu_id == -1:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda:%d" % (args.gpu_id))

    learner = KGEModel(model_name=args.kge_model,
                       num_entities=num_entities,
                       num_relations=num_relations,
                       hidden_dim=args.emb_dim,
                       gamma=args.margin,
                       double_entity_embedding=args.kge_model
                       in ["RotatE", "ComplEx"],
                       double_relation_embedding=args.kge_model
                       in ["ComplEx"]).to(device)

    # TODO COPY TOP EMBEDDING
    # top_student_entity = pickle.load(open('../intermediate/top_one_entity_transe_200.pkl', 'rb'))
    # top_shared_entity = shared_entity[shared_entity['student'].isin(top_student_entity)]
    # shared_entity = shared_entity[(True ^ shared_entity['student'].isin(top_student_entity))]
    # for t, s in np.array(top_shared_entity):
    # for t, s in np.array(shared_entity):
    #     learner.entity_embedding.data[s].copy_(teacher.entity_embedding.weight[t])

    # TODO DELETE ONCE ENTITY
    # once_entity = pickle.load(open('top_one_entity_transe_200.pkl', 'rb'))
    # shared_entity = shared_entity[(True ^ shared_entity['student'].isin(once_entity))]

    optimizer = optim.Adam(learner.parameters(), lr=args.kge_lr)
    # scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
    #     optimizer, args.steps//100, args.steps, num_cycles=4, min_percent=args.kge_lr*0.001)
    scheduler = get_constant_schedule_with_warmup(optimizer, args.steps // 100)
    optimizer.zero_grad()
    optimizer.step()
    scheduler.step(0)

    bce_loss = nn.BCELoss(reduction='none')

    logging.info("begin training..")
    learner.train()

    training_logs = []
    valid_best_metrics = {}
    test_best_metrics = {}

    for step in range(1, args.steps + 1):
        # ----------------
        # Configure Input
        # ----------------

        positive_sample, negative_sample, subsampling_weight, mode = next(
            train_iterator)
        positive_sample = positive_sample.to(device)
        negative_sample = negative_sample.to(device)
        subsampling_weight = subsampling_weight.to(device)

        log = {}

        # --------------------
        # Train Learner
        # --------------------

        optimizer.zero_grad()

        positive_score = learner(positive_sample)
        negative_score = learner((positive_sample, negative_sample), mode=mode)

        if learner.model_name == 'ConvE':
            positive_sample_loss = bce_loss(
                F.logsigmoid(positive_score),
                torch.ones_like(positive_score)).mean()
            negative_sample_loss = bce_loss(
                F.logsigmoid(-negative_score),
                torch.zeros_like(negative_score)).mean()
            loss = positive_sample_loss + negative_sample_loss
        elif learner.model_name == 'ComplEx' or learner.model_name == 'RotatE':
            positive_sample_loss = -F.logsigmoid(positive_score).mean()
            negative_sample_loss = -F.logsigmoid(-negative_score).mean()
            loss = positive_sample_loss + negative_sample_loss
        elif learner.model_name == 'TransE' or learner.model_name == 'DistMult':
            positive_sample_loss = -positive_score.mean()
            negative_sample_loss = negative_score.mean()
            loss = F.relu(-positive_score + negative_score +
                          args.margin).mean()

        if args.reg != 0.0:
            # Use L3 regularization for ComplEx and DistMult
            reg = args.reg * (learner.entity_embedding.norm(
                p=3)**3 + learner.relation_embedding.norm(p=3).norm(p=3)**3)
            loss = loss + reg
            reg_log = {'reg': reg.item()}
            log.update(reg_log)
        else:
            reg_log = {}

        loss.backward()
        optimizer.step()
        scheduler.step(step)

        log["pos_loss"] = positive_sample_loss.item()
        log["neg_loss"] = negative_sample_loss.item()
        log["loss"] = loss.item()

        # ----------------
        # Save logs
        # ----------------
        training_logs.append(log)

        # save summary
        for k, v in log.items():
            writer.add_scalar('loss/%s' % (k), v, step)

        # training average
        if len(training_logs) > 0 and (step % args.print_steps == 0
                                       or step == args.steps):
            logging.info("--------------------------------------")
            metrics = {}
            for metric in training_logs[0].keys():
                metrics[metric] = sum([log[metric] for log in training_logs
                                       ]) / len(training_logs)
            log_metrics('training average', step, metrics)

            for k, v in metrics.items():
                writer.add_scalar('train-metric/%s' % (k), v, step)

        # evaluate model
        if len(training_logs) > 0 and (step % args.eval_steps == 0
                                       or step == args.steps):
            logging.info("--------------------------------------")
            logging.info('evaluating on valid dataset...')
            valid_metrics = link_prediction(args, learner, valid_dataloaders)
            log_metrics('valid', step, valid_metrics)
            if len(valid_best_metrics) == 0 or \
                valid_best_metrics["MRR"] + valid_best_metrics["HITS@3"] < valid_metrics["MRR"] + valid_metrics["HITS@3"]:
                valid_best_metrics = valid_metrics.copy()
                valid_best_metrics["step"] = step
                save_model(
                    {
                        "learner": learner,
                        "optimizer": optimizer,
                        "scheduler": scheduler,
                        "step": step,
                        "steps": args.steps
                    }, os.path.join(args.save_path, "checkpoint_valid.pt"))

            for k, v in valid_metrics.items():
                writer.add_scalar('valid-metric/%s' % (k.replace("@", "_")), v,
                                  step)

            logging.info("--------------------------------------")
            logging.info('evaluating on test dataset...')
            test_metrics = link_prediction(args, learner, test_dataloaders)
            log_metrics('test', step, test_metrics)
            if len(test_best_metrics) == 0 or \
                test_best_metrics["MRR"] + test_best_metrics["HITS@3"] < test_metrics["MRR"] + test_metrics["HITS@3"]:
                test_best_metrics = test_metrics.copy()
                test_best_metrics["step"] = step
                save_model(
                    {
                        "model": learner,
                        "optimizer": optimizer,
                        "scheduler": scheduler,
                        "step": step,
                        "steps": args.steps
                    }, os.path.join(args.save_path, "checkpoint_test.pt"))

            for k, v in test_metrics.items():
                writer.add_scalar('test-metric/%s' % (k.replace("@", "_")), v,
                                  step)

            learner.train()

        # save model
        if len(training_logs) > 0 and (step % args.save_steps == 0
                                       or step == args.steps):
            save_model(
                {
                    "model": learner,
                    "optimizer": optimizer,
                    "scheduler": scheduler,
                    "step": step,
                    "steps": args.steps
                }, os.path.join(args.save_path, "checkpoint_%d.pt" % (step)))

        if len(training_logs) > 0 and (step % args.print_steps == 0
                                       or step == args.steps):
            training_logs.clear()

    logging.info("--------------------------------------")
    log_metrics('valid-best', valid_best_metrics["step"], valid_best_metrics)
    log_metrics('test-best', test_best_metrics["step"], test_best_metrics)
예제 #4
0
파일: main.py 프로젝트: HELL-TO-HEAVEN/FedE
def test_pretrain(args, all_data):
    data_len = len(all_data)
    #
    # train_dataloader_list, valid_dataloader_list, test_dataloader_list, ent_emb_list, rel_update_weights, g_list \
    #     = get_all_clients(all_data, args)
    #
    # total_test_data_size = sum([len(test_dataloader_list[i].dataset) for i in range(data_len)])
    # eval_weights = [len(test_dataloader_list[i].dataset) / total_test_data_size for i in range(data_len)]

    embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim])
    kge_model = KGEModel(args, model_name=args.model)

    # rel_result = ddict(list)
    # rel_result_bydata = ddict(lambda : ddict(list))
    results = ddict(float)
    for i, data in enumerate(all_data):
        one_results = ddict(float)
        state = torch.load('../LTLE/fed_state/fb15k237_fed10_client_{}.best'.format(i), map_location=args.gpu)
        rel_embed = state['rel_emb'].detach()
        ent_embed = state['ent_emb'].detach()

        train_dataset, valid_dataset, test_dataset, nrelation, nentity = get_task_dataset(data, args)
        test_dataloader_tail = DataLoader(
            test_dataset,
            batch_size=args.test_batch_size,
            # num_workers=max(1, args.num_cpu),
            collate_fn=TestDataset.collate_fn
        )

        client_res = ddict(float)
        for batch in test_dataloader_tail:
            triplets, labels, mode = batch
            # triplets, labels, mode = next(test_dataloader_list[i].__iter__())
            triplets, labels = triplets.to(args.gpu), labels.to(args.gpu)
            head_idx, rel_idx, tail_idx = triplets[:, 0], triplets[:, 1], triplets[:, 2]
            pred = kge_model((triplets, None),
                              rel_embed,
                              ent_embed,
                              mode=mode)
            b_range = torch.arange(pred.size()[0], device=args.gpu)
            target_pred = pred[b_range, tail_idx]
            pred = torch.where(labels.byte(), -torch.ones_like(pred) * 10000000, pred)
            pred[b_range, tail_idx] = target_pred

            ranks = 1 + torch.argsort(torch.argsort(pred, dim=1, descending=True),
                                      dim=1, descending=False)[b_range, tail_idx]

            ranks = ranks.float()
            count = torch.numel(ranks)

            results['count'] += count
            results['mr'] += torch.sum(ranks).item()
            results['mrr'] += torch.sum(1.0 / ranks).item()

            one_results['count'] += count
            one_results['mr'] += torch.sum(ranks).item()
            one_results['mrr'] += torch.sum(1.0 / ranks).item()

            for k in [1, 5, 10]:
                results['hits@{}'.format(k)] += torch.numel(ranks[ranks <= k])
                one_results['hits@{}'.format(k)] += torch.numel(ranks[ranks <= k])

        for k, v in one_results.items():
            if k != 'count':
                one_results[k] = v / one_results['count']

        logging.info('mrr: {:.4f}, hits@1: {:.4f}, hits@5: {:.4f}, hits@10: {:.4f}'.format(
            one_results['mrr'], one_results['hits@1'],
            one_results['hits@5'], one_results['hits@10']))

    for k, v in results.items():
        if k != 'count':
            results[k] = v / results['count']

    logging.info('mrr: {:.4f}, hits@1: {:.4f}, hits@5: {:.4f}, hits@10: {:.4f}'.format(
        results['mrr'], results['hits@1'],
        results['hits@5'], results['hits@10']))

    return results
예제 #5
0
파일: main.py 프로젝트: HELL-TO-HEAVEN/FedE
    def __init__(self, args, data):
        self.args = args
        self.data = data

        if args.run_mode == 'Entire':
            train_dataset, valid_dataset, test_dataset, nrelation, nentity = get_task_dataset_entire(data, args)
        else:
            train_dataset, valid_dataset, test_dataset, nrelation, nentity = get_task_dataset(data, args)

        self.nentity = nentity
        self.nrelation = nrelation

        # embedding
        embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim])
        if args.model in ['RotatE', 'ComplEx']:
            self.entity_embedding = torch.zeros(self.nentity, args.hidden_dim * 2).to(args.gpu).requires_grad_()
        else:
            self.entity_embedding = torch.zeros(self.nentity, args.hidden_dim).to(args.gpu).requires_grad_()
        nn.init.uniform_(
            tensor=self.entity_embedding,
            a=-embedding_range.item(),
            b=embedding_range.item()
        )
        if args.model in ['ComplEx']:
            self.relation_embedding = torch.zeros(self.nrelation, args.hidden_dim * 2).to(args.gpu).requires_grad_()
        else:
            self.relation_embedding = torch.zeros(self.nrelation, args.hidden_dim).to(args.gpu).requires_grad_()
        nn.init.uniform_(
            tensor=self.relation_embedding,
            a=-embedding_range.item(),
            b=embedding_range.item()
        )

        # dataloader
        self.train_dataloader = DataLoader(
            train_dataset,
            batch_size = args.batch_size,
            shuffle = True,
            collate_fn = TrainDataset.collate_fn
        )

        if args.run_mode == 'Entire':
            self.valid_dataloader = DataLoader(
                valid_dataset,
                batch_size=args.test_batch_size,
                collate_fn=TestDataset_Entire.collate_fn
            )

            self.test_dataloader = DataLoader(
                test_dataset,
                batch_size=args.test_batch_size,
                collate_fn=TestDataset_Entire.collate_fn
            )
        else:
            self.valid_dataloader = DataLoader(
                valid_dataset,
                batch_size=args.test_batch_size,
                collate_fn=TestDataset.collate_fn
            )

            self.test_dataloader = DataLoader(
                test_dataset,
                batch_size = args.test_batch_size,
                collate_fn=TestDataset.collate_fn
            )

        # model
        self.kge_model = KGEModel(args, args.model)

        self.optimizer = torch.optim.Adam(
            [{'params': self.entity_embedding},
             {'params': self.relation_embedding}], lr=args.lr
        )
예제 #6
0
파일: main.py 프로젝트: HELL-TO-HEAVEN/FedE
class KGERunner():
    def __init__(self, args, data):
        self.args = args
        self.data = data

        if args.run_mode == 'Entire':
            train_dataset, valid_dataset, test_dataset, nrelation, nentity = get_task_dataset_entire(data, args)
        else:
            train_dataset, valid_dataset, test_dataset, nrelation, nentity = get_task_dataset(data, args)

        self.nentity = nentity
        self.nrelation = nrelation

        # embedding
        embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim])
        if args.model in ['RotatE', 'ComplEx']:
            self.entity_embedding = torch.zeros(self.nentity, args.hidden_dim * 2).to(args.gpu).requires_grad_()
        else:
            self.entity_embedding = torch.zeros(self.nentity, args.hidden_dim).to(args.gpu).requires_grad_()
        nn.init.uniform_(
            tensor=self.entity_embedding,
            a=-embedding_range.item(),
            b=embedding_range.item()
        )
        if args.model in ['ComplEx']:
            self.relation_embedding = torch.zeros(self.nrelation, args.hidden_dim * 2).to(args.gpu).requires_grad_()
        else:
            self.relation_embedding = torch.zeros(self.nrelation, args.hidden_dim).to(args.gpu).requires_grad_()
        nn.init.uniform_(
            tensor=self.relation_embedding,
            a=-embedding_range.item(),
            b=embedding_range.item()
        )

        # dataloader
        self.train_dataloader = DataLoader(
            train_dataset,
            batch_size = args.batch_size,
            shuffle = True,
            collate_fn = TrainDataset.collate_fn
        )

        if args.run_mode == 'Entire':
            self.valid_dataloader = DataLoader(
                valid_dataset,
                batch_size=args.test_batch_size,
                collate_fn=TestDataset_Entire.collate_fn
            )

            self.test_dataloader = DataLoader(
                test_dataset,
                batch_size=args.test_batch_size,
                collate_fn=TestDataset_Entire.collate_fn
            )
        else:
            self.valid_dataloader = DataLoader(
                valid_dataset,
                batch_size=args.test_batch_size,
                collate_fn=TestDataset.collate_fn
            )

            self.test_dataloader = DataLoader(
                test_dataset,
                batch_size = args.test_batch_size,
                collate_fn=TestDataset.collate_fn
            )

        # model
        self.kge_model = KGEModel(args, args.model)

        self.optimizer = torch.optim.Adam(
            [{'params': self.entity_embedding},
             {'params': self.relation_embedding}], lr=args.lr
        )

    def before_test_load(self):
        state = torch.load(os.path.join(self.args.state_dir, self.args.name + '.best'),
                           map_location=self.args.gpu)
        self.relation_embedding = state['rel_emb']
        self.entity_embedding = state['ent_emb']

    # def load_from_multi(self):
    #     state = torch.load(os.path.join(self.args.state_dir, self.args.name + '.best'),
    #                        map_location=self.args.gpu)
    #     self.relation_embedding = state['rel_emb']
    #     self.entity_embedding = state['ent_emb']
    #
    #     nentity = len(np.unique(data['train']['edge_index'].reshape(-1)))
    #     nrelation = len(np.unique(data['train']['edge_type'].reshape(-1)))
    #     rel_purm = np.zeros(nrelation, dtype=np.int64)
    #     ent_purm = np.zeros(nentity, dtype=np.int64)
    #     for i in range(data['train']['edge_index'].shape[1]):
    #         h, r, t = data['train']['edge_index'][0][i], data['train']['edge_type'][i], \
    #                   data['train']['edge_index'][1][i]
    #         h_ori, r_ori, t_ori = data['train']['edge_index_ori'][0][i], data['train']['edge_type_ori'][i], \
    #                               data['train']['edge_index_ori'][1][i]
    #         ent_purm[h] = h_ori
    #         rel_purm[r] = r_ori
    #         ent_purm[t] = t_ori
    #     ent_purm = torch.LongTensor(ent_purm)
    #     rel_purm = torch.LongTensor(rel_purm)
    #
    #     self.relation_embedding = self.relation_embedding[rel_purm]
    #     self.entity_embedding = self.entity_embedding[ent_purm]

    def write_training_loss(self, loss, e):
        self.args.writer.add_scalar("training/loss", loss, e)

    def write_evaluation_result(self, results, e):
        self.args.writer.add_scalar("evaluation/mrr", results['mrr'], e)
        self.args.writer.add_scalar("evaluation/hits10", results['hits@10'], e)
        self.args.writer.add_scalar("evaluation/hits5", results['hits@5'], e)
        self.args.writer.add_scalar("evaluation/hits1", results['hits@1'], e)

    def save_checkpoint(self, e):
        state = {'rel_emb': self.relation_embedding,
                 'ent_emb': self.entity_embedding}
        # delete previous checkpoint
        for filename in os.listdir(self.args.state_dir):
            if self.args.name in filename.split('.') and os.path.isfile(os.path.join(self.args.state_dir, filename)):
                os.remove(os.path.join(self.args.state_dir, filename))
        # save current checkpoint
        torch.save(state, os.path.join(self.args.state_dir,
                                       self.args.name + '.' + str(e) + '.ckpt'))

    def save_model(self, best_epoch):
        os.rename(os.path.join(self.args.state_dir, self.args.name + '.' + str(best_epoch) + '.ckpt'),
                  os.path.join(self.args.state_dir, self.args.name + '.best'))

    def train(self):
        best_epoch = 0
        best_mrr = 0
        bad_count = 0

        for epoch in range(self.args.max_epoch):
            losses = []
            self.kge_model.train()
            for batch in self.train_dataloader:

                positive_sample, negative_sample, subsampling_weight = batch

                positive_sample = positive_sample.to(args.gpu)
                negative_sample = negative_sample.to(args.gpu)
                subsampling_weight = subsampling_weight.to(args.gpu)

                negative_score = self.kge_model((positive_sample, negative_sample),
                                                  self.relation_embedding,
                                                  self.entity_embedding)

                # In self-adversarial sampling, we do not apply back-propagation on the sampling weight
                negative_score = (F.softmax(negative_score * args.adversarial_temperature, dim=1).detach()
                                  * F.logsigmoid(-negative_score)).sum(dim=1)

                positive_score = self.kge_model(positive_sample,
                                                self.relation_embedding, self.entity_embedding, neg=False)

                positive_score = F.logsigmoid(positive_score).squeeze(dim=1)

                positive_sample_loss = - positive_score.mean()
                negative_sample_loss = - negative_score.mean()

                loss = (positive_sample_loss + negative_sample_loss) / 2

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                losses.append(loss.item())

            if epoch % self.args.log_per_epoch == 0:
                logging.info('epoch: {} | loss: {:.4f}'.format(epoch, np.mean(losses)))
                self.write_training_loss(np.mean(losses), epoch)

            if epoch % self.args.check_per_epoch == 0:
                if args.run_mode == 'Entire':
                    eval_res = self.evaluate_multi()
                else:
                    eval_res = self.evaluate()
                self.write_evaluation_result(eval_res, epoch)

                if eval_res['mrr'] > best_mrr:
                    best_mrr = eval_res['mrr']
                    best_epoch = epoch
                    logging.info('best model | mrr {:.4f}'.format(best_mrr))
                    self.save_checkpoint(epoch)
                    bad_count = 0
                else:
                    bad_count += 1
                    logging.info('best model is at round {0}, mrr {1:.4f}, bad count {2}'.format(
                        best_epoch, best_mrr, bad_count))

            if bad_count >= self.args.early_stop_patience:
                logging.info('early stop at round {}'.format(epoch))
                break

        logging.info('finish training')
        logging.info('save best model')
        self.save_model(best_epoch)

        logging.info('eval on test set')
        self.before_test_load()
        if args.run_mode == 'multi_client_train':
            eval_res = self.evaluate_multi(eval_split='test')
        else:
            eval_res = self.evaluate(eval_split='test')

    def evaluate_multi(self, eval_split='valid'):

        if eval_split == 'test':
            dataloader = self.test_dataloader
        elif eval_split == 'valid':
            dataloader = self.valid_dataloader

        client_ranks = ddict(list)
        all_ranks = []
        for batch in dataloader:

            triplets, labels, triple_idx = batch
            triplets, labels = triplets.to(args.gpu), labels.to(args.gpu)
            head_idx, rel_idx, tail_idx = triplets[:, 0], triplets[:, 1], triplets[:, 2]
            pred = self.kge_model((triplets, None),
                                   self.relation_embedding,
                                   self.entity_embedding)
            b_range = torch.arange(pred.size()[0], device=self.args.gpu)
            target_pred = pred[b_range, tail_idx]
            pred = torch.where(labels.byte(), -torch.ones_like(pred) * 10000000, pred)
            pred[b_range, tail_idx] = target_pred

            ranks = 1 + torch.argsort(torch.argsort(pred, dim=1, descending=True),
                                      dim=1, descending=False)[b_range, tail_idx]

            ranks = ranks.float()

            for i in range(args.num_multi):
                client_ranks[i].extend(ranks[triple_idx == i].tolist())

            all_ranks.extend(ranks.tolist())

        for i in range(args.num_multi):
            results = ddict(float)
            ranks = torch.tensor(client_ranks[i])
            count = torch.numel(ranks)
            results['count'] = count
            results['mr'] = torch.sum(ranks).item() / count
            results['mrr'] = torch.sum(1.0 / ranks).item() / count
            for k in [1, 5, 10]:
                results['hits@{}'.format(k)] = torch.numel(ranks[ranks <= k]) / count
            logging.info('mrr: {:.4f}, hits@1: {:.4f}, hits@5: {:.4f}, hits@10: {:.4f}'.format(
                results['mrr'], results['hits@1'],
                results['hits@5'], results['hits@10']))

        results = ddict(float)
        ranks = torch.tensor(all_ranks)
        count = torch.numel(ranks)
        results['count'] = count
        results['mr'] = torch.sum(ranks).item() / count
        results['mrr'] = torch.sum(1.0 / ranks).item() / count
        for k in [1, 5, 10]:
            results['hits@{}'.format(k)] = torch.numel(ranks[ranks <= k]) / count
        logging.info('mrr: {:.4f}, hits@1: {:.4f}, hits@5: {:.4f}, hits@10: {:.4f}'.format(
            results['mrr'], results['hits@1'],
            results['hits@5'], results['hits@10']))

        return results

    def evaluate(self, eval_split='valid'):
        results = ddict(float)

        if eval_split == 'test':
            dataloader = self.test_dataloader
        elif eval_split == 'valid':
            dataloader = self.valid_dataloader

        pred_list = []
        rank_list = []
        results_list = []
        for batch in dataloader:
            triplets, labels = batch
            triplets, labels = triplets.to(args.gpu), labels.to(args.gpu)
            head_idx, rel_idx, tail_idx = triplets[:, 0], triplets[:, 1], triplets[:, 2]
            pred = self.kge_model((triplets, None),
                                  self.relation_embedding,
                                  self.entity_embedding)
            b_range = torch.arange(pred.size()[0], device=self.args.gpu)
            target_pred = pred[b_range, tail_idx]
            pred = torch.where(labels.byte(), -torch.ones_like(pred) * 10000000, pred)
            pred[b_range, tail_idx] = target_pred

            pred_argsort = torch.argsort(pred, dim=1, descending=True)
            ranks = 1 + torch.argsort(pred_argsort, dim=1, descending=False)[b_range, tail_idx]

            pred_list.append(pred_argsort[:, :10])
            rank_list.append(ranks)

            ranks = ranks.float()

            for idx, tri in enumerate(triplets):
                results_list.append([tri.tolist(), ranks[idx].item()])

            count = torch.numel(ranks)
            results['count'] += count
            results['mr'] += torch.sum(ranks).item()
            results['mrr'] += torch.sum(1.0 / ranks).item()

            for k in [1, 5, 10]:
                results['hits@{}'.format(k)] += torch.numel(ranks[ranks <= k])

        torch.save(torch.cat(pred_list, dim=0), os.path.join(args.state_dir,
                                                             args.name + '_' + str(args.one_client_idx) + '.pred'))
        torch.save(torch.cat(rank_list), os.path.join(args.state_dir,
                                                      args.name + '_' + str(args.one_client_idx) + '.rank'))

        for k, v in results.items():
            if k != 'count':
                results[k] /= results['count']

        logging.info('mrr: {:.4f}, hits@1: {:.4f}, hits@5: {:.4f}, hits@10: {:.4f}'.format(
            results['mrr'], results['hits@1'],
            results['hits@5'], results['hits@10']))

        test_rst_file = os.path.join(args.log_dir, args.name + '.test.rst')
        pickle.dump(results_list, open(test_rst_file, 'wb'))

        return results
예제 #7
0
def train_eval(args, train_triples, valid_triples, test_triples):
    all_true_triples = train_triples + valid_triples + test_triples
    set_seed(args.seed_num)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    model = KGEModel(model_name=args.model,
                     nentity=args.nentity,
                     nrelation=args.nrelation,
                     hidden_dim=args.hidden_dim,
                     gamma=args.gamma,
                     double_entity_embedding=args.double_entity_embedding,
                     double_relation_embedding=args.double_relation_embedding)

    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    current_learning_rate = args.learning_rate
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 lr=current_learning_rate)

    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, args.nentity, args.nrelation,
                         args.negative_sample_size, 'head-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, args.nentity, args.nrelation,
                         args.negative_sample_size, 'tail-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                      train_dataloader_tail)

        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            model.parameters()),
                                     lr=current_learning_rate)
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(
            os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0

    step = init_step

    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('learning_rate = %d' % current_learning_rate)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' %
                 args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' %
                 str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' %
                     args.adversarial_temperature)

    if args.do_train:
        training_logs = []

        # Training Loop
        for step in range(init_step, args.max_steps):

            model.train()

            optimizer.zero_grad()

            positive_sample, negative_sample, subsampling_weight, mode = next(
                train_iterator)

            positive_sample = positive_sample.to(device)
            negative_sample = negative_sample.to(device)
            subsampling_weight = subsampling_weight.to(device)

            negative_score = model((positive_sample, negative_sample),
                                   mode=mode)

            if args.negative_adversarial_sampling:
                # In self-adversarial sampling, we do not apply back-propagation on the sampling weight
                negative_score = (
                    F.softmax(negative_score * args.adversarial_temperature,
                              dim=1).detach() *
                    F.logsigmoid(-negative_score)).sum(dim=1)
            else:
                negative_score = F.logsigmoid(-negative_score).mean(dim=1)

            positive_score = model(positive_sample)

            positive_score = F.logsigmoid(positive_score).squeeze(dim=1)

            if args.uni_weight:
                positive_sample_loss = -positive_score.mean()
                negative_sample_loss = -negative_score.mean()
            else:
                positive_sample_loss = -(subsampling_weight * positive_score
                                         ).sum() / subsampling_weight.sum()
                negative_sample_loss = -(subsampling_weight * negative_score
                                         ).sum() / subsampling_weight.sum()

            loss = (positive_sample_loss + negative_sample_loss) / 2

            if args.regularization != 0.0:
                # Use L3 regularization for ComplEx and DistMult
                regularization = args.regularization * (
                    model.entity_embedding.norm(p=3)**3 +
                    model.relation_embedding.norm(p=3).norm(p=3)**3)
                loss = loss + regularization
                regularization_log = {'regularization': regularization.item()}
            else:
                regularization_log = {}

            loss.backward()

            optimizer.step()

            log = {
                **regularization_log, 'positive_sample_loss':
                positive_sample_loss.item(),
                'negative_sample_loss':
                negative_sample_loss.item(),
                'loss':
                loss.item()
            }

            training_logs.append(log)

            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' %
                             (current_learning_rate, step))
                optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                    model.parameters()),
                                             lr=current_learning_rate)
                warm_up_steps = warm_up_steps * 3

            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step,
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(model, optimizer, save_variable_list, args)

            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum(
                        [log[metric]
                         for log in training_logs]) / len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []

            if args.do_valid and step % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = test_step(model, valid_triples, all_true_triples,
                                    args)
                log_metrics('Valid', step, metrics)

        save_variable_list = {
            'step': step,
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(model, optimizer, save_variable_list, args)

    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = model.test_step(model, valid_triples, all_true_triples, args)
        log_metrics('Valid', step, metrics)

    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = model.test_step(model, test_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)

    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics = model.test_step(model, train_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)