コード例 #1
0
def nmt_construction():
    config_file = "/home/user_data/baoy/projects/seq2seq_parser/configs/data_configs/nmt.yaml"
    args_dict = yaml_load_dict(config_file)
    args = dict_to_args(args_dict)
    data_dir = args.origin_tgts
    data_dict = {
        "train": "train.s2b",
        "dev": "dev.s2b",
        "test": "test.s2b",
    }
    make_dataset(
        data_dir=data_dir,
        data_dict=data_dict,
        tgt_dir=data_dir,
        max_src_vocab=args.max_src_vocab,
        max_tgt_vocab=args.max_tgt_vocab,
        vocab_freq_cutoff=args.cut_off,
        max_src_length=-1,
        max_tgt_length=-1,
    )
    make_dataset(
        data_dir=data_dir,
        data_dict=data_dict,
        tgt_dir=args.sample_tgts,
        max_src_vocab=args.max_src_vocab,
        max_tgt_vocab=args.max_tgt_vocab,
        vocab_freq_cutoff=args.cut_off,
        max_src_length=args.max_src_length,
        max_tgt_length=args.max_tgt_length,
    )
コード例 #2
0
def prepare_con():
    config_file = "/home/user_data/baoy/projects/seq2seq_parser/configs/data_configs/quora.yaml"
    args_dict = yaml_load_dict(config_file)
    args = dict_to_args(args_dict)
    pair_file = os.path.join(args.origin_tgts, args.pair_file)
    unpair_file = os.path.join(args.origin_tgts, args.unpair_file)
    pair_tree = load_tree_file(pair_file)
    unpair_tree = load_tree_file(unpair_file)
    all_size = args.train_size + args.valid_size + args.test_size
    all_tree = pair_tree + unpair_tree

    all_idx = random.sample(range(len(all_tree)), all_size)
    train_idx = all_idx[:args.train_size]
    valid_idx = all_idx[args.train_size:args.train_size + args.valid_size]
    test_idx = all_idx[-args.test_size:]

    train_tree = [all_tree[idx] for idx in train_idx]
    valid_tree = [all_tree[idx] for idx in valid_idx]
    test_tree = [all_tree[idx] for idx in test_idx]
    if not os.path.exists(args.data_tgts):
        os.makedirs(args.data_tgts)
    write_docs(docs=train_tree,
               fname=os.path.join(args.data_tgts, "train.con"))
    write_docs(docs=valid_tree, fname=os.path.join(args.data_tgts, "dev.con"))
    write_docs(docs=test_tree, fname=os.path.join(args.data_tgts, "test.con"))
コード例 #3
0
def prepare_s2b(
    config_file="/home/user_data/baoy/projects/seq2seq_parser/configs/data_configs/quora.yaml"
):
    args_dict = yaml_load_dict(config_file)
    args = dict_to_args(args_dict)
    make_s2b_dataset(train_file=os.path.join(args.data_tgts, "train.con"),
                     dev_file=os.path.join(args.data_tgts, "dev.con"),
                     test_file=os.path.join(args.data_tgts, "test.con"),
                     tgt_dir=args.data_tgts)
コード例 #4
0
def webnlg_preprocess():
    config_file = "/home/user_data/baoy/projects/seq2seq_parser/configs/data_configs/webnlg.yaml"
    args_dict = yaml_load_dict(config_file)
    args = dict_to_args(args_dict)
    make_nmt_simple_dataset(
        train_file=args.train_file,
        dev_file=args.dev_file,
        test_file=args.test_file,
        tgt_dir=args.origin_tgts,
    )
コード例 #5
0
def prepare_raw_paraphrase():
    config_file = "/home/user_data/baoy/projects/seq2seq_parser/configs/data_configs/quora.yaml"
    args_dict = yaml_load_dict(config_file)
    args = dict_to_args(args_dict)
    with open(os.path.join(args.origin_tgts, "para.raw.token"), "r") as f:
        pair_raws = [line for line in f.readlines()]

    select_idx = random.sample(range(len(pair_raws)), args.test_size)
    test_paraphrase = []
    for idx in select_idx:
        test_paraphrase.append(pair_raws[idx].strip())
    write_docs(fname=os.path.join(args.origin_tgts, "para.raw.text"),
               docs=test_paraphrase)
コード例 #6
0
def prepare_paraphrase():
    config_file = "/home/user_data/baoy/projects/seq2seq_parser/configs/data_configs/quora.yaml"
    args_dict = yaml_load_dict(config_file)
    args = dict_to_args(args_dict)
    pair_file = os.path.join(args.origin_tgts, args.pair_file)
    pair_trees = load_to_pair_tree(pair_file)

    select_idx = random.sample(range(len(pair_trees)), args.test_size)
    test_paraphrase = []
    for idx in select_idx:
        pair_tree = pair_trees[idx]
        test_paraphrase.append("\t".join(
            [pair_tree[0].words, pair_tree[1].words]))
    write_docs(fname=os.path.join(args.origin_tgts, "para.text"),
               docs=test_paraphrase)
コード例 #7
0
def parsed_to_pair(args=None):
    if args is None:
        config_file = "/home/user_data/baoy/projects/seq2seq_parser/configs/data_configs/quora.yaml"
        args_dict = yaml_load_dict(config_file)
        args = dict_to_args(args_dict)
    origin_file = os.path.join(args.origin_tgts, args.origin_file)
    label_file = os.path.join(args.origin_tgts, args.label_file)
    pair_file = os.path.join(args.origin_tgts, args.pair_file)
    unpair_file = os.path.join(args.origin_tgts, args.unpair_file)

    with open(origin_file, "r") as f:
        tree_list = [line.strip() for line in f.readlines()]

    with open(label_file, "r") as f:
        label_list = [line.strip().split(" ") for line in f.readlines()]

    pair_list = []
    unpair_list = []
    for label in label_list:
        try:
            num_i1 = int(label[0]) - 1
            t1 = tree_list[num_i1]
            num_i2 = int(label[1]) - 1
            t2 = tree_list[num_i2]
            num_l = int(label[2])

            if len(t1.strip()) > 0 and len(t2.strip()) > 0:
                item = [t1, t2]
                item_str = "\t".join(item)
                if num_l == 0:
                    unpair_list.append(item_str)
                elif num_l == 1:
                    pair_list.append(item_str)
            else:
                print(t1)
                print(t2)
        except:
            pass

    with open(pair_file, "w") as f:
        for line in pair_list:
            f.write(line.strip())
            f.write("\n")
    with open(unpair_file, "w") as f:
        for line in unpair_list:
            f.write(line.strip())
            f.write("\n")
コード例 #8
0
def snli_sample_construction(is_write=True):
    config_file = "/home/user_data/baoy/projects/seq2seq_parser/configs/data_configs/snli-sample.yaml"
    args_dict = yaml_load_dict(config_file)
    args = dict_to_args(args_dict)
    data_dir = args.origin_tgts
    data_dict = {
        "train": "train.s2b",
        "dev": "dev.s2b",
        "test": "test.s2b",
    }
    make_dataset(
        data_dir=data_dir,
        data_dict=data_dict,
        tgt_dir=args.target_tgts,
        max_src_vocab=args.max_src_vocab,
        max_tgt_vocab=args.max_tgt_vocab,
        vocab_freq_cutoff=args.cut_off,
        max_src_length=-1,
        max_tgt_length=-1,
        write_down=is_write,
    )
コード例 #9
0
ファイル: main.py プロジェクト: wangqi1996/DSS-VAE-pytorch
def process_args():
    opt_parser = argparse.ArgumentParser()
    opt_parser.add_argument('--config_files', type=str, help='config_files')
    opt_parser.add_argument('--exp_name', type=str, help='config_files')
    opt_parser.add_argument('--load_src_lm', type=str, default=None)
    opt_parser.add_argument('--mode', type=str, default=None)
    opt = opt_parser.parse_args()

    configs = yaml_load_dict(opt.config_files)

    base_args = dict_to_args(configs['base_configs']) if 'base_configs' in configs else None
    baseline_args = dict_to_args(configs['baseline_configs']) if 'baseline_configs' in configs else None
    prior_args = dict_to_args(configs['prior_configs']) if 'prior_configs' in configs else None
    encoder_args = dict_to_args(configs['encoder_configs']) if 'encoder_configs' in configs else None
    decoder_args = dict_to_args(configs['decoder_configs']) if 'decoder_configs' in configs else None
    vae_args = dict_to_args(configs['vae_configs']) if 'vae_configs' in configs else None
    ae_args = dict_to_args(configs["ae_configs"]) if 'ae_configs' in configs else None
    nae_args = dict_to_args(configs["nag_configs"]) if 'nag_configs' in configs else None

    if base_args is not None:
        if opt.mode is not None:
            base_args.mode = opt.mode
        if opt.exp_name is not None:
            base_args.exp_name = opt.exp_name
        if opt.load_src_lm is not None:
            base_args.load_src_lm = opt.load_src_lm

    return {
        'base': base_args,
        "baseline": baseline_args,
        'prior': prior_args,
        'encoder': encoder_args,
        "decoder": decoder_args,
        "vae": vae_args,
        "ae": ae_args,
        "nae": nae_args
    }
コード例 #10
0
def quora_construction(
        config_file="/home/user_data/baoy/projects/seq2seq_parser/configs/data_configs/quora-50k.yaml",
        is_write=True):
    # config_file = "/home/user_data/baoy/projects/seq2seq_parser/configs/data_configs/quora-50k.yaml"
    args_dict = yaml_load_dict(config_file)
    args = dict_to_args(args_dict)
    data_dir = args.origin_tgts
    data_dict = {
        "train": "train.s2b",
        "dev": "dev.s2b",
        "test": "test.s2b",
    }
    make_dataset(data_dir=data_dir,
                 data_dict=data_dict,
                 tgt_dir=args.data_tgts,
                 max_src_vocab=args.max_src_vocab,
                 max_tgt_vocab=args.max_tgt_vocab,
                 vocab_freq_cutoff=args.cut_off,
                 max_src_length=args.max_src_length,
                 max_tgt_length=args.max_tgt_length,
                 train_size=args.train_size,
                 write_down=is_write)
コード例 #11
0
    print("Filter")
    data_details(train, dev, test)
    test, dev, train = load_ptb_to_s2b(
        data_dir=args.data_dirs,
        data_dict=data_dicts,
        same_filter=False,
    )
    print("Origin")
    data_details(train, dev, test)

"""

if __name__ == '__main__':
    config_file = "/home/user_data/baoy/projects/seq2seq_parser/configs/snli_data.yaml"
    args_dict = yaml_load_dict(config_file)
    args = dict_to_args(args_dict)
    data_dicts = {
        'train': args.train_file,
        'dev': args.dev_file,
        'test': args.test_file,
    }
    t_dict = {
        'train': args.process_train,
        'dev': args.process_dev,
        'test': args.process_test,
    }

    prepare_ptb_to_s2b(
        data_dir=args.data_dirs,
        data_dict=data_dicts,
        target_dir=args.sample_dirs,