예제 #1
0
    def __init__(self,
                 strategy_1="max",
                 strategy_2="random",
                 mix_ratio=0.95,
                 random_state=None,
                 **kwargs):
        """Initialize the Mixed query strategy."""
        super(MixedQuery, self).__init__()
        kwargs_1 = {}
        kwargs_2 = {}
        for key, value in kwargs.items():
            if key.startswith(strategy_1):
                new_key = key[len(strategy_1) + 1:]
                kwargs_1[new_key] = value
            elif key.starts_with(strategy_2):
                new_key = key[len(strategy_2) + 1:]
                kwargs_2[new_key] = value
            else:
                logging.warn(f"Key {key} is being ignored for the mixed "
                             "({strategy_1}, {strategy_2}) query strategy.")

        self.strategy_1 = strategy_1
        self.strategy_2 = strategy_2

        self.query_model1 = get_query_model(strategy_1, **kwargs_1)
        self.query_model2 = get_query_model(strategy_2, **kwargs_2)

        self._random_state = get_random_state(random_state)
        if "random_state" in self.query_model1.default_param:
            self.query_model1 = get_query_model(
                strategy_1, **kwargs_1, random_state=self._random_state)
        if "random_state" in self.query_model2.default_param:
            self.query_model2 = get_query_model(
                strategy_2, **kwargs_2, random_state=self._random_state)
        self.mix_ratio = mix_ratio
예제 #2
0
파일: mixed.py 프로젝트: openefsa/asreview
    def __init__(self,
                 strategy_1="max",
                 strategy_2="random",
                 mix_ratio=0.95,
                 random_state=None,
                 **kwargs):
        """Initialize the Mixed query strategy

        Arguments
        ---------
        strategy_1: str
            Name of the first query strategy.
        strategy_2: str
            Name of the second query strategy.
        mix_ratio: float
            Portion of queries done by the first strategy. So a mix_ratio of
            0.95 means that 95% of the time query strategy 1 is used and 5% of
            the time query strategy 2.
        **kwargs: dict
            Keyword arguments for the two strategy. To specify which of the
            strategies the argument is for, prepend with the name of the query
            strategy and an underscore, e.g. 'max_' for maximal sampling.
        """
        super(MixedQuery, self).__init__()
        kwargs_1 = {}
        kwargs_2 = {}
        for key, value in kwargs.items():
            if key.startswith(strategy_1):
                new_key = key[len(strategy_1) + 1:]
                kwargs_1[new_key] = value
            elif key.starts_with(strategy_2):
                new_key = key[len(strategy_2) + 1:]
                kwargs_2[new_key] = value
            else:
                logging.warn(f"Key {key} is being ignored for the mixed "
                             "({strategy_1}, {strategy_2}) query strategy.")

        self.strategy_1 = strategy_1
        self.strategy_2 = strategy_2

        self.query_model1 = get_query_model(strategy_1, **kwargs_1)
        self.query_model2 = get_query_model(strategy_2, **kwargs_2)

        self._random_state = get_random_state(random_state)
        if "random_state" in self.query_model1.default_param:
            self.query_model1 = get_query_model(
                strategy_1, **kwargs_1, random_state=self._random_state)
        if "random_state" in self.query_model2.default_param:
            self.query_model2 = get_query_model(
                strategy_2, **kwargs_2, random_state=self._random_state)
        self.mix_ratio = mix_ratio
def test_query(query_strategy,
               n_features=50,
               n_sample=100,
               n_instances_list=[0, 1, 5, 50],
               n_train_idx=[0, 1, 5, 50]):
    classifier = get_model("rf")
    if query_strategy == "cluster":
        data_fp = os.path.join("test", "demo_data", "generic.csv")
        texts = ASReviewData.from_file(data_fp).texts
        while len(texts) < n_features:
            texts = np.append(texts, texts)
            print(len(texts))


#             texts.extend(texts)
        texts = texts[:n_features]
        query_model = get_query_model(query_strategy,
                                      texts=texts,
                                      update_interval=None,
                                      cluster_size=int(n_sample / 3))
        assert isinstance(query_model.param, dict)
    else:
        query_model = get_query_model(query_strategy)
    X = np.random.rand(n_sample, n_features)

    y = np.concatenate((np.zeros(n_sample // 2), np.ones(n_sample // 2)),
                       axis=0)
    print(X.shape, y.shape)
    order = np.random.permutation(n_sample)
    print(order.shape)
    X = X[order]
    y = y[order]
    sources = query_strategy.split('_')

    classifier.fit(X, y)

    assert isinstance(query_model.param, dict)
    assert query_model.name == query_strategy

    for n_instances in n_instances_list:
        for n_train in n_train_idx:
            shared = {"query_src": {}, "current_queries": {}}
            train_idx = np.random.choice(np.arange(n_sample),
                                         n_train,
                                         replace=False)
            pool_idx = np.delete(np.arange(n_sample), train_idx)
            query_idx, X_query = query_model.query(X, classifier, pool_idx,
                                                   n_instances, shared)
            check_integrity(query_idx, X_query, X, pool_idx, shared,
                            n_instances, sources)
예제 #4
0
def test_query(query_strategy,
               n_features=50,
               n_sample=100,
               n_instances_list=[0, 1, 5, 50],
               n_train_idx=[0, 1, 5, 50]):
    classifier = get_model("rf")

    query_model = get_query_model(query_strategy)
    X = np.random.rand(n_sample, n_features)

    y = np.concatenate((np.zeros(n_sample // 2), np.ones(n_sample // 2)),
                       axis=0)
    print(X.shape, y.shape)
    order = np.random.permutation(n_sample)
    print(order.shape)
    X = X[order]
    y = y[order]
    sources = query_strategy.split('_')

    classifier.fit(X, y)

    assert isinstance(query_model.param, dict)
    assert query_model.name == query_strategy

    for n_instances in n_instances_list:
        for n_train in n_train_idx:
            shared = {"query_src": {}, "current_queries": {}}
            train_idx = np.random.choice(np.arange(n_sample),
                                         n_train,
                                         replace=False)
            pool_idx = np.delete(np.arange(n_sample), train_idx)
            query_idx, X_query = query_model.query(X, classifier, pool_idx,
                                                   n_instances, shared)
            check_integrity(query_idx, X_query, X, pool_idx, shared,
                            n_instances, sources)
예제 #5
0
 def get_hyper_space(self):
     model_hs, model_hc = get_model(self.model_name).hyper_space()
     query_hs, query_hc = get_query_model(self.query_name).hyper_space()
     balance_hs, balance_hc = get_balance_model(
         self.balance_name).hyper_space()
     feature_hs, feature_hc = get_feature_model(
         self.feature_name).hyper_space()
     hyper_space = {**model_hs, **query_hs, **balance_hs, **feature_hs}
     hyper_choices = {**model_hc, **query_hc, **balance_hc, **feature_hc}
     return hyper_space, hyper_choices
예제 #6
0
    def from_file(self, config_file):
        """ Fill the contents of settings by reading a config file.

        Arguments
        ---------
        config_file: str
            Source configuration file.

        """
        if config_file is None or not os.path.isfile(config_file):
            if config_file is not None:
                print(f"Didn't find configuration file: {config_file}")
            return

        config = ConfigParser()
        config.optionxform = str
        config.read(config_file)

        # Read the each of the sections.
        for sect in config:
            if sect == "global_settings":
                for key, value in config.items(sect):
                    try:
                        setattr(self, key, SETTINGS_TYPE_DICT[key](value))
                    except (KeyError, TypeError):
                        print(f"Warning: value with key '{key}' is ignored "
                              "(spelling mistake, wrong type?).")

            elif sect in [
                    "model_param", "query_param", "balance_param",
                    "feature_param"
            ]:
                setattr(self, sect, dict(config.items(sect)))
            elif sect != "DEFAULT":
                print(f"Warning: section [{sect}] is ignored in "
                      f"config file {config_file}")

        model = get_model(self.model)
        _convert_types(model.default_param, self.model_param)
        balance_model = get_balance_model(self.balance_strategy)
        _convert_types(balance_model.default_param, self.balance_param)
        query_model = get_query_model(self.query_strategy)
        _convert_types(query_model.default_param, self.query_param)
        feature_model = get_feature_model(self.feature_extraction)
        _convert_types(feature_model.default_param, self.feature_param)
예제 #7
0
def get_reviewer(dataset,
                 mode="simulate",
                 model=DEFAULT_MODEL,
                 query_strategy=DEFAULT_QUERY_STRATEGY,
                 balance_strategy=DEFAULT_BALANCE_STRATEGY,
                 feature_extraction=DEFAULT_FEATURE_EXTRACTION,
                 n_instances=DEFAULT_N_INSTANCES,
                 n_papers=None,
                 n_queries=None,
                 embedding_fp=None,
                 verbose=0,
                 prior_idx=None,
                 n_prior_included=DEFAULT_N_PRIOR_INCLUDED,
                 n_prior_excluded=DEFAULT_N_PRIOR_EXCLUDED,
                 config_file=None,
                 state_file=None,
                 model_param=None,
                 query_param=None,
                 balance_param=None,
                 feature_param=None,
                 abstract_only=False,
                 included_dataset=[],
                 excluded_dataset=[],
                 prior_dataset=[],
                 new=False,
                 **kwargs):
    """Get a review object from arguments.

    See __main__.py for a description of the arguments.
    """
    as_data = create_as_data(dataset,
                             included_dataset,
                             excluded_dataset,
                             prior_dataset,
                             new=new)

    if len(as_data) == 0:
        raise ValueError("Supply at least one dataset"
                         " with at least one record.")

    cli_settings = ASReviewSettings(model=model,
                                    n_instances=n_instances,
                                    n_queries=n_queries,
                                    n_papers=n_papers,
                                    n_prior_included=n_prior_included,
                                    n_prior_excluded=n_prior_excluded,
                                    query_strategy=query_strategy,
                                    balance_strategy=balance_strategy,
                                    feature_extraction=feature_extraction,
                                    mode=mode,
                                    data_fp=None,
                                    abstract_only=abstract_only)
    cli_settings.from_file(config_file)

    if state_file is not None:
        with open_state(state_file) as state:
            if state.is_empty():
                state.settings = cli_settings
            settings = state.settings
    else:
        settings = cli_settings

    if n_queries is not None:
        settings.n_queries = n_queries
    if n_papers is not None:
        settings.n_papers = n_papers

    if model_param is not None:
        settings.model_param = model_param
    if query_param is not None:
        settings.query_param = query_param
    if balance_param is not None:
        settings.balance_param = balance_param
    if feature_param is not None:
        settings.feature_param = feature_param

    # Check if mode is valid
    if mode in AVAILABLE_REVIEW_CLASSES:
        logging.info(f"Start review in '{mode}' mode.")
    else:
        raise ValueError(f"Unknown mode '{mode}'.")
    logging.debug(settings)

    train_model = get_model(settings.model, **settings.model_param)
    query_model = get_query_model(settings.query_strategy,
                                  **settings.query_param)
    balance_model = get_balance_model(settings.balance_strategy,
                                      **settings.balance_param)
    feature_model = get_feature_model(settings.feature_extraction,
                                      **settings.feature_param)

    if train_model.name.startswith("lstm-"):
        texts = as_data.texts
        train_model.embedding_matrix = feature_model.get_embedding_matrix(
            texts, embedding_fp)

    # Initialize the review class.
    if mode == "simulate":
        reviewer = ReviewSimulate(as_data,
                                  model=train_model,
                                  query_model=query_model,
                                  balance_model=balance_model,
                                  feature_model=feature_model,
                                  n_papers=settings.n_papers,
                                  n_instances=settings.n_instances,
                                  n_queries=settings.n_queries,
                                  verbose=verbose,
                                  prior_idx=prior_idx,
                                  n_prior_included=settings.n_prior_included,
                                  n_prior_excluded=settings.n_prior_excluded,
                                  state_file=state_file,
                                  data_fp=dataset,
                                  **kwargs)
    elif mode == "minimal":
        reviewer = MinimalReview(as_data,
                                 model=train_model,
                                 query_model=query_model,
                                 balance_model=balance_model,
                                 feature_model=feature_model,
                                 n_papers=settings.n_papers,
                                 n_instances=settings.n_instances,
                                 n_queries=settings.n_queries,
                                 verbose=verbose,
                                 state_file=state_file,
                                 data_fp=dataset,
                                 **kwargs)
    else:
        raise ValueError("Error finding mode, should never come here...")

    return reviewer
def get_reviewer(dataset,
                 mode='oracle',
                 model=DEFAULT_MODEL,
                 query_strategy=DEFAULT_QUERY_STRATEGY,
                 balance_strategy=DEFAULT_BALANCE_STRATEGY,
                 feature_extraction=DEFAULT_FEATURE_EXTRACTION,
                 n_instances=DEFAULT_N_INSTANCES,
                 n_papers=None,
                 n_queries=None,
                 embedding_fp=None,
                 verbose=0,
                 prior_included=None,
                 prior_excluded=None,
                 n_prior_included=DEFAULT_N_PRIOR_INCLUDED,
                 n_prior_excluded=DEFAULT_N_PRIOR_EXCLUDED,
                 config_file=None,
                 log_file=None,
                 model_param=None,
                 query_param=None,
                 balance_param=None,
                 feature_param=None,
                 abstract_only=False,
                 extra_dataset=[],
                 **kwargs
                 ):
    """ Get a review object from arguments. See __main__.py for a description
        Of the arguments.
    """

    # Find the URL of the datasets if the dataset is an example dataset.
    if dataset in DEMO_DATASETS.keys():
        dataset = DEMO_DATASETS[dataset]

    cli_settings = ASReviewSettings(
        model=model, n_instances=n_instances, n_queries=n_queries,
        n_papers=n_papers, n_prior_included=n_prior_included,
        n_prior_excluded=n_prior_excluded, query_strategy=query_strategy,
        balance_strategy=balance_strategy,
        feature_extraction=feature_extraction,
        mode=mode, data_fp=dataset,
        abstract_only=abstract_only)
    cli_settings.from_file(config_file)

    if log_file is not None:
        with open_logger(log_file) as logger:
            if logger.is_empty():
                logger.add_settings(cli_settings)
            settings = logger.settings
    else:
        settings = cli_settings
        logger = None

    if n_queries is not None:
        settings.n_queries = n_queries
    if n_papers is not None:
        settings.n_papers = n_papers

    if model_param is not None:
        settings.model_param = model_param
    if query_param is not None:
        settings.query_param = query_param
    if balance_param is not None:
        settings.balance_param = balance_param
    if feature_param is not None:
        settings.feature_param = feature_param

    # Check if mode is valid
    if mode in AVAILABLE_REVIEW_CLASSES:
        logging.info(f"Start review in '{mode}' mode.")
    else:
        raise ValueError(f"Unknown mode '{mode}'.")
    logging.debug(settings)

    as_data = ASReviewData.from_file(dataset, extra_dataset=extra_dataset,
                                     abstract_only=settings.abstract_only)
    texts = as_data.texts
    y = as_data.labels

    data_prior_included, data_prior_excluded = as_data.get_priors()
    if len(data_prior_included) != 0:
        if prior_included is None:
            prior_included = []
        prior_included.extend(data_prior_included.tolist())
    if len(data_prior_excluded) != 0:
        if prior_excluded is None:
            prior_excluded = []
        prior_excluded.extend(data_prior_excluded.tolist())

    if as_data.final_labels is not None:
        with open_logger(log_file) as logger:
            logger.set_final_labels(as_data.final_labels)

    train_model = get_model(settings.model, **settings.model_param)
    query_model = get_query_model(settings.query_strategy,
                                  **settings.query_param)
    balance_model = get_balance_model(settings.balance_strategy,
                                      **settings.balance_param)
    feature_model = get_feature_model(settings.feature_extraction,
                                      **settings.feature_param)

    X = feature_model.fit_transform(texts, as_data.title, as_data.abstract)

    if train_model.name.startswith("lstm-"):
        train_model.embedding_matrix = feature_model.get_embedding_matrix(
            texts, embedding_fp)

    # Initialize the review class.
    if mode == "simulate":
        reviewer = ReviewSimulate(
            X, y,
            model=train_model,
            query_model=query_model,
            balance_model=balance_model,
            feature_model=feature_model,
            n_papers=settings.n_papers,
            n_instances=settings.n_instances,
            n_queries=settings.n_queries,
            verbose=verbose,
            prior_included=prior_included,
            prior_excluded=prior_excluded,
            n_prior_included=settings.n_prior_included,
            n_prior_excluded=settings.n_prior_excluded,
            log_file=log_file,
            final_labels=as_data.final_labels,
            data_fp=dataset,
            **kwargs)
    elif mode == "oracle":
        reviewer = ReviewOracle(
            X,
            model=train_model,
            query_model=query_model,
            balance_model=balance_model,
            feature_model=feature_model,
            as_data=as_data,
            n_papers=settings.n_papers,
            n_instances=settings.n_instances,
            n_queries=settings.n_queries,
            verbose=verbose,
            prior_included=prior_included,
            prior_excluded=prior_excluded,
            log_file=log_file,
            data_fp=dataset,
            **kwargs)
    elif mode == "minimal":
        reviewer = MinimalReview(
            X,
            model=model,
            query_model=query_model,
            balance_model=balance_model,
            feature_model=feature_model,
            n_papers=settings.n_papers,
            n_instances=settings.n_instances,
            n_queries=settings.n_queries,
            verbose=verbose,
            prior_included=prior_included,
            prior_excluded=prior_excluded,
            log_file=log_file,
            data_fp=dataset,
            **kwargs)
    else:
        raise ValueError("Error finding mode, should never come here...")

    return reviewer