예제 #1
0
def build_template_dataset(fields,
                           src_path,
                           tgt_path,
                           template_path,
                           src_seq_length=0,
                           tgt_seq_length=0,
                           src_seq_length_trunc=0,
                           tgt_seq_length_trunc=0,
                           dynamic_dict=True,
                           use_filter_pred=True,
                           with_pos=False):

    # Build src/tgt examples iterator from corpus files, also extract
    # number of features.
    src_examples_iter, num_src_feats = \
            AdvancedTextDataset.make_template_examples_nfeats_tpl(
                src_path,template_path, src_seq_length_trunc, "src",with_pos)

    # For all data types, the tgt side corpus is in form of text.
    tgt_examples_iter, num_tgt_feats = \
        TextDataset.make_text_examples_nfeats_tpl(
            tgt_path, tgt_seq_length_trunc, "tgt")

    dataset = TextDataset(fields,
                          src_examples_iter,
                          tgt_examples_iter,
                          num_src_feats,
                          num_tgt_feats,
                          src_seq_length=src_seq_length,
                          tgt_seq_length=tgt_seq_length,
                          dynamic_dict=dynamic_dict,
                          use_filter_pred=use_filter_pred)

    return dataset
예제 #2
0
    def num_feats(self):
        # We peek the first line and seek back to
        # the beginning of the file.
        saved_pos = self.corpus.tell()

        line = self.corpus.readline().split()
        if self.line_truncate:
            line = line[:self.line_truncate]
        _, _, self.n_feats = TextDataset.extract_text_features(line)

        self.corpus.seek(saved_pos)

        return self.n_feats
예제 #3
0
    def _example_dict_iter(self, line):
        line = line.split()
        if self.line_truncate:
            line = line[:self.line_truncate]
        words, feats, n_feats = TextDataset.extract_text_features(line)
        example_dict = {self.side: words, "indices": self.line_index}
        if feats:
            # All examples must have same number of features.
            aeq(self.n_feats, n_feats)

            prefix = self.side + "_feat_"
            example_dict.update(
                (prefix + str(j), f) for j, f in enumerate(feats))

        return example_dict
예제 #4
0
def build_dataset(fields, data_type, src_path, tgt_path, src_dir=None,
                  src_seq_length=0, tgt_seq_length=0,
                  src_seq_length_trunc=0, tgt_seq_length_trunc=0,
                  dynamic_dict=True, sample_rate=0,
                  window_size=0, window_stride=0, window=None,
                  normalize_audio=True, use_filter_pred=True):

    # Hide this import inside to avoid circular dependency problem.
    from onmt.io import TextDataset, ImageDataset, AudioDataset

    # Build src/tgt examples iterator from corpus files, also extract
    # number of features. For all data types, the tgt side corpus is
    # in form of text.
    src_examples_iter, num_src_feats = \
        _make_examples_nfeats_tpl(data_type, src_path, src_dir,
                                  src_seq_length_trunc, sample_rate,
                                  window_size, window_stride,
                                  window, normalize_audio)

    tgt_examples_iter, num_tgt_feats = \
        _make_text_examples_nfeats_tpl(tgt_path, tgt_seq_length_trunc, "tgt")

    if data_type == 'text':
        dataset = TextDataset(fields, src_examples_iter, tgt_examples_iter,
                              num_src_feats, num_tgt_feats,
                              src_seq_length=src_seq_length,
                              tgt_seq_length=tgt_seq_length,
                              dynamic_dict=dynamic_dict,
                              use_filter_pred=use_filter_pred)

    elif data_type == 'img':
        dataset = ImageDataset(fields, src_examples_iter, tgt_examples_iter,
                               num_src_feats, num_tgt_feats,
                               tgt_seq_length=tgt_seq_length,
                               use_filter_pred=use_filter_pred)

    elif data_type == 'audio':
        dataset = AudioDataset(fields, src_examples_iter, tgt_examples_iter,
                               num_src_feats, num_tgt_feats,
                               tgt_seq_length=tgt_seq_length,
                               sample_rate=sample_rate,
                               window_size=window_size,
                               window_stride=window_stride,
                               window=window,
                               normalize_audio=normalize_audio,
                               use_filter_pred=use_filter_pred)

    return dataset
예제 #5
0
    def textDataFromString(data, truncate, side):
        with io.StringIO(data) as corpus_file:
            for i, line in enumerate(corpus_file):
                line = line.strip().split()
                if truncate:
                    line = line[:truncate]

                words, feats, n_feats = \
                    TextDataset.extract_text_features(line)

                example_dict = {side: words, "indices": i}
                if feats:
                    prefix = side + "_feat_"
                    example_dict.update((prefix + str(j), f)
                                        for j, f in enumerate(feats))
                yield example_dict, n_feats
예제 #6
0
    def get_num_features(corpus_file, side):
        """
        Peek one line and get number of features of it.
        (All lines must have same number of features).
        For text corpus, both sides are in text form, thus
        it works the same.

        Args:
            corpus_file (str): file path to get the features.
            side (str): 'src' or 'tgt'.

        Returns:
            number of features on `side`.
        """
        with codecs.open(corpus_file, "r", "utf-8") as cf:
            f_line = cf.readline().strip().split()
            _, _, num_feats = TextDataset.extract_text_features(f_line)

        return num_feats
예제 #7
0
    def build_example_from_pair(truncate,
                                side,
                                i,
                                line,
                                line_t,
                                with_pos=False):
        words_s = line.split()
        if truncate:
            #line = line[:truncate]
            words_s = words_s[:truncate]
        spliter_pos = len(words_s)
        line = ' '.join(words_s)
        line = ' '.join([line, TEMPLATE_SPLITER, line_t]).split()

        words, feats, n_feats = TextDataset.extract_text_features(line)
        example_dict = {side: words, "indices": i}
        if with_pos:
            example_dict['spliter_pos'] = spliter_pos
        if feats:
            prefix = side + "_feat_"
            example_dict.update(
                (prefix + str(j), f) for (j, f) in enumerate(feats))
        return example_dict, n_feats
예제 #8
0
    def translate(self, in_text, partial_decode=[], attn_overwrite=[], k=5,
                  attn=None, dump_data=False, roundTo=5):
        """
        in_text: list of strings
        partial_decode: list of strings, not implemented yet
        k: int, number of top translations to return
        attn: list, not implemented yet
        """

        # Set batch size to number of requested translations
        self.opt.batch_size = len(in_text)
        # Workaround until we have API that does not require files
        # with codecs.open("tmp.txt", "w", "utf-8") as f:
        #     for line in in_text:
        #         f.write(line + "\n")

        if dump_data:
            # Code to extract the source and target dict
            with open("s2s/src.dict", 'w') as f:
                for w, ix in self.translator.fields['src'].vocab.stoi.items():
                    f.write(str(ix) + " " + w + "\n")
            with open("s2s/tgt.dict", 'w') as f:
                for w, ix in self.translator.fields['tgt'].vocab.stoi.items():
                    f.write(str(ix) + " " + w + "\n")
            with h5py.File("s2s/embs.h5", 'w') as f:
                f.create_dataset("encoder", data=
                self.translator.model.encoder.embeddings.emb_luts[
                    0].weight.data.numpy())
                f.create_dataset("decoder", data=
                self.translator.model.decoder.embeddings.emb_luts[
                    0].weight.data.numpy())

        # Use written file as input to dataset builder
        # data = onmt.io.build_dataset(
        #     self.fields, self.opt.data_type,
        #     "tmp.txt", self.opt.tgt,
        #     src_dir=self.opt.src_dir,
        #     sample_rate=self.opt.sample_rate,
        #     window_size=self.opt.window_size,
        #     window_stride=self.opt.window_stride,
        #     window=self.opt.window,
        #     use_filter_pred=False)

        (src_examples_iter, num_src_feats) = \
            ONMTmodelAPI.make_text_examples_nfeats_tpl('\n'.join(in_text), 0,
                                                       'src')

        data = TextDataset(self.fields, src_examples_iter, None,
                           num_src_feats, 0,
                           src_seq_length=0,
                           tgt_seq_length=0,
                           dynamic_dict=True,
                           use_filter_pred=False)

        # Iterating over the single batch... torchtext requirement
        test_data = onmt.io.OrderedIterator(
            dataset=data, device=self.opt.gpu,
            batch_size=self.opt.batch_size, train=False, sort=False,
            sort_within_batch=True,
            shuffle=False)

        # set n_best in translator
        self.translator.n_best = k

        # Increase Beam size if asked for large k
        if self.translator.beam_size < k:
            self.translator.beam_size = k

        # Builder used to convert translation to text
        builder = onmt.translate.TranslationBuilder(
            data, self.translator.fields,
            self.opt.n_best, self.opt.replace_unk, self.opt.tgt)

        # Convert partial decode into valid input to decoder
        print("partial:", partial_decode)
        vocab = self.fields["tgt"].vocab
        partial = []
        for p in partial_decode:
            curr_part = []
            for tok in p.split():
                curr_part.append(vocab.stoi[tok])
            partial.append(curr_part)

        reply = {}

        # Only has one batch, but indexing does not work
        for batch in test_data:
            print(attn_overwrite, 'over')
            batch_data = self.translator.translate_batch(
                batch, data, return_states=True,
                partial=partial, attn_overwrite=attn_overwrite)
            translations = builder.from_batch(batch_data)
            # iteratres over items in batch
            rr = lambda x: [(round(xx, roundTo)) for xx in x]
            for transIx, trans in enumerate(translations):
                context = batch_data['context'][:, transIx, :]
                print(trans.pred_sents)
                res = {}
                # Fill encoder Result
                encoderRes = []
                for token, state in zip(in_text[transIx].split(), context):
                    encoderRes.append({'token': token,
                                       'state': rr(list(state.data))
                                       })
                res['encoder'] = encoderRes

                # # Fill decoder Result
                decoderRes = []
                attnRes = []
                for ix, p in enumerate(trans.pred_sents[:k]):
                    if p:
                        topIx = []
                        topIxAttn = []
                        for token, attn, state, cstar in zip(p,
                                                             trans.attns[ix],
                                                             batch_data[
                                                                 "target_states"][
                                                                 transIx][ix],
                                                             batch_data[
                                                                 'target_cstar'][
                                                                 transIx][ix]):
                            currentDec = {}
                            currentDec['token'] = token
                            currentDec['state'] = rr(list(state.data))
                            currentDec['cstar'] = rr(list(cstar.data))
                            topIx.append(currentDec)
                            topIxAttn.append(rr(list(attn)))
                            # if t in ['.', '!', '?']:
                            #     break
                        decoderRes.append(topIx)
                        attnRes.append(topIxAttn)
                res['scores'] = list(np.array(trans.pred_scores))[:k]
                res['decoder'] = decoderRes
                res['attn'] = attnRes
                # res['beam'] = batch_data['beam'][transIx]
                # todo: make nice...
                convert_to_py = lambda x: {"pred": x['pred'].item(),
                                           "score": x[
                                               'score'].item(),
                                           "state": rr(
                                               list(map(lambda s: s.item(),
                                                        x['state'])))
                                           }
                res['beam'] = list(map(lambda t:
                                       list(map(convert_to_py,
                                                t)),
                                       batch_data['beam'][transIx]))
                res['beam_trace'] = batch_data['beam_trace'][transIx]
                reply[transIx] = res
        return reply
예제 #9
0
def build_dataset(fields,
                  data_type,
                  src_path,
                  tgt_path,
                  src_dir=None,
                  src_seq_length=0,
                  tgt_seq_length=0,
                  src_seq_length_trunc=0,
                  tgt_seq_length_trunc=0,
                  dynamic_dict=True,
                  sample_rate=0,
                  window_size=0,
                  window_stride=0,
                  window=None,
                  normalize_audio=True,
                  use_filter_pred=True):

    # Hide this import inside to avoid circular dependency problem.
    from onmt.io import TextDataset, ImageDataset, AudioDataset

    if data_type == 'text':
        src_examples_iter, num_src_feats = \
            _make_examples_numfeats_tpl(src_path, src_seq_length_trunc, "src")

        tgt_examples_iter, num_tgt_feats = \
            _make_examples_numfeats_tpl(tgt_path, tgt_seq_length_trunc, "tgt")

        dataset = TextDataset(fields,
                              src_examples_iter,
                              tgt_examples_iter,
                              num_src_feats,
                              num_tgt_feats,
                              src_seq_length=src_seq_length,
                              tgt_seq_length=tgt_seq_length,
                              dynamic_dict=dynamic_dict,
                              use_filter_pred=use_filter_pred)

    elif data_type == 'img':
        src_examples_iter = _read_img_file(src_path, src_dir, "src")
        num_src_feats = 0  # Source side(image) has no features.

        tgt_examples_iter, num_tgt_feats = \
            _make_examples_numfeats_tpl(tgt_path, tgt_seq_length_trunc, "tgt")

        dataset = ImageDataset(fields,
                               src_examples_iter,
                               tgt_examples_iter,
                               num_src_feats,
                               num_tgt_feats,
                               tgt_seq_length=tgt_seq_length,
                               use_filter_pred=use_filter_pred)

    elif data_type == 'audio':
        src_examples_iter = _read_audio_file(src_path, src_dir, "src",
                                             sample_rate, window_size,
                                             window_stride, window,
                                             normalize_audio)
        num_src_feats = 0  # Source side(audio) has no features.

        tgt_examples_iter, num_tgt_feats = \
            _make_examples_numfeats_tpl(tgt_path, tgt_seq_length_trunc, "tgt")

        dataset = AudioDataset(fields,
                               src_examples_iter,
                               tgt_examples_iter,
                               num_src_feats,
                               num_tgt_feats,
                               tgt_seq_length=tgt_seq_length,
                               sample_rate=sample_rate,
                               window_size=window_size,
                               window_stride=window_stride,
                               window=window,
                               normalize_audio=normalize_audio,
                               use_filter_pred=use_filter_pred)

    return dataset