Beispiel #1
0
def test_query(query_strategy,
               n_features=50,
               n_sample=100,
               n_instances_list=[0, 1, 5, 50],
               n_train_idx=[0, 1, 5, 50]):
    classifier = get_classifier("rf")

    query_model = get_query_model(query_strategy)
    X = np.random.rand(n_sample, n_features)

    y = np.concatenate((np.zeros(n_sample // 2), np.ones(n_sample // 2)),
                       axis=0)
    print(X.shape, y.shape)
    order = np.random.permutation(n_sample)
    print(order.shape)
    X = X[order]
    y = y[order]
    sources = query_strategy.split('_')

    classifier.fit(X, y)

    assert isinstance(query_model.param, dict)
    assert query_model.name == query_strategy

    for n_instances in n_instances_list:
        for n_train in n_train_idx:
            shared = {"query_src": {}, "current_queries": {}}
            train_idx = np.random.choice(np.arange(n_sample),
                                         n_train,
                                         replace=False)
            pool_idx = np.delete(np.arange(n_sample), train_idx)
            query_idx, X_query = query_model.query(X, classifier, pool_idx,
                                                   n_instances, shared)
            check_integrity(query_idx, X_query, X, pool_idx, shared,
                            n_instances, sources)
Beispiel #2
0
    def from_file(self, config_file):
        """Fill the contents of settings by reading a config file.

        Arguments
        ---------
        config_file: str
            Source configuration file.

        """
        if config_file is None or not os.path.isfile(config_file):
            if config_file is not None:
                print(f"Didn't find configuration file: {config_file}")
            return

        config = ConfigParser()
        config.optionxform = str
        config.read(config_file)

        # Read the each of the sections.
        for sect in config:
            if sect == "global_settings":
                for key, value in config.items(sect):
                    try:
                        setattr(self, key, SETTINGS_TYPE_DICT[key](value))
                    except (KeyError, TypeError):
                        print(f"Warning: value with key '{key}' is ignored "
                              "(spelling mistake, wrong type?).")

            elif sect in [
                    "model_param", "query_param", "balance_param",
                    "feature_param"
            ]:
                setattr(self, sect, dict(config.items(sect)))
            elif sect != "DEFAULT":
                print(f"Warning: section [{sect}] is ignored in "
                      f"config file {config_file}")

        model = get_classifier(self.model)
        _convert_types(model.default_param, self.model_param)
        balance_model = get_balance_model(self.balance_strategy)
        _convert_types(balance_model.default_param, self.balance_param)
        query_model = get_query_model(self.query_strategy)
        _convert_types(query_model.default_param, self.query_param)
        feature_model = get_feature_model(self.feature_extraction)
        _convert_types(feature_model.default_param, self.feature_param)
Beispiel #3
0
def get_reviewer(dataset,
                 mode="simulate",
                 model=DEFAULT_MODEL,
                 query_strategy=DEFAULT_QUERY_STRATEGY,
                 balance_strategy=DEFAULT_BALANCE_STRATEGY,
                 feature_extraction=DEFAULT_FEATURE_EXTRACTION,
                 n_instances=DEFAULT_N_INSTANCES,
                 n_papers=None,
                 n_queries=None,
                 embedding_fp=None,
                 verbose=0,
                 prior_idx=None,
                 prior_record_id=None,
                 n_prior_included=DEFAULT_N_PRIOR_INCLUDED,
                 n_prior_excluded=DEFAULT_N_PRIOR_EXCLUDED,
                 config_file=None,
                 state_file=None,
                 model_param=None,
                 query_param=None,
                 balance_param=None,
                 feature_param=None,
                 seed=None,
                 included_dataset=[],
                 excluded_dataset=[],
                 prior_dataset=[],
                 new=False,
                 **kwargs):
    """Get a review object from arguments.

    See __main__.py for a description of the arguments.
    """
    as_data = create_as_data(dataset,
                             included_dataset,
                             excluded_dataset,
                             prior_dataset,
                             new=new)

    if len(as_data) == 0:
        raise ValueError("Supply at least one dataset"
                         " with at least one record.")

    cli_settings = ASReviewSettings(model=model,
                                    n_instances=n_instances,
                                    n_queries=n_queries,
                                    n_papers=n_papers,
                                    n_prior_included=n_prior_included,
                                    n_prior_excluded=n_prior_excluded,
                                    query_strategy=query_strategy,
                                    balance_strategy=balance_strategy,
                                    feature_extraction=feature_extraction,
                                    mode=mode,
                                    data_fp=None)
    cli_settings.from_file(config_file)

    if state_file is not None:
        with open_state(state_file) as state:
            if state.is_empty():
                state.settings = cli_settings
            settings = state.settings
    else:
        settings = cli_settings

    if n_queries is not None:
        settings.n_queries = n_queries
    if n_papers is not None:
        settings.n_papers = n_papers

    if model_param is not None:
        settings.model_param = model_param
    if query_param is not None:
        settings.query_param = query_param
    if balance_param is not None:
        settings.balance_param = balance_param
    if feature_param is not None:
        settings.feature_param = feature_param

    # Check if mode is valid
    if mode in AVAILABLE_REVIEW_CLASSES:
        logging.info(f"Start review in '{mode}' mode.")
    else:
        raise ValueError(f"Unknown mode '{mode}'.")
    logging.debug(settings)

    # Initialize models.
    random_state = get_random_state(seed)
    train_model = get_classifier(settings.model,
                                 **settings.model_param,
                                 random_state=random_state)
    query_model = get_query_model(settings.query_strategy,
                                  **settings.query_param,
                                  random_state=random_state)
    balance_model = get_balance_model(settings.balance_strategy,
                                      **settings.balance_param,
                                      random_state=random_state)
    feature_model = get_feature_model(settings.feature_extraction,
                                      **settings.feature_param,
                                      random_state=random_state)

    # LSTM models need embedding matrices.
    if train_model.name.startswith("lstm-"):
        texts = as_data.texts
        train_model.embedding_matrix = feature_model.get_embedding_matrix(
            texts, embedding_fp)

    # prior knowledge
    if prior_idx is not None and prior_record_id is not None and \
            len(prior_idx) > 0 and len(prior_record_id) > 0:
        raise ValueError(
            "Not possible to provide both prior_idx and prior_record_id")
    if prior_record_id is not None and len(prior_record_id) > 0:
        prior_idx = convert_id_to_idx(as_data, prior_record_id)

    # Initialize the review class.
    if mode == "simulate":
        reviewer = ReviewSimulate(as_data,
                                  model=train_model,
                                  query_model=query_model,
                                  balance_model=balance_model,
                                  feature_model=feature_model,
                                  n_papers=settings.n_papers,
                                  n_instances=settings.n_instances,
                                  n_queries=settings.n_queries,
                                  prior_idx=prior_idx,
                                  n_prior_included=settings.n_prior_included,
                                  n_prior_excluded=settings.n_prior_excluded,
                                  state_file=state_file,
                                  **kwargs)
    elif mode == "minimal":
        reviewer = MinimalReview(as_data,
                                 model=train_model,
                                 query_model=query_model,
                                 balance_model=balance_model,
                                 feature_model=feature_model,
                                 n_papers=settings.n_papers,
                                 n_instances=settings.n_instances,
                                 n_queries=settings.n_queries,
                                 state_file=state_file,
                                 **kwargs)
    else:
        raise ValueError("Error finding mode, should never come here...")

    return reviewer