def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    check_existing_pt_files(opt)
    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)
    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, src_reader,
                                             tgt_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
예제 #2
0
def main():
    opt = parse_args()

    if opt.max_shard_size > 0:
        raise AssertionError("-max_shard_size is deprecated, please use \
                             -shard_size (number of examples) instead.")
    if opt.shuffle > 0:
        raise AssertionError("-shuffle is not implemented, please make sure \
                             you shuffle your data before pre-processing.")

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = inputters.get_num_features(opt.data_type, opt.train_src,
                                            "src")
    tgt_nfeats = inputters.get_num_features(opt.data_type, opt.train_tgt,
                                            "tgt")
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset("train", fields, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset("valid", fields, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
def main():
    opt = parse_args()
    f = open("data/corpus.json", "r")
    data = f.readlines()
    samples = []
    for i in data:
        samples.append(json.loads(i)["sents"])
    corpus = []
    for i in samples:
        corpus.append(" ".join([" ".join(sent) for sent in i]))
    vectorizer = TfidfVectorizer()
    tfidf = vectorizer.fit(corpus)

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, tfidf, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, tfidf, opt)

    logger.info("Building & saving vocabulary...")
    # train_dataset_files = 'data/processed.train.pt'
    build_save_vocab(train_dataset_files, opt.data_type, fields, opt)
예제 #4
0
def preprocess(opt):
    # 参数的验证
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0

    fields = inputters.get_fields(
        src_nfeats,
        tgt_nfeats,
        src_truncate=opt.src_seq_length_trunc,
        tgt_truncate=opt.tgt_seq_length_trunc)
    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    align_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    build_save_dataset(
        'train', fields, src_reader, tgt_reader, align_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset(
            'valid', fields, src_reader, tgt_reader, align_reader, opt)
예제 #5
0
def main():
    opt = parse_args()

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = inputters.get_num_features(opt.data_type, opt.train_dir,
                                            'src')
    qa_nfeats = inputters.get_num_features(opt.data_type, opt.train_dir, 'qa')
    tgt_nfeats = inputters.get_num_features(opt.data_type, opt.train_dir,
                                            'tgt')
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of qa features: %d." % qa_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(src_nfeats, qa_nfeats, tgt_nfeats,
                                  opt.data_type)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, opt)

    logger.info("Building & saving vocabulary...")
    # train_dataset_files = '/research/king3/yfgao/pycharm_deployment/CoQG/data/coref_flow/processed/coqg.turn3.train.pt'
    build_save_vocab(train_dataset_files, opt.data_type, fields, opt)
예제 #6
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    if not (opt.overwrite):
        check_existing_pt_files(opt)

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = 0
    # tgt_nfeats = 0
    for src in opt.train_src:
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
    logger.info(" * number of source features: %d." % src_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  src_truncate=opt.src_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)

    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, src_reader, opt)

    if opt.valid_src:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, opt)
예제 #7
0
def main():
    #pdb.set_trace()
    opt = parse_args()
    init_logger(opt.log_file)
    logger.info("Extracting features...")

    # If there are special features added -- not in our case
    src_nfeats = inputters.get_num_features(opt.data_type, opt.train_src,
                                            'src')
    tgt_nfeats = inputters.get_num_features(opt.data_type, opt.train_tgt,
                                            'tgt')
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt)

    logger.info("Building & saving validation data...")
    valid_dataset_files = build_save_dataset('valid', fields, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files + valid_dataset_files, fields, opt)
예제 #8
0
def main():
    opt = parse_args()

    assert opt.max_shard_size == 0, \
        "-max_shard_size is deprecated. Please use \
        -shard_size (number of examples) instead."
    assert opt.shuffle == 0, \
        "-shuffle is not implemented. Please shuffle \
        your data before pre-processing."

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, opt)
예제 #9
0
def main():
    opt = parse_args()

    if (opt.max_shard_size > 0):
        raise AssertionError("-max_shard_size is deprecated, please use \
                             -shard_size (number of examples) instead.")

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    # 下面的代码是尝试解决多进程prepare失败的问题,但是没有效果
    torch.multiprocessing.set_sharing_strategy('file_system')
    import resource
    rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
    resource.setrlimit(resource.RLIMIT_NOFILE, (65535, rlimit[1]))
    # END

    src_nfeats = inputters.get_num_features(opt.data_type, opt.train_src,
                                            'src')
    tgt_nfeats = inputters.get_num_features(opt.data_type, opt.train_tgt,
                                            'tgt')
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats)
    myutils.add_more_field(fields)
    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
예제 #10
0
def main():
    # Options are parsed and stored in opt object
    opt = parse_args()
    # Logging of data
    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = inputters.get_num_features(opt.data_type, opt.train_src,
                                            'src')
    tgt_nfeats = inputters.get_num_features(opt.data_type, opt.train_tgt,
                                            'tgt')
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats)

    #Generation of traing, validation and dictionary data
    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt)

    logger.info("Building & saving validation data...")
    valid_dataset_files = build_save_dataset('valid', fields, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files + valid_dataset_files, fields, opt)
예제 #11
0
def main():
    opt = parse_args()
    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = inputters.get_num_features(opt.data_type, opt.train_src,
                                            'src')
    tgt_nfeats = inputters.get_num_features(opt.data_type, opt.train_tgt,
                                            'tgt')
    ans_nfeats = inputters.get_num_features(opt.data_type, opt.train_ans,
                                            "ans")
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)
    logger.info(" * number of answer features: %d." % ans_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats,
                                  ans_nfeats)

    logger.info("fields src")
    logger.info(fields.src.__dict__)
    logger.info(fields.tgt.__dict__)
    logger.info(fields.src_map.__dict__)
    logger.info(fields.ans.__dict__)
    logger.info(fields.indices.__dict__)
    logger.info(fields.alignment.__dict__)
    '''
예제 #12
0
def main():
    opt = parse_args()

    if (opt.max_shard_size > 0):
        raise AssertionError("-max_shard_size is deprecated, please use \
                             -shard_size (number of examples) instead.")

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = inputters.get_num_features(opt.data_type, opt.train_src,
                                            'src')
    tgt_nfeats = inputters.get_num_features(opt.data_type, opt.train_tgt,
                                            'tgt')
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
예제 #13
0
def _get_fields(data_type, train_src, train_tgt):
    logger.info("Extracting features...")

    src_nfeats = inputters.get_num_features(data_type, train_src, 'src')
    tgt_nfeats = inputters.get_num_features(data_type, train_tgt, 'tgt')
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(data_type, src_nfeats, tgt_nfeats)

    return fields
예제 #14
0
def preprocess(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)

    init_logger(opt.log_file)

    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    src_nfeats = count_features(opt.train_src[0]) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt[0])  # tgt always text so far
    if len(opt.train_src) > 1 and opt.data_type == 'text':
        for src, tgt in zip(opt.train_src[1:], opt.train_tgt[1:]):
            assert src_nfeats == count_features(src),\
                "%s seems to mismatch features of "\
                "the other source datasets" % src
            assert tgt_nfeats == count_features(tgt),\
                "%s seems to mismatch features of "\
                "the other target datasets" % tgt
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    if opt.disable_eos_sampling:
        eos_token = "<blank>"
        logger.info("Using NO eos token")
    else:
        eos_token = "</s>"
        logger.info("Using standard eos token")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  with_align=opt.train_align[0] is not None,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc,
                                  eos=eos_token)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    align_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, src_reader, tgt_reader, align_reader,
                       opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader,
                           align_reader, opt)
예제 #15
0
def main():
    opt = parse_args()

    assert opt.max_shard_size == 0, \
        "-max_shard_size is deprecated. Please use \
        -shard_size (number of examples) instead."
    assert opt.shuffle == 0, \
        "-shuffle is not implemented. Please shuffle \
        your data before pre-processing."

    assert os.path.isfile(opt.train_src) and os.path.isfile(opt.train_tgt), \
        "Please check path of your train src and tgt files!"

    assert not opt.valid_src or os.path.isfile(opt.valid_src), \
        "Please check path of your valid src file!"
    assert not opt.valid_tgt or os.path.isfile(opt.valid_tgt), \
        "Please check path of your valid tgt file!"

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(
        opt.data_type,
        src_nfeats,
        tgt_nfeats,
        dynamic_dict=opt.dynamic_dict,
        src_truncate=opt.src_seq_length_trunc,
        tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset(
        'train', fields, src_reader, tgt_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
예제 #16
0
def preprocess(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)

    init_logger(opt.log_file)

    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    for src, tgt in zip(opt.train_src, opt.train_tgt):
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
        # print(src_nfeats)
    # exit()
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)
    # exit()
    ##################=======================================
    tt_nfeats = tgt_nfeats
    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tt_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  with_align=opt.train_align[0] is not None,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    ########===============================================================
    tt_reader = inputters.str2reader["text"].from_opt(opt)

    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    align_reader = inputters.str2reader["text"].from_opt(opt)

    # for k,v in fields.items():
    #     if(k in ['src','tgt','tt']):
    #         print(("preprocess .py preprocess() fields_item",k,v.fields[0][1].include_lengths))
    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, src_reader, tt_reader, tgt_reader,
                       align_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tt_reader, tgt_reader,
                           align_reader, opt)
예제 #17
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    check_existing_pt_files(opt)

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(
        opt.train_src) if opt.data_type == 'text' else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far

    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    if len(opt.src_vocab) > 0:
        assert len(opt.src_vocab) == len(
            opt.train_src
        ), "you should provide src vocab for each dataset if you want to use your own vocab"
    for i, (train_src,
            train_tgt) in enumerate(zip(opt.train_src, opt.train_tgt)):
        valid_src = opt.valid_src[i]
        valid_tgt = opt.valid_tgt[i]
        logger.info("Working on %d dataset..." % i)
        logger.info("Building `Fields` object...")
        fields = inputters.get_fields(opt.data_type,
                                      src_nfeats,
                                      tgt_nfeats,
                                      dynamic_dict=opt.dynamic_dict,
                                      src_truncate=opt.src_seq_length_trunc,
                                      tgt_truncate=opt.tgt_seq_length_trunc)

        src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
        tgt_reader = inputters.str2reader["text"].from_opt(opt)

        logger.info("Building & saving training data...")
        train_dataset_files = build_save_dataset('train', fields, src_reader,
                                                 tgt_reader, opt, i, train_src,
                                                 train_tgt, valid_src,
                                                 valid_tgt)

        if opt.valid_src and opt.valid_tgt:
            logger.info("Building & saving validation data...")
            build_save_dataset('valid', fields, src_reader, tgt_reader, opt, i,
                               train_src, train_tgt, valid_src, valid_tgt)

        logger.info("Building & saving vocabulary...")
        build_save_vocab(train_dataset_files, fields, opt, i)
예제 #18
0
def preprocess(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)

    init_logger(opt.log_file)

    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    confnet_nfeats = 0
    for src, tgt, cnet in zip(opt.train_src, opt.train_tgt, opt.train_confnet):
        src_nfeats += count_features(src) if opt.data_type == 'text' or opt.data_type == 'lattice' \
            else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
        #confnet_nfeats += count_features(cnet) if opt.data_type == 'lattice' \
        #    else 0
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)
    logger.info(" * number of confnet features: %d." % confnet_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  confnet_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  with_align=opt.train_align[0] is not None,
                                  ans_truncate=opt.src_seq_length_trunc,
                                  ques_truncate=opt.confnet_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)
    #print('fields done')
    ans_reader = inputters.str2reader["text"].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    align_reader = inputters.str2reader["text"].from_opt(opt)
    ques_reader = inputters.str2reader["lattice"].from_opt(opt)
    #print('src_reader', ques_reader)
    #print('tgt_reader', tgt_reader)
    #print('aglign_reader', align_reader)
    #print('confnet_reader', ans_reader)
    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, ques_reader, ans_reader, tgt_reader,
                       align_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, ques_reader, ans_reader,
                           tgt_reader, align_reader, opt)
예제 #19
0
def preprocess(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)

    init_logger(opt.log_file)

    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    for src, tgt in zip(opt.train_src, opt.train_tgt):
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    print("haha")
    logger.info("haha")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  with_align=opt.train_align[0] is not None,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)

    logger.info("please...")
    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)
    align_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    is_train_dataset_finished = build_save_dataset('train', fields, src_reader,
                                                   tgt_reader, align_reader,
                                                   opt)
    #if is_train_dataset_finished: # For Debugging(From sehun)
    #    logger.info("Train data is finished") # For Debugging(From sehun)
    #else: # For Debugging(From sehun)
    #    logger.info("f**k") # For Debugging(From sehun)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader,
                           align_reader, opt)
예제 #20
0
def dump_dataset(savepath, save_dev=False):
    src_corpus = savepath + '/src-train.txt'
    tgt_corpus = savepath + '/tgt-train.txt'

    src_nfeats = inputters.get_num_features('text', src_corpus, 'src')
    tgt_nfeats = inputters.get_num_features('text', tgt_corpus, 'tgt')
    fields = inputters.get_fields('text', src_nfeats, tgt_nfeats)
    fields['graph'] = torchtext.data.Field(sequential=False)
    train_dataset_files = build_save_dataset('train', fields, src_corpus,
                                             tgt_corpus, savepath, args)

    if save_dev:
        src_corpus = savepath + '/src-dev.txt'
        tgt_corpus = savepath + '/tgt-dev.txt'
        build_save_dataset('dev', fields, src_corpus, tgt_corpus, savepath,
                           args)
    build_save_vocab(train_dataset_files, fields, savepath, args)
예제 #21
0
def main():
    opt = parse_args()

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, opt)

    logger.info("Building & saving vocabulary...")
    # train_dataset_files = 'data/processed.train.pt'
    build_save_vocab(train_dataset_files, opt.data_type, fields, opt)
예제 #22
0
def main():
    opt = parse_args()
    print("[preprocess.py] opt.model_mode: {}".format(opt.model_mode))

    assert opt.max_shard_size == 0, \
        "-max_shard_size is deprecated. Please use \
        -shard_size (number of examples) instead."
    assert opt.shuffle == 0, \
        "-shuffle is not implemented. Please shuffle \
        your data before pre-processing."

    assert os.path.isfile(opt.train_src) and os.path.isfile(opt.train_tgt), \
        "Please check path of your train src and tgt files!"

    assert os.path.isfile(opt.valid_src) and os.path.isfile(opt.valid_tgt), \
        "Please check path of your valid src and tgt files!"

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    knl_nfeats = count_features(opt.train_knl) if opt.data_type == 'text' \
        else 0
    src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far
    logger.info(" * number of knowledge features: %d." % knl_nfeats)
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats,
                                  knl_nfeats)
    # fields = inputters.get_fields(opt.data_type, 1, 1, knl_nfeats)
    # {'src', 'src_feat_0', 'knl', 'src_map', 'alignment', 'tgt', 'tgt_feat_0', 'indices'}
    # {'src' 'knl', 'src_map', 'alignment', 'tgt', 'indices', 'src_da_label', 'tgt_da_label'}

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
예제 #23
0
def preprocess(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)

    init_logger(opt.log_file)

    logger.info("Extracting features...")

    src_nfeats = 0
    cue_nfeats = 0
    tgt_nfeats = 0
    for src, cue, tgt in zip(opt.train_src, opt.train_cue, opt.train_tgt):
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
        cue_nfeats += count_features(cue) if opt.data_type == 'text' else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info("* number of knowledge features: %d" % cue_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  cue_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  src_truncate=opt.src_seq_length_trunc,
                                  cue_truncate=opt.cue_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    cue_reader = inputters.str2reader["text"].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, src_reader, cue_reader, tgt_reader,
                       opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, cue_reader, tgt_reader,
                           opt)
예제 #24
0
def preprocess_main(opt):
    logger = get_logger(opt.log_file)
    src_nfeats = inputters.get_num_features(opt.data_type, opt.train_src,
                                            'src')
    tgt_nfeats = inputters.get_num_features(opt.data_type, opt.train_tgt,
                                            'tgt')
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt, logger)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt, logger)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, opt, logger)
예제 #25
0
def main():
    opt = parse_args()
    if (opt.max_shard_size > 0):
        raise AssertionError("-max_shard_size is deprecated, please use \
                             -shard_size (number of examples) instead.")

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, 0, 0)

    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(fields, opt)
예제 #26
0
def main():
    opt = parse_args()

    print("Extracting features...")
    src_nfeats = inputters.get_num_features(opt.data_type, opt.train_src,
                                            'src')
    tgt_nfeats = inputters.get_num_features(opt.data_type, opt.train_tgt,
                                            'tgt')
    print(" * number of source features: %d." % src_nfeats)
    print(" * number of target features: %d." % tgt_nfeats)

    print("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats)

    print("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt)

    print("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)

    print("Building & saving validation data...")
    build_save_dataset('valid', fields, opt)
예제 #27
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    if not (opt.overwrite):
        check_existing_pt_files(opt)

    init_logger(opt.log_file)

    shutil.copy2(opt.config, os.path.dirname(opt.log_file))
    logger.info(opt)
    logger.info("Extracting features...")

    src_nfeats = 0
    tgt_nfeats = 0
    for src, tgt in zip(opt.train_src, opt.train_tgt):
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader[opt.data_type].from_opt(opt)

    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, src_reader, tgt_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader, opt)
예제 #28
0
def main():
    opt = parse_args()
    init_logger(opt.log_file)
    logger.info("Extracting features...")


    src_nfeats = inputters.get_num_features(
        opt.data_type, opt.train_src, 'src')
    tgt_nfeats = inputters.get_num_features(
        opt.data_type, opt.train_tgt, 'tgt')
    ans_nfeats = inputters.get_num_features(
        opt.data_type, opt.train_ans, "ans")
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)
    logger.info(" * number of answer features: %d." % ans_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type, src_nfeats, tgt_nfeats, ans_nfeats)

    logger.info("fields src")
    logger.info(fields['src'].__dict__)
    logger.info(fields['tgt'].__dict__)
    logger.info(fields['src_map'].__dict__)
    logger.info(fields['ans'].__dict__)
    logger.info(fields['indices'].__dict__)
    logger.info(fields['alignment'].__dict__)


    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, opt)
    logger.info(train_dataset_files)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)

    logger.info("Building & saving validation data...")
    build_save_dataset('valid', fields, opt)
예제 #29
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    check_existing_pt_files(opt)

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(
        opt.data_type,
        src_nfeats,
        tgt_nfeats,
        dynamic_dict=opt.dynamic_dict,
        src_truncate=opt.src_seq_length_trunc,
        tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset(
        'train', fields, src_reader, tgt_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
예제 #30
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    check_existing_pt_files(opt)

    init_logger(opt.log_file)
    logger.info("Extracting features...")

    src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \
        else 0
    tgt_nfeats = count_features(opt.train_tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")

    if opt.fixed_vocab:
        tgt_bos = '<|endoftext|>'
        tgt_eos = '\u0120GDDR'
        tgt_pad = '\u0120SHALL'
        tgt_unk = '\u0120RELE'

        if opt.no_spec_src:
            src_pad = None
            src_unk = None
        elif opt.free_src:
            src_pad = '<blank>'
            src_unk = '<unk>'
        else:
            src_pad = '\u0120SHALL'
            src_unk = '\u0120RELE'

    else:
        tgt_bos = '<s>'
        tgt_eos = '</s>'
        tgt_pad = '<blank>'
        tgt_unk = '<unk>'
        src_pad = '<blank>'
        src_unk = '<unk>'

    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc,
                                  src_pad=src_pad,
                                  src_unk=src_unk,
                                  tgt_pad=tgt_pad,
                                  tgt_unk=tgt_unk,
                                  tgt_bos=tgt_bos,
                                  tgt_eos=tgt_eos,
                                  include_ptrs=opt.pointers_file is not None)

    if opt.data_type == 'none':
        src_reader = None
    else:
        src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader["text"].from_opt(opt)

    logger.info("Building & saving training data...")
    train_dataset_files = build_save_dataset('train', fields, src_reader,
                                             tgt_reader, opt)

    if (opt.valid_src or opt.data_type == 'none') and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader, opt)

    logger.info("Building & saving vocabulary...")
    build_save_vocab(train_dataset_files, fields, opt)
예제 #31
0
def main(opt):
    ArgumentParser.validate_preprocess_args(opt)
    torch.manual_seed(opt.seed)
    if not (opt.overwrite):
        check_existing_pt_files(opt)

    init_logger(opt.log_file)

    shutil.copy2(opt.config, os.path.dirname(opt.log_file))
    logger.info(opt)
    logger.info("Extracting features...")

    #Prepares the document embedding to initialize memory vectors.
    embedder = SentenceTransformer('bert-base-nli-mean-tokens')

    kpcorpus = []
    files_path = [  #'data/keyphrase/json/kp20k/kp20k_train.json',
        'data/keyphrase/json/kp20k/kp20k_valid.json',
        'data/keyphrase/json/kp20k/kp20k_test.json',
        'data/keyphrase/json/inspec/inspec_valid.json',
        'data/keyphrase/json/inspec/inspec_test.json',
        'data/keyphrase/json/krapivin/krapivin_valid.json',
        'data/keyphrase/json/krapivin/krapivin_test.json',
        'data/keyphrase/json/nus/split/nus_valid.json',
        'data/keyphrase/json/nus/split/nus_test.json',
        'data/keyphrase/json/semeval/semeval_valid.json',
        'data/keyphrase/json/semeval/semeval_test.json',
        'data/keyphrase/json/duc/split/duc_valid.json',
        'data/keyphrase/json/duc/split/duc_test.json'
    ]
    for file_path in files_path:
        file = open(file_path, 'r')
        for line in file.readlines():
            dic = json.loads(line)
            # print(dic)
            kpcorpus.append(dic['title'] + ' ' + dic['abstract'])
            # print(kpcorpus)

    num_of_example = len(kpcorpus)
    print("number of examples in corpus: ", num_of_example)
    time_a = time.time()
    corpus_embeddings = embedder.encode(kpcorpus[:num_of_example])
    print("elapsed time: ", time.time() - time_a)
    alldocs_emb = torch.Tensor(corpus_embeddings)
    torch.save(alldocs_emb, './data/alldocs_emb')

    src_nfeats = 0
    tgt_nfeats = 0
    for src, tgt in zip(opt.train_src, opt.train_tgt):
        src_nfeats += count_features(src) if opt.data_type == 'text' \
            else 0
        tgt_nfeats += count_features(tgt)  # tgt always text so far
    logger.info(" * number of source features: %d." % src_nfeats)
    logger.info(" * number of target features: %d." % tgt_nfeats)

    logger.info("Building `Fields` object...")
    fields = inputters.get_fields(opt.data_type,
                                  src_nfeats,
                                  tgt_nfeats,
                                  dynamic_dict=opt.dynamic_dict,
                                  src_truncate=opt.src_seq_length_trunc,
                                  tgt_truncate=opt.tgt_seq_length_trunc)

    src_reader = inputters.str2reader[opt.data_type].from_opt(opt)
    tgt_reader = inputters.str2reader[opt.data_type].from_opt(opt)

    logger.info("Building & saving training data...")
    build_save_dataset('train', fields, src_reader, tgt_reader, opt)

    if opt.valid_src and opt.valid_tgt:
        logger.info("Building & saving validation data...")
        build_save_dataset('valid', fields, src_reader, tgt_reader, opt)