def main(): parser = argparse.ArgumentParser() parser.add_argument('--loglevel', type=str, default='INFO') parser.add_argument('document_paths', type=argparse_utils.existing_file_path, nargs='+') parser.add_argument('--encoding', type=str, default='latin1') parser.add_argument('--vocabulary_min_count', type=int, default=2) parser.add_argument('--vocabulary_min_word_size', type=int, default=2) parser.add_argument('--vocabulary_max_size', type=int, default=65536) parser.add_argument('--include_stopwords', action='store_true', default=False) parser.add_argument('--num_workers', type=argparse_utils.positive_int, default=8) parser.add_argument('--dictionary_out', required=True) parser.add_argument('--humanreadable_dictionary_out', default=None) args = parser.parse_args() try: logging_utils.configure_logging(args) except IOError: return -1 ignore_words = set() if not args.include_stopwords: ignore_words.update(nltk_utils.get_stopwords()) logging.info('Constructing vocabulary.') vocabulary = io_utils.construct_vocabulary( args.document_paths, num_workers=args.num_workers, min_count=args.vocabulary_min_count, min_word_size=args.vocabulary_min_word_size, max_vocab_size=args.vocabulary_max_size, ignore_tokens=ignore_words, encoding=args.encoding) logging.info('Pickling vocabulary.') with open(args.dictionary_out, 'wb') as f_out: pickle.dump(vocabulary, f_out, pickle.HIGHEST_PROTOCOL) if args.humanreadable_dictionary_out is not None: with open(args.humanreadable_dictionary_out, 'w', encoding=args.encoding) as f_out: for token, token_meta in vocabulary.iteritems(): f_out.write('{} {}\n'.format(token, token_meta.id))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--loglevel', type=str, default='INFO') parser.add_argument('--index', type=argparse_utils.existing_directory_path, required=True) parser.add_argument('--model', type=argparse_utils.existing_file_path, required=True) parser.add_argument('--vocabulary_list', type=argparse_utils.nonexisting_file_path, required=True) args = parser.parse_args() args.index = pyndri.Index(args.index) try: logging_utils.configure_logging(args) except IOError: return -1 logging.info('Loading dictionary.') dictionary = pyndri.extract_dictionary(args.index) logging.info('Loading model.') model_base, epoch_and_ext = args.model.rsplit('_', 1) epoch = int(epoch_and_ext.split('.')[0]) if not os.path.exists('{}_meta'.format(model_base)): model_meta_base, batch_idx = model_base.rsplit('_', 1) else: model_meta_base = model_base model = nvsm.load_model(nvsm.load_meta(model_meta_base), model_base, epoch) with open(args.vocabulary_list, 'w') as f_vocabulary_list: for index_term_id in model.term_mapping: f_vocabulary_list.write(dictionary[index_term_id]) f_vocabulary_list.write('\n')
def main(): parser = argparse.ArgumentParser() parser.add_argument('--loglevel', type=str, default='INFO') parser.add_argument('--shard_size', type=argparse_utils.positive_int, default=(1 << 14)) parser.add_argument('reviews_file', type=argparse_utils.existing_file_path) parser.add_argument('--product_list', type=argparse_utils.existing_file_path, default=None) parser.add_argument('--trectext_out', type=argparse_utils.nonexisting_file_path, required=True) args = parser.parse_args() try: logging_utils.configure_logging(args) except IOError: return -1 if args.product_list: with open(args.product_list, 'r') as f_product_list: product_list = set( product_id.strip() for product_id in f_product_list) logging.info('Only considering white list of %d products.', len(product_list)) else: product_list = None writer = trec_utils.ShardedTRECTextWriter( args.trectext_out, args.shard_size) with gzip.open(args.reviews_file, 'r') as f_reviews: for i, raw_line in enumerate(f_reviews): raw_line = raw_line.decode('utf8') review = json.loads(raw_line) product_id = review['asin'] if product_list and product_id not in product_list: continue document_id = '{product_id}_{reviewer_id}_{review_time}'.format( product_id=product_id, reviewer_id=review['reviewerID'], review_time=review['unixReviewTime']) review_summary = review['summary'] review_text = review['reviewText'] product_document = '{0} \n{1}'.format( review_summary, review_text) product_document = io_utils.tokenize_text(product_document) document = ' '.join(product_document) writer.write_document(document_id, document) if (i + 1) % 1000 == 0: logging.info('Processed %d reviews.', i + 1) writer.close() logging.info('All done!')
def main(): parser = argparse.ArgumentParser() parser.add_argument('--loglevel', type=str, default='INFO') parser.add_argument('--data', type=argparse_utils.existing_file_path, required=True) parser.add_argument('--meta', type=argparse_utils.existing_file_path, required=True) parser.add_argument('--type', choices=MODELS, required=True) parser.add_argument('--iterations', type=argparse_utils.positive_int, default=1) parser.add_argument('--batch_size', type=argparse_utils.positive_int, default=1024) parser.add_argument('--word_representation_size', type=argparse_utils.positive_int, default=300) parser.add_argument('--representation_initializer', type=argparse_utils.existing_file_path, default=None) # Specific to VectorSpaceLanguageModel. parser.add_argument('--entity_representation_size', type=argparse_utils.positive_int, default=None) parser.add_argument('--num_negative_samples', type=argparse_utils.positive_int, default=None) parser.add_argument('--one_hot_classes', action='store_true', default=False) parser.add_argument('--regularization_lambda', type=argparse_utils.ratio, default=0.01) parser.add_argument('--model_output', type=str, required=True) args = parser.parse_args() if args.entity_representation_size is None: args.entity_representation_size = args.word_representation_size args.type = MODELS[args.type] try: logging_utils.configure_logging(args) except IOError: return -1 logging_utils.log_module_info(theano, lasagne, np, scipy) # Load data. logging.info('Loading data from %s.', args.data) data_sets = np.load(args.data) if 'w_train' in data_sets and not args.ignore_weights: w_train = data_sets['w_train'] else: logging.warning('No weights found in data set; ' 'assuming uniform instance weighting.') w_train = np.ones(data_sets['x_train'].shape[0], dtype=np.float32) training_set = (data_sets['x_train'], data_sets['y_train'][()], w_train) validation_set = (data_sets['x_validate'], data_sets['y_validate'][()]) logging.info('Training instances: %s (%s) %s (%s) %s (%s)', training_set[0].shape, training_set[0].dtype, training_set[1].shape, training_set[1].dtype, training_set[2].shape, training_set[2].dtype) logging.info('Validation instances: %s (%s) %s (%s)', validation_set[0].shape, validation_set[0].dtype, validation_set[1].shape, validation_set[1].dtype) num_entities = training_set[1].shape[1] assert num_entities > 1 if args.one_hot_classes: logging.info('Transforming y-values to one-hot values.') if not scipy.sparse.issparse(training_set[1]) or \ not scipy.sparse.issparse(validation_set[1]): raise RuntimeError( 'Argument --one_hot_classes expects sparse truth values.') y_train, (x_train, w_train) = sparse_to_one_hot_multiple( training_set[1], training_set[0], training_set[2]) training_set = (x_train, y_train, w_train) y_validate, (x_validate, ) = sparse_to_one_hot_multiple( validation_set[1], validation_set[0]) validation_set = (x_validate, y_validate) logging.info('Loading meta-data from %s.', args.meta) with open(args.meta, 'rb') as f: # We do not load the remaining of the vocabulary. data_args, words, tokens = (pickle.load(f) for _ in range(3)) vocabulary_size = len(words) representations = lasagne.init.GlorotUniform().sample( (vocabulary_size, args.word_representation_size)) if args.representation_initializer: # This way of creating the dictionary ignores duplicate words in # the representation initializer. representation_lookup = dict( embedding_utils.load_binary_representations( args.representation_initializer, tokens)) representation_init_count = 0 for word, meta in words.items(): if word.lower() in representation_lookup: representations[meta.id] = \ representation_lookup[word.lower()] representation_init_count += 1 logging.info( 'Initialized representations from ' 'pre-learned collection for %d words (%.2f%%).', representation_init_count, (representation_init_count / float(len(words))) * 100.0) # Allow GC to clear memory. del words del tokens model_options = { 'batch_size': args.batch_size, 'window_size': data_args.window_size, 'representations_init': representations, 'regularization_lambda': args.regularization_lambda, 'training_set': training_set, 'validation_set': validation_set, } if args.type == models.LanguageModel: model_options.update(output_layer_size=num_entities) elif args.type == models.VectorSpaceLanguageModel: entity_representations = lasagne.init.GlorotUniform().sample( (num_entities, args.entity_representation_size)) model_options.update( entity_representations_init=entity_representations, num_negative_samples=args.num_negative_samples) # Construct neural net. model = args.type(**model_options) train(model, args.iterations, args.model_output, abort_threshold=1e-5, early_stopping=False, additional_args=[args])
def main(): parser = argparse.ArgumentParser() parser.add_argument('--loglevel', type=str, default='INFO') parser.add_argument('--judgments', type=argparse_utils.existing_file_path, required=True) parser.add_argument('--session_topic_map', type=argparse_utils.existing_file_path, required=True) parser.add_argument('--track_year', type=int, choices=domain.INPUT_FORMATS, required=True) parser.add_argument('--qrel_out', type=argparse_utils.nonexisting_file_path, required=True) args = parser.parse_args() try: logging_utils.configure_logging(args) except IOError: return -1 topic_id_to_subtopics = collections.defaultdict(lambda: set([0])) topic_id_to_session_ids = collections.defaultdict(set) with open(args.session_topic_map, 'r') as f_mapping: for line in f_mapping: line = line.strip() if not line: continue try: data = line.strip().split()[:3] except: logging.warning('Unable to parse %s', line) continue if len(data) == 3: session_id, topic_id, subtopic_id = data else: session_id, topic_id, subtopic_id = data + [0] topic_id_to_subtopics[topic_id].add(subtopic_id) topic_id_to_session_ids[topic_id].add(session_id) with open(args.judgments, 'r') as f_judgments: qrel = sesh.parse_qrel(f_judgments, args.track_year) with open(args.qrel_out, 'w') as f_out: for topic_id, session_ids in topic_id_to_session_ids.items(): relevant_items = qrel[topic_id] for session_id in session_ids: for document_id, relevance in relevant_items.items(): f_out.write( '{session_id} 0 {document_id} {relevance}\n'.format( session_id=session_id, document_id=document_id, relevance=relevance))
def main(): parser = argparse.ArgumentParser() parser.add_argument('--loglevel', type=str, default='INFO') parser.add_argument('--seed', type=argparse_utils.positive_int, required=True) parser.add_argument('document_paths', type=argparse_utils.existing_file_path, nargs='+') parser.add_argument('--encoding', type=str, default='latin1') parser.add_argument('--assoc_path', type=argparse_utils.existing_file_path, required=True) parser.add_argument('--num_workers', type=int, default=1) parser.add_argument('--vocabulary_min_count', type=int, default=2) parser.add_argument('--vocabulary_min_word_size', type=int, default=2) parser.add_argument('--vocabulary_max_size', type=int, default=65536) parser.add_argument('--remove_stopwords', type=str, default='nltk') parser.add_argument('--validation_set_ratio', type=argparse_utils.ratio, default=0.01) parser.add_argument('--window_size', type=int, default=10) parser.add_argument('--overlapping', action='store_true', default=False) parser.add_argument('--stride', type=argparse_utils.positive_int, default=None) parser.add_argument('--resample', action='store_true', default=False) parser.add_argument('--no_shuffle', action='store_true', default=False) parser.add_argument('--no_padding', action='store_true', default=False) parser.add_argument('--no_instance_weights', action='store_true', default=False) parser.add_argument('--meta_output', type=argparse_utils.nonexisting_file_path, required=True) parser.add_argument('--data_output', type=argparse_utils.nonexisting_file_path, required=True) args = parser.parse_args() try: logging_utils.configure_logging(args) except IOError: return -1 # Seed RNG. np.random.seed(args.seed) logging_utils.log_module_info(np, scipy, sklearn) ignore_words = ['<doc>', '</doc>', '<docno>', '<text>', '</text>'] if args.remove_stopwords == 'none': logging.info('Stopwords will be included in instances.') elif args.remove_stopwords == 'nltk': logging.info('Using NLTK stopword list.') ignore_words.extend(stopwords.words('english')) ignore_words.extend(stopwords.words('dutch')) elif os.path.exists(args.remove_stopwords): logging.info('Using custom stopword list (%s).', args.remove_stopwords) with open(args.remove_stopwords, 'r') as f: ignore_words.extend( filter(len, map(str.strip, map(str.lower, f.readlines())))) else: logging.error('Invalid stopword removal strategy "%s".', args.remove_stopwords) return -1 logging.info('Ignoring words: %s.', ignore_words) # TODO(cvangysel): maybe switch by encapsulated call? words, tokens = io_utils.extract_vocabulary( args.document_paths, min_count=args.vocabulary_min_count, max_vocab_size=args.vocabulary_max_size, min_word_size=args.vocabulary_min_word_size, num_workers=args.num_workers, ignore_tokens=ignore_words, encoding=args.encoding) logging.info('Loading document identifiers.') reader = trec_utils.TRECTextReader(args.document_paths, encoding=args.encoding) document_ids = reader.iter_document_ids(num_workers=args.num_workers) with open(args.assoc_path, 'r') as f_assocs: assocs = trec_utils.EntityDocumentAssociations( f_assocs, document_ids=document_ids) logging.info('Found %d unique entities.', len(assocs.entities)) logging.info('Document-per-entity stats: %s', list(map(lambda kv: (kv[0], len(kv[1])), sorted(assocs.documents_per_entity.items())))) logging.info( 'Entity-per-document association stats: %s', collections.Counter( map(len, assocs.entities_per_document.values())).items()) # Estimate the position in authorship distribution. num_associations_distribution = np.zeros( assocs.max_entities_per_document, dtype=np.int32) for association_length in ( len(associated_entities) for associated_entities in assocs.entities_per_document.values()): num_associations_distribution[association_length - 1] += 1 logging.info('Number of associations distribution: %s', num_associations_distribution) position_in_associations_distribution = np.cumsum( num_associations_distribution[::-1])[::-1] logging.info('Position in associations distribution: %s', position_in_associations_distribution) instances_and_labels = [] num_documents = 0 num_non_associated_documents = 0 documents_per_entity = collections.defaultdict(int) instances_per_entity = collections.defaultdict(int) instances_per_document = {} global_label_distribution = collections.defaultdict(float) if args.overlapping and args.stride is None: args.stride = 1 elif args.overlapping and args.stride is not None: logging.error('Option --overlapping passed ' 'concurrently with --stride.') return -1 elif args.stride is None: logging.info('Defaulting stride to window size.') args.stride = args.window_size logging.info('Generating instances with stride %d.', args.stride) result_q = multiprocessing.Queue() pool = multiprocessing.Pool( args.num_workers, initializer=prepare_initializer, initargs=[result_q, args, assocs.entities, assocs.entities_per_document, position_in_associations_distribution, tokens, words, args.encoding]) max_document_length = 0 worker_result = pool.map_async( prepare_worker, args.document_paths) # We will not submit any more tasks to the pool. pool.close() it = multiprocessing_utils.QueueIterator( pool, worker_result, result_q) def _extract_key(obj): return tuple(sorted(obj)) num_labels = 0 num_instances = 0 instances_per_label = collections.defaultdict(list) while True: try: result = next(it) except StopIteration: break num_documents += 1 if result: document_id, \ document_instances_and_labels, \ document_label = result assert document_id not in instances_per_document num_instances_for_doc = len(document_instances_and_labels) instances_per_document[document_id] = num_instances_for_doc max_document_length = max(max_document_length, num_instances_for_doc) # For statistical purposes. for entity_id in assocs.entities_per_document[document_id]: documents_per_entity[entity_id] += 1 # For statistical purposes. for entity_id in assocs.entities_per_document[document_id]: num_labels += num_instances_for_doc # Aggregate. instances_per_label[ _extract_key(document_label.keys())].extend( document_instances_and_labels) num_instances += len(document_instances_and_labels) # Some more accounting. for entity_id, mass in document_label.items(): global_label_distribution[entity_id] += \ num_instances_for_doc * mass else: num_non_associated_documents += 1 # assert result_q.empty() logging.info('Global unnormalized distribution: %s', global_label_distribution) if args.resample: num_entities = len(instances_per_label) avg_instances_per_doc = ( float(num_instances) / float(num_entities)) min_instances_per_entity = int(avg_instances_per_doc) logging.info('Setting number of sampled instances ' 'to %d per document.', avg_instances_per_doc) for label in list(instances_per_label.keys()): label_instances_and_labels = instances_per_label.pop(label) if not label_instances_and_labels: logging.warning('Label %s has no instances; skipping.', label) continue label_num_instances = len(label_instances_and_labels) if args.resample: assert min_instances_per_entity > 0 label_instances_and_labels_sample = [] else: label_instances_and_labels_sample = \ label_instances_and_labels while (len(label_instances_and_labels_sample) < min_instances_per_entity): label_instances_and_labels_sample.append( label_instances_and_labels[ np.random.randint(label_num_instances)]) for entity_id in label: instances_per_entity[entity_id] += len( label_instances_and_labels_sample) instances_and_labels.extend(label_instances_and_labels_sample) logging.info( 'Documents-per-indexed-entity stats (mean=%.2f, std_dev=%.2f): %s', np.mean(list(documents_per_entity.values())), np.std(list(documents_per_entity.values())), sorted(documents_per_entity.items())) logging.info( 'Instances-per-indexed-entity stats ' '(mean=%.2f, std_dev=%.2f, min=%d, max=%d): %s', np.mean(list(instances_per_entity.values())), np.std(list(instances_per_entity.values())), np.min(instances_per_entity.values()), np.max(instances_per_entity.values()), sorted(instances_per_entity.items())) logging.info( 'Instances-per-document stats (mean=%.2f, std_dev=%.2f, max=%d).', np.mean(list(instances_per_document.values())), np.std(list(instances_per_document.values())), max_document_length) logging.info('Observed %d documents of which %d (ratio=%.2f) ' 'are not associated with an entity.', num_documents, num_non_associated_documents, (float(num_non_associated_documents) / num_documents)) training_instances_and_labels, validation_instances_and_labels = \ sklearn.cross_validation.train_test_split( instances_and_labels, test_size=args.validation_set_ratio) num_training_instances = len(training_instances_and_labels) num_validation_instances = len(validation_instances_and_labels) num_instances = num_training_instances + num_validation_instances logging.info( 'Processed %d instances; training=%d, validation=%d (ratio=%.2f).', num_instances, num_training_instances, num_validation_instances, (float(num_validation_instances) / (num_training_instances + num_validation_instances))) # Figure out if there are any entities with no instances, and # do not consider them during training. entity_indices = {} entity_indices_inv = {} for entity_id, num_instances in instances_per_entity.items(): if not num_instances: continue entity_index = len(entity_indices) entity_indices[entity_id] = entity_index entity_indices_inv[entity_index] = entity_id logging.info('Retained %d entities after instance creation.', len(entity_indices)) directories = list(map(os.path.dirname, [args.meta_output, args.data_output])) # Create those directories. [os.makedirs(directory) for directory in directories if not os.path.exists(directory)] # Dump vocabulary. with open(args.meta_output, 'wb') as f: for obj in (args, words, tokens, entity_indices_inv, assocs.documents_per_entity): pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL) logging.info('Saved vocabulary to "%s".', args.meta_output) instance_dtype = np.min_scalar_type(len(words) - 1) logging.info('Instance elements will be stored using %s.', instance_dtype) data = {} x_train, y_train = instances_and_labels_to_arrays( training_instances_and_labels, args.window_size, entity_indices, instance_dtype, not args.no_shuffle) data['x_train'] = x_train data['y_train'] = y_train if not args.no_instance_weights: w_train = np.fromiter( (float(max_document_length) / instances_per_document[doc_id] for doc_id, _, _ in training_instances_and_labels), np.float32, len(training_instances_and_labels)) assert w_train.shape == (x_train.shape[0],) data['w_train'] = w_train x_validate, y_validate = instances_and_labels_to_arrays( validation_instances_and_labels, args.window_size, entity_indices, instance_dtype, not args.no_shuffle) data['x_validate'] = x_validate data['y_validate'] = y_validate with open(args.data_output, 'wb') as f: np.savez(f, **data) logging.info('Saved data sets.') logging.info('Entity-per-document association stats: {0}'.format( collections.Counter( map(len, assocs.entities_per_document.values())).items())) logging.info('Documents-per-entity stats: {0}'.format( collections.Counter( map(len, assocs.documents_per_entity.values())).items())) logging.info('Done.')
def main(): parser = argparse.ArgumentParser() parser.add_argument('--loglevel', type=str, default='INFO') parser.add_argument('--shard_size', type=argparse_utils.positive_int, default=(1 << 14)) parser.add_argument('meta_file', type=argparse_utils.existing_file_path) parser.add_argument('--product_list', type=argparse_utils.existing_file_path, default=None) parser.add_argument('--trectext_out', type=argparse_utils.nonexisting_file_path, required=True) args = parser.parse_args() try: logging_utils.configure_logging(args) except IOError: return -1 if args.product_list: with open(args.product_list, 'r') as f_product_list: product_list = set(product_id.strip() for product_id in f_product_list) logging.info('Only considering white list of %d products.', len(product_list)) else: product_list = None writer = trec_utils.ShardedTRECTextWriter(args.trectext_out, args.shard_size) department = ' '.join( os.path.basename(args.meta_file).split('.')[0].split('_')[1:]).replace( ' and ', ' & ') logging.info('Department: %s', department) with gzip.open(args.meta_file, 'r') as f_meta: for i, raw_line in enumerate(f_meta): raw_line = raw_line.decode('utf8') product = ast.literal_eval(raw_line) product_id = product['asin'] if product_list and product_id not in product_list: continue if 'description' in product and 'title' in product: product_title = product['title'] product_description = \ io_utils.strip_html(product['description']) product_document = '{0} \n{1}'.format(product_title, product_description) product_document = io_utils.tokenize_text(product_document) logging.debug('Product %s has description of %d tokens.', len(product_document)) writer.write_document(product_id, ' '.join(product_document)) else: logging.debug( 'Filtering product %s due to missing description.', product_id) continue if (i + 1) % 1000 == 0: logging.info('Processed %d products.', i + 1) writer.close() logging.info('All done!')
def main(): parser = argparse.ArgumentParser() parser.add_argument('model') parser.add_argument('index', type=argparse_utils.existing_directory_path) parser.add_argument('--limit', type=argparse_utils.positive_int, default=None) parser.add_argument('--object_classification', type=argparse_utils.existing_file_path, nargs='+', default=None) parser.add_argument('--filter_unclassified', action='store_true', default=False) parser.add_argument('--l2_normalize', action='store_true', default=False) parser.add_argument('--mode', choices=('tsne', 'embedding_projector'), default='tsne') parser.add_argument('--legend', action='store_true', default=False) parser.add_argument('--tick_labels', action='store_true', default=False) parser.add_argument('--edges', action='store_true', default=False) parser.add_argument('--border', action='store_true', default=False) parser.add_argument('--plot_out', type=argparse_utils.nonexisting_file_path, required=True) args = parser.parse_args() try: logging_utils.configure_logging(args) except IOError: return -1 # Set matplotlib style. plt.style.use('bmh') logging.info('Loading index.') index = pyndri.Index(args.index) logging.info('Loading cuNVSM model.') model_base, epoch_and_ext = args.model.rsplit('_', 1) epoch = int(epoch_and_ext.split('.')[0]) if not os.path.exists('{}_meta'.format(model_base)): model_meta_base, batch_idx = model_base.rsplit('_', 1) else: model_meta_base = model_base model = nvsm.load_model( nvsm.load_meta(model_meta_base), model_base, epoch, only_object_embeddings=True) raw_object_representations = np.copy(model.object_representations) if args.limit: raw_object_representations = raw_object_representations[:args.limit, :] for object_classification in args.object_classification: root, ext = os.path.splitext(args.plot_out) plot_out = '{}-{}.{}'.format( root, os.path.basename(object_classification), ext.lstrip('.')) if object_classification and args.filter_unclassified: logging.info('Filtering unclassified.') with open(object_classification, 'r') as f_objects: object_ids = [line.strip().split()[0] for line in f_objects] indices = sorted(model.inv_object_mapping[idx] for _, idx in index.document_ids(object_ids) if idx in model.inv_object_mapping) logging.info('Considering %d out of %d representations.', len(indices), len(object_ids)) translation_table = {idx: i for i, idx in enumerate(indices)} object_representations = raw_object_representations[indices] assert object_representations.shape[0] == \ len(translation_table) else: translation_table = None raise NotImplementedError() logging.info('Loading object clusters.') cluster_id_to_product_ids = {} if object_classification: with open(object_classification, 'r') as f_objects: for line in f_objects: object_id, cluster_id = line.strip().split() if cluster_id not in cluster_id_to_product_ids: cluster_id_to_product_ids[cluster_id] = set() cluster_id_to_product_ids[cluster_id].add(object_id) for cluster_id in list(cluster_id_to_product_ids.keys()): object_ids = list(cluster_id_to_product_ids[cluster_id]) cluster_id_to_product_ids[cluster_id] = set( (model.inv_object_mapping[int_object_id] if translation_table is None else translation_table[ model.inv_object_mapping[int_object_id]]) for ext_object_id, int_object_id in index.document_ids(object_ids) if int_object_id in model.inv_object_mapping and (args.limit is None or (model.inv_object_mapping[int_object_id] < args.limit))) else: raise NotImplementedError() assert len(cluster_id_to_product_ids) < len(MARKERS) if args.l2_normalize: logging.info('L2-normalizing representations.') object_representations /= np.linalg.norm( object_representations, axis=1, keepdims=True) if args.mode == 'tsne': logging.info('Running t-SNE.') twodim_object_representations = \ TSNE(n_components=2, init='pca', random_state=0).\ fit_transform(object_representations) logging.info('Plotting %s.', twodim_object_representations.shape) colors = cm.rainbow( np.linspace(0, 1, len(cluster_id_to_product_ids))) for idx, cluster_id in enumerate( sorted(cluster_id_to_product_ids.keys(), key=lambda cluster_id: len( cluster_id_to_product_ids[cluster_id]), reverse=True)): row_ids = list(cluster_id_to_product_ids[cluster_id]) plt.scatter( twodim_object_representations[row_ids, 0], twodim_object_representations[row_ids, 1], marker=MARKERS[idx], edgecolors='grey' if args.edges else None, cmap=plt.cm.Spectral, color=colors[idx], alpha=0.3, label=pylatex.utils.escape_latex(cluster_id)) plt.grid() plt.tight_layout() if args.legend: plt.legend(bbox_to_anchor=(0, -0.15, 1, 0), loc=2, ncol=2, mode='expand', borderaxespad=0) if not args.tick_labels: plt.gca().get_xaxis().set_visible(False) plt.gca().get_yaxis().set_visible(False) if not args.border: # plt.gcf().patch.set_visible(False) plt.gca().axis('off') logging.info('Writing %s.', plot_out) plt.savefig(plot_out, bbox_inches='tight', transparent=True, pad_inches=0, dpi=200) elif args.mode == 'embedding_projector': logging.info('Dumping to TensorFlow embedding projector format.') with open('{}_vectors.tsv'.format(plot_out), 'w') as f_vectors, \ open('{}_meta.tsv'.format(plot_out), 'w') as f_meta: f_meta.write('document_id\tclass\n') def write_rowids(row_ids, cluster_id): for row_id in row_ids: f_vectors.write( '{}\n'.format('\t'.join( '{:.5f}'.format(x) for x in object_representations[row_id]))) f_meta.write('{}\t{}\n'.format( index.ext_document_id( model.object_mapping[row_id]), cluster_id)) for cluster_id in cluster_id_to_product_ids.keys(): row_ids = list(cluster_id_to_product_ids[cluster_id]) write_rowids(row_ids, cluster_id) logging.info('All done!')
def main(): parser = argparse.ArgumentParser() parser.add_argument('--loglevel', type=str, default='INFO') parser.add_argument('--meta', type=argparse_utils.existing_file_path, required=True) parser.add_argument('--model', type=argparse_utils.existing_file_path, required=True) parser.add_argument('--topics', type=argparse_utils.existing_file_path, nargs='+') parser.add_argument('--top', type=argparse_utils.positive_int, default=None) parser.add_argument('--run_out', type=argparse_utils.nonexisting_file_path, required=True) args = parser.parse_args() try: logging_utils.configure_logging(args) except IOError: return -1 with open(args.model, 'rb') as f: # Load model arguments and learned mapping. model_args, predict_fn = (pickle.load(f) for _ in range(2)) # Load word representations. word_representations = pickle.load(f) try: entity_representations = pickle.load(f) except EOFError: entity_representations = None with open(args.meta, 'rb') as f: (data_args, words, tokens, entity_indices_inv, entity_assocs) = ( pickle.load(f) for _ in range(5)) # Parse topic files. topic_f = list(map(lambda filename: open(filename, 'r'), args.topics)) topics = trec_utils.parse_topics(topic_f) [f_.close() for f_ in topic_f] model_name = os.path.basename(args.model) # Entity profiling. topics_per_entity = collections.defaultdict(list) # Entity finding. entities_per_topic = collections.defaultdict(list) def ranker_callback(topic_id, top_ranked_indices, top_ranked_values): for rank, (entity_internal_id, relevance) in enumerate( zip(top_ranked_indices, top_ranked_values)): entity_id = entity_indices_inv[entity_internal_id] # Entity profiling. topics_per_entity[entity_id].append((relevance, topic_id)) # Entity finding. entities_per_topic[topic_id].append((relevance, entity_id)) with open('{0}_debug'.format(args.run_out), 'w') as f_debug_out: if model_args.type == models.LanguageModel: result_callback = LogLinearCallback( args, model_args, tokens, f_debug_out, ranker_callback) elif model_args.type == models.VectorSpaceLanguageModel: result_callback = VectorSpaceCallback( entity_representations, args, model_args, tokens, f_debug_out, ranker_callback) batcher = inference.create( predict_fn, word_representations, model_args.batch_size, data_args.window_size, len(words), result_callback) logging.info('Batching queries using %s.', batcher) for q_id, (topic_id, terms) in enumerate(topics.items()): if topic_id not in topics: logging.error('Topic "%s" not found in topic list.', topic_id) continue # Do not replace numeric tokens in queries. query_terms = trec_utils.parse_query(terms) query_tokens = [] logging.debug('Query (%d/%d) %s: %s (%s)', q_id + 1, len(topics), topic_id, query_terms, terms) for term in query_terms: if term not in words: logging.debug('Term "%s" is OOV.', term) continue term_token = words[term].id query_tokens.append(term_token) if not query_tokens: logging.warning('Skipping query with terms "%s".', terms) continue batcher.submit(query_tokens, topic_id=topic_id) batcher.process() # Entity profiling. with io.open('{0}_ep'.format(args.run_out), 'w', encoding='utf8') as out_ep_run: trec_utils.write_run(model_name, topics_per_entity, out_ep_run) # Entity finding. with io.open('{0}_ef'.format(args.run_out), 'w', encoding='utf8') as out_ef_run: trec_utils.write_run(model_name, entities_per_topic, out_ef_run) logging.info('Saved run to %s.', args.run_out)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--loglevel', type=str, default='INFO') parser.add_argument('--meta', type=argparse_utils.existing_file_path, required=True) parser.add_argument('--model', type=argparse_utils.existing_file_path, required=True) parser.add_argument('--topics', type=argparse_utils.existing_file_path, nargs='+') parser.add_argument('--top', type=argparse_utils.positive_int, default=None) parser.add_argument('--run_out', type=argparse_utils.nonexisting_file_path, required=True) args = parser.parse_args() try: logging_utils.configure_logging(args) except IOError: return -1 with open(args.model, 'rb') as f: # Load model arguments and learned mapping. model_args, predict_fn = (pickle.load(f) for _ in range(2)) # Load word representations. word_representations = pickle.load(f) try: entity_representations = pickle.load(f) except EOFError: entity_representations = None with open(args.meta, 'rb') as f: (data_args, words, tokens, entity_indices_inv, entity_assocs) = (pickle.load(f) for _ in range(5)) # Parse topic files. topic_f = list(map(lambda filename: open(filename, 'r'), args.topics)) topics = trec_utils.parse_topics(topic_f) [f_.close() for f_ in topic_f] model_name = os.path.basename(args.model) # Entity profiling. topics_per_entity = collections.defaultdict(list) # Entity finding. entities_per_topic = collections.defaultdict(list) def ranker_callback(topic_id, top_ranked_indices, top_ranked_values): for rank, (entity_internal_id, relevance) in enumerate( zip(top_ranked_indices, top_ranked_values)): entity_id = entity_indices_inv[entity_internal_id] # Entity profiling. topics_per_entity[entity_id].append((relevance, topic_id)) # Entity finding. entities_per_topic[topic_id].append((relevance, entity_id)) with open('{0}_debug'.format(args.run_out), 'w') as f_debug_out: if model_args.type == models.LanguageModel: result_callback = LogLinearCallback(args, model_args, tokens, f_debug_out, ranker_callback) elif model_args.type == models.VectorSpaceLanguageModel: result_callback = VectorSpaceCallback(entity_representations, args, model_args, tokens, f_debug_out, ranker_callback) batcher = inference.create(predict_fn, word_representations, model_args.batch_size, data_args.window_size, len(words), result_callback) logging.info('Batching queries using %s.', batcher) for q_id, (topic_id, terms) in enumerate(topics.items()): if topic_id not in topics: logging.error('Topic "%s" not found in topic list.', topic_id) continue # Do not replace numeric tokens in queries. query_terms = trec_utils.parse_query(terms) query_tokens = [] logging.debug('Query (%d/%d) %s: %s (%s)', q_id + 1, len(topics), topic_id, query_terms, terms) for term in query_terms: if term not in words: logging.debug('Term "%s" is OOV.', term) continue term_token = words[term].id query_tokens.append(term_token) if not query_tokens: logging.warning('Skipping query with terms "%s".', terms) continue batcher.submit(query_tokens, topic_id=topic_id) batcher.process() # Entity profiling. with io.open('{0}_ep'.format(args.run_out), 'w', encoding='utf8') as out_ep_run: trec_utils.write_run(model_name, topics_per_entity, out_ep_run) # Entity finding. with io.open('{0}_ef'.format(args.run_out), 'w', encoding='utf8') as out_ef_run: trec_utils.write_run(model_name, entities_per_topic, out_ef_run) logging.info('Saved run to %s.', args.run_out)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--loglevel', type=str, default='INFO') parser.add_argument('--shard_size', type=argparse_utils.positive_int, default=1000000) parser.add_argument('sgm', type=argparse_utils.existing_file_path, nargs='+') parser.add_argument('--top_k_topics', type=argparse_utils.positive_int, default=20) parser.add_argument('--trectext_out_prefix', type=str, required=True) parser.add_argument('--document_classification_out', type=argparse_utils.nonexisting_file_path, required=True) args = parser.parse_args() try: logging_utils.configure_logging(args) except IOError: return -1 parser = ReutersParser() for sgm_path in args.sgm: logging.info('Parsing %s.', sgm_path) with open(sgm_path, 'r', encoding='ISO-8859-1') as f_sgm: parser.feed(f_sgm.read()) logging.info('Parsed %d documents.', len(parser.documents)) topic_histogram = collections.Counter( topic for document in parser.documents for topic in document['tags']['topics']) top_topics = set( sorted(topic_histogram.keys(), key=lambda topic: topic_histogram[topic])[-args.top_k_topics:]) logging.info('Top topics: %s', top_topics) writer = trec_utils.ShardedTRECTextWriter(args.trectext_out_prefix, shard_size=args.shard_size, encoding='latin1') with open(args.document_classification_out, 'w') as \ f_document_classification_out: for document in parser.documents: doc_id = document['doc_id'] doc_text = '\n'.join([ document['texts'].get('title', ''), document['texts'].get('dateline', ''), document['texts'].get('body', '') ]) writer.write_document(doc_id, doc_text) doc_topics = { topic for topic in document['tags']['topics'] if topic in top_topics } if doc_topics: most_specific_doc_topic = min( doc_topics, key=lambda topic: topic_histogram[topic]) f_document_classification_out.write(doc_id) f_document_classification_out.write(' ') f_document_classification_out.write(most_specific_doc_topic) f_document_classification_out.write('\n') writer.close()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--loglevel', type=str, default='INFO') parser.add_argument('--shard_size', type=argparse_utils.positive_int, default=(1 << 14)) parser.add_argument('meta_file', type=argparse_utils.existing_file_path) parser.add_argument('--product_list', type=argparse_utils.existing_file_path, default=None) parser.add_argument('--trectext_out', type=argparse_utils.nonexisting_file_path, required=True) args = parser.parse_args() try: logging_utils.configure_logging(args) except IOError: return -1 if args.product_list: with open(args.product_list, 'r') as f_product_list: product_list = set( product_id.strip() for product_id in f_product_list) logging.info('Only considering white list of %d products.', len(product_list)) else: product_list = None writer = trec_utils.ShardedTRECTextWriter( args.trectext_out, args.shard_size) department = ' '.join( os.path.basename(args.meta_file).split('.')[0] .split('_')[1:]).replace(' and ', ' & ') logging.info('Department: %s', department) with gzip.open(args.meta_file, 'r') as f_meta: for i, raw_line in enumerate(f_meta): raw_line = raw_line.decode('utf8') product = ast.literal_eval(raw_line) product_id = product['asin'] if product_list and product_id not in product_list: continue if 'description' in product and 'title' in product: product_title = product['title'] product_description = \ io_utils.strip_html(product['description']) product_document = '{0} \n{1}'.format( product_title, product_description) product_document = io_utils.tokenize_text(product_document) logging.debug('Product %s has description of %d tokens.', len(product_document)) writer.write_document( product_id, ' '.join(product_document)) else: logging.debug( 'Filtering product %s due to missing description.', product_id) continue if (i + 1) % 1000 == 0: logging.info('Processed %d products.', i + 1) writer.close() logging.info('All done!')
def main(): parser = argparse.ArgumentParser() parser.add_argument('--loglevel', type=str, default='INFO') parser.add_argument('index', type=argparse_utils.existing_file_path) parser.add_argument('session_file', type=argparse_utils.existing_file_path) parser.add_argument('--num_workers', type=int, default=1) parser.add_argument('--harvested_links_file', type=argparse_utils.existing_file_path, default=None) parser.add_argument('--qrel', type=argparse_utils.existing_file_path, nargs='*') parser.add_argument('--configuration', type=str, nargs='+') parser.add_argument('--top_sessions', type=argparse_utils.positive_int, default=None) parser.add_argument('--out_base', type=argparse_utils.nonexisting_file_path, required=True) args = parser.parse_args() try: logging_utils.configure_logging(args) except IOError: return -1 logging_utils.log_module_info(np, scipy) configuration = sesh_pb2.ScoreSessionsConfig() pb.text_format.Merge(' '.join(args.configuration), configuration) if not configuration.modifier: configuration.modifier.add() # Create an empty modifier. elif len(configuration.modifier) > 1: modifier_identifiers = [ modifier.identifier for modifier in configuration.modifier] assert all(modifier_identifiers), \ 'All session modifiers should have an identifier.' assert len(modifier_identifiers) == len(set(modifier_identifiers)), \ 'All session modifier identifiers should be unique: {}.'.format( modifier_identifiers) logging.info('Configuration: %s', configuration) logging.info('Loading index.') index = pyndri.Index(args.index) num_documents = index.document_count() logging.debug('Index contains %d documents.', num_documents) logging.info('Loading dictionary.') dictionary = pyndri.extract_dictionary(index) logging.info('Loading background corpus.') background_prob_dist = pyndri_utils.extract_background_prob_dist(index) for modifier in configuration.modifier: out_base = os.path.join(args.out_base, modifier.identifier) assert not os.path.exists(out_base) os.makedirs(out_base) logging.info('Loading sessions using %s and outputting to %s.', modifier or 'no modifier', out_base) with codecs.open(args.session_file, 'r', 'utf8') as f_xml: track_edition, _, sessions, session_id_to_topic_id = \ domain.construct_sessions( f_xml, args.top_sessions, dictionary) logging.info('Discovered %d sessions.', len(sessions)) sessions = domain.alter_sessions(sessions, modifier) documents = domain.get_document_set(sessions.values()) logging.info('Retained %d sessions (%d SERP documents) ' 'after filtering.', len(sessions), len(documents)) # Load QRels for debugging and oracle runs. qrels_per_session = [] for qrel_path in args.qrel: with open(qrel_path, 'r') as f_qrel: qrels_per_session.append(sesh.parse_qrel(f_qrel, None)) scorer_impls = {} for scorer_desc in configuration.scorer: assert scorer_desc.type in scorers.SESSION_SCORERS identifier = scorer_desc.identifier or scorer_desc.type assert identifier not in scorer_impls scorer = scorers.create_scorer(scorer_desc, qrels_per_session) logging.info('Scoring using %s.', repr(scorer)) scorer_impls[identifier] = scorer anchor_texts = None if args.harvested_links_file is not None: urls = set(document.url for document in documents) logging.info('Loading anchor texts for session SERPs (%d URLs).', len(urls)) with codecs.open(args.harvested_links_file, 'r', 'latin1') \ as f_harvested_links: anchor_texts = load_anchor_texts(f_harvested_links, urls) logging.info('Discovered anchor texts for %d URLs (%d total).', len(anchor_texts), len(urls)) else: logging.info('No anchor texts loaded.') # The following will hold all the rankings. document_assessments_per_session_per_scorer = collections.defaultdict( lambda: collections.defaultdict( lambda: collections.defaultdict(float))) assert configuration.candidate_generator in \ DOCUMENT_CANDIDATE_GENERATORS # Document candidate generation. candidate_generator = DOCUMENT_CANDIDATE_GENERATORS[ configuration.candidate_generator](**locals()) logging.info('Using %s for document candidate generation.', candidate_generator) result_queue = multiprocessing.Queue() initargs = [ result_queue, args, configuration, out_base, background_prob_dist, candidate_generator, scorer_impls, index, dictionary, anchor_texts] pool = multiprocessing.Pool( args.num_workers, initializer=score_session_initializer, initargs=initargs) worker_result = pool.map_async( score_session_worker, sessions.values()) # We will not submit any more tasks to the pool. pool.close() it = multiprocessing_utils.QueueIterator( pool, worker_result, result_queue) while True: try: result = next(it) except StopIteration: break scorer_name, session_id, ranking = result document_assessments_per_session_per_scorer[ scorer_name][session_id] = ranking for scorer_name in document_assessments_per_session_per_scorer: # Switch object asssessments to lists. for topic_id, object_assesments in \ document_assessments_per_session_per_scorer[ scorer_name].items(): document_assessments_per_session_per_scorer[ scorer_name][topic_id] = [ (score, document_id) for document_id, score in object_assesments.items()] # Write the runs. for scorer_name in document_assessments_per_session_per_scorer: run_out_path = os.path.join( out_base, '{0}.run'.format(scorer_name)) with io.open(run_out_path, 'w', encoding='utf8') as f_run_out: trec_utils.write_run( scorer_name, document_assessments_per_session_per_scorer[scorer_name], f_run_out)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--loglevel', type=str, default='INFO') parser.add_argument('--data', type=argparse_utils.existing_file_path, required=True) parser.add_argument('--meta', type=argparse_utils.existing_file_path, required=True) parser.add_argument('--type', choices=MODELS, required=True) parser.add_argument('--iterations', type=argparse_utils.positive_int, default=1) parser.add_argument('--batch_size', type=argparse_utils.positive_int, default=1024) parser.add_argument('--word_representation_size', type=argparse_utils.positive_int, default=300) parser.add_argument('--representation_initializer', type=argparse_utils.existing_file_path, default=None) # Specific to VectorSpaceLanguageModel. parser.add_argument('--entity_representation_size', type=argparse_utils.positive_int, default=None) parser.add_argument('--num_negative_samples', type=argparse_utils.positive_int, default=None) parser.add_argument('--one_hot_classes', action='store_true', default=False) parser.add_argument('--regularization_lambda', type=argparse_utils.ratio, default=0.01) parser.add_argument('--model_output', type=str, required=True) args = parser.parse_args() if args.entity_representation_size is None: args.entity_representation_size = args.word_representation_size args.type = MODELS[args.type] try: logging_utils.configure_logging(args) except IOError: return -1 logging_utils.log_module_info(theano, lasagne, np, scipy) # Load data. logging.info('Loading data from %s.', args.data) data_sets = np.load(args.data) if 'w_train' in data_sets and not args.ignore_weights: w_train = data_sets['w_train'] else: logging.warning('No weights found in data set; ' 'assuming uniform instance weighting.') w_train = np.ones(data_sets['x_train'].shape[0], dtype=np.float32) training_set = (data_sets['x_train'], data_sets['y_train'][()], w_train) validation_set = (data_sets['x_validate'], data_sets['y_validate'][()]) logging.info('Training instances: %s (%s) %s (%s) %s (%s)', training_set[0].shape, training_set[0].dtype, training_set[1].shape, training_set[1].dtype, training_set[2].shape, training_set[2].dtype) logging.info('Validation instances: %s (%s) %s (%s)', validation_set[0].shape, validation_set[0].dtype, validation_set[1].shape, validation_set[1].dtype) num_entities = training_set[1].shape[1] assert num_entities > 1 if args.one_hot_classes: logging.info('Transforming y-values to one-hot values.') if not scipy.sparse.issparse(training_set[1]) or \ not scipy.sparse.issparse(validation_set[1]): raise RuntimeError( 'Argument --one_hot_classes expects sparse truth values.') y_train, (x_train, w_train) = sparse_to_one_hot_multiple( training_set[1], training_set[0], training_set[2]) training_set = (x_train, y_train, w_train) y_validate, (x_validate,) = sparse_to_one_hot_multiple( validation_set[1], validation_set[0]) validation_set = (x_validate, y_validate) logging.info('Loading meta-data from %s.', args.meta) with open(args.meta, 'rb') as f: # We do not load the remaining of the vocabulary. data_args, words, tokens = (pickle.load(f) for _ in range(3)) vocabulary_size = len(words) representations = lasagne.init.GlorotUniform().sample( (vocabulary_size, args.word_representation_size)) if args.representation_initializer: # This way of creating the dictionary ignores duplicate words in # the representation initializer. representation_lookup = dict( embedding_utils.load_binary_representations( args.representation_initializer, tokens)) representation_init_count = 0 for word, meta in words.items(): if word.lower() in representation_lookup: representations[meta.id] = \ representation_lookup[word.lower()] representation_init_count += 1 logging.info('Initialized representations from ' 'pre-learned collection for %d words (%.2f%%).', representation_init_count, (representation_init_count / float(len(words))) * 100.0) # Allow GC to clear memory. del words del tokens model_options = { 'batch_size': args.batch_size, 'window_size': data_args.window_size, 'representations_init': representations, 'regularization_lambda': args.regularization_lambda, 'training_set': training_set, 'validation_set': validation_set, } if args.type == models.LanguageModel: model_options.update( output_layer_size=num_entities) elif args.type == models.VectorSpaceLanguageModel: entity_representations = lasagne.init.GlorotUniform().sample( (num_entities, args.entity_representation_size)) model_options.update( entity_representations_init=entity_representations, num_negative_samples=args.num_negative_samples) # Construct neural net. model = args.type(**model_options) train(model, args.iterations, args.model_output, abort_threshold=1e-5, early_stopping=False, additional_args=[args])
def main(): parser = argparse.ArgumentParser() parser.add_argument('--loglevel', type=str, default='INFO') parser.add_argument('--num_workers', type=argparse_utils.positive_int, default=16) parser.add_argument('--topics', nargs='+', type=argparse_utils.existing_file_path) parser.add_argument('model', type=argparse_utils.existing_file_path) parser.add_argument('--index', required=True) parser.add_argument('--linear', action='store_true', default=False) parser.add_argument('--self_information', action='store_true', default=False) parser.add_argument('--l2norm_phrase', action='store_true', default=False) parser.add_argument('--bias_coefficient', type=argparse_utils.ratio, default=0.0) parser.add_argument('--rerank_exact_matching_documents', action='store_true', default=False) parser.add_argument('--strict', action='store_true', default=False) parser.add_argument('--top_k', default=None) parser.add_argument('--num_queries', type=argparse_utils.positive_int, default=None) parser.add_argument('run_out') args = parser.parse_args() args.index = pyndri.Index(args.index) try: logging_utils.configure_logging(args) except IOError: return -1 if not args.top_k: args.top_k = 1000 elif args.top_k == 'all': args.top_k = args.top_k = \ args.index.maximum_document() - args.index.document_base() elif args.top_k.isdigit(): args.top_k = int(args.top_k) elif all(map(os.path.exists, args.top_k.split())): topics_and_documents = {} for qrel_path in args.top_k.split(): with open(qrel_path, 'r') as f_qrel: for topic_id, judgments in trec_utils.parse_qrel(f_qrel): if topic_id not in topics_and_documents: topics_and_documents[topic_id] = set() for doc_id, _ in judgments: topics_and_documents[topic_id].add(doc_id) args.top_k = topics_and_documents else: raise RuntimeError() logging.info('Loading dictionary.') dictionary = pyndri.extract_dictionary(args.index) logging.info('Loading model.') model_base, epoch_and_ext = args.model.rsplit('_', 1) epoch = int(epoch_and_ext.split('.')[0]) if not os.path.exists('{}_meta'.format(model_base)): model_meta_base, batch_idx = model_base.rsplit('_', 1) else: model_meta_base = model_base kwargs = { 'strict': args.strict, } if args.self_information: kwargs['self_information'] = True if args.linear: kwargs['bias_coefficient'] = args.bias_coefficient kwargs['nonlinearity'] = None if args.l2norm_phrase: kwargs['l2norm_phrase'] = True model = nvsm.load_model( nvsm.load_meta(model_meta_base), model_base, epoch, **kwargs) for topic_path in args.topics: run_out_path = '{}-{}'.format( args.run_out, os.path.basename(topic_path)) if os.path.exists(run_out_path): logging.warning('Run for topics %s already exists (%s); skipping.', topic_path, run_out_path) continue queries = list(pyndri.utils.parse_queries( args.index, dictionary, topic_path, strict=args.strict, num_queries=args.num_queries)) if args.rerank_exact_matching_documents: assert not isinstance(args.top_k, dict) topics_and_documents = {} query_env = pyndri.TFIDFQueryEnvironment(args.index) for topic_id, topic_token_ids in queries: topics_and_documents[topic_id] = set() query_str = ' '.join( dictionary[term_id] for term_id in topic_token_ids if term_id is not None) for int_doc_id, score in query_env.query( query_str, results_requested=1000): topics_and_documents[topic_id].add( args.index.ext_document_id(int_doc_id)) args.top_k = topics_and_documents run = trec_utils.OnlineTRECRun( 'cuNVSM', rank_cutoff=( args.top_k if isinstance(args.top_k, int) else sys.maxsize)) rank_fn = RankFn( args.num_workers, args=args, model=model) for idx, (topic_id, topic_data) in enumerate(rank_fn(queries)): if topic_data is None: continue logging.info('Query %s (%d/%d)', topic_id, idx + 1, len(queries)) (topic_repr, topic_scores_and_documents) = topic_data run.add_ranking(topic_id, topic_scores_and_documents) del topic_scores_and_documents run.close_and_write(run_out_path, overwrite=False) logging.info('Run outputted to %s.', run_out_path) del rank_fn