kge_model.parameters()), lr=learning_rate) training_logs = [] valid_logs = [] for step in range(args["n_epoch"] * train_iterator.epoch_size): loss = kge_model.train_step(optimizer, train_iterator, args) training_logs.append(('train', loss)) if step % args["save_checkpoint_steps"] == 0 and step > 0: torch.save( { 'step': step, 'loss': loss, 'model': kge_model.state_dict() }, "checkpoint_" + str(now)) if step % args["log_steps"] == 0: print("step:", step, "loss:", loss) if step % args["valid_steps"] == 0: logging.info('Evaluating on Valid Dataset...') valid_loss, metrics = kge_model.test_step(validation_iterator, args) training_logs.append(('validation', valid_loss)) valid_logs.append(metrics) # save progress DataFrame(valid_logs).to_csv("valid_logs.csv") DataFrame(training_logs, columns=['type', 'loss']).to_csv("training_logs.csv")
def main(args): if args.seed != -1: torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) if (not args.do_train) and (not args.do_valid) and (not args.do_test): raise ValueError('one of train/val/test mode must be choosed.') if args.init_checkpoint: override_config(args) elif args.data_path is None: raise ValueError('one of init_checkpoint/data_path must be choosed.') if args.do_train and args.save_path is None: raise ValueError('Where do you want to save your trained model?') if args.save_path and not os.path.exists(args.save_path): os.makedirs(args.save_path) if args.do_train and args.do_valid: if not os.path.exists("%s/best/" % args.save_path): os.makedirs("%s/best/" % args.save_path) # Write logs to checkpoint and console set_logger(args) with open(os.path.join(args.data_path, 'entities.dict')) as fin: entity2id = dict() for line in fin: eid, entity = line.strip().split('\t') entity2id[entity] = int(eid) with open(os.path.join(args.data_path, 'relations.dict')) as fin: relation2id = dict() for line in fin: rid, relation = line.strip().split('\t') relation2id[relation] = int(rid) # Read regions for Countries S* datasets if args.countries: regions = list() with open(os.path.join(args.data_path, 'regions.list')) as fin: for line in fin: region = line.strip() regions.append(entity2id[region]) args.regions = regions nentity = len(entity2id) nrelation = len(relation2id) args.nentity = nentity args.nrelation = nrelation logging.info('Model: %s' % args.model) logging.info('Data Path: %s' % args.data_path) logging.info('#entity: %d' % nentity) logging.info('#relation: %d' % nrelation) train_triples = read_triple(os.path.join(args.data_path, 'train.txt'), entity2id, relation2id) logging.info('#train: %d' % len(train_triples)) valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'), entity2id, relation2id) logging.info('#valid: %d' % len(valid_triples)) test_triples = read_triple(os.path.join(args.data_path, 'test.txt'), entity2id, relation2id) logging.info('#test: %d' % len(test_triples)) train_triples_tsr = torch.LongTensor(train_triples).transpose( 0, 1) #idx X batch #All true triples all_true_triples = train_triples + valid_triples + test_triples #if args.use_gnn: # assert False # #kge_model = GNN_KGEModel( # # model_name=args.model, # # nentity=nentity, # # nrelation=nrelation, # # hidden_dim=args.hidden_dim, # # gamma=args.gamma, # # num_layers=args.gnn_layers, # # args = args, # # dropout=args.dropout, # # double_entity_embedding=args.double_entity_embedding, # # double_relation_embedding=args.double_relation_embedding, # #) #else: kge_model = KGEModel( model_name=args.model, nentity=nentity, nrelation=nrelation, hidden_dim=args.hidden_dim, gamma=args.gamma, args=args, double_entity_embedding=args.double_entity_embedding, double_relation_embedding=args.double_relation_embedding, ) logging.info('Model Configuration:') logging.info(str(kge_model)) logging.info('Model Parameter Configuration:') for name, param in kge_model.named_parameters(): logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad))) if args.cuda: kge_model = kge_model.cuda() train_triples_tsr = train_triples_tsr.cuda() #kge_model.build_cxt_triple_map(train_triples) if args.do_train: # Set training dataloader iterator if args.same_head_tail: #shuffle train_triples first and no shuffle within dataloaders. So both head and tail will share the same idx shuffle(train_triples) train_dataloader_head = DataLoader( TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'head-batch'), batch_size=args.batch_size, shuffle=False, num_workers=max(1, args.cpu_num // 2), collate_fn=TrainDataset.collate_fn) train_dataloader_tail = DataLoader( TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'tail-batch'), batch_size=args.batch_size, shuffle=False, num_workers=max(1, args.cpu_num // 2), collate_fn=TrainDataset.collate_fn) else: train_dataloader_head = DataLoader( TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'head-batch'), batch_size=args.batch_size, shuffle=True, num_workers=max(1, args.cpu_num // 2), collate_fn=TrainDataset.collate_fn) train_dataloader_tail = DataLoader( TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'tail-batch'), batch_size=args.batch_size, shuffle=True, num_workers=max(1, args.cpu_num // 2), collate_fn=TrainDataset.collate_fn) train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail) #else: # train_dataloader_rel = DataLoader( # TrainDataset(train_triples, nentity, nrelation, # args.negative_sample_head_size*args.negative_sample_tail_size, # 'rel-batch', # negative_sample_head_size =args.negative_sample_head_size, # negative_sample_tail_size =args.negative_sample_tail_size, # half_correct=args.negative_sample_half_correct), # batch_size=args.batch_size, # shuffle=True, # num_workers=max(1, args.cpu_num//2), # collate_fn=TrainDataset.collate_fn # ) # train_iterator = BidirectionalOneShotIterator.one_shot_iterator(train_dataloader_rel) # tail_only = True # Set training configuration current_learning_rate = args.learning_rate optimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, kge_model.parameters()), lr=current_learning_rate, weight_decay=args.weight_decay, ) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5, last_epoch=-1) #if args.warm_up_steps: # warm_up_steps = args.warm_up_steps #else: # warm_up_steps = args.max_steps // 2 if args.init_checkpoint: # Restore model from checkpoint directory logging.info('Loading checkpoint %s...' % args.init_checkpoint) checkpoint = torch.load( os.path.join(args.init_checkpoint, 'checkpoint')) init_step = checkpoint['step'] if 'score_weight' in kge_model.state_dict( ) and 'score_weight' not in checkpoint['model_state_dict']: checkpoint['model_state_dict'][ 'score_weights'] = kge_model.state_dict()['score_weights'] kge_model.load_state_dict(checkpoint['model_state_dict']) if args.do_train: current_learning_rate = checkpoint['current_learning_rate'] #warm_up_steps = checkpoint['warm_up_steps'] optimizer.load_state_dict(checkpoint['optimizer_state_dict']) else: current_learning_rate = 0 elif args.init_embedding: logging.info('Loading pretrained embedding %s ...' % args.init_embedding) if kge_model.entity_embedding is not None: entity_embedding = np.load( os.path.join(args.init_embedding, 'entity_embedding.npy')) relation_embedding = np.load( os.path.join(args.init_embedding, 'relation_embedding.npy')) entity_embedding = torch.from_numpy(entity_embedding).to( kge_model.entity_embedding.device) relation_embedding = torch.from_numpy(relation_embedding).to( kge_model.relation_embedding.device) kge_model.entity_embedding.data[:entity_embedding. size(0)] = entity_embedding kge_model.relation_embedding.data[:relation_embedding. size(0)] = relation_embedding init_step = 1 current_learning_rate = 0 else: logging.info('Ramdomly Initializing %s Model...' % args.model) init_step = 1 step = init_step logging.info('Start Training...') logging.info('init_step = %d' % init_step) logging.info('learning_rate = %.5f' % current_learning_rate) logging.info('batch_size = %d' % args.batch_size) logging.info('negative_adversarial_sampling = %d' % args.negative_adversarial_sampling) logging.info('hidden_dim = %d' % args.hidden_dim) logging.info('gamma = %f' % args.gamma) logging.info('negative_adversarial_sampling = %s' % str(args.negative_adversarial_sampling)) if args.negative_adversarial_sampling: logging.info('adversarial_temperature = %f' % args.adversarial_temperature) # Set valid dataloader as it would be evaluated during training #loss_func = nn.BCEWithLogitsLoss(reduction="none") if args.use_bceloss else nn.LogSigmoid() if args.use_bceloss: loss_func = nn.BCELoss(reduction="none") elif args.use_softmarginloss: loss_func = nn.SoftMarginLoss(reduction="none") else: loss_func = nn.LogSigmoid() #kge_model.cluster_relation_entity_embedding(args.context_cluster_num, args.context_cluster_scale) if args.do_train: training_logs = [] best_metrics = None #Training Loop optimizer.zero_grad() for step in range(init_step, args.max_steps + 1): if step % args.update_freq == 1 or args.update_freq == 1: optimizer.zero_grad() log = kge_model.train_step(kge_model, train_iterator, train_triples_tsr, loss_func, args) if step % args.update_freq == 0: optimizer.step() training_logs.append(log) #if step >= warm_up_steps: # current_learning_rate = current_learning_rate / 10 # logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step)) # optimizer = torch.optim.Adam( # filter(lambda p: p.requires_grad, kge_model.parameters()), # lr=current_learning_rate # ) # warm_up_steps = warm_up_steps * 3 if step % args.schedule_steps == 0: scheduler.step() if step % args.save_checkpoint_steps == 0: save_variable_list = { 'step': step, 'current_learning_rate': current_learning_rate, #'warm_up_steps': warm_up_steps } save_model(kge_model, optimizer, save_variable_list, args) if step % args.log_steps == 0: metrics = {} for metric in training_logs[0].keys(): metrics[metric] = sum( [log[metric] for log in training_logs]) / len(training_logs) log_metrics('Training average', step, [metrics]) training_logs = [] if args.do_valid and step % args.valid_steps == 0: logging.info('Evaluating on Valid Dataset...') metrics = kge_model.test_step(kge_model, valid_triples, all_true_triples, train_triples_tsr, args) log_metrics('Valid', step, metrics) if is_better_metric(best_metrics, metrics): save_variable_list = { 'step': step, 'current_learning_rate': current_learning_rate, #'warm_up_steps': warm_up_steps } save_model(kge_model, optimizer, save_variable_list, args, True) best_metrics = metrics #kge_model.cluster_relation_entity_embedding(args.context_cluster_num, args.context_cluster_scale) save_variable_list = { 'step': step, 'current_learning_rate': current_learning_rate, #'warm_up_steps': warm_up_steps } save_model(kge_model, optimizer, save_variable_list, args) if args.do_valid and args.do_train: #load the best model best_checkpoint = torch.load("%s/best/checkpoint" % args.save_path) kge_model.load_state_dict(best_checkpoint['model_state_dict']) logging.info("Loading best model from step %d" % best_checkpoint['step']) step = best_checkpoint['step'] if args.do_valid: logging.info('Evaluating on Valid Dataset...') metrics = kge_model.test_step(kge_model, valid_triples, all_true_triples, train_triples_tsr, args) log_metrics('Valid', step, metrics) if args.do_test: logging.info('Evaluating on Test Dataset...') metrics = kge_model.test_step(kge_model, test_triples, all_true_triples, train_triples_tsr, args) log_metrics('Test', step, metrics) if args.evaluate_train: logging.info('Evaluating on Training Dataset...') metrics = kge_model.test_step(kge_model, train_triples, all_true_triples, train_triples_tsr, args) log_metrics('Test', step, metrics)