예제 #1
0
def train_nli_smart_rf():
    hp = hyperparams.HPSENLI()
    hp.compare_deletion_num = 20
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"

    e_config = ExperimentConfig()

    #explain_tag = 'mismatch'
    explain_tag = 'match'
    #explain_tag = 'mismatch'

    loss_type = 2
    e_config.name = "NLIEx_Hinge_{}".format(explain_tag)
    e_config.num_epoch = 1
    e_config.ex_val = True
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert', 'cls_dense']  #, 'aux_conflict']
    e_config.save_eval = True
    e_config.save_name = "LossFn_{}_{}".format(loss_type, explain_tag)

    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    load_id = ("NLI_run_A", 'model-0')

    print("Loss : ", loss_type)

    e.train_nli_smart(nli_setting, e_config, data_loader, load_id, explain_tag,
                      loss_type)
예제 #2
0
파일: nli_main.py 프로젝트: clover3/Chair
def train_snli_ex():
    hp = hyperparams.HPBert()
    hp.compare_deletion_num = 20
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"
    e_config = ExperimentConfig()
    e_config.name = "SNLIEx_B"
    e_config.ex_val = False
    e_config.num_epoch = 1
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert', 'cls_dense']  #, 'aux_conflict']

    #explain_tag = 'match'  # 'dontcare'  'match' 'mismatch'

    #explain_tag = 'mismatch'
    #explain_tag = 'conflict'
    data_loader = nli.SNLIDataLoader(hp.seq_max, nli_setting.vocab_filename,
                                     True)
    #load_id = ("NLI_run_nli_warm", "model-97332")
    #load_id = ("NLIEx_A", "model-16910")
    #load_id = ("uncased_L-12_H-768_A-12", 'bert_model.ckpt')
    #load_id = ("NLIEx_D", "model-1964")
    #load_id = ("NLIEx_D", "model-1317")
    load_id = ("SNLI_Only_A", 'model-0')
    e.train_nli_any_way(nli_setting, e_config, data_loader, load_id)
예제 #3
0
def train_mnli_any_way():
    hp = HP()
    hp.batch_size = 8
    hp.compare_deletion_num = 20
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"
    e_config = ExperimentConfig()
    e_config.name = "NLIEx_Any_512"
    e_config.ex_val = False
    e_config.num_epoch = 1
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert', 'cls_dense']  # , 'aux_conflict']
    e_config.v2_load = True
    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    load_id = ("nli512", 'model.ckpt-65000')
    e.train_nli_any_way(nli_setting, e_config, data_loader, load_id)
예제 #4
0
파일: nli_main.py 프로젝트: clover3/Chair
def train_mnli_any_way():
    hp = hyperparams.HPBert()
    hp.compare_deletion_num = 20
    e = Experiment(hp)
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"
    e_config = ExperimentConfig()
    e_config.name = "NLIEx_AnyA"
    e_config.ex_val = False
    e_config.num_epoch = 1
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert', 'cls_dense']  # , 'aux_conflict']

    # explain_tag = 'match'  # 'dontcare'  'match' 'mismatch'

    # explain_tag = 'mismatch'
    # explain_tag = 'conflict'
    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    load_id = ("NLI_run_A", 'model-0')
    e.train_nli_any_way(nli_setting, e_config, data_loader, load_id)
예제 #5
0
def train_nli_smart_rf(explain_tag):
    hp = hyperparams.HPSENLI()
    hp.compare_deletion_num = 20
    e = Experiment(hp)
    e.log.setLevel(logging.WARNING)
    e.log2.setLevel(logging.WARNING)
    e.log.info("I don't want to see")
    nli_setting = NLI()
    nli_setting.vocab_size = 30522
    nli_setting.vocab_filename = "bert_voca.txt"

    e_config = ExperimentConfig()
    e_config.name = "NLIEx_{}".format("CO_" + explain_tag)
    e_config.num_epoch = 1
    e_config.ex_val = False
    e_config.save_interval = 30 * 60  # 30 minutes
    e_config.load_names = ['bert', 'cls_dense']  #, 'aux_conflict']
    e_config.save_payload = True

    data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename, True)
    load_id = ("NLI_run_A", 'model-0')
    e.train_nli_smart(nli_setting, e_config, data_loader, load_id, explain_tag,
                      5)
예제 #6
0
파일: nli_main.py 프로젝트: clover3/Chair
def tuning_train_nli_rf():
    l = [0.9]
    for g_del in l:
        tf.reset_default_graph()
        hp = hyperparams.HPSENLI()
        hp.g_val = g_del
        hp.compare_deletion_num = 20
        e = Experiment(hp)
        nli_setting = NLI()
        nli_setting.vocab_size = 30522
        nli_setting.vocab_filename = "bert_voca.txt"

        e_config = ExperimentConfig()

        e_config.name = "NLIEx_{}".format("match_del_{}".format(g_del))
        e_config.num_epoch = 1
        e_config.ex_val = True
        e_config.save_interval = 30 * 60  # 30 minutes
        e_config.load_names = ['bert', 'cls_dense']  #, 'aux_conflict']

        explain_tag = 'match'  # 'dontcare'  'match' 'mismatch'
        #explain_tag = 'conflict'
        #explain_tag = 'mismatch'
        #explain_tag = 'conflict'

        data_loader = nli.DataLoader(hp.seq_max, nli_setting.vocab_filename,
                                     True)
        #load_id = ("NLI_run_nli_warm", "model-97332")
        #load_id = ("NLIEx_A", "model-16910")
        #load_id = ("uncased_L-12_H-768_A-12", 'bert_model.ckpt')
        #load_id = ("NLIEx_D", "model-1964")
        #load_id = ("NLIEx_D", "model-1317")
        load_id = ("NLI_run_A", 'model-0')

        e.train_nli_smart(nli_setting, e_config, data_loader, load_id,
                          explain_tag, 5)