def prepare_worker_(document_path): reader = trec_utils.TRECTextReader([document_path], encoding=prepare_worker.encoding) num_documents = 0 for doc_id, doc_text in reader.iter_documents(replace_digits=True, strip_html=True): # Values to be returned. instances_and_labels = [] if doc_id not in prepare_worker.document_assocs: logging.debug('Document "%s" does not exist in associations.', doc_id) continue def _callback(num_yielded_windows, remaining_tokens): if not num_yielded_windows: logging.error('Document "%s" (%s) yielded zero instances; ' 'remaining tokens: %s.', doc_id, doc_text, remaining_tokens) padding_token = ( '</s>' if not prepare_worker.args.no_padding else None) # Ignore end-of-sentence. windowed_word_stream = io_utils.windowed_translated_token_stream( io_utils.replace_numeric_tokens_stream( io_utils.token_stream( io_utils.lowercased_stream( io_utils.filter_non_latin_stream( io_utils.filter_non_alphanumeric_stream( iter(doc_text)))), eos_chars=[])), prepare_worker.args.window_size, prepare_worker.words, stride=prepare_worker.args.stride, padding_token=padding_token, callback=_callback) # To determine the matrix indices of the entities associated with # the document. entity_ids = [entity_id for entity_id in prepare_worker.document_assocs[doc_id]] label = _candidate_centric_label(entity_ids) partition_function = float(sum(label.values())) for index in label: label[index] /= partition_function for instance in windowed_word_stream: instances_and_labels.append((doc_id, instance, label)) prepare_worker.result_queue.put( (doc_id, instances_and_labels, label)) num_documents += 1 return num_documents
def test_windowed_translated_token_stream_bigrams_trigrams(self): words = { '</s>': io_utils.Word(id=0, count=1), 'world': io_utils.Word(id=1, count=1), 'foo': io_utils.Word(id=2, count=1), 'bar': io_utils.Word(id=3, count=1), 'hello': io_utils.Word(id=4, count=1), } text = ('hello', 'world', 'world', 'hello', 'bar', 'hello', 'foo', 'world', 'foo', 'foo', 'world', 'bar') stream = io_utils.windowed_translated_token_stream(iter(text), window_size=(2, 3), words=words) world, foo, bar, hello = range(1, 5) self.assertEqual( sorted(list(stream)), sorted([(hello, world), (world, world), (world, hello), (hello, bar), (bar, hello), (hello, foo), (foo, world), (world, foo), (foo, foo), (foo, world), (world, bar), (4, 1, 1), (1, 1, 4), (1, 4, 3), (4, 3, 4), (3, 4, 2), (4, 2, 1), (2, 1, 2), (1, 2, 2), (2, 2, 1), (2, 1, 3)]))
def test_windowed_translated_token_stream_0and1skip(self): words = { '</s>': io_utils.Word(id=0, count=1), 'world': io_utils.Word(id=1, count=1), 'foo': io_utils.Word(id=2, count=1), 'bar': io_utils.Word(id=3, count=1), 'hello': io_utils.Word(id=4, count=1), } text = ('hello', 'world', 'world', 'hello', 'bar', 'hello', 'foo', 'world', 'foo', 'foo', 'world', 'bar') stream = io_utils.windowed_translated_token_stream(iter(text), window_size=3, words=words, skips=( 0, 1, )) world, foo, bar, hello = range(1, 5) self.assertEqual( sorted(list(stream)), sorted([(hello, world, world), (world, world, hello), (world, hello, bar), (hello, bar, hello), (bar, hello, foo), (hello, foo, world), (foo, world, foo), (world, foo, foo), (foo, foo, world), (foo, world, bar), (hello, world, bar), (world, hello, hello), (world, bar, foo), (hello, hello, world), (bar, foo, foo), (hello, world, foo), (foo, foo, world), (world, foo, bar)]))
def test_windowed_translated_token_stream_window_too_large(self): words = { '</s>': io_utils.Word(id=0, count=1), 'world': io_utils.Word(id=1, count=1), 'foo': io_utils.Word(id=2, count=1), 'bar': io_utils.Word(id=3, count=1), 'hello': io_utils.Word(id=4, count=1), } text = ('hello', 'world', 'world', 'hello', 'bar', 'hello', 'foo', 'world', 'foo', 'foo', 'world', 'bar') stream = io_utils.windowed_translated_token_stream(iter(text), window_size=16, words=words) self.assertEqual(list(stream), [])
def test_windowed_translated_token_stream_3stride(self): words = { '</s>': io_utils.Word(id=0, count=1), 'world': io_utils.Word(id=1, count=1), 'foo': io_utils.Word(id=2, count=1), 'bar': io_utils.Word(id=3, count=1), 'hello': io_utils.Word(id=4, count=1), } text = ('hello', 'world', 'world', 'hello', 'bar', 'hello', 'foo', 'world', 'foo', 'foo', 'world', 'bar') stream = io_utils.windowed_translated_token_stream(iter(text), window_size=3, words=words, stride=3) self.assertEqual(list(stream), [(4, 1, 1), (4, 3, 4), (2, 1, 2), (2, 1, 3)])