def _run_train_job(sicnk, device=None): """Runs a training job and returns the trace entry of its best validation result. Also takes are of appropriate tracing. """ search_job, train_job_index, train_job_config, train_job_count, trace_keys = sicnk try: # load the job if device is not None: train_job_config.set("job.device", device) search_job.config.log( "Starting training job {} ({}/{}) on device {}...".format( train_job_config.folder, train_job_index + 1, train_job_count, train_job_config.get("job.device"), )) checkpoint_file = get_checkpoint_file(train_job_config) if checkpoint_file is not None: checkpoint = load_checkpoint(checkpoint_file, train_job_config.get("job.device")) job = Job.create_from( checkpoint=checkpoint, new_config=train_job_config, dataset=search_job.dataset, parent_job=search_job, ) else: job = Job.create( config=train_job_config, dataset=search_job.dataset, parent_job=search_job, ) # process the trace entries to far (in case of a resumed job) metric_name = search_job.config.get("valid.metric") valid_trace = [] def copy_to_search_trace(job, trace_entry=None): if trace_entry is None: trace_entry = job.valid_trace[-1] trace_entry = copy.deepcopy(trace_entry) for key in trace_keys: # Process deprecated options to some extent. Support key renames, but # not value renames. actual_key = {key: None} _process_deprecated_options(actual_key) if len(actual_key) > 1: raise KeyError( f"{key} is deprecated but cannot be handled automatically" ) actual_key = next(iter(actual_key.keys())) value = train_job_config.get(actual_key) trace_entry[key] = value trace_entry["folder"] = os.path.split(train_job_config.folder)[1] metric_value = Trace.get_metric(trace_entry, metric_name) trace_entry["metric_name"] = metric_name trace_entry["metric_value"] = metric_value trace_entry["parent_job_id"] = search_job.job_id search_job.config.trace(**trace_entry) valid_trace.append(trace_entry) for trace_entry in job.valid_trace: copy_to_search_trace(None, trace_entry) # run the job (adding new trace entries as we go) # TODO make this less hacky (easier once integrated into SearchJob) from kge.job import ManualSearchJob if not isinstance( search_job, ManualSearchJob) or search_job.config.get("manual_search.run"): job.post_valid_hooks.append(copy_to_search_trace) job.run() else: search_job.config.log( "Skipping running of training job as requested by user.") return (train_job_index, None, None) # analyze the result search_job.config.log("Best result in this training job:") best = None best_metric = None for trace_entry in valid_trace: metric = trace_entry["metric_value"] if not best or Metric(search_job).better(metric, best_metric): best = trace_entry best_metric = metric # record the best result of this job best["child_job_id"] = best["job_id"] for k in ["job", "job_id", "type", "parent_job_id", "scope", "event"]: if k in best: del best[k] search_job.trace( event="search_completed", echo=True, echo_prefix=" ", log=True, scope="train", **best, ) # force releasing the GPU memory of the job to avoid memory leakage del job gc.collect() return (train_job_index, best, best_metric) except BaseException as e: search_job.config.log("Trial {:05d} failed: {}".format( train_job_index, repr(e))) if search_job.on_error == "continue": return (train_job_index, None, None) else: search_job.config.log( "Aborting search due to failure of trial {:05d}".format( train_job_index)) raise e
def main(): args = parse_args() # Load model checkpoint and data checkpoint = load_checkpoint(args.model_checkpoint, device="cpu") model_pt = kge.model.KgeModel.create_from(checkpoint) print("Loaded model from", args.model_checkpoint) dataset = model_pt.dataset # Load all data train_spo, valid_spo, test_spo = [ dataset.split(split) for split in ("train", "valid", "test") ] all_spo = torch.cat((train_spo, valid_spo, test_spo), axis=0).long() # Load relation ID to string mapping relation_ids = dataset.relation_ids() metric_names = ("mrr", "hits@10") metrics_all = defaultdict(lambda: defaultdict(list)) dfs = [] # Keep track of percentage of test triples per relation type for rid in tqdm(torch.unique(test_spo[:, 1]), desc="Relation"): rid = rid.item() # Get all test triples with this relation test_filt = test_spo[test_spo[:, 1] == rid] for direction in ["s", "o"]: # (?, r, t) and (h, r, ?) metrics_mean = defaultdict(dict) for modelname, score_fn in zip( ["Model", "Baseline"], [score_with_model, score_by_frequency]): # score test triples and evaluate rankings scores = score_fn(model_pt, test_filt, direction=direction) model_metrics = evaluate_rankings(scores, test_filt, all_spo, direction=direction) for metric_name, metric in zip(metric_names, model_metrics): metrics_mean[modelname][metric_name] = np.mean(metric) metrics_all[modelname][metric_name].extend(metric) for metric_name in metric_names: model_metric = metrics_mean["Model"][metric_name] baseline_metric = metrics_mean["Baseline"][metric_name] diff = model_metric - baseline_metric line = dict( relation=relation_ids[rid], metric=metric_name, direction=direction, count=len(test_filt), diff=diff, model=model_metric, baseline=baseline_metric, ) if args.csv is not None: dfs.append( pd.DataFrame.from_dict(line, orient="index").transpose()) if args.csv is not None: df = pd.concat(dfs) df.to_csv(args.csv, index=False) print("Saved results to", args.csv) for modelname in metrics_all: for metric, scores in metrics_all[modelname].items(): print(modelname, metric, np.mean(scores))
def main(): args = parse_args() # Load first model, get dataset # Assumes all models trained on same data checkpoint = load_checkpoint(args.model_files[0], device="cpu") model = kge.model.KgeModel.create_from(checkpoint) dataset = model.dataset splits = ("valid", "test") valid_spo, test_spo = [dataset.split(split).long() for split in splits] if args.negative in ("uniform", "frequency"): valid_neg_spo, test_neg_spo = [ generate_neg_spo(dataset, split, negative_type=args.negative) for split in splits ] else: valid_neg_spo, test_neg_spo = load_neg_spo(dataset, size=args.size) print( f"Loaded {len(valid_neg_spo)} valid negatives", f"and {len(test_neg_spo)} test negatives", ) valid_spo_all = torch.cat((valid_spo, valid_neg_spo)) test_spo_all = torch.cat((test_spo, test_neg_spo)) metrics = [] dfs = [] for model_file in args.model_files: if os.path.exists(model_file): checkpoint = load_checkpoint(model_file, device="cpu") model = kge.model.KgeModel.create_from(checkpoint) # Score negative and positive validation triples X_valid, y_valid = get_X_y(model, valid_spo, valid_neg_spo) X_test, y_test = get_X_y(model, test_spo, test_neg_spo) valid_relations = valid_spo_all[:, 1].unique() test_relations = test_spo_all[:, 1].unique() y_pred_valid = torch.zeros(y_valid.shape, dtype=torch.long, device="cpu") y_pred_test = torch.zeros(y_test.shape, dtype=torch.long, device="cpu") ############################################################################ # begin credits to https://github.com/uma-pi1/kge/blob/triple_classification/kge/job/triple_classification.py#L302 # ############################################################################ REL_KEY = -1 thresholds = { r: -float("inf") for r in range(dataset.num_relations()) } thresholds[REL_KEY] = -float("inf") for r in valid_relations: # set a threshold for each relation current_rel = valid_spo_all[:, 1] == r threshold = get_threshold(X_valid[current_rel], y_valid[current_rel]) thresholds[r.item()] = threshold predictions = X_valid[current_rel] >= threshold y_pred_valid[current_rel] = predictions.view(-1).long() # also set a global threshold for relations unseen in valid set thresholds[REL_KEY] = get_threshold(X_valid, y_valid) for r in test_relations: # get predictions based on validation thresholds key = r.item() if r.item() in thresholds else REL_KEY threshold = thresholds[key] current_rel = test_spo_all[:, 1] == r predictions = X_test[current_rel] >= threshold y_pred_test[current_rel] = predictions.view(-1).long() ############################################################################ # end credits # ############################################################################ y_test = y_test.numpy() y_pred_test = y_pred_test.numpy() line = dict( valid_accuracy=accuracy_score(y_valid, y_pred_valid), valid_f1=f1_score(y_valid, y_pred_valid), test_accuracy=accuracy_score(y_test, y_pred_test), test_f1=f1_score(y_test, y_pred_test), model_file=model_file, ) metrics.append(line) if args.csv is not None: dfs.append( pd.DataFrame.from_dict(line, orient="index").transpose()) if args.csv is not None: df = pd.concat(dfs) df.to_csv(args.csv, index=False) print("Saved results to", args.csv) for metric in metrics: for key, val in metric.items(): print(f"{key}: {val}") print()
from util.eval import Evaluator import json import torch from kge.model import KgeModel from kge.util.io import load_checkpoint import numpy as np # Link prediction performances of RESCAL, ComplEx, ConvE, DistMult and TransE on WN18RR* (out-of-vocabulary entities are removed) models = ['rescal', 'complex', 'conve', 'distmult', 'transe'] for m in models: if m == 'conex': """ """ raise NotImplementedError() else: # 1. Load pretrained model via LibKGE checkpoint = load_checkpoint( f'pretrained_models/FB15K-237/fb15k-237-{m}.pt') model = KgeModel.create_from(checkpoint) # 3. Create mappings. # 3.1 Entity index mapping. entity_idxs = { e: e_idx for e, e_idx in zip(model.dataset.entity_ids(), range(len(model.dataset.entity_ids()))) } # 3.2 Relation index mapping. relation_idxs = { r: r_idx for r, r_idx in zip(model.dataset.relation_ids(), range(len(model.dataset.relation_ids()))) }
from util.eval import Evaluator import json import torch from kge.model import KgeModel from kge.util.io import load_checkpoint import numpy as np # Link prediction performances of RESCAL, ComplEx, ConvE, DistMult and TransE on WN18RR* (out-of-vocabulary entities are removed) models = ['rescal', 'complex', 'conve', 'distmult', 'transe'] for m in models: if m == 'conex': """ """ raise NotImplementedError() else: # 1. Load pretrained model via LibKGE checkpoint = load_checkpoint(f'pretrained_models/WN18RR/wnrr-{m}.pt') model = KgeModel.create_from(checkpoint) # 3. Create mappings. # 3.1 Entity index mapping. entity_idxs = { e: e_idx for e, e_idx in zip(model.dataset.entity_ids(), range(len(model.dataset.entity_ids()))) } # 3.2 Relation index mapping. relation_idxs = { r: r_idx for r, r_idx in zip(model.dataset.relation_ids(), range(len(model.dataset.relation_ids()))) }
def run(self): torch_device = self.config.get("job.device") if self.config.get("job.device") == "cuda": torch_device = "cuda:0" if torch_device != "cpu": torch.cuda.set_device(torch_device) # seeds need to be set in every process set_seeds(self.config, self.rank) os.environ["MASTER_ADDR"] = self.config.get("job.distributed.master_ip") os.environ["MASTER_PORT"] = self.config.get("job.distributed.master_port") min_rank = get_min_rank(self.config) print("before init", self.rank + min_rank) dist.init_process_group( backend="gloo", init_method="env://", world_size=self.num_total_workers + min_rank, rank=self.rank + min_rank, timeout=datetime.timedelta(hours=6), ) worker_ranks = list(range(min_rank, self.num_total_workers+min_rank)) worker_group = dist.new_group(worker_ranks, timeout=datetime.timedelta(hours=6)) # create parameter server server = None if self.config.get("job.distributed.parameter_server") == "lapse": os.environ["DMLC_NUM_WORKER"] = "0" os.environ["DMLC_NUM_SERVER"] = str(self.num_total_workers) os.environ["DMLC_ROLE"] = "server" os.environ["DMLC_PS_ROOT_URI"] = self.config.get( "job.distributed.master_ip" ) os.environ["DMLC_PS_ROOT_PORT"] = self.config.get( "job.distributed.lapse_port" ) num_workers_per_server = 1 lapse.setup(self.num_keys, num_workers_per_server) server = lapse.Server(self.num_keys, self.embedding_dim + self.optimizer_dim) elif self.config.get("job.distributed.parameter_server") == "shared": server = self.parameters # create train-worker config, dataset and folder device_pool: list = self.config.get("job.device_pool") if len(device_pool) == 0: device_pool.append(self.config.get("job.device")) worker_id = self.rank config = deepcopy(self.config) config.set("job.device", device_pool[worker_id % len(device_pool)]) config.folder = os.path.join(self.config.folder, f"worker-{self.rank}") config.init_folder() dataset = deepcopy(self.dataset) parameter_client = KgeParameterClient.create( client_type=self.config.get("job.distributed.parameter_server"), server_id=0, client_id=worker_id + min_rank, embedding_dim=self.embedding_dim + self.optimizer_dim, server=server, num_keys=self.num_keys, num_meta_keys=self.num_meta_keys, worker_group=worker_group, ) # don't re-initialize the model after loading checkpoint init_for_load_only = self.checkpoint_name is not None job = Job.create( config=config, dataset=dataset, parameter_client=parameter_client, init_for_load_only=init_for_load_only, ) if self.checkpoint_name is not None: checkpoint = load_checkpoint(self.checkpoint_name) job._load(checkpoint) job.load_distributed(checkpoint_name=self.checkpoint_name) job.run() # all done, clean up print("shut down everything") parameter_client.barrier() if hasattr(job, "work_scheduler_client"): job.work_scheduler_client.shutdown() parameter_client.shutdown() # delete all occurrences of the parameter client to properly shutdown lapse # del job del job.parameter_client del job.model.get_s_embedder().parameter_client del job.model.get_p_embedder().parameter_client del job.model if hasattr(job, "optimizer"): del job.optimizer del parameter_client gc.collect() # make sure lapse-worker destructor is called # shutdown server if server is not None and type(server) != torch.Tensor: server.shutdown() if self.result_pipe is not None: if hasattr(job, "valid_trace"): # if we valid from checkpoint there is no valid trace self.result_pipe.send(job.valid_trace) else: self.result_pipe.send(None)
def main(): # default config config = Config() # now parse the arguments parser = create_parser(config) args, unknown_args = parser.parse_known_args() # If there where unknown args, add them to the parser and reparse. The correctness # of these arguments will be checked later. if len(unknown_args) > 0: parser = create_parser( config, filter(lambda a: a.startswith("--"), unknown_args) ) args = parser.parse_args() # process meta-commands process_meta_command(args, "create", {"command": "start", "run": False}) process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"}) process_meta_command( args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"} ) process_meta_command( args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"} ) # dump command if args.command == "dump": dump(args) exit() # package command if args.command == "package": package_model(args) exit() # start command if args.command == "start": # use toy config file if no config given if args.config is None: args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml" print( "WARNING: No configuration specified; using " + args.config, file=sys.stderr, ) if not vars(args)["console.quiet"]: print("Loading configuration {}...".format(args.config)) config.load(args.config) # resume command if args.command == "resume": if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"): args.config += "/config.yaml" if not vars(args)["console.quiet"]: print("Resuming from configuration {}...".format(args.config)) config.load(args.config) config.folder = os.path.dirname(args.config) if not config.folder: config.folder = "." if not os.path.exists(config.folder): raise ValueError( "{} is not a valid config file for resuming".format(args.config) ) # overwrite configuration with command line arguments for key, value in vars(args).items(): if key in [ "command", "config", "run", "folder", "checkpoint", "abort_when_cache_outdated", ]: continue if value is not None: if key == "search.device_pool": value = "".join(value).split(",") try: if isinstance(config.get(key), bool): value = argparse_bool_type(value) except KeyError: pass config.set(key, value) if key == "model": config._import(value) # initialize output folder if args.command == "start": if args.folder is None: # means: set default config_name = os.path.splitext(os.path.basename(args.config))[0] config.folder = os.path.join( kge_base_dir(), "local", "experiments", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name, ) else: config.folder = args.folder # catch errors to log them try: if args.command == "start" and not config.init_folder(): raise ValueError("output folder {} exists already".format(config.folder)) config.log("Using folder: {}".format(config.folder)) # determine checkpoint to resume (if any) if hasattr(args, "checkpoint"): checkpoint_file = get_checkpoint_file(config, args.checkpoint) # disable processing of outdated cached dataset files globally Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated # log configuration config.log("Configuration:") config.log(yaml.dump(config.options), prefix=" ") config.log("git commit: {}".format(get_git_revision_short_hash()), prefix=" ") # set random seeds def get_seed(what): seed = config.get(f"random_seed.{what}") if seed < 0 and config.get(f"random_seed.default") >= 0: import hashlib # we add an md5 hash to the default seed so that different PRNGs get a # different seed seed = ( config.get(f"random_seed.default") + int(hashlib.md5(what.encode()).hexdigest(), 16) ) % 0xFFFF # stay 32-bit return seed if get_seed("python") > -1: import random random.seed(get_seed("python")) if get_seed("torch") > -1: import torch torch.manual_seed(get_seed("torch")) if get_seed("numpy") > -1: import numpy.random numpy.random.seed(get_seed("numpy")) if get_seed("numba") > -1: import numpy as np, numba @numba.njit def seed_numba(seed): np.random.seed(seed) seed_numba(get_seed("numba")) # let's go if args.command == "start" and not args.run: config.log("Job created successfully.") else: # load data dataset = Dataset.create(config) # let's go if args.command == "resume": if checkpoint_file is not None: checkpoint = load_checkpoint( checkpoint_file, config.get("job.device") ) job = Job.create_from( checkpoint, new_config=config, dataset=dataset ) else: job = Job.create(config, dataset) job.config.log( "No checkpoint found or specified, starting from scratch..." ) else: job = Job.create(config, dataset) job.run() except BaseException: tb = traceback.format_exc() config.log(tb, echo=False) raise
def train(data_path, neg_batch_size, batch_size, shuffle, num_workers, nb_epochs, embedding_dim, hidden_dim, relation_dim, gpu, use_cuda, patience, freeze, validate_every, hops, lr, entdrop, reldrop, scoredrop, l3_reg, model_name, decay, ls, load_from, outfile, do_batch_norm, valid_data_path=None): print('Loading entities and relations') kg_type = 'full' if 'half' in hops: kg_type = 'half' checkpoint_file = '../../pretrained_models/embeddings/ComplEx_fbwq_' + kg_type + '/checkpoint_best.pt' print('Loading kg embeddings from', checkpoint_file) kge_checkpoint = load_checkpoint(checkpoint_file) kge_model = KgeModel.create_from(kge_checkpoint) kge_model.eval() e = getEntityEmbeddings(kge_model, hops) print('Loaded entities and relations') entity2idx, idx2entity, embedding_matrix = prepare_embeddings(e) data = process_text_file(data_path, split=False) print('Train file processed, making dataloader') # word2ix,idx2word, max_len = get_vocab(data) # hops = str(num_hops) device = torch.device(gpu if use_cuda else "cpu") dataset = DatasetMetaQA(data, e, entity2idx) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) print('Creating model...') model = RelationExtractor(embedding_dim=embedding_dim, num_entities=len(idx2entity), relation_dim=relation_dim, pretrained_embeddings=embedding_matrix, freeze=freeze, device=device, entdrop=entdrop, reldrop=reldrop, scoredrop=scoredrop, l3_reg=l3_reg, model=model_name, ls=ls, do_batch_norm=do_batch_norm) print('Model created!') if load_from != '': # model.load_state_dict(torch.load("checkpoints/roberta_finetune/" + load_from + ".pt")) fname = "checkpoints/roberta_finetune/" + load_from + ".pt" model.load_state_dict( torch.load(fname, map_location=lambda storage, loc: storage)) model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = ExponentialLR(optimizer, decay) optimizer.zero_grad() best_score = -float("inf") best_model = model.state_dict() no_update = 0 # time.sleep(10) for epoch in range(nb_epochs): phases = [] for i in range(validate_every): phases.append('train') phases.append('valid') for phase in phases: if phase == 'train': model.train() # model.apply(set_bn_eval) loader = tqdm(data_loader, total=len(data_loader), unit="batches") running_loss = 0 for i_batch, a in enumerate(loader): model.zero_grad() question_tokenized = a[0].to(device) attention_mask = a[1].to(device) positive_head = a[2].to(device) positive_tail = a[3].to(device) loss = model(question_tokenized=question_tokenized, attention_mask=attention_mask, p_head=positive_head, p_tail=positive_tail) loss.backward() optimizer.step() running_loss += loss.item() loader.set_postfix(Loss=running_loss / ((i_batch + 1) * batch_size), Epoch=epoch) loader.set_description('{}/{}'.format(epoch, nb_epochs)) loader.update() scheduler.step() elif phase == 'valid': model.eval() eps = 0.0001 answers, score = validate_v2(model=model, data_path=valid_data_path, entity2idx=entity2idx, train_dataloader=dataset, device=device, model_name=model_name) if score > best_score + eps: best_score = score no_update = 0 best_model = model.state_dict() print( hops + " hop Validation accuracy (no relation scoring) increased from previous epoch", score) # writeToFile(answers, 'results_' + model_name + '_' + hops + '.txt') # torch.save(best_model, "checkpoints/roberta_finetune/best_score_model.pt") # torch.save(best_model, "checkpoints/roberta_finetune/" + outfile + ".pt") elif (score < best_score + eps) and (no_update < patience): no_update += 1 print( "Validation accuracy decreases to %f from %f, %d more epoch to check" % (score, best_score, patience - no_update)) elif no_update == patience: print( "Model has exceed patience. Saving best model and exiting" ) # torch.save(best_model, "checkpoints/roberta_finetune/best_score_model.pt") # torch.save(best_model, "checkpoints/roberta_finetune/" + outfile + ".pt") exit() if epoch == nb_epochs - 1: print("Final Epoch has reached. Stoping and saving model.") # torch.save(best_model, "checkpoints/roberta_finetune/best_score_model.pt") # torch.save(best_model, "checkpoints/roberta_finetune/" + outfile + ".pt") exit()
def perform_experiment(data_path, mode, neg_batch_size, batch_size, shuffle, num_workers, nb_epochs, embedding_dim, hidden_dim, relation_dim, gpu, use_cuda, patience, freeze, validate_every, hops, lr, entdrop, reldrop, scoredrop, l3_reg, model_name, decay, ls, load_from, outfile, do_batch_norm, que_embedding_model, valid_data_path=None, test_data_path=None): webqsp_checkpoint_folder = f"../../checkpoints/WebQSP/{model_name}_{que_embedding_model}_{outfile}/" if not os.path.exists(webqsp_checkpoint_folder): os.makedirs(webqsp_checkpoint_folder) print('Loading entities and relations') kg_type = 'full' if 'half' in hops: kg_type = 'half' checkpoint_file = f"../../pretrained_models/embeddings/{model_name}_fbwq_{kg_type}/checkpoint_best.pt" print('Loading kg embeddings from', checkpoint_file) kge_checkpoint = load_checkpoint(checkpoint_file) kge_model = KgeModel.create_from(kge_checkpoint) kge_model.eval() e = getEntityEmbeddings(model_name, kge_model, hops) print('Loaded entities and relations') entity2idx, idx2entity, embedding_matrix = prepare_embeddings(e) # word2ix,idx2word, max_len = get_vocab(data) # hops = str(num_hops) device = torch.device(gpu if use_cuda else "cpu") model = RelationExtractor(embedding_dim=embedding_dim, num_entities=len(idx2entity), relation_dim=relation_dim, pretrained_embeddings=embedding_matrix, freeze=freeze, device=device, entdrop=entdrop, reldrop=reldrop, scoredrop=scoredrop, l3_reg=l3_reg, model=model_name, que_embedding_model=que_embedding_model, ls=ls, do_batch_norm=do_batch_norm) # time.sleep(10) if mode == 'train': data = process_text_file(data_path) dataset = DatasetWebQSP(data, e, entity2idx, que_embedding_model, model_name) # if model_name=="ComplEx": # data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) # else: # data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=custom_collate_fn) data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers) if load_from != '': # model.load_state_dict(torch.load("checkpoints/roberta_finetune/" + load_from + ".pt")) fname = f"checkpoints/{que_embedding_model}_finetune/{load_from}.pt" model.load_state_dict( torch.load(fname, map_location=lambda storage, loc: storage)) model.to(device) optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = ExponentialLR(optimizer, decay) optimizer.zero_grad() best_score = -float("inf") best_model = model.state_dict() no_update = 0 for epoch in range(nb_epochs): phases = [] for i in range(validate_every): phases.append('train') phases.append('valid') for phase in phases: if phase == 'train': model.train() # model.apply(set_bn_eval) loader = tqdm(data_loader, total=len(data_loader), unit="batches") running_loss = 0 for i_batch, a in enumerate(loader): model.zero_grad() question_tokenized = a[0].to(device) attention_mask = a[1].to(device) positive_head = a[2].to(device) positive_tail = a[3].to(device) loss = model(question_tokenized=question_tokenized, attention_mask=attention_mask, p_head=positive_head, p_tail=positive_tail) loss.backward() optimizer.step() running_loss += loss.item() loader.set_postfix(Loss=running_loss / ((i_batch + 1) * batch_size), Epoch=epoch) loader.set_description('{}/{}'.format( epoch, nb_epochs)) loader.update() scheduler.step() elif phase == 'valid': model.eval() eps = 0.0001 answers, score = test(model=model, data_path=valid_data_path, entity2idx=entity2idx, dataloader=dataset, device=device, model_name=model_name, return_hits_at_k=False) if score > best_score + eps: best_score = score no_update = 0 best_model = model.state_dict() print( hops + " hop Validation accuracy (no relation scoring) increased from previous epoch", score) writeToFile( answers, f'results/{model_name}_{que_embedding_model}_{outfile}.txt' ) torch.save( best_model, get_chkpt_path(model_name, que_embedding_model, outfile)) elif (score < best_score + eps) and (no_update < patience): no_update += 1 print( "Validation accuracy decreases to %f from %f, %d more epoch to check" % (score, best_score, patience - no_update)) elif no_update == patience: print( "Model has exceed patience. Saving best model and exiting" ) torch.save( best_model, get_chkpt_path(model_name, que_embedding_model, outfile)) exit(0) if epoch == nb_epochs - 1: print( "Final Epoch has reached. Stoping and saving model." ) torch.save( best_model, get_chkpt_path(model_name, que_embedding_model, outfile)) exit() # torch.save(model.state_dict(), "checkpoints/roberta_finetune/"+str(epoch)+".pt") # torch.save(model.state_dict(), "checkpoints/roberta_finetune/x.pt") elif mode == 'test': data = process_text_file(test_data_path) dataset = DatasetWebQSP(data, e, entity2idx, que_embedding_model, model_name) model_chkpt_file_path = get_chkpt_path(model_name, que_embedding_model, outfile) model.load_state_dict( torch.load(model_chkpt_file_path, map_location=lambda storage, loc: storage)) model.to(device) for parameter in model.parameters(): parameter.requires_grad = False model.eval() answers, accuracy, hits_at_1, hits_at_5, hits_at_10 = test( model=model, data_path=test_data_path, entity2idx=entity2idx, dataloader=dataset, device=device, model_name=model_name, return_hits_at_k=True) d = { 'KG-Model': model_name, 'KG-Type': kg_type, 'Que-Embedding-Model': que_embedding_model, 'Accuracy': [accuracy], 'Hits@1': [hits_at_1], 'Hits@5': [hits_at_5], 'Hits@10': [hits_at_10] } df = pd.DataFrame(data=d) df.to_csv(f"final_results.csv", mode='a', index=False, header=False)