def main(_run: Run, _config): p = Params(**_config) agent = create_agent(p) # load_agent(761, 200, agent) task = create_task(p) inner_loop = LearningLoop() n_experts = p.n_experts for epoch in tqdm(range(1, p.epochs + 1)): # with torch.autograd.detect_anomaly(): should_save_images = epoch % p.image_save_period == 0 observer = MultiObserver() if should_save_images else None # if epoch == 10000: # task = create_task(p, 0.5) # # if epoch == 20000: # task = create_task(p, 0.01) # if epoch % 100 == 0: # task_size = random.randint(n_experts, n_experts + 10) # p.task_size = task_size # task = create_task(p) # reset_ids = epoch > 10000 or (epoch - 1) % 2000 == 0 reset_ids = True agent.optim.zero_grad() agent.init_rollout(p.batch_size, n_experts, SearchAgentInitRolloutParams(reset_ids)) rollout_size = p.rollout_size # if epoch > 10000: # rollout_size = rollout_size + random.randint(-5, 5) err = inner_loop.train_fixed_steps(agent, task, rollout_size, p.learning_exp_decay, p.learning_rollout_steps_clip, observer) err.backward() _run.log_scalar('loss', err.cpu().detach().item()) # save grads # if observer is not None: # params = [p for p in agent.parameters() if p.grad is not None] # for i, param in enumerate(params): # observer.current.add_tensor(f'params_{i}', param.detach().cpu()) # observer.current.add_tensor(f'params-grads_{i}', param.grad.detach().cpu()) agent.optim.step() if observer is not None: # print(f'Saving plots ep: {epoch}') tensors = [o.tensors_as_dict() for o in observer.observers] sacred_writer.save_tensor(tensors, 'tensors', epoch) # log targets if epoch % p.save_period == 0: sacred_writer.save_model(agent, 'agent', epoch)
def create_run(experiment, command_name, config_updates=None, named_configs=(), force=False): sorted_ingredients = gather_ingredients_topological(experiment) scaffolding = create_scaffolding(experiment, sorted_ingredients) # --------- configuration process ------------------- distribute_named_configs(scaffolding, named_configs) config_updates = config_updates or {} config_updates = convert_to_nested_dict(config_updates) root_logger, run_logger = initialize_logging(experiment, scaffolding) past_paths = set() for scaffold in scaffolding.values(): scaffold.pick_relevant_config_updates(config_updates, past_paths) past_paths.add(scaffold.path) scaffold.gather_fallbacks() scaffold.set_up_config() # update global config config = get_configuration(scaffolding) # run config hooks config_updates = scaffold.run_config_hooks(config, config_updates, command_name, run_logger) for scaffold in reversed(list(scaffolding.values())): scaffold.set_up_seed() # partially recursive config = get_configuration(scaffolding) config_modifications = get_config_modifications(scaffolding) # ---------------------------------------------------- experiment_info = experiment.get_experiment_info() host_info = get_host_info() main_function = get_command(scaffolding, command_name) pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks] post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks] run = Run(config, config_modifications, main_function, copy(experiment.observers), root_logger, run_logger, experiment_info, host_info, pre_runs, post_runs, experiment.captured_out_filter) if hasattr(main_function, 'unobserved'): run.unobserved = main_function.unobserved run.force = force for scaffold in scaffolding.values(): scaffold.finalize_initialization(run=run) return run
def _save_fitable(self, run: Run, fitable: Model): """ :param run: sacred.Run object. see sacred documentation for more details on utility. :param fitable: tensorflow.keras.Model object. """ path = self.exp_config["run_config"]["model_path"] if self.exp_config["run_config"]["save_verbosity"] > 0: fitable.summary() fitable.save(self.exp_config["run_config"]["model_path"]) run.add_artifact(path)
def create_run(experiment, command_name, config_updates=None, named_configs=(), force=False): sorted_ingredients = gather_ingredients_topological(experiment) scaffolding = create_scaffolding(experiment, sorted_ingredients) # --------- configuration process ------------------- distribute_named_configs(scaffolding, named_configs) config_updates = config_updates or {} config_updates = convert_to_nested_dict(config_updates) root_logger, run_logger = initialize_logging(experiment, scaffolding) past_paths = set() for scaffold in scaffolding.values(): scaffold.pick_relevant_config_updates(config_updates, past_paths) past_paths.add(scaffold.path) scaffold.gather_fallbacks() scaffold.set_up_config() # update global config config = get_configuration(scaffolding) # run config hooks config_updates = scaffold.run_config_hooks(config, config_updates, command_name, run_logger) for scaffold in reversed(list(scaffolding.values())): scaffold.set_up_seed() # partially recursive config = get_configuration(scaffolding) config_modifications = get_config_modifications(scaffolding) # ---------------------------------------------------- experiment_info = experiment.get_experiment_info() host_info = get_host_info() main_function = get_command(scaffolding, command_name) pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks] post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks] run = Run(config, config_modifications, main_function, experiment.observers, root_logger, run_logger, experiment_info, host_info, pre_runs, post_runs) if hasattr(main_function, 'unobserved'): run.unobserved = main_function.unobserved run.force = force for scaffold in scaffolding.values(): scaffold.finalize_initialization(run=run) return run
def create_run(experiment, command_name, config_updates=None, log_level=None, named_configs=()): scaffolding = create_scaffolding(experiment) distribute_config_updates(scaffolding, config_updates) distribute_named_configs(scaffolding, named_configs) for scaffold in scaffolding.values(): scaffold.set_up_config() for scaffold in reversed(list(scaffolding.values())): scaffold.set_up_seed() # partially recursive config = get_configuration(scaffolding) config_modifications = get_config_modifications(scaffolding) experiment_info = experiment._get_info() host_info = get_host_info() main_function = get_command(scaffolding, command_name) logger = initialize_logging(experiment, scaffolding, log_level) run = Run(config, config_modifications, main_function, experiment.observers, logger, experiment_info, host_info) for scaffold in scaffolding.values(): scaffold.finalize_initialization(run=run) return run
def run(): config = {'a': 17, 'foo': {'bar': True, 'baz': False}, 'seed': 1234} config_mod = ConfigSummary() main_func = mock.Mock(return_value=123) logger = mock.Mock() observer = [mock.Mock()] return Run(config, config_mod, main_func, observer, logger, logger, {}, {}, [], [])
def run(): config = {"a": 17, "foo": {"bar": True, "baz": False}, "seed": 1234} config_mod = ConfigSummary() signature = mock.Mock() signature.name = "main_func" main_func = mock.Mock(return_value=123, prefix="", signature=signature) logger = mock.Mock() observer = [mock.Mock(priority=10)] return Run(config, config_mod, main_func, observer, logger, logger, {}, {}, [], [])
def run(): config = {'a': 17, 'foo': {'bar': True, 'baz': False}, 'seed': 1234} config_mod = ConfigSummary() signature = mock.Mock() signature.name = 'main_func' main_func = mock.Mock(return_value=123, prefix='', signature=signature) logger = mock.Mock() observer = [mock.Mock(priority=10)] return Run(config, config_mod, main_func, observer, logger, logger, {}, {}, [], [])
def train(train_corpus: str, dev_corpus: str, char_int: int, save_path: str, test_corpus: str = None, dropout: float = 0.5, num_epochs: int = 10, lm_loss_scale=0.1, device: int = 0, save=False, _run: Run = None): _run.add_resource(train_corpus) _run.add_resource(dev_corpus) trainer = TrainerMod(_run, train_corpus, save_path, dev_corpus, num_epochs=num_epochs, dropout=dropout, char_integration_method=char_int, lm_loss_scale=lm_loss_scale, save=save, device=device) trainer.run() if test_corpus: _run.add_resource(test_corpus) ex.run_command('test', config_updates={ 'save_path': save_path, 'test_corpus': test_corpus, 'device': device })
def run(): config = {"a": 17, "foo": {"bar": True, "baz": False}, "seed": 1234} config_mod = ConfigSummary() signature = mock.Mock() signature.name = "main_func" def side_effect(*args): # TODO : Type checking ? Does mock have a function for this ? for arg in args: arg = arg + 10 return args main_func = mock.Mock(return_value=123, prefix="", signature=signature, side_effect=side_effect) logger = mock.Mock() observer = [mock.Mock(priority=10)] return Run(config, config_mod, main_func, observer, logger, logger, {}, {}, [], [])
def train_w_pretrained(train_corpus: str, dev_corpus: str, char_int: int, pretrained_embeddings: str, save_path: str, test_corpus: str = None, word_embedding_size: int = 300, update_pretrained_embedding: bool = True, dropout: float = 0.5, num_epochs: int = 10, lm_loss_scale=0.1, device: int = 0, save=False, _run: Run = None): _run.add_resource(train_corpus) _run.add_resource(dev_corpus) trainer = TrainerMod( _run, train_corpus, save_path, dev_corpus, word_embedding_size=word_embedding_size, num_epochs=num_epochs, dropout=dropout, char_integration_method=char_int, lm_loss_scale=lm_loss_scale, save=save, device=device, pretrained_embeddings=pretrained_embeddings, update_pretrained_embedding=update_pretrained_embedding, model_class=NewSequenceLabeler) trainer.run() if test_corpus: _run.add_resource(test_corpus) ex.run_command('test_w_pretrained', config_updates={ 'save_path': save_path, 'test_corpus': test_corpus, 'device': device })
def create_run(experiment, command_name, config_updates=None, named_configs=(), force=False): sorted_ingredients = gather_ingredients_topological(experiment) scaffolding = create_scaffolding(experiment, sorted_ingredients) # get all split non-empty prefixes sorted from deepest to shallowest prefixes = sorted([s.split('.') for s in scaffolding if s != ''], reverse=True, key=lambda p: len(p)) # --------- configuration process ------------------- # Phase 1: Config updates config_updates = config_updates or {} config_updates = convert_to_nested_dict(config_updates) root_logger, run_logger = initialize_logging(experiment, scaffolding) distribute_config_updates(prefixes, scaffolding, config_updates) # Phase 2: Named Configs for ncfg in named_configs: scaff, cfg_name = get_scaffolding_and_config_name(ncfg, scaffolding) scaff.gather_fallbacks() ncfg_updates = scaff.run_named_config(cfg_name) distribute_presets(prefixes, scaffolding, ncfg_updates) for ncfg_key, value in iterate_flattened(ncfg_updates): set_by_dotted_path(config_updates, join_paths(scaff.path, ncfg_key), value) distribute_config_updates(prefixes, scaffolding, config_updates) # Phase 3: Normal config scopes for scaffold in scaffolding.values(): scaffold.gather_fallbacks() scaffold.set_up_config() # update global config config = get_configuration(scaffolding) # run config hooks config_updates = scaffold.run_config_hooks(config, config_updates, command_name, run_logger) # Phase 4: finalize seeding for scaffold in reversed(list(scaffolding.values())): scaffold.set_up_seed() # partially recursive config = get_configuration(scaffolding) config_modifications = get_config_modifications(scaffolding) # ---------------------------------------------------- experiment_info = experiment.get_experiment_info() host_info = get_host_info() main_function = get_command(scaffolding, command_name) pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks] post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks] run = Run(config, config_modifications, main_function, copy(experiment.observers), root_logger, run_logger, experiment_info, host_info, pre_runs, post_runs, experiment.captured_out_filter) if hasattr(main_function, 'unobserved'): run.unobserved = main_function.unobserved run.force = force for scaffold in scaffolding.values(): scaffold.finalize_initialization(run=run) return run
def test(model_filename: str, test_corpus: str, pacrf: str, window_size: int = 5, _run: Run = None, _log: logger = None): _run.add_resource(test_corpus) _run.add_resource(f'{model_filename}.pkl') test_sents, _ = get_tagged_sents_and_words(test_corpus) X_test = [sent2features(s, window_size) for s in test_sents] y_test = [sent2labels(s) for s in test_sents] _log.info(f'load from: {model_filename}.pkl') # TODO modified this to call partial-crf via Popen command crf = sklearn_crfsuite.CRF(model_filename=model_filename) # TODO modified this to call partial-crf via Popen command y_pred = crf.predict(X_test) # TODO modified this to read partial-crf via tempfile overall, by_type = evaluate(y_test, y_pred) _run.info[f'overall_f1'] = overall.f1_score _run.log_scalar('overall_f1', overall.f1_score) _run.info[f'overall_precision'] = overall.precision _run.log_scalar('overall_precision', overall.precision) _run.info[f'overall_recall'] = overall.recall _run.log_scalar('overall_recall', overall.recall) _log.info(f'Overall F1 score: {overall.f1_score}') for _, key in enumerate(sorted(by_type.keys())): for metric_key in by_type[key]._fields: metric_val = getattr(by_type[key], metric_key) _run.info[f'{key}-{metric_key}'] = metric_val _run.log_scalar(f'{key}-{metric_key}', metric_val) _log.info(f'{key}-{metric_key}: {metric_val}')
def train(train_corpus: str, dev_corpus: str, pacrf: str, model_filename: str, labels: List, c1: float = 0.0, c2: float = 1.0, algorithm: str = 'lbfgs', max_iterations: int = None, all_possible_transitions: bool = False, window_size: int = 0, _run: Run = None, _log: logger = None): """ running crf experiment """ _run.add_resource(train_corpus) _run.add_resource(dev_corpus) train_sents, _ = get_tagged_sents_and_words(train_corpus) dev_sents, _ = get_tagged_sents_and_words(dev_corpus) tmp_train = tempfile.NamedTemporaryFile(mode='w+') # temp_train_corpus = open(f'{model_filename}-{train_corpus}.feature', mode='w+') print_corpus(train_sents, labels, tmp_train, window_size=window_size) # X_dev = [sent2features(s, window_size) for s in dev_sents] y_dev = [sent2labels_colmap(s, col=1) for s in dev_sents] tmp_dev = tempfile.NamedTemporaryFile(mode='w+') # temp_test_corpus = open(f'{model_filename}-{test_corpus}.feature', mode='w+') print_corpus(dev_sents, labels, tmp_dev, window_size=window_size) # to call partial-crf via Popen command # command = f'{pacrf} learn -m {model_filename} -a {algorithm} {temp_train_corpus}' # call([pacrf, "--help"]) crfsuire_proc = Popen([pacrf, "learn", "-m", model_filename, "-a", algorithm, \ "-p", f"c1={c1}", "-p", f"c2={c2}", tmp_train.name]) out, err = crfsuire_proc.communicate() print(out) print(err) # os.system(f'{pacrf} learn -m {model_filename} -a {algorithm} {tmp_train.name}') tmp_train.close() tmp_pred = tempfile.NamedTemporaryFile(mode='w+') # cmd_out([pacrf, "tag", "-m", model_filename, tmp_dev.name, ">", tmp_pred.name]) _run.add_artifact(model_filename) # TODO modified this to call partial-crf via Popen command # y_pred = crf.predict(X_dev) y_pred = get_tagged_sents_and_words(tmp_pred.name) print(y_pred) y_pred = [sent2labels_colmap(s, 0) for s in y_pred] # TODO modified this to read partial-crf via tempfile overall, by_type = evaluate(y_dev, y_pred) tmp_pred.close() tmp_dev.close() _run.info[f'overall_f1'] = overall.f1_score _run.log_scalar('overall_f1', overall.f1_score) _run.info[f'overall_precision'] = overall.precision _run.log_scalar('overall_precision', overall.precision) _run.info[f'overall_recall'] = overall.recall _run.log_scalar('overall_recall', overall.recall) _log.info(f'Overall F1 score: {overall.f1_score}') for _, key in enumerate(sorted(by_type.keys())): for metric_key in by_type[key]._fields: metric_val = getattr(by_type[key], metric_key) _run.info[f'{key}-{metric_key}'] = metric_val _run.log_scalar(f'{key}-{metric_key}', metric_val) _log.info(f'{key}-{metric_key}: {metric_val}')
def create_run(experiment, command_name, config_updates=None, named_configs=(), force=False, log_level=None): sorted_ingredients = gather_ingredients_topological(experiment) scaffolding = create_scaffolding(experiment, sorted_ingredients) # get all split non-empty prefixes sorted from deepest to shallowest prefixes = sorted([s.split('.') for s in scaffolding if s != ''], reverse=True, key=lambda p: len(p)) # --------- configuration process ------------------- # Phase 1: Config updates config_updates = config_updates or {} config_updates = convert_to_nested_dict(config_updates) root_logger, run_logger = initialize_logging(experiment, scaffolding, log_level) distribute_config_updates(prefixes, scaffolding, config_updates) # Phase 2: Named Configs for ncfg in named_configs: scaff, cfg_name = get_scaffolding_and_config_name(ncfg, scaffolding) scaff.gather_fallbacks() ncfg_updates = scaff.run_named_config(cfg_name) distribute_presets(prefixes, scaffolding, ncfg_updates) for ncfg_key, value in iterate_flattened(ncfg_updates): set_by_dotted_path(config_updates, join_paths(scaff.path, ncfg_key), value) distribute_config_updates(prefixes, scaffolding, config_updates) # Phase 3: Normal config scopes for scaffold in scaffolding.values(): scaffold.gather_fallbacks() scaffold.set_up_config() # update global config config = get_configuration(scaffolding) # run config hooks config_hook_updates = scaffold.run_config_hooks( config, command_name, run_logger) recursive_update(scaffold.config, config_hook_updates) # Phase 4: finalize seeding for scaffold in reversed(list(scaffolding.values())): scaffold.set_up_seed() # partially recursive config = get_configuration(scaffolding) config_modifications = get_config_modifications(scaffolding) # ---------------------------------------------------- experiment_info = experiment.get_experiment_info() host_info = get_host_info() main_function = get_command(scaffolding, command_name) pre_runs = [pr for ing in sorted_ingredients for pr in ing.pre_run_hooks] post_runs = [pr for ing in sorted_ingredients for pr in ing.post_run_hooks] run = Run(config, config_modifications, main_function, copy(experiment.observers), root_logger, run_logger, experiment_info, host_info, pre_runs, post_runs, experiment.captured_out_filter) if hasattr(main_function, 'unobserved'): run.unobserved = main_function.unobserved run.force = force for scaffold in scaffolding.values(): scaffold.finalize_initialization(run=run) return run
def train(train_corpus: str, dev_corpus: str, c1: float = 0.0, c2: float = 0.0, algorithm: str = 'lbfgs', max_iterations: int = 100, all_possible_transitions: bool = False, window_size: int = 1, model_filename: str = None, _run: Run = None, _log: logger = None): """ running crf experiment """ _run.add_resource(train_corpus) _run.add_resource(dev_corpus) train_sents, _ = get_tagged_sents_and_words(train_corpus) dev_sents, _ = get_tagged_sents_and_words(dev_corpus) X_train = [sent2features(s, window_size) for s in train_sents] y_train = [sent2labels(s) for s in train_sents] X_dev = [sent2features(s, window_size) for s in dev_sents] y_dev = [sent2labels(s) for s in dev_sents] crf = sklearn_crfsuite.CRF( algorithm=algorithm, c1=c1, c2=c2, max_iterations=max_iterations, all_possible_transitions=all_possible_transitions, model_filename=model_filename, ) crf.fit(X_train, y_train) y_pred = crf.predict(X_dev) overall, by_type = evaluate(y_dev, y_pred) _run.info[f'overall_f1'] = overall.f1_score _run.log_scalar('overall_f1', overall.f1_score) _run.info[f'overall_precision'] = overall.precision _run.log_scalar('overall_precision', overall.precision) _run.info[f'overall_recall'] = overall.recall _run.log_scalar('overall_recall', overall.recall) _log.info(f'Overall F1 score: {overall.f1_score}') for _, key in enumerate(sorted(by_type.keys())): for metric_key in by_type[key]._fields: metric_val = getattr(by_type[key], metric_key) _run.info[f'{key}-{metric_key}'] = metric_val _run.log_scalar(f'{key}-{metric_key}', metric_val) _log.info(f'{key}-{metric_key}: {metric_val}') if model_filename is not None: _log.info(f'saving to: {model_filename}.pkl') joblib.dump(crf, f'{model_filename}.pkl') _run.add_artifact(f'{model_filename}.pkl')
def embedding_generation(dataset, dim, model, rel_model, loss_fn, encoder_name, regularizer, max_len, num_negatives, batch_size, emb_batch_size, max_epochs, checkpoint, use_cached_text, _run: Run, _log: Logger): drop_stopwords = model in { 'bert-bow', 'bert-dkrl', 'glove-bow', 'glove-dkrl' } # converted KG as input triples_file = f'data/{dataset}/all-triples.tsv' if device != torch.device('cpu'): num_devices = torch.cuda.device_count() if batch_size % num_devices != 0: raise ValueError(f'Batch size ({batch_size}) must be a multiple of' f' the number of CUDA devices ({num_devices})') _log.info(f'CUDA devices used: {num_devices}') else: num_devices = 1 _log.info('Training on CPU') if model == 'transductive': train_data = GraphDataset(triples_file, num_negatives, write_maps_file=True, num_devices=num_devices) else: if model.startswith('bert') or model == 'blp': tokenizer = BertTokenizer.from_pretrained(encoder_name) else: tokenizer = GloVeTokenizer('data/glove/glove.6B.300d-maps.pt') train_data = TextGraphDataset(triples_file, num_negatives, max_len, tokenizer, drop_stopwords, write_maps_file=True, use_cached_text=use_cached_text, num_devices=num_devices) # train_loader = DataLoader(train_data, batch_size, shuffle=True, # collate_fn=train_data.collate_fn, # num_workers=0, drop_last=True) # Build graph with all triples to compute filtered metrics graph = nx.MultiDiGraph() all_triples = torch.tensor(train_data.triples) graph.add_weighted_edges_from(all_triples.tolist()) train_ent = set(train_data.entities.tolist()) _run.log_scalar('num_train_entities', len(train_ent)) train_ent = torch.tensor(list(train_ent)) model = utils.get_model(model, dim, rel_model, loss_fn, len(train_ent), train_data.num_rels, encoder_name, regularizer) if device != torch.device('cpu'): model = torch.nn.DataParallel(model).to(device) tokens = str(dataset).split("_") print(tokens[-1]) # load language model if tokens[-1] == "questions": model.load_state_dict(torch.load("models/model-questions.pt")) else: model.load_state_dict(torch.load("models/model-entities.pt")) _log.info('Evaluating on training set, Embedding generation') ent_emb = embedding(model, train_data, train_ent, emb_batch_size, filtering_graph=None) # Save final entity embeddings obtained with trained encoder torch.save(ent_emb, osp.join(OUT_PATH, f'ent_emb-{_run._id}.pt')) torch.save(train_ent, osp.join(OUT_PATH, f'ents-{_run._id}.pt'))
def test(test_corpus: str, model_output: str, col_ref: int = 0, col_hyp: int = 0, _run: Run = None, _log: logger = None): test_sents, _ = get_tagged_sents_and_words(test_corpus) print(f'num sentences: {len(test_sents)}') y_test = [sent2labels_colmap(s, col=int(col_ref)) for s in test_sents] yout_sents, _ = get_tagged_sents_and_words(model_output) print(f'num sentences: {len(yout_sents)}') y_pred = [sent2labels_colmap(s, col=int(col_hyp)) for s in yout_sents] if len(y_test) != len(y_pred): for i, j in zip_longest(y_test, y_pred): print(i, j) overall, by_type = evaluate(y_test, y_pred) print(overall) print(by_type) _run.info[f'overall_f1'] = overall.f1_score _run.log_scalar('overall_f1', overall.f1_score) _run.info[f'overall_precision'] = overall.precision _run.log_scalar('overall_precision', overall.precision) _run.info[f'overall_recall'] = overall.recall _run.log_scalar('overall_recall', overall.recall) _log.info(f'Overall F1 score: {overall.f1_score}') for _, key in enumerate(sorted(by_type.keys())): for metric_key in by_type[key]._fields: metric_val = getattr(by_type[key], metric_key) _run.info[f'{key}-{metric_key}'] = metric_val _run.log_scalar(f'{key}-{metric_key}', metric_val) _log.info(f'{key}-{metric_key}: {metric_val}')
def link_prediction(dataset, inductive, dim, model, rel_model, loss_fn, encoder_name, regularizer, max_len, num_negatives, lr, use_scheduler, batch_size, emb_batch_size, eval_batch_size, max_epochs, checkpoint, use_cached_text, _run: Run, _log: Logger): drop_stopwords = model in { 'bert-bow', 'bert-dkrl', 'glove-bow', 'glove-dkrl' } prefix = 'ind-' if inductive and model != 'transductive' else '' triples_file = f'data/{dataset}/{prefix}train.tsv' if device != torch.device('cpu'): num_devices = torch.cuda.device_count() if batch_size % num_devices != 0: raise ValueError(f'Batch size ({batch_size}) must be a multiple of' f' the number of CUDA devices ({num_devices})') _log.info(f'CUDA devices used: {num_devices}') else: num_devices = 1 _log.info('Training on CPU') if model == 'transductive': train_data = GraphDataset(triples_file, num_negatives, write_maps_file=True, num_devices=num_devices) else: if model.startswith('bert') or model == 'blp': tokenizer = BertTokenizer.from_pretrained(encoder_name) else: tokenizer = GloVeTokenizer('data/glove/glove.6B.300d-maps.pt') train_data = TextGraphDataset(triples_file, num_negatives, max_len, tokenizer, drop_stopwords, write_maps_file=True, use_cached_text=use_cached_text, num_devices=num_devices) train_loader = DataLoader(train_data, batch_size, shuffle=True, collate_fn=train_data.collate_fn, num_workers=0, drop_last=True) train_eval_loader = DataLoader(train_data, eval_batch_size) valid_data = GraphDataset(f'data/{dataset}/{prefix}dev.tsv') valid_loader = DataLoader(valid_data, eval_batch_size) test_data = GraphDataset(f'data/{dataset}/{prefix}test.tsv') test_loader = DataLoader(test_data, eval_batch_size) # Build graph with all triples to compute filtered metrics if dataset != 'Wikidata5M': graph = nx.MultiDiGraph() all_triples = torch.cat( (train_data.triples, valid_data.triples, test_data.triples)) graph.add_weighted_edges_from(all_triples.tolist()) train_ent = set(train_data.entities.tolist()) train_val_ent = set(valid_data.entities.tolist()).union(train_ent) train_val_test_ent = set( test_data.entities.tolist()).union(train_val_ent) val_new_ents = train_val_ent.difference(train_ent) test_new_ents = train_val_test_ent.difference(train_val_ent) else: graph = None train_ent = set(train_data.entities.tolist()) train_val_ent = set(valid_data.entities.tolist()) train_val_test_ent = set(test_data.entities.tolist()) val_new_ents = test_new_ents = None _run.log_scalar('num_train_entities', len(train_ent)) train_ent = torch.tensor(list(train_ent)) train_val_ent = torch.tensor(list(train_val_ent)) train_val_test_ent = torch.tensor(list(train_val_test_ent)) model = utils.get_model(model, dim, rel_model, loss_fn, len(train_val_test_ent), train_data.num_rels, encoder_name, regularizer) if checkpoint is not None: model.load_state_dict(torch.load(checkpoint, map_location='cpu')) if device != torch.device('cpu'): model = torch.nn.DataParallel(model).to(device) optimizer = Adam(model.parameters(), lr=lr) total_steps = len(train_loader) * max_epochs if use_scheduler: warmup = int(0.2 * total_steps) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=warmup, num_training_steps=total_steps) best_valid_mrr = 0.0 checkpoint_file = osp.join(OUT_PATH, f'model-{_run._id}.pt') for epoch in range(1, max_epochs + 1): train_loss = 0 for step, data in enumerate(train_loader): loss = model(*data).mean() optimizer.zero_grad() loss.backward() optimizer.step() if use_scheduler: scheduler.step() train_loss += loss.item() if step % int(0.05 * len(train_loader)) == 0: _log.info(f'Epoch {epoch}/{max_epochs} ' f'[{step}/{len(train_loader)}]: {loss.item():.6f}') _run.log_scalar('batch_loss', loss.item()) _run.log_scalar('train_loss', train_loss / len(train_loader), epoch) if dataset != 'Wikidata5M': _log.info('Evaluating on sample of training set') eval_link_prediction(model, train_eval_loader, train_data, train_ent, epoch, emb_batch_size, prefix='train', max_num_batches=len(valid_loader)) _log.info('Evaluating on validation set') val_mrr, _ = eval_link_prediction(model, valid_loader, train_data, train_val_ent, epoch, emb_batch_size, prefix='valid') # Keep checkpoint of best performing model (based on raw MRR) if val_mrr > best_valid_mrr: best_valid_mrr = val_mrr torch.save(model.state_dict(), checkpoint_file) # Evaluate with best performing checkpoint if max_epochs > 0: model.load_state_dict(torch.load(checkpoint_file)) if dataset == 'Wikidata5M': graph = nx.MultiDiGraph() graph.add_weighted_edges_from(valid_data.triples.tolist()) if dataset == 'Wikidata5M': graph = nx.MultiDiGraph() graph.add_weighted_edges_from(test_data.triples.tolist()) _log.info('Evaluating on test set') _, ent_emb = eval_link_prediction(model, test_loader, train_data, train_val_test_ent, max_epochs + 1, emb_batch_size, prefix='test', filtering_graph=graph, new_entities=test_new_ents, return_embeddings=True) # Save final entity embeddings obtained with trained encoder torch.save(ent_emb, osp.join(OUT_PATH, f'ent_emb-{_run._id}.pt')) torch.save(train_val_test_ent, osp.join(OUT_PATH, f'ents-{_run._id}.pt'))
def sacred_main(_run: Run, seed, showoff, out_dir, batch_size, epochs, tags, model_desc, experiment_id, weights, train_examples, val_examples, deterministic, train_datasets, val_datasets, lr, lr_milestones, lr_gamma, optim_algorithm): seed_all(seed) init_algorithms(deterministic=deterministic) exp_out_dir = None if out_dir: exp_out_dir = path.join(out_dir, experiment_id) makedirs(exp_out_dir, exist_ok=True) print(f'Experiment ID: {experiment_id}') #### # Model #### if weights is None: model = create_model(model_desc) else: details = torch.load(weights) model_desc = details['model_desc'] model = create_model(model_desc) model.load_state_dict(details['state_dict']) model.to(global_opts['device']) print(json.dumps(model_desc, sort_keys=True, indent=2)) #### # Data #### train_loader = create_train_dataloader(train_datasets, model.data_specs, batch_size, train_examples) if len(val_datasets) > 0: val_loader = create_val_dataloader(val_datasets, model.data_specs, batch_size, val_examples) else: val_loader = None #### # Reporting #### reporter = Reporter(with_val=(val_loader is not None)) reporter.setup_console_output() reporter.setup_sacred_output(_run) notebook = None if showoff: title = '3D pose model ({}@{})'.format(model_desc['type'], model_desc['version']) notebook = create_showoff_notebook(title, tags) reporter.setup_showoff_output(notebook) def set_progress(value): if notebook is not None: notebook.set_progress(value) tel = reporter.telemetry tel['config'].set_value(_run.config) tel['host_info'].set_value(get_host_info()) #### # Optimiser #### if optim_algorithm == '1cycle': from torch import optim optimiser = optim.SGD(model.parameters(), lr=0) scheduler = make_1cycle(optimiser, epochs * len(train_loader), lr_max=lr, momentum=0.9) else: scheduler = learning_schedule(model.parameters(), optim_algorithm, lr, lr_milestones, lr_gamma) #### # Training #### model_file = None if exp_out_dir: model_file = path.join(exp_out_dir, 'model-latest.pth') with open(path.join(exp_out_dir, 'config.json'), 'w') as f: json.dump(tel['config'].value(), f, sort_keys=True, indent=2) for epoch in range(epochs): tel['epoch'].set_value(epoch) print('> Epoch {:3d}/{:3d}'.format(epoch + 1, epochs)) def on_train_progress(samples_processed): so_far = epoch * len(train_loader.dataset) + samples_processed total = epochs * len(train_loader.dataset) set_progress(so_far / total) do_training_pass(epoch, model, tel, train_loader, scheduler, on_train_progress) if val_loader: do_validation_pass(epoch, model, tel, val_loader) _run.result = tel['train_pck'].value()[0] if model_file is not None: state = { 'state_dict': model.state_dict(), 'model_desc': model_desc, 'train_datasets': train_datasets, 'optimizer': scheduler.optimizer.state_dict(), 'epoch': epoch + 1, } torch.save(state, model_file) tel.step() # Add the final model as a Sacred artifact if model_file is not None and path.isfile(model_file): _run.add_artifact(model_file) set_progress(1.0) return _run.result
def eval_link_prediction(model, triples_loader, text_dataset, entities, epoch, emb_batch_size, _run: Run, _log: Logger, prefix='', max_num_batches=None, filtering_graph=None, new_entities=None, return_embeddings=False): compute_filtered = filtering_graph is not None mrr_by_position = torch.zeros(3, dtype=torch.float).to(device) mrr_pos_counts = torch.zeros_like(mrr_by_position) rel_categories = triples_loader.dataset.rel_categories.to(device) mrr_by_category = torch.zeros([2, 4], dtype=torch.float).to(device) mrr_cat_count = torch.zeros([1, 4], dtype=torch.float).to(device) hit_positions = [1, 2, 3, 5, 8] # 3610 hits_at_k = {pos: 0.0 for pos in hit_positions} mrr = 0.0 mrr_filt = 0.0 hits_at_k_filt = {pos: 0.0 for pos in hit_positions} if device != torch.device('cpu'): model = model.module if isinstance(model, models.InductiveLinkPrediction): num_entities = entities.shape[0] if compute_filtered: max_ent_id = max(filtering_graph.nodes) else: max_ent_id = entities.max() ent2idx = utils.make_ent2idx(entities, max_ent_id) else: # Transductive models have a lookup table of embeddings num_entities = model.ent_emb.num_embeddings ent2idx = torch.arange(num_entities) entities = ent2idx # Create embedding lookup table for evaluation ent_emb = torch.zeros((num_entities, model.dim), dtype=torch.float, device=device) idx = 0 num_iters = np.ceil(num_entities / emb_batch_size) iters_count = 0 while idx < num_entities: # Get a batch of entity IDs and encode them batch_ents = entities[idx:idx + emb_batch_size] if isinstance(model, models.InductiveLinkPrediction): # Encode with entity descriptions data = text_dataset.get_entity_description(batch_ents) text_tok, text_mask, text_len = data batch_emb = model( text_tok.unsqueeze(1).to(device), text_mask.unsqueeze(1).to(device)) else: # Encode from lookup table batch_emb = model(batch_ents) ent_emb[idx:idx + batch_ents.shape[0]] = batch_emb iters_count += 1 if iters_count % np.ceil(0.2 * num_iters) == 0: _log.info(f'[{idx + batch_ents.shape[0]:,}/{num_entities:,}]') idx += emb_batch_size ent_emb = ent_emb.unsqueeze(0) batch_count = 0 _log.info('Computing metrics on set of triples') total = len(triples_loader) if max_num_batches is None else max_num_batches for i, triples in enumerate(triples_loader): print(type(triples)) if max_num_batches is not None and i == max_num_batches: break heads, tails, rels = torch.chunk(triples, chunks=3, dim=1) # Map entity IDs to positions in ent_emb heads = ent2idx[heads].to(device) tails = ent2idx[tails].to(device) assert heads.min() >= 0 assert tails.min() >= 0 # Embed triple head_embs = ent_emb.squeeze()[heads] tail_embs = ent_emb.squeeze()[tails] rel_embs = model.rel_emb(rels.to(device)) # Score all possible heads and tails heads_predictions = model.score_fn(ent_emb, tail_embs, rel_embs) tails_predictions = model.score_fn(head_embs, ent_emb, rel_embs) pred_ents = torch.cat((heads_predictions, tails_predictions)) true_ents = torch.cat((heads, tails)) hits = utils.hit_at_k(pred_ents, true_ents, hit_positions) for j, h in enumerate(hits): hits_at_k[hit_positions[j]] += h mrr += utils.mrr(pred_ents, true_ents).mean().item() if compute_filtered: filters = utils.get_triple_filters(triples, filtering_graph, num_entities, ent2idx) heads_filter, tails_filter = filters # Filter entities by assigning them the lowest score in the batch filter_mask = torch.cat((tails_filter, tails_filter)).to(device) pred_ents[filter_mask] = pred_ents.min() - 1.0 hits_filt = utils.hit_at_k(pred_ents, true_ents, hit_positions) for j, h in enumerate(hits_filt): hits_at_k_filt[hit_positions[j]] += h mrr_filt_per_triple = utils.mrr(pred_ents, true_ents) mrr_filt += mrr_filt_per_triple.mean().item() if new_entities is not None: by_position = utils.split_by_new_position( triples, mrr_filt_per_triple, new_entities) batch_mrr_by_position, batch_mrr_pos_counts = by_position mrr_by_position += batch_mrr_by_position mrr_pos_counts += batch_mrr_pos_counts if triples_loader.dataset.has_rel_categories: by_category = utils.split_by_category(triples, mrr_filt_per_triple, rel_categories) batch_mrr_by_cat, batch_mrr_cat_count = by_category mrr_by_category += batch_mrr_by_cat mrr_cat_count += batch_mrr_cat_count batch_count += 1 if (i + 1) % int(0.2 * total) == 0: _log.info(f'[{i + 1:,}/{total:,}]') for hits_dict in (hits_at_k, hits_at_k_filt): for k in hits_dict: hits_dict[k] /= batch_count mrr = mrr / batch_count mrr_filt = mrr_filt / batch_count log_str = f'{prefix} mrr: {mrr:.4f} ' _run.log_scalar(f'{prefix}_mrr', mrr, epoch) for k, value in hits_at_k.items(): log_str += f'hits@{k}: {value:.4f} ' _run.log_scalar(f'{prefix}_hits@{k}', value, epoch) if compute_filtered: log_str += f'mrr_filt: {mrr_filt:.4f} ' _run.log_scalar(f'{prefix}_mrr_filt', mrr_filt, epoch) for k, value in hits_at_k_filt.items(): log_str += f'hits@{k}_filt: {value:.4f} ' _run.log_scalar(f'{prefix}_hits@{k}_filt', value, epoch) _log.info(log_str) if new_entities is not None and compute_filtered: mrr_pos_counts[mrr_pos_counts < 1.0] = 1.0 mrr_by_position = mrr_by_position / mrr_pos_counts log_str = '' for i, t in enumerate( (f'{prefix}_mrr_filt_both_new', f'{prefix}_mrr_filt_head_new', f'{prefix}_mrr_filt_tail_new')): value = mrr_by_position[i].item() log_str += f'{t}: {value:.4f} ' _run.log_scalar(t, value, epoch) _log.info(log_str) if compute_filtered and triples_loader.dataset.has_rel_categories: mrr_cat_count[mrr_cat_count < 1.0] = 1.0 mrr_by_category = mrr_by_category / mrr_cat_count for i, case in enumerate(['pred_head', 'pred_tail']): log_str = f'{case} ' for cat, cat_id in CATEGORY_IDS.items(): log_str += f'{cat}_mrr: {mrr_by_category[i, cat_id]:.4f} ' _log.info(log_str) if return_embeddings: out = (mrr, ent_emb) else: out = (mrr, None) return out
def test(model_filename: str, test_corpus: str, _run: Run = None, _log: logger = None): """ run test crf model using stanford ner-crf features main """ _run.add_resource(test_corpus) _run.add_resource(f'{model_filename}.pkl') test_sents, _ = get_tagged_sents_and_words(test_corpus) X_test = [sent2stanfordfeats(s) for s in test_sents] y_test = [sent2stanfordlabels(s) for s in test_sents] _log.info(f'load from: {model_filename}.pkl') crf = sklearn_crfsuite.CRF(model_filename=model_filename) y_pred = crf.predict(X_test) overall, by_type = evaluate(y_test, y_pred) _run.info[f'overall_f1'] = overall.f1_score _run.log_scalar('overall_f1', overall.f1_score) _run.info[f'overall_precision'] = overall.precision _run.log_scalar('overall_precision', overall.precision) _run.info[f'overall_recall'] = overall.recall _run.log_scalar('overall_recall', overall.recall) _log.info(f'Overall F1 score: {overall.f1_score}') for _, key in enumerate(sorted(by_type.keys())): for metric_key in by_type[key]._fields: metric_val = getattr(by_type[key], metric_key) _run.info[f'{key}-{metric_key}'] = metric_val _run.log_scalar(f'{key}-{metric_key}', metric_val) _log.info(f'{key}-{metric_key}: {metric_val}')