示例#1
0
    def get_embedding_matrix(self, texts, embedding_fp):
        self.fit_transform(texts)
        if embedding_fp is None:
            embedding_fp = Path(get_data_home(),
                                EMBEDDING_EN["name"]).expanduser()

            if not embedding_fp.exists():
                logging.warning("Warning: will start to download large "
                                "embedding file in 10 seconds.")
                time.sleep(10)
                download_embedding()
        logging.info("Loading embedding matrix. "
                     "This can take several minutes.")

        embedding = load_embedding(embedding_fp, n_jobs=self.n_jobs)
        embedding_matrix = sample_embedding(embedding, self.word_index)
        return embedding_matrix
示例#2
0
def download_embedding(url=EMBEDDING_EN['url'],
                       name=EMBEDDING_EN['name'],
                       data_home=None):
    """Download word embedding file.

    Download word embedding file, unzip the file and save to the
    file system.

    Parameters
    ----------
    url: str
        The URL of the gzipped word embedding file
    name: str
        The filename of the embedding file.
    data_home: str
        The location of the ASR datasets.
        Default `asreview.utils.get_data_home()`

    """

    if data_home is None:
        data_home = get_data_home()

    out_fp = Path(data_home, name)

    logging.info(f'Start downloading: {url}')

    r = urlopen(url)
    compressed_file = io.BytesIO(r.read())

    logging.info(f'Save embedding to {out_fp}')

    decompressed_file = gzip.GzipFile(fileobj=compressed_file)

    with open(out_fp, 'wb') as out_file:
        for line in decompressed_file:
            out_file.write(line)
示例#3
0
    def get_embedding_matrix(self):
        if self.embedding_matrix is not None:
            return self.embedding_matrix

        if self.word_index is None:
            self.get_Xy()

        if self.embedding_fp is None:
            self.embedding_fp = Path(get_data_home(),
                                     EMBEDDING_EN["name"]).expanduser()

            if not self.embedding_fp.exists():
                print("Warning: will start to download large "
                      "embedding file in 10 seconds.")
                time.sleep(10)
                download_embedding()

        # create features and labels
        logging.info("Loading embedding matrix. "
                     "This can take several minutes.")
        embedding = load_embedding(self.embedding_fp,
                                   word_index=self.word_index)
        self.embedding_matrix = sample_embedding(embedding, self.word_index)
        return self.embedding_matrix
def get_reviewer(dataset,
                 mode='oracle',
                 model=DEFAULT_MODEL,
                 query_strategy=DEFAULT_QUERY_STRATEGY,
                 balance_strategy=DEFAULT_BALANCE_STRATEGY,
                 n_instances=DEFAULT_N_INSTANCES,
                 n_queries=1,
                 embedding_fp=None,
                 verbose=1,
                 prior_included=None,
                 prior_excluded=None,
                 n_prior_included=DEFAULT_N_PRIOR_INCLUDED,
                 n_prior_excluded=DEFAULT_N_PRIOR_EXCLUDED,
                 config_file=None,
                 src_log_fp=None,
                 **kwargs):

    # Find the URL of the datasets if the dataset is an example dataset.
    if dataset in DEMO_DATASETS.keys():
        dataset = DEMO_DATASETS[dataset]

    if src_log_fp is not None:
        logger = Logger(log_fp=src_log_fp)
        settings = logger.settings
    else:
        logger = None
        settings = ASReviewSettings(model=model,
                                    n_instances=n_instances,
                                    n_queries=n_queries,
                                    n_prior_included=n_prior_included,
                                    n_prior_excluded=n_prior_excluded,
                                    query_strategy=query_strategy,
                                    balance_strategy=balance_strategy,
                                    mode=mode,
                                    data_fp=dataset)

        settings.from_file(config_file)
    model = settings.model

    if model in ["lstm_base", "lstm_pool"]:
        base_model = "RNN"
    else:
        base_model = "other"

    # Check if mode is valid
    if mode in AVAILABLE_REVIEW_CLASSES:
        if verbose:
            print(f"Start review in '{mode}' mode.")
    else:
        raise ValueError(f"Unknown mode '{mode}'.")
    print(f"Model: '{model}'")

    # if the provided file is a pickle file
    if is_pickle(dataset):
        with open(dataset, 'rb') as f:
            data_obj = pickle.load(f)
        if isinstance(data_obj, tuple) and len(data_obj) == 3:
            X, y, embedding_matrix = data_obj
        elif isinstance(data_obj, tuple) and len(data_obj) == 4:
            X, y, embedding_matrix, _ = data_obj
        else:
            raise ValueError("Incorrect pickle object.")
    else:
        as_data = ASReviewData.from_file(dataset)
        _, texts, labels = as_data.get_data()

        # get the model
        if base_model == "RNN":

            if embedding_fp is None:
                embedding_fp = Path(get_data_home(),
                                    EMBEDDING_EN["name"]).expanduser()

                if not embedding_fp.exists():
                    print("Warning: will start to download large "
                          "embedding file in 10 seconds.")
                    time.sleep(10)
                    download_embedding(verbose=verbose)

            # create features and labels
            X, word_index = text_to_features(texts)
            y = labels
            embedding = load_embedding(embedding_fp, word_index=word_index)
            embedding_matrix = sample_embedding(embedding, word_index)

        elif model.lower() in ['nb', 'svc', 'svm']:
            from sklearn.pipeline import Pipeline
            from sklearn.feature_extraction.text import TfidfTransformer
            from sklearn.feature_extraction.text import CountVectorizer

            text_clf = Pipeline([('vect', CountVectorizer()),
                                 ('tfidf', TfidfTransformer())])

            X = text_clf.fit_transform(texts)
            y = labels

    settings.fit_kwargs = {}
    settings.query_kwargs = {}

    if base_model == 'RNN':
        if model == "lstm_base":
            model_kwargs = lstm_base_model_defaults(settings, verbose)
            create_lstm_model = create_lstm_base_model
        elif model == "lstm_pool":
            model_kwargs = lstm_pool_model_defaults(settings, verbose)
            create_lstm_model = create_lstm_pool_model
        else:
            raise ValueError(f"Unknown model {model}")

        settings.fit_kwargs = lstm_fit_defaults(settings, verbose)
        settings.query_kwargs['verbose'] = verbose
        # create the model
        model = KerasClassifier(create_lstm_model(
            embedding_matrix=embedding_matrix, **model_kwargs),
                                verbose=verbose)

    elif model.lower() in ['nb']:
        from asreview.models import create_nb_model

        model = create_nb_model()

    elif model.lower() in ['svm', 'svc']:
        from asreview.models import create_svc_model

        model = create_svc_model()
    else:
        raise ValueError('Model not found.')

    # Pick query strategy
    query_fn, query_str = get_query_strategy(settings)
    if verbose:
        print(f"Query strategy: {query_str}")

    train_data_fn, train_method = get_balance_strategy(settings)
    if verbose:
        print(f"Using {train_method} method to obtain training data.")

    # Initialize the review class.
    if mode == "simulate":
        reviewer = ReviewSimulate(X,
                                  y,
                                  model=model,
                                  query_strategy=query_fn,
                                  train_data_fn=train_data_fn,
                                  n_instances=settings.n_instances,
                                  n_queries=settings.n_queries,
                                  verbose=verbose,
                                  prior_included=prior_included,
                                  prior_excluded=prior_excluded,
                                  n_prior_included=settings.n_prior_included,
                                  n_prior_excluded=settings.n_prior_excluded,
                                  fit_kwargs=settings.fit_kwargs,
                                  balance_kwargs=settings.balance_kwargs,
                                  query_kwargs=settings.query_kwargs,
                                  logger=logger,
                                  **kwargs)

    elif mode == "oracle":
        reviewer = ReviewOracle(X,
                                model=model,
                                query_strategy=query_fn,
                                as_data=as_data,
                                train_data_fn=train_data_fn,
                                n_instances=settings.n_instances,
                                n_queries=settings.n_queries,
                                verbose=verbose,
                                prior_included=prior_included,
                                prior_excluded=prior_excluded,
                                fit_kwargs=settings.fit_kwargs,
                                balance_kwargs=settings.balance_kwargs,
                                query_kwargs=settings.query_kwargs,
                                logger=logger,
                                **kwargs)
    elif mode == "minimal":
        reviewer = MinimalReview(X,
                                 model=model,
                                 query_strategy=query_fn,
                                 train_data_fn=train_data_fn,
                                 n_instances=settings.n_instances,
                                 n_queries=settings.n_queries,
                                 verbose=verbose,
                                 prior_included=prior_included,
                                 prior_excluded=prior_excluded,
                                 fit_kwargs=settings.fit_kwargs,
                                 balance_kwargs=settings.balance_kwargs,
                                 query_kwargs=settings.query_kwargs,
                                 logger=logger,
                                 **kwargs)
    else:
        raise ValueError("Error finding mode, should never come here...")

    reviewer._logger.add_settings(settings)

    return reviewer