コード例 #1
0
ファイル: factory.py プロジェクト: fransfolkvord/asreview
def get_reviewer(dataset,
                 mode="simulate",
                 model=DEFAULT_MODEL,
                 query_strategy=DEFAULT_QUERY_STRATEGY,
                 balance_strategy=DEFAULT_BALANCE_STRATEGY,
                 feature_extraction=DEFAULT_FEATURE_EXTRACTION,
                 n_instances=DEFAULT_N_INSTANCES,
                 n_papers=None,
                 n_queries=None,
                 embedding_fp=None,
                 verbose=0,
                 prior_idx=None,
                 prior_record_id=None,
                 n_prior_included=DEFAULT_N_PRIOR_INCLUDED,
                 n_prior_excluded=DEFAULT_N_PRIOR_EXCLUDED,
                 config_file=None,
                 state_file=None,
                 model_param=None,
                 query_param=None,
                 balance_param=None,
                 feature_param=None,
                 seed=None,
                 included_dataset=[],
                 excluded_dataset=[],
                 prior_dataset=[],
                 new=False,
                 **kwargs):
    """Get a review object from arguments.

    See __main__.py for a description of the arguments.
    """
    as_data = create_as_data(dataset,
                             included_dataset,
                             excluded_dataset,
                             prior_dataset,
                             new=new)

    if len(as_data) == 0:
        raise ValueError("Supply at least one dataset"
                         " with at least one record.")

    cli_settings = ASReviewSettings(model=model,
                                    n_instances=n_instances,
                                    n_queries=n_queries,
                                    n_papers=n_papers,
                                    n_prior_included=n_prior_included,
                                    n_prior_excluded=n_prior_excluded,
                                    query_strategy=query_strategy,
                                    balance_strategy=balance_strategy,
                                    feature_extraction=feature_extraction,
                                    mode=mode,
                                    data_fp=None)
    cli_settings.from_file(config_file)

    if state_file is not None:
        with open_state(state_file) as state:
            if state.is_empty():
                state.settings = cli_settings
            settings = state.settings
    else:
        settings = cli_settings

    if n_queries is not None:
        settings.n_queries = n_queries
    if n_papers is not None:
        settings.n_papers = n_papers

    if model_param is not None:
        settings.model_param = model_param
    if query_param is not None:
        settings.query_param = query_param
    if balance_param is not None:
        settings.balance_param = balance_param
    if feature_param is not None:
        settings.feature_param = feature_param

    # Check if mode is valid
    if mode in AVAILABLE_REVIEW_CLASSES:
        logging.info(f"Start review in '{mode}' mode.")
    else:
        raise ValueError(f"Unknown mode '{mode}'.")
    logging.debug(settings)

    # Initialize models.
    random_state = get_random_state(seed)
    train_model = get_classifier(settings.model,
                                 **settings.model_param,
                                 random_state=random_state)
    query_model = get_query_model(settings.query_strategy,
                                  **settings.query_param,
                                  random_state=random_state)
    balance_model = get_balance_model(settings.balance_strategy,
                                      **settings.balance_param,
                                      random_state=random_state)
    feature_model = get_feature_model(settings.feature_extraction,
                                      **settings.feature_param,
                                      random_state=random_state)

    # LSTM models need embedding matrices.
    if train_model.name.startswith("lstm-"):
        texts = as_data.texts
        train_model.embedding_matrix = feature_model.get_embedding_matrix(
            texts, embedding_fp)

    # prior knowledge
    if prior_idx is not None and prior_record_id is not None and \
            len(prior_idx) > 0 and len(prior_record_id) > 0:
        raise ValueError(
            "Not possible to provide both prior_idx and prior_record_id")
    if prior_record_id is not None and len(prior_record_id) > 0:
        prior_idx = convert_id_to_idx(as_data, prior_record_id)

    # Initialize the review class.
    if mode == "simulate":
        reviewer = ReviewSimulate(as_data,
                                  model=train_model,
                                  query_model=query_model,
                                  balance_model=balance_model,
                                  feature_model=feature_model,
                                  n_papers=settings.n_papers,
                                  n_instances=settings.n_instances,
                                  n_queries=settings.n_queries,
                                  prior_idx=prior_idx,
                                  n_prior_included=settings.n_prior_included,
                                  n_prior_excluded=settings.n_prior_excluded,
                                  state_file=state_file,
                                  **kwargs)
    elif mode == "minimal":
        reviewer = MinimalReview(as_data,
                                 model=train_model,
                                 query_model=query_model,
                                 balance_model=balance_model,
                                 feature_model=feature_model,
                                 n_papers=settings.n_papers,
                                 n_instances=settings.n_instances,
                                 n_queries=settings.n_queries,
                                 state_file=state_file,
                                 **kwargs)
    else:
        raise ValueError("Error finding mode, should never come here...")

    return reviewer
コード例 #2
0
def train_model(project_id, label_method=None):
    """Add the new labels to the review and do the modeling.

    It uses a lock to ensure only one model is running at the same time.
    Old results directories are deleted after 4 iterations.

    It has one argument on the CLI, which is the base project directory.
    """

    logging.info(f"Project {project_id} - Train a new model for project")

    # get file locations
    asr_kwargs_file = get_kwargs_path(project_id)
    lock_file = get_lock_path(project_id)

    # Lock so that only one training run is running at the same time.
    # It doesn't lock the flask server/client.
    with SQLiteLock(lock_file,
                    blocking=False,
                    lock_name="training",
                    project_id=project_id) as lock:

        # If the lock is not acquired, another training instance is running.
        if not lock.locked():
            logging.info("Project {project_id} - "
                         "Cannot acquire lock, other instance running.")
            return

        # Lock the current state. We want to have a consistent active state.
        # This does communicate with the flask backend; it prevents writing and
        # reading to the same files at the same time.
        with SQLiteLock(lock_file,
                        blocking=True,
                        lock_name="active",
                        project_id=project_id) as lock:
            # Get the all labels since last run. If no new labels, quit.
            new_label_history = read_label_history(project_id)

        data_fp = str(get_data_file_path(project_id))
        as_data = read_data(project_id)
        state_file = get_state_path(project_id)

        # collect command line arguments and pass them to the reviewer
        with open(asr_kwargs_file, "r") as fp:
            asr_kwargs = json.load(fp)
        asr_kwargs['state_file'] = str(state_file)
        reviewer = get_reviewer(dataset=data_fp, mode="minimal", **asr_kwargs)

        with open_state(state_file) as state:
            old_label_history = _get_label_train_history(state)

        diff_history = _get_diff_history(new_label_history, old_label_history)

        if len(diff_history) == 0:
            logging.info(
                "Project {project_id} - No new labels since last run.")
            return

        query_record_ids = np.array([x[0] for x in diff_history], dtype=int)
        inclusions = np.array([x[1] for x in diff_history], dtype=int)

        query_idx = convert_id_to_idx(as_data, query_record_ids)

        # Classify the new labels, train and store the results.
        with open_state(state_file) as state:
            reviewer.classify(query_idx,
                              inclusions,
                              state,
                              method=label_method)
            reviewer.train()
            reviewer.log_probabilities(state)
            new_query_idx = reviewer.query(reviewer.n_pool()).tolist()
            reviewer.log_current_query(state)

            # write the proba to a pandas dataframe with record_ids as index
            proba = pd.DataFrame({"proba": state.pred_proba.tolist()},
                                 index=pd.Index(as_data.record_ids,
                                                name="record_id"))

        # update the pool and output the proba's
        # important: pool is sorted on query
        with SQLiteLock(lock_file,
                        blocking=True,
                        lock_name="active",
                        project_id=project_id) as lock:

            # read the pool
            current_pool = read_pool(project_id)

            # diff pool and new_query_ind
            current_pool_idx = convert_id_to_idx(as_data, current_pool)
            current_pool_idx = frozenset(current_pool_idx)
            new_pool_idx = [x for x in new_query_idx if x in current_pool_idx]

            # convert new_pool_idx back to record_ids
            new_pool = convert_idx_to_id(as_data, new_pool_idx)

            # write the pool and proba
            write_pool(project_id, new_pool)
            write_proba(project_id, proba)