def translate(self, src_data_iter, tgt_data_iter, structure_iter, batch_size, out_file=None):

        data = build_dataset(self.fields, src_data_iter, tgt_data_iter, None, structure_iter,None, None, use_filter_pred=False)

        # for line in data:
        #   print(line.__dict__)    {src:  , indices:   structure: }

        def sort_translation(indices, translation):
            ordered_transalation = [None] * len(translation)
            for i, index in enumerate(indices):
                ordered_transalation[index] = translation[i]
            return ordered_transalation

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = OrderedIterator(
            dataset=data, device=cur_device,
            batch_size=batch_size, train=False, sort=True,
            sort_within_batch=True, shuffle=True)

        start_time = time.time()
        print("Begin decoding ...")
        batch_count = 0
        all_translation = []

        for batch in data_iter:
            '''
            batch
            [torchtext.data.batch.Batch of size 30]
            [.src]:('[torch.LongTensor of size 4x30]', '[torch.LongTensor of size 30]')
            [.indices]:[torch.LongTensor of size 30]
            [.structure]:[torch.LongTensor of size 30x4x4]
            '''
            hyps, scores = self.translate_batch(batch)
            assert len(batch) == len(hyps)
            batch_transtaltion = []
            for src_idx_seq, tran_idx_seq, score in zip(batch.src[0].transpose(0, 1), hyps, scores):
                src_words = self.build_tokens(src_idx_seq, side='src')
                src = ' '.join(src_words)

                tran_words = self.build_tokens(tran_idx_seq, side='tgt')
                tran = ' '.join(tran_words)

                batch_transtaltion.append(tran)
                print("SOURCE: " + src + "\nOUTPUT: " + tran + "\n")
            for index, tran in zip(batch.indices.data, batch_transtaltion):
                while (len(all_translation) <= index):
                    all_translation.append("")
                all_translation[index] = tran
            batch_count += 1
            print("batch: " + str(batch_count) + "...")

        if out_file is not None:
            for tran in all_translation:
                out_file.write(tran + '\n')
        print('Decoding took %.1f minutes ...' % (float(time.time() - start_time) / 60.))
Example #2
0
def build_save_dataset(corpus_type, fields,
                       opt):  # corpus_type: train or valid
    """ Building and saving the dataset """
    assert corpus_type in ['train',
                           'valid']  # Judging whether it is train or valid

    if corpus_type == 'train':
        src_corpus = opt.train_src  # 获取源端、目标端和结构信息的path
        tgt_corpus = opt.train_tgt
        stgt_corpus = opt.train_stgt
        structure_corpus = opt.train_structure
        mask_corpus = opt.train_mask
        relation_corpus = opt.train_relation
    else:
        src_corpus = opt.valid_src
        tgt_corpus = opt.valid_tgt
        stgt_corpus = opt.valid_stgt
        structure_corpus = opt.valid_structure
        mask_corpus = opt.valid_mask
        relation_corpus = opt.valid_relation

    if (opt.shard_size > 0):
        return build_save_in_shards_using_shards_size(
            src_corpus, tgt_corpus, stgt_corpus, structure_corpus, mask_corpus,
            relation_corpus, fields, corpus_type, opt)

    # We only build a monolithic dataset.
    # But since the interfaces are uniform, it would be not hard to do this should users need this feature.
    src_iter = make_text_iterator_from_file(src_corpus)
    tgt_iter = make_text_iterator_from_file(tgt_corpus)
    stgt_iter = make_text_iterator_from_file(stgt_corpus)
    structure_iter = make_text_iterator_from_file(structure_corpus)
    mask_iter = make_text_iterator_from_file(mask_corpus)
    relation_iter = make_text_iterator_from_file(relation_corpus)

    dataset = build_dataset(fields,
                            src_iter,
                            tgt_iter,
                            stgt_iter,
                            structure_iter,
                            mask_iter,
                            relation_iter,
                            src_seq_length=opt.src_seq_length,
                            tgt_seq_length=opt.tgt_seq_length,
                            src_seq_length_trunc=opt.src_seq_length_trunc,
                            tgt_seq_length_trunc=opt.tgt_seq_length_trunc,
                            abundancy=opt.abundancy)

    # We save fields in vocab.pt seperately, so make it empty.
    dataset.fields = []

    pt_file = "{:s}_{:s}.pt".format(opt.save_data, corpus_type)
    logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file))

    torch.save(dataset, pt_file)

    return [pt_file]
Example #3
0
    def translate(self,
                  src_data_iter,
                  tgt_data_iter,
                  batch_size,
                  out_file=None):
        # data每次产生一个eaxmple, 包含example.indice, example.src
        data = build_dataset(self.fields,
                             src_data_iter=src_data_iter,
                             tgt_data_iter=tgt_data_iter,
                             use_filter_pred=False)

        def sort_translation(indices, translation):
            # indices是一维张量,translation是一维数组
            ordered_transalation = [None] * len(translation)
            for i, index in enumerate(indices):
                ordered_transalation[index] = translation[i]
            return ordered_transalation

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = OrderedIterator(dataset=data,
                                    device=cur_device,
                                    batch_size=batch_size,
                                    train=False,
                                    sort=False,
                                    sort_within_batch=True,
                                    shuffle=False)
        start_time = time.time()
        print("Begin decoding ...")
        idx = 0,  # 此处的batch中的src每行长度不对齐
        for batch in data_iter:
            # batch.src[0]: (27, batch_size), batch.src[1]: (27, ... ,...)
            # hyps尺寸为(batch_size, 4)的arry, scores长度为batch_size的一维数组
            # 可以看出最终每句话均翻译为4个单词

            # 下面代码使用batch的时候,并没有迭代,而是直接取值
            hyps, scores = self.translate_batch(batch)
            assert len(batch) == len(hyps)
            transtaltion = []
            for idx_seq, score in zip(hyps, scores):
                words = self.build_tokens(idx_seq, side='tgt')
                tran = ' '.join(words)
                transtaltion.append(tran)
            if out_file is not None:
                transtaltion = sort_translation(batch.indices.data - idx,
                                                transtaltion)
                for tran in transtaltion:
                    out_file.write(tran + '\n')
            idx += len(batch)
            print("sents " + str(idx) + "...")
        print('Decoding took %.1f minutes ...' %
              (float(time.time() - start_time) / 60.))
Example #4
0
    def translate(self,
                  src_data_iter,
                  tgt_data_iter,
                  batch_size,
                  out_file=None):
        data = build_dataset(self.fields,
                             src_data_iter=src_data_iter,
                             tgt_data_iter=tgt_data_iter,
                             use_filter_pred=False)

        def sort_translation(indices, translation):
            ordered_transalation = [None] * len(translation)
            for i, index in enumerate(indices):
                ordered_transalation[index] = translation[i]
            return ordered_transalation

        if self.cuda:
            cur_device = "cuda"
        else:
            cur_device = "cpu"

        data_iter = OrderedIterator(dataset=data,
                                    device=cur_device,
                                    batch_size=batch_size,
                                    train=False,
                                    sort=False,
                                    sort_within_batch=True,
                                    shuffle=False)
        start_time = time.time()
        print("Begin decoding ...")
        idx = 0
        for batch in data_iter:
            hyps, scores = self.translate_batch(batch)
            assert len(batch) == len(hyps)
            transtaltion = []
            for idx_seq, score in zip(hyps, scores):
                words = self.build_tokens(idx_seq, side='tgt')
                tran = ' '.join(words)
                transtaltion.append(tran)
            if out_file is not None:
                transtaltion = sort_translation(batch.indices.data - idx,
                                                transtaltion)
                for tran in transtaltion:
                    out_file.write(tran + '\n')
            idx += len(batch)
            print("sents " + str(idx) + "...")
        print('Decoding took %.1f minutes ...' %
              (float(time.time() - start_time) / 60.))
Example #5
0
  def translate(self, src_data_iter, tgt_data_iter, batch_size, out_file=None):
    data = build_dataset(self.fields,
                         src_data_iter=src_data_iter,
                         tgt_data_iter=tgt_data_iter,
                         use_filter_pred=False)
    
    def sort_translation(indices, translation):
      ordered_transalation = [None] * len(translation)
      for i, index in enumerate(indices):
        ordered_transalation[index] = translation[i]
      return ordered_transalation
    
    if self.cuda:
        cur_device = "cuda"
    else:
        cur_device = "cpu"
    # sort=True sort_within_batch=True shuffle=True
    data_iter = OrderedIterator(
      dataset=data, device=cur_device,
      batch_size=batch_size, train=False, sort=False,
      sort_within_batch=False, shuffle=False)
    start_time = time.time()
    print("Begin decoding ...")
    batch_count = 0
    all_translation = []
    for batch in data_iter:
      hyps, scores = self.translate_batch(batch)
      assert len(batch) == len(hyps)
      batch_transtaltion = []

      for src_idx_seq, tran_idx_seq, score in zip(batch.src[0].transpose(0, 1), hyps, scores):
        # src_words = self.build_tokens(src_idx_seq, side='src')
        # src = ' '.join(src_words)
        tran_words = self.build_tokens(tran_idx_seq, side='tgt')
        tran = ' '.join(tran_words)
        batch_transtaltion.append(tran)
        print("SOURCE: " + "Three-modal" + "\nOUTPUT: " + tran + "\n")#src
      for index, tran in zip(batch.indices.data, batch_transtaltion):# why my batch have inidices
        while (len(all_translation) <=  index):
          all_translation.append("")
        all_translation[index] = tran
      batch_count += 1
      print("batch: " + str(batch_count) + "...")
      
    if out_file is not None:
      for tran in all_translation:
        out_file.write(tran + '\n')
    print('Decoding took %.1f minutes ...'%(float(time.time() - start_time) / 60.))
def build_save_dataset(corpus_type, task_type, fields, opt):
    """ Building and saving the dataset """
    assert corpus_type in ['train', 'valid']
    assert task_type in ['task', 'task2']

    if corpus_type == 'train':
        if task_type == 'task':
            src_corpus = opt.train_src
            tgt_corpus = opt.train_tgt
        else:
            src_corpus = opt.train_src2
            tgt_corpus = opt.train_tgt2
    else:
        if task_type == 'task':
            src_corpus = opt.valid_src
            tgt_corpus = opt.valid_tgt
        else:
            src_corpus = opt.valid_src2
            tgt_corpus = opt.valid_tgt2

    if (opt.shard_size > 0):
        return build_save_in_shards_using_shards_size(src_corpus, tgt_corpus,
                                                      fields, corpus_type,
                                                      task_type, opt)

    # We only build a monolithic dataset.
    # But since the interfaces are uniform, it would be not hard
    # to do this should users need this feature.
    src_iter = make_text_iterator_from_file(src_corpus)
    tgt_iter = make_text_iterator_from_file(tgt_corpus)
    dataset = build_dataset(fields,
                            src_iter,
                            tgt_iter,
                            src_seq_length=opt.src_seq_length,
                            tgt_seq_length=opt.tgt_seq_length,
                            src_seq_length_trunc=opt.src_seq_length_trunc,
                            tgt_seq_length_trunc=opt.tgt_seq_length_trunc)

    # We save fields in vocab.pt seperately, so make it empty.
    dataset.fields = []

    pt_file = "{:s}_{:s}_{:s}.pt".format(opt.save_data, task_type, corpus_type)
    logger.info(" * saving %s %s dataset to %s." %
                (task_type, corpus_type, pt_file))
    torch.save(dataset, pt_file)

    return [pt_file]
Example #7
0
def build_save_dataset(corpus_type, fields, opt):
    """ Building and saving the dataset """
    assert corpus_type in ['train', 'valid']
    # X_train, X_valid, X_test, y_train, y_valid, y_test = data_loader.test_mosei_emotion_data()
    X_train, X_valid, X_test, y_train, y_valid, y_test = data_loader.read_cmumosei_emotion_pkl(
    )
    if corpus_type == 'train':
        src_corpus = X_train  #opt.train_src
        tgt_corpus = y_train  #opt.train_tgt
    else:
        src_corpus = X_valid  #opt.valid_src
        tgt_corpus = y_valid  #opt.valid_tgt

    if (opt.shard_size > 0):
        return build_save_in_shards_using_shards_size(src_corpus, tgt_corpus,
                                                      fields, corpus_type, opt)

    # We only build a monolithic dataset.
    # But since the interfaces are uniform, it would be not hard
    # to do this should users need this feature.

    src_iter = make_text_iterator_from_file(src_corpus)
    tgt_iter = make_text_iterator_from_file(tgt_corpus)
    dataset = build_dataset(fields,
                            src_iter,
                            tgt_iter,
                            src_seq_length=opt.src_seq_length,
                            tgt_seq_length=opt.tgt_seq_length,
                            src_seq_length_trunc=opt.src_seq_length_trunc,
                            tgt_seq_length_trunc=opt.tgt_seq_length_trunc)

    # We save fields in vocab.pt seperately, so make it empty.p
    dataset.fields = []
    pt_file = "{:s}_{:s}.pt".format(opt.save_data, corpus_type)
    logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file))
    torch.save(dataset, pt_file)

    return [pt_file]
def build_save_in_shards_using_shards_size(src_corpus, tgt_corpus, fields,
                                           corpus_type, task_type, opt):
    src_data = []
    tgt_data = []
    with open(src_corpus, "r") as src_file:
        with open(tgt_corpus, "r") as tgt_file:
            for s, t in zip(src_file, tgt_file):
                src_data.append(s)
                tgt_data.append(t)
    if len(src_data) != len(tgt_data):
        raise AssertionError("Source and Target should \
                           have the same length")

    num_shards = int(len(src_data) / opt.shard_size)
    for x in range(num_shards):
        logger.info("Splitting shard %d." % x)
        f = codecs.open(src_corpus + ".{0}.txt".format(x),
                        "w",
                        encoding="utf-8")
        f.writelines(src_data[x * opt.shard_size:(x + 1) * opt.shard_size])
        f.close()
        f = codecs.open(tgt_corpus + ".{0}.txt".format(x),
                        "w",
                        encoding="utf-8")
        f.writelines(tgt_data[x * opt.shard_size:(x + 1) * opt.shard_size])
        f.close()
    num_written = num_shards * opt.shard_size
    if len(src_data) > num_written:
        logger.info("Splitting shard %d." % num_shards)
        f = codecs.open(src_corpus + ".{0}.txt".format(num_shards),
                        'w',
                        encoding="utf-8")
        f.writelines(src_data[num_shards * opt.shard_size:])
        f.close()
        f = codecs.open(tgt_corpus + ".{0}.txt".format(num_shards),
                        'w',
                        encoding="utf-8")
        f.writelines(tgt_data[num_shards * opt.shard_size:])
        f.close()
    src_list = sorted(glob.glob(src_corpus + '.*.txt'))
    tgt_list = sorted(glob.glob(tgt_corpus + '.*.txt'))

    ret_list = []

    for index, src in enumerate(src_list):
        logger.info("Building shard %d." % index)
        src_iter = make_text_iterator_from_file(src)
        tgt_iter = make_text_iterator_from_file(tgt_list[index])
        dataset = build_dataset(fields,
                                src_iter,
                                tgt_iter,
                                src_seq_length=opt.src_seq_length,
                                tgt_seq_length=opt.tgt_seq_length,
                                src_seq_length_trunc=opt.src_seq_length_trunc,
                                tgt_seq_length_trunc=opt.tgt_seq_length_trunc,
                                task_type=task_type)

        pt_file = "{:s}_{:s}_{:s}.{:d}.pt".format(opt.save_data, task_type,
                                                  corpus_type, index)

        # We save fields in vocab.pt seperately, so make it empty.
        dataset.fields = []

        logger.info(" * saving %sth %s data shard to %s." %
                    (index, corpus_type, pt_file))
        torch.save(dataset, pt_file)

        ret_list.append(pt_file)
        os.remove(src)
        os.remove(tgt_list[index])
        del dataset.examples
        gc.collect()
        del dataset
        gc.collect()

    return ret_list
Example #9
0
def build_save_dataset(corpus_type, fields, opt):  #corpus_type: train or valid
    """ Building and saving the dataset """
    assert corpus_type in ['train',
                           'valid']  #Judging whether it is train or valid

    if corpus_type == 'train':
        src_corpus = opt.train_src  #获取源端、目标端和结构信息的path
        tgt_corpus = opt.train_tgt
        structure_corpus1 = opt.train_structure1
        structure_corpus2 = opt.train_structure2
        structure_corpus3 = opt.train_structure3
        structure_corpus4 = opt.train_structure4
        structure_corpus5 = opt.train_structure5
        structure_corpus6 = opt.train_structure6
        structure_corpus7 = opt.train_structure7
        structure_corpus8 = opt.train_structure8

    else:
        src_corpus = opt.valid_src
        tgt_corpus = opt.valid_tgt
        structure_corpus1 = opt.valid_structure1
        structure_corpus2 = opt.valid_structure2
        structure_corpus3 = opt.valid_structure3
        structure_corpus4 = opt.valid_structure4
        structure_corpus5 = opt.valid_structure5
        structure_corpus6 = opt.valid_structure6
        structure_corpus7 = opt.valid_structure7
        structure_corpus8 = opt.valid_structure8

    if (opt.shard_size > 0):
        return build_save_in_shards_using_shards_size(
            src_corpus, tgt_corpus, structure_corpus1, structure_corpus2,
            structure_corpus3, structure_corpus4, structure_corpus5, fields,
            corpus_type, opt)

    # We only build a monolithic dataset.
    # But since the interfaces are uniform, it would be not hard to do this should users need this feature.
    src_iter = make_text_iterator_from_file(src_corpus)
    tgt_iter = make_text_iterator_from_file(tgt_corpus)
    structure_iter1 = make_text_iterator_from_file(structure_corpus1)
    structure_iter2 = make_text_iterator_from_file(structure_corpus2)
    structure_iter3 = make_text_iterator_from_file(structure_corpus3)
    structure_iter4 = make_text_iterator_from_file(structure_corpus4)
    structure_iter5 = make_text_iterator_from_file(structure_corpus5)
    # structure_iter6 = make_text_iterator_from_file(structure_corpus6)
    # structure_iter7 = make_text_iterator_from_file(structure_corpus7)
    # structure_iter8 = make_text_iterator_from_file(structure_corpus8)

    dataset = build_dataset(fields,
                            src_iter,
                            tgt_iter,
                            structure_iter1,
                            structure_iter2,
                            structure_iter3,
                            structure_iter4,
                            structure_iter5,
                            src_seq_length=opt.src_seq_length,
                            tgt_seq_length=opt.tgt_seq_length,
                            src_seq_length_trunc=opt.src_seq_length_trunc,
                            tgt_seq_length_trunc=opt.tgt_seq_length_trunc)

    # We save fields in vocab.pt seperately, so make it empty.
    dataset.fields = []

    pt_file = "{:s}_{:s}.pt".format(opt.save_data, corpus_type)
    logger.info(" * saving %s dataset to %s." % (corpus_type, pt_file))

    torch.save(dataset, pt_file)

    return [pt_file]
Example #10
0
def build_save_in_shards_using_shards_size(
        src_corpus, tgt_corpus, structure_corpus1, structure_corpus2,
        structure_corpus3, structure_corpus4, structure_corpus5, fields,
        corpus_type, opt):
    src_data = []
    tgt_data = []
    structure_data1 = []
    structure_data2 = []
    structure_data3 = []
    structure_data4 = []
    structure_data5 = []
    # structure_data6 = []
    # structure_data7 = []
    # structure_data8 = []

    with open(src_corpus, "r") as src_file:
        with open(tgt_corpus, "r") as tgt_file:
            with open(structure_corpus1, "r") as structure_file1:
                with open(structure_corpus2, "r") as structure_file2:
                    with open(structure_corpus3, "r") as structure_file3:
                        with open(structure_corpus4, "r") as structure_file4:
                            with open(structure_corpus5,
                                      "r") as structure_file5:
                                # with open(structure_corpus6, "r") as structure_file6:
                                #   with open(structure_corpus7, "r") as structure_file7:
                                #     with open(structure_corpus8, "r") as structure_file8:
                                for s, t, structure1, structure2, structure3, structure4, structure5 in zip(
                                        src_file,
                                        tgt_file,
                                        structure_file1,
                                        structure_file2,
                                        structure_file3,
                                        structure_file4,
                                        structure_file5,
                                ):

                                    src_data.append(s)
                                    tgt_data.append(t)
                                    structure_data1.append(structure1)
                                    structure_data2.append(structure2)
                                    structure_data3.append(structure3)
                                    structure_data4.append(structure4)
                                    structure_data5.append(structure5)
                                    # structure_data6.append(structure6)
                                    # structure_data7.append(structure7)
                                    # structure_data8.append(structure8)

                                    assert (len(s.split()) + 1)**2 == len(
                                        structure1.split())
                                    assert (len(s.split()) + 1)**2 == len(
                                        structure2.split())
                                    assert (len(s.split()) + 1)**2 == len(
                                        structure3.split())
                                    assert (len(s.split()) + 1)**2 == len(
                                        structure4.split())
                                    assert (len(s.split()) + 1)**2 == len(
                                        structure5.split())

    if len(src_data) != len(tgt_data) or len(tgt_data) != len(structure_data1):
        raise AssertionError(
            "Source,Target,structure and index should have the same length")

    num_shards = int(len(src_data) / opt.shard_size)
    for x in range(num_shards):
        logger.info("Splitting shard %d." % x)

        f = codecs.open(src_corpus + ".{0}.txt".format(x),
                        "w",
                        encoding="utf-8")
        f.writelines(src_data[x * opt.shard_size:(x + 1) * opt.shard_size])
        f.close()

        f = codecs.open(tgt_corpus + ".{0}.txt".format(x),
                        "w",
                        encoding="utf-8")
        f.writelines(tgt_data[x * opt.shard_size:(x + 1) * opt.shard_size])
        f.close()

        f = codecs.open(structure_corpus1 + ".{0}.txt".format(x),
                        "w",
                        encoding="utf-8")
        f.writelines(structure_data1[x * opt.shard_size:(x + 1) *
                                     opt.shard_size])
        f.close()

        f = codecs.open(structure_corpus2 + ".{0}.txt".format(x),
                        "w",
                        encoding="utf-8")
        f.writelines(structure_data2[x * opt.shard_size:(x + 1) *
                                     opt.shard_size])
        f.close()

        f = codecs.open(structure_corpus3 + ".{0}.txt".format(x),
                        "w",
                        encoding="utf-8")
        f.writelines(structure_data3[x * opt.shard_size:(x + 1) *
                                     opt.shard_size])
        f.close()

        f = codecs.open(structure_corpus4 + ".{0}.txt".format(x),
                        "w",
                        encoding="utf-8")
        f.writelines(structure_data4[x * opt.shard_size:(x + 1) *
                                     opt.shard_size])
        f.close()

        f = codecs.open(structure_corpus5 + ".{0}.txt".format(x),
                        "w",
                        encoding="utf-8")
        f.writelines(structure_data5[x * opt.shard_size:(x + 1) *
                                     opt.shard_size])
        f.close()

        # f = codecs.open(structure_corpus6 + ".{0}.txt".format(x), "w", encoding="utf-8")
        # f.writelines(structure_data6[x * opt.shard_size: (x + 1) * opt.shard_size])
        # f.close()
        #
        # f = codecs.open(structure_corpus7 + ".{0}.txt".format(x), "w", encoding="utf-8")
        # f.writelines(structure_data7[x * opt.shard_size: (x + 1) * opt.shard_size])
        # f.close()
        #
        # f = codecs.open(structure_corpus8 + ".{0}.txt".format(x), "w", encoding="utf-8")
        # f.writelines(structure_data8[x * opt.shard_size: (x + 1) * opt.shard_size])
        # f.close()

    num_written = num_shards * opt.shard_size
    if len(src_data) > num_written:  #处理最后一个剩下的shard
        logger.info("Splitting shard %d." % num_shards)
        f = codecs.open(src_corpus + ".{0}.txt".format(num_shards),
                        'w',
                        encoding="utf-8")
        f.writelines(src_data[num_shards * opt.shard_size:])
        f.close()

        f = codecs.open(tgt_corpus + ".{0}.txt".format(num_shards),
                        'w',
                        encoding="utf-8")
        f.writelines(tgt_data[num_shards * opt.shard_size:])
        f.close()

        f = codecs.open(structure_corpus1 + ".{0}.txt".format(num_shards),
                        'w',
                        encoding="utf-8")
        f.writelines(structure_data1[num_shards * opt.shard_size:])
        f.close()
        f = codecs.open(structure_corpus2 + ".{0}.txt".format(num_shards),
                        'w',
                        encoding="utf-8")
        f.writelines(structure_data2[num_shards * opt.shard_size:])
        f.close()
        f = codecs.open(structure_corpus3 + ".{0}.txt".format(num_shards),
                        'w',
                        encoding="utf-8")
        f.writelines(structure_data3[num_shards * opt.shard_size:])
        f.close()
        f = codecs.open(structure_corpus4 + ".{0}.txt".format(num_shards),
                        'w',
                        encoding="utf-8")
        f.writelines(structure_data4[num_shards * opt.shard_size:])
        f.close()
        f = codecs.open(structure_corpus5 + ".{0}.txt".format(num_shards),
                        'w',
                        encoding="utf-8")
        f.writelines(structure_data5[num_shards * opt.shard_size:])
        f.close()
        # f = codecs.open(structure_corpus6 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8")
        # f.writelines(structure_data6[num_shards * opt.shard_size:])
        # f.close()
        # f = codecs.open(structure_corpus7 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8")
        # f.writelines(structure_data7[num_shards * opt.shard_size:])
        # f.close()
        # f = codecs.open(structure_corpus8 + ".{0}.txt".format(num_shards), 'w', encoding="utf-8")
        # f.writelines(structure_data8[num_shards * opt.shard_size:])
        # f.close()

    src_list = sorted(glob.glob(src_corpus + '.*.txt'))
    tgt_list = sorted(glob.glob(tgt_corpus + '.*.txt'))
    structure_list1 = sorted(glob.glob(structure_corpus1 + '.*.txt'))
    structure_list2 = sorted(glob.glob(structure_corpus2 + '.*.txt'))
    structure_list3 = sorted(glob.glob(structure_corpus3 + '.*.txt'))
    structure_list4 = sorted(glob.glob(structure_corpus4 + '.*.txt'))
    structure_list5 = sorted(glob.glob(structure_corpus5 + '.*.txt'))
    # structure_list6 = sorted(glob.glob(structure_corpus6 + '.*.txt'))
    # structure_list7 = sorted(glob.glob(structure_corpus7 + '.*.txt'))
    # structure_list8 = sorted(glob.glob(structure_corpus8 + '.*.txt'))

    ret_list = []

    for i, src in enumerate(src_list):
        logger.info("Building shard %d." % i)
        src_iter = make_text_iterator_from_file(src)  #迭代器,每次返回文件中的一行数据
        tgt_iter = make_text_iterator_from_file(tgt_list[i])
        structure_iter1 = make_text_iterator_from_file(structure_list1[i])
        structure_iter2 = make_text_iterator_from_file(structure_list2[i])
        structure_iter3 = make_text_iterator_from_file(structure_list3[i])
        structure_iter4 = make_text_iterator_from_file(structure_list4[i])
        structure_iter5 = make_text_iterator_from_file(structure_list5[i])
        # structure_iter6 = make_text_iterator_from_file(structure_list6[i])
        # structure_iter7 = make_text_iterator_from_file(structure_list7[i])
        # structure_iter8 = make_text_iterator_from_file(structure_list8[i])

        dataset = build_dataset(fields,
                                src_iter,
                                tgt_iter,
                                structure_iter1,
                                structure_iter2,
                                structure_iter3,
                                structure_iter4,
                                structure_iter5,
                                src_seq_length=opt.src_seq_length,
                                tgt_seq_length=opt.tgt_seq_length,
                                src_seq_length_trunc=opt.src_seq_length_trunc,
                                tgt_seq_length_trunc=opt.tgt_seq_length_trunc)

        pt_file = "{:s}_{:s}.{:d}.pt".format(opt.save_data, corpus_type,
                                             i)  #..../gq_coupus_type.{0,1}.pt

        # We save fields in vocab.pt seperately, so make it empty.
        dataset.fields = []

        logger.info(" * saving %sth %s data shard to %s." %
                    (i, corpus_type, pt_file))
        torch.save(dataset, pt_file)
        ret_list.append(pt_file)

        os.remove(src)
        os.remove(tgt_list[i])
        os.remove(structure_list1[i])
        os.remove(structure_list2[i])
        os.remove(structure_list3[i])
        os.remove(structure_list4[i])
        os.remove(structure_list5[i])
        # os.remove(structure_list6[i])
        # os.remove(structure_list7[i])
        # os.remove(structure_list8[i])
        del dataset.examples
        gc.collect()
        del dataset
        gc.collect()

    return ret_list  #返回一个文件名列表
Example #11
0
def build_save_in_shards_using_shards_size(src_corpus, tgt_corpus,
                                           structure_corpus, mask_corpus,
                                           relation_corpus, fields,
                                           corpus_type, opt):
    src_data = []
    tgt_data = []
    structure_data = []
    mask_data = []
    relation_data = []
    with open(src_corpus, "r") as src_file:
        with open(tgt_corpus, "r") as tgt_file:
            with open(structure_corpus, "r") as structure_file:
                with open(mask_corpus, 'r') as mask_file:
                    with open(relation_corpus, 'r') as relation_file:
                        for s, t, structure, mask, relation in zip(
                                src_file, tgt_file, structure_file, mask_file,
                                relation_file):
                            src_data.append(s)
                            tgt_data.append(t)
                            structure_data.append(structure)
                            mask_data.append(mask)
                            relation_data.append(relation)

                            assert (len(s.split()) + 1)**2 == len(
                                structure.split()) and (len(t.split())**2
                                                        == len(mask.split()))

    if len(src_data) != len(tgt_data) or len(tgt_data) != len(structure_data):
        raise AssertionError(
            "Source,Target and structure should have the same length")

    num_shards = int(len(src_data) / opt.shard_size)
    for x in range(num_shards):
        logger.info("Splitting shard %d." % x)

        f = codecs.open(src_corpus + ".{0}.txt".format(x),
                        "w",
                        encoding="utf-8")
        f.writelines(src_data[x * opt.shard_size:(x + 1) * opt.shard_size])
        f.close()

        f = codecs.open(tgt_corpus + ".{0}.txt".format(x),
                        "w",
                        encoding="utf-8")
        f.writelines(tgt_data[x * opt.shard_size:(x + 1) * opt.shard_size])
        f.close()

        f = codecs.open(structure_corpus + ".{0}.txt".format(x),
                        "w",
                        encoding="utf-8")
        f.writelines(structure_data[x * opt.shard_size:(x + 1) *
                                    opt.shard_size])
        f.close()

        f = codecs.open(mask_corpus + ".{0}.txt".format(x),
                        "w",
                        encoding="utf-8")
        f.writelines(mask_data[x * opt.shard_size:(x + 1) * opt.shard_size])
        f.close()

        f = codecs.open(relation_corpus + ".{0}.txt".format(x),
                        "w",
                        encoding="utf-8")
        f.writelines(relation_data[x * opt.shard_size:(x + 1) *
                                   opt.shard_size])
        f.close()

    num_written = num_shards * opt.shard_size
    if len(src_data) > num_written:  # 处理最后一个剩下的shard
        logger.info("Splitting shard %d." % num_shards)
        f = codecs.open(src_corpus + ".{0}.txt".format(num_shards),
                        'w',
                        encoding="utf-8")
        f.writelines(src_data[num_shards * opt.shard_size:])
        f.close()

        f = codecs.open(tgt_corpus + ".{0}.txt".format(num_shards),
                        'w',
                        encoding="utf-8")
        f.writelines(tgt_data[num_shards * opt.shard_size:])
        f.close()

        f = codecs.open(structure_corpus + ".{0}.txt".format(num_shards),
                        'w',
                        encoding="utf-8")
        f.writelines(structure_data[num_shards * opt.shard_size:])
        f.close()

        f = codecs.open(mask_corpus + ".{0}.txt".format(num_shards),
                        'w',
                        encoding="utf-8")
        f.writelines(mask_data[num_shards * opt.shard_size:])
        f.close()

        f = codecs.open(relation_corpus + ".{0}.txt".format(num_shards),
                        "w",
                        encoding="utf-8")
        f.writelines(relation_data[num_shards * opt.shard_size:])
        f.close()

    src_list = sorted(glob.glob(src_corpus + '.*.txt'))
    tgt_list = sorted(glob.glob(tgt_corpus + '.*.txt'))
    structure_list = sorted(glob.glob(structure_corpus + '.*.txt'))
    mask_list = sorted(glob.glob(mask_corpus + '.*.txt'))
    relation_list = sorted(glob.glob(relation_corpus + '.*.txt'))

    ret_list = []

    for index, src in enumerate(src_list):
        logger.info("Building shard %d." % index)
        src_iter = make_text_iterator_from_file(src)  # 迭代器,每次返回文件中的一行数据
        tgt_iter = make_text_iterator_from_file(tgt_list[index])
        structure_iter = make_text_iterator_from_file(structure_list[index])
        mask_iter = make_text_iterator_from_file(mask_list[index])
        relation_iter = make_text_iterator_from_file(relation_list[index])

        dataset = build_dataset(fields,
                                src_iter,
                                tgt_iter,
                                structure_iter,
                                mask_iter,
                                relation_iter,
                                src_seq_length=opt.src_seq_length,
                                tgt_seq_length=opt.tgt_seq_length,
                                src_seq_length_trunc=opt.src_seq_length_trunc,
                                tgt_seq_length_trunc=opt.tgt_seq_length_trunc)

        pt_file = "{:s}_{:s}.{:d}.pt".format(
            opt.save_data, corpus_type, index)  # ..../gq_coupus_type.{0,1}.pt

        # We save fields in vocab.pt seperately, so make it empty.
        dataset.fields = []

        logger.info(" * saving %sth %s data shard to %s." %
                    (index, corpus_type, pt_file))
        torch.save(dataset, pt_file)
        ret_list.append(pt_file)

        os.remove(src)
        os.remove(tgt_list[index])
        os.remove(structure_list[index])
        os.remove(mask_list[index])
        os.remove(relation_list[index])
        del dataset.examples
        gc.collect()
        del dataset
        gc.collect()

    return ret_list  # 返回一个文件名列表