def __init__(self, args, client_id, data, train_dataloader, valid_dataloader, test_dataloader, rel_embed): self.args = args self.data = data self.train_dataloader = train_dataloader self.valid_dataloader = valid_dataloader self.test_dataloader = test_dataloader self.rel_embed = rel_embed self.client_id = client_id self.score_local = [] self.score_global = [] self.kge_model = KGEModel(args, args.model) self.ent_embed = None
def train(args): # ---------------- # Load Data # ---------------- logging.info("loading data..") teacher = KGDataset(dict_path=args.teacher_data_path, model_path=args.teacher_model_path) student = KGDataset(data_path=args.student_data_path, dict_path=args.student_data_path) shared_entity = load_shared_entity(args.shared_entity_path) shared_entity = shared_entity[ (shared_entity["student"] < student.get_entity_count()) & (shared_entity["teacher"] < teacher.get_entity_count())] student2teacher = dict( zip(shared_entity["student"], shared_entity["teacher"])) logging.info("train: valid: test = %d: %d: %d" % (len( student.train_set), len(student.valid_set), len(student.test_set))) num_entities = student.get_entity_count() num_relations = student.get_relation_count() logging.info("number of entities: %d" % (num_entities)) logging.info("number of relations: %d" % (num_relations)) # training data train_loader_head = DataLoader(TrainDataset(student.train_set, num_entities, num_relations, args.num_neg_samples, 'head-batch'), batch_size=args.kge_batch, shuffle=True, num_workers=max(0, args.num_workers // 2), collate_fn=TrainDataset.collate_fn) train_loader_tail = DataLoader(TrainDataset(student.train_set, num_entities, num_relations, args.num_neg_samples, 'tail-batch'), batch_size=args.kge_batch, shuffle=True, num_workers=max(0, args.num_workers // 2), collate_fn=TrainDataset.collate_fn) train_iterator = BidirectionalOneShotIterator(train_loader_head, train_loader_tail) # validation data and test data all_student_data = student.train_set + student.valid_set + student.test_set valid_loader_head = DataLoader(TestDataset(student.valid_set, all_student_data, num_entities, num_relations, 'head-batch'), batch_size=args.test_batch, num_workers=max(0, args.num_workers // 2), collate_fn=TestDataset.collate_fn) valid_loader_tail = DataLoader(TestDataset(student.valid_set, all_student_data, num_entities, num_relations, 'tail-batch'), batch_size=args.test_batch, num_workers=max(0, args.num_workers // 2), collate_fn=TestDataset.collate_fn) valid_dataloaders = [valid_loader_head, valid_loader_tail] test_loader_head = DataLoader(TestDataset(student.test_set, all_student_data, num_entities, num_relations, 'head-batch'), batch_size=args.test_batch, num_workers=max(0, args.num_workers // 2), collate_fn=TestDataset.collate_fn) test_loader_tail = DataLoader(TestDataset(student.test_set, all_student_data, num_entities, num_relations, 'tail-batch'), batch_size=args.test_batch, num_workers=max(0, args.num_workers // 2), collate_fn=TestDataset.collate_fn) test_dataloaders = [test_loader_head, test_loader_tail] # ---------------- # Prepare Data # ---------------- logging.info("preparing data..") # writer = SummaryWriter(args.save_path) writer = None if args.gpu_id == -1: device = torch.device("cpu") else: device = torch.device("cuda:%d" % (args.gpu_id)) teacher.entity_embedding = teacher.entity_embedding.to( device).requires_grad_(False) teacher.relation_embedding = teacher.relation_embedding.to( device).requires_grad_(False) learner = KGEModel(model_name=args.kge_model, num_entities=num_entities, num_relations=num_relations, hidden_dim=args.emb_dim, gamma=args.margin, double_entity_embedding=args.kge_model in ["RotatE", "ComplEx"], double_relation_embedding=args.kge_model in ["ComplEx"]).to(device) transnet = TransNetwork(teacher.get_embedding_dim(), learner.entity_dim).to(device) generator = Generator(learner.entity_dim).to(device) discriminator = Discriminator(learner.entity_dim, activation=nn.Sigmoid()).to(device) optimizer_g = optim.Adam(generator.parameters(), lr=args.gan_lr, betas=(0.5, 0.9)) optimizer_d = optim.Adam(list(transnet.parameters()) + list(discriminator.parameters()), lr=args.gan_lr, betas=(0.5, 0.9)) optimizer_l = optim.Adam(list(transnet.parameters()) + list(learner.parameters()), lr=args.kge_lr) scheduler_g = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer_g, args.steps // 100, args.steps, num_cycles=4, min_percent=args.gan_lr * 0.001) scheduler_d = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer_d, args.steps // 100, args.steps, num_cycles=4, min_percent=args.gan_lr * 0.001) scheduler_l = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer_l, args.steps // 100, args.steps, num_cycles=4, min_percent=args.kge_lr * 0.001) # scheduler_g = get_constant_schedule_with_warmup(optimizer_g, args.steps//100) # scheduler_d = get_constant_schedule_with_warmup(optimizer_d, args.steps//100) # scheduler_l = get_constant_schedule_with_warmup(optimizer_l, args.steps//100) optimizer_g.zero_grad() optimizer_g.step() scheduler_g.step(0) optimizer_d.zero_grad() optimizer_d.step() scheduler_d.step(0) optimizer_l.zero_grad() optimizer_l.step() scheduler_l.step(0) bce_loss = nn.BCELoss(reduction='none') mse_loss = nn.MSELoss(reduction='none') cos_loss = nn.CosineEmbeddingLoss(reduction='none') logging.info("begin training..") transnet.train() generator.train() discriminator.train() learner.train() shared_iterator = batch_iter(shared_entity, args.gan_batch) training_logs = [] valid_best_metrics = {} test_best_metrics = {} for step in range(1, args.steps + 1): log = {} loss_g_selfs = [] loss_g_dists = [] loss_gs = [] loss_d_reals = [] loss_d_transes = [] loss_d_fakes = [] loss_ds = [] teacher_entities = [] student_entities = [] for gan_step in range(args.gan_steps): try: shared_batch = next(shared_iterator) except StopIteration: shared_iterator = batch_iter(shared_entity, args.gan_batch) shared_batch = next(shared_iterator) teacher_entity = torch.LongTensor( shared_batch['teacher'].values).to(device) student_entity = torch.LongTensor( shared_batch['student'].values).to(device) teacher_embed = torch.index_select(teacher.entity_embedding, dim=0, index=teacher_entity) student_embed = torch.index_select(learner.entity_embedding, dim=0, index=student_entity) gan_batch_size = teacher_entity.size(0) gan_ones = torch.ones((gan_batch_size, 1), device=device, requires_grad=False) teacher_entities.append(teacher_entity) student_entities.append(student_entity) # -------------------- # Train Discriminator # -------------------- random_z = (torch.rand_like(student_embed) * 2 - 1) fake_emb = generator(student_embed.detach(), random_z) real_output = discriminator(student_embed.detach(), student_embed.detach()) trans_output = discriminator(student_embed.detach(), transnet(teacher_embed)) fake_output = discriminator(student_embed.detach(), fake_emb.detach()) loss_d_real = -torch.mean(real_output) loss_d_trans = -torch.mean(trans_output) loss_d_fake = torch.mean(fake_output) # beta = anneal_fn("cosine", step, args.steps, args.kge_beta, 0) beta = 0 loss_d = loss_d_real + loss_d_fake + beta * loss_d_trans optimizer_d.zero_grad() loss_d.backward() optimizer_d.step() scheduler_d.step(step) loss_d_reals.append(loss_d_real.item()) loss_d_transes.append(loss_d_trans.item()) loss_d_fakes.append(loss_d_fake.item()) loss_ds.append(loss_d.item()) # clip weights of discriminator for param in discriminator.parameters(): param.data.clamp_(-args.gan_clip_value, args.gan_clip_value) if args.gan_n_critic > 0: if gan_step % args.gan_n_critic == 0: # ---------------- # Train Generator # ---------------- # random_z = (torch.rand_like(student_embed) * 2 - 1) # fake_emb = generator(student_embed.detach(), random_z) loss_g_self = -torch.mean( discriminator(student_embed.detach(), fake_emb)) loss_g_dist = cos_loss(student_embed.detach(), fake_emb, gan_ones).mean() alpha = anneal_fn("cosine", step, args.steps, args.kge_alpha, 0) loss_g = loss_g_self + alpha * loss_g_dist optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() scheduler_g.step(step) loss_g_selfs.append(loss_g_self.item()) loss_g_dists.append(loss_g_dist.item()) loss_gs.append(loss_g.item()) else: for gene_step in range(-args.gan_n_critic): # ---------------- # Train Generator # ---------------- random_z = (torch.rand_like(student_embed) * 2 - 1) fake_emb = generator(student_embed.detach(), random_z) loss_g_self = -torch.mean( discriminator(student_embed.detach(), fake_emb)) loss_g_dist = cos_loss(student_embed.detach(), fake_emb, gan_ones).mean() alpha = anneal_fn("cosine", step, args.steps, args.kge_alpha, 0) loss_g = loss_g_self + alpha * loss_g_dist optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() scheduler_g.step(step) loss_g_selfs.append(loss_g_self.item()) loss_g_dists.append(loss_g_dist.item()) loss_gs.append(loss_g.item()) log["loss_g_self"] = np.mean(loss_g_selfs) if len(loss_gs) > 0 else 0.0 log["loss_g_dist"] = np.mean(loss_g_dists) if len(loss_gs) > 0 else 0.0 log["loss_g"] = np.mean(loss_gs) if len(loss_gs) > 0 else 0.0 log["loss_d_real"] = np.mean(loss_d_reals) if len(loss_ds) > 0 else 0.0 log["loss_d_trans"] = np.mean( loss_d_transes) if len(loss_ds) > 0 else 0.0 log["loss_d_fake"] = np.mean(loss_d_fakes) if len(loss_ds) > 0 else 0.0 log["loss_d"] = np.mean(loss_ds) if len(loss_ds) > 0 else 0.0 # -------------------- # Train Learner # -------------------- positive_sample, negative_sample, subsampling_weight, mode = next( train_iterator) teacher_entity = [] student_entity = [] transfered_sample = [] transfered_weight = [] transfered_index = [] for i, s in enumerate(positive_sample[:, 0].numpy()): t = student2teacher.get(s, -1) if t != -1: teacher_entity.append(t) student_entity.append(s) transfered_sample.append(positive_sample[i]) transfered_weight.append(subsampling_weight[i].item()) transfered_index.append(i) num_transfered_head = len(transfered_sample) for i, s in enumerate(positive_sample[:, 2].numpy()): t = student2teacher.get(s, -1) if t != -1: teacher_entity.append(t) student_entity.append(s) transfered_sample.append(positive_sample[i]) transfered_weight.append(subsampling_weight[i].item()) transfered_index.append(i) teacher_entity = torch.LongTensor(teacher_entity).to(device) student_entity = torch.LongTensor(student_entity).to(device) positive_sample = positive_sample.to(device) negative_sample = negative_sample.to(device) transfered_sample = torch.cat(transfered_sample, dim=0).view(-1, 3).to(device) subsampling_weight = subsampling_weight.to(device) transfered_weight = torch.tensor(transfered_weight).to(device) transfered_index = torch.LongTensor(transfered_index).to(device) teacher_entities.append(teacher_entity) student_entities.append(student_entity) positive_score = learner(positive_sample) negative_score = learner((positive_sample, negative_sample), mode=mode) if transfered_sample.size(0) > 0: transfered_negative_score = torch.index_select( negative_score, dim=0, index=transfered_index) teacher_embed = torch.index_select(teacher.entity_embedding, dim=0, index=teacher_entity) student_embed = torch.index_select(learner.entity_embedding, dim=0, index=student_entity) transfered_head = torch.cat([ transnet(teacher_embed[:num_transfered_head]), torch.index_select( learner.entity_embedding, dim=0, index=transfered_sample[num_transfered_head:, 0]) ], dim=0).unsqueeze(1) transfered_relation = torch.index_select( learner.relation_embedding, dim=0, index=transfered_sample[:, 1]).unsqueeze(1) transfered_tail = torch.cat([ torch.index_select( learner.entity_embedding, dim=0, index=transfered_sample[:num_transfered_head, 2]), transnet(teacher_embed[num_transfered_head:]) ], dim=0).unsqueeze(1) transfered_score = learner.score(transfered_head, transfered_relation, transfered_tail) # gan_validity = torch.ones((transfered_weight.size(0), 1), dtype=transfered_weight.dtype, device=transfered_weight.device) # w/o AAM gan_validity = discriminator(student_embed, transnet(teacher_embed)) transfered_weight = gan_validity * transfered_weight else: transfered_negative_score = torch.tensor([[0.0]]).to(device) transfered_score = torch.tensor([[0.0]]).to(device) gan_validity = 1 transfered_weight = torch.tensor([0.0]).to(device) if learner.model_name == 'ConvE': positive_sample_loss = bce_loss( F.logsigmoid(positive_score), torch.ones_like(positive_score)).mean() negative_sample_loss = bce_loss( F.logsigmoid(-negative_score), torch.zeros_like(negative_score)).mean() transfered_sample_loss = ( gan_validity * bce_loss(F.logsigmoid(transfered_score), torch.ones_like(transfered_score))).mean() loss_l_self = positive_sample_loss + negative_sample_loss loss_l_trans = transfered_sample_loss elif learner.model_name == 'ComplEx' or learner.model_name == 'RotatE': positive_sample_loss = -F.logsigmoid(positive_score).mean() negative_sample_loss = -F.logsigmoid(-negative_score).mean() transfered_sample_loss = -(gan_validity * F.logsigmoid(transfered_score)).mean() loss_l_self = positive_sample_loss + negative_sample_loss loss_l_trans = transfered_sample_loss elif learner.model_name == 'TransE' or learner.model_name == 'DistMult': positive_sample_loss = -positive_score.mean() negative_sample_loss = negative_score.mean() transfered_sample_loss = -transfered_score.mean() loss_l_self = F.relu(-positive_score + negative_score + args.margin).mean() loss_l_trans = (F.relu(-transfered_score + transfered_negative_score + args.margin) * gan_validity.unsqueeze(1)).mean() if len(teacher_entities) > 0: teacher_entities = torch.cat(teacher_entities, dim=0) student_entities = torch.cat(student_entities, dim=0) teacher_embeds = torch.index_select(teacher.entity_embedding, dim=0, index=teacher_entities) student_embeds = torch.index_select(learner.entity_embedding, dim=0, index=student_entities) trans_embeds = transnet(teacher_embeds) # gan_validity = torch.ones((trans_embeds.size(0), 1), dtype=trans_embeds.dtype, device=trans_embeds.device) # w/o AAM gan_validity = discriminator(student_embeds, trans_embeds) loss_l_dist = (gan_validity * cos_loss( student_embeds, trans_embeds, torch.ones(trans_embeds.size(0), device=device))).mean() else: loss_l_dist = torch.tensor([0.0]).to(device) if args.reg != 0.0: # Use L3 regularization for ComplEx and DistMult reg = args.reg * (learner.entity_embedding.norm( p=3)**3 + learner.relation_embedding.norm(p=3).norm(p=3)**3) loss_l_self = loss_l_self + reg reg_log = {'reg': reg.item()} log.update(reg_log) else: reg_log = {} alpha = anneal_fn("cyclical_cosine", step, args.steps // 2, args.kge_alpha, 0) beta = anneal_fn("cyclical_cosine", step, args.steps // 2, args.kge_beta, 0) # alpha = args.kge_alpha # beta = args.kge_beta loss_l = loss_l_self + alpha * loss_l_dist + beta * loss_l_trans optimizer_l.zero_grad() loss_l.backward() optimizer_l.step() scheduler_l.step(step) log["alpha"] = alpha log["beta"] = beta log["pos_loss"] = positive_sample_loss.item() log["neg_loss"] = negative_sample_loss.item() log["loss_l_self"] = loss_l_self.item() log["loss_l_dist"] = loss_l_dist.item() log["loss_l_transfer"] = transfered_sample_loss.item() log["loss_l"] = loss_l.item() # ---------------- # Save logs # ---------------- training_logs.append(log) # save summary # for k, v in log.items(): # writer.add_scalar('loss/%s' % (k), v, step) # training average if len(training_logs) > 0 and (step % args.print_steps == 0 or step == args.steps): logging.info("--------------------------------------") metrics = {} for metric in training_logs[0].keys(): metrics[metric] = sum([log[metric] for log in training_logs ]) / len(training_logs) log_metrics('training average', step, metrics) # for k, v in metrics.items(): # writer.add_scalar('train-metric/%s' % (k.replace("@", "_")), v, step) # evaluate model if len(training_logs) > 0 and (step % args.eval_steps == 0 or step == args.steps): logging.info("--------------------------------------") logging.info('evaluating on valid dataset...') valid_metrics = link_prediction(args, learner, valid_dataloaders) log_metrics('valid', step, valid_metrics) if len(valid_best_metrics) == 0 or \ valid_best_metrics["MRR"] + valid_best_metrics["HITS@3"] < valid_metrics["MRR"] + valid_metrics["HITS@3"]: valid_best_metrics = valid_metrics.copy() valid_best_metrics["step"] = step save_model( { "model": learner, "transnet": transnet, "generator": generator, "discriminator": discriminator, "optimizer_l": optimizer_l, "optimizer_g": optimizer_g, "optimizer_d": optimizer_d, "scheduler_l": scheduler_l, "scheduler_g": scheduler_g, "scheduler_d": scheduler_d, "step": step, "steps": args.steps }, os.path.join(args.save_path, "checkpoint_valid.pt")) # for k, v in valid_metrics.items(): # writer.add_scalar('valid-metric/%s' % (k.replace("@", "_")), v, step) logging.info("--------------------------------------") logging.info('evaluating on test dataset...') test_metrics = link_prediction(args, learner, test_dataloaders) log_metrics('test', step, test_metrics) if len(test_best_metrics) == 0 or \ test_best_metrics["MRR"] + test_best_metrics["HITS@3"] < test_metrics["MRR"] + test_metrics["HITS@3"]: test_best_metrics = test_metrics.copy() test_best_metrics["step"] = step save_model( { "model": learner, "transnet": transnet, "generator": generator, "discriminator": discriminator, "optimizer_l": optimizer_l, "optimizer_g": optimizer_g, "optimizer_d": optimizer_d, "scheduler_l": scheduler_l, "scheduler_g": scheduler_g, "scheduler_d": scheduler_d, "step": step, "steps": args.steps }, os.path.join(args.save_path, "checkpoint_test.pt")) # for k, v in test_metrics.items(): # writer.add_scalar('test-metric/%s' % (k.replace("@", "_")), v, step) learner.train() # save model if len(training_logs) > 0 and (step % args.save_steps == 0 or step == args.steps): save_model( { "model": learner, "transnet": transnet, "generator": generator, "discriminator": discriminator, "optimizer_l": optimizer_l, "optimizer_g": optimizer_g, "optimizer_d": optimizer_d, "scheduler_l": scheduler_l, "scheduler_g": scheduler_g, "scheduler_d": scheduler_d, "step": step, "steps": args.steps }, os.path.join(args.save_path, "checkpoint_%d.pt" % (step))) if len(training_logs) > 0 and (step % args.print_steps == 0 or step == args.steps): training_logs.clear() logging.info("--------------------------------------") log_metrics('valid-best', valid_best_metrics["step"], valid_best_metrics) log_metrics('test-best', test_best_metrics["step"], test_best_metrics)
def train(args): # ---------------- # Load Data # ---------------- logging.info("loading data..") data = KGDataset(data_path=args.data_path, dict_path=args.data_path) logging.info( "train: valid: test = %d: %d: %d" % (len(data.train_set), len(data.valid_set), len(data.test_set))) num_entities = data.get_entity_count() num_relations = data.get_relation_count() logging.info("number of entities: %d" % (num_entities)) logging.info("number of relations: %d" % (num_relations)) # training data train_loader_head = DataLoader(TrainDataset(data.train_set, data.get_entity_count(), data.get_relation_count(), args.num_neg_samples, 'head-batch'), batch_size=args.kge_batch, shuffle=True, num_workers=max(0, args.num_workers // 2), collate_fn=TrainDataset.collate_fn) train_loader_tail = DataLoader(TrainDataset(data.train_set, data.get_entity_count(), data.get_relation_count(), args.num_neg_samples, 'tail-batch'), batch_size=args.kge_batch, shuffle=True, num_workers=max(0, args.num_workers // 2), collate_fn=TrainDataset.collate_fn) train_iterator = BidirectionalOneShotIterator(train_loader_head, train_loader_tail) # validation data and test data all_data = data.train_set + data.valid_set + data.test_set valid_loader_head = DataLoader(TestDataset(data.valid_set, all_data, num_entities, num_relations, 'head-batch'), batch_size=args.test_batch, num_workers=max(0, args.num_workers // 2), collate_fn=TestDataset.collate_fn) valid_loader_tail = DataLoader(TestDataset(data.valid_set, all_data, num_entities, num_relations, 'tail-batch'), batch_size=args.test_batch, num_workers=max(0, args.num_workers // 2), collate_fn=TestDataset.collate_fn) valid_dataloaders = [valid_loader_head, valid_loader_tail] test_loader_head = DataLoader(TestDataset(data.test_set, all_data, num_entities, num_relations, 'head-batch'), batch_size=args.test_batch, num_workers=max(0, args.num_workers // 2), collate_fn=TestDataset.collate_fn) test_loader_tail = DataLoader(TestDataset(data.test_set, all_data, num_entities, num_relations, 'tail-batch'), batch_size=args.test_batch, num_workers=max(0, args.num_workers // 2), collate_fn=TestDataset.collate_fn) test_dataloaders = [test_loader_head, test_loader_tail] # ---------------- # Prepare Data # ---------------- logging.info("preparing data..") writer = SummaryWriter(args.save_path) if args.gpu_id == -1: device = torch.device("cpu") else: device = torch.device("cuda:%d" % (args.gpu_id)) learner = KGEModel(model_name=args.kge_model, num_entities=num_entities, num_relations=num_relations, hidden_dim=args.emb_dim, gamma=args.margin, double_entity_embedding=args.kge_model in ["RotatE", "ComplEx"], double_relation_embedding=args.kge_model in ["ComplEx"]).to(device) # TODO COPY TOP EMBEDDING # top_student_entity = pickle.load(open('../intermediate/top_one_entity_transe_200.pkl', 'rb')) # top_shared_entity = shared_entity[shared_entity['student'].isin(top_student_entity)] # shared_entity = shared_entity[(True ^ shared_entity['student'].isin(top_student_entity))] # for t, s in np.array(top_shared_entity): # for t, s in np.array(shared_entity): # learner.entity_embedding.data[s].copy_(teacher.entity_embedding.weight[t]) # TODO DELETE ONCE ENTITY # once_entity = pickle.load(open('top_one_entity_transe_200.pkl', 'rb')) # shared_entity = shared_entity[(True ^ shared_entity['student'].isin(once_entity))] optimizer = optim.Adam(learner.parameters(), lr=args.kge_lr) # scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( # optimizer, args.steps//100, args.steps, num_cycles=4, min_percent=args.kge_lr*0.001) scheduler = get_constant_schedule_with_warmup(optimizer, args.steps // 100) optimizer.zero_grad() optimizer.step() scheduler.step(0) bce_loss = nn.BCELoss(reduction='none') logging.info("begin training..") learner.train() training_logs = [] valid_best_metrics = {} test_best_metrics = {} for step in range(1, args.steps + 1): # ---------------- # Configure Input # ---------------- positive_sample, negative_sample, subsampling_weight, mode = next( train_iterator) positive_sample = positive_sample.to(device) negative_sample = negative_sample.to(device) subsampling_weight = subsampling_weight.to(device) log = {} # -------------------- # Train Learner # -------------------- optimizer.zero_grad() positive_score = learner(positive_sample) negative_score = learner((positive_sample, negative_sample), mode=mode) if learner.model_name == 'ConvE': positive_sample_loss = bce_loss( F.logsigmoid(positive_score), torch.ones_like(positive_score)).mean() negative_sample_loss = bce_loss( F.logsigmoid(-negative_score), torch.zeros_like(negative_score)).mean() loss = positive_sample_loss + negative_sample_loss elif learner.model_name == 'ComplEx' or learner.model_name == 'RotatE': positive_sample_loss = -F.logsigmoid(positive_score).mean() negative_sample_loss = -F.logsigmoid(-negative_score).mean() loss = positive_sample_loss + negative_sample_loss elif learner.model_name == 'TransE' or learner.model_name == 'DistMult': positive_sample_loss = -positive_score.mean() negative_sample_loss = negative_score.mean() loss = F.relu(-positive_score + negative_score + args.margin).mean() if args.reg != 0.0: # Use L3 regularization for ComplEx and DistMult reg = args.reg * (learner.entity_embedding.norm( p=3)**3 + learner.relation_embedding.norm(p=3).norm(p=3)**3) loss = loss + reg reg_log = {'reg': reg.item()} log.update(reg_log) else: reg_log = {} loss.backward() optimizer.step() scheduler.step(step) log["pos_loss"] = positive_sample_loss.item() log["neg_loss"] = negative_sample_loss.item() log["loss"] = loss.item() # ---------------- # Save logs # ---------------- training_logs.append(log) # save summary for k, v in log.items(): writer.add_scalar('loss/%s' % (k), v, step) # training average if len(training_logs) > 0 and (step % args.print_steps == 0 or step == args.steps): logging.info("--------------------------------------") metrics = {} for metric in training_logs[0].keys(): metrics[metric] = sum([log[metric] for log in training_logs ]) / len(training_logs) log_metrics('training average', step, metrics) for k, v in metrics.items(): writer.add_scalar('train-metric/%s' % (k), v, step) # evaluate model if len(training_logs) > 0 and (step % args.eval_steps == 0 or step == args.steps): logging.info("--------------------------------------") logging.info('evaluating on valid dataset...') valid_metrics = link_prediction(args, learner, valid_dataloaders) log_metrics('valid', step, valid_metrics) if len(valid_best_metrics) == 0 or \ valid_best_metrics["MRR"] + valid_best_metrics["HITS@3"] < valid_metrics["MRR"] + valid_metrics["HITS@3"]: valid_best_metrics = valid_metrics.copy() valid_best_metrics["step"] = step save_model( { "learner": learner, "optimizer": optimizer, "scheduler": scheduler, "step": step, "steps": args.steps }, os.path.join(args.save_path, "checkpoint_valid.pt")) for k, v in valid_metrics.items(): writer.add_scalar('valid-metric/%s' % (k.replace("@", "_")), v, step) logging.info("--------------------------------------") logging.info('evaluating on test dataset...') test_metrics = link_prediction(args, learner, test_dataloaders) log_metrics('test', step, test_metrics) if len(test_best_metrics) == 0 or \ test_best_metrics["MRR"] + test_best_metrics["HITS@3"] < test_metrics["MRR"] + test_metrics["HITS@3"]: test_best_metrics = test_metrics.copy() test_best_metrics["step"] = step save_model( { "model": learner, "optimizer": optimizer, "scheduler": scheduler, "step": step, "steps": args.steps }, os.path.join(args.save_path, "checkpoint_test.pt")) for k, v in test_metrics.items(): writer.add_scalar('test-metric/%s' % (k.replace("@", "_")), v, step) learner.train() # save model if len(training_logs) > 0 and (step % args.save_steps == 0 or step == args.steps): save_model( { "model": learner, "optimizer": optimizer, "scheduler": scheduler, "step": step, "steps": args.steps }, os.path.join(args.save_path, "checkpoint_%d.pt" % (step))) if len(training_logs) > 0 and (step % args.print_steps == 0 or step == args.steps): training_logs.clear() logging.info("--------------------------------------") log_metrics('valid-best', valid_best_metrics["step"], valid_best_metrics) log_metrics('test-best', test_best_metrics["step"], test_best_metrics)
def test_pretrain(args, all_data): data_len = len(all_data) # # train_dataloader_list, valid_dataloader_list, test_dataloader_list, ent_emb_list, rel_update_weights, g_list \ # = get_all_clients(all_data, args) # # total_test_data_size = sum([len(test_dataloader_list[i].dataset) for i in range(data_len)]) # eval_weights = [len(test_dataloader_list[i].dataset) / total_test_data_size for i in range(data_len)] embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim]) kge_model = KGEModel(args, model_name=args.model) # rel_result = ddict(list) # rel_result_bydata = ddict(lambda : ddict(list)) results = ddict(float) for i, data in enumerate(all_data): one_results = ddict(float) state = torch.load('../LTLE/fed_state/fb15k237_fed10_client_{}.best'.format(i), map_location=args.gpu) rel_embed = state['rel_emb'].detach() ent_embed = state['ent_emb'].detach() train_dataset, valid_dataset, test_dataset, nrelation, nentity = get_task_dataset(data, args) test_dataloader_tail = DataLoader( test_dataset, batch_size=args.test_batch_size, # num_workers=max(1, args.num_cpu), collate_fn=TestDataset.collate_fn ) client_res = ddict(float) for batch in test_dataloader_tail: triplets, labels, mode = batch # triplets, labels, mode = next(test_dataloader_list[i].__iter__()) triplets, labels = triplets.to(args.gpu), labels.to(args.gpu) head_idx, rel_idx, tail_idx = triplets[:, 0], triplets[:, 1], triplets[:, 2] pred = kge_model((triplets, None), rel_embed, ent_embed, mode=mode) b_range = torch.arange(pred.size()[0], device=args.gpu) target_pred = pred[b_range, tail_idx] pred = torch.where(labels.byte(), -torch.ones_like(pred) * 10000000, pred) pred[b_range, tail_idx] = target_pred ranks = 1 + torch.argsort(torch.argsort(pred, dim=1, descending=True), dim=1, descending=False)[b_range, tail_idx] ranks = ranks.float() count = torch.numel(ranks) results['count'] += count results['mr'] += torch.sum(ranks).item() results['mrr'] += torch.sum(1.0 / ranks).item() one_results['count'] += count one_results['mr'] += torch.sum(ranks).item() one_results['mrr'] += torch.sum(1.0 / ranks).item() for k in [1, 5, 10]: results['hits@{}'.format(k)] += torch.numel(ranks[ranks <= k]) one_results['hits@{}'.format(k)] += torch.numel(ranks[ranks <= k]) for k, v in one_results.items(): if k != 'count': one_results[k] = v / one_results['count'] logging.info('mrr: {:.4f}, hits@1: {:.4f}, hits@5: {:.4f}, hits@10: {:.4f}'.format( one_results['mrr'], one_results['hits@1'], one_results['hits@5'], one_results['hits@10'])) for k, v in results.items(): if k != 'count': results[k] = v / results['count'] logging.info('mrr: {:.4f}, hits@1: {:.4f}, hits@5: {:.4f}, hits@10: {:.4f}'.format( results['mrr'], results['hits@1'], results['hits@5'], results['hits@10'])) return results
def __init__(self, args, data): self.args = args self.data = data if args.run_mode == 'Entire': train_dataset, valid_dataset, test_dataset, nrelation, nentity = get_task_dataset_entire(data, args) else: train_dataset, valid_dataset, test_dataset, nrelation, nentity = get_task_dataset(data, args) self.nentity = nentity self.nrelation = nrelation # embedding embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim]) if args.model in ['RotatE', 'ComplEx']: self.entity_embedding = torch.zeros(self.nentity, args.hidden_dim * 2).to(args.gpu).requires_grad_() else: self.entity_embedding = torch.zeros(self.nentity, args.hidden_dim).to(args.gpu).requires_grad_() nn.init.uniform_( tensor=self.entity_embedding, a=-embedding_range.item(), b=embedding_range.item() ) if args.model in ['ComplEx']: self.relation_embedding = torch.zeros(self.nrelation, args.hidden_dim * 2).to(args.gpu).requires_grad_() else: self.relation_embedding = torch.zeros(self.nrelation, args.hidden_dim).to(args.gpu).requires_grad_() nn.init.uniform_( tensor=self.relation_embedding, a=-embedding_range.item(), b=embedding_range.item() ) # dataloader self.train_dataloader = DataLoader( train_dataset, batch_size = args.batch_size, shuffle = True, collate_fn = TrainDataset.collate_fn ) if args.run_mode == 'Entire': self.valid_dataloader = DataLoader( valid_dataset, batch_size=args.test_batch_size, collate_fn=TestDataset_Entire.collate_fn ) self.test_dataloader = DataLoader( test_dataset, batch_size=args.test_batch_size, collate_fn=TestDataset_Entire.collate_fn ) else: self.valid_dataloader = DataLoader( valid_dataset, batch_size=args.test_batch_size, collate_fn=TestDataset.collate_fn ) self.test_dataloader = DataLoader( test_dataset, batch_size = args.test_batch_size, collate_fn=TestDataset.collate_fn ) # model self.kge_model = KGEModel(args, args.model) self.optimizer = torch.optim.Adam( [{'params': self.entity_embedding}, {'params': self.relation_embedding}], lr=args.lr )
class KGERunner(): def __init__(self, args, data): self.args = args self.data = data if args.run_mode == 'Entire': train_dataset, valid_dataset, test_dataset, nrelation, nentity = get_task_dataset_entire(data, args) else: train_dataset, valid_dataset, test_dataset, nrelation, nentity = get_task_dataset(data, args) self.nentity = nentity self.nrelation = nrelation # embedding embedding_range = torch.Tensor([(args.gamma + args.epsilon) / args.hidden_dim]) if args.model in ['RotatE', 'ComplEx']: self.entity_embedding = torch.zeros(self.nentity, args.hidden_dim * 2).to(args.gpu).requires_grad_() else: self.entity_embedding = torch.zeros(self.nentity, args.hidden_dim).to(args.gpu).requires_grad_() nn.init.uniform_( tensor=self.entity_embedding, a=-embedding_range.item(), b=embedding_range.item() ) if args.model in ['ComplEx']: self.relation_embedding = torch.zeros(self.nrelation, args.hidden_dim * 2).to(args.gpu).requires_grad_() else: self.relation_embedding = torch.zeros(self.nrelation, args.hidden_dim).to(args.gpu).requires_grad_() nn.init.uniform_( tensor=self.relation_embedding, a=-embedding_range.item(), b=embedding_range.item() ) # dataloader self.train_dataloader = DataLoader( train_dataset, batch_size = args.batch_size, shuffle = True, collate_fn = TrainDataset.collate_fn ) if args.run_mode == 'Entire': self.valid_dataloader = DataLoader( valid_dataset, batch_size=args.test_batch_size, collate_fn=TestDataset_Entire.collate_fn ) self.test_dataloader = DataLoader( test_dataset, batch_size=args.test_batch_size, collate_fn=TestDataset_Entire.collate_fn ) else: self.valid_dataloader = DataLoader( valid_dataset, batch_size=args.test_batch_size, collate_fn=TestDataset.collate_fn ) self.test_dataloader = DataLoader( test_dataset, batch_size = args.test_batch_size, collate_fn=TestDataset.collate_fn ) # model self.kge_model = KGEModel(args, args.model) self.optimizer = torch.optim.Adam( [{'params': self.entity_embedding}, {'params': self.relation_embedding}], lr=args.lr ) def before_test_load(self): state = torch.load(os.path.join(self.args.state_dir, self.args.name + '.best'), map_location=self.args.gpu) self.relation_embedding = state['rel_emb'] self.entity_embedding = state['ent_emb'] # def load_from_multi(self): # state = torch.load(os.path.join(self.args.state_dir, self.args.name + '.best'), # map_location=self.args.gpu) # self.relation_embedding = state['rel_emb'] # self.entity_embedding = state['ent_emb'] # # nentity = len(np.unique(data['train']['edge_index'].reshape(-1))) # nrelation = len(np.unique(data['train']['edge_type'].reshape(-1))) # rel_purm = np.zeros(nrelation, dtype=np.int64) # ent_purm = np.zeros(nentity, dtype=np.int64) # for i in range(data['train']['edge_index'].shape[1]): # h, r, t = data['train']['edge_index'][0][i], data['train']['edge_type'][i], \ # data['train']['edge_index'][1][i] # h_ori, r_ori, t_ori = data['train']['edge_index_ori'][0][i], data['train']['edge_type_ori'][i], \ # data['train']['edge_index_ori'][1][i] # ent_purm[h] = h_ori # rel_purm[r] = r_ori # ent_purm[t] = t_ori # ent_purm = torch.LongTensor(ent_purm) # rel_purm = torch.LongTensor(rel_purm) # # self.relation_embedding = self.relation_embedding[rel_purm] # self.entity_embedding = self.entity_embedding[ent_purm] def write_training_loss(self, loss, e): self.args.writer.add_scalar("training/loss", loss, e) def write_evaluation_result(self, results, e): self.args.writer.add_scalar("evaluation/mrr", results['mrr'], e) self.args.writer.add_scalar("evaluation/hits10", results['hits@10'], e) self.args.writer.add_scalar("evaluation/hits5", results['hits@5'], e) self.args.writer.add_scalar("evaluation/hits1", results['hits@1'], e) def save_checkpoint(self, e): state = {'rel_emb': self.relation_embedding, 'ent_emb': self.entity_embedding} # delete previous checkpoint for filename in os.listdir(self.args.state_dir): if self.args.name in filename.split('.') and os.path.isfile(os.path.join(self.args.state_dir, filename)): os.remove(os.path.join(self.args.state_dir, filename)) # save current checkpoint torch.save(state, os.path.join(self.args.state_dir, self.args.name + '.' + str(e) + '.ckpt')) def save_model(self, best_epoch): os.rename(os.path.join(self.args.state_dir, self.args.name + '.' + str(best_epoch) + '.ckpt'), os.path.join(self.args.state_dir, self.args.name + '.best')) def train(self): best_epoch = 0 best_mrr = 0 bad_count = 0 for epoch in range(self.args.max_epoch): losses = [] self.kge_model.train() for batch in self.train_dataloader: positive_sample, negative_sample, subsampling_weight = batch positive_sample = positive_sample.to(args.gpu) negative_sample = negative_sample.to(args.gpu) subsampling_weight = subsampling_weight.to(args.gpu) negative_score = self.kge_model((positive_sample, negative_sample), self.relation_embedding, self.entity_embedding) # In self-adversarial sampling, we do not apply back-propagation on the sampling weight negative_score = (F.softmax(negative_score * args.adversarial_temperature, dim=1).detach() * F.logsigmoid(-negative_score)).sum(dim=1) positive_score = self.kge_model(positive_sample, self.relation_embedding, self.entity_embedding, neg=False) positive_score = F.logsigmoid(positive_score).squeeze(dim=1) positive_sample_loss = - positive_score.mean() negative_sample_loss = - negative_score.mean() loss = (positive_sample_loss + negative_sample_loss) / 2 self.optimizer.zero_grad() loss.backward() self.optimizer.step() losses.append(loss.item()) if epoch % self.args.log_per_epoch == 0: logging.info('epoch: {} | loss: {:.4f}'.format(epoch, np.mean(losses))) self.write_training_loss(np.mean(losses), epoch) if epoch % self.args.check_per_epoch == 0: if args.run_mode == 'Entire': eval_res = self.evaluate_multi() else: eval_res = self.evaluate() self.write_evaluation_result(eval_res, epoch) if eval_res['mrr'] > best_mrr: best_mrr = eval_res['mrr'] best_epoch = epoch logging.info('best model | mrr {:.4f}'.format(best_mrr)) self.save_checkpoint(epoch) bad_count = 0 else: bad_count += 1 logging.info('best model is at round {0}, mrr {1:.4f}, bad count {2}'.format( best_epoch, best_mrr, bad_count)) if bad_count >= self.args.early_stop_patience: logging.info('early stop at round {}'.format(epoch)) break logging.info('finish training') logging.info('save best model') self.save_model(best_epoch) logging.info('eval on test set') self.before_test_load() if args.run_mode == 'multi_client_train': eval_res = self.evaluate_multi(eval_split='test') else: eval_res = self.evaluate(eval_split='test') def evaluate_multi(self, eval_split='valid'): if eval_split == 'test': dataloader = self.test_dataloader elif eval_split == 'valid': dataloader = self.valid_dataloader client_ranks = ddict(list) all_ranks = [] for batch in dataloader: triplets, labels, triple_idx = batch triplets, labels = triplets.to(args.gpu), labels.to(args.gpu) head_idx, rel_idx, tail_idx = triplets[:, 0], triplets[:, 1], triplets[:, 2] pred = self.kge_model((triplets, None), self.relation_embedding, self.entity_embedding) b_range = torch.arange(pred.size()[0], device=self.args.gpu) target_pred = pred[b_range, tail_idx] pred = torch.where(labels.byte(), -torch.ones_like(pred) * 10000000, pred) pred[b_range, tail_idx] = target_pred ranks = 1 + torch.argsort(torch.argsort(pred, dim=1, descending=True), dim=1, descending=False)[b_range, tail_idx] ranks = ranks.float() for i in range(args.num_multi): client_ranks[i].extend(ranks[triple_idx == i].tolist()) all_ranks.extend(ranks.tolist()) for i in range(args.num_multi): results = ddict(float) ranks = torch.tensor(client_ranks[i]) count = torch.numel(ranks) results['count'] = count results['mr'] = torch.sum(ranks).item() / count results['mrr'] = torch.sum(1.0 / ranks).item() / count for k in [1, 5, 10]: results['hits@{}'.format(k)] = torch.numel(ranks[ranks <= k]) / count logging.info('mrr: {:.4f}, hits@1: {:.4f}, hits@5: {:.4f}, hits@10: {:.4f}'.format( results['mrr'], results['hits@1'], results['hits@5'], results['hits@10'])) results = ddict(float) ranks = torch.tensor(all_ranks) count = torch.numel(ranks) results['count'] = count results['mr'] = torch.sum(ranks).item() / count results['mrr'] = torch.sum(1.0 / ranks).item() / count for k in [1, 5, 10]: results['hits@{}'.format(k)] = torch.numel(ranks[ranks <= k]) / count logging.info('mrr: {:.4f}, hits@1: {:.4f}, hits@5: {:.4f}, hits@10: {:.4f}'.format( results['mrr'], results['hits@1'], results['hits@5'], results['hits@10'])) return results def evaluate(self, eval_split='valid'): results = ddict(float) if eval_split == 'test': dataloader = self.test_dataloader elif eval_split == 'valid': dataloader = self.valid_dataloader pred_list = [] rank_list = [] results_list = [] for batch in dataloader: triplets, labels = batch triplets, labels = triplets.to(args.gpu), labels.to(args.gpu) head_idx, rel_idx, tail_idx = triplets[:, 0], triplets[:, 1], triplets[:, 2] pred = self.kge_model((triplets, None), self.relation_embedding, self.entity_embedding) b_range = torch.arange(pred.size()[0], device=self.args.gpu) target_pred = pred[b_range, tail_idx] pred = torch.where(labels.byte(), -torch.ones_like(pred) * 10000000, pred) pred[b_range, tail_idx] = target_pred pred_argsort = torch.argsort(pred, dim=1, descending=True) ranks = 1 + torch.argsort(pred_argsort, dim=1, descending=False)[b_range, tail_idx] pred_list.append(pred_argsort[:, :10]) rank_list.append(ranks) ranks = ranks.float() for idx, tri in enumerate(triplets): results_list.append([tri.tolist(), ranks[idx].item()]) count = torch.numel(ranks) results['count'] += count results['mr'] += torch.sum(ranks).item() results['mrr'] += torch.sum(1.0 / ranks).item() for k in [1, 5, 10]: results['hits@{}'.format(k)] += torch.numel(ranks[ranks <= k]) torch.save(torch.cat(pred_list, dim=0), os.path.join(args.state_dir, args.name + '_' + str(args.one_client_idx) + '.pred')) torch.save(torch.cat(rank_list), os.path.join(args.state_dir, args.name + '_' + str(args.one_client_idx) + '.rank')) for k, v in results.items(): if k != 'count': results[k] /= results['count'] logging.info('mrr: {:.4f}, hits@1: {:.4f}, hits@5: {:.4f}, hits@10: {:.4f}'.format( results['mrr'], results['hits@1'], results['hits@5'], results['hits@10'])) test_rst_file = os.path.join(args.log_dir, args.name + '.test.rst') pickle.dump(results_list, open(test_rst_file, 'wb')) return results
def train_eval(args, train_triples, valid_triples, test_triples): all_true_triples = train_triples + valid_triples + test_triples set_seed(args.seed_num) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count() model = KGEModel(model_name=args.model, nentity=args.nentity, nrelation=args.nrelation, hidden_dim=args.hidden_dim, gamma=args.gamma, double_entity_embedding=args.double_entity_embedding, double_relation_embedding=args.double_relation_embedding) model.to(device) if n_gpu > 1: model = torch.nn.DataParallel(model) current_learning_rate = args.learning_rate optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=current_learning_rate) if args.do_train: # Set training dataloader iterator train_dataloader_head = DataLoader( TrainDataset(train_triples, args.nentity, args.nrelation, args.negative_sample_size, 'head-batch'), batch_size=args.batch_size, shuffle=True, num_workers=max(1, args.cpu_num // 2), collate_fn=TrainDataset.collate_fn) train_dataloader_tail = DataLoader( TrainDataset(train_triples, args.nentity, args.nrelation, args.negative_sample_size, 'tail-batch'), batch_size=args.batch_size, shuffle=True, num_workers=max(1, args.cpu_num // 2), collate_fn=TrainDataset.collate_fn) train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail) # Set training configuration current_learning_rate = args.learning_rate optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=current_learning_rate) if args.warm_up_steps: warm_up_steps = args.warm_up_steps else: warm_up_steps = args.max_steps // 2 if args.init_checkpoint: # Restore model from checkpoint directory logging.info('Loading checkpoint %s...' % args.init_checkpoint) checkpoint = torch.load( os.path.join(args.init_checkpoint, 'checkpoint')) init_step = checkpoint['step'] model.load_state_dict(checkpoint['model_state_dict']) if args.do_train: current_learning_rate = checkpoint['current_learning_rate'] warm_up_steps = checkpoint['warm_up_steps'] optimizer.load_state_dict(checkpoint['optimizer_state_dict']) else: logging.info('Ramdomly Initializing %s Model...' % args.model) init_step = 0 step = init_step logging.info('Start Training...') logging.info('init_step = %d' % init_step) logging.info('learning_rate = %d' % current_learning_rate) logging.info('batch_size = %d' % args.batch_size) logging.info('negative_adversarial_sampling = %d' % args.negative_adversarial_sampling) logging.info('hidden_dim = %d' % args.hidden_dim) logging.info('gamma = %f' % args.gamma) logging.info('negative_adversarial_sampling = %s' % str(args.negative_adversarial_sampling)) if args.negative_adversarial_sampling: logging.info('adversarial_temperature = %f' % args.adversarial_temperature) if args.do_train: training_logs = [] # Training Loop for step in range(init_step, args.max_steps): model.train() optimizer.zero_grad() positive_sample, negative_sample, subsampling_weight, mode = next( train_iterator) positive_sample = positive_sample.to(device) negative_sample = negative_sample.to(device) subsampling_weight = subsampling_weight.to(device) negative_score = model((positive_sample, negative_sample), mode=mode) if args.negative_adversarial_sampling: # In self-adversarial sampling, we do not apply back-propagation on the sampling weight negative_score = ( F.softmax(negative_score * args.adversarial_temperature, dim=1).detach() * F.logsigmoid(-negative_score)).sum(dim=1) else: negative_score = F.logsigmoid(-negative_score).mean(dim=1) positive_score = model(positive_sample) positive_score = F.logsigmoid(positive_score).squeeze(dim=1) if args.uni_weight: positive_sample_loss = -positive_score.mean() negative_sample_loss = -negative_score.mean() else: positive_sample_loss = -(subsampling_weight * positive_score ).sum() / subsampling_weight.sum() negative_sample_loss = -(subsampling_weight * negative_score ).sum() / subsampling_weight.sum() loss = (positive_sample_loss + negative_sample_loss) / 2 if args.regularization != 0.0: # Use L3 regularization for ComplEx and DistMult regularization = args.regularization * ( model.entity_embedding.norm(p=3)**3 + model.relation_embedding.norm(p=3).norm(p=3)**3) loss = loss + regularization regularization_log = {'regularization': regularization.item()} else: regularization_log = {} loss.backward() optimizer.step() log = { **regularization_log, 'positive_sample_loss': positive_sample_loss.item(), 'negative_sample_loss': negative_sample_loss.item(), 'loss': loss.item() } training_logs.append(log) if step >= warm_up_steps: current_learning_rate = current_learning_rate / 10 logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step)) optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=current_learning_rate) warm_up_steps = warm_up_steps * 3 if step % args.save_checkpoint_steps == 0: save_variable_list = { 'step': step, 'current_learning_rate': current_learning_rate, 'warm_up_steps': warm_up_steps } save_model(model, optimizer, save_variable_list, args) if step % args.log_steps == 0: metrics = {} for metric in training_logs[0].keys(): metrics[metric] = sum( [log[metric] for log in training_logs]) / len(training_logs) log_metrics('Training average', step, metrics) training_logs = [] if args.do_valid and step % args.valid_steps == 0: logging.info('Evaluating on Valid Dataset...') metrics = test_step(model, valid_triples, all_true_triples, args) log_metrics('Valid', step, metrics) save_variable_list = { 'step': step, 'current_learning_rate': current_learning_rate, 'warm_up_steps': warm_up_steps } save_model(model, optimizer, save_variable_list, args) if args.do_valid: logging.info('Evaluating on Valid Dataset...') metrics = model.test_step(model, valid_triples, all_true_triples, args) log_metrics('Valid', step, metrics) if args.do_test: logging.info('Evaluating on Test Dataset...') metrics = model.test_step(model, test_triples, all_true_triples, args) log_metrics('Test', step, metrics) if args.evaluate_train: logging.info('Evaluating on Training Dataset...') metrics = model.test_step(model, train_triples, all_true_triples, args) log_metrics('Test', step, metrics)