def _eval_lex_model(label_est_samples, valid_samples) -> float: estimated_labelprops = { "estimated": calculate_labelprops( label_est_samples, _DATADEF.n_classes, _DATADEF.domain_names, ) } datadef = get_datadef(_DATASET_NAME) datadef.load_labelprops_func = lambda _split: estimated_labelprops[ _split ] metrics = eval_lexicon_model( model, datadef, valid_samples, vocab, use_source_individual_norm=_LEXICON_CONFIG[ "use_source_individual_norm" ], labelprop_split="estimated", # match _load_labelprops_func() ) return metrics["valid_f1"]
config, _DATADEF, train_samples=train_samples, valid_samples=valid_samples, vocab_size=config["vocab_size"], logdir=join(savedir, train_source), train_labelprop_split="train", valid_labelprop_split="train", ) model = torch.load(join(savedir, train_source, "model.pth")) vocab = read_txt_as_str_list(join(savedir, train_source, "vocab.txt")) test_samples = _DATADEF.load_splits_func(holdout_sources, ["test"])["test"] test_metrics = eval_lexicon_model( model, _DATADEF, test_samples, vocab, use_lemmatize=False, use_source_individual_norm=config["use_source_individual_norm"], labelprop_split="test", ) save_json(test_metrics, join(savedir, train_source, "leaf_test.json")) save_json(config, join(savedir, "config.json")) reduce_and_save_metrics(_SAVE_ROOT) reduce_and_save_metrics(_SAVE_ROOT, "leaf_test.json", "mean_test.json")
vocab_size=len(vocab), ) model = get_model(config).to(AUTO_DEVICE) model.set_weight_from_lexicon(lexicon_df, _DATADEF.label_names) use_source_individual_norm = config["use_source_individual_norm"] use_lemmatize = config["use_lemmatize"] metrics = {} # run validation set valid_metrics = eval_lexicon_model( model=model, datadef=_DATADEF, valid_samples=valid_samples, vocab=vocab, use_source_individual_norm=use_source_individual_norm, use_lemmatize=use_lemmatize, labelprop_split="train", ) metrics.update(valid_metrics) save_json(metrics, join(logdir, "leaf_metrics.json")) write_str_list_as_txt(vocab, join(logdir, "vocab.txt")) torch.save(model, join(logdir, "model.pth")) # run test set test_samples = _DATADEF.load_splits_func([holdout_source], ["test"])["test"] test_metrics = eval_lexicon_model( model, _DATADEF,