コード例 #1
0
def prepare_worker_(document_path):
    reader = trec_utils.TRECTextReader([document_path],
                                       encoding=prepare_worker.encoding)
    num_documents = 0

    for doc_id, doc_text in reader.iter_documents(replace_digits=True,
                                                  strip_html=True):
        # Values to be returned.
        instances_and_labels = []

        if doc_id not in prepare_worker.document_assocs:
            logging.debug('Document "%s" does not exist in associations.',
                          doc_id)

            continue

        def _callback(num_yielded_windows, remaining_tokens):
            if not num_yielded_windows:
                logging.error('Document "%s" (%s) yielded zero instances; '
                              'remaining tokens: %s.',
                              doc_id, doc_text, remaining_tokens)

        padding_token = (
            '</s>' if not prepare_worker.args.no_padding else None)

        # Ignore end-of-sentence.
        windowed_word_stream = io_utils.windowed_translated_token_stream(
            io_utils.replace_numeric_tokens_stream(
                io_utils.token_stream(
                    io_utils.lowercased_stream(
                        io_utils.filter_non_latin_stream(
                            io_utils.filter_non_alphanumeric_stream(
                                iter(doc_text)))),
                    eos_chars=[])),
            prepare_worker.args.window_size,
            prepare_worker.words,
            stride=prepare_worker.args.stride,
            padding_token=padding_token,
            callback=_callback)

        # To determine the matrix indices of the entities associated with
        # the document.
        entity_ids = [entity_id for entity_id in
                      prepare_worker.document_assocs[doc_id]]

        label = _candidate_centric_label(entity_ids)
        partition_function = float(sum(label.values()))

        for index in label:
            label[index] /= partition_function

        for instance in windowed_word_stream:
            instances_and_labels.append((doc_id, instance, label))

        prepare_worker.result_queue.put(
            (doc_id, instances_and_labels, label))

        num_documents += 1

    return num_documents
コード例 #2
0
    def test_windowed_translated_token_stream_bigrams_trigrams(self):
        words = {
            '</s>': io_utils.Word(id=0, count=1),
            'world': io_utils.Word(id=1, count=1),
            'foo': io_utils.Word(id=2, count=1),
            'bar': io_utils.Word(id=3, count=1),
            'hello': io_utils.Word(id=4, count=1),
        }

        text = ('hello', 'world', 'world', 'hello', 'bar', 'hello', 'foo',
                'world', 'foo', 'foo', 'world', 'bar')

        stream = io_utils.windowed_translated_token_stream(iter(text),
                                                           window_size=(2, 3),
                                                           words=words)

        world, foo, bar, hello = range(1, 5)

        self.assertEqual(
            sorted(list(stream)),
            sorted([(hello, world), (world, world), (world, hello),
                    (hello, bar), (bar, hello), (hello, foo), (foo, world),
                    (world, foo), (foo, foo), (foo, world), (world, bar),
                    (4, 1, 1), (1, 1, 4), (1, 4, 3), (4, 3, 4), (3, 4, 2),
                    (4, 2, 1), (2, 1, 2), (1, 2, 2), (2, 2, 1), (2, 1, 3)]))
コード例 #3
0
    def test_windowed_translated_token_stream_0and1skip(self):
        words = {
            '</s>': io_utils.Word(id=0, count=1),
            'world': io_utils.Word(id=1, count=1),
            'foo': io_utils.Word(id=2, count=1),
            'bar': io_utils.Word(id=3, count=1),
            'hello': io_utils.Word(id=4, count=1),
        }

        text = ('hello', 'world', 'world', 'hello', 'bar', 'hello', 'foo',
                'world', 'foo', 'foo', 'world', 'bar')

        stream = io_utils.windowed_translated_token_stream(iter(text),
                                                           window_size=3,
                                                           words=words,
                                                           skips=(
                                                               0,
                                                               1,
                                                           ))

        world, foo, bar, hello = range(1, 5)

        self.assertEqual(
            sorted(list(stream)),
            sorted([(hello, world, world), (world, world, hello),
                    (world, hello, bar), (hello, bar, hello),
                    (bar, hello, foo), (hello, foo, world), (foo, world, foo),
                    (world, foo, foo), (foo, foo, world), (foo, world, bar),
                    (hello, world, bar), (world, hello, hello),
                    (world, bar, foo), (hello, hello, world), (bar, foo, foo),
                    (hello, world, foo), (foo, foo, world),
                    (world, foo, bar)]))
コード例 #4
0
    def test_windowed_translated_token_stream_window_too_large(self):
        words = {
            '</s>': io_utils.Word(id=0, count=1),
            'world': io_utils.Word(id=1, count=1),
            'foo': io_utils.Word(id=2, count=1),
            'bar': io_utils.Word(id=3, count=1),
            'hello': io_utils.Word(id=4, count=1),
        }

        text = ('hello', 'world', 'world', 'hello', 'bar', 'hello', 'foo',
                'world', 'foo', 'foo', 'world', 'bar')

        stream = io_utils.windowed_translated_token_stream(iter(text),
                                                           window_size=16,
                                                           words=words)

        self.assertEqual(list(stream), [])
コード例 #5
0
    def test_windowed_translated_token_stream_3stride(self):
        words = {
            '</s>': io_utils.Word(id=0, count=1),
            'world': io_utils.Word(id=1, count=1),
            'foo': io_utils.Word(id=2, count=1),
            'bar': io_utils.Word(id=3, count=1),
            'hello': io_utils.Word(id=4, count=1),
        }

        text = ('hello', 'world', 'world', 'hello', 'bar', 'hello', 'foo',
                'world', 'foo', 'foo', 'world', 'bar')

        stream = io_utils.windowed_translated_token_stream(iter(text),
                                                           window_size=3,
                                                           words=words,
                                                           stride=3)

        self.assertEqual(list(stream), [(4, 1, 1), (4, 3, 4), (2, 1, 2),
                                        (2, 1, 3)])