Пример #1
0
    def settings(self):
        if self.prior_included is not None:
            n_prior_included = len(self.prior_included)
        else:
            n_prior_included = 0
        if self.prior_excluded is not None:
            n_prior_excluded = len(self.prior_excluded)
        else:
            n_prior_excluded = 0

        return ASReviewSettings(mode=self.name,
                                model=self.model.name,
                                query_strategy=self.query_model.name,
                                balance_strategy=self.balance_model.name,
                                feature_extraction=self.feature_model.name,
                                n_instances=self.n_instances,
                                n_queries=self.n_queries,
                                n_papers=self.n_papers,
                                n_prior_included=n_prior_included,
                                n_prior_excluded=n_prior_excluded,
                                model_param=self.model.param,
                                query_param=self.query_model.param,
                                balance_param=self.balance_model.param,
                                feature_param=self.feature_model.param,
                                data_fp=self.data_fp)
Пример #2
0
def check_write_logger(tmpdir, log_file):
    if log_file is not None:
        log_fp = os.path.join(tmpdir, log_file)
    else:
        log_fp = None

    settings = ASReviewSettings(mode="simulate", model="nb",
                                query_strategy="rand_max",
                                balance_strategy="simple")

    n_records = 6
    n_half = int(n_records/2)
    start_labels = np.full(n_records, np.nan, dtype=np.int)
    labels = np.zeros(n_records, dtype=np.int)
    labels[::2] = np.ones(n_half, dtype=np.int)
    methods = np.full((n_records), "initial")
    methods[2::] = np.full((int(n_records-2)), "random")
    methods[2::2] = np.full((int((n_records-2)/2)), "max")

    with open_logger(log_fp) as logger:
        logger.add_settings(settings)
        logger.set_labels(start_labels)
        current_labels = np.copy(start_labels)
        for i in range(n_records):
            query_i = int(i/2)
            proba = None
            if i >= 2 and (i % 2) == 0:
                proba = np.random.rand(n_records)
            logger.add_classification([i], [labels[i]], [methods[i]], query_i)
            if proba is not None:
                logger.add_proba(np.arange(i+1, n_records), np.arange(i+1),
                                 proba, query_i)
            current_labels[i] = labels[i]
            logger.set_labels(current_labels)
            check_logger(logger, i, query_i, labels, methods, proba)
Пример #3
0
 def restore(self, fp):
     try:
         with open(fp, "r") as f:
             self._log_dict = OrderedDict(json.load(f))
         log_version = self._log_dict["version"]
         if log_version != self.version:
             raise ValueError(
                 f"Log cannot be read: logger version {self.version}, "
                 f"logfile version {log_version}.")
         self.settings = ASReviewSettings(**self._log_dict["settings"])
     except FileNotFoundError:
         self.initialize_structure()
Пример #4
0
 def restore(self, fp):
     try:
         with open(fp, "r") as f:
             self._state_dict = OrderedDict(json.load(f))
         state_version = self._state_dict["version"]
         if state_version != self.version:
             raise ValueError(
                 f"State cannot be read: state version {self.version}, "
                 f"state file version {state_version}.")
         self.settings = ASReviewSettings(**self._state_dict["settings"])
         try:
             self._state_dict["current_queries"] = {
                 int(key): val
                 for key, val in self._state_dict["current_queries"].items()}
         except KeyError:
             pass
     except FileNotFoundError:
         self.initialize_structure()
Пример #5
0
    def restore(self, fp):
        if self.read_only:
            mode = 'r'
        else:
            mode = 'a'

        Path(fp).parent.mkdir(parents=True, exist_ok=True)
        self.f = h5py.File(fp, mode)
        try:
            log_version = self.f.attrs['version'].decode("ascii")
            if log_version != self.version:
                raise ValueError(
                    f"Log cannot be read: logger version {self.version}, "
                    f"logfile version {log_version}.")
            settings_dict = json.loads(self.f.attrs['settings'])
            if "mode" in settings_dict:
                self.settings = ASReviewSettings(**settings_dict)
        except KeyError:
            self.initialize_structure()
Пример #6
0
 def settings(self):
     extra_kwargs = {}
     if hasattr(self, 'n_prior_included'):
         extra_kwargs['n_prior_included'] = self.n_prior_included
     if hasattr(self, 'n_prior_excluded'):
         extra_kwargs['n_prior_excluded'] = self.n_prior_excluded
     return ASReviewSettings(mode=self.name,
                             model=self.model.name,
                             query_strategy=self.query_model.name,
                             balance_strategy=self.balance_model.name,
                             feature_extraction=self.feature_model.name,
                             n_instances=self.n_instances,
                             n_queries=self.n_queries,
                             n_papers=self.n_papers,
                             model_param=self.model.param,
                             query_param=self.query_model.param,
                             balance_param=self.balance_model.param,
                             feature_param=self.feature_model.param,
                             data_name=self.as_data.data_name,
                             **extra_kwargs)
Пример #7
0
def get_reviewer(dataset,
                 mode="simulate",
                 model=DEFAULT_MODEL,
                 query_strategy=DEFAULT_QUERY_STRATEGY,
                 balance_strategy=DEFAULT_BALANCE_STRATEGY,
                 feature_extraction=DEFAULT_FEATURE_EXTRACTION,
                 n_instances=DEFAULT_N_INSTANCES,
                 n_papers=None,
                 n_queries=None,
                 embedding_fp=None,
                 verbose=0,
                 prior_idx=None,
                 prior_record_id=None,
                 n_prior_included=DEFAULT_N_PRIOR_INCLUDED,
                 n_prior_excluded=DEFAULT_N_PRIOR_EXCLUDED,
                 config_file=None,
                 state_file=None,
                 model_param=None,
                 query_param=None,
                 balance_param=None,
                 feature_param=None,
                 seed=None,
                 included_dataset=[],
                 excluded_dataset=[],
                 prior_dataset=[],
                 new=False,
                 **kwargs):
    """Get a review object from arguments.

    See __main__.py for a description of the arguments.
    """
    as_data = create_as_data(dataset,
                             included_dataset,
                             excluded_dataset,
                             prior_dataset,
                             new=new)

    if len(as_data) == 0:
        raise ValueError("Supply at least one dataset"
                         " with at least one record.")

    cli_settings = ASReviewSettings(model=model,
                                    n_instances=n_instances,
                                    n_queries=n_queries,
                                    n_papers=n_papers,
                                    n_prior_included=n_prior_included,
                                    n_prior_excluded=n_prior_excluded,
                                    query_strategy=query_strategy,
                                    balance_strategy=balance_strategy,
                                    feature_extraction=feature_extraction,
                                    mode=mode,
                                    data_fp=None)
    cli_settings.from_file(config_file)

    if state_file is not None:
        with open_state(state_file) as state:
            if state.is_empty():
                state.settings = cli_settings
            settings = state.settings
    else:
        settings = cli_settings

    if n_queries is not None:
        settings.n_queries = n_queries
    if n_papers is not None:
        settings.n_papers = n_papers

    if model_param is not None:
        settings.model_param = model_param
    if query_param is not None:
        settings.query_param = query_param
    if balance_param is not None:
        settings.balance_param = balance_param
    if feature_param is not None:
        settings.feature_param = feature_param

    # Check if mode is valid
    if mode in AVAILABLE_REVIEW_CLASSES:
        logging.info(f"Start review in '{mode}' mode.")
    else:
        raise ValueError(f"Unknown mode '{mode}'.")
    logging.debug(settings)

    # Initialize models.
    random_state = get_random_state(seed)
    train_model = get_classifier(settings.model,
                                 **settings.model_param,
                                 random_state=random_state)
    query_model = get_query_model(settings.query_strategy,
                                  **settings.query_param,
                                  random_state=random_state)
    balance_model = get_balance_model(settings.balance_strategy,
                                      **settings.balance_param,
                                      random_state=random_state)
    feature_model = get_feature_model(settings.feature_extraction,
                                      **settings.feature_param,
                                      random_state=random_state)

    # LSTM models need embedding matrices.
    if train_model.name.startswith("lstm-"):
        texts = as_data.texts
        train_model.embedding_matrix = feature_model.get_embedding_matrix(
            texts, embedding_fp)

    # prior knowledge
    if prior_idx is not None and prior_record_id is not None and \
            len(prior_idx) > 0 and len(prior_record_id) > 0:
        raise ValueError(
            "Not possible to provide both prior_idx and prior_record_id")
    if prior_record_id is not None and len(prior_record_id) > 0:
        prior_idx = convert_id_to_idx(as_data, prior_record_id)

    # Initialize the review class.
    if mode == "simulate":
        reviewer = ReviewSimulate(as_data,
                                  model=train_model,
                                  query_model=query_model,
                                  balance_model=balance_model,
                                  feature_model=feature_model,
                                  n_papers=settings.n_papers,
                                  n_instances=settings.n_instances,
                                  n_queries=settings.n_queries,
                                  prior_idx=prior_idx,
                                  n_prior_included=settings.n_prior_included,
                                  n_prior_excluded=settings.n_prior_excluded,
                                  state_file=state_file,
                                  **kwargs)
    elif mode == "minimal":
        reviewer = MinimalReview(as_data,
                                 model=train_model,
                                 query_model=query_model,
                                 balance_model=balance_model,
                                 feature_model=feature_model,
                                 n_papers=settings.n_papers,
                                 n_instances=settings.n_instances,
                                 n_queries=settings.n_queries,
                                 state_file=state_file,
                                 **kwargs)
    else:
        raise ValueError("Error finding mode, should never come here...")

    return reviewer
Пример #8
0
def get_reviewer(dataset,
                 mode='oracle',
                 model=DEFAULT_MODEL,
                 query_strategy=DEFAULT_QUERY_STRATEGY,
                 balance_strategy=DEFAULT_BALANCE_STRATEGY,
                 n_instances=DEFAULT_N_INSTANCES,
                 n_papers=None,
                 n_queries=None,
                 embedding_fp=None,
                 verbose=0,
                 prior_included=None,
                 prior_excluded=None,
                 n_prior_included=DEFAULT_N_PRIOR_INCLUDED,
                 n_prior_excluded=DEFAULT_N_PRIOR_EXCLUDED,
                 config_file=None,
                 log_file=None,
                 model_param=None,
                 query_param=None,
                 balance_param=None,
                 abstract_only=False,
                 **kwargs):
    """ Get a review object from arguments. See __main__.py for a description
        Of the arguments.
    """

    # Find the URL of the datasets if the dataset is an example dataset.
    if dataset in DEMO_DATASETS.keys():
        dataset = DEMO_DATASETS[dataset]

    cli_settings = ASReviewSettings(model=model,
                                    n_instances=n_instances,
                                    n_queries=n_queries,
                                    n_papers=n_papers,
                                    n_prior_included=n_prior_included,
                                    n_prior_excluded=n_prior_excluded,
                                    query_strategy=query_strategy,
                                    balance_strategy=balance_strategy,
                                    mode=mode,
                                    data_fp=dataset,
                                    abstract_only=abstract_only)
    cli_settings.from_file(config_file)

    if log_file is not None:
        with open_logger(log_file) as logger:
            if logger.is_empty():
                logger.add_settings(cli_settings)
            settings = logger.settings
    else:
        settings = cli_settings
        logger = None

    if n_queries is not None:
        settings.n_queries = n_queries
    if n_papers is not None:
        settings.n_papers = n_papers
    if model_param is not None:
        settings.model_param = model_param
    if query_param is not None:
        settings.query_param = query_param
    if balance_param is not None:
        settings.balance_param = balance_param

    model = settings.model

    # Check if mode is valid
    if mode in AVAILABLE_REVIEW_CLASSES:
        logging.info(f"Start review in '{mode}' mode.")
    else:
        raise ValueError(f"Unknown mode '{mode}'.")
    logging.debug(settings)

    as_data = ASReviewData.from_file(dataset,
                                     abstract_only=settings.abstract_only)
    _, texts, labels = as_data.get_data()

    if as_data.final_labels is not None:
        with open_logger(log_file) as logger:
            logger.set_final_labels(as_data.final_labels)

    model_class = get_model_class(model)
    model_inst = model_class(param=settings.model_param,
                             embedding_fp=embedding_fp)
    X, y = model_inst.get_Xy(texts, labels)

    model_fn = model_inst.model()
    settings.fit_kwargs = model_inst.fit_kwargs()

    settings.query_kwargs = {}
    # Pick query strategy
    query_fn, query_str = get_query_with_settings(settings)
    logging.info(f"Query strategy: {query_str}")

    train_data_fn, train_method = get_balance_with_settings(settings)
    logging.info(f"Using {train_method} method to obtain training data.")

    # Initialize the review class.
    if mode == "simulate":
        reviewer = ReviewSimulate(X,
                                  y,
                                  model=model_fn,
                                  query_strategy=query_fn,
                                  train_data_fn=train_data_fn,
                                  n_papers=settings.n_papers,
                                  n_instances=settings.n_instances,
                                  n_queries=settings.n_queries,
                                  verbose=verbose,
                                  prior_included=prior_included,
                                  prior_excluded=prior_excluded,
                                  n_prior_included=settings.n_prior_included,
                                  n_prior_excluded=settings.n_prior_excluded,
                                  fit_kwargs=settings.fit_kwargs,
                                  balance_kwargs=settings.balance_kwargs,
                                  query_kwargs=settings.query_kwargs,
                                  log_file=log_file,
                                  final_labels=as_data.final_labels,
                                  **kwargs)
    elif mode == "oracle":
        reviewer = ReviewOracle(X,
                                model=model_fn,
                                query_strategy=query_fn,
                                as_data=as_data,
                                train_data_fn=train_data_fn,
                                n_papers=settings.n_papers,
                                n_instances=settings.n_instances,
                                n_queries=settings.n_queries,
                                verbose=verbose,
                                prior_included=prior_included,
                                prior_excluded=prior_excluded,
                                fit_kwargs=settings.fit_kwargs,
                                balance_kwargs=settings.balance_kwargs,
                                query_kwargs=settings.query_kwargs,
                                log_file=log_file,
                                **kwargs)
    elif mode == "minimal":
        reviewer = MinimalReview(X,
                                 model=model_fn,
                                 query_strategy=query_fn,
                                 train_data_fn=train_data_fn,
                                 n_papers=settings.n_papers,
                                 n_instances=settings.n_instances,
                                 n_queries=settings.n_queries,
                                 verbose=verbose,
                                 prior_included=prior_included,
                                 prior_excluded=prior_excluded,
                                 fit_kwargs=settings.fit_kwargs,
                                 balance_kwargs=settings.balance_kwargs,
                                 query_kwargs=settings.query_kwargs,
                                 log_file=log_file,
                                 **kwargs)
    else:
        raise ValueError("Error finding mode, should never come here...")

    return reviewer
def get_reviewer(dataset,
                 mode='oracle',
                 model=DEFAULT_MODEL,
                 query_strategy=DEFAULT_QUERY_STRATEGY,
                 balance_strategy=DEFAULT_BALANCE_STRATEGY,
                 feature_extraction=DEFAULT_FEATURE_EXTRACTION,
                 n_instances=DEFAULT_N_INSTANCES,
                 n_papers=None,
                 n_queries=None,
                 embedding_fp=None,
                 verbose=0,
                 prior_included=None,
                 prior_excluded=None,
                 n_prior_included=DEFAULT_N_PRIOR_INCLUDED,
                 n_prior_excluded=DEFAULT_N_PRIOR_EXCLUDED,
                 config_file=None,
                 log_file=None,
                 model_param=None,
                 query_param=None,
                 balance_param=None,
                 feature_param=None,
                 abstract_only=False,
                 extra_dataset=[],
                 **kwargs
                 ):
    """ Get a review object from arguments. See __main__.py for a description
        Of the arguments.
    """

    # Find the URL of the datasets if the dataset is an example dataset.
    if dataset in DEMO_DATASETS.keys():
        dataset = DEMO_DATASETS[dataset]

    cli_settings = ASReviewSettings(
        model=model, n_instances=n_instances, n_queries=n_queries,
        n_papers=n_papers, n_prior_included=n_prior_included,
        n_prior_excluded=n_prior_excluded, query_strategy=query_strategy,
        balance_strategy=balance_strategy,
        feature_extraction=feature_extraction,
        mode=mode, data_fp=dataset,
        abstract_only=abstract_only)
    cli_settings.from_file(config_file)

    if log_file is not None:
        with open_logger(log_file) as logger:
            if logger.is_empty():
                logger.add_settings(cli_settings)
            settings = logger.settings
    else:
        settings = cli_settings
        logger = None

    if n_queries is not None:
        settings.n_queries = n_queries
    if n_papers is not None:
        settings.n_papers = n_papers

    if model_param is not None:
        settings.model_param = model_param
    if query_param is not None:
        settings.query_param = query_param
    if balance_param is not None:
        settings.balance_param = balance_param
    if feature_param is not None:
        settings.feature_param = feature_param

    # Check if mode is valid
    if mode in AVAILABLE_REVIEW_CLASSES:
        logging.info(f"Start review in '{mode}' mode.")
    else:
        raise ValueError(f"Unknown mode '{mode}'.")
    logging.debug(settings)

    as_data = ASReviewData.from_file(dataset, extra_dataset=extra_dataset,
                                     abstract_only=settings.abstract_only)
    texts = as_data.texts
    y = as_data.labels

    data_prior_included, data_prior_excluded = as_data.get_priors()
    if len(data_prior_included) != 0:
        if prior_included is None:
            prior_included = []
        prior_included.extend(data_prior_included.tolist())
    if len(data_prior_excluded) != 0:
        if prior_excluded is None:
            prior_excluded = []
        prior_excluded.extend(data_prior_excluded.tolist())

    if as_data.final_labels is not None:
        with open_logger(log_file) as logger:
            logger.set_final_labels(as_data.final_labels)

    train_model = get_model(settings.model, **settings.model_param)
    query_model = get_query_model(settings.query_strategy,
                                  **settings.query_param)
    balance_model = get_balance_model(settings.balance_strategy,
                                      **settings.balance_param)
    feature_model = get_feature_model(settings.feature_extraction,
                                      **settings.feature_param)

    X = feature_model.fit_transform(texts, as_data.title, as_data.abstract)

    if train_model.name.startswith("lstm-"):
        train_model.embedding_matrix = feature_model.get_embedding_matrix(
            texts, embedding_fp)

    # Initialize the review class.
    if mode == "simulate":
        reviewer = ReviewSimulate(
            X, y,
            model=train_model,
            query_model=query_model,
            balance_model=balance_model,
            feature_model=feature_model,
            n_papers=settings.n_papers,
            n_instances=settings.n_instances,
            n_queries=settings.n_queries,
            verbose=verbose,
            prior_included=prior_included,
            prior_excluded=prior_excluded,
            n_prior_included=settings.n_prior_included,
            n_prior_excluded=settings.n_prior_excluded,
            log_file=log_file,
            final_labels=as_data.final_labels,
            data_fp=dataset,
            **kwargs)
    elif mode == "oracle":
        reviewer = ReviewOracle(
            X,
            model=train_model,
            query_model=query_model,
            balance_model=balance_model,
            feature_model=feature_model,
            as_data=as_data,
            n_papers=settings.n_papers,
            n_instances=settings.n_instances,
            n_queries=settings.n_queries,
            verbose=verbose,
            prior_included=prior_included,
            prior_excluded=prior_excluded,
            log_file=log_file,
            data_fp=dataset,
            **kwargs)
    elif mode == "minimal":
        reviewer = MinimalReview(
            X,
            model=model,
            query_model=query_model,
            balance_model=balance_model,
            feature_model=feature_model,
            n_papers=settings.n_papers,
            n_instances=settings.n_instances,
            n_queries=settings.n_queries,
            verbose=verbose,
            prior_included=prior_included,
            prior_excluded=prior_excluded,
            log_file=log_file,
            data_fp=dataset,
            **kwargs)
    else:
        raise ValueError("Error finding mode, should never come here...")

    return reviewer
Пример #10
0
 def settings(self):
     settings = self.f.attrs.get('settings', None)
     if settings is None:
         return None
     settings_dict = json.loads(settings)
     return ASReviewSettings(**settings_dict)
 def restore(self, fp):
     with open(fp, "r") as f:
         self._log_dict = OrderedDict(json.load(f))
     self.settings = ASReviewSettings(**self._log_dict.pop("settings"))
def get_reviewer(dataset,
                 mode='oracle',
                 model=DEFAULT_MODEL,
                 query_strategy=DEFAULT_QUERY_STRATEGY,
                 balance_strategy=DEFAULT_BALANCE_STRATEGY,
                 n_instances=DEFAULT_N_INSTANCES,
                 n_queries=1,
                 embedding_fp=None,
                 verbose=1,
                 prior_included=None,
                 prior_excluded=None,
                 n_prior_included=DEFAULT_N_PRIOR_INCLUDED,
                 n_prior_excluded=DEFAULT_N_PRIOR_EXCLUDED,
                 config_file=None,
                 src_log_fp=None,
                 **kwargs):

    # Find the URL of the datasets if the dataset is an example dataset.
    if dataset in DEMO_DATASETS.keys():
        dataset = DEMO_DATASETS[dataset]

    if src_log_fp is not None:
        logger = Logger(log_fp=src_log_fp)
        settings = logger.settings
    else:
        logger = None
        settings = ASReviewSettings(model=model,
                                    n_instances=n_instances,
                                    n_queries=n_queries,
                                    n_prior_included=n_prior_included,
                                    n_prior_excluded=n_prior_excluded,
                                    query_strategy=query_strategy,
                                    balance_strategy=balance_strategy,
                                    mode=mode,
                                    data_fp=dataset)

        settings.from_file(config_file)
    model = settings.model

    if model in ["lstm_base", "lstm_pool"]:
        base_model = "RNN"
    else:
        base_model = "other"

    # Check if mode is valid
    if mode in AVAILABLE_REVIEW_CLASSES:
        if verbose:
            print(f"Start review in '{mode}' mode.")
    else:
        raise ValueError(f"Unknown mode '{mode}'.")
    print(f"Model: '{model}'")

    # if the provided file is a pickle file
    if is_pickle(dataset):
        with open(dataset, 'rb') as f:
            data_obj = pickle.load(f)
        if isinstance(data_obj, tuple) and len(data_obj) == 3:
            X, y, embedding_matrix = data_obj
        elif isinstance(data_obj, tuple) and len(data_obj) == 4:
            X, y, embedding_matrix, _ = data_obj
        else:
            raise ValueError("Incorrect pickle object.")
    else:
        as_data = ASReviewData.from_file(dataset)
        _, texts, labels = as_data.get_data()

        # get the model
        if base_model == "RNN":

            if embedding_fp is None:
                embedding_fp = Path(get_data_home(),
                                    EMBEDDING_EN["name"]).expanduser()

                if not embedding_fp.exists():
                    print("Warning: will start to download large "
                          "embedding file in 10 seconds.")
                    time.sleep(10)
                    download_embedding(verbose=verbose)

            # create features and labels
            X, word_index = text_to_features(texts)
            y = labels
            embedding = load_embedding(embedding_fp, word_index=word_index)
            embedding_matrix = sample_embedding(embedding, word_index)

        elif model.lower() in ['nb', 'svc', 'svm']:
            from sklearn.pipeline import Pipeline
            from sklearn.feature_extraction.text import TfidfTransformer
            from sklearn.feature_extraction.text import CountVectorizer

            text_clf = Pipeline([('vect', CountVectorizer()),
                                 ('tfidf', TfidfTransformer())])

            X = text_clf.fit_transform(texts)
            y = labels

    settings.fit_kwargs = {}
    settings.query_kwargs = {}

    if base_model == 'RNN':
        if model == "lstm_base":
            model_kwargs = lstm_base_model_defaults(settings, verbose)
            create_lstm_model = create_lstm_base_model
        elif model == "lstm_pool":
            model_kwargs = lstm_pool_model_defaults(settings, verbose)
            create_lstm_model = create_lstm_pool_model
        else:
            raise ValueError(f"Unknown model {model}")

        settings.fit_kwargs = lstm_fit_defaults(settings, verbose)
        settings.query_kwargs['verbose'] = verbose
        # create the model
        model = KerasClassifier(create_lstm_model(
            embedding_matrix=embedding_matrix, **model_kwargs),
                                verbose=verbose)

    elif model.lower() in ['nb']:
        from asreview.models import create_nb_model

        model = create_nb_model()

    elif model.lower() in ['svm', 'svc']:
        from asreview.models import create_svc_model

        model = create_svc_model()
    else:
        raise ValueError('Model not found.')

    # Pick query strategy
    query_fn, query_str = get_query_strategy(settings)
    if verbose:
        print(f"Query strategy: {query_str}")

    train_data_fn, train_method = get_balance_strategy(settings)
    if verbose:
        print(f"Using {train_method} method to obtain training data.")

    # Initialize the review class.
    if mode == "simulate":
        reviewer = ReviewSimulate(X,
                                  y,
                                  model=model,
                                  query_strategy=query_fn,
                                  train_data_fn=train_data_fn,
                                  n_instances=settings.n_instances,
                                  n_queries=settings.n_queries,
                                  verbose=verbose,
                                  prior_included=prior_included,
                                  prior_excluded=prior_excluded,
                                  n_prior_included=settings.n_prior_included,
                                  n_prior_excluded=settings.n_prior_excluded,
                                  fit_kwargs=settings.fit_kwargs,
                                  balance_kwargs=settings.balance_kwargs,
                                  query_kwargs=settings.query_kwargs,
                                  logger=logger,
                                  **kwargs)

    elif mode == "oracle":
        reviewer = ReviewOracle(X,
                                model=model,
                                query_strategy=query_fn,
                                as_data=as_data,
                                train_data_fn=train_data_fn,
                                n_instances=settings.n_instances,
                                n_queries=settings.n_queries,
                                verbose=verbose,
                                prior_included=prior_included,
                                prior_excluded=prior_excluded,
                                fit_kwargs=settings.fit_kwargs,
                                balance_kwargs=settings.balance_kwargs,
                                query_kwargs=settings.query_kwargs,
                                logger=logger,
                                **kwargs)
    elif mode == "minimal":
        reviewer = MinimalReview(X,
                                 model=model,
                                 query_strategy=query_fn,
                                 train_data_fn=train_data_fn,
                                 n_instances=settings.n_instances,
                                 n_queries=settings.n_queries,
                                 verbose=verbose,
                                 prior_included=prior_included,
                                 prior_excluded=prior_excluded,
                                 fit_kwargs=settings.fit_kwargs,
                                 balance_kwargs=settings.balance_kwargs,
                                 query_kwargs=settings.query_kwargs,
                                 logger=logger,
                                 **kwargs)
    else:
        raise ValueError("Error finding mode, should never come here...")

    reviewer._logger.add_settings(settings)

    return reviewer
Пример #13
0
 def settings(self):
     settings = self._state_dict.get("settings", None)
     if settings is None:
         return None
     return ASReviewSettings(**settings)