def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('document_paths',
                        type=argparse_utils.existing_file_path,
                        nargs='+')

    parser.add_argument('--encoding', type=str, default='latin1')

    parser.add_argument('--vocabulary_min_count', type=int, default=2)
    parser.add_argument('--vocabulary_min_word_size', type=int, default=2)
    parser.add_argument('--vocabulary_max_size', type=int, default=65536)

    parser.add_argument('--include_stopwords',
                        action='store_true',
                        default=False)

    parser.add_argument('--num_workers',
                        type=argparse_utils.positive_int,
                        default=8)

    parser.add_argument('--dictionary_out', required=True)
    parser.add_argument('--humanreadable_dictionary_out', default=None)

    args = parser.parse_args()

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    ignore_words = set()

    if not args.include_stopwords:
        ignore_words.update(nltk_utils.get_stopwords())

    logging.info('Constructing vocabulary.')

    vocabulary = io_utils.construct_vocabulary(
        args.document_paths,
        num_workers=args.num_workers,
        min_count=args.vocabulary_min_count,
        min_word_size=args.vocabulary_min_word_size,
        max_vocab_size=args.vocabulary_max_size,
        ignore_tokens=ignore_words,
        encoding=args.encoding)

    logging.info('Pickling vocabulary.')

    with open(args.dictionary_out, 'wb') as f_out:
        pickle.dump(vocabulary, f_out, pickle.HIGHEST_PROTOCOL)

    if args.humanreadable_dictionary_out is not None:
        with open(args.humanreadable_dictionary_out,
                  'w',
                  encoding=args.encoding) as f_out:
            for token, token_meta in vocabulary.iteritems():
                f_out.write('{} {}\n'.format(token, token_meta.id))
Пример #2
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('--index',
                        type=argparse_utils.existing_directory_path,
                        required=True)
    parser.add_argument('--model',
                        type=argparse_utils.existing_file_path,
                        required=True)
    parser.add_argument('--vocabulary_list',
                        type=argparse_utils.nonexisting_file_path,
                        required=True)

    args = parser.parse_args()

    args.index = pyndri.Index(args.index)

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    logging.info('Loading dictionary.')
    dictionary = pyndri.extract_dictionary(args.index)

    logging.info('Loading model.')
    model_base, epoch_and_ext = args.model.rsplit('_', 1)
    epoch = int(epoch_and_ext.split('.')[0])

    if not os.path.exists('{}_meta'.format(model_base)):
        model_meta_base, batch_idx = model_base.rsplit('_', 1)
    else:
        model_meta_base = model_base

    model = nvsm.load_model(nvsm.load_meta(model_meta_base), model_base, epoch)

    with open(args.vocabulary_list, 'w') as f_vocabulary_list:
        for index_term_id in model.term_mapping:
            f_vocabulary_list.write(dictionary[index_term_id])
            f_vocabulary_list.write('\n')
Пример #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')
    parser.add_argument('--shard_size',
                        type=argparse_utils.positive_int, default=(1 << 14))

    parser.add_argument('reviews_file',
                        type=argparse_utils.existing_file_path)

    parser.add_argument('--product_list',
                        type=argparse_utils.existing_file_path,
                        default=None)

    parser.add_argument('--trectext_out',
                        type=argparse_utils.nonexisting_file_path,
                        required=True)

    args = parser.parse_args()

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    if args.product_list:
        with open(args.product_list, 'r') as f_product_list:
            product_list = set(
                product_id.strip()
                for product_id in f_product_list)

        logging.info('Only considering white list of %d products.',
                     len(product_list))
    else:
        product_list = None

    writer = trec_utils.ShardedTRECTextWriter(
        args.trectext_out, args.shard_size)

    with gzip.open(args.reviews_file, 'r') as f_reviews:
        for i, raw_line in enumerate(f_reviews):
            raw_line = raw_line.decode('utf8')

            review = json.loads(raw_line)

            product_id = review['asin']

            if product_list and product_id not in product_list:
                continue

            document_id = '{product_id}_{reviewer_id}_{review_time}'.format(
                product_id=product_id,
                reviewer_id=review['reviewerID'],
                review_time=review['unixReviewTime'])

            review_summary = review['summary']
            review_text = review['reviewText']

            product_document = '{0} \n{1}'.format(
                review_summary, review_text)

            product_document = io_utils.tokenize_text(product_document)

            document = ' '.join(product_document)

            writer.write_document(document_id, document)

            if (i + 1) % 1000 == 0:
                logging.info('Processed %d reviews.', i + 1)

    writer.close()

    logging.info('All done!')
Пример #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('--data',
                        type=argparse_utils.existing_file_path,
                        required=True)
    parser.add_argument('--meta',
                        type=argparse_utils.existing_file_path,
                        required=True)

    parser.add_argument('--type', choices=MODELS, required=True)

    parser.add_argument('--iterations',
                        type=argparse_utils.positive_int,
                        default=1)

    parser.add_argument('--batch_size',
                        type=argparse_utils.positive_int,
                        default=1024)

    parser.add_argument('--word_representation_size',
                        type=argparse_utils.positive_int,
                        default=300)
    parser.add_argument('--representation_initializer',
                        type=argparse_utils.existing_file_path,
                        default=None)

    # Specific to VectorSpaceLanguageModel.
    parser.add_argument('--entity_representation_size',
                        type=argparse_utils.positive_int,
                        default=None)
    parser.add_argument('--num_negative_samples',
                        type=argparse_utils.positive_int,
                        default=None)
    parser.add_argument('--one_hot_classes',
                        action='store_true',
                        default=False)

    parser.add_argument('--regularization_lambda',
                        type=argparse_utils.ratio,
                        default=0.01)

    parser.add_argument('--model_output', type=str, required=True)

    args = parser.parse_args()

    if args.entity_representation_size is None:
        args.entity_representation_size = args.word_representation_size

    args.type = MODELS[args.type]

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    logging_utils.log_module_info(theano, lasagne, np, scipy)

    # Load data.
    logging.info('Loading data from %s.', args.data)
    data_sets = np.load(args.data)

    if 'w_train' in data_sets and not args.ignore_weights:
        w_train = data_sets['w_train']
    else:
        logging.warning('No weights found in data set; '
                        'assuming uniform instance weighting.')

        w_train = np.ones(data_sets['x_train'].shape[0], dtype=np.float32)

    training_set = (data_sets['x_train'], data_sets['y_train'][()], w_train)
    validation_set = (data_sets['x_validate'], data_sets['y_validate'][()])

    logging.info('Training instances: %s (%s) %s (%s) %s (%s)',
                 training_set[0].shape, training_set[0].dtype,
                 training_set[1].shape, training_set[1].dtype,
                 training_set[2].shape, training_set[2].dtype)
    logging.info('Validation instances: %s (%s) %s (%s)',
                 validation_set[0].shape, validation_set[0].dtype,
                 validation_set[1].shape, validation_set[1].dtype)

    num_entities = training_set[1].shape[1]
    assert num_entities > 1

    if args.one_hot_classes:
        logging.info('Transforming y-values to one-hot values.')

        if not scipy.sparse.issparse(training_set[1]) or \
           not scipy.sparse.issparse(validation_set[1]):
            raise RuntimeError(
                'Argument --one_hot_classes expects sparse truth values.')

        y_train, (x_train, w_train) = sparse_to_one_hot_multiple(
            training_set[1], training_set[0], training_set[2])

        training_set = (x_train, y_train, w_train)

        y_validate, (x_validate, ) = sparse_to_one_hot_multiple(
            validation_set[1], validation_set[0])

        validation_set = (x_validate, y_validate)

    logging.info('Loading meta-data from %s.', args.meta)
    with open(args.meta, 'rb') as f:
        # We do not load the remaining of the vocabulary.
        data_args, words, tokens = (pickle.load(f) for _ in range(3))

        vocabulary_size = len(words)

    representations = lasagne.init.GlorotUniform().sample(
        (vocabulary_size, args.word_representation_size))

    if args.representation_initializer:
        # This way of creating the dictionary ignores duplicate words in
        # the representation initializer.
        representation_lookup = dict(
            embedding_utils.load_binary_representations(
                args.representation_initializer, tokens))

        representation_init_count = 0

        for word, meta in words.items():
            if word.lower() in representation_lookup:
                representations[meta.id] = \
                    representation_lookup[word.lower()]

                representation_init_count += 1

        logging.info(
            'Initialized representations from '
            'pre-learned collection for %d words (%.2f%%).',
            representation_init_count,
            (representation_init_count / float(len(words))) * 100.0)

    # Allow GC to clear memory.
    del words
    del tokens

    model_options = {
        'batch_size': args.batch_size,
        'window_size': data_args.window_size,
        'representations_init': representations,
        'regularization_lambda': args.regularization_lambda,
        'training_set': training_set,
        'validation_set': validation_set,
    }

    if args.type == models.LanguageModel:
        model_options.update(output_layer_size=num_entities)
    elif args.type == models.VectorSpaceLanguageModel:
        entity_representations = lasagne.init.GlorotUniform().sample(
            (num_entities, args.entity_representation_size))

        model_options.update(
            entity_representations_init=entity_representations,
            num_negative_samples=args.num_negative_samples)

    # Construct neural net.
    model = args.type(**model_options)

    train(model,
          args.iterations,
          args.model_output,
          abort_threshold=1e-5,
          early_stopping=False,
          additional_args=[args])
Пример #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('--judgments',
                        type=argparse_utils.existing_file_path,
                        required=True)
    parser.add_argument('--session_topic_map',
                        type=argparse_utils.existing_file_path,
                        required=True)

    parser.add_argument('--track_year',
                        type=int,
                        choices=domain.INPUT_FORMATS, required=True)

    parser.add_argument('--qrel_out',
                        type=argparse_utils.nonexisting_file_path,
                        required=True)

    args = parser.parse_args()

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    topic_id_to_subtopics = collections.defaultdict(lambda: set([0]))
    topic_id_to_session_ids = collections.defaultdict(set)

    with open(args.session_topic_map, 'r') as f_mapping:
        for line in f_mapping:
            line = line.strip()

            if not line:
                continue

            try:
                data = line.strip().split()[:3]
            except:
                logging.warning('Unable to parse %s', line)

                continue

            if len(data) == 3:
                session_id, topic_id, subtopic_id = data
            else:
                session_id, topic_id, subtopic_id = data + [0]

            topic_id_to_subtopics[topic_id].add(subtopic_id)
            topic_id_to_session_ids[topic_id].add(session_id)

    with open(args.judgments, 'r') as f_judgments:
        qrel = sesh.parse_qrel(f_judgments, args.track_year)

    with open(args.qrel_out, 'w') as f_out:
        for topic_id, session_ids in topic_id_to_session_ids.items():
            relevant_items = qrel[topic_id]

            for session_id in session_ids:
                for document_id, relevance in relevant_items.items():
                    f_out.write(
                        '{session_id} 0 {document_id} {relevance}\n'.format(
                            session_id=session_id,
                            document_id=document_id,
                            relevance=relevance))
Пример #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('--seed',
                        type=argparse_utils.positive_int,
                        required=True)

    parser.add_argument('document_paths',
                        type=argparse_utils.existing_file_path, nargs='+')

    parser.add_argument('--encoding', type=str, default='latin1')

    parser.add_argument('--assoc_path',
                        type=argparse_utils.existing_file_path, required=True)

    parser.add_argument('--num_workers', type=int, default=1)

    parser.add_argument('--vocabulary_min_count', type=int, default=2)
    parser.add_argument('--vocabulary_min_word_size', type=int, default=2)
    parser.add_argument('--vocabulary_max_size', type=int, default=65536)

    parser.add_argument('--remove_stopwords', type=str, default='nltk')

    parser.add_argument('--validation_set_ratio',
                        type=argparse_utils.ratio, default=0.01)

    parser.add_argument('--window_size', type=int, default=10)
    parser.add_argument('--overlapping', action='store_true', default=False)
    parser.add_argument('--stride',
                        type=argparse_utils.positive_int,
                        default=None)

    parser.add_argument('--resample', action='store_true', default=False)

    parser.add_argument('--no_shuffle', action='store_true', default=False)
    parser.add_argument('--no_padding', action='store_true', default=False)

    parser.add_argument('--no_instance_weights',
                        action='store_true',
                        default=False)

    parser.add_argument('--meta_output',
                        type=argparse_utils.nonexisting_file_path,
                        required=True)
    parser.add_argument('--data_output',
                        type=argparse_utils.nonexisting_file_path,
                        required=True)

    args = parser.parse_args()

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    # Seed RNG.
    np.random.seed(args.seed)

    logging_utils.log_module_info(np, scipy, sklearn)

    ignore_words = ['<doc>', '</doc>', '<docno>', '<text>', '</text>']
    if args.remove_stopwords == 'none':
        logging.info('Stopwords will be included in instances.')
    elif args.remove_stopwords == 'nltk':
        logging.info('Using NLTK stopword list.')

        ignore_words.extend(stopwords.words('english'))
        ignore_words.extend(stopwords.words('dutch'))
    elif os.path.exists(args.remove_stopwords):
        logging.info('Using custom stopword list (%s).', args.remove_stopwords)

        with open(args.remove_stopwords, 'r') as f:
            ignore_words.extend(
                filter(len, map(str.strip, map(str.lower, f.readlines()))))
    else:
        logging.error('Invalid stopword removal strategy "%s".',
                      args.remove_stopwords)

        return -1

    logging.info('Ignoring words: %s.', ignore_words)

    # TODO(cvangysel): maybe switch by encapsulated call?
    words, tokens = io_utils.extract_vocabulary(
        args.document_paths,
        min_count=args.vocabulary_min_count,
        max_vocab_size=args.vocabulary_max_size,
        min_word_size=args.vocabulary_min_word_size,
        num_workers=args.num_workers,
        ignore_tokens=ignore_words,
        encoding=args.encoding)

    logging.info('Loading document identifiers.')

    reader = trec_utils.TRECTextReader(args.document_paths,
                                       encoding=args.encoding)

    document_ids = reader.iter_document_ids(num_workers=args.num_workers)

    with open(args.assoc_path, 'r') as f_assocs:
        assocs = trec_utils.EntityDocumentAssociations(
            f_assocs,
            document_ids=document_ids)

    logging.info('Found %d unique entities.', len(assocs.entities))

    logging.info('Document-per-entity stats: %s',
                 list(map(lambda kv: (kv[0], len(kv[1])),
                          sorted(assocs.documents_per_entity.items()))))

    logging.info(
        'Entity-per-document association stats: %s',
        collections.Counter(
            map(len, assocs.entities_per_document.values())).items())

    # Estimate the position in authorship distribution.
    num_associations_distribution = np.zeros(
        assocs.max_entities_per_document, dtype=np.int32)

    for association_length in (
            len(associated_entities)
            for associated_entities in
            assocs.entities_per_document.values()):
        num_associations_distribution[association_length - 1] += 1

    logging.info('Number of associations distribution: %s',
                 num_associations_distribution)

    position_in_associations_distribution = np.cumsum(
        num_associations_distribution[::-1])[::-1]

    logging.info('Position in associations distribution: %s',
                 position_in_associations_distribution)

    instances_and_labels = []

    num_documents = 0
    num_non_associated_documents = 0

    documents_per_entity = collections.defaultdict(int)
    instances_per_entity = collections.defaultdict(int)
    instances_per_document = {}

    global_label_distribution = collections.defaultdict(float)

    if args.overlapping and args.stride is None:
        args.stride = 1
    elif args.overlapping and args.stride is not None:
        logging.error('Option --overlapping passed '
                      'concurrently with --stride.')

        return -1
    elif args.stride is None:
        logging.info('Defaulting stride to window size.')

        args.stride = args.window_size

    logging.info('Generating instances with stride %d.', args.stride)

    result_q = multiprocessing.Queue()

    pool = multiprocessing.Pool(
        args.num_workers,
        initializer=prepare_initializer,
        initargs=[result_q,
                  args, assocs.entities, assocs.entities_per_document,
                  position_in_associations_distribution,
                  tokens, words,
                  args.encoding])

    max_document_length = 0

    worker_result = pool.map_async(
        prepare_worker, args.document_paths)

    # We will not submit any more tasks to the pool.
    pool.close()

    it = multiprocessing_utils.QueueIterator(
        pool, worker_result, result_q)

    def _extract_key(obj):
        return tuple(sorted(obj))

    num_labels = 0
    num_instances = 0

    instances_per_label = collections.defaultdict(list)

    while True:
        try:
            result = next(it)
        except StopIteration:
            break

        num_documents += 1

        if result:
            document_id, \
                document_instances_and_labels, \
                document_label = result

            assert document_id not in instances_per_document

            num_instances_for_doc = len(document_instances_and_labels)
            instances_per_document[document_id] = num_instances_for_doc

            max_document_length = max(max_document_length,
                                      num_instances_for_doc)

            # For statistical purposes.
            for entity_id in assocs.entities_per_document[document_id]:
                documents_per_entity[entity_id] += 1

            # For statistical purposes.
            for entity_id in assocs.entities_per_document[document_id]:
                num_labels += num_instances_for_doc

            # Aggregate.
            instances_per_label[
                _extract_key(document_label.keys())].extend(
                document_instances_and_labels)

            num_instances += len(document_instances_and_labels)

            # Some more accounting.
            for entity_id, mass in document_label.items():
                global_label_distribution[entity_id] += \
                    num_instances_for_doc * mass
        else:
            num_non_associated_documents += 1

    # assert result_q.empty()

    logging.info('Global unnormalized distribution: %s',
                 global_label_distribution)

    if args.resample:
        num_entities = len(instances_per_label)
        avg_instances_per_doc = (
            float(num_instances) / float(num_entities))

        min_instances_per_entity = int(avg_instances_per_doc)

        logging.info('Setting number of sampled instances '
                     'to %d per document.',
                     avg_instances_per_doc)

    for label in list(instances_per_label.keys()):
        label_instances_and_labels = instances_per_label.pop(label)

        if not label_instances_and_labels:
            logging.warning('Label %s has no instances; skipping.', label)

            continue

        label_num_instances = len(label_instances_and_labels)

        if args.resample:
            assert min_instances_per_entity > 0

            label_instances_and_labels_sample = []
        else:
            label_instances_and_labels_sample = \
                label_instances_and_labels

        while (len(label_instances_and_labels_sample) <
               min_instances_per_entity):
            label_instances_and_labels_sample.append(
                label_instances_and_labels[
                    np.random.randint(label_num_instances)])

        for entity_id in label:
            instances_per_entity[entity_id] += len(
                label_instances_and_labels_sample)

        instances_and_labels.extend(label_instances_and_labels_sample)

    logging.info(
        'Documents-per-indexed-entity stats (mean=%.2f, std_dev=%.2f): %s',
        np.mean(list(documents_per_entity.values())),
        np.std(list(documents_per_entity.values())),
        sorted(documents_per_entity.items()))

    logging.info(
        'Instances-per-indexed-entity stats '
        '(mean=%.2f, std_dev=%.2f, min=%d, max=%d): %s',
        np.mean(list(instances_per_entity.values())),
        np.std(list(instances_per_entity.values())),
        np.min(instances_per_entity.values()),
        np.max(instances_per_entity.values()),
        sorted(instances_per_entity.items()))

    logging.info(
        'Instances-per-document stats (mean=%.2f, std_dev=%.2f, max=%d).',
        np.mean(list(instances_per_document.values())),
        np.std(list(instances_per_document.values())),
        max_document_length)

    logging.info('Observed %d documents of which %d (ratio=%.2f) '
                 'are not associated with an entity.',
                 num_documents, num_non_associated_documents,
                 (float(num_non_associated_documents) / num_documents))

    training_instances_and_labels, validation_instances_and_labels = \
        sklearn.cross_validation.train_test_split(
            instances_and_labels, test_size=args.validation_set_ratio)

    num_training_instances = len(training_instances_and_labels)
    num_validation_instances = len(validation_instances_and_labels)

    num_instances = num_training_instances + num_validation_instances

    logging.info(
        'Processed %d instances; training=%d, validation=%d (ratio=%.2f).',
        num_instances,
        num_training_instances, num_validation_instances,
        (float(num_validation_instances) /
         (num_training_instances + num_validation_instances)))

    # Figure out if there are any entities with no instances, and
    # do not consider them during training.
    entity_indices = {}
    entity_indices_inv = {}

    for entity_id, num_instances in instances_per_entity.items():
        if not num_instances:
            continue

        entity_index = len(entity_indices)

        entity_indices[entity_id] = entity_index
        entity_indices_inv[entity_index] = entity_id

    logging.info('Retained %d entities after instance creation.',
                 len(entity_indices))

    directories = list(map(os.path.dirname,
                           [args.meta_output, args.data_output]))

    # Create those directories.
    [os.makedirs(directory) for directory in directories
     if not os.path.exists(directory)]

    # Dump vocabulary.
    with open(args.meta_output, 'wb') as f:
        for obj in (args, words, tokens,
                    entity_indices_inv, assocs.documents_per_entity):
            pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)

        logging.info('Saved vocabulary to "%s".', args.meta_output)

    instance_dtype = np.min_scalar_type(len(words) - 1)
    logging.info('Instance elements will be stored using %s.', instance_dtype)

    data = {}

    x_train, y_train = instances_and_labels_to_arrays(
        training_instances_and_labels,
        args.window_size,
        entity_indices,
        instance_dtype,
        not args.no_shuffle)

    data['x_train'] = x_train
    data['y_train'] = y_train

    if not args.no_instance_weights:
        w_train = np.fromiter(
            (float(max_document_length) / instances_per_document[doc_id]
             for doc_id, _, _ in training_instances_and_labels),
            np.float32, len(training_instances_and_labels))

        assert w_train.shape == (x_train.shape[0],)

        data['w_train'] = w_train

    x_validate, y_validate = instances_and_labels_to_arrays(
        validation_instances_and_labels,
        args.window_size,
        entity_indices,
        instance_dtype,
        not args.no_shuffle)

    data['x_validate'] = x_validate
    data['y_validate'] = y_validate

    with open(args.data_output, 'wb') as f:
        np.savez(f, **data)

    logging.info('Saved data sets.')

    logging.info('Entity-per-document association stats: {0}'.format(
        collections.Counter(
            map(len, assocs.entities_per_document.values())).items()))

    logging.info('Documents-per-entity stats: {0}'.format(
        collections.Counter(
            map(len, assocs.documents_per_entity.values())).items()))

    logging.info('Done.')
Пример #7
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')
    parser.add_argument('--shard_size',
                        type=argparse_utils.positive_int,
                        default=(1 << 14))

    parser.add_argument('meta_file', type=argparse_utils.existing_file_path)

    parser.add_argument('--product_list',
                        type=argparse_utils.existing_file_path,
                        default=None)

    parser.add_argument('--trectext_out',
                        type=argparse_utils.nonexisting_file_path,
                        required=True)

    args = parser.parse_args()

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    if args.product_list:
        with open(args.product_list, 'r') as f_product_list:
            product_list = set(product_id.strip()
                               for product_id in f_product_list)

        logging.info('Only considering white list of %d products.',
                     len(product_list))
    else:
        product_list = None

    writer = trec_utils.ShardedTRECTextWriter(args.trectext_out,
                                              args.shard_size)

    department = ' '.join(
        os.path.basename(args.meta_file).split('.')[0].split('_')[1:]).replace(
            ' and ', ' & ')

    logging.info('Department: %s', department)

    with gzip.open(args.meta_file, 'r') as f_meta:
        for i, raw_line in enumerate(f_meta):
            raw_line = raw_line.decode('utf8')

            product = ast.literal_eval(raw_line)

            product_id = product['asin']

            if product_list and product_id not in product_list:
                continue

            if 'description' in product and 'title' in product:
                product_title = product['title']
                product_description = \
                    io_utils.strip_html(product['description'])

                product_document = '{0} \n{1}'.format(product_title,
                                                      product_description)

                product_document = io_utils.tokenize_text(product_document)

                logging.debug('Product %s has description of %d tokens.',
                              len(product_document))

                writer.write_document(product_id, ' '.join(product_document))
            else:
                logging.debug(
                    'Filtering product %s due to missing description.',
                    product_id)

                continue

            if (i + 1) % 1000 == 0:
                logging.info('Processed %d products.', i + 1)

    writer.close()

    logging.info('All done!')
Пример #8
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('model')

    parser.add_argument('index', type=argparse_utils.existing_directory_path)

    parser.add_argument('--limit',
                        type=argparse_utils.positive_int,
                        default=None)

    parser.add_argument('--object_classification',
                        type=argparse_utils.existing_file_path,
                        nargs='+',
                        default=None)

    parser.add_argument('--filter_unclassified',
                        action='store_true',
                        default=False)

    parser.add_argument('--l2_normalize',
                        action='store_true',
                        default=False)

    parser.add_argument('--mode',
                        choices=('tsne', 'embedding_projector'),
                        default='tsne')

    parser.add_argument('--legend',
                        action='store_true',
                        default=False)

    parser.add_argument('--tick_labels',
                        action='store_true',
                        default=False)

    parser.add_argument('--edges',
                        action='store_true',
                        default=False)

    parser.add_argument('--border',
                        action='store_true',
                        default=False)

    parser.add_argument('--plot_out',
                        type=argparse_utils.nonexisting_file_path,
                        required=True)

    args = parser.parse_args()

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    # Set matplotlib style.
    plt.style.use('bmh')

    logging.info('Loading index.')
    index = pyndri.Index(args.index)

    logging.info('Loading cuNVSM model.')
    model_base, epoch_and_ext = args.model.rsplit('_', 1)
    epoch = int(epoch_and_ext.split('.')[0])

    if not os.path.exists('{}_meta'.format(model_base)):
        model_meta_base, batch_idx = model_base.rsplit('_', 1)
    else:
        model_meta_base = model_base

    model = nvsm.load_model(
        nvsm.load_meta(model_meta_base),
        model_base, epoch,
        only_object_embeddings=True)

    raw_object_representations = np.copy(model.object_representations)

    if args.limit:
        raw_object_representations = raw_object_representations[:args.limit, :]

    for object_classification in args.object_classification:
        root, ext = os.path.splitext(args.plot_out)

        plot_out = '{}-{}.{}'.format(
            root, os.path.basename(object_classification), ext.lstrip('.'))

        if object_classification and args.filter_unclassified:
            logging.info('Filtering unclassified.')

            with open(object_classification, 'r') as f_objects:
                object_ids = [line.strip().split()[0] for line in f_objects]
                indices = sorted(model.inv_object_mapping[idx]
                                 for _, idx in index.document_ids(object_ids)
                                 if idx in model.inv_object_mapping)

                logging.info('Considering %d out of %d representations.',
                             len(indices), len(object_ids))

                translation_table = {idx: i for i, idx in enumerate(indices)}

                object_representations = raw_object_representations[indices]

                assert object_representations.shape[0] == \
                    len(translation_table)
        else:
            translation_table = None

            raise NotImplementedError()

        logging.info('Loading object clusters.')

        cluster_id_to_product_ids = {}

        if object_classification:
            with open(object_classification, 'r') as f_objects:
                for line in f_objects:
                    object_id, cluster_id = line.strip().split()

                    if cluster_id not in cluster_id_to_product_ids:
                        cluster_id_to_product_ids[cluster_id] = set()

                    cluster_id_to_product_ids[cluster_id].add(object_id)

                for cluster_id in list(cluster_id_to_product_ids.keys()):
                    object_ids = list(cluster_id_to_product_ids[cluster_id])

                    cluster_id_to_product_ids[cluster_id] = set(
                        (model.inv_object_mapping[int_object_id]
                            if translation_table is None
                            else translation_table[
                                model.inv_object_mapping[int_object_id]])
                        for ext_object_id, int_object_id in
                        index.document_ids(object_ids)
                        if int_object_id in model.inv_object_mapping and
                        (args.limit is None or
                         (model.inv_object_mapping[int_object_id] <
                             args.limit)))
        else:
            raise NotImplementedError()

        assert len(cluster_id_to_product_ids) < len(MARKERS)

        if args.l2_normalize:
            logging.info('L2-normalizing representations.')

            object_representations /= np.linalg.norm(
                object_representations,
                axis=1, keepdims=True)

        if args.mode == 'tsne':
            logging.info('Running t-SNE.')

            twodim_object_representations = \
                TSNE(n_components=2, init='pca', random_state=0).\
                fit_transform(object_representations)

            logging.info('Plotting %s.', twodim_object_representations.shape)

            colors = cm.rainbow(
                np.linspace(0, 1, len(cluster_id_to_product_ids)))

            for idx, cluster_id in enumerate(
                    sorted(cluster_id_to_product_ids.keys(),
                           key=lambda cluster_id: len(
                               cluster_id_to_product_ids[cluster_id]),
                           reverse=True)):
                row_ids = list(cluster_id_to_product_ids[cluster_id])

                plt.scatter(
                    twodim_object_representations[row_ids, 0],
                    twodim_object_representations[row_ids, 1],
                    marker=MARKERS[idx],
                    edgecolors='grey' if args.edges else None,
                    cmap=plt.cm.Spectral,
                    color=colors[idx],
                    alpha=0.3,
                    label=pylatex.utils.escape_latex(cluster_id))

            plt.grid()

            plt.tight_layout()

            if args.legend:
                plt.legend(bbox_to_anchor=(0, -0.15, 1, 0),
                           loc=2,
                           ncol=2,
                           mode='expand',
                           borderaxespad=0)

            if not args.tick_labels:
                plt.gca().get_xaxis().set_visible(False)
                plt.gca().get_yaxis().set_visible(False)

            if not args.border:
                # plt.gcf().patch.set_visible(False)
                plt.gca().axis('off')

            logging.info('Writing %s.', plot_out)

            plt.savefig(plot_out,
                        bbox_inches='tight',
                        transparent=True,
                        pad_inches=0,
                        dpi=200)
        elif args.mode == 'embedding_projector':
            logging.info('Dumping to TensorFlow embedding projector format.')

            with open('{}_vectors.tsv'.format(plot_out), 'w') as f_vectors, \
                    open('{}_meta.tsv'.format(plot_out), 'w') as f_meta:
                f_meta.write('document_id\tclass\n')

                def write_rowids(row_ids, cluster_id):
                    for row_id in row_ids:
                        f_vectors.write(
                            '{}\n'.format('\t'.join(
                                '{:.5f}'.format(x)
                                for x in object_representations[row_id])))

                        f_meta.write('{}\t{}\n'.format(
                            index.ext_document_id(
                                model.object_mapping[row_id]),
                            cluster_id))

                for cluster_id in cluster_id_to_product_ids.keys():
                    row_ids = list(cluster_id_to_product_ids[cluster_id])
                    write_rowids(row_ids, cluster_id)

    logging.info('All done!')
Пример #9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('--meta',
                        type=argparse_utils.existing_file_path, required=True)
    parser.add_argument('--model',
                        type=argparse_utils.existing_file_path, required=True)

    parser.add_argument('--topics',
                        type=argparse_utils.existing_file_path, nargs='+')

    parser.add_argument('--top',
                        type=argparse_utils.positive_int,
                        default=None)

    parser.add_argument('--run_out',
                        type=argparse_utils.nonexisting_file_path,
                        required=True)

    args = parser.parse_args()

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    with open(args.model, 'rb') as f:
        # Load model arguments and learned mapping.
        model_args, predict_fn = (pickle.load(f) for _ in range(2))

        # Load word representations.
        word_representations = pickle.load(f)

        try:
            entity_representations = pickle.load(f)
        except EOFError:
            entity_representations = None

    with open(args.meta, 'rb') as f:
        (data_args,
         words, tokens,
         entity_indices_inv, entity_assocs) = (
            pickle.load(f) for _ in range(5))

    # Parse topic files.
    topic_f = list(map(lambda filename: open(filename, 'r'), args.topics))
    topics = trec_utils.parse_topics(topic_f)
    [f_.close() for f_ in topic_f]

    model_name = os.path.basename(args.model)

    # Entity profiling.
    topics_per_entity = collections.defaultdict(list)

    # Entity finding.
    entities_per_topic = collections.defaultdict(list)

    def ranker_callback(topic_id, top_ranked_indices, top_ranked_values):
        for rank, (entity_internal_id, relevance) in enumerate(
                zip(top_ranked_indices, top_ranked_values)):
            entity_id = entity_indices_inv[entity_internal_id]

            # Entity profiling.
            topics_per_entity[entity_id].append((relevance, topic_id))

            # Entity finding.
            entities_per_topic[topic_id].append((relevance, entity_id))

    with open('{0}_debug'.format(args.run_out), 'w') as f_debug_out:
        if model_args.type == models.LanguageModel:
            result_callback = LogLinearCallback(
                args, model_args, tokens,
                f_debug_out,
                ranker_callback)
        elif model_args.type == models.VectorSpaceLanguageModel:
            result_callback = VectorSpaceCallback(
                entity_representations,
                args, model_args, tokens,
                f_debug_out,
                ranker_callback)

        batcher = inference.create(
            predict_fn, word_representations,
            model_args.batch_size, data_args.window_size, len(words),
            result_callback)

        logging.info('Batching queries using %s.', batcher)

        for q_id, (topic_id, terms) in enumerate(topics.items()):
            if topic_id not in topics:
                logging.error('Topic "%s" not found in topic list.', topic_id)

                continue

            # Do not replace numeric tokens in queries.
            query_terms = trec_utils.parse_query(terms)

            query_tokens = []

            logging.debug('Query (%d/%d) %s: %s (%s)',
                          q_id + 1, len(topics),
                          topic_id, query_terms, terms)

            for term in query_terms:
                if term not in words:
                    logging.debug('Term "%s" is OOV.', term)

                    continue

                term_token = words[term].id

                query_tokens.append(term_token)

            if not query_tokens:
                logging.warning('Skipping query with terms "%s".', terms)

                continue

            batcher.submit(query_tokens, topic_id=topic_id)

        batcher.process()

    # Entity profiling.
    with io.open('{0}_ep'.format(args.run_out),
                 'w', encoding='utf8') as out_ep_run:
        trec_utils.write_run(model_name, topics_per_entity, out_ep_run)

    # Entity finding.
    with io.open('{0}_ef'.format(args.run_out),
                 'w', encoding='utf8') as out_ef_run:
        trec_utils.write_run(model_name, entities_per_topic, out_ef_run)

    logging.info('Saved run to %s.', args.run_out)
Пример #10
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('--meta',
                        type=argparse_utils.existing_file_path,
                        required=True)
    parser.add_argument('--model',
                        type=argparse_utils.existing_file_path,
                        required=True)

    parser.add_argument('--topics',
                        type=argparse_utils.existing_file_path,
                        nargs='+')

    parser.add_argument('--top',
                        type=argparse_utils.positive_int,
                        default=None)

    parser.add_argument('--run_out',
                        type=argparse_utils.nonexisting_file_path,
                        required=True)

    args = parser.parse_args()

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    with open(args.model, 'rb') as f:
        # Load model arguments and learned mapping.
        model_args, predict_fn = (pickle.load(f) for _ in range(2))

        # Load word representations.
        word_representations = pickle.load(f)

        try:
            entity_representations = pickle.load(f)
        except EOFError:
            entity_representations = None

    with open(args.meta, 'rb') as f:
        (data_args, words, tokens, entity_indices_inv,
         entity_assocs) = (pickle.load(f) for _ in range(5))

    # Parse topic files.
    topic_f = list(map(lambda filename: open(filename, 'r'), args.topics))
    topics = trec_utils.parse_topics(topic_f)
    [f_.close() for f_ in topic_f]

    model_name = os.path.basename(args.model)

    # Entity profiling.
    topics_per_entity = collections.defaultdict(list)

    # Entity finding.
    entities_per_topic = collections.defaultdict(list)

    def ranker_callback(topic_id, top_ranked_indices, top_ranked_values):
        for rank, (entity_internal_id, relevance) in enumerate(
                zip(top_ranked_indices, top_ranked_values)):
            entity_id = entity_indices_inv[entity_internal_id]

            # Entity profiling.
            topics_per_entity[entity_id].append((relevance, topic_id))

            # Entity finding.
            entities_per_topic[topic_id].append((relevance, entity_id))

    with open('{0}_debug'.format(args.run_out), 'w') as f_debug_out:
        if model_args.type == models.LanguageModel:
            result_callback = LogLinearCallback(args, model_args, tokens,
                                                f_debug_out, ranker_callback)
        elif model_args.type == models.VectorSpaceLanguageModel:
            result_callback = VectorSpaceCallback(entity_representations, args,
                                                  model_args, tokens,
                                                  f_debug_out, ranker_callback)

        batcher = inference.create(predict_fn, word_representations,
                                   model_args.batch_size,
                                   data_args.window_size, len(words),
                                   result_callback)

        logging.info('Batching queries using %s.', batcher)

        for q_id, (topic_id, terms) in enumerate(topics.items()):
            if topic_id not in topics:
                logging.error('Topic "%s" not found in topic list.', topic_id)

                continue

            # Do not replace numeric tokens in queries.
            query_terms = trec_utils.parse_query(terms)

            query_tokens = []

            logging.debug('Query (%d/%d) %s: %s (%s)', q_id + 1, len(topics),
                          topic_id, query_terms, terms)

            for term in query_terms:
                if term not in words:
                    logging.debug('Term "%s" is OOV.', term)

                    continue

                term_token = words[term].id

                query_tokens.append(term_token)

            if not query_tokens:
                logging.warning('Skipping query with terms "%s".', terms)

                continue

            batcher.submit(query_tokens, topic_id=topic_id)

        batcher.process()

    # Entity profiling.
    with io.open('{0}_ep'.format(args.run_out), 'w',
                 encoding='utf8') as out_ep_run:
        trec_utils.write_run(model_name, topics_per_entity, out_ep_run)

    # Entity finding.
    with io.open('{0}_ef'.format(args.run_out), 'w',
                 encoding='utf8') as out_ef_run:
        trec_utils.write_run(model_name, entities_per_topic, out_ef_run)

    logging.info('Saved run to %s.', args.run_out)
Пример #11
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('--shard_size',
                        type=argparse_utils.positive_int,
                        default=1000000)

    parser.add_argument('sgm',
                        type=argparse_utils.existing_file_path,
                        nargs='+')

    parser.add_argument('--top_k_topics',
                        type=argparse_utils.positive_int,
                        default=20)

    parser.add_argument('--trectext_out_prefix', type=str, required=True)

    parser.add_argument('--document_classification_out',
                        type=argparse_utils.nonexisting_file_path,
                        required=True)

    args = parser.parse_args()

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    parser = ReutersParser()

    for sgm_path in args.sgm:
        logging.info('Parsing %s.', sgm_path)

        with open(sgm_path, 'r', encoding='ISO-8859-1') as f_sgm:
            parser.feed(f_sgm.read())

    logging.info('Parsed %d documents.', len(parser.documents))

    topic_histogram = collections.Counter(
        topic for document in parser.documents
        for topic in document['tags']['topics'])

    top_topics = set(
        sorted(topic_histogram.keys(),
               key=lambda topic: topic_histogram[topic])[-args.top_k_topics:])

    logging.info('Top topics: %s', top_topics)

    writer = trec_utils.ShardedTRECTextWriter(args.trectext_out_prefix,
                                              shard_size=args.shard_size,
                                              encoding='latin1')

    with open(args.document_classification_out, 'w') as \
            f_document_classification_out:
        for document in parser.documents:
            doc_id = document['doc_id']

            doc_text = '\n'.join([
                document['texts'].get('title', ''),
                document['texts'].get('dateline',
                                      ''), document['texts'].get('body', '')
            ])

            writer.write_document(doc_id, doc_text)

            doc_topics = {
                topic
                for topic in document['tags']['topics'] if topic in top_topics
            }

            if doc_topics:
                most_specific_doc_topic = min(
                    doc_topics, key=lambda topic: topic_histogram[topic])

                f_document_classification_out.write(doc_id)
                f_document_classification_out.write(' ')
                f_document_classification_out.write(most_specific_doc_topic)
                f_document_classification_out.write('\n')

    writer.close()
Пример #12
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')
    parser.add_argument('--shard_size',
                        type=argparse_utils.positive_int, default=(1 << 14))

    parser.add_argument('meta_file',
                        type=argparse_utils.existing_file_path)

    parser.add_argument('--product_list',
                        type=argparse_utils.existing_file_path,
                        default=None)

    parser.add_argument('--trectext_out',
                        type=argparse_utils.nonexisting_file_path,
                        required=True)

    args = parser.parse_args()

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    if args.product_list:
        with open(args.product_list, 'r') as f_product_list:
            product_list = set(
                product_id.strip()
                for product_id in f_product_list)

        logging.info('Only considering white list of %d products.',
                     len(product_list))
    else:
        product_list = None

    writer = trec_utils.ShardedTRECTextWriter(
        args.trectext_out, args.shard_size)

    department = ' '.join(
        os.path.basename(args.meta_file).split('.')[0]
        .split('_')[1:]).replace(' and ', ' & ')

    logging.info('Department: %s', department)

    with gzip.open(args.meta_file, 'r') as f_meta:
        for i, raw_line in enumerate(f_meta):
            raw_line = raw_line.decode('utf8')

            product = ast.literal_eval(raw_line)

            product_id = product['asin']

            if product_list and product_id not in product_list:
                continue

            if 'description' in product and 'title' in product:
                product_title = product['title']
                product_description = \
                    io_utils.strip_html(product['description'])

                product_document = '{0} \n{1}'.format(
                    product_title, product_description)

                product_document = io_utils.tokenize_text(product_document)

                logging.debug('Product %s has description of %d tokens.',
                              len(product_document))

                writer.write_document(
                    product_id, ' '.join(product_document))
            else:
                logging.debug(
                    'Filtering product %s due to missing description.',
                    product_id)

                continue

            if (i + 1) % 1000 == 0:
                logging.info('Processed %d products.', i + 1)

    writer.close()

    logging.info('All done!')
Пример #13
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('index', type=argparse_utils.existing_file_path)

    parser.add_argument('session_file', type=argparse_utils.existing_file_path)

    parser.add_argument('--num_workers', type=int, default=1)

    parser.add_argument('--harvested_links_file',
                        type=argparse_utils.existing_file_path,
                        default=None)

    parser.add_argument('--qrel',
                        type=argparse_utils.existing_file_path,
                        nargs='*')

    parser.add_argument('--configuration', type=str, nargs='+')

    parser.add_argument('--top_sessions',
                        type=argparse_utils.positive_int,
                        default=None)

    parser.add_argument('--out_base',
                        type=argparse_utils.nonexisting_file_path,
                        required=True)

    args = parser.parse_args()

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    logging_utils.log_module_info(np, scipy)

    configuration = sesh_pb2.ScoreSessionsConfig()
    pb.text_format.Merge(' '.join(args.configuration), configuration)

    if not configuration.modifier:
        configuration.modifier.add()  # Create an empty modifier.
    elif len(configuration.modifier) > 1:
        modifier_identifiers = [
            modifier.identifier for modifier in configuration.modifier]

        assert all(modifier_identifiers), \
            'All session modifiers should have an identifier.'

        assert len(modifier_identifiers) == len(set(modifier_identifiers)), \
            'All session modifier identifiers should be unique: {}.'.format(
                modifier_identifiers)

    logging.info('Configuration: %s', configuration)

    logging.info('Loading index.')
    index = pyndri.Index(args.index)

    num_documents = index.document_count()
    logging.debug('Index contains %d documents.', num_documents)

    logging.info('Loading dictionary.')
    dictionary = pyndri.extract_dictionary(index)

    logging.info('Loading background corpus.')
    background_prob_dist = pyndri_utils.extract_background_prob_dist(index)

    for modifier in configuration.modifier:
        out_base = os.path.join(args.out_base, modifier.identifier)
        assert not os.path.exists(out_base)

        os.makedirs(out_base)

        logging.info('Loading sessions using %s and outputting to %s.',
                     modifier or 'no modifier', out_base)

        with codecs.open(args.session_file, 'r', 'utf8') as f_xml:
            track_edition, _, sessions, session_id_to_topic_id = \
                domain.construct_sessions(
                    f_xml, args.top_sessions, dictionary)

        logging.info('Discovered %d sessions.', len(sessions))

        sessions = domain.alter_sessions(sessions, modifier)
        documents = domain.get_document_set(sessions.values())

        logging.info('Retained %d sessions (%d SERP documents) '
                     'after filtering.',
                     len(sessions), len(documents))

        # Load QRels for debugging and oracle runs.
        qrels_per_session = []

        for qrel_path in args.qrel:
            with open(qrel_path, 'r') as f_qrel:
                qrels_per_session.append(sesh.parse_qrel(f_qrel, None))

        scorer_impls = {}

        for scorer_desc in configuration.scorer:
            assert scorer_desc.type in scorers.SESSION_SCORERS

            identifier = scorer_desc.identifier or scorer_desc.type
            assert identifier not in scorer_impls

            scorer = scorers.create_scorer(scorer_desc, qrels_per_session)
            logging.info('Scoring using %s.', repr(scorer))

            scorer_impls[identifier] = scorer

        anchor_texts = None

        if args.harvested_links_file is not None:
            urls = set(document.url for document in documents)

            logging.info('Loading anchor texts for session SERPs (%d URLs).',
                         len(urls))

            with codecs.open(args.harvested_links_file, 'r', 'latin1') \
                    as f_harvested_links:
                anchor_texts = load_anchor_texts(f_harvested_links, urls)

            logging.info('Discovered anchor texts for %d URLs (%d total).',
                         len(anchor_texts), len(urls))
        else:
            logging.info('No anchor texts loaded.')

        # The following will hold all the rankings.
        document_assessments_per_session_per_scorer = collections.defaultdict(
            lambda: collections.defaultdict(
                lambda: collections.defaultdict(float)))

        assert configuration.candidate_generator in \
            DOCUMENT_CANDIDATE_GENERATORS

        # Document candidate generation.
        candidate_generator = DOCUMENT_CANDIDATE_GENERATORS[
            configuration.candidate_generator](**locals())

        logging.info('Using %s for document candidate generation.',
                     candidate_generator)

        result_queue = multiprocessing.Queue()

        initargs = [
            result_queue,
            args,
            configuration,
            out_base,
            background_prob_dist,
            candidate_generator,
            scorer_impls,
            index,
            dictionary,
            anchor_texts]

        pool = multiprocessing.Pool(
            args.num_workers,
            initializer=score_session_initializer,
            initargs=initargs)

        worker_result = pool.map_async(
            score_session_worker,
            sessions.values())

        # We will not submit any more tasks to the pool.
        pool.close()

        it = multiprocessing_utils.QueueIterator(
            pool, worker_result, result_queue)

        while True:
            try:
                result = next(it)
            except StopIteration:
                break

            scorer_name, session_id, ranking = result

            document_assessments_per_session_per_scorer[
                scorer_name][session_id] = ranking

        for scorer_name in document_assessments_per_session_per_scorer:
            # Switch object asssessments to lists.
            for topic_id, object_assesments in \
                    document_assessments_per_session_per_scorer[
                        scorer_name].items():
                document_assessments_per_session_per_scorer[
                    scorer_name][topic_id] = [
                    (score, document_id)
                    for document_id, score in object_assesments.items()]

        # Write the runs.
        for scorer_name in document_assessments_per_session_per_scorer:
            run_out_path = os.path.join(
                out_base, '{0}.run'.format(scorer_name))

            with io.open(run_out_path, 'w', encoding='utf8') as f_run_out:
                trec_utils.write_run(
                    scorer_name,
                    document_assessments_per_session_per_scorer[scorer_name],
                    f_run_out)
Пример #14
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('--data',
                        type=argparse_utils.existing_file_path, required=True)
    parser.add_argument('--meta',
                        type=argparse_utils.existing_file_path, required=True)

    parser.add_argument('--type', choices=MODELS, required=True)

    parser.add_argument('--iterations',
                        type=argparse_utils.positive_int, default=1)

    parser.add_argument('--batch_size',
                        type=argparse_utils.positive_int, default=1024)

    parser.add_argument('--word_representation_size',
                        type=argparse_utils.positive_int, default=300)
    parser.add_argument('--representation_initializer',
                        type=argparse_utils.existing_file_path, default=None)

    # Specific to VectorSpaceLanguageModel.
    parser.add_argument('--entity_representation_size',
                        type=argparse_utils.positive_int, default=None)
    parser.add_argument('--num_negative_samples',
                        type=argparse_utils.positive_int, default=None)
    parser.add_argument('--one_hot_classes',
                        action='store_true',
                        default=False)

    parser.add_argument('--regularization_lambda',
                        type=argparse_utils.ratio, default=0.01)

    parser.add_argument('--model_output', type=str, required=True)

    args = parser.parse_args()

    if args.entity_representation_size is None:
        args.entity_representation_size = args.word_representation_size

    args.type = MODELS[args.type]

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    logging_utils.log_module_info(theano, lasagne, np, scipy)

    # Load data.
    logging.info('Loading data from %s.', args.data)
    data_sets = np.load(args.data)

    if 'w_train' in data_sets and not args.ignore_weights:
        w_train = data_sets['w_train']
    else:
        logging.warning('No weights found in data set; '
                        'assuming uniform instance weighting.')

        w_train = np.ones(data_sets['x_train'].shape[0], dtype=np.float32)

    training_set = (data_sets['x_train'], data_sets['y_train'][()], w_train)
    validation_set = (data_sets['x_validate'], data_sets['y_validate'][()])

    logging.info('Training instances: %s (%s) %s (%s) %s (%s)',
                 training_set[0].shape, training_set[0].dtype,
                 training_set[1].shape, training_set[1].dtype,
                 training_set[2].shape, training_set[2].dtype)
    logging.info('Validation instances: %s (%s) %s (%s)',
                 validation_set[0].shape, validation_set[0].dtype,
                 validation_set[1].shape, validation_set[1].dtype)

    num_entities = training_set[1].shape[1]
    assert num_entities > 1

    if args.one_hot_classes:
        logging.info('Transforming y-values to one-hot values.')

        if not scipy.sparse.issparse(training_set[1]) or \
           not scipy.sparse.issparse(validation_set[1]):
            raise RuntimeError(
                'Argument --one_hot_classes expects sparse truth values.')

        y_train, (x_train, w_train) = sparse_to_one_hot_multiple(
            training_set[1], training_set[0], training_set[2])

        training_set = (x_train, y_train, w_train)

        y_validate, (x_validate,) = sparse_to_one_hot_multiple(
            validation_set[1], validation_set[0])

        validation_set = (x_validate, y_validate)

    logging.info('Loading meta-data from %s.', args.meta)
    with open(args.meta, 'rb') as f:
        # We do not load the remaining of the vocabulary.
        data_args, words, tokens = (pickle.load(f) for _ in range(3))

        vocabulary_size = len(words)

    representations = lasagne.init.GlorotUniform().sample(
        (vocabulary_size, args.word_representation_size))

    if args.representation_initializer:
        # This way of creating the dictionary ignores duplicate words in
        # the representation initializer.
        representation_lookup = dict(
            embedding_utils.load_binary_representations(
                args.representation_initializer, tokens))

        representation_init_count = 0

        for word, meta in words.items():
            if word.lower() in representation_lookup:
                representations[meta.id] = \
                    representation_lookup[word.lower()]

                representation_init_count += 1

        logging.info('Initialized representations from '
                     'pre-learned collection for %d words (%.2f%%).',
                     representation_init_count,
                     (representation_init_count /
                      float(len(words))) * 100.0)

    # Allow GC to clear memory.
    del words
    del tokens

    model_options = {
        'batch_size': args.batch_size,
        'window_size': data_args.window_size,
        'representations_init': representations,
        'regularization_lambda': args.regularization_lambda,
        'training_set': training_set,
        'validation_set': validation_set,
    }

    if args.type == models.LanguageModel:
        model_options.update(
            output_layer_size=num_entities)
    elif args.type == models.VectorSpaceLanguageModel:
        entity_representations = lasagne.init.GlorotUniform().sample(
            (num_entities, args.entity_representation_size))

        model_options.update(
            entity_representations_init=entity_representations,
            num_negative_samples=args.num_negative_samples)

    # Construct neural net.
    model = args.type(**model_options)

    train(model, args.iterations, args.model_output,
          abort_threshold=1e-5,
          early_stopping=False,
          additional_args=[args])
Пример #15
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--loglevel', type=str, default='INFO')

    parser.add_argument('--num_workers',
                        type=argparse_utils.positive_int, default=16)

    parser.add_argument('--topics', nargs='+',
                        type=argparse_utils.existing_file_path)

    parser.add_argument('model', type=argparse_utils.existing_file_path)

    parser.add_argument('--index', required=True)

    parser.add_argument('--linear', action='store_true', default=False)
    parser.add_argument('--self_information',
                        action='store_true',
                        default=False)
    parser.add_argument('--l2norm_phrase', action='store_true', default=False)

    parser.add_argument('--bias_coefficient',
                        type=argparse_utils.ratio,
                        default=0.0)

    parser.add_argument('--rerank_exact_matching_documents',
                        action='store_true',
                        default=False)

    parser.add_argument('--strict', action='store_true', default=False)

    parser.add_argument('--top_k', default=None)

    parser.add_argument('--num_queries',
                        type=argparse_utils.positive_int,
                        default=None)

    parser.add_argument('run_out')

    args = parser.parse_args()

    args.index = pyndri.Index(args.index)

    try:
        logging_utils.configure_logging(args)
    except IOError:
        return -1

    if not args.top_k:
        args.top_k = 1000
    elif args.top_k == 'all':
        args.top_k = args.top_k = \
            args.index.maximum_document() - args.index.document_base()
    elif args.top_k.isdigit():
        args.top_k = int(args.top_k)
    elif all(map(os.path.exists, args.top_k.split())):
        topics_and_documents = {}

        for qrel_path in args.top_k.split():
            with open(qrel_path, 'r') as f_qrel:
                for topic_id, judgments in trec_utils.parse_qrel(f_qrel):
                    if topic_id not in topics_and_documents:
                        topics_and_documents[topic_id] = set()

                    for doc_id, _ in judgments:
                        topics_and_documents[topic_id].add(doc_id)

        args.top_k = topics_and_documents
    else:
        raise RuntimeError()

    logging.info('Loading dictionary.')
    dictionary = pyndri.extract_dictionary(args.index)

    logging.info('Loading model.')
    model_base, epoch_and_ext = args.model.rsplit('_', 1)
    epoch = int(epoch_and_ext.split('.')[0])

    if not os.path.exists('{}_meta'.format(model_base)):
        model_meta_base, batch_idx = model_base.rsplit('_', 1)
    else:
        model_meta_base = model_base

    kwargs = {
        'strict': args.strict,
    }

    if args.self_information:
        kwargs['self_information'] = True

    if args.linear:
        kwargs['bias_coefficient'] = args.bias_coefficient
        kwargs['nonlinearity'] = None

    if args.l2norm_phrase:
        kwargs['l2norm_phrase'] = True

    model = nvsm.load_model(
        nvsm.load_meta(model_meta_base),
        model_base, epoch, **kwargs)

    for topic_path in args.topics:
        run_out_path = '{}-{}'.format(
            args.run_out, os.path.basename(topic_path))

        if os.path.exists(run_out_path):
            logging.warning('Run for topics %s already exists (%s); skipping.',
                            topic_path, run_out_path)

            continue

        queries = list(pyndri.utils.parse_queries(
            args.index, dictionary, topic_path,
            strict=args.strict,
            num_queries=args.num_queries))

        if args.rerank_exact_matching_documents:
            assert not isinstance(args.top_k, dict)

            topics_and_documents = {}

            query_env = pyndri.TFIDFQueryEnvironment(args.index)

            for topic_id, topic_token_ids in queries:
                topics_and_documents[topic_id] = set()

                query_str = ' '.join(
                    dictionary[term_id] for term_id in topic_token_ids
                    if term_id is not None)

                for int_doc_id, score in query_env.query(
                        query_str, results_requested=1000):
                    topics_and_documents[topic_id].add(
                        args.index.ext_document_id(int_doc_id))

            args.top_k = topics_and_documents

        run = trec_utils.OnlineTRECRun(
            'cuNVSM', rank_cutoff=(
                args.top_k if isinstance(args.top_k, int)
                else sys.maxsize))

        rank_fn = RankFn(
            args.num_workers,
            args=args, model=model)

        for idx, (topic_id, topic_data) in enumerate(rank_fn(queries)):
            if topic_data is None:
                continue

            logging.info('Query %s (%d/%d)', topic_id, idx + 1, len(queries))

            (topic_repr,
             topic_scores_and_documents) = topic_data

            run.add_ranking(topic_id, topic_scores_and_documents)

            del topic_scores_and_documents

        run.close_and_write(run_out_path, overwrite=False)

        logging.info('Run outputted to %s.', run_out_path)

    del rank_fn