Beispiel #1
0
def train_and_save_distance_model(ethnicity_model_path,
                                  save_distance_model_path,
                                  sampled_pairs_size):
    """Train the distance estimator model and save it to disk.

    Args:
        ethnicity_model_path (str): Full path where ethnicity model is saved.
        save_distance_model_path (str): Full path where trained distance model
            will be saved.
        sampled_pairs_size (int): Number of pairs to be generated for the training.
            Note:
                Must be multiple of 12.
    """
    LOGGER.info("Pulling training data from ES")
    curated_signatures = get_signatures(only_curated=True)
    input_clusters = get_input_clusters(curated_signatures)
    LOGGER.info("Preparing %s pairs from sampled data for training.",
                sampled_pairs_size)
    pairs = list(
        sample_signature_pairs(curated_signatures, input_clusters,
                               sampled_pairs_size))

    ethnicity_estimator = EthnicityEstimator(ethnicity_model_path)
    distance_estimator = DistanceEstimator(ethnicity_estimator)
    distance_estimator.load_data(curated_signatures, pairs, sampled_pairs_size)
    LOGGER.info("Training DistanceEstimator...")
    distance_estimator.fit()
    distance_estimator.save_model(save_distance_model_path)
Beispiel #2
0
def train_and_save_distance_model(ethnicity_model_path,
                                  save_distance_model_path,
                                  sampled_pairs_size):
    """Train the distance estimator model and save it to disk.

    Args:
        ethnicity_model_path (str): Full path where ethnicity model is saved.
        save_distance_model_path (str): Full path where trained distance model
            will be saved.
        sampled_pairs_size (int): Number of pairs to be generated for the training.
            Note:
                Must be multiple of 12.
    """
    LOGGER.info("Pulling training data from ES")
    start_time = datetime.now()
    curated_signatures = get_signatures(only_curated=True)
    input_clusters = get_input_clusters(curated_signatures)
    prepare_intput_time = datetime.now()
    LOGGER.info(
        "Preparing pairs from sampled data for training.",
        pairs_count=sampled_pairs_size,
    )
    pairs = list(
        sample_signature_pairs(curated_signatures, input_clusters,
                               sampled_pairs_size))
    prepare_pairs_time = datetime.now()

    ethnicity_estimator = EthnicityEstimator(ethnicity_model_path)
    distance_estimator = DistanceEstimator(ethnicity_estimator)
    prepare_estimators_time = datetime.now()
    distance_estimator.load_data(curated_signatures, pairs, sampled_pairs_size)
    load_data_to_model_time = datetime.now()
    LOGGER.info("Training DistanceEstimator...")
    distance_estimator.fit()
    training_model_time = datetime.now()
    distance_estimator.save_model(save_distance_model_path)
    save_model_time = datetime.now()
    LOGGER.info(
        "Train distance model",
        prepare_input_runtime=str(prepare_intput_time - start_time),
        prepare_pairs_runtime=str(prepare_pairs_time - prepare_intput_time),
        prepare_estimators_runtime=str(prepare_estimators_time -
                                       prepare_pairs_time),
        load_data_runtime=str(load_data_to_model_time -
                              prepare_estimators_time),
        training_model_runtime=str(training_model_time -
                                   load_data_to_model_time),
        save_model_runtime=str(save_model_time - training_model_time),
        total_runtime=str(save_model_time - start_time),
    )
Beispiel #3
0
def train_and_save_distance_model(
    ethnicity_model_path,
    save_distance_model_path,
    sampled_pairs_size,
    train_to_validation_split_fraction=0.8,
):
    """Train the distance estimator model and save it to disk.

    Args:
        ethnicity_model_path (str): Full path where ethnicity model is saved.
        save_distance_model_path (str): Full path where trained distance model
            will be saved.
        sampled_pairs_size (int): Number of pairs to be generated for the training.
            Note:
                Must be multiple of 4.
        train_to_validation_split_fraction (float): fraction of the data
            used for training.
    """
    start_time = datetime.now()
    curated_signatures = get_signatures(only_curated=True)
    LOGGER.info(
        "Splitting data into training and test set.",
        training_set_fraction=train_to_validation_split_fraction,
    )
    train_signatures_dict, test_signatures_dict = train_validation_split(
        curated_signatures, train_to_validation_split_fraction)
    train_signatures_list = train_signatures_dict.values()
    test_signatures_list = test_signatures_dict.values()
    input_clusters_train = get_input_clusters(train_signatures_list)
    input_clusters_test = get_input_clusters(test_signatures_list)
    prepare_intput_time = datetime.now()
    LOGGER.info(
        "Preparing pairs from sampled data for training.",
        pairs_count=sampled_pairs_size,
    )
    pairs_train = list(
        sample_signature_pairs(train_signatures_list, input_clusters_train,
                               sampled_pairs_size))
    prepare_pairs_time = datetime.now()
    # must be multiple of 4
    pair_size_test = 4 * math.ceil(
        (((1 - train_to_validation_split_fraction) /
          train_to_validation_split_fraction)**2 * sampled_pairs_size) / 4)
    pairs_test = list(
        sample_signature_pairs(test_signatures_list, input_clusters_test,
                               pair_size_test))
    LOGGER.info(
        "Pairs prepared.",
        n_training_pairs=len(pairs_train),
        n_test_pairs=len(pairs_test),
    )
    ethnicity_estimator = EthnicityEstimator(ethnicity_model_path)
    distance_estimator = DistanceEstimator(ethnicity_estimator)
    prepare_estimators_time = datetime.now()
    distance_estimator.load_data(train_signatures_list, pairs_train,
                                 sampled_pairs_size)
    load_data_to_model_time = datetime.now()
    distance_estimator.fit()
    training_model_time = datetime.now()
    distance_estimator.save_model(save_distance_model_path)
    save_model_time = datetime.now()
    distance_estimator.load_data(test_signatures_list, pairs_test,
                                 pair_size_test)
    test_score = distance_estimator.score()
    LOGGER.info(
        "Train distance model",
        prepare_input_runtime=str(prepare_intput_time - start_time),
        prepare_pairs_runtime=str(prepare_pairs_time - prepare_intput_time),
        prepare_estimators_runtime=str(prepare_estimators_time -
                                       prepare_pairs_time),
        load_data_runtime=str(load_data_to_model_time -
                              prepare_estimators_time),
        training_model_runtime=str(training_model_time -
                                   load_data_to_model_time),
        save_model_runtime=str(save_model_time - training_model_time),
        total_runtime=str(save_model_time - start_time),
        test_score=str(test_score),
    )
    return set(test_signatures_dict)
def test_distance_estimator_load_data(scan_mock,
                                      es_record_with_many_curated_authors):
    scan_mock.side_effect = [[es_record_with_many_curated_authors]]
    signatures = get_signatures()
    pairs = [
        {
            "same_cluster":
            True,
            "signature_uuids": [
                "94fc2b0a-dc17-42c2-bae3-ca0024079e52",
                "94fc2b0a-dc17-42c2-bae3-ca0024079e53",
            ],
        },
        {
            "same_cluster":
            True,
            "signature_uuids": [
                "94fc2b0a-dc17-42c2-bae3-ca0024079e54",
                "94fc2b0a-dc17-42c2-bae3-ca0024079e55",
            ],
        },
        {
            "same_cluster":
            False,
            "signature_uuids": [
                "94fc2b0a-dc17-42c2-bae3-ca0024079e56",
                "94fc2b0a-dc17-42c2-bae3-ca0024079e57",
            ],
        },
        {
            "same_cluster":
            False,
            "signature_uuids": [
                "94fc2b0a-dc17-42c2-bae3-ca0024079e52",
                "94fc2b0a-dc17-42c2-bae3-ca0024079e54",
            ],
        },
    ]
    distance_estimator = DistanceEstimator(None)
    distance_estimator.load_data(signatures, pairs, 4)
    expected_X = array(
        [
            [
                Signature(
                    author_affiliation="Rutgers U., Piscataway",
                    author_id=1,
                    author_name="Doe, John",
                    publication=Publication(
                        abstract="Many curated authors",
                        authors=[
                            "Doe, John",
                            "Doe, J",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Jamie",
                            "Jamie",
                        ],
                        collaborations=[],
                        keywords=["keyword"],
                        publication_id=1,
                        title="Title",
                        topics=["category"],
                    ),
                    signature_block="JOhn",
                    signature_uuid="94fc2b0a-dc17-42c2-bae3-ca0024079e52",
                    is_curated_author_id=True,
                ),
                Signature(
                    author_affiliation="Rutgers U., Piscataway",
                    author_id=1,
                    author_name="Doe, J",
                    publication=Publication(
                        abstract="Many curated authors",
                        authors=[
                            "Doe, John",
                            "Doe, J",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Jamie",
                            "Jamie",
                        ],
                        collaborations=[],
                        keywords=["keyword"],
                        publication_id=1,
                        title="Title",
                        topics=["category"],
                    ),
                    signature_block="JOhn",
                    signature_uuid="94fc2b0a-dc17-42c2-bae3-ca0024079e53",
                    is_curated_author_id=True),
            ],
            [
                Signature(
                    author_affiliation="Rutgers U., Piscataway",
                    author_id=2,
                    author_name="Doe, John",
                    publication=Publication(
                        abstract="Many curated authors",
                        authors=[
                            "Doe, John",
                            "Doe, J",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Jamie",
                            "Jamie",
                        ],
                        collaborations=[],
                        keywords=["keyword"],
                        publication_id=1,
                        title="Title",
                        topics=["category"],
                    ),
                    signature_block="JOhn",
                    signature_uuid="94fc2b0a-dc17-42c2-bae3-ca0024079e54",
                    is_curated_author_id=True),
                Signature(
                    author_affiliation="Rutgers U., Piscataway",
                    author_id=2,
                    author_name="Doe, John",
                    publication=Publication(
                        abstract="Many curated authors",
                        authors=[
                            "Doe, John",
                            "Doe, J",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Jamie",
                            "Jamie",
                        ],
                        collaborations=[],
                        keywords=["keyword"],
                        publication_id=1,
                        title="Title",
                        topics=["category"],
                    ),
                    signature_block="JOhn",
                    signature_uuid="94fc2b0a-dc17-42c2-bae3-ca0024079e55",
                    is_curated_author_id=True),
            ],
            [
                Signature(
                    author_affiliation="",
                    author_id=6,
                    author_name="Doe, John",
                    publication=Publication(
                        abstract="Many curated authors",
                        authors=[
                            "Doe, John",
                            "Doe, J",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Jamie",
                            "Jamie",
                        ],
                        collaborations=[],
                        keywords=["keyword"],
                        publication_id=1,
                        title="Title",
                        topics=["category"],
                    ),
                    signature_block="JOhn",
                    signature_uuid="94fc2b0a-dc17-42c2-bae3-ca0024079e56",
                    is_curated_author_id=True),
                Signature(
                    author_affiliation="Rutgers U., Piscataway",
                    author_id=7,
                    author_name="Jamie",
                    publication=Publication(
                        abstract="Many curated authors",
                        authors=[
                            "Doe, John",
                            "Doe, J",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Jamie",
                            "Jamie",
                        ],
                        collaborations=[],
                        keywords=["keyword"],
                        publication_id=1,
                        title="Title",
                        topics=["category"],
                    ),
                    signature_block="Jana",
                    signature_uuid="94fc2b0a-dc17-42c2-bae3-ca0024079e57",
                    is_curated_author_id=True),
            ],
            [
                Signature(
                    author_affiliation="Rutgers U., Piscataway",
                    author_id=1,
                    author_name="Doe, John",
                    publication=Publication(
                        abstract="Many curated authors",
                        authors=[
                            "Doe, John",
                            "Doe, J",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Jamie",
                            "Jamie",
                        ],
                        collaborations=[],
                        keywords=["keyword"],
                        publication_id=1,
                        title="Title",
                        topics=["category"],
                    ),
                    signature_block="JOhn",
                    signature_uuid="94fc2b0a-dc17-42c2-bae3-ca0024079e52",
                    is_curated_author_id=True),
                Signature(
                    author_affiliation="Rutgers U., Piscataway",
                    author_id=2,
                    author_name="Doe, John",
                    publication=Publication(
                        abstract="Many curated authors",
                        authors=[
                            "Doe, John",
                            "Doe, J",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Doe, John",
                            "Jamie",
                            "Jamie",
                        ],
                        collaborations=[],
                        keywords=["keyword"],
                        publication_id=1,
                        title="Title",
                        topics=["category"],
                    ),
                    signature_block="JOhn",
                    signature_uuid="94fc2b0a-dc17-42c2-bae3-ca0024079e54",
                    is_curated_author_id=True),
            ],
        ],
        dtype=object,
    )
    expected_y = array([0, 0, 1, 1])
    assert (distance_estimator.X == expected_X).all()
    assert (distance_estimator.y == expected_y).all()