def train_nli_smart_rf(): hp = hyperparams.HPSENLI() hp.compare_deletion_num = 20 e = Experiment(hp) nli_setting = NLI() nli_setting.vocab_size = 30522 nli_setting.vocab_filename = "bert_voca.txt" e_config = ExperimentConfig() #explain_tag = 'mismatch' explain_tag = 'match' #explain_tag = 'mismatch' loss_type = 2 e_config.name = "NLIEx_Hinge_{}".format(explain_tag) e_config.num_epoch = 1 e_config.ex_val = True e_config.save_interval = 30 * 60 # 30 minutes e_config.load_names = ['bert', 'cls_dense'] #, 'aux_conflict'] e_config.save_eval = True e_config.save_name = "LossFn_{}_{}".format(loss_type, explain_tag) data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True) load_id = ("NLI_run_A", 'model-0') print("Loss : ", loss_type) e.train_nli_smart(nli_setting, e_config, data_loader, load_id, explain_tag, loss_type)
def train_snli_ex(): hp = hyperparams.HPBert() hp.compare_deletion_num = 20 e = Experiment(hp) nli_setting = NLI() nli_setting.vocab_size = 30522 nli_setting.vocab_filename = "bert_voca.txt" e_config = ExperimentConfig() e_config.name = "SNLIEx_B" e_config.ex_val = False e_config.num_epoch = 1 e_config.save_interval = 30 * 60 # 30 minutes e_config.load_names = ['bert', 'cls_dense'] #, 'aux_conflict'] #explain_tag = 'match' # 'dontcare' 'match' 'mismatch' #explain_tag = 'mismatch' #explain_tag = 'conflict' data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename, True) #load_id = ("NLI_run_nli_warm", "model-97332") #load_id = ("NLIEx_A", "model-16910") #load_id = ("uncased_L-12_H-768_A-12", 'bert_model.ckpt') #load_id = ("NLIEx_D", "model-1964") #load_id = ("NLIEx_D", "model-1317") load_id = ("SNLI_Only_A", 'model-0') e.train_nli_any_way(nli_setting, e_config, data_loader, load_id)
def train_mnli_any_way(): hp = HP() hp.batch_size = 8 hp.compare_deletion_num = 20 e = Experiment(hp) nli_setting = NLI() nli_setting.vocab_size = 30522 nli_setting.vocab_filename = "bert_voca.txt" e_config = ExperimentConfig() e_config.name = "NLIEx_Any_512" e_config.ex_val = False e_config.num_epoch = 1 e_config.save_interval = 30 * 60 # 30 minutes e_config.load_names = ['bert', 'cls_dense'] # , 'aux_conflict'] e_config.v2_load = True data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True) load_id = ("nli512", 'model.ckpt-65000') e.train_nli_any_way(nli_setting, e_config, data_loader, load_id)
def train_mnli_any_way(): hp = hyperparams.HPBert() hp.compare_deletion_num = 20 e = Experiment(hp) nli_setting = NLI() nli_setting.vocab_size = 30522 nli_setting.vocab_filename = "bert_voca.txt" e_config = ExperimentConfig() e_config.name = "NLIEx_AnyA" e_config.ex_val = False e_config.num_epoch = 1 e_config.save_interval = 30 * 60 # 30 minutes e_config.load_names = ['bert', 'cls_dense'] # , 'aux_conflict'] # explain_tag = 'match' # 'dontcare' 'match' 'mismatch' # explain_tag = 'mismatch' # explain_tag = 'conflict' data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True) load_id = ("NLI_run_A", 'model-0') e.train_nli_any_way(nli_setting, e_config, data_loader, load_id)
def train_nli_smart_rf(explain_tag): hp = hyperparams.HPSENLI() hp.compare_deletion_num = 20 e = Experiment(hp) e.log.setLevel(logging.WARNING) e.log2.setLevel(logging.WARNING) e.log.info("I don't want to see") nli_setting = NLI() nli_setting.vocab_size = 30522 nli_setting.vocab_filename = "bert_voca.txt" e_config = ExperimentConfig() e_config.name = "NLIEx_{}".format("CO_" + explain_tag) e_config.num_epoch = 1 e_config.ex_val = False e_config.save_interval = 30 * 60 # 30 minutes e_config.load_names = ['bert', 'cls_dense'] #, 'aux_conflict'] e_config.save_payload = True data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True) load_id = ("NLI_run_A", 'model-0') e.train_nli_smart(nli_setting, e_config, data_loader, load_id, explain_tag, 5)
def tuning_train_nli_rf(): l = [0.9] for g_del in l: tf.reset_default_graph() hp = hyperparams.HPSENLI() hp.g_val = g_del hp.compare_deletion_num = 20 e = Experiment(hp) nli_setting = NLI() nli_setting.vocab_size = 30522 nli_setting.vocab_filename = "bert_voca.txt" e_config = ExperimentConfig() e_config.name = "NLIEx_{}".format("match_del_{}".format(g_del)) e_config.num_epoch = 1 e_config.ex_val = True e_config.save_interval = 30 * 60 # 30 minutes e_config.load_names = ['bert', 'cls_dense'] #, 'aux_conflict'] explain_tag = 'match' # 'dontcare' 'match' 'mismatch' #explain_tag = 'conflict' #explain_tag = 'mismatch' #explain_tag = 'conflict' data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True) #load_id = ("NLI_run_nli_warm", "model-97332") #load_id = ("NLIEx_A", "model-16910") #load_id = ("uncased_L-12_H-768_A-12", 'bert_model.ckpt') #load_id = ("NLIEx_D", "model-1964") #load_id = ("NLIEx_D", "model-1317") load_id = ("NLI_run_A", 'model-0') e.train_nli_smart(nli_setting, e_config, data_loader, load_id, explain_tag, 5)