示例#1
0
def train_test_repeat(load_id, exp_name, n_repeat):
    hp = hyperparams.HPBert()
    e_config = ExperimentConfig()
    e_config.name = "RTE_{}".format("A")
    e_config.num_epoch = 10
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert']
    vocab_filename = "bert_voca.txt"
    data_loader = rte.DataLoader(hp.seq_max, vocab_filename, True)

    print(load_id)
    scores = []
    for i in range(n_repeat):
        e = Experiment(hp)
        print(exp_name)
        e_config.name = "rte_{}".format(exp_name)
        save_path = e.train_rte(e_config, data_loader, load_id)
        acc = e.eval_rte(e_config, data_loader, save_path)
        scores.append(acc)
    print(exp_name)
    for e in scores:
        print(e, end="\t")
    print()
    r = average(scores)
    print("Avg\n{0:.03f}".format(r))
    return r
示例#2
0
文件: nli_main.py 项目: clover3/Chair
def test_fidelity():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"

    is_senn = False

    e_config = ExperimentConfig()
    e_config.name = "NLIEx_{}".format("Fidelity")
    e_config.num_epoch = 4
    e_config.save_interval = 30 * 60  # 30 minutes
    if is_senn:
        e_config.load_names = ['bert', 'cls_dense', 'aux_conflict']
    else:
        e_config.load_names = ['bert', 'cls_dense']
    explain_tag = 'conflict'

    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    load_id = ("NLIEx_Y_conflict", 'model-12039')
    #load_id = ("NLI_Only_C", 'model-0')
    #e.eval_fidelity(nli_setting, e_config, data_loader, load_id, explain_tag)
    e.eval_fidelity_gradient(nli_setting, e_config, data_loader, load_id,
                             explain_tag)
示例#3
0
def ukp_train_test_repeat(load_id, exp_name, topic, n_repeat):
    hp = hyperparams.HPBert()
    e_config = ExperimentConfig()
    e_config.num_epoch = 2
    e_config.save_interval = 100 * 60  # 30 minutes
    e_config.voca_size = 30522
    e_config.load_names = ['bert']
    encode_opt = "is_good"

    print(load_id)
    scores = []
    for i in range(n_repeat):
        e = Experiment(hp)
        print(exp_name)
        e_config.name = "arg_{}_{}_{}".format(exp_name, topic, encode_opt)
        data_loader = BertDataLoader(topic,
                                     True,
                                     hp.seq_max,
                                     "bert_voca.txt",
                                     option=encode_opt)
        save_path = e.train_ukp(e_config, data_loader, load_id)
        f1_last = e.eval_ukp(e_config, data_loader, save_path)
        scores.append(f1_last)
    print(exp_name)
    print(encode_opt)
    for e in scores:
        print(e, end="\t")
    print()
    print("Avg\n{0:.03f}".format(average(scores)))
示例#4
0
文件: nli_main.py 项目: clover3/Chair
def analyze_nli_ex():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"
    explain_tag = 'match'

    e_config = ExperimentConfig()
    #e_config.name = "NLIEx_{}_premade_analyze".format(explain_tag)
    e_config.name = "NLIEx_{}_analyze".format(explain_tag)
    e_config.num_epoch = 4
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert', 'cls_dense', 'aux_conflict']

    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    #load_id = ("NLIEx_E_align", "model-23621")
    #load_id = ("NLIEx_I_match", "model-1238")

    if explain_tag == 'conflict':
        load_id = ("NLIEx_Y_conflict", "model-12039")
        #load_id = ("NLIEx_HB", "model-2684")
    elif explain_tag == 'match':
        load_id = ("NLIEx_P_match", "model-1636")
        load_id = ("NLIEx_X_match", "model-12238")
    elif explain_tag == 'mismatch':
        load_id = ("NLIEx_U_mismatch", "model-10265")
    e.nli_visualization(nli_setting, e_config, data_loader, load_id,
                        explain_tag)
示例#5
0
文件: nli_main.py 项目: clover3/Chair
def train_snli_ex():
    hp = hyperparams.HPBert()
    hp.compare_deletion_num = 20
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"
    e_config = ExperimentConfig()
    e_config.name = "SNLIEx_B"
    e_config.ex_val = False
    e_config.num_epoch = 1
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert', 'cls_dense']  #, 'aux_conflict']

    #explain_tag = 'match'  # 'dontcare'  'match' 'mismatch'

    #explain_tag = 'mismatch'
    #explain_tag = 'conflict'
    data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename,
                                     True)
    #load_id = ("NLI_run_nli_warm", "model-97332")
    #load_id = ("NLIEx_A", "model-16910")
    #load_id = ("uncased_L-12_H-768_A-12", 'bert_model.ckpt')
    #load_id = ("NLIEx_D", "model-1964")
    #load_id = ("NLIEx_D", "model-1317")
    load_id = ("SNLI_Only_A", 'model-0')
    e.train_nli_any_way(nli_setting, e_config, data_loader, load_id)
示例#6
0
def get_nli_ex_model_segmented(input_ids, input_mask, segment_ids):
    method = 5
    hp = hyperparams.HPBert()
    voca_size = 30522
    sequence_shape = bert_common.get_shape_list2(input_ids)
    batch_size = sequence_shape[0]

    step = 200
    pad_len = 200 - 1 - (512 - (step * 2 - 1))

    def spread(t):
        cls_token = t[:, :1]
        pad = tf.ones([batch_size, pad_len], tf.dtypes.int32) * PAD_ID
        a = t[:, :step]
        b = tf.concat([cls_token, t[:, step:step * 2 - 1]], axis=1)
        c = tf.concat([cls_token, t[:, step * 2 - 1:], pad], axis=1)
        return tf.concat([a, b, c], axis=0)

    def collect(t):
        a = t[:batch_size]
        b = t[batch_size:batch_size * 2, 1:]
        c = t[batch_size * 2:, 1:-pad_len]
        return tf.concat([a, b, c], axis=1)

    model = transformer_nli(hp, spread(input_ids), spread(input_mask),
                            spread(segment_ids), voca_size, method, False)
    output = model.conf_logits
    output = collect(output)
    return output
示例#7
0
def ukp_train_test(load_id, exp_name):
    hp = hyperparams.HPBert()
    e_config = ExperimentConfig()
    e_config.num_epoch = 2
    e_config.save_interval = 100 * 60  # 30 minutes
    e_config.voca_size = 30522
    e_config.load_names = ['bert']
    encode_opt = "is_good"

    print(load_id)
    f1_list = []
    for topic in data_generator.argmining.ukp_header.all_topics:
        e = Experiment(hp)
        print(exp_name)
        e_config.name = "arg_{}_{}_{}".format(exp_name, topic, encode_opt)
        data_loader = BertDataLoader(topic,
                                     True,
                                     hp.seq_max,
                                     "bert_voca.txt",
                                     option=encode_opt)
        save_path = e.train_ukp(e_config, data_loader, load_id)
        print(topic)
        f1_last = e.eval_ukp(e_config, data_loader, save_path)
        f1_list.append((topic, f1_last))
    print(exp_name)
    print(encode_opt)
    print(f1_list)
    for key, score in f1_list:
        print("{0}\t{1:.03f}".format(key, score))
示例#8
0
文件: nli_main.py 项目: clover3/Chair
def predict_rf():
    hp = hyperparams.HPBert()
    hp.batch_size = 256
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"
    target_label = 'mismatch'
    #data_id = 'test_conflict'
    data_id = "{}_1000".format(target_label)
    e_config = ExperimentConfig()

    #del_g = 0.7
    #e_config.name = "X_match_del_{}".format(del_g)
    e_config.name = "NLIEx_AnyA_{}".format(target_label)
    e_config.load_names = ['bert', 'cls_dense', 'aux_conflict']

    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    #load_id = ("NLI_bare_A", 'model-195608')
    #load_id = ("NLIEx_O", 'model-10278')
    load_id = ("NLIEx_W_mismatch", "model-12030")
    load_id = ("NLIEx_Y_conflict", "model-12039")
    load_id = ("NLIEx_X_match", "model-12238")
    #load_id = ("NLIEx_match_del_{}".format(del_g), "model-4390")
    load_id = ("NLIEx_CE_{}".format(target_label), "model-12199")
    load_id = ("NLIEx_AnyA", "model-7255")
    e.predict_rf(nli_setting, e_config, data_loader, load_id, data_id, 5)
示例#9
0
def eval_ukp_with_nli(exp_name):
    step_per_epoch = 24544 + 970

    hp = hyperparams.HPBert()
    e_config = ExperimentConfig()
    e_config.num_steps = step_per_epoch
    e_config.voca_size = 30522
    e_config.num_dev_batches = 30
    e_config.load_names = ['bert']
    encode_opt = "is_good"
    num_class_list = [3, 3]
    f1_list = []
    save_path = "/mnt/scratch/youngwookim/Chair/output/model/runs/argmix_AN_B_40000_abortion_is_good/model-21306"
    for topic in data_generator.argmining.ukp_header.all_topics[:1]:
        e = Experiment(hp)
        print(exp_name)
        e_config.name = "argmix_{}_{}_{}".format(exp_name, topic, encode_opt)
        arg_data_loader = BertDataLoader(topic,
                                         True,
                                         hp.seq_max,
                                         "bert_voca.txt",
                                         option=encode_opt)
        f1_last = e.eval_ukp_on_shared(e_config, arg_data_loader,
                                       num_class_list, save_path)
        f1_list.append((topic, f1_last))
    print(exp_name)
    print(encode_opt)
    print(f1_list)
    for key, score in f1_list:
        print("{0}\t{1:.03f}".format(key, score))
示例#10
0
def get_nli_ex_model(input_ids, input_mask, segment_ids):
    method = 5
    hp = hyperparams.HPBert()
    voca_size = 30522

    model = transformer_nli(hp, input_ids, input_mask, segment_ids, voca_size,
                            method, False)
    output = model.conf_logits

    return output
示例#11
0
def run(token_path, ranked_list_path, start_model_path, output_path):
    voca_size = 30522
    target_topic = get_topic_from_path(ranked_list_path)
    print(target_topic)
    hp = hyperparams.HPBert()
    vocab_filename = "bert_voca.txt"
    voca_path = os.path.join(cpath.data_path, vocab_filename)
    batch_size = 256

    encoder = EncoderUnit(hp.seq_max, voca_path)
    seg_b = encoder.encoder.encode(target_topic + " is good")

    def encode(tokens):
        seg_a = encoder.encoder.ft.convert_tokens_to_ids(tokens)
        d = encoder.encode_inner(seg_a, seg_b)
        return d["input_ids"], d["input_mask"], d["segment_ids"], 0

    token_reader = CluewebTokenReader(target_topic, ranked_list_path,
                                      token_path)

    def predict_list(batch_size, tokens_list):
        predictor = UkpExPredictor(hp, voca_size, start_model_path)
        entry_itr = [encode(tokens) for tokens in tokens_list]
        print("len(tokens_list)", len(tokens_list))
        result = []
        for idx, batch in enumerate(
                batch_iter_from_entry_iter(batch_size, entry_itr)):
            result.append(predictor.run(batch))
            if idx % 100 == 0:
                print(idx)
        r = flatten_from_batches(result)
        print("len(r)", len(r))
        return r

    # iterate token reader and schedule task with promise keeper
    pk = PromiseKeeper(partial(predict_list, batch_size))
    result_list = []
    for idx, doc in enumerate(token_reader.iter_docs()):
        future_list = []
        for sent in doc:
            promise = MyPromise(sent, pk)
            future_list.append(promise.future())
        result_list.append(future_list)

        if idx == 1000:
            break

    # encode promise into the batches and run them
    pk.do_duty()

    r = []
    for future_list in result_list:
        r.append(list_future(future_list))

    pickle.dump(r, open(output_path, "wb"))
示例#12
0
def gradient_rte_visulize():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    vocab_filename = "bert_voca.txt"
    load_id = loader.find_model_name("RTE_A")
    e_config = ExperimentConfig()
    e_config.name = "RTE_{}".format("visual")
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert', 'cls_dense']

    data_loader = rte.DataLoader(hp.seq_max, vocab_filename, True)
    e.rte_visualize(e_config, data_loader, load_id)
示例#13
0
def do_fetch_param():
    hp = hyperparams.HPBert()
    voca_size = 30522
    encode_opt = "is_good"
    topic = "abortion"
    load_run_name = "arg_nli_{}_is_good".format(topic)
    run_name = "arg_{}_{}_{}".format("fetch_grad", topic, encode_opt)
    data_loader = BertDataLoader(topic, True, hp.seq_max, "bert_voca.txt", option=encode_opt)
    model_path = get_model_full_path(load_run_name)
    names, vars = fetch_params(hp, voca_size, run_name, data_loader, model_path)
    r = names, vars
    pickle.dump(r, open(os.path.join(output_path, "params.pickle"), "wb"))
示例#14
0
文件: nli_main.py 项目: clover3/Chair
def train_nil():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"

    e_config = ExperimentConfig()
    e_config.name = "NLI_only_{}".format("B")
    e_config.num_epoch = 2
    e_config.save_interval = 30 * 60  # 30 minutes

    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    e.train_nli(nli_setting, e_config, data_loader)
示例#15
0
def protest_bert():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    e_config = ExperimentConfig()
    e_config.name = "protest"
    e_config.num_epoch = 1
    e_config.save_interval = 1 * 60  # 1 minutes
    e_config.load_names = ['bert']
    vocab_size = 30522
    vocab_filename = "bert_voca.txt"

    data_loader = protest.DataLoader(hp.seq_max, vocab_filename, vocab_size)
    load_id = ("uncased_L-12_H-768_A-12", 'bert_model.ckpt')
    e.train_protest(e_config, data_loader, load_id)
示例#16
0
def train_rte():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    vocab_filename = "bert_voca.txt"
    #load_id = ("uncased_L-12_H-768_A-12", 'bert_model.ckpt')
    load_id = ("tlm_simple", "model.ckpt-15000")

    e_config = ExperimentConfig()
    e_config.name = "RTE_{}".format("tlm_simple_15000")
    e_config.num_epoch = 10
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert']

    data_loader = rte.DataLoader(hp.seq_max, vocab_filename, True)
    e.train_rte(e_config, data_loader, load_id)
示例#17
0
文件: nli_main.py 项目: clover3/Chair
def pred_snli_ex():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"
    e_config = ExperimentConfig()
    e_config.name = "SNLIEx_B"
    e_config.load_names = ['bert', 'cls_dense', 'aux_conflict']

    data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename,
                                     True)

    load_id = ("SNLIEx_B", 'model-10275')
    e.predict_rf(nli_setting, e_config, data_loader, load_id, "test")
示例#18
0
def train_nil_cold():
    print('train_nil_cold')
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"
    e_config = ExperimentConfig()
    e_config.name = "NLI_Cold"
    e_config.num_epoch = 2
    e_config.save_interval = 30 * 60  # 30 minutes

    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    saved = e.train_nli_ex_0(nli_setting, e_config, data_loader, None, False)
    e.test_acc2(nli_setting, e_config, data_loader, saved)
示例#19
0
文件: nli_main.py 项目: clover3/Chair
def pred_mnli_anyway():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"
    e_config = ExperimentConfig()
    e_config.name = "NLIEx_AnyA"
    e_config.load_names = ['bert', 'cls_dense', 'aux_conflict']

    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    target_label = 'mismatch'
    data_id = "{}_1000".format(target_label)
    load_id = ("NLIEx_AnyA", 'model-2785')
    e.predict_rf(nli_setting, e_config, data_loader, load_id, data_id)
示例#20
0
def wikicont_bert():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    e_config = ExperimentConfig()
    e_config.name = "WikiContrv2009_only_wiki"
    e_config.num_epoch = 1
    e_config.save_interval = 60 * 60  # 1 minutes
    e_config.load_names = ['bert']
    e_config.valid_freq = 100
    vocab_size = 30522
    vocab_filename = "bert_voca.txt"

    data_loader = Ams18.DataLoader(hp.seq_max, vocab_filename, vocab_size)
    data_loader.source_collection.collection_type = 'wiki'
    load_id = ("uncased_L-12_H-768_A-12", 'bert_model.ckpt')
    e.train_wiki_contrv(e_config, data_loader, load_id)
示例#21
0
def train_nli_with_premade(explain_tag):
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"

    e_config = ExperimentConfig()
    e_config.name = "NLIEx_{}".format("Premade_"+explain_tag)
    e_config.num_epoch = 1
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert'] #, 'cls_dense'] #, 'aux_conflict']

    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    load_id = ("uncased_L-12_H-768_A-12", 'bert_model.ckpt')
    e.train_nli_ex_with_premade_data(nli_setting, e_config, data_loader, load_id, explain_tag)
示例#22
0
def test_nli():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"
    e_config = ExperimentConfig()
    e_config.name = "NLI_400k_tlm_simple_wo_hint"
    e_config.num_epoch = 2
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert', 'cls_dense']  # , 'aux_conflict']
    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    #saved = "/mnt/scratch/youngwookim/Chair/output/model/runs/NLI_Cold/model-0"
    saved = "/mnt/scratch/youngwookim/Chair/output/model/runs/NLI_400k_tlm_wo_hint/model-0"
    saved = '/mnt/scratch/youngwookim/Chair/output/model/runs/NLI_400k_tlm_simple_hint/model-0'
    print(saved)
    e.test_acc2(nli_setting, e_config, data_loader, saved)
示例#23
0
文件: nli_main.py 项目: clover3/Chair
def attribution_explain():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"

    e_config = ExperimentConfig()
    e_config.name = "NLI_run_{}".format("nli_eval")
    e_config.num_epoch = 4
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert', 'cls_dense']

    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    load_id = ("NLI_run_nli_warm", "model-97332")
    load_id = ("NLI_Only_A", 'model-0')
    e.nli_attribution_baselines(nli_setting, e_config, data_loader, load_id)
示例#24
0
文件: nli_main.py 项目: clover3/Chair
def train_nil_on_bert():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"
    e_config = ExperimentConfig()
    e_config.name = "NLI_Only_{}".format("C")
    e_config.num_epoch = 2
    e_config.save_interval = 30 * 60  # 30 minutes

    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    #load_id = None
    load_id = ("uncased_L-12_H-768_A-12", 'bert_model.ckpt')
    #load_id = ("NLI_bert_w_explain", 'model-91531')
    #load_id = ("NLI_Only_A", "model-0")
    e.train_nli_ex_0(nli_setting, e_config, data_loader, load_id, False)
示例#25
0
文件: nli_main.py 项目: clover3/Chair
def interactive():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"

    e_config = ExperimentConfig()
    e_config.name = "NLIInterative"
    e_config.num_epoch = 1
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert', 'cls_dense']  #, 'aux_conflict']

    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    load_id = ("NLI_Only_B", 'model-0')

    e.nli_interactive_list(nli_setting, e_config, data_loader, load_id)
示例#26
0
文件: nli_main.py 项目: clover3/Chair
def analyze_nli_pair():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"

    e_config = ExperimentConfig()
    e_config.name = "NLIEx_pair_analyze"
    e_config.num_epoch = 4
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert', 'cls_dense', 'aux_conflict']

    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    load_id = ("NLIEx_T", "model-12097")
    e.nli_visualization_pairing(nli_setting, e_config, data_loader, load_id,
                                data)
示例#27
0
文件: nli_main.py 项目: clover3/Chair
def predict_lime_snli_continue():
    hp = hyperparams.HPBert()
    hp.batch_size = 512 + 256
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"

    e_config = ExperimentConfig()
    e_config.name = "SNLI_LIME_{}".format("eval")
    e_config.load_names = ['bert', 'cls_dense']

    data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename,
                                     True)
    load_id = ("SNLI_Only_A", 'model-0')
    e.predict_lime_snli_continue(nli_setting, e_config, data_loader, load_id,
                                 "test")
示例#28
0
def train_agree():
    hp = hyperparams.HPBert()

    e_config = ExperimentConfig()
    e_config.num_epoch = 2
    e_config.save_interval = 100 * 60  # 30 minutes
    e_config.voca_size = 30522
    e_config.load_names = ['bert']
    load_id = ("uncased_L-12_H-768_A-12", 'bert_model.ckpt')
    exp_purpose = "(dis)agree train"

    e = Experiment(hp)
    print(exp_purpose)
    e_config.name = "AgreeTrain"
    vocab_filename = "bert_voca.txt"
    data_loader = agree.DataLoader(hp.seq_max, vocab_filename)
    save_path = e.train_agree(e_config, data_loader, load_id)
    print(exp_purpose)
示例#29
0
文件: nli_main.py 项目: clover3/Chair
def do_test_dev_acc():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"

    e_config = ExperimentConfig()
    e_config.name = "NLIEx_Test"
    e_config.load_names = ['bert', 'cls_dense']  #, 'aux_conflict']

    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    #load_id = ("NLI_bare_A", 'model-195608')
    load_id = ("NLIEx_S", 'model-4417')
    load_id = ("NLIEx_Y_conflict", "model-9636")
    load_id = ("NLI_Only_C", 'model-0')

    e.test_acc(nli_setting, e_config, data_loader, load_id)
示例#30
0
文件: nli_main.py 项目: clover3/Chair
def interactive_visual():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"

    e_config = ExperimentConfig()
    e_config.name = "NLIInterative"
    e_config.num_epoch = 1
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert', 'cls_dense', 'aux_conflict']
    load_id = ("NLIEx_U_mismatch", "model-10265")
    load_id = ("NLIEx_Y_conflict", 'model-12039')

    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)

    e.nli_interactive_visual(nli_setting, e_config, data_loader, load_id)