예제 #1
0
    return vocab


if __name__ == '__main__':

    # python 2/3 compatibility
    if sys.version_info < (3, 0):
        sys.stderr = codecs.getwriter('UTF-8')(sys.stderr)
        sys.stdout = codecs.getwriter('UTF-8')(sys.stdout)
        sys.stdin = codecs.getreader('UTF-8')(sys.stdin)
    else:
        sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer)
        sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer)
        sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer)

    parser = create_parser()
    args = parser.parse_args()

    # read/write files as UTF-8
    if args.input.name != '<stdin>':
        args.input = codecs.open(args.input.name, encoding='utf-8')
    if args.output.name != '<stdout>':
        args.output = codecs.open(args.output.name, 'w', encoding='utf-8')

    vocab = main_args(args, args.input, args.output, args.dict_input)
    if len(args.vocab):
        with codecs.open(args.output.name, encoding='UTF-8') as codes:
            bpe = apply_bpe.BPE(codes, args.separator, None)
            # apply BPE to each training corpus and get vocabulary
            make_vocabularies(bpe, vocab, args.vocab)
예제 #2
0
 def set_bpe(self, codes_file):
     with codecs.open(self.codes_file, encoding='UTF-8') as codes:
         self.bpe = apply_bpe.BPE(codes, separator=self.separator)
def learn_joint_bpe_and_vocab(args):

    if args.vocab and len(args.input) != len(args.vocab):
        sys.stderr.write(
            'Error: number of input files and vocabulary files must match\n')
        sys.exit(1)

    # read/write files as utf-8
    args.input = [codecs.open(f.name, encoding='utf-8') for f in args.input]
    args.vocab = [
        codecs.open(f.name, 'w', encoding='utf-8') for f in args.vocab
    ]

    # get combined vocabulary of all input texts
    full_vocab = Counter()
    for f in args.input:
        full_vocab += learn_bpe.get_vocabulary(f)
        f.seek(0)

    vocab_list = [
        '{0} {1}'.format(key, freq) for (key, freq) in full_vocab.items()
    ]

    # learn BPE on combined vocabulary
    with codecs.open(args.output.name, 'w', encoding='utf-8') as output:
        learn_bpe.learn_bpe(vocab_list,
                            output,
                            args.symbols,
                            args.min_frequency,
                            args.verbose,
                            is_dict=True,
                            total_symbols=args.total_symbols)

    with codecs.open(args.output.name, encoding='utf-8') as codes:
        bpe = apply_bpe.BPE(codes, separator=args.separator)

    # apply BPE to each training corpus and get vocabulary
    for train_file, vocab_file in zip(args.input, args.vocab):

        tmp = tempfile.NamedTemporaryFile(delete=False)
        tmp.close()

        tmpout = codecs.open(tmp.name, 'w', encoding='utf-8')

        train_file.seek(0)
        for line in train_file:
            tmpout.write(bpe.segment(line).strip())
            tmpout.write('\n')

        tmpout.close()
        tmpin = codecs.open(tmp.name, encoding='utf-8')

        vocab = learn_bpe.get_vocabulary(tmpin)
        tmpin.close()
        os.remove(tmp.name)

        for key, freq in sorted(vocab.items(),
                                key=lambda x: x[1],
                                reverse=True):
            vocab_file.write("{0} {1}\n".format(key, freq))
        vocab_file.close()
예제 #4
0
def learn_joint_bpe_and_vocab(args):

    if args.vocab and len(args.input) != len(args.vocab):
        sys.stderr.write(
            'Error: number of input files and vocabulary files must match\n')
        sys.exit(1)

    # read/write files as UTF-8
    args.input = [codecs.open(f.name, encoding='UTF-8') for f in args.input]
    args.vocab = [
        codecs.open(f.name, 'w', encoding='UTF-8') for f in args.vocab
    ]

    if args.special_vocab:
        with codecs.open(args.special_vocab, encoding='UTF-8') as f:
            l = [
                line.strip('\r\n ')
                for line in codecs.open(args.special_vocab, encoding='UTF-8')
            ]
        args.special_vocab = l

    # get combined vocabulary of all input texts
    full_vocab = Counter()
    for f in args.input:
        full_vocab += learn_bpe.get_vocabulary(f, args.dict_input)
        f.seek(0)
    if args.special_vocab:
        for word in args.special_vocab:
            full_vocab[word] += 1  # integrate special vocab to full_vocab

    vocab_list = yield_dict_lines(full_vocab)

    # learn BPE on combined vocabulary
    with codecs.open(args.output.name, 'w', encoding='UTF-8') as output:
        learn_bpe.learn_bpe(vocab_list,
                            output,
                            args.symbols,
                            args.min_frequency,
                            args.verbose,
                            is_dict=True,
                            total_symbols=args.total_symbols,
                            is_postpend=args.postpend,
                            special_vocab=args.special_vocab)

    with codecs.open(args.output.name, encoding='UTF-8') as codes:
        bpe = apply_bpe.BPE(codes,
                            separator=args.separator,
                            is_postpend=args.postpend)

    # apply BPE to each training corpus and get vocabulary
    for train_file, vocab_file in zip(args.input, args.vocab):
        if args.dict_input:
            vocab = Counter()
            for i, line in enumerate(train_file):
                try:
                    word, count = line.strip('\r\n ').split(' ')
                    segments = bpe.segment_tokens([word])
                except:
                    print('Failed reading vocabulary file at line {0}: {1}'.
                          format(i, line))
                    sys.exit(1)
                for seg in segments:
                    vocab[seg] += int(count)
        else:
            tmp = tempfile.NamedTemporaryFile(delete=False)
            tmp.close()

            tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8')

            train_file.seek(0)
            for line in train_file:
                tmpout.write(bpe.process_line(line).strip())
                tmpout.write('\n')

            tmpout.close()
            tmpin = codecs.open(tmp.name, encoding='UTF-8')

            vocab = learn_bpe.get_vocabulary(tmpin)
            tmpin.close()
            os.remove(tmp.name)

        # if special vocab is defined, include them
        if args.special_vocab:
            for i, word in enumerate(args.special_vocab):
                try:
                    segments = bpe.segment_tokens([word])
                except:
                    print(
                        'Failed reading special vocabulary file at line {0}: {1}'
                        .format(i, line))
                    sys.exit(1)
                if len(segments) != 1:
                    sys.stderr.write(
                        'WARNING: special vocab \'{0}\' not captured by merges, split into \'{1}\'\n'
                        .format(word, ' '.join(segments)))
                for seg in segments:
                    vocab[seg] += 1

        sys.stderr.write('Vocabulary got {0:d} unique items\n'.format(
            len(vocab)))

        # if character vocab is to be included
        if args.character_vocab:
            char_internal, char_terminal = learn_bpe.extract_uniq_chars(
                full_vocab, args.postpend)
            sys.stderr.write(
                'Got {0:d} non-terminal and {1:d} terminal characters\n'.
                format(len(char_internal), len(char_terminal)))
            pseudo_count_terminal = max(
                vocab.values()) + 2  # always precedes non-terminal
            pseudo_count_internal = max(
                vocab.values()) + 1  # always precedes other items
            for c in char_terminal:
                vocab[c] = pseudo_count_terminal
            for c in char_internal:
                c = '{0}{1}'.format(args.separator,
                                    c) if args.postpend else '{0}{1}'.format(
                                        c, args.separator)
                vocab[c] = pseudo_count_internal
        for key, freq in sorted(vocab.items(), key=lambda x: (-x[1], x[0])):
            vocab_file.write("{0} {1}\n".format(key, freq))
        train_file.close()
        vocab_file.close()
    # learn BPE on combined vocabulary
    with io.open(args.output.name, 'w', encoding='UTF-8') as output:
        learn_bpe.main(vocab_list,
                       output,
                       args.symbols,
                       args.min_frequency,
                       args.verbose,
                       is_dict=True,
                       case_insensitive=args.case_insensitive)

    separator = '@@'
    if args.opennmt_separator: separator = '■'
    with io.open(args.output.name, encoding='UTF-8') as codes:
        bpe = apply_bpe.BPE(codes,
                            separator,
                            None,
                            case_feature=args.case_insensitive)

    # apply BPE to each training corpus and get vocabulary
    for train_file, vocab_file in zip(args.input, args.vocab):

        tmp = tempfile.NamedTemporaryFile(delete=False)
        tmp.close()

        tmpout = io.open(tmp.name, 'w', encoding='UTF-8')

        train_file.seek(0)
        for line in train_file:
            tmpout.write(bpe.segment(line).strip())
            tmpout.write('\n')
    vocab_list = [
        '{0} {1}'.format(key, freq) for (key, freq) in full_vocab.items()
    ]

    # learn BPE on combined vocabulary
    with codecs.open(args.output.name, 'w', encoding='UTF-8') as output:
        learn_bpe.main(vocab_list,
                       output,
                       args.symbols,
                       args.min_frequency,
                       args.verbose,
                       is_dict=True)

    with codecs.open(args.output.name, encoding='UTF-8') as codes:
        bpe = apply_bpe.BPE(codes, separator=args.separator)

    # apply BPE to each training corpus and get vocabulary
    for train_file, vocab_file in zip(args.input, args.vocab):

        tmp = tempfile.NamedTemporaryFile(delete=False)
        tmp.close()

        tmpout = codecs.open(tmp.name, 'w', encoding='UTF-8')

        train_file.seek(0)
        for line in train_file:
            tmpout.write(bpe.segment(line).strip())
            tmpout.write('\n')

        tmpout.close()
예제 #7
0
    def load_src_trg_ids(self, fpattern, tar_fname):
        converters = [
            Converter(vocab=self._src_vocab,
                      beg=self._bos_idx,
                      end=self._eos_idx,
                      unk=self._unk_idx,
                      delimiter=self._token_delimiter,
                      add_beg=False)
        ]
        if not self._only_src:
            converters.append(
                Converter(vocab=self._trg_vocab,
                          beg=self._bos_idx,
                          end=self._eos_idx,
                          unk=self._unk_idx,
                          delimiter=self._token_delimiter,
                          add_beg=True))

        converters = ComposedConverter(converters)

        self._src_seq_ids = []
        self._trg_seq_ids = None if self._only_src else []
        self._sample_infos = []

        if self._stream:
            # read json
            import json
            import jieba
            import apply_bpe
            parser = apply_bpe.create_parser()
            args = parser.parse_args(args=['-c', self._src_bpe_dict])
            bpe = apply_bpe.BPE(args.codes, args.merges, args.separator, None,
                                args.glossaries)

            src_js = []
            if fpattern.endswith('wav.txt'):
                # This is for Chinese-to-English asr translation
                src_js.append([[]])
                with open(fpattern, 'r') as f:
                    #src_js = json.load(f)
                    for line in f.readlines():
                        line = line.strip().split(', ')
                        # if len(line) < 3 or len(line[2].split(': ')) < 2:
                        #     embed()
                        #     print(line)
                        if len(line) > 5:
                            sent = ', '.join(line[1:-3])[14:]
                        else:
                            sent = line[1].split(': ')[1]
                        final = line[-3].split(': ')[1] == 'final'
                        # t = line[4].split(': ')[1]
                        #if len(src_js[-1]) == 0:
                        #    src_js[-1].append([])
                        #print(sent)
                        #embed()
                        src_js[-1][-1].append(sent)
                        if final:
                            src_js[-1].append([])
                    src_js[-1] = src_js[-1][:-1]
            else:
                src_js.append([])
                # This is for Chinese-to-English transcription translation
                with open(fpattern, 'r') as f:
                    #src_js = json.load(f)
                    for line in f.readlines():
                        txt = line.strip()
                        if len(src_js[-1]) == 0 or txt.startswith(
                                src_js[-1][-1][-1]) is False:
                            src_js[-1].append([])
                        src_js[-1][-1].append(txt)

            # do tokenization and bpe in this step
            # remove the final word because it's highly unstable (might be different bpe)
            stream_output = []
            for talk in src_js:
                stream_output.append([])
                for stream in talk:
                    _stream = []
                    for s in stream:
                        s = s.strip()
                        s = ' '.join(jieba.cut(s))
                        s = bpe.process_line(s, args.dropout)
                        _stream.append(s)

                    stream_output[-1].append([])
                    prev_len = 1
                    prev_sent = ''
                    for s in _stream[:-1]:
                        if len(s.split()) > prev_len and prev_sent != '':
                            stream_output[-1][-1].append(' '.join(
                                s.split()[:-1]))
                        else:
                            stream_output[-1][-1].append('')
                        prev_sent = s
                        prev_len = len(s.split())
                    if len(_stream) > 0:
                        stream_output[-1][-1].append(_stream[-1])

            # each line is a read
            filter_src = []
            real_read = [
            ]  # real incremental number of read before reading this line
            for talk in stream_output:
                for stream in talk:
                    filter_src.append([])
                    real_read.append([0])
                    for idx, s in enumerate(stream):
                        real_read[-1][-1] += 1
                        if len(s) > 0:
                            filter_src[-1].append(s)
                            if idx < len(stream) - 1:
                                real_read[-1].append(0)

            # embed()

            self._real_read = []
            i = 0
            for src_list, read_list in zip(filter_src, real_read):
                self._sample_infos.append([])
                for line, read_num in zip(src_list, read_list):
                    src_trg_ids = converters([line.encode()])
                    self._src_seq_ids.append(src_trg_ids[0])
                    self._real_read.append(read_num)
                    lens = [len(src_trg_ids[0])]
                    if not self._only_src:
                        self._trg_seq_ids.append(src_trg_ids[1])
                        lens.append(len(src_trg_ids[1]))
                    self._sample_infos[-1].append(
                        SampleInfo(i, max(lens), min(lens)))
                    i += 1
        else:
            for i, line in enumerate(self._load_lines(fpattern, tar_fname)):
                src_trg_ids = converters(line)
                self._src_seq_ids.append(src_trg_ids[0])
                lens = [len(src_trg_ids[0])]
                if not self._only_src:
                    self._trg_seq_ids.append(src_trg_ids[1])
                    lens.append(len(src_trg_ids[1]))
                self._sample_infos.append(SampleInfo(i, max(lens), min(lens)))
예제 #8
0
parser.add_argument('--add-sos',
                    default=False,
                    action='store_true',
                    help='add <s> token')
parser.add_argument('--add-eos',
                    default=False,
                    action='store_true',
                    help='add </s> token')
parser.add_argument('--all',
                    '-a',
                    action='store_true',
                    default=False,
                    help="Add SOS and EOS tokens to the constraint")
args = parser.parse_args()

bpe = apply_bpe.BPE(open(args.bpe))

termsdict = {}
if args.dictionary is not None:
    with open(args.dictionary) as fh:
        for i, line in enumerate(fh):
            source, target = line.rstrip().split('\t')
            termsdict[source] = target


def get_phrase(words, index, length):
    assert (index < len(words) - length + 1)
    phr = ' '.join(words[index:index + length])
    for i in range(length):
        words.pop(index)