def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    check_existing_pt_files(opt)
    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)
    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, src_reader,
                                             tgt_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
예제 #2
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    if not (opt.overwrite):
        check_existing_pt_files(opt)

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = 0
    # tgt_nfeats = 0
    for src in opt.train_src:
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
    logger.info(" * number of source features: %d." % src_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  src_truncate=opt.src_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)

    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, src_reader, opt)

    if opt.valid_src:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, opt)
예제 #3
0
def preprocess(opt):
    # 参数的验证
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0

    fields = inputters.get_fields(
        src_nfeats,
        tgt_nfeats,
        src_truncate=opt.src_seq_length_trunc,
        tgt_truncate=opt.tgt_seq_length_trunc)
    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    align_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    build_save_dataset(
        'train', fields, src_reader, tgt_reader, align_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset(
            'valid', fields, src_reader, tgt_reader, align_reader, opt)
예제 #4
0
def preprocess(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)

    init_logger(opt.log_file)

    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    src_nfeats = count_features(opt.train_src[0]) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt[0])  # tgt always text so far
    if len(opt.train_src) > 1 and opt.data_type == 'text':
        for src, tgt in zip(opt.train_src[1:], opt.train_tgt[1:]):
            assert src_nfeats == count_features(src),\
                "%s seems to mismatch features of "\
                "the other source datasets" % src
            assert tgt_nfeats == count_features(tgt),\
                "%s seems to mismatch features of "\
                "the other target datasets" % tgt
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    if opt.disable_eos_sampling:
        eos_token = "<blank>"
        logger.info("Using NO eos token")
    else:
        eos_token = "</s>"
        logger.info("Using standard eos token")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  with_align=opt.train_align[0] is not None,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc,
                                  eos=eos_token)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    align_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, src_reader, tgt_reader, align_reader,
                       opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader,
                           align_reader, opt)
예제 #5
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    #check_existing_pt_files(opt)

    init_logger(opt.log_file)
    #if not os.path.exists(os.path.dirname(opt.save_data)):
    #    os.makedirs(os.path.dirname(opt.save_data))
    #    logger.info("Creating dirs..."+os.path.dirname(opt.save_data))
    logger.info("Extracting features...")
    fields = get_fields(opt)
    logger.info("Building & saving training data...")
    build_save_dataset(opt, fields)
예제 #6
0
def preprocess(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)

    init_logger(opt.log_file)

    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    for src, tgt in zip(opt.train_src, opt.train_tgt):
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
        # print(src_nfeats)
    # exit()
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)
    # exit()
    ##################=======================================
    tt_nfeats = tgt_nfeats
    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tt_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  with_align=opt.train_align[0] is not None,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    ########===============================================================
    tt_reader = inputters.str2reader["text"].from_opt(opt)

    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    align_reader = inputters.str2reader["text"].from_opt(opt)

    # for k,v in fields.items():
    #     if(k in ['src','tgt','tt']):
    #         print(("preprocess .py preprocess() fields_item",k,v.fields[0][1].include_lengths))
    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, src_reader, tt_reader, tgt_reader,
                       align_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tt_reader, tgt_reader,
                           align_reader, opt)
예제 #7
0
def preprocess(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)

    init_logger(opt.log_file)

    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    confnet_nfeats = 0
    for src, tgt, cnet in zip(opt.train_src, opt.train_tgt, opt.train_confnet):
        src_nfeats += count_features(src) if opt.data_type == 'text' or opt.data_type == 'lattice' \
            else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
        #confnet_nfeats += count_features(cnet) if opt.data_type == 'lattice' \
        #    else 0
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)
    logger.info(" * number of confnet features: %d." % confnet_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  confnet_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  with_align=opt.train_align[0] is not None,
                                  ans_truncate=opt.src_seq_length_trunc,
                                  ques_truncate=opt.confnet_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)
    #print('fields done')
    ans_reader = inputters.str2reader["text"].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    align_reader = inputters.str2reader["text"].from_opt(opt)
    ques_reader = inputters.str2reader["lattice"].from_opt(opt)
    #print('src_reader', ques_reader)
    #print('tgt_reader', tgt_reader)
    #print('aglign_reader', align_reader)
    #print('confnet_reader', ans_reader)
    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, ques_reader, ans_reader, tgt_reader,
                       align_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, ques_reader, ans_reader,
                           tgt_reader, align_reader, opt)
예제 #8
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    check_existing_pt_files(opt)

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(
        opt.train_src) if opt.data_type == 'text' else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far

    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    if len(opt.src_vocab) > 0:
        assert len(opt.src_vocab) == len(
            opt.train_src
        ), "you should provide src vocab for each dataset if you want to use your own vocab"
    for i, (train_src,
            train_tgt) in enumerate(zip(opt.train_src, opt.train_tgt)):
        valid_src = opt.valid_src[i]
        valid_tgt = opt.valid_tgt[i]
        logger.info("Working on %d dataset..." % i)
        logger.info("Building `Fields` object...")
        fields = inputters.get_fields(opt.data_type,
                                      src_nfeats,
                                      tgt_nfeats,
                                      dynamic_dict=opt.dynamic_dict,
                                      src_truncate=opt.src_seq_length_trunc,
                                      tgt_truncate=opt.tgt_seq_length_trunc)

        src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
        tgt_reader = inputters.str2reader["text"].from_opt(opt)

        logger.info("Building & saving training data...")
        train_dataset_files = build_save_dataset('train', fields, src_reader,
                                                 tgt_reader, opt, i, train_src,
                                                 train_tgt, valid_src,
                                                 valid_tgt)

        if opt.valid_src and opt.valid_tgt:
            logger.info("Building & saving validation data...")
            build_save_dataset('valid', fields, src_reader, tgt_reader, opt, i,
                               train_src, train_tgt, valid_src, valid_tgt)

        logger.info("Building & saving vocabulary...")
        build_save_vocab(train_dataset_files, fields, opt, i)
예제 #9
0
def preprocess(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)

    init_logger(opt.log_file)

    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    for src, tgt in zip(opt.train_src, opt.train_tgt):
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    print("haha")
    logger.info("haha")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  with_align=opt.train_align[0] is not None,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)

    logger.info("please...")
    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    align_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    is_train_dataset_finished = build_save_dataset('train', fields, src_reader,
                                                   tgt_reader, align_reader,
                                                   opt)
    #if is_train_dataset_finished: # For Debugging(From sehun)
    #    logger.info("Train data is finished") # For Debugging(From sehun)
    #else: # For Debugging(From sehun)
    #    logger.info("f**k") # For Debugging(From sehun)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader,
                           align_reader, opt)
예제 #10
0
def preprocess(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)

    init_logger(opt.log_file)

    logger.info("Extracting features...")

    src_nfeats = 0
    cue_nfeats = 0
    tgt_nfeats = 0
    for src, cue, tgt in zip(opt.train_src, opt.train_cue, opt.train_tgt):
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
        cue_nfeats += count_features(cue) if opt.data_type == 'text' else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info("* number of knowledge features: %d" % cue_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  cue_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  src_truncate=opt.src_seq_length_trunc,
                                  cue_truncate=opt.cue_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    cue_reader = inputters.str2reader["text"].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, src_reader, cue_reader, tgt_reader,
                       opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, cue_reader, tgt_reader,
                           opt)
예제 #11
0
def preprocess(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)

    init_logger(opt.log_file)

    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    for src, tgt in zip(opt.train_src, opt.train_tgt):
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    # TODO get_fields api need to update for pivot language pairs
    fields = get_fields(opt.data_type,
                        src_nfeats,
                        tgt_nfeats,
                        dynamic_dict=opt.dynamic_dict,
                        with_align=opt.train_align[0] is not None,
                        src_truncate=opt.src_seq_length_trunc,
                        tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    align_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    # TODO build_save_dataset api need to be updated to support a pivot language pair
    build_save_dataset('train', fields, src_reader, tgt_reader, align_reader,
                       opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader,
                           align_reader, opt)
예제 #12
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    if not (opt.overwrite):
        check_existing_pt_files(opt)

    init_logger(opt.log_file)

    shutil.copy2(opt.config, os.path.dirname(opt.log_file))
    logger.info(opt)
    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    for src, tgt in zip(opt.train_src, opt.train_tgt):
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader[opt.data_type].from_opt(opt)

    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, src_reader, tgt_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader, opt)
예제 #13
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    check_existing_pt_files(opt)

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(
        opt.data_type,
        src_nfeats,
        tgt_nfeats,
        dynamic_dict=opt.dynamic_dict,
        src_truncate=opt.src_seq_length_trunc,
        tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset(
        'train', fields, src_reader, tgt_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
예제 #14
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    check_existing_pt_files(opt)

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")

    if opt.fixed_vocab:
        tgt_bos = '<|endoftext|>'
        tgt_eos = '\u0120GDDR'
        tgt_pad = '\u0120SHALL'
        tgt_unk = '\u0120RELE'

        if opt.no_spec_src:
            src_pad = None
            src_unk = None
        elif opt.free_src:
            src_pad = '<blank>'
            src_unk = '<unk>'
        else:
            src_pad = '\u0120SHALL'
            src_unk = '\u0120RELE'

    else:
        tgt_bos = '<s>'
        tgt_eos = '</s>'
        tgt_pad = '<blank>'
        tgt_unk = '<unk>'
        src_pad = '<blank>'
        src_unk = '<unk>'

    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc,
                                  src_pad=src_pad,
                                  src_unk=src_unk,
                                  tgt_pad=tgt_pad,
                                  tgt_unk=tgt_unk,
                                  tgt_bos=tgt_bos,
                                  tgt_eos=tgt_eos,
                                  include_ptrs=opt.pointers_file is not None)

    if opt.data_type == 'none':
        src_reader = None
    else:
        src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, src_reader,
                                             tgt_reader, opt)

    if (opt.valid_src or opt.data_type == 'none') and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
예제 #15
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    if not (opt.overwrite):
        check_existing_pt_files(opt)

    init_logger(opt.log_file)

    shutil.copy2(opt.config, os.path.dirname(opt.log_file))
    logger.info(opt)
    logger.info("Extracting features...")

    #Prepares the document embedding to initialize memory vectors.
    embedder = SentenceTransformer('bert-base-nli-mean-tokens')

    kpcorpus = []
    files_path = [  #'data/keyphrase/json/kp20k/kp20k_train.json',
        'data/keyphrase/json/kp20k/kp20k_valid.json',
        'data/keyphrase/json/kp20k/kp20k_test.json',
        'data/keyphrase/json/inspec/inspec_valid.json',
        'data/keyphrase/json/inspec/inspec_test.json',
        'data/keyphrase/json/krapivin/krapivin_valid.json',
        'data/keyphrase/json/krapivin/krapivin_test.json',
        'data/keyphrase/json/nus/split/nus_valid.json',
        'data/keyphrase/json/nus/split/nus_test.json',
        'data/keyphrase/json/semeval/semeval_valid.json',
        'data/keyphrase/json/semeval/semeval_test.json',
        'data/keyphrase/json/duc/split/duc_valid.json',
        'data/keyphrase/json/duc/split/duc_test.json'
    ]
    for file_path in files_path:
        file = open(file_path, 'r')
        for line in file.readlines():
            dic = json.loads(line)
            # print(dic)
            kpcorpus.append(dic['title'] + ' ' + dic['abstract'])
            # print(kpcorpus)

    num_of_example = len(kpcorpus)
    print("number of examples in corpus: ", num_of_example)
    time_a = time.time()
    corpus_embeddings = embedder.encode(kpcorpus[:num_of_example])
    print("elapsed time: ", time.time() - time_a)
    alldocs_emb = torch.Tensor(corpus_embeddings)
    torch.save(alldocs_emb, './data/alldocs_emb')

    src_nfeats = 0
    tgt_nfeats = 0
    for src, tgt in zip(opt.train_src, opt.train_tgt):
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader[opt.data_type].from_opt(opt)

    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, src_reader, tgt_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader, opt)
def main(opt):
    # import random
    # if opt.shuffle == 1:
    #     for src_path, tgt_path in [(opt.train_src, opt.train_tgt), (opt.valid_src, opt.valid_tgt)]:
    #         with open(src_path, 'r') as f: src_lines = f.readlines()
    #         with open(tgt_path, 'r') as f: tgt_lines = f.readlines()
    #         with open(src_path+".unshuffled", 'w') as f: f.write(''.join(src_lines))
    #         with open(tgt_path+".unshuffled", 'w') as f: f.write(''.join(tgt_lines))
    #         combined = list(zip(src_lines, tgt_lines))
    #         random.shuffle(combined)
    #         src_lines[:], tgt_lines[:] = zip(*combined)
    #         with open(src_path, 'w') as f: f.write(''.join(src_lines))
    #         with open(tgt_path, 'w') as f: f.write(''.join(tgt_lines))

    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    check_existing_pt_files(opt)

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far
    agenda_nfeats = count_features(opt.train_agenda) if opt.train_agenda else 0
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    
    if opt.fixed_vocab:
        tgt_bos = '<|endoftext|>'
        tgt_eos = '\u0120GDDR'
        tgt_pad = '\u0120SHALL'
        tgt_unk = '\u0120RELE'

        if opt.no_spec_src:
            src_pad = None
            src_unk = None
        elif opt.free_src:
            src_pad = '<blank>'
            src_unk='<unk>'
        else:
            src_pad = '\u0120SHALL'
            src_unk = '\u0120RELE'

    else:
        tgt_bos='<s>'
        tgt_eos='</s>'
        tgt_pad = '<blank>'
        tgt_unk='<unk>'
        src_pad = '<blank>'
        src_unk='<unk>'

    fields = inputters.get_fields(
        opt.data_type,
        src_nfeats,
        tgt_nfeats,
        agenda_nfeats,
        dynamic_dict=opt.dynamic_dict,
        src_truncate=opt.src_seq_length_trunc,
        tgt_truncate=opt.tgt_seq_length_trunc,
        src_pad=src_pad,
        src_unk=src_unk,
        tgt_pad=tgt_pad,
        tgt_unk=tgt_unk,
        tgt_bos=tgt_bos,
        tgt_eos=tgt_eos,
        include_ptrs=opt.pointers_file is not None,
        include_agenda=opt.train_agenda or opt.valid_agenda)
    
    if opt.data_type == 'none':
        readers = [None]
    else:
        readers = [inputters.str2reader[opt.data_type].from_opt(opt)]
    readers.append(inputters.str2reader["text"].from_opt(opt))
    if opt.train_agenda or opt.valid_agenda:
        readers.append(inputters.str2reader["text"].from_opt(opt))

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset(
        'train', fields, readers, opt)

    if (opt.valid_src or opt.data_type == 'none') and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, readers, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)