Beispiel #1
0
def run_export(model_path: RichPath, test_data_path: RichPath,
               output_folder: str):
    test_hyper_overrides = {
        'run_id': 'exporting',
        "dropout_keep_rate": 1.0,
    }

    data_chunks = test_data_path.get_filtered_files_in_dir('*gz')

    # Restore model
    model = model_restore_helper.restore(model_path,
                                         is_train=False,
                                         hyper_overrides=test_hyper_overrides)

    exporting = model.export_representations(data_chunks)

    os.makedirs(output_folder, exist_ok=True)
    with open(os.path.join(output_folder, 'vectors.tsv'), 'w') as vectors_file,\
            open(os.path.join(output_folder, 'metadata.tsv'), 'w') as metadata_file:

        metadata_file.write('varname\ttype\tkind\tprovenance\n')
        for annot in exporting:
            metadata_file.write(
                f'{assert_valid_str(annot.name)}\t{assert_valid_str(annot.type_annotation)}\t{assert_valid_str(annot.kind)}\t{assert_valid_str(annot.provenance)}\n'
            )
            vectors_file.write('\t'.join(str(e) for e in annot.representation))
            vectors_file.write('\n')
Beispiel #2
0
def run_train(model_class: Type[Model],
              train_data_path: RichPath,
              valid_data_path: RichPath,
              save_folder: str,
              hyperparameters: Dict[str, Any],
              run_name: Optional[str] = None,
              quiet: bool = False) \
        -> RichPath:
    train_data_chunk_paths = train_data_path.get_filtered_files_in_dir(
        'chunk_*')
    valid_data_chunk_paths = valid_data_path.get_filtered_files_in_dir(
        'chunk_*')

    model = model_class(hyperparameters, run_name=run_name,
                        model_save_dir=save_folder, log_save_dir=save_folder)  # pytype: disable=not-instantiable
    if os.path.exists(model.model_save_path):
        model = model_restore_helper.restore(
            RichPath.create(model.model_save_path), is_train=True)
        model.train_log("Resuming training run %s of model %s with following hypers:\n%s" % (hyperparameters['run_id'], model.__class__.__name__, json.dumps(
            hyperparameters)))
        resume = True
    else:
        try:
            model.load_existing_metadata(
                train_data_path.join('metadata.pkl.gz'))
        except:
            model.load_metadata(train_data_path)
        model.make_model(is_train=True)

        model.train_log("Starting training run %s of model %s with following hypers:\n%s" % (
            hyperparameters['run_id'], model.__class__.__name__, json.dumps(hyperparameters, indent=2, sort_keys=True)))
        resume = False
    model_path = model.train(train_data_chunk_paths,
                             valid_data_chunk_paths, quiet=quiet, resume=resume)
    return model_path
Beispiel #3
0
def run_test(model_path: RichPath, test_data_path: RichPath, type_lattice_path: RichPath, alias_metadata_path: RichPath, print_predictions: bool = False):
    test_run_id = "_".join(
        [time.strftime("%Y-%m-%d-%H-%M-%S"), str(os.getpid())])

    test_hyper_overrides = {
        'run_id': test_run_id,
        "dropout_keep_rate": 1.0,
    }

    test_data_chunks = test_data_path.get_filtered_files_in_dir('*gz')

    # Restore model
    model = model_restore_helper.restore(
        model_path, is_train=False, hyper_overrides=test_hyper_overrides)

    evaluator = TypePredictionEvaluator(type_lattice_path, alias_metadata_path)

    all_annotations = model.annotate(test_data_chunks)
    for annotation in all_annotations:
        if ignore_type_annotation(annotation.original_annotation):
            continue
        predicted_annotation = max(annotation.predicted_annotation_logprob_dist,
                                   key=lambda x: annotation.predicted_annotation_logprob_dist[x])
        if print_predictions:
            print(
                f'{annotation.provenance} -- {annotation.name}: {annotation.original_annotation} -> {predicted_annotation} ({math.exp(annotation.predicted_annotation_logprob_dist[predicted_annotation])*100:.1f}%)')
        evaluator.add_sample(ground_truth=annotation.original_annotation,
                             predicted_dist=annotation.predicted_annotation_logprob_dist)

    print(json.dumps(evaluator.metrics(), indent=2, sort_keys=True))
Beispiel #4
0
def run_indexing(model_path: RichPath, index_data_path: RichPath):
    test_hyper_overrides = {
        'run_id': 'indexing',
        "dropout_keep_rate": 1.0,
    }

    data_chunks = index_data_path.get_filtered_files_in_dir('*.jsonl.gz')

    # Restore model
    model = model_restore_helper.restore(model_path,
                                         is_train=False,
                                         hyper_overrides=test_hyper_overrides)

    model.create_index(data_chunks)
    model.save(model_path)
Beispiel #5
0
def run_predict(model_path: RichPath, test_data_path: RichPath,
                output_file: RichPath):
    test_run_id = "_".join(
        [time.strftime("%Y-%m-%d-%H-%M-%S"),
         str(os.getpid())])

    test_hyper_overrides = {
        'run_id': test_run_id,
        "dropout_keep_rate": 1.0,
    }

    test_data_chunks = test_data_path.get_filtered_files_in_dir('*.jsonl.gz')

    # Restore model
    model = model_restore_helper.restore(model_path,
                                         is_train=False,
                                         hyper_overrides=test_hyper_overrides)

    def predictions():
        for annotation in model.annotate(test_data_chunks):
            if ignore_annotation(annotation.original_annotation):
                continue
            ordered_annotation_predictions = sorted(
                annotation.predicted_annotation_logprob_dist,
                key=lambda x: -annotation.predicted_annotation_logprob_dist[
                    x])[:10]

            annotation_dict = annotation._asdict()
            logprobs = annotation_dict['predicted_annotation_logprob_dist']
            filtered_logprobs = []
            for annot in ordered_annotation_predictions:
                logprob = float(logprobs[annot])
                if annot == '%UNK%' or annot == '%UNKNOWN%':
                    annot = 'typing.Any'
                filtered_logprobs.append((annot, logprob))
            annotation_dict[
                'predicted_annotation_logprob_dist'] = filtered_logprobs

            yield annotation_dict

    output_file.save_as_compressed_file(predictions())