Exemple #1
0
def predict(batch_size, data_path,
            dict_path, model_full_name,
            bucket=False, model_specific_dictionary=True,
            print_raw=False, dir_normalized=None, dir_original=None,
            get_batch_mode=False,
            normalization=True, debug=False, use_gpu=None, verbose=0):

    assert model_specific_dictionary, "ERROR : only model_specific_dictionary = True supported now"
    # NB : now : you have to load dictionary when evaluating (cannot recompute) (could add in the LexNormalizer ability)
    use_gpu = use_gpu_(use_gpu)
    hardware_choosen = "GPU" if use_gpu else "CPU"
    printing("{} mode ", var=([hardware_choosen]), verbose_level=0, verbose=verbose)

    if not debug:
        pdb.set_trace = lambda: 1

    model = LexNormalizer(generator=Generator, load=True, model_full_name=model_full_name,
                          voc_size=None, use_gpu=use_gpu, dict_path=dict_path, model_specific_dictionary=True,
                          dir_model=os.path.join(PROJECT_PATH, "checkpoints",
                                                 model_full_name + "-folder"),
                          char_decoding=True, word_decoding=False,
                          verbose=verbose
                          )

    data_read = conllu_data.read_data_to_variable(data_path, model.word_dictionary, model.char_dictionary,
                                                  model.pos_dictionary,
                                                  model.xpos_dictionary, model.type_dictionary,
                                                  use_gpu=use_gpu,
                                                  norm_not_norm=model.auxilliary_task_norm_not_norm,
                                                  symbolic_end=True, symbolic_root=True,
                                                  dry_run=0, lattice=False, verbose=verbose,
                                                  normalization=normalization,
                                                  bucket=bucket,
                                                  add_start_char=1, add_end_char=1)

    batchIter = data_gen_conllu(data_read, model.word_dictionary, model.char_dictionary,
                                batch_size=batch_size,
                                get_batch_mode=False,
                                normalization=normalization,
                                print_raw=print_raw,  verbose=verbose)
    model.eval()
    greedy_decode_batch(char_dictionary=model.char_dictionary, verbose=verbose,
                        gold_output=False,
                        use_gpu=use_gpu,
                        write_output=True,
                        label_data=REPO_DATASET[data_path],
                        batchIter=batchIter, model=model, dir_normalized=dir_normalized, dir_original=dir_original,
                        batch_size=batch_size)
Exemple #2
0
def train_eval(train_path,
               dev_path,
               model_id_pref,
               pos_specific_path=None,
               expand_vocab_dev_test=False,
               checkpointing_metric="loss-dev-all",
               n_epochs=11,
               test_path=None,
               args=None,
               overall_report_dir=CHECKPOINT_DIR,
               overall_label="DEFAULT",
               get_batch_mode_all=True,
               warmup=False,
               freq_checkpointing=1,
               debug=False,
               compute_scoring_curve=False,
               compute_mean_score_per_sent=False,
               print_raw=False,
               freq_scoring=5,
               bucketing_train=True,
               freq_writer=None,
               extend_n_batch=1,
               score_to_compute_ls=None,
               symbolic_end=False,
               symbolic_root=False,
               gpu=None,
               use_gpu=None,
               scoring_func_sequence_pred=DEFAULT_SCORING_FUNCTION,
               max_char_len=None,
               verbose=0):
    if gpu is not None and use_gpu_(use_gpu):
        assert use_gpu or use_gpu is None, "ERROR : use_gpu should be neutral (None) or True as 'gpu' is defined"
        #assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None, "ERROR : no CUDA_VISIBLE_DEVICES env variable (gpu should be None)"
        os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # gpu
        printing("ENV : CUDA_VISIBLE_DEVICES set to {}",
                 var=[gpu],
                 verbose=verbose,
                 verbose_level=1)

    else:
        printing("CPU mode cause {} gpu arg or use_gpu detected {} ",
                 var=(gpu, use_gpu_(use_gpu)),
                 verbose_level=1,
                 verbose=verbose)

    hidden_size_encoder = args.get("hidden_size_encoder", 10)
    word_embed = args.get("word_embed", False)
    word_embedding_projected_dim = args.get("word_embedding_projected_dim",
                                            None)
    word_embedding_dim = args.get("word_embedding_dim", 0)
    mode_word_encoding = args.get("mode_word_encoding", "cat")
    char_level_embedding_projection_dim = args.get(
        "char_level_embedding_projection_dim", 0)

    output_dim = args.get("output_dim", 10)
    char_embedding_dim = args.get("char_embedding_dim", 10)
    hidden_size_sent_encoder = args.get("hidden_size_sent_encoder", 10)
    hidden_size_decoder = args.get("hidden_size_decoder", 10)
    batch_size = args.get("batch_size", 2)
    dropout_sent_encoder, dropout_word_encoder_cell, dropout_word_decoder = args.get("dropout_sent_encoder",0), \
    args.get("dropout_word_encoder_cell", 0), args.get("dropout_word_decoder", 0)
    n_layers_word_encoder = args.get("n_layers_word_encoder", 1)
    n_layers_sent_cell = args.get("n_layers_sent_cell", 1)
    dir_sent_encoder = args.get("dir_sent_encoder", 1)

    drop_out_word_encoder_out = args.get("drop_out_word_encoder_out", 0)
    drop_out_sent_encoder_out = args.get("drop_out_sent_encoder_out", 0)
    dropout_bridge = args.get("dropout_bridge", 0)

    word_recurrent_cell_encoder = args.get("word_recurrent_cell_encoder",
                                           "GRU")
    word_recurrent_cell_decoder = args.get("word_recurrent_cell_decoder",
                                           "GRU")
    dense_dim_auxilliary = args.get("dense_dim_auxilliary", None)
    dense_dim_auxilliary_2 = args.get("dense_dim_auxilliary_2", None)

    drop_out_char_embedding_decoder = args.get(
        "drop_out_char_embedding_decoder", 0)
    unrolling_word = args.get("unrolling_word", False)

    #auxilliary_task_norm_not_norm = args.get("auxilliary_task_norm_not_norm",False)
    char_src_attention = args.get("char_src_attention", False)
    weight_binary_loss = args.get("weight_binary_loss", 1)
    dir_word_encoder = args.get("dir_word_encoder", 1)
    shared_context = args.get("shared_context", "all")

    schedule_training_policy = args.get("policy", None)
    lr = args.get("lr", 0.001)
    gradient_clipping = args.get("gradient_clipping", None)

    teacher_force = args.get("teacher_force", True)
    proportion_pred_train = args.get("proportion_pred_train", None)
    if teacher_force and proportion_pred_train is not None:
        printing(
            "WARNING : inconsistent arguments solved :  proportion_pred_train forced to None while it was {} "
            "cause teacher_force mode",
            var=[proportion_pred_train],
            verbose=verbose,
            verbose_level=0)
        proportion_pred_train = None

    stable_decoding_state = args.get("stable_decoding_state", False)
    init_context_decoder = args.get("init_context_decoder", True)
    optimizer = args.get("optimizer", "adam")

    word_decoding = args.get("word_decoding", False)
    dense_dim_word_pred = args.get("dense_dim_word_pred", None)
    dense_dim_word_pred_2 = args.get("dense_dim_word_pred_2", None)
    dense_dim_word_pred_3 = args.get("dense_dim_word_pred_3", 0)
    word_embed_init = args.get("word_embed_init", None)

    char_decoding = args.get("char_decoding", True)

    #auxilliary_task_pos = args.get("auxilliary_task_pos", False)
    dense_dim_auxilliary_pos = args.get("dense_dim_auxilliary_pos", None)
    dense_dim_auxilliary_pos_2 = args.get("dense_dim_auxilliary_pos_2", None)

    activation_char_decoder = args.get("activation_char_decoder", None)
    activation_word_decoder = args.get("activation_word_decoder", None)
    tasks = args.get("tasks", ["normalize"])

    attention_tagging = args.get("attention_tagging", False)

    multi_task_loss_ponderation = args.get("multi_task_loss_ponderation",
                                           "all")
    dropout_input = args.get("dropout_input", None)

    n_epochs = WARMUP_N_EPOCHS if warmup else n_epochs

    if test_path is not None:
        assert isinstance(
            test_path, list
        ), "ERROR test_path should be a list with one element per task "
        assert isinstance(
            test_path[0], list
        ), "ERROR : each element of test_path should be a list of dataset path even of len 1"
    print(
        "WARNING : only dataset that are in test_path will be evlauated (test_path:{}) "
        .format(test_path))
    if warmup:
        printing("Warm up : running 1 epoch ",
                 verbose=verbose,
                 verbose_level=0)
    printing("GRID : START TRAINING ", verbose_level=0, verbose=verbose)
    printing("SANITY CHECK : TASKS {} ",
             var=[tasks],
             verbose=verbose,
             verbose_level=1)
    normalization = "normalize" in tasks or "norm_not_norm" in tasks
    printing("SANITY CHECK : normalization {} ",
             var=normalization,
             verbose=verbose,
             verbose_level=1)
    model_full_name = train(
        train_path,
        dev_path,
        pos_specific_path=pos_specific_path,
        checkpointing_metric=checkpointing_metric,
        expand_vocab_dev_test=expand_vocab_dev_test
        if word_embed_init is not None else False,
        dense_dim_auxilliary=dense_dim_auxilliary,
        dense_dim_auxilliary_2=dense_dim_auxilliary_2,
        lr=lr,
        extend_n_batch=extend_n_batch,
        n_epochs=n_epochs,
        normalization=normalization,
        get_batch_mode_all=get_batch_mode_all,
        batch_size=batch_size,
        model_specific_dictionary=True,
        freq_writer=freq_writer,
        dict_path=None,
        model_dir=None,
        add_start_char=1,
        freq_scoring=freq_scoring,
        add_end_char=1,
        use_gpu=use_gpu,
        dir_sent_encoder=dir_sent_encoder,
        dropout_sent_encoder_cell=dropout_sent_encoder,
        dropout_word_encoder_cell=dropout_word_encoder_cell,
        dropout_word_decoder_cell=dropout_word_decoder,
        policy=schedule_training_policy,
        dir_word_encoder=dir_word_encoder,
        compute_mean_score_per_sent=compute_mean_score_per_sent,
        overall_label=overall_label,
        overall_report_dir=overall_report_dir,
        label_train=get_data_set_label(train_path),
        label_dev=get_data_set_label(dev_path),
        word_recurrent_cell_encoder=word_recurrent_cell_encoder,
        word_recurrent_cell_decoder=word_recurrent_cell_decoder,
        drop_out_sent_encoder_out=drop_out_sent_encoder_out,
        drop_out_char_embedding_decoder=drop_out_char_embedding_decoder,
        word_embedding_dim=word_embedding_dim,
        word_embed=word_embed,
        word_embedding_projected_dim=word_embedding_projected_dim,
        mode_word_encoding=mode_word_encoding,
        char_level_embedding_projection_dim=char_level_embedding_projection_dim,
        drop_out_word_encoder_out=drop_out_word_encoder_out,
        dropout_bridge=dropout_bridge,
        freq_checkpointing=freq_checkpointing,
        reload=False,
        model_id_pref=model_id_pref,
        score_to_compute_ls=score_to_compute_ls,
        mode_norm_ls=["all", "NEED_NORM", "NORMED"],
        hidden_size_encoder=hidden_size_encoder,
        output_dim=output_dim,
        char_embedding_dim=char_embedding_dim,
        extern_emb_dir=word_embed_init,
        hidden_size_sent_encoder=hidden_size_sent_encoder,
        hidden_size_decoder=hidden_size_decoder,
        n_layers_word_encoder=n_layers_word_encoder,
        n_layers_sent_cell=n_layers_sent_cell,
        compute_scoring_curve=compute_scoring_curve,
        unrolling_word=unrolling_word,
        char_src_attention=char_src_attention,
        print_raw=print_raw,
        debug=debug,
        shared_context=shared_context,
        bucketing=bucketing_train,
        weight_binary_loss=weight_binary_loss,
        teacher_force=teacher_force,
        proportion_pred_train=proportion_pred_train,
        clipping=gradient_clipping,
        tasks=tasks,
        optimizer=optimizer,
        #auxilliary_task_pos=auxilliary_task_pos,
        dense_dim_auxilliary_pos=dense_dim_auxilliary_pos,
        dense_dim_auxilliary_pos_2=dense_dim_auxilliary_pos_2,
        word_decoding=word_decoding,
        dense_dim_word_pred=dense_dim_word_pred,
        dense_dim_word_pred_2=dense_dim_word_pred_2,
        dense_dim_word_pred_3=dense_dim_word_pred_3,
        char_decoding=char_decoding,
        activation_char_decoder=activation_char_decoder,
        activation_word_decoder=activation_word_decoder,
        symbolic_end=symbolic_end,
        symbolic_root=symbolic_root,
        attention_tagging=attention_tagging,
        stable_decoding_state=stable_decoding_state,
        init_context_decoder=init_context_decoder,
        multi_task_loss_ponderation=multi_task_loss_ponderation,
        dropout_input=dropout_input,
        test_path=test_path[0] if isinstance(test_path, list) else test_path,
        max_char_len=max_char_len,
        checkpointing=True,
        verbose=verbose)

    model_dir = os.path.join(CHECKPOINT_DIR, model_full_name + "-folder")
    if test_path is not None:
        dict_path = os.path.join(CHECKPOINT_DIR, model_full_name + "-folder",
                                 "dictionaries")
        printing("GRID : START EVALUATION FINAL ",
                 verbose_level=0,
                 verbose=verbose)
        # you have to specify all data you want to evaluate !!
        eval_data_paths = test_path
        #eval_data_paths = list(set(eval_data_paths))
        start_eval = time.time()
        if len(tasks) > 1:
            assert isinstance(eval_data_paths,
                              list), "ERROR : on element per task"
            assert isinstance(
                eval_data_paths[0], list
            ), "ERROR : in multitask we want list of list for eval_data_paths {} one sublist per task {} ".format(
                eval_data_paths, tasks)
        if len(tasks) == 1:
            tasks = [tasks[0] for _ in eval_data_paths]
    print("EVALUATING WITH {}".format(scoring_func_sequence_pred))
    for task, eval_data in zip(tasks, eval_data_paths):
        for eval_data_set in eval_data:
            printing("EVALUATING task {} on dataset {}",
                     var=[task, eval_data_set],
                     verbose=verbose,
                     verbose_level=1)
            evaluate(model_full_name=model_full_name,
                     data_path=eval_data_set,
                     dict_path=dict_path,
                     use_gpu=use_gpu,
                     label_report=REPO_DATASET[eval_data_set],
                     overall_label=overall_label + "-last",
                     score_to_compute_ls=score_to_compute_ls,
                     mode_norm_ls=["all", "NEED_NORM", "NORMED"],
                     normalization=normalization,
                     print_raw=print_raw,
                     model_specific_dictionary=True,
                     get_batch_mode_evaluate=False,
                     bucket=True,
                     compute_mean_score_per_sent=compute_mean_score_per_sent,
                     batch_size=batch_size,
                     debug=debug,
                     word_decoding=word_decoding,
                     char_decoding=char_decoding,
                     scoring_func_sequence_pred=scoring_func_sequence_pred,
                     evaluated_task=task,
                     tasks=tasks,
                     max_char_len=max_char_len,
                     dir_report=model_dir,
                     verbose=1)
        printing("GRID : END EVAL {:.3f}s ".format(time.time() - start_eval),
                 verbose=verbose,
                 verbose_level=1)
    printing("WARNING : no evaluation ", verbose=verbose, verbose_level=0)

    return model_full_name, model_dir
Exemple #3
0
def train(train_path,
          dev_path,
          n_epochs,
          normalization,
          dict_path=None,
          pos_specific_path=None,
          expand_vocab_dev_test=False,
          checkpointing_metric="loss-dev-all",
          batch_size=10,
          test_path=None,
          label_train="",
          label_dev="",
          use_gpu=None,
          lr=0.001,
          n_layers_word_encoder=1,
          n_layers_sent_cell=1,
          get_batch_mode_all=True,
          dropout_sent_encoder_cell=0,
          dropout_word_encoder_cell=0,
          dropout_word_decoder_cell=0,
          dropout_bridge=0,
          drop_out_word_encoder_out=0,
          drop_out_sent_encoder_out=0,
          dir_word_encoder=1,
          word_embed=False,
          word_embedding_dim=None,
          word_embedding_projected_dim=None,
          mode_word_encoding="cat",
          char_level_embedding_projection_dim=0,
          word_recurrent_cell_encoder=None,
          word_recurrent_cell_decoder=None,
          drop_out_char_embedding_decoder=0,
          hidden_size_encoder=None,
          output_dim=None,
          char_embedding_dim=None,
          hidden_size_decoder=None,
          hidden_size_sent_encoder=None,
          freq_scoring=5,
          compute_scoring_curve=False,
          score_to_compute_ls=None,
          mode_norm_ls=None,
          checkpointing=True,
          freq_checkpointing=None,
          freq_writer=None,
          model_dir=None,
          reload=False,
          model_full_name=None,
          model_id_pref="",
          print_raw=False,
          model_specific_dictionary=False,
          dir_sent_encoder=1,
          add_start_char=None,
          add_end_char=1,
          overall_label="DEFAULT",
          overall_report_dir=CHECKPOINT_DIR,
          compute_mean_score_per_sent=False,
          weight_binary_loss=1,
          dense_dim_auxilliary=None,
          dense_dim_auxilliary_2=None,
          unrolling_word=False,
          char_src_attention=False,
          debug=False,
          timing=False,
          dev_report_loss=True,
          bucketing=True,
          policy=None,
          teacher_force=True,
          proportion_pred_train=None,
          shared_context="all",
          clipping=None,
          extend_n_batch=1,
          stable_decoding_state=False,
          init_context_decoder=True,
          dense_dim_auxilliary_pos=None,
          dense_dim_auxilliary_pos_2=None,
          tasks=None,
          word_decoding=False,
          char_decoding=True,
          dense_dim_word_pred=None,
          dense_dim_word_pred_2=None,
          dense_dim_word_pred_3=None,
          symbolic_root=False,
          symbolic_end=False,
          extern_emb_dir=None,
          activation_word_decoder=None,
          activation_char_decoder=None,
          extra_arg_specific_label="",
          freezing_mode=False,
          freeze_ls_param_prefix=None,
          multi_task_loss_ponderation=None,
          max_char_len=None,
          attention_tagging=False,
          dropout_input=None,
          optimizer="adam",
          verbose=1):

    if multi_task_loss_ponderation is not None:
        sanity_check_loss_poneration(multi_task_loss_ponderation,
                                     verbose=verbose)
    if teacher_force:
        assert proportion_pred_train is None, "proportion_pred_train should be None as teacher_force mode"
    else:
        assert 100 > proportion_pred_train > 0, "proportion_pred_train should be between 0 and 100"
    auxilliary_task_norm_not_norm = "norm_not_norm" in tasks  # auxilliary_task_norm_not_norm
    auxilliary_task_pos = "pos" in tasks
    if "normalize" not in tasks:
        word_decoding = False
        char_decoding = False
    if not unrolling_word:
        assert not char_src_attention, "ERROR attention requires step by step unrolling  "
    printing("WARNING bucketing is {} ",
             var=bucketing,
             verbose=verbose,
             verbose_level=1)
    if freq_writer is None:
        freq_writer = freq_checkpointing
        printing("REPORTING freq_writer set to freq_checkpointing {}",
                 var=[freq_checkpointing],
                 verbose=verbose,
                 verbose_level=1)
    if auxilliary_task_norm_not_norm:
        printing(
            "MODEL : training model with auxillisary task (loss weighted with {})",
            var=[weight_binary_loss],
            verbose=verbose,
            verbose_level=1)
    #if compute_scoring_curve:
    #assert score_to_compute_ls is not None and mode_norm_ls is not None and freq_scoring is not None, \
    #    "ERROR score_to_compute_ls and mode_norm_ls should not be None"
    use_gpu = use_gpu_(use_gpu)
    hardware_choosen = "GPU" if use_gpu else "CPU"
    printing("{} hardware mode ",
             var=([hardware_choosen]),
             verbose_level=0,
             verbose=verbose)
    freq_checkpointing = int(
        n_epochs / 10
    ) if checkpointing and freq_checkpointing is None else freq_checkpointing
    assert add_start_char == 1, "ERROR : add_start_char must be activated due decoding behavior of output_text_"
    printing("WARNING : add_start_char is {} and add_end_char {}  ".format(
        add_start_char, add_end_char),
             verbose=verbose,
             verbose_level=0)
    printing("TRAINING : checkpointing every {} epoch",
             var=freq_checkpointing,
             verbose=verbose,
             verbose_level=1)
    if reload:
        assert model_full_name is not None and len(
            model_id_pref
        ) == 0 and model_dir is not None and dict_path is not None
    else:
        assert model_full_name is None and model_dir is None

    if not debug:
        pdb.set_trace = lambda: None

    loss_training = []
    loss_developing = []
    # was not able to use the template cause no more reinitialization of the variable
    loss_details_template = {
        'loss_seq_prediction': [],
        'other': {},
        'loss_binary': [],
        'loss_overall': []
    } if auxilliary_task_norm_not_norm else None

    # used for computed scores for early stoping if checkpoint_metric != loss and for curves plot
    evaluation_set_reporting = dev_path
    mode_norm_ls = ["all"]
    score_to_compute_ls = ["exact_match"]
    print(
        "WARNING :train.py overwriting mode_norm_ls score_to_compute_ls argument "
    )
    curve_scores = {
        score + "-" + mode_norm + "-" + REPO_DATASET[data]: []
        for score in score_to_compute_ls for mode_norm in mode_norm_ls
        for data in evaluation_set_reporting
    } if compute_scoring_curve else None

    printing("WARNING :  lr {} ".format(lr, add_start_char, add_end_char),
             verbose=verbose,
             verbose_level=0)
    printing(
        "INFO : dictionary is computed (re)created from scratch on train_path {} and dev_path {}"
        .format(train_path, dev_path),
        verbose=verbose,
        verbose_level=1)

    if not model_specific_dictionary:
        word_dictionary, char_dictionary, pos_dictionary, \
        xpos_dictionary, type_dictionary = \
        conllu_data.load_dict(dict_path=dict_path,
                              train_path=train_path,
                              dev_path=dev_path,
                              test_path=test_path,
                              word_embed_dict={},
                              dry_run=False,
                              force_new_dic=True,
                              add_start_char=add_start_char, verbose=1)

        voc_size = len(char_dictionary.instance2index) + 1
        word_voc_input_size = len(word_dictionary.instance2index) + 1
        printing("DICTIONARY ; character vocabulary is len {} : {} ",
                 var=str(
                     len(char_dictionary.instance2index) + 1,
                     char_dictionary.instance2index),
                 verbose=verbose,
                 verbose_level=0)
        _train_path, _dev_path, _add_start_char = None, None, None
    else:
        voc_size = None
        word_voc_input_size = 0
        if not reload:
            # we need to feed the model the data so that it computes the model_specific_dictionary
            _train_path = train_path
            _dev_path = dev_path
            _test_path = test_path
            _add_start_char = add_start_char
        else:
            # as it reload : we don't need data
            _train_path, _dev_path, _test_path, _add_start_char = None, None, None, None

    model = LexNormalizer(
        generator=Generator,
        expand_vocab_dev_test=expand_vocab_dev_test,
        dense_dim_auxilliary=dense_dim_auxilliary,
        dense_dim_auxilliary_2=dense_dim_auxilliary_2,
        tasks=tasks,
        weight_binary_loss=weight_binary_loss,
        dense_dim_auxilliary_pos=dense_dim_auxilliary_pos,
        dense_dim_auxilliary_pos_2=dense_dim_auxilliary_pos_2,
        load=reload,
        char_embedding_dim=char_embedding_dim,
        voc_size=voc_size,
        dir_model=model_dir,
        use_gpu=use_gpu,
        dict_path=dict_path,
        word_recurrent_cell_decoder=word_recurrent_cell_decoder,
        word_recurrent_cell_encoder=word_recurrent_cell_encoder,
        train_path=_train_path,
        dev_path=_dev_path,
        pos_specific_path=pos_specific_path,
        add_start_char=_add_start_char,
        model_specific_dictionary=model_specific_dictionary,
        dir_word_encoder=dir_word_encoder,
        drop_out_sent_encoder_cell=dropout_sent_encoder_cell,
        drop_out_word_encoder_cell=dropout_word_encoder_cell,
        drop_out_word_decoder_cell=dropout_word_decoder_cell,
        drop_out_bridge=dropout_bridge,
        drop_out_char_embedding_decoder=drop_out_char_embedding_decoder,
        drop_out_word_encoder_out=drop_out_word_encoder_out,
        drop_out_sent_encoder_out=drop_out_sent_encoder_out,
        n_layers_word_encoder=n_layers_word_encoder,
        dir_sent_encoder=dir_sent_encoder,
        n_layers_sent_cell=n_layers_sent_cell,
        hidden_size_encoder=hidden_size_encoder,
        output_dim=output_dim,
        model_id_pref=model_id_pref,
        model_full_name=model_full_name,
        hidden_size_sent_encoder=hidden_size_sent_encoder,
        shared_context=shared_context,
        unrolling_word=unrolling_word,
        char_src_attention=char_src_attention,
        word_decoding=word_decoding,
        dense_dim_word_pred=dense_dim_word_pred,
        dense_dim_word_pred_2=dense_dim_word_pred_2,
        dense_dim_word_pred_3=dense_dim_word_pred_3,
        char_decoding=char_decoding,
        mode_word_encoding=mode_word_encoding,
        char_level_embedding_projection_dim=char_level_embedding_projection_dim,
        stable_decoding_state=stable_decoding_state,
        init_context_decoder=init_context_decoder,
        symbolic_root=symbolic_root,
        symbolic_end=symbolic_end,
        word_embed=word_embed,
        word_embedding_dim=word_embedding_dim,
        word_embedding_projected_dim=word_embedding_projected_dim,
        word_embed_dir=extern_emb_dir,
        word_voc_input_size=word_voc_input_size,
        teacher_force=teacher_force,
        activation_char_decoder=activation_char_decoder,
        activation_word_decoder=activation_word_decoder,
        test_path=_test_path,
        extend_vocab_with_test=_test_path is not None,
        attention_tagging=attention_tagging,
        multi_task_loss_ponderation=
        multi_task_loss_ponderation,  # needed for save/reloading purposes
        hidden_size_decoder=hidden_size_decoder,
        verbose=verbose,
        timing=timing)

    pos_batch = auxilliary_task_pos

    if use_gpu:
        model = model.cuda()
        printing("TYPE model is cuda : {} ",
                 var=(next(model.parameters()).is_cuda),
                 verbose=verbose,
                 verbose_level=4)
        #model.decoder.attn_layer = model.decoder.attn_layer.cuda()
    if not model_specific_dictionary:
        model.word_dictionary, model.char_dictionary, model.pos_dictionary, \
        model.xpos_dictionary, model.type_dictionary = word_dictionary, char_dictionary, pos_dictionary, xpos_dictionary, type_dictionary

    starting_epoch = model.arguments["info_checkpoint"][
        "n_epochs"] if reload else 1
    reloading = "" if not reload else " reloaded from " + str(starting_epoch)
    n_epochs += starting_epoch
    if freezing_mode:
        assert freeze_ls_param_prefix is not None, "freeze_ls_param_prefix should not be None"
        printing("TRAINING : freezing is on for layers {} ",
                 var=[freeze_ls_param_prefix],
                 verbose=verbose,
                 verbose_level=1)
        for name, param in model.named_parameters():
            for freeze_param in freeze_ls_param_prefix:
                if name.startswith(freeze_param):
                    param.requires_grad = False
                    printing("TRAINING : freezing {} parameter ",
                             var=[name],
                             verbose=verbose,
                             verbose_level=1)

    _loss_dev = 1000
    checkpoint_score_saved = 1000
    _loss_train = 1000
    counter_no_deacrease = 0
    saved_epoch = 1
    if reload:
        printing(
            "TRAINING : RELOADED MODE , starting from checkpointed epoch {} ",
            var=starting_epoch,
            verbose_level=0,
            verbose=verbose)
    printing(
        "TRAINING : Running from {} to {} epochs : training on {} evaluating on {}",
        var=(starting_epoch, n_epochs, train_path, dev_path),
        verbose=verbose,
        verbose_level=0)
    starting_time = time.time()
    total_time = 0
    x_axis_epochs = []
    epoch_ls_dev = []
    epoch_ls_train = []

    train_path = [train_path] if isinstance(train_path, str) else train_path
    dev_path = [dev_path] if isinstance(dev_path, str) else dev_path

    readers_train = readers_load(
        datasets=train_path,
        tasks=tasks,
        word_dictionary=model.word_dictionary,
        word_dictionary_norm=model.word_nom_dictionary,
        char_dictionary=model.char_dictionary,
        pos_dictionary=model.pos_dictionary,
        xpos_dictionary=model.xpos_dictionary,
        type_dictionary=model.type_dictionary,
        use_gpu=use_gpu,
        norm_not_norm=auxilliary_task_norm_not_norm,
        word_decoder=word_decoding,
        add_start_char=add_start_char,
        add_end_char=add_end_char,
        symbolic_end=symbolic_end,
        symbolic_root=symbolic_root,
        bucket=bucketing,
        max_char_len=max_char_len,
        verbose=verbose)

    readers_dev = readers_load(datasets=dev_path,
                               tasks=tasks,
                               word_dictionary=model.word_dictionary,
                               word_dictionary_norm=model.word_nom_dictionary,
                               char_dictionary=model.char_dictionary,
                               pos_dictionary=model.pos_dictionary,
                               xpos_dictionary=model.xpos_dictionary,
                               type_dictionary=model.type_dictionary,
                               use_gpu=use_gpu,
                               norm_not_norm=auxilliary_task_norm_not_norm,
                               word_decoder=word_decoding,
                               add_start_char=add_start_char,
                               add_end_char=add_end_char,
                               symbolic_end=symbolic_end,
                               symbolic_root=symbolic_root,
                               bucket=bucketing,
                               max_char_len=max_char_len,
                               verbose=verbose)

    dir_writer = os.path.join(overall_report_dir, "runs",
                              "{}-model".format(model.model_full_name))
    writer = SummaryWriter(log_dir=dir_writer)
    printing(
        "REPORT : run \ntensorboard --logdir={} --host=localhost --port=9101 "
        "(run tensorboard remotely : sh $EXPERIENCE/track/run_tensorboard_serveo.sh $log_dir $port )  ",
        var=[dir_writer],
        verbose=verbose,
        verbose_level=1)
    printing("REPORT : summary writer will be located {}",
             var=[dir_writer],
             verbose_level=1,
             verbose=verbose)
    step_train = 0
    step_dev = 0
    if ADAPTABLE_SCORING:
        printing("WARNING : scoring epochs not regular (more at the begining ",
                 verbose_level=1,
                 verbose=verbose)
        freq_scoring = 1
    checkpoint_dir_former = None

    for epoch in tqdm(range(starting_epoch, n_epochs),
                      disable_tqdm_level(verbose=verbose, verbose_level=0)):
        index_look = 25
        #parameters = filter(lambda p: p.requires_grad, model.parameters())
        decay_rate = 1
        opt = dptx.get_optimizer(model.parameters(),
                                 lr=lr * decay_rate**epoch,
                                 optimizer="adam")
        assert policy in AVAILABLE_SCHEDULING_POLICIES
        policy_dic = eval(policy)(epoch) if policy is not None else None
        #TODO : no need of re-ouptuting multi_task_mode : tasks should be harmonized to read
        multi_task_mode, ponderation_normalize_loss, weight_binary_loss, weight_pos_loss = scheduling_policy(
            epoch=epoch, phases_ls=policy_dic, tasks=tasks)
        printing(
            "TRAINING Tasks scheduling : ponderation_normalize_loss is {} weight_binary_loss is {}"
            " weight_pos_loss is {} mode is {} ",
            var=[
                ponderation_normalize_loss, weight_binary_loss,
                weight_pos_loss, multi_task_mode
            ],
            verbose=verbose,
            verbose_level=2)

        printing("TRAINING : Starting {} epoch out of {} ",
                 var=(epoch + 1, n_epochs),
                 verbose=verbose,
                 verbose_level=1)
        model.train()
        #batchIter = data_gen_conllu(data_read_train,model.word_dictionary, model.char_dictionary,normalization=normalization,get_batch_mode=get_batch_mode_all,batch_size=batch_size, extend_n_batch=extend_n_batch,print_raw=print_raw, timing=timing, pos_dictionary=model.pos_dictionary,verbose=verbose)
        batchIter = data_gen_multi_task_sampling_batch(
            tasks=tasks,
            readers=readers_train,
            batch_size=batch_size,
            word_dictionary=model.word_dictionary,
            char_dictionary=model.char_dictionary,
            pos_dictionary=model.pos_dictionary,
            word_dictionary_norm=model.word_nom_dictionary,
            get_batch_mode=get_batch_mode_all,
            extend_n_batch=extend_n_batch,
            dropout_input=dropout_input,
            verbose=verbose)
        start = time.time()
        printing(
            "TRAINING : TEACHER FORCE : Schedule Sampling proportion of train on prediction is {} ",
            var=[proportion_pred_train],
            verbose=verbose,
            verbose_level=2)

        #rep_tl.checkout_layer_name("encoder.seq_encoder.weight_ih_l0", model.named_parameters(), info_epoch=epoch)

        loss_train, loss_details_train, step_train = run_epoch(
            batchIter,
            model,
            LossCompute(
                model.generator,
                opt=opt,
                multi_task_loss_ponderation=model.multi_task_loss_ponderation,
                auxilliary_task_norm_not_norm=auxilliary_task_norm_not_norm,
                model=model,
                writer=writer,
                use="train",
                use_gpu=use_gpu,
                verbose=verbose,
                tasks=tasks,
                char_decoding=char_decoding,
                word_decoding=word_decoding,
                pos_pred=auxilliary_task_pos,
                vocab_char_size=len(
                    list(model.char_dictionary.instance2index.keys())) + 1,
                timing=timing),
            verbose=verbose,
            i_epoch=epoch,
            multi_task_mode=multi_task_mode,
            n_epochs=n_epochs,
            timing=timing,
            weight_binary_loss=weight_binary_loss,
            weight_pos_loss=weight_pos_loss,
            ponderation_normalize_loss=ponderation_normalize_loss,
            step=step_train,
            clipping=clipping,
            pos_batch=pos_batch,
            proportion_pred_train=proportion_pred_train,
            log_every_x_batch=100)

        writer_weights_and_grad(model=model,
                                freq_writer=freq_writer,
                                epoch=epoch,
                                writer=writer,
                                verbose=verbose)

        _train_ep_time, start = get_timing(start)
        model.eval()
        # TODO : should be added in the freq_checkpointing orhterwise useless
        #batchIter_eval = data_gen_conllu(data_read_dev,model.word_dictionary, model.char_dictionary,batch_size=batch_size, get_batch_mode=False,normalization=normalization, extend_n_batch=1,pos_dictionary=model.pos_dictionary, verbose=verbose)
        batchIter_eval = data_gen_multi_task_sampling_batch(
            tasks=tasks,
            readers=readers_dev,
            batch_size=batch_size,
            word_dictionary=model.word_dictionary,
            char_dictionary=model.char_dictionary,
            word_dictionary_norm=model.word_nom_dictionary,
            pos_dictionary=model.pos_dictionary,
            dropout_input=0,
            extend_n_batch=1,
            get_batch_mode=False,
            verbose=verbose)
        _create_iter_time, start = get_timing(start)
        # TODO : should be able o factorize this to have a single run_epoch() for train and dev (I think the computaiton would be same )
        # TODO : should not evaluate for each epoch : should evalaute every x epoch : check if it decrease and checkpoint
        if (dev_report_loss and
            (epoch % freq_checkpointing == 0)) or (epoch + 1 == n_epochs):
            printing("EVALUATION : computing loss on dev epoch {}  ",
                     var=epoch,
                     verbose=verbose,
                     verbose_level=1)
            loss_obj = LossCompute(
                model.generator,
                use_gpu=use_gpu,
                verbose=verbose,
                multi_task_loss_ponderation=model.multi_task_loss_ponderation,
                writer=writer,
                use="dev",
                vocab_char_size=len(
                    list(model.char_dictionary.instance2index.keys())) + 1,
                pos_pred=auxilliary_task_pos,
                tasks=tasks,
                char_decoding=char_decoding,
                word_decoding=word_decoding,
                auxilliary_task_norm_not_norm=auxilliary_task_norm_not_norm)
            loss_dev, loss_details_dev, step_dev = run_epoch(
                batchIter_eval,
                model,
                loss_compute=loss_obj,
                i_epoch=epoch,
                n_epochs=n_epochs,
                verbose=verbose,
                timing=timing,
                step=step_dev,
                weight_binary_loss=weight_binary_loss,
                ponderation_normalize_loss=ponderation_normalize_loss,
                weight_pos_loss=weight_pos_loss,
                pos_batch=pos_batch,
                log_every_x_batch=100)

            loss_developing.append(loss_dev)
            epoch_ls_dev.append(epoch)

            if auxilliary_task_norm_not_norm:
                # in this case we report loss detail
                for ind, loss_key in enumerate(loss_details_dev.keys()):
                    if loss_key != "other":
                        loss_details_template[loss_key].append(
                            loss_details_dev[loss_key])
            else:
                loss_details_template = None

        _eval_time, start = get_timing(start)
        loss_training.append(loss_train)
        epoch_ls_train.append(epoch)
        time_per_epoch = time.time() - starting_time
        total_time += time_per_epoch
        starting_time = time.time()

        # computing exact/edit score
        exact_only = False
        overall_report_ls = None
        # MODIFIED FREQ SCORING TO FREQ CHECKPOINTING

        if compute_scoring_curve and (
            (epoch %
             (freq_checkpointing if checkpointing_metric != "loss-dev-all" else
              freq_scoring) == 0) or (epoch + 1 == n_epochs)):
            if epoch < 1 and ADAPTABLE_SCORING:
                freq_scoring *= 5
            if epoch > 4 and epoch < 6 and ADAPTABLE_SCORING:
                freq_scoring *= 3
            if epoch > 14 and epoch < 15 and ADAPTABLE_SCORING:
                freq_scoring *= 2
            if (epoch + 1 == n_epochs):
                printing("EVALUATION : final scoring ",
                         verbose,
                         verbose_level=0)
            x_axis_epochs.append(epoch)
            printing("EVALUATION : Computing score on {} and {}  ",
                     var=(score_to_compute_ls, mode_norm_ls),
                     verbose=verbose,
                     verbose_level=1)
            overall_report_ls = []
            for task, eval_data in zip(tasks, evaluation_set_reporting):
                eval_label = REPO_DATASET[eval_data]
                assert len(set(evaluation_set_reporting)) == len(evaluation_set_reporting),\
                    "ERROR : twice the same dataset has been provided for reporting which will mess up the loss"
                printing("EVALUATION on {} ",
                         var=[eval_data],
                         verbose=verbose,
                         verbose_level=1)
                scores = evaluate(
                    data_path=eval_data,
                    use_gpu=use_gpu,
                    overall_label=overall_label,
                    overall_report_dir=overall_report_dir,
                    score_to_compute_ls=score_to_compute_ls,
                    mode_norm_ls=mode_norm_ls,
                    label_report=eval_label,
                    model=model,
                    normalization=normalization,
                    print_raw=False,
                    model_specific_dictionary=True,
                    get_batch_mode_evaluate=False,
                    compute_mean_score_per_sent=compute_mean_score_per_sent,
                    batch_size=batch_size,
                    word_decoding=word_decoding,
                    dir_report=model.dir_model,
                    debug=debug,
                    evaluated_task=task,
                    tasks=tasks,
                    verbose=verbose)
                # we keep everythinghere in case we want to do some fancy early stopping metric
                overall_report_ls.extend(scores)

                # dirty but do the job
                exact_only = True
                DEPRECIATED = False
                if DEPRECIATED:
                    curve_scores = update_curve_dic(
                        score_to_compute_ls=score_to_compute_ls,
                        mode_norm_ls=mode_norm_ls,
                        eval_data=eval_label,
                        former_curve_scores=curve_scores,
                        scores=scores,
                        exact_only=exact_only)
                    curve_ls_tuple = [
                        (loss_ls, label)
                        for label, loss_ls in curve_scores.items()
                        if isinstance(loss_ls, list)
                    ]
                    curves = [tupl[0] for tupl in curve_ls_tuple]
                    val_ls = [
                        tupl[1] + "({}tok)".format(info_token)
                        for tupl in curve_ls_tuple
                        for data, info_token in curve_scores.items()
                        if not isinstance(info_token, list)
                        if tupl[1].endswith(data)
                    ]
            score_to_compute_ls = ["exact"
                                   ] if exact_only else score_to_compute_ls
            if DEPRECIATED:
                for score_plot in score_to_compute_ls:
                    # dirty but do the job
                    print(val_ls)
                    if exact_only:
                        val_ls = [
                            val for val in val_ls
                            if val.startswith("exact-all")
                            or val.startswith("exact-NORMED")
                            or val.startswith("exact-NEED_NORM")
                        ]
                        #val_ls = ["{}-all-{}".format(metric,REPO_DATASET[eval]) for eval in evaluation_set_reporting for metric in ["exact", "edit"]]
                        curves = [curve for curve in curves if len(curve) > 0]

                    simple_plot_ls(losses_ls=curves,
                                   labels=val_ls,
                                   final_loss="",
                                   save=True,
                                   filter_by_label=score_plot,
                                   x_axis=x_axis_epochs,
                                   dir=model.dir_model,
                                   prefix=model.model_full_name,
                                   epochs=str(epoch) + reloading,
                                   verbose=verbose,
                                   lr=lr,
                                   label_color_0=REPO_DATASET[
                                       evaluation_set_reporting[0]],
                                   label_color_1=REPO_DATASET[
                                       evaluation_set_reporting[1]])

        # WARNING : only saving if we decrease not loading former model if we relaod
        if (checkpointing
                and epoch % freq_checkpointing == 0) or (epoch + 1
                                                         == n_epochs):
            if checkpointing_metric != "loss-dev-all" and epoch < STARTING_CHECKPOINTING_WITH_SCORE:
                _checkpointing_metric = "loss-dev-all"
            elif checkpointing_metric != "loss-dev-all":
                _checkpointing_metric = checkpointing_metric
                if epoch == STARTING_CHECKPOINTING_WITH_SCORE:
                    checkpoint_score_saved = -report["score"]
                    printing("Checkoint info : switching "
                             "checkpoint_score_saved to {} : {}".format(
                                 checkpointing_metric, checkpoint_score_saved),
                             verbose_level=1,
                             verbose=verbose)
            elif checkpointing_metric == "loss-dev-all":
                _checkpointing_metric = checkpointing_metric
            else:
                raise (Exception("You missed a case"))

            dir_plot_detailed = simple_plot(
                final_loss=0,
                epoch_ls_1=epoch_ls_dev,
                epoch_ls_2=epoch_ls_dev,
                loss_2=loss_details_template.get("loss_binary", None),
                loss_ls=loss_details_template["loss_seq_prediction"],
                epochs=str(epoch) + reloading,
                label="dev-seq_prediction",
                label_2="dev-binary",
                save=True,
                dir=model.dir_model,
                verbose=verbose,
                verbose_level=1,
                lr=lr,
                prefix=model.model_full_name + "-details",
                show=False) if loss_details_template is not None else None

            dir_plot = simple_plot(final_loss=loss_train,
                                   loss_2=loss_developing,
                                   loss_ls=loss_training,
                                   epochs=str(epoch) + reloading,
                                   epoch_ls_1=epoch_ls_train,
                                   epoch_ls_2=epoch_ls_dev,
                                   label=label_train + "-train",
                                   label_2=label_dev + "-dev",
                                   save=True,
                                   dir=model.dir_model,
                                   verbose=verbose,
                                   verbose_level=1,
                                   lr=lr,
                                   prefix=model.model_full_name,
                                   show=False)

            sanity_check_checkpointing_metric(
                tasks, checkpointing_metric=_checkpointing_metric)

            if _checkpointing_metric != "loss-dev-all" or \
                    (epoch == (STARTING_CHECKPOINTING_WITH_SCORE-1) and checkpointing_metric != "loss-dev-all"):
                # for now only useful when different from loss --> compute metric on dev all and default always
                # assuing unitask thanks to sanity check
                assert overall_report_ls is not None, "ERROR overall_report_ls  was not defined "
                report = rep_tl.get_score(
                    overall_report_ls,
                    metric=TASKS_PARAMETER[tasks[0]].get("default_metric"),
                    data=REPO_DATASET[dev_path[0]],
                    info_score="all",
                    task=tasks[0])
                # Negative cause it's an accuracy
                checkpoint_score = -report["score"]
            else:
                checkpoint_score = loss_dev

            model, checkpoint_score_saved, counter_no_deacrease, saved_epoch, checkpoint_dir_former = \
                    checkpoint(loss_saved=checkpoint_score_saved, loss=checkpoint_score,
                               checkpointing_metric=_checkpointing_metric,
                               model=model, counter_no_decrease=counter_no_deacrease,
                               checkpoint_dir_former=checkpoint_dir_former,
                               saved_epoch=saved_epoch, model_dir=model.dir_model,
                               extra_checkpoint_label="1st_train" if not reload else "start_{}_ep-{}".format(starting_epoch, extra_arg_specific_label),
                               extra_arg_specific_label=extra_arg_specific_label,
                               info_checkpoint={"n_epochs": epoch, "batch_size": batch_size, "optimizer": optimizer,
                                                "gradient_clipping": clipping,
                                                "tasks_schedule_policy": policy,
                                                "teacher_force": teacher_force,
                                                "proportion_pred_train": proportion_pred_train,
                                                "train_data_path": train_path, "dev_data_path": dev_path,
                                                "other": {"error_curves": dir_plot, "loss": loss_dev,
                                                          "sanity_test": {"loss": loss_dev,
                                                                          "data": [REPO_DATASET[_dev_path] for _dev_path in dev_path],
                                                                          "batch_size": batch_size},
                                                          "error_curves_details": dir_plot_detailed,
                                                          "dropout_input": dropout_input,
                                                          "checkpointing_metric": _checkpointing_metric,
                                                          "multi_task_loss_ponderation": multi_task_loss_ponderation,
                                                          "weight_binary_loss": weight_binary_loss*int(auxilliary_task_norm_not_norm),
                                                          "weight_pos_loss": weight_pos_loss*int(auxilliary_task_pos),
                                                          "ponderation_normalize_loss": ponderation_normalize_loss,
                                                          "data": "dev", "seed(np/torch)": (SEED_NP, SEED_TORCH),
                                                          "extend_n_batch": extend_n_batch,
                                                          "lr": lr, "optim_strategy": "lr_constant",
                                                          "time_training(min)": "{0:.2f}".format(total_time/60),
                                                          "average_per_epoch(min)": "{0:.2f}".format((total_time/n_epochs)/60)}},
                               epoch=epoch, epochs=n_epochs-1,
                               keep_all_checkpoint=False if epoch > starting_epoch else True,# we have nothing to remove after 1st epoch
                               verbose=verbose)
            if counter_no_deacrease * freq_checkpointing >= BREAKING_NO_DECREASE:
                printing(
                    "CHECKPOINTING : Breaking training : loss did not decrease on dev for 10 checkpoints "
                    "so keeping model from {} epoch  ".format(saved_epoch),
                    verbose=verbose,
                    verbose_level=0)
                break
        printing(
            "LOSS train {:.3f}, dev {:.3f} for epoch {} out of {} epochs ",
            var=(loss_train, loss_dev, epoch, n_epochs),
            verbose=verbose,
            verbose_level=1)

        if timing:
            print("Summary : {}".format(
                OrderedDict([("_train_ep_time", _train_ep_time),
                             ("_create_iter_time", _create_iter_time),
                             ("_eval_time", _eval_time)])))

    writer.close()
    printing(
        "REPORT : run : \n tensorboard --logdir={} --host=localhost --port=9101  ",
        var=[dir_writer],
        verbose=verbose,
        verbose_level=1)

    #rep_tl.checkout_layer_name("encoder.seq_encoder.weight_ih_l0", model.named_parameters(), info_epoch="LAST")

    simple_plot(final_loss=loss_dev,
                loss_ls=loss_training,
                loss_2=loss_developing,
                epoch_ls_1=epoch_ls_train,
                epoch_ls_2=epoch_ls_dev,
                epochs=n_epochs,
                save=True,
                dir=model.dir_model,
                label=label_train,
                label_2=label_dev,
                lr=lr,
                prefix=model.model_full_name + "-LAST",
                verbose=verbose)

    return model.model_full_name
Exemple #4
0
        args.tasks, args.train_path)
    assert len(args.dev_path) == len(args.train_path)

    if run_mode == "test":
        assert args.test_paths is not None and isinstance(
            args.test_paths, list)
    if args.test_paths is not None:
        assert isinstance(args.test_paths, list) and isinstance(
            args.test_paths[0], list), "ERROR args.test_paths should be a list"

    if run_mode == "train":
        printing("CHECKPOINTING info : saving model every {}",
                 var=saving_every_epoch,
                 verbose=verbose,
                 verbose_level=1)
    use_gpu = use_gpu_(use_gpu=None, verbose=verbose)

    train_data_label = "|".join([
        REPO_DATASET.get(_train_path, "train_{}".format(i))
        for i, _train_path in enumerate(args.train_path)
    ])
    dev_data_label = "|".join([
        REPO_DATASET.get(_dev_path, "dev_{}".format(i))
        for i, _dev_path in enumerate(args.dev_path)
    ]) if args.dev_path is not None else None
    if use_gpu:
        model.to("cuda")

    if not debug:
        pdb.set_trace = lambda: None
def fine_tune(train_path,
              dev_path,
              test_path,
              n_epochs,
              model_full_name,
              learning_rate,
              tasks,
              evaluation=False,
              freq_checkpointing=1,
              freq_writer=1,
              fine_tune_label="",
              batch_size=2,
              freeze_ls_param_prefix=None,
              debug=False,
              verbose=0):
    if not debug:
        pdb.set_trace = lambda: 1
    use_gpu = use_gpu_(None)
    hardware_choosen = "GPU" if use_gpu else "CPU"
    printing("{} mode ",
             var=([hardware_choosen]),
             verbose_level=0,
             verbose=verbose)
    dict_path = os.path.join(CHECKPOINT_DIR, model_full_name + "-folder",
                             "dictionaries")
    model_dir = os.path.join(CHECKPOINT_DIR, model_full_name + "-folder")

    #batch_size = 50#model.arguments["info_checkpoint"]["batch_size"]
    word_decoding = False  #model.arguments["hyperparameters"]["decoder_arch"]["word_decoding"]
    char_decoding = True  #model.arguments["hyperparameters"]["decoder_arch"]["char_decoding"]
    #learning_rate = 0.00001#model.arguments["other"]["lr"]
    printing(
        "LOADED Optimization arguments from last checkpoint are "
        " learning rate {} batch_size {} ",
        var=[learning_rate, batch_size],
        verbose_level=0,
        verbose=verbose)
    print(
        "WARNING : char_decoding {} and word_decoding should not be loaded here "
        .format(char_decoding, word_decoding))

    warmup = False
    test_before_run = False
    RUN_ID = str(uuid4())[0:5]
    LABEL_GRID = fine_tune_label if not warmup else "WARMUP"
    LABEL_GRID = "test_before_run-" + LABEL_GRID if test_before_run else LABEL_GRID

    OAR = os.environ.get('OAR_JOB_ID') + "_rioc-" if os.environ.get(
        'OAR_JOB_ID', None) is not None else ""
    print("OAR=", OAR)
    OAR = RUN_ID if OAR == "" else OAR
    LABEL_GRID = OAR + "-" + LABEL_GRID

    GRID_FOLDER_NAME = LABEL_GRID if len(LABEL_GRID) > 0 else RUN_ID
    GRID_FOLDER_NAME += "-summary"
    dir_grid = os.path.join(CHECKPOINT_DIR, GRID_FOLDER_NAME)
    os.mkdir(dir_grid)
    printing("FINE TUNE RUN : target directory   {} made".format(dir_grid),
             verbose=verbose,
             verbose_level=0)

    train(
        train_path=train_path,
        dev_path=dev_path,
        test_path=test_path,
        n_epochs=n_epochs,
        batch_size=batch_size,
        get_batch_mode_all=True,
        bucketing=True,
        dict_path=dict_path,
        model_full_name=model_full_name,
        reload=True,
        model_dir=model_dir,
        symbolic_root=True,
        symbolic_end=True,
        overall_label=LABEL_GRID,
        overall_report_dir=dir_grid,
        model_specific_dictionary=True,
        print_raw=False,
        compute_mean_score_per_sent=False,
        word_decoding=word_decoding,
        char_decoding=char_decoding,
        checkpointing=True,
        normalization=True,
        pos_specific_path=None,
        expand_vocab_dev_test=True,
        lr=learning_rate,
        extend_n_batch=
        2,  #model.arguments["info_checkpoint"]["other"]["extend_n_batch"],
        freq_writer=freq_writer,
        freq_checkpointing=
        freq_checkpointing,  #compute_mean_score_per_sent=True,
        score_to_compute_ls=["exact"],
        mode_norm_ls=["all", "NEED_NORM", "NORMED"],
        compute_scoring_curve=False,
        add_start_char=1,
        add_end_char=1,
        tasks=tasks,
        extra_arg_specific_label=fine_tune_label,
        freeze_ls_param_prefix=freeze_ls_param_prefix,
        freezing_mode=freeze_ls_param_prefix is not None,
        debug=False,
        use_gpu=None,
        verbose=0)

    if evaluation:
        if test_path is not None:
            dict_path = os.path.join(CHECKPOINT_DIR,
                                     model_full_name + "-folder",
                                     "dictionaries")
            printing("GRID : START EVALUATION FINAL ",
                     verbose_level=0,
                     verbose=verbose)
            eval_data_paths = []  #[train_path, dev_path]
            if isinstance(test_path, list):
                eval_data_paths.extend(test_path)
            else:
                eval_data_paths.append(test_path)
            eval_data_paths = list(set(eval_data_paths))
            for eval_data in eval_data_paths:
                eval_label = REPO_DATASET[eval_data]
                evaluate(model_full_name=model_full_name,
                         data_path=eval_data,
                         dict_path=dict_path,
                         use_gpu=use_gpu,
                         label_report=eval_label,
                         overall_label=LABEL_GRID,
                         score_to_compute_ls=["exact"],
                         mode_norm_ls=["all", "NEED_NORM", "NORMED"],
                         normalization=True,
                         print_raw=False,
                         model_specific_dictionary=True,
                         get_batch_mode_evaluate=False,
                         bucket=True,
                         compute_mean_score_per_sent=True,
                         batch_size=batch_size,
                         debug=debug,
                         extra_arg_specific_label=fine_tune_label,
                         word_decoding=word_decoding,
                         char_decoding=char_decoding,
                         dir_report=model_dir,
                         verbose=1)
def evaluate(batch_size, data_path, tasks, evaluated_task,
             write_report=True, dir_report=None,
             dict_path=None, model_full_name=None,
             score_to_compute_ls=None, mode_norm_ls=None, get_batch_mode_evaluate=True,
             overall_label="ALL_MODELS", overall_report_dir=CHECKPOINT_DIR, bucket=False,
             model_specific_dictionary=True, label_report="",
             print_raw=False,
             model=None,
             compute_mean_score_per_sent=False, write_output=False,
             word_decoding=False, char_decoding=True,
             extra_arg_specific_label="", scoring_func_sequence_pred="BLUE",
             max_char_len=None,
             normalization=True, debug=False,
             force_new_dic=False,
             use_gpu=None, verbose=0):
    assert model_specific_dictionary, "ERROR : only model_specific_dictionary = True supported now"
    # NB : now : you have to load dictionary when evaluating (cannot recompute) (could add in the LexNormalizer ability)
    use_gpu = use_gpu_(use_gpu)
    hardware_choosen = "GPU" if use_gpu else "CPU"
    printing("{} mode ", var=([hardware_choosen]), verbose_level=0, verbose=verbose)
    printing("EVALUATION : evaluating with compute_mean_score_per_sent {}".format(compute_mean_score_per_sent), verbose=verbose, verbose_level=1)

    if mode_norm_ls is None:
        mode_norm_ls = ["all", "NORMED", "NEED_NORM"]
    if write_report:
        assert dir_report is not None
    if model is not None:
        assert model_full_name is None and dict_path is None, \
            "ERROR as model is provided : model_full_name and dict_path should be None"
    else:
        assert model_full_name is not None and dict_path is not None,\
            "ERROR : model_full_name and dict_path required to load model "
    voc_size = None
    if not debug:
        pdb.set_trace = lambda: 1

    model = LexNormalizer(generator=Generator, load=True, model_full_name=model_full_name,
                          tasks=tasks,
                          word_decoding=word_decoding, char_decoding=char_decoding,
                          voc_size=voc_size, use_gpu=use_gpu, dict_path=dict_path, model_specific_dictionary=True,
                          dir_model=os.path.join(PROJECT_PATH, "checkpoints", model_full_name + "-folder"),
                          extra_arg_specific_label=extra_arg_specific_label,
                          loading_sanity_test=True,
                          verbose=verbose
                          ) if model is None else model

    if score_to_compute_ls is None:
        score_to_compute_ls = ["edit", "exact"]
        if model.auxilliary_task_norm_not_norm:
            score_to_compute_ls.extend(SCORE_AUX)

    printing("EVALUATION : Evaluating {} metric with details {}  ", var=[score_to_compute_ls, mode_norm_ls], verbose=verbose, verbose_level=3)

    #rep_tl.checkout_layer_name("encoder.seq_encoder.weight_ih_l0", model.named_parameters(), info_epoch="EVAL")

    readers_eval = readers_load(datasets=[data_path], tasks=[evaluated_task], word_dictionary=model.word_dictionary,
                                word_dictionary_norm=model.word_nom_dictionary, char_dictionary=model.char_dictionary,
                                pos_dictionary=model.pos_dictionary, xpos_dictionary=model.xpos_dictionary,
                                type_dictionary=model.type_dictionary, use_gpu=use_gpu,
                                norm_not_norm=model.auxilliary_task_norm_not_norm, word_decoder=word_decoding,
                                bucket=bucket,max_char_len=max_char_len,
                                add_start_char=1, add_end_char=1, symbolic_end=model.symbolic_end, symbolic_root=model.symbolic_root,
                                verbose=verbose)
    batchIter = data_gen_multi_task_sampling_batch(tasks=[evaluated_task], readers=readers_eval, batch_size=batch_size,
                                                   word_dictionary=model.word_dictionary,
                                                   char_dictionary=model.char_dictionary,
                                                   pos_dictionary=model.pos_dictionary,
                                                   get_batch_mode=get_batch_mode_evaluate,
                                                   word_dictionary_norm=model.word_nom_dictionary,
                                                   extend_n_batch=1, dropout_input=0,
                                                   verbose=verbose)

    model.eval()
    # the formulas comes from normalization_erros functions
    score_dic_new, formulas = greedy_decode_batch(char_dictionary=model.char_dictionary, verbose=verbose, gold_output=True,
                                                  score_to_compute_ls=score_to_compute_ls, use_gpu=use_gpu,
                                                  write_output=write_output, eval_new=True,
                                                  task_simultaneous_eval=[evaluated_task],
                                                  stat="sum", mode_norm_score_ls=mode_norm_ls,
                                                  label_data=REPO_DATASET[data_path],
                                                  batchIter=batchIter, model=model,
                                                  scoring_func_sequence_pred=scoring_func_sequence_pred,
                                                  compute_mean_score_per_sent=compute_mean_score_per_sent,
                                                  batch_size=batch_size)
    for score_name, formula in formulas.items():
        if isinstance(formula, tuple) and len(formula) > 1:
            (num, denom) = formula
            score_value = score_dic_new[num]/score_dic_new[denom] if score_dic_new[denom] > 0 else None
            #score_value_per_sent =
            if score_dic_new[denom] == 0:
                print("WARNING Score {} has denumerator {} null and numerator {} equal to  {}".format(score_name, denom,
                                                                                                      num,
                                                                                                      score_dic_new[num]
                                                                                                      ))
            reg = re.match("([^-]+)-([^-]+)-.*", num)
            mode_norm = reg.group(1)
            task = reg.group(2)
            # report all in a dictionary
            if not reportint_unavailable:
                report = report_template(metric_val=score_name,
                                         info_score_val=mode_norm,
                                         score_val=score_value,
                                         n_sents=score_dic_new["n_sents"],
                                         avg_per_sent=0,
                                         n_tokens_score=score_dic_new.get(mode_norm+"-"+task+"-gold-count",-1),
                                         model_full_name_val=model.model_full_name,
                                         task=task,
                                         report_path_val=model.arguments["checkpoint_dir"],
                                         evaluation_script_val="exact_match",
                                         model_args_dir=model.args_dir,
                                         data_val=REPO_DATASET[data_path])
            else:
                report = {"report ":0}
            over_all_report_dir = os.path.join(dir_report, model.model_full_name + "-report-" + label_report + ".json")
            over_all_report_dir_all_models = os.path.join(overall_report_dir, overall_label + "-report.json")
            writing_mode = "w" if not os.path.isfile(over_all_report_dir) else "a"
            writing_mode_all_models = "w" if not os.path.isfile(over_all_report_dir_all_models) else "a"
            for dir, writing_mode in zip([over_all_report_dir, over_all_report_dir_all_models],
                                         [writing_mode, writing_mode_all_models]):
                if writing_mode == "w":
                    _all_report = [report]
                    json.dump([report], open(dir, writing_mode))
                    printing("REPORT : Creating new report  {} ".format(dir), verbose=verbose, verbose_level=1)
                else:
                    all_report = json.load(open(dir, "r"))
                    all_report.append(report)
                    json.dump(all_report, open(dir, "w"))
    printing("NEW REPORT metric : {} ", var=[" ".join(list(formulas.keys()))], verbose=verbose, verbose_level=1)
    try:
        printing("NEW REPORT : model specific report saved {} ".format(over_all_report_dir), verbose=verbose, verbose_level=1)
        printing("NEW REPORT : overall report saved {} ".format(over_all_report_dir_all_models), verbose=verbose,verbose_level=1)
    except Exception as e:
        print(Exception(e))
    if writing_mode == "w":
        all_report = _all_report
    return all_report