コード例 #1
0
def test_model_train():
    trainablemodel = models.BagOfWordsModel(parameters=config['model'],
                                            logger=logger)
    assert type(trainablemodel._model) == linear_model.LogisticRegression
    webquestions = webquestions_io.WebQuestions(config['webquestions'],
                                                logger=logger)
    input_set, targets = webquestions.get_training_samples()
    input_set, targets = input_set[:200], targets[:200]
    trainablemodel.train(
        (input_set, targets),
        validation_with_targets=webquestions.get_validation_samples()
        if 'train_validation' in config['webquestions']['path.to.dataset'] else
        None)
    print('Training finished')
コード例 #2
0
def test_model_multiclass():
    config['model']['epochs'] = 10
    config['model']['loss'] = "mlml"
    config['webquestions']['target.dist'] = "multiclass"
    webquestions = webquestions_io.WebQuestions(config['webquestions'],
                                                logger=logger)
    _, silver_test_targets = webquestions.get_full_validation()
    print(silver_test_targets[:10])
    trainablemodel = models.CNNLabelsTorchModel(parameters=config['model'],
                                                logger=logger)
    trainablemodel.prepare_model(webquestions.get_training_tokens(),
                                 webquestions.get_training_properties_tokens())
    trainablemodel.train(
        webquestions,
        validation_with_targets=webquestions.get_validation_samples())
    print('Training finished')
コード例 #3
0
def test_model(path_to_model, config_file_path):
    """

    :param path_to_model:
    :param config_file_path:
    :return:
    """
    config = utils.load_config(config_file_path)
    config_global = config.get('global', {})
    np.random.seed(config_global.get('random_seed', 1))

    logger = logging.getLogger(__name__)
    logger.setLevel(config['logger']['level'])
    ch = logging.StreamHandler()
    ch.setLevel(config['logger']['level'])
    logger.addHandler(ch)

    results_logger = None
    if 'log.results' in config['training']:
        results_logger = logging.getLogger("results_logger")
        results_logger.setLevel(logging.INFO)
        fh = logging.FileHandler(filename=config['training']['log.results'])
        fh.setLevel(logging.INFO)
        results_logger.addHandler(fh)
        results_logger.info(str(config))

    webquestions = webquestions_io.WebQuestions(config['webquestions'], logger=logger)
    config['model']['samples.per.epoch'] = webquestions.get_train_sample_size()
    config['model']['graph.choices'] = config['webquestions'].get("max.negative.samples", 30)

    trainablemodel = getattr(models, config['model']['class'])(parameters=config['model'], logger=logger)
    trainablemodel.load_from_file(path_to_model)

    print("Testing the model on silver data.")
    if 'train_validation' in config['webquestions']['path.to.dataset']:
        silver_test_set, silver_test_targets = webquestions.get_full_validation()
    else:
        silver_test_set, silver_test_targets = webquestions.get_full_training()
    accuracy_on_silver, predicted_targets = trainablemodel.test_on_silver((silver_test_set, silver_test_targets), verbose=True)
    print("Accuracy on silver data: {}".format(accuracy_on_silver))
    with open(config['training']['log.results'].replace(".log", "_silver_predictions.log"), "w") as out:
        if len(silver_test_targets) > 0 and not issubclass(type(silver_test_targets[0]), np.integer):
            silver_test_targets = np.argmax(silver_test_targets, axis=-1)
        json.dump((silver_test_set, predicted_targets, [int(t) for t in silver_test_targets]), out)
    if results_logger:
        results_logger.info("Accuracy on silver data: {}".format(accuracy_on_silver))
コード例 #4
0
def train(config_file_path):
    """

    :param config_file_path:
    :return:
    """
    config = utils.load_config(config_file_path)
    if "training" not in config:
        print("Training parameters not in the config file!")
        sys.exit()

    config_global = config.get('global', {})
    np.random.seed(config_global.get('random.seed', 1))
    torch.manual_seed(config_global.get('random.seed', 1))
    if torch.cuda.is_available():
        print("using your CUDA device")
        torch.cuda.manual_seed(config_global.get('random.seed', 1))

    logger = logging.getLogger(__name__)
    logger.setLevel(config['logger']['level'])
    ch = logging.StreamHandler()
    ch.setLevel(config['logger']['level'])
    logger.addHandler(ch)
    logger.debug(str(datetime.datetime.now()))

    results_logger = None
    if 'log.results' in config['training']:
        results_logger = logging.getLogger("results_logger")
        results_logger.setLevel(logging.INFO)
        fh = logging.FileHandler(filename=config['training']['log.results'])
        fh.setLevel(logging.INFO)
        results_logger.addHandler(fh)
        results_logger.info(str(config))

    config['webquestions']['max.entity.options'] = config['evaluation'].get(
        'max.entity.options', 3)
    webquestions = webquestions_io.WebQuestions(config['webquestions'],
                                                logger=logger)
    config['model']['samples.per.epoch'] = webquestions.get_train_sample_size()
    config['model']['graph.choices'] = config['webquestions'].get(
        "max.negative.samples", 30)
    if config['training'].get('train.generator', False):
        config['training']['train.mode'] = "generator"

    trainablemodel = getattr(models, config['model']['class'])(
        parameters=config['model'], logger=logger)
    trainablemodel.prepare_model(webquestions.get_training_tokens(),
                                 webquestions.get_training_properties_tokens())
    if results_logger:
        results_logger.info("Model save to: {}".format(
            trainablemodel._model_file_name))
    if config['training'].get('train.mode', "base") == "generator":
        trainablemodel.train_on_generator(
            webquestions,
            validation_with_targets=webquestions.get_validation_samples()
            if 'train_validation' in config['webquestions']['path.to.dataset']
            else None)
    elif config['training'].get('train.mode', "base") == "model":
        trainablemodel.train(
            webquestions,
            validation_with_targets=webquestions.get_validation_samples()
            if 'train_validation' in config['webquestions']['path.to.dataset']
            else None,
            model_sampling=True)
    else:
        trainablemodel.train(
            webquestions,
            validation_with_targets=webquestions.get_validation_samples()
            if 'train_validation' in config['webquestions']['path.to.dataset']
            else None)
    print("Loading the best model")
    trainablemodel.load_last_saved()

    if 'train_validation' in config['webquestions']['path.to.dataset']:
        silver_test_set, silver_test_targets = webquestions.get_full_validation(
        )
    else:
        silver_test_set, silver_test_targets = webquestions.get_full_training()
    accuracy_on_silver, predicted_targets = trainablemodel.test_on_silver(
        (silver_test_set, silver_test_targets), verbose=True)
    print("Accuracy on silver data: {}".format(accuracy_on_silver))
    with open(
            config['training']['log.results'].replace(
                ".log", "_silver_predictions.log"), "w") as out:
        if len(silver_test_targets) > 0 and not issubclass(
                type(silver_test_targets[0]), np.integer):
            silver_test_targets = np.argmax(silver_test_targets, axis=-1)
        json.dump((silver_test_set, predicted_targets,
                   [int(t) for t in silver_test_targets]), out)

    if results_logger:
        results_logger.info(
            "Accuracy on silver data: {}".format(accuracy_on_silver))
コード例 #5
0
with open("../questionanswering/default_config.yaml", 'r') as config_file:
    config = yaml.load(config_file.read())

logger = logging.getLogger(__name__)
logger.setLevel(config['logger']['level'])
ch = logging.StreamHandler()
ch.setLevel(config['logger']['level'])
logger.addHandler(ch)

config['webquestions']['extensions'] = []
config['webquestions']['max.entity.options'] = 1
config['model']['graph.choices'] = config['webquestions'].get(
    "max.negative.samples", 30)
config['model']['epochs'] = 6
config['model']['threshold'] = 10
webquestions = webquestions_io.WebQuestions(config['webquestions'],
                                            logger=logger)

config['model']["batch.size"] = 10


def test_pool_selfatt():
    config['model']['sibling.pool.mode'] = "selfatt"
    config['model']["model.checkpoint"] = True
    trainablemodel = models.pytorchmodel_impl.CNNLabelsHashesModel(
        parameters=config['model'], logger=logger)
    trainablemodel.prepare_model(webquestions)
    trainablemodel.train(
        webquestions,
        validation_with_targets=webquestions.get_validation_samples())
    print('Training finished')
    trainablemodel.load_last_saved()
コード例 #6
0
def generate(path_to_model, config_file_path):

    config = utils.load_config(config_file_path)
    if "evaluation" not in config:
        print("Evaluation parameters not in the config file!")
        sys.exit()
    config_global = config.get('global', {})
    np.random.seed(config_global.get('random.seed', 1))

    logger = logging.getLogger(__name__)
    logger.setLevel(config['logger']['level'])
    ch = logging.StreamHandler()
    ch.setLevel(config['logger']['level'])
    logger.addHandler(ch)
    # logging.basicConfig(level=config['logger']['level'])

    wdaccess.wdaccess_p["timeout"] = config['wikidata'].get("timeout", 20)
    wdaccess.wdaccess_p['wikidata_url'] = config['wikidata'].get("backend", "http://knowledgebase:8890/sparql")
    wdaccess.sparql_init()

    assert not(config['webquestions']["no.ne.tags"] and config['evaluation']["only.named.entities"])
    entity_linking.entity_linking_p["no.ne.tags"] = config['webquestions']["no.ne.tags"]
    assert not(entity_linking.entity_linking_p["no.ne.tags"] and config['evaluation']["only.named.entities"])
    entity_linking.entity_linking_p["max.entity.options"] = config['evaluation']["max.entity.options"]
    entity_linking.entity_linking_p["global.entity.grouping"] = not config['webquestions']["no.ne.tags"] and entity_linking.entity_linking_p["overlaping.nn.ne"]
    entity_linking.entity_linking_p["np.parser"] = config['evaluation']["np.parser"]

    wdaccess.wdaccess_p["restrict.hop"] = config['wikidata'].get("restrict.hop", False)
    wdaccess.update_sparql_clauses()
    staged_generation.generation_p["use.whitelist"] = config['evaluation'].get("use.whitelist", False)
    staged_generation.generation_p["strict.structure"] = config['evaluation'].get("strict.structure", False)
    staged_generation.generation_p["v.structure"] = config['evaluation'].get("v.structure", True)
    staged_generation.generation_p["max.num.entities"] = config['evaluation'].get("max.num.entities", 3)
    staged_generation.generation_p["topics.at.test"] = config['evaluation'].get("topics.at.test", False)
    if not config['wikidata'].get("addclass.action", False):
        stages.RESTRICT_ACTIONS = [stages.add_entity_and_relation, stages.last_relation_temporal,
                                   stages.add_temporal_relation, stages.last_relation_numeric]

    logger.debug("max.entity.options: {}".format(entity_linking.entity_linking_p["max.entity.options"]))
    if 'hop.types' in config['wikidata']:
        stages.HOP_TYPES = set(config['wikidata']['hop.types'])
    if 'arg.types' in config['wikidata']:
        stages.ARG_TYPES = set(config['wikidata']['arg.types'])
    if 'filter.out.relation.classes' in config['wikidata']:
        wdaccess.FILTER_RELATION_CLASSES = set(config['wikidata']['filter.out.relation.classes'])
    logger.debug("entity_linking: {}".format(entity_linking.entity_linking_p))
    with open(config['evaluation']['questions']) as f:
        webquestions_questions = json.load(f)
    webquestions = webquestions_io.WebQuestions(config['webquestions'], logger=logger)

    logger.debug('Extracting entities.')
    webquestions_entities = webquestions.extract_question_entities()
    if 'class' not in config['model']:
        config['model']['class'] = path_to_model.split('/')[-1].split('_')[0]
    logger.debug('Loading the model from: {}'.format(path_to_model))
    qa_model = getattr(models, config['model']['class'])(parameters=config['model'], logger=logger)
    qa_model.load_from_file(path_to_model)

    logger.debug('Testing')
    global_answers = []
    avg_metrics = np.zeros(3)
    len_webquestion = webquestions.get_dataset_size()
    for i in tqdm.trange(len_webquestion, ncols=100, ascii=True):
        question_entities = webquestions_entities[i]
        nes = [e for e in question_entities if e[1] != "NN"]
        if config['evaluation'].get('only.named.entities', False) and len(nes) > 0:
            question_entities = nes
        ungrounded_graph = {'tokens': webquestions.get_original_question_tokens(i),
                            'edgeSet': [],
                            'fragments': question_entities}
        chosen_graphs = staged_generation.generate_with_model(ungrounded_graph, qa_model, beam_size=config['evaluation'].get("beam.size", 10))
        model_answers = []
        g = ({},)
        if chosen_graphs:
            j = 0
            while not model_answers and j < len(chosen_graphs):
                g = chosen_graphs[j]
                model_answers = wdaccess.query_graph_denotations(g[0]).get("e1")
                j += 1
        if config['evaluation'].get('label.answers', False):
            gold_answers = [e.lower() for e in webquestions_io.get_answers_from_question(webquestions_questions[i])]
            model_answers_labels = wdaccess.label_query_results(model_answers)
            model_answers_labels = staged_generation.post_process_answers_given_graph(model_answers_labels, g[0])
            metrics = evaluation.retrieval_prec_rec_f1_with_altlabels(gold_answers, model_answers_labels)
            global_answers.append((i, list(metrics), model_answers, model_answers_labels,
                                   [(c_g[0], float(c_g[1])) for c_g in chosen_graphs[:10]]))
        else:
            gold_answers = webquestions_io.get_answers_from_question(webquestions_questions[i])
            metrics = evaluation.retrieval_prec_rec_f1(gold_answers, model_answers)
            global_answers.append((i, list(metrics), model_answers,
                                   [(c_g[0], float(c_g[1])) for c_g in chosen_graphs[:10]]))
        avg_metrics += metrics
        if i % 100 == 0:
            logger.debug("Average f1 so far: {}".format((avg_metrics/(i+1))))
            with open(config['evaluation']["save.answers.to"], 'w') as answers_out:
                json.dump(global_answers, answers_out, sort_keys=True, indent=4)

    print("Average metrics: {}".format((avg_metrics/(len_webquestion))))

    logger.debug('Testing is finished')
    with open(config['evaluation']["save.answers.to"], 'w') as answers_out:
        json.dump(global_answers, answers_out, sort_keys=True, indent=4)
def generate(config_file_path):
    config = utils.load_config(config_file_path)
    if "generation" not in config:
        print("Generation parameters not in the config file!")
        sys.exit()
    config_global = config.get('global', {})
    np.random.seed(config_global.get('random.seed', 1))

    logger = logging.getLogger(__name__)
    logger.setLevel(config['logger']['level'])
    ch = logging.StreamHandler()
    ch.setLevel(config['logger']['level'])
    logger.addHandler(ch)
    # logging.basicConfig(level=logging.ERROR)
    logger.debug(str(datetime.datetime.now()))

    staged_generation.generation_p['label.query.results'] = config[
        'generation'].get('label.query.results', False)
    staged_generation.generation_p["use.whitelist"] = config['generation'].get(
        "use.whitelist", False)
    staged_generation.generation_p["min.fscore.to.stop"] = config[
        'generation'].get("min.fscore.to.stop", 0.9)
    entity_linking.entity_linking_p[
        "global.entity.grouping"] = not config['webquestions']["no.ne.tags"]
    entity_linking.entity_linking_p["max.entity.options"] = config[
        'generation']["max.entity.options"]
    wdaccess.wdaccess_p['wikidata_url'] = config['wikidata'].get(
        "backend", "http://knowledgebase:8890/sparql")
    wdaccess.wdaccess_p["restrict.hop"] = config['wikidata'].get(
        "restrict.hop", False)
    wdaccess.wdaccess_p["timeout"] = config['wikidata'].get("timeout", 20)
    wdaccess.sparql_init()
    wdaccess.update_sparql_clauses()
    logger.debug("entity_linking: {}".format(entity_linking.entity_linking_p))
    if 'hop.types' in config['wikidata']:
        stages.HOP_TYPES = set(config['wikidata']['hop.types'])
    logger.debug("Hop types set to: {}".format(stages.HOP_TYPES))
    if 'arg.types' in config['wikidata']:
        stages.ARG_TYPES = set(config['wikidata']['arg.types'])
    if 'filter.out.relation.classes' in config['wikidata']:
        wdaccess.FILTER_RELATION_CLASSES = set(
            config['wikidata']['filter.out.relation.classes'])
    logger.debug("Arg types set to: {}".format(stages.ARG_TYPES))

    if wdaccess.wdaccess_p["restrict.hop"]:
        logger.debug("HOP UP relations set to: {}".format(
            wdaccess.HOP_UP_RELATIONS))
        logger.debug("HOP UP relations set to: {}".format(
            wdaccess.sparql_hopup_values))
        logger.debug("HOP DOWN relations set to: {}".format(
            wdaccess.HOP_DOWN_RELATIONS))
        logger.debug("HOP UP relations set to: {}".format(
            wdaccess.sparql_hopdown_values))

    webquestions = webquestions_io.WebQuestions(config['webquestions'],
                                                logger=logger)
    logger.debug('Loaded WebQuestions, size: {}'.format(
        webquestions.get_dataset_size()))

    with open(config['generation']['questions']) as f:
        webquestions_questions = json.load(f)
    logger.debug(
        'Loaded WebQuestions original training questions, size: {}'.format(
            len(webquestions_questions)))
    assert len(webquestions_questions) == webquestions.get_dataset_size()

    logger.debug('Extracting entities.')
    webquestions_entities = webquestions.extract_question_entities()

    silver_dataset = []
    previous_silver = []
    if 'previous' in config['generation']:
        logger.debug("Loading the previous result")
        with open(config['generation']['previous']) as f:
            previous_silver = json.load(f)
        logger.debug("Previous length: {}".format(len(previous_silver)))
        logger.debug("Previous number of answers covered: {}".format(
            len([
                1 for graphs in previous_silver if len(graphs) > 0
                and any([len(g) > 1 and g[1][2] > 0.0 for g in graphs])
            ]) / len(previous_silver)))
        logger.debug("Previous average f1: {}".format(
            np.average([
                np.max([g[1][2] if len(g) > 1 else 0.0
                        for g in graphs]) if len(graphs) > 0 else 0.0
                for graphs in previous_silver
            ])))
        previous_silver = [[
            g for g in graph_set
            if graph.if_graph_adheres(g[0],
                                      allowed_extensions=config['webquestions']
                                      .get("extensions", set()))
        ] for graph_set in previous_silver]
        logger.debug(
            "Previous average f1 (after extensions removed): {}".format(
                np.average([
                    np.max([g[1][2] if len(g) > 1 else 0.0
                            for g in graphs]) if len(graphs) > 0 else 0.0
                    for graphs in previous_silver
                ])))
        logger.debug("Reusable: {}".format(
            len([
                1 for graphs in previous_silver if len(graphs) > 0
                and any([len(g) > 1 and g[1][2] > 0.9 for g in graphs])
            ]) / len(previous_silver)))

    len_webquestion = webquestions.get_dataset_size()
    start_with = 0
    if 'start.with' in config['generation']:
        start_with = config['generation']['start.with']
        print("Starting with {}.".format(start_with))

    if 'take_first' in config['generation']:
        print("Taking the first {} questions.".format(
            config['generation']['take_first']))
        len_webquestion = config['generation']['take_first']
    logger.debug("First question: {} {}\n {}".format(
        start_with, webquestions_questions[start_with],
        webquestions_io.get_answers_from_question(
            webquestions_questions[start_with])))
    for i in tqdm.tqdm(range(start_with, len_webquestion), ncols=100):
        if len(previous_silver) > i and max(
                g[1][2] if len(g) > 1 and len(g[1]) > 2 else 0.0
                for g in previous_silver[i]) > 0.9:
            silver_dataset.append(previous_silver[i])
        else:
            question_entities = webquestions_entities[i]
            if "max.num.entities" in config['generation']:
                question_entities = question_entities[:config['generation']
                                                      ["max.num.entities"]]
            if config['generation'].get('include_url_entities', False):
                url_entity = webquestions_io.get_main_entity_from_question(
                    webquestions_questions[i])
                if not any(e == url_entity[0] for e, t in question_entities):
                    # question_entities = [url_entity] + [(e, t) for e, t in question_entities if e != url_entity[0]]
                    question_entities = [url_entity] + question_entities
            ungrounded_graph = {
                'tokens': webquestions.get_original_question_tokens(i),
                'edgeSet': [],
                'fragments': question_entities
            }
            logger.log(level=0,
                       msg="Generating from: {}".format(ungrounded_graph))
            gold_answers = webquestions_io.get_answers_from_question(
                webquestions_questions[i])
            if staged_generation.generation_p['label.query.results']:
                gold_answers = [e.lower() for e in gold_answers]
            generated_graphs = staged_generation.generate_with_gold(
                ungrounded_graph, gold_answers)
            silver_dataset.append(generated_graphs)
        if i % 100 == 0:
            logger.debug("Cov., avg. f1: {}, {}".format(
                (len([
                    1 for graphs in silver_dataset if len(graphs) > 0
                    and any([len(g) > 1 and g[1][2] > 0.0 for g in graphs])
                ]) / (i + 1)),
                np.average([
                    np.max([g[1][2] if len(g) > 1 else 0.0
                            for g in graphs]) if len(graphs) > 0 else 0.0
                    for graphs in silver_dataset
                ])))
            # Dump the data set once in while
            with open(config['generation']["save.silver.to"], 'w') as out:
                json.dump(silver_dataset, out, sort_keys=True, indent=4)

    logger.debug("Generation finished. Silver dataset size: {}".format(
        len(silver_dataset)))
    with open(config['generation']["save.silver.to"], 'w') as out:
        json.dump(silver_dataset, out, sort_keys=True, indent=4)

    print("Query cache: {}".format(len(wdaccess.query_cache)))
    print("Number of answers covered: {}".format(
        len([
            1 for graphs in silver_dataset if len(graphs) > 0
            and any([len(g) > 1 and g[1][2] > 0.0 for g in graphs])
        ]) / len_webquestion))
    print("Average f1 of the silver data: {}".format(
        np.average([
            np.max([g[1][2] if len(g) > 1 else 0.0
                    for g in graphs]) if len(graphs) > 0 else 0.0
            for graphs in silver_dataset
        ])))