Esempio n. 1
0
def train_snli_ex():
    hp = hyperparams.HPBert()
    hp.compare_deletion_num = 20
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"
    e_config = ExperimentConfig()
    e_config.name = "SNLIEx_B"
    e_config.ex_val = False
    e_config.num_epoch = 1
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert', 'cls_dense']  #, 'aux_conflict']

    #explain_tag = 'match'  # 'dontcare'  'match' 'mismatch'

    #explain_tag = 'mismatch'
    #explain_tag = 'conflict'
    data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename,
                                     True)
    #load_id = ("NLI_run_nli_warm", "model-97332")
    #load_id = ("NLIEx_A", "model-16910")
    #load_id = ("uncased_L-12_H-768_A-12", 'bert_model.ckpt')
    #load_id = ("NLIEx_D", "model-1964")
    #load_id = ("NLIEx_D", "model-1317")
    load_id = ("SNLI_Only_A", 'model-0')
    e.train_nli_any_way(nli_setting, e_config, data_loader, load_id)
Esempio n. 2
0
def get_snli_data(hp, nli_setting):
    data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename,
                                     True)
    tokenizer = get_tokenizer()
    CLS_ID = tokenizer.convert_tokens_to_ids(["[CLS]"])[0]
    SEP_ID = tokenizer.convert_tokens_to_ids(["[SEP]"])[0]
    data_loader.CLS_ID = CLS_ID
    data_loader.SEP_ID = SEP_ID
    cache_name = "snli_batch{}_seq{}".format(hp.batch_size, hp.seq_max)
    data = load_cache(cache_name)
    if data is None:
        tf_logger.info("Encoding data from csv")
        data = get_nli_batches_from_data_loader(data_loader, hp.batch_size)
        save_to_pickle(data, cache_name)
    return data
Esempio n. 3
0
def pred_snli_ex():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"
    e_config = ExperimentConfig()
    e_config.name = "SNLIEx_B"
    e_config.load_names = ['bert', 'cls_dense', 'aux_conflict']

    data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename,
                                     True)

    load_id = ("SNLIEx_B", 'model-10275')
    e.predict_rf(nli_setting, e_config, data_loader, load_id, "test")
Esempio n. 4
0
def predict_lime_snli_continue():
    hp = hyperparams.HPBert()
    hp.batch_size = 512 + 256
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"

    e_config = ExperimentConfig()
    e_config.name = "SNLI_LIME_{}".format("eval")
    e_config.load_names = ['bert', 'cls_dense']

    data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename,
                                     True)
    load_id = ("SNLI_Only_A", 'model-0')
    e.predict_lime_snli_continue(nli_setting, e_config, data_loader, load_id,
                                 "test")
Esempio n. 5
0
def test_snli():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"

    e_config = ExperimentConfig()
    e_config.name = "SNLIEx_Test"
    e_config.load_names = ['bert', 'cls_dense']  # , 'aux_conflict']
    data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename,
                                     True)
    todo = []
    load_id = ("SNLI_Only_A", 'model-0')
    todo.append(load_id)
    todo.append(("SNLI_Only_1", 'model-0'))

    for load_id in todo:
        tf.reset_default_graph()
        e.test_acc(nli_setting, e_config, data_loader, load_id)
Esempio n. 6
0
def train_snli_on_bert():
    hp = hyperparams.HPBert()
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"
    e_config = ExperimentConfig()
    e_config.name = "SNLI_Only_{}".format("1")
    e_config.num_epoch = 1
    e_config.save_interval = 3 * 60 * 60  # 30 minutes

    data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename,
                                     True)
    #load_id = None
    load_id = ("uncased_L-12_H-768_A-12", 'bert_model.ckpt')
    e.train_nli_ex_0(nli_setting,
                     e_config,
                     data_loader,
                     load_id,
                     f_train_ex=False)
Esempio n. 7
0
def predict_lime_snli():
    hp = hyperparams.HPBert()
    hp.batch_size = 1024 + 512 + 256
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"

    e_config = ExperimentConfig()
    e_config.name = "SNLI_LIME_{}".format("eval")
    e_config.load_names = ['bert', 'cls_dense']

    start = int(sys.argv[1])
    print("Begin", start)
    sub_range = (start, start + 100)

    data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename,
                                     True)
    load_id = ("SNLI_Only_A", 'model-0')
    e.predict_lime_snli(nli_setting, e_config, data_loader, load_id, "test",
                        sub_range)
Esempio n. 8
0
def get_eval_params(load_type, model_path, data_type):
    hp = hyperparams.HPSENLI3_eval()
    hp.batch_size = 128
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"
    if data_type == "mnli":
        data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename,
                                     True)
    elif data_type == "snli":
        data_loader = nli.SNLIDataLoader(hp.seq_max,
                                         nli_setting.vocab_filename, True)
    else:
        assert False

    dir_path, file_name = os.path.split(model_path)
    run_name = os.path.split(dir_path)[1] + "/" + file_name
    dev_batches = get_batches_ex(data_loader.get_dev_data(), hp.batch_size, 4)
    if load_type == "v2":
        load_fn = load_bert_v2
    else:
        load_fn = load_model
    return dev_batches, hp, load_fn, nli_setting, run_name