def train_snli_ex(): hp = hyperparams.HPBert() hp.compare_deletion_num = 20 e = Experiment(hp) nli_setting = NLI() nli_setting.vocab_size = 30522 nli_setting.vocab_filename = "bert_voca.txt" e_config = ExperimentConfig() e_config.name = "SNLIEx_B" e_config.ex_val = False e_config.num_epoch = 1 e_config.save_interval = 30 * 60 # 30 minutes e_config.load_names = ['bert', 'cls_dense'] #, 'aux_conflict'] #explain_tag = 'match' # 'dontcare' 'match' 'mismatch' #explain_tag = 'mismatch' #explain_tag = 'conflict' data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename, True) #load_id = ("NLI_run_nli_warm", "model-97332") #load_id = ("NLIEx_A", "model-16910") #load_id = ("uncased_L-12_H-768_A-12", 'bert_model.ckpt') #load_id = ("NLIEx_D", "model-1964") #load_id = ("NLIEx_D", "model-1317") load_id = ("SNLI_Only_A", 'model-0') e.train_nli_any_way(nli_setting, e_config, data_loader, load_id)
def get_snli_data(hp, nli_setting): data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename, True) tokenizer = get_tokenizer() CLS_ID = tokenizer.convert_tokens_to_ids(["[CLS]"])[0] SEP_ID = tokenizer.convert_tokens_to_ids(["[SEP]"])[0] data_loader.CLS_ID = CLS_ID data_loader.SEP_ID = SEP_ID cache_name = "snli_batch{}_seq{}".format(hp.batch_size, hp.seq_max) data = load_cache(cache_name) if data is None: tf_logger.info("Encoding data from csv") data = get_nli_batches_from_data_loader(data_loader, hp.batch_size) save_to_pickle(data, cache_name) return data
def pred_snli_ex(): hp = hyperparams.HPBert() e = Experiment(hp) nli_setting = NLI() nli_setting.vocab_size = 30522 nli_setting.vocab_filename = "bert_voca.txt" e_config = ExperimentConfig() e_config.name = "SNLIEx_B" e_config.load_names = ['bert', 'cls_dense', 'aux_conflict'] data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename, True) load_id = ("SNLIEx_B", 'model-10275') e.predict_rf(nli_setting, e_config, data_loader, load_id, "test")
def predict_lime_snli_continue(): hp = hyperparams.HPBert() hp.batch_size = 512 + 256 e = Experiment(hp) nli_setting = NLI() nli_setting.vocab_size = 30522 nli_setting.vocab_filename = "bert_voca.txt" e_config = ExperimentConfig() e_config.name = "SNLI_LIME_{}".format("eval") e_config.load_names = ['bert', 'cls_dense'] data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename, True) load_id = ("SNLI_Only_A", 'model-0') e.predict_lime_snli_continue(nli_setting, e_config, data_loader, load_id, "test")
def test_snli(): hp = hyperparams.HPBert() e = Experiment(hp) nli_setting = NLI() nli_setting.vocab_size = 30522 nli_setting.vocab_filename = "bert_voca.txt" e_config = ExperimentConfig() e_config.name = "SNLIEx_Test" e_config.load_names = ['bert', 'cls_dense'] # , 'aux_conflict'] data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename, True) todo = [] load_id = ("SNLI_Only_A", 'model-0') todo.append(load_id) todo.append(("SNLI_Only_1", 'model-0')) for load_id in todo: tf.reset_default_graph() e.test_acc(nli_setting, e_config, data_loader, load_id)
def train_snli_on_bert(): hp = hyperparams.HPBert() e = Experiment(hp) nli_setting = NLI() nli_setting.vocab_size = 30522 nli_setting.vocab_filename = "bert_voca.txt" e_config = ExperimentConfig() e_config.name = "SNLI_Only_{}".format("1") e_config.num_epoch = 1 e_config.save_interval = 3 * 60 * 60 # 30 minutes data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename, True) #load_id = None load_id = ("uncased_L-12_H-768_A-12", 'bert_model.ckpt') e.train_nli_ex_0(nli_setting, e_config, data_loader, load_id, f_train_ex=False)
def predict_lime_snli(): hp = hyperparams.HPBert() hp.batch_size = 1024 + 512 + 256 e = Experiment(hp) nli_setting = NLI() nli_setting.vocab_size = 30522 nli_setting.vocab_filename = "bert_voca.txt" e_config = ExperimentConfig() e_config.name = "SNLI_LIME_{}".format("eval") e_config.load_names = ['bert', 'cls_dense'] start = int(sys.argv[1]) print("Begin", start) sub_range = (start, start + 100) data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename, True) load_id = ("SNLI_Only_A", 'model-0') e.predict_lime_snli(nli_setting, e_config, data_loader, load_id, "test", sub_range)
def get_eval_params(load_type, model_path, data_type): hp = hyperparams.HPSENLI3_eval() hp.batch_size = 128 nli_setting = NLI() nli_setting.vocab_size = 30522 nli_setting.vocab_filename = "bert_voca.txt" if data_type == "mnli": data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True) elif data_type == "snli": data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename, True) else: assert False dir_path, file_name = os.path.split(model_path) run_name = os.path.split(dir_path)[1] + "/" + file_name dev_batches = get_batches_ex(data_loader.get_dev_data(), hp.batch_size, 4) if load_type == "v2": load_fn = load_bert_v2 else: load_fn = load_model return dev_batches, hp, load_fn, nli_setting, run_name