Exemplo n.º 1
0
    def translate(self, src_data_iter, tgt_data_iter, structure_iter, batch_size, out_file=None):

        data = build_dataset(self.fields, src_data_iter, tgt_data_iter, None, structure_iter,None, None, use_filter_pred=False)

        # for line in data:
        #   print(line.__dict__)    {src:  , indices:   structure: }

        def sort_translation(indices, translation):
            ordered_transalation = [None] * len(translation)
            for i, index in enumerate(indices):
                ordered_transalation[index] = translation[i]
            return ordered_transalation

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = OrderedIterator(
            dataset=data, device=cur_device,
            batch_size=batch_size, train=False, sort=True,
            sort_within_batch=True, shuffle=True)

        start_time = time.time()
        print("Begin decoding ...")
        batch_count = 0
        all_translation = []

        for batch in data_iter:
            '''
            batch
            [torchtext.data.batch.Batch of size 30]
            [.src]:('[torch.LongTensor of size 4x30]', '[torch.LongTensor of size 30]')
            [.indices]:[torch.LongTensor of size 30]
            [.structure]:[torch.LongTensor of size 30x4x4]
            '''
            hyps, scores = self.translate_batch(batch)
            assert len(batch) == len(hyps)
            batch_transtaltion = []
            for src_idx_seq, tran_idx_seq, score in zip(batch.src[0].transpose(0, 1), hyps, scores):
                src_words = self.build_tokens(src_idx_seq, side='src')
                src = ' '.join(src_words)

                tran_words = self.build_tokens(tran_idx_seq, side='tgt')
                tran = ' '.join(tran_words)

                batch_transtaltion.append(tran)
                print("SOURCE: " + src + "\nOUTPUT: " + tran + "\n")
            for index, tran in zip(batch.indices.data, batch_transtaltion):
                while (len(all_translation) <= index):
                    all_translation.append("")
                all_translation[index] = tran
            batch_count += 1
            print("batch: " + str(batch_count) + "...")

        if out_file is not None:
            for tran in all_translation:
                out_file.write(tran + '\n')
        print('Decoding took %.1f minutes ...' % (float(time.time() - start_time) / 60.))
Exemplo n.º 2
0
    def translate(self,
                  src_data_iter,
                  tgt_data_iter,
                  batch_size,
                  out_file=None):
        # data每次产生一个eaxmple, 包含example.indice, example.src
        data = build_dataset(self.fields,
                             src_data_iter=src_data_iter,
                             tgt_data_iter=tgt_data_iter,
                             use_filter_pred=False)

        def sort_translation(indices, translation):
            # indices是一维张量,translation是一维数组
            ordered_transalation = [None] * len(translation)
            for i, index in enumerate(indices):
                ordered_transalation[index] = translation[i]
            return ordered_transalation

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = OrderedIterator(dataset=data,
                                    device=cur_device,
                                    batch_size=batch_size,
                                    train=False,
                                    sort=False,
                                    sort_within_batch=True,
                                    shuffle=False)
        start_time = time.time()
        print("Begin decoding ...")
        idx = 0,  # 此处的batch中的src每行长度不对齐
        for batch in data_iter:
            # batch.src[0]: (27, batch_size), batch.src[1]: (27, ... ,...)
            # hyps尺寸为(batch_size, 4)的arry, scores长度为batch_size的一维数组
            # 可以看出最终每句话均翻译为4个单词

            # 下面代码使用batch的时候,并没有迭代,而是直接取值
            hyps, scores = self.translate_batch(batch)
            assert len(batch) == len(hyps)
            transtaltion = []
            for idx_seq, score in zip(hyps, scores):
                words = self.build_tokens(idx_seq, side='tgt')
                tran = ' '.join(words)
                transtaltion.append(tran)
            if out_file is not None:
                transtaltion = sort_translation(batch.indices.data - idx,
                                                transtaltion)
                for tran in transtaltion:
                    out_file.write(tran + '\n')
            idx += len(batch)
            print("sents " + str(idx) + "...")
        print('Decoding took %.1f minutes ...' %
              (float(time.time() - start_time) / 60.))
Exemplo n.º 3
0
    def translate(self,
                  src_data_iter,
                  tgt_data_iter,
                  batch_size,
                  out_file=None):
        data = build_dataset(self.fields,
                             src_data_iter=src_data_iter,
                             tgt_data_iter=tgt_data_iter,
                             use_filter_pred=False)

        def sort_translation(indices, translation):
            ordered_transalation = [None] * len(translation)
            for i, index in enumerate(indices):
                ordered_transalation[index] = translation[i]
            return ordered_transalation

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = OrderedIterator(dataset=data,
                                    device=cur_device,
                                    batch_size=batch_size,
                                    train=False,
                                    sort=False,
                                    sort_within_batch=True,
                                    shuffle=False)
        start_time = time.time()
        print("Begin decoding ...")
        idx = 0
        for batch in data_iter:
            hyps, scores = self.translate_batch(batch)
            assert len(batch) == len(hyps)
            transtaltion = []
            for idx_seq, score in zip(hyps, scores):
                words = self.build_tokens(idx_seq, side='tgt')
                tran = ' '.join(words)
                transtaltion.append(tran)
            if out_file is not None:
                transtaltion = sort_translation(batch.indices.data - idx,
                                                transtaltion)
                for tran in transtaltion:
                    out_file.write(tran + '\n')
            idx += len(batch)
            print("sents " + str(idx) + "...")
        print('Decoding took %.1f minutes ...' %
              (float(time.time() - start_time) / 60.))
Exemplo n.º 4
0
  def translate(self, src_data_iter, tgt_data_iter, batch_size, out_file=None):
    data = build_dataset(self.fields,
                         src_data_iter=src_data_iter,
                         tgt_data_iter=tgt_data_iter,
                         use_filter_pred=False)
    
    def sort_translation(indices, translation):
      ordered_transalation = [None] * len(translation)
      for i, index in enumerate(indices):
        ordered_transalation[index] = translation[i]
      return ordered_transalation
    
    if self.cuda:
        cur_device = "cuda"
    else:
        cur_device = "cpu"
    # sort=True sort_within_batch=True shuffle=True
    data_iter = OrderedIterator(
      dataset=data, device=cur_device,
      batch_size=batch_size, train=False, sort=False,
      sort_within_batch=False, shuffle=False)
    start_time = time.time()
    print("Begin decoding ...")
    batch_count = 0
    all_translation = []
    for batch in data_iter:
      hyps, scores = self.translate_batch(batch)
      assert len(batch) == len(hyps)
      batch_transtaltion = []

      for src_idx_seq, tran_idx_seq, score in zip(batch.src[0].transpose(0, 1), hyps, scores):
        # src_words = self.build_tokens(src_idx_seq, side='src')
        # src = ' '.join(src_words)
        tran_words = self.build_tokens(tran_idx_seq, side='tgt')
        tran = ' '.join(tran_words)
        batch_transtaltion.append(tran)
        print("SOURCE: " + "Three-modal" + "\nOUTPUT: " + tran + "\n")#src
      for index, tran in zip(batch.indices.data, batch_transtaltion):# why my batch have inidices
        while (len(all_translation) <=  index):
          all_translation.append("")
        all_translation[index] = tran
      batch_count += 1
      print("batch: " + str(batch_count) + "...")
      
    if out_file is not None:
      for tran in all_translation:
        out_file.write(tran + '\n')
    print('Decoding took %.1f minutes ...'%(float(time.time() - start_time) / 60.))