Example #1
0
def train(
    annotations,
    use_precise_form_types=True,
    optimize_hyperparameters_iters=0,
    full_form_type_names=False,
    full_field_type_names=True,
    verbose=True,
):
    def log(msg):
        if verbose:
            print(msg)

    annotations = [a for a in annotations if a.fields_annotated and (a.form_annotated or not use_precise_form_types)]
    log("Training on {} forms".format(len(annotations)))

    if use_precise_form_types:
        log("Using precise form types")
        if full_form_type_names:
            form_types = np.asarray([a.type_full for a in annotations])
        else:
            form_types = np.asarray([a.type for a in annotations])
    else:
        log("Computing realistic form types")
        form_types = formtype_model.get_realistic_form_labels(
            annotations=annotations, n_folds=10, full_type_names=full_form_type_names
        )

    log("Extracting features")
    X, y = get_Xy(annotations=annotations, form_types=form_types, full_type_names=full_field_type_names)

    crf = get_model(use_precise_form_types)

    if optimize_hyperparameters_iters != 0:
        if optimize_hyperparameters_iters < 50:
            warnings.warn(
                "RandomizedSearchCV n_iter is low, results may be unstable. "
                "Consider increasing optimize_hyperparameters_iters."
            )

        log("Finding best hyperparameters using randomized search")
        params_space = {"c1": scipy.stats.expon(scale=0.5), "c2": scipy.stats.expon(scale=0.05)}

        rs = RandomizedSearchCV(
            crf,
            params_space,
            cv=get_annotation_folds(annotations, 5),
            verbose=verbose,
            n_jobs=-1,
            n_iter=optimize_hyperparameters_iters,
            iid=False,
            scoring=scorer,
        )
        rs.fit(X, y)
        crf = rs.best_estimator_
        log("Best hyperparameters: c1={:0.5f}, c2={:0.5f}".format(crf.c1, crf.c2))
    else:
        crf.fit(X, y)

    return crf
Example #2
0
def get_realistic_form_labels(annotations, n_folds=10, model=None,
                              full_type_names=True):
    """
    Return form type labels which form type detection model
    is likely to produce.
    """
    if model is None:
        model = get_model()

    X, y = get_Xy(annotations, full_type_names)
    folds = get_annotation_folds(annotations, n_folds)
    return cross_val_predict(model, X, y, cv=folds)
Example #3
0
def print_classification_report(annotations, n_folds=10, model=None):
    """ Evaluate model, print classification report """
    if model is None:
        # FIXME: we're overfitting on hyperparameters - they should be chosen
        # using inner cross-validation, not set to fixed values beforehand.
        model = get_model()

    X, y = get_Xy(annotations, full_type_names=True)
    folds = get_annotation_folds(annotations, n_folds)
    y_pred = cross_val_predict(model, X, y, cv=folds)

    # hack to format report nicely
    all_labels = list(annotations[0].form_schema.types.keys())
    labels = sorted(set(y_pred), key=lambda k: all_labels.index(k))
    print(classification_report(y, y_pred, digits=2,
                                labels=labels, target_names=labels))

    print("{:0.1f}% forms are classified correctly.".format(
        accuracy_score(y, y_pred) * 100
    ))
Example #4
0
def print_classification_report(annotations, n_folds=10, model=None):
    """ Evaluate model, print classification report """
    if model is None:
        # FIXME: we're overfitting on hyperparameters - they should be chosen
        # using inner cross-validation, not set to fixed values beforehand.
        model = get_model(use_precise_form_types=True)

    annotations = [a for a in annotations if a.fields_annotated]
    form_types = formtype_model.get_realistic_form_labels(
        annotations=annotations, n_folds=n_folds, full_type_names=False
    )

    X, y = get_Xy(annotations=annotations, form_types=form_types, full_type_names=True)
    cv = get_annotation_folds(annotations, n_folds=n_folds)
    y_pred = cross_val_predict(model, X, y, cv=cv, n_jobs=-1)

    all_labels = list(annotations[0].field_schema.types.keys())
    labels = sorted(set(flatten(y_pred)), key=lambda k: all_labels.index(k))
    print(flat_classification_report(y, y_pred, digits=2, labels=labels, target_names=labels))

    print("{:0.1f}% fields are classified correctly.".format(flat_accuracy_score(y, y_pred) * 100))
    print("All fields are classified correctly in {:0.1f}% forms.".format(sequence_accuracy_score(y, y_pred) * 100))