Beispiel #1
0
def pred_word_to_list(pred_word, special_symb_ls):
    index_special_ls = []

    pred_word = [pred_word]
    ind_pred_word = 0
    counter = 0
    while True:
        counter += 1
        index_special_ls = []
        _pred_word = pred_word[ind_pred_word]
        # Looking for all special character (we only look at the first one found)
        for special_symb in special_symb_ls:
            index_special_ls.append(_pred_word.find(special_symb))
        indexes = np.argsort(index_special_ls)
        index_special_char=-1
        # Getting the index and the character of the first special character if nothing we get -1
        for ind, a in enumerate(indexes):
            if index_special_ls[a] >= 0:
                special_symb = special_symb_ls[a]
                index_special_char = index_special_ls[a]
                break
            if ind > len(indexes):
                index_special_char = -1
                special_symb = ""
                break
        # if found a special character
        if (index_special_char) >= 0:
            starting_seq = [_pred_word[:index_special_char]] if index_special_char> 0 else []
            middle = [_pred_word[index_special_char:index_special_char + len(special_symb) ]]
            end_seq = [_pred_word[index_special_char + len(special_symb):]]
            if len(end_seq[0].strip()) == 0:
                end_seq = []
            _pred_word_ls = starting_seq + middle +end_seq
            pred_word[ind_pred_word] = _pred_word_ls[0]
            if len(_pred_word_ls) > 0:
                pred_word.extend(_pred_word_ls[1:])
            ind_pred_word += 1
            pdb.set_trace()
            if len(starting_seq) > 0:
                ind_pred_word += 1
        else:
            ind_pred_word += 1
        pdb.set_trace()
        if ind_pred_word >= len(pred_word):
            break

    new_word = []
    # transform the way we splitted in list of characters (including special ones)
    for word in pred_word:
        if word in special_symb_ls:
            new_word.append(word)
        else:
            new_word.extend(list(word))

    return new_word
    def get_loss(loss_func,
                 label,
                 num_label_dic,
                 labels,
                 logits_dict,
                 task,
                 logit_label,
                 head_label=None):
        if label not in ["heads", "types"]:
            try:
                loss = loss_func(
                    logits_dict[logit_label].view(-1,
                                                  num_label_dic[logit_label]),
                    labels.view(-1))
            except Exception as e:
                print(e)
                pdb.set_trace()
                print("ERROR task {} num_label {} , labels {} ".format(
                    task, num_label_dic, labels.view(-1)))
                raise (e)

        elif label == "heads":
            # trying alternative way for loss
            loss = CrossEntropyLoss(
                ignore_index=LABEL_PARAMETER[label]["pad_value"],
                reduction="sum")(logits_dict[logit_label].view(
                    -1, logits_dict[logit_label].size(2)), labels.view(-1))
            # other possibilities is to do log softmax then L1 loss (lead to other results)

        elif label == "types":
            assert head_label is not None, "ERROR head_label should be passed"
            # gold label after removing 0 gold
            gold = labels[head_label != LABEL_PARAMETER["heads"]["pad_value"]]
            # pred logits (after removing -1) on the gold heads
            pred = logits_dict["parsing-types"][(
                head_label != LABEL_PARAMETER["heads"]["pad_value"]
            ).nonzero()[:, 0], (
                head_label != LABEL_PARAMETER["heads"]["pad_value"]
            ).nonzero()[:, 1], head_label[
                head_label != LABEL_PARAMETER["heads"]["pad_value"]]]
            # remark : in the way it's coded for paring : the padding is already removed (so ignore index is null)
            loss = loss_func(pred, gold)

        return loss
def preprocess_batch_string_for_bert(batch,
                                     start_token,
                                     end_token,
                                     rp_space=False):
    """
    adding starting and ending token in raw sentences
    :param batch:
    :return:
    """
    for i in range(len(batch)):
        try:
            batch[i][0] = start_token
        except:
            pdb.set_trace()
        batch[i][-1] = end_token
        if rp_space:
            batch[i] = rp_space_func(batch[i])
        batch[i] = " ".join(batch[i])
    return batch
Beispiel #4
0
 def sanity_test_parsing_label(labels, output_tokens_tensor_new,
                               input_alignement_with_raw, cumulate_shift):
     for sent in range(labels.size(0)):
         ind_max = len(cumulate_shift[sent]) - 1
         for _ in range(5):
             ind = np.random.choice(range(ind_max))
             # the new label must be equal to the old one at the corresponding position + 1 + the number of non-first-bpe-token (original indexing of the label)
             if output_tokens_tensor_new[sent][ind] not in [
                     ROOT_HEADS_INDEX + 1, END_HEADS_INDEX,
                     PAD_ID_LOSS_STANDART
             ]:
                 try:
                     assert output_tokens_tensor_new[sent][ind] == labels[sent, int(input_alignement_with_raw[sent][ind])]+CLS_ADJUST+cumulate_shift[sent][labels[sent, int(input_alignement_with_raw[sent][ind])]], \
                     "ERROR sent {} ind word {} " \
                     "new {} and old {} cumulted {} ".format(sent, ind, output_tokens_tensor_new[sent][ind],
                                                         labels[sent, input_alignement_with_raw[sent][ind]], cumulate_shift[sent][ind])
                 except AssertionError as e:
                     print(e)
                     pdb.set_trace()
def readers_load(datasets,
                 tasks,
                 word_dictionary,
                 word_dictionary_norm,
                 char_dictionary,
                 pos_dictionary,
                 xpos_dictionary,
                 type_dictionary,
                 bert_tokenizer,
                 word_decoder=False,
                 must_get_norm=True,
                 bucket=True,
                 input_level_ls=None,
                 run_mode="train",
                 add_start_char=1,
                 add_end_char=1,
                 symbolic_end=True,
                 symbolic_root=True,
                 verbose=1):

    readers = {}
    simultanuous_training = False  #depreciated
    assert "all" not in tasks, "ERROR not supported yet (pb for simultanuous training..) "
    if not "all" in tasks and not simultanuous_training:
        try:
            assert len(tasks) == len(datasets), "ERROR : as simultanuous_training is {} : " \
                                                "we need 1 dataset per task but have only {} for task {} ".format(simultanuous_training, datasets, tasks)

        except Exception as e:
            pdb.set_trace()
            datasets = [datasets[0] for _ in tasks]
            # SHOULD NOT DO THAT !!
            print("WARNING : duplicating readers", e)

    elif not simultanuous_training:
        assert len(
            tasks) == 1, "ERROR : if all should have only all nothing else"
        printing("TRAINING : MultiTask Iterator wit task 'all' ",
                 verbose_level=1,
                 verbose=verbose)
    elif simultanuous_training:
        printing(
            "TRAINING : Training simulatnuously tasks provided in {} (should have all required labels in datasets)",
            verbose_level=1,
            verbose=verbose)
        raise (Exception("Not supported yet --> should handle the loop "))

    for simul_task, data in zip(tasks, datasets):
        normalization_in_reader = does_one_task_require_normalization(
            simul_task)
        # 1 reader per simultaneously trained task
        readers[",".join(simul_task)] = conllu_data.read_data_to_variable(
            data,
            word_dictionary,
            char_dictionary,
            pos_dictionary,
            xpos_dictionary,
            type_dictionary,
            word_decoder=word_decoder,
            symbolic_end=symbolic_end,
            symbolic_root=symbolic_root,
            dry_run=0,
            normalization=normalization_in_reader,
            bucket=bucket,
            add_start_char=add_start_char,
            add_end_char=add_end_char,
            tasks=simul_task,
            max_char_len=None,
            must_get_norm=must_get_norm,
            bert_tokenizer=bert_tokenizer,
            input_level_ls=input_level_ls,
            run_mode=run_mode,
            word_norm_dictionary=word_dictionary_norm,
            pad_id=bert_tokenizer.convert_tokens_to_ids(
                bert_tokenizer.pad_token),
            verbose=verbose)

    return readers
Beispiel #6
0
def get_bpe_label_word_level_task(labels,
                                  batch,
                                  input_tokens_tensor,
                                  input_alignement_with_raw,
                                  use_gpu,
                                  label_name,
                                  pad,
                                  graph_labels=False):

    if labels is not None:
        output_tokens_tensor = np.array(labels.cpu())
    else:
        output_tokens_tensor = None
    new_input = np.array(input_tokens_tensor.cpu())
    len_max = max([len(sent) for sent in new_input])
    new_input = [[inp
                  for inp in sent] + [pad for _ in range(len_max - len(sent))]
                 for sent in new_input]
    # we mask bpe token that have been split (we don't mask the first bpe token of each word)
    _input_mask = [[
        0 if new_input[ind_sent][ind_tok] == pad
        or input_alignement_with_raw[ind_sent][ind_tok - 1]
        == input_alignement_with_raw[ind_sent][ind_tok] else 1
        for ind_tok in range(len(new_input[ind_sent]))
    ] for ind_sent in range(len(new_input))]
    cumulate_shift = None
    cumulate_shift_tok = None
    if graph_labels:
        # for each sentence : each bpe token : we count the number of multi-bpe token before it
        def get_cumulated_non_first_bpe_counter(sent):
            counter = 0
            new_sent = []
            new_sent_tok = []
            counter_former = 0
            cumulated = 0
            for ind, token in enumerate(sent):
                if ind + 1 < len(sent) and token == sent[ind +
                                                         1] and token != 1000:
                    counter += 1
                    # same as new_sent but tok aligned
                    #new_sent_tok.append(-1)
                elif token != 1000:
                    new_sent.append(counter_former + cumulated)
                    new_sent_tok.append(counter_former + cumulated)
                    cumulated += counter_former
                    counter_former = counter
                    counter = 0
            return new_sent

        # TO-REMOVE
        def test_get_cumulated_non_first_bpe_counter():
            assert [0, 0, 0, 1, 1, 1, 3, 3, 3, 5, 5,
                    5] == get_cumulated_non_first_bpe_counter([
                        0, 1, 2, 2, 3, 4, 5, 5, 5, 6, 7, 8, 8, 8, 9, 10, 11,
                        1000
                    ])
            assert [0, 0, 0, 1, 1, 1, 3, 3, 3, 5, 5,
                    5] == get_cumulated_non_first_bpe_counter(
                        [0, 1, 2, 2, 3, 4, 5, 5, 5, 6, 7, 8, 8, 8, 9, 10, 11])
            #print("TEST passed ")

        #test_get_cumulated_non_first_bpe_counter()

        cumulate_shift = [
            get_cumulated_non_first_bpe_counter(
                input_alignement_with_raw[ind_sent])
            for ind_sent in range(len(input_alignement_with_raw))
        ]
        #cumulate_shift = [cumulate_shift_all[ind_sent][0] for ind_sent in range(len(input_alignement_with_raw))]
        #cumulate_shift_tok = [cumulate_shift_all[ind_sent][1] for ind_sent in range(len(input_alignement_with_raw))]

    output_tokens_tensor_new = []
    for ind_sent in range(len(_input_mask)):
        output_tokens_tensor_new_ls = []
        _cumulated_shift_tok = []
        shift = 0
        for ind_tok in range(len(_input_mask[ind_sent])):
            mask = _input_mask[ind_sent][ind_tok]

            if labels is not None:
                try:
                    label = output_tokens_tensor[ind_sent, ind_tok - shift]
                    if graph_labels:
                        # as CLS is appended at the begining of each sentences : we need to adjust the labels for it
                        # TODO : !! cumulated is indexed by bpe tokenized sequence :
                        #   label is indexed by original index : should get cumulated[lnew_index_label]
                        # CLS and SEQ points to the first token indexed by -1 so become 1
                        if label not in [
                                ROOT_HEADS_INDEX, END_HEADS_INDEX
                        ] and cumulate_shift[ind_sent][label] > 0:
                            label += cumulate_shift[ind_sent][label]
                        label += CLS_ADJUST
                except Exception as e:
                    try:
                        assert input_alignement_with_raw[ind_sent][
                            ind_tok] == 1000, "ERROR we should have reached the end of get labels also "
                        label = LABEL_PARAMETER[label_name][
                            "pad_value"]  #PAD_ID_TAG if not graph_labels else PAD_ID_HEADS # output_tokens_tensor[ind_sent, output_tokens_tensor.shape[1] - 1]
                    except Exception as f:
                        print(
                            "ERROR (get_bpe_labels): we reached the end of output labels but input is not done ! ",
                            f)
                        print(
                            "ERROR ind_send:{} ind_tok {} shift {} output_tokens_tensor {} alignement {} -  {}"
                            .format(ind_sent, ind_tok, shift,
                                    output_tokens_tensor,
                                    input_alignement_with_raw[ind_sent], e))
                        print("ERROR ind_send ", batch.raw_input,
                              batch.raw_output)
                        pdb.set_trace()
                        #label = output_tokens_tensor[ind_sent, output_tokens_tensor.shape[1] - 1]
                        raise (e)

            if mask == 0 and labels is not None:
                # 1 for _PAD_POS fpr PAD_ID_HEADS 0
                pad = LABEL_PARAMETER[label_name][
                    "pad_value"]  #PAD_ID_TAG if not graph_labels else PAD_ID_HEADS
                output_tokens_tensor_new_ls.append(pad)
                shift += 1
            elif labels is not None:
                output_tokens_tensor_new_ls.append(label)
                #print("building output_tokens_tensor_new_ls... ", output_tokens_tensor_new_ls)
        #print("based on alignement ", input_alignement_with_raw)
        output_tokens_tensor_new.append(output_tokens_tensor_new_ls)

    def sanity_test_parsing_label(labels, output_tokens_tensor_new,
                                  input_alignement_with_raw, cumulate_shift):
        for sent in range(labels.size(0)):
            ind_max = len(cumulate_shift[sent]) - 1
            for _ in range(5):
                ind = np.random.choice(range(ind_max))
                # the new label must be equal to the old one at the corresponding position + 1 + the number of non-first-bpe-token (original indexing of the label)
                if output_tokens_tensor_new[sent][ind] not in [
                        ROOT_HEADS_INDEX + 1, END_HEADS_INDEX,
                        PAD_ID_LOSS_STANDART
                ]:
                    try:
                        assert output_tokens_tensor_new[sent][ind] == labels[sent, int(input_alignement_with_raw[sent][ind])]+CLS_ADJUST+cumulate_shift[sent][labels[sent, int(input_alignement_with_raw[sent][ind])]], \
                        "ERROR sent {} ind word {} " \
                        "new {} and old {} cumulted {} ".format(sent, ind, output_tokens_tensor_new[sent][ind],
                                                            labels[sent, input_alignement_with_raw[sent][ind]], cumulate_shift[sent][ind])
                    except AssertionError as e:
                        print(e)
                        pdb.set_trace()
                    #print("TEST passed for sent {} word {}".format(sent, ind))

    if graph_labels and labels is not None:
        start_time = time.time()
        sanity_test_parsing_label(labels, output_tokens_tensor_new,
                                  input_alignement_with_raw, cumulate_shift)
        #print("TIME TEST", time.time()-start_time)
    if labels is not None:
        ##if graph_labels:
        ##   print("GOLD:OUTPUT BEFORE PREPROCESSING", output_tokens_tensor)
        #    print("GOLD:OUTPUT AFTER TOK PREPROCESSING", output_tokens_tensor_new)
        output_tokens_tensor = torch.Tensor(output_tokens_tensor_new).long()
        head_mask = torch.Tensor(_input_mask).long()
    input_tokens_tensor = torch.Tensor(new_input).long()
    if use_gpu:
        if labels is not None:
            head_mask = head_mask.cuda()
            output_tokens_tensor = output_tokens_tensor.cuda()
        input_tokens_tensor = input_tokens_tensor.cuda()

    return output_tokens_tensor, head_mask, input_tokens_tensor, cumulate_shift
def write_conll_multitask(format, dir_pred, dir_original, src_text_ls,
                          pred_per_task, tasks, task_parameters, cp_paste=False, gold=False,
                          all_indexes=None, sep_token=None, cls_token=None,
                          sent_id=None, raw_text=None,
                          append_mwe_ind= None,
                          append_mwe_row= None,
                          ind_batch=0, new_file=False, cut_sent=False, verbose=0):

    assert format in ["conll"]
    max_len_word = None
    writing_top = 1
    # assert each task is predicting as many sample per batch
    pred_task_len_former = -1
    task_former = ""

    # assertion on number of samples predicted
    for task_label in pred_per_task:

        pred_task_len = len(pred_per_task[task_label]) if gold else len(pred_per_task[task_label][writing_top-1])
        _task = re.match("(.*)-(.*)", task_label)
        if _task is not None:  # , "ERROR writer could not match {}".format(task_label)
            task = _task.group(1)
        else:
            task = task_label
        if pred_task_len_former > 0:
            assert pred_task_len == pred_task_len_former, \
                "ERROR {} and {} task ".format(task_former, task_label)
            if not gold:
                assert pred_task_len == len(src_text_ls[task_parameters[task]["input"]]), "ERROR  src len {} and pred len {} ".format(len(src_text_ls[task_parameters[task]["input"]]),pred_task_len)
            # we check also other input length
            if src_text_ls.get("input_masked") is not None:
                assert pred_task_len == len(src_text_ls["input_masked"])
            if src_text_ls.get("wordpieces_inputs_words") is not None:
                assert pred_task_len == len(src_text_ls["wordpieces_inputs_words"]), "ERROR mismatch source " \
                                                                            "wordpieces_inputs_words {}  " \
                                                                            "and prediction {} ".format(src_text_ls, pred_per_task[task_label])
            if src_text_ls.get("wordpieces_inputs_raw_tokens") is not None:
                assert pred_task_len == len(src_text_ls["wordpieces_inputs_raw_tokens"]), \
                                    "ERROR mismatch source wordpieces_inputs_" \
                                    "raw_tokens {} and prediction {} ".format(src_text_ls, pred_per_task[task_label])
            try:
                assert pred_task_len == all_indexes.shape[0], "ERROR mismatch index {}  and all_indexes {} : pred {}".format(pred_task_len, all_indexes.shape[0], pred_per_task[task_label])
            except:
                pdb.set_trace()
        pred_task_len_former = pred_task_len

        task_former = task_label
        if format == "conll":
            mode_write = "w" if new_file else "a"
        if new_file:
            printing("CREATING NEW FILE (io_/dat/normalized_writer) : {} ", var=[dir_pred], verbose=verbose,
                     verbose_level=1)

    pos_label = "pos-pos" if not gold else "pos"
    types_label = "parsing-types" if not gold else "types"
    heads_label = "parsing-heads" if not gold else "heads"
    n_masks_mwe_label = "n_masks_mwe-n_masks_mwe" if not gold else "n_masks_mwe"
    mwe_detection_label = "mwe_detection-mwe_detection" if not gold else "mwe_detection"

    with open(dir_pred, mode_write) as norm_file:
        with open(dir_original, mode_write) as original:
            len_original = 0
            for ind_sent in range(all_indexes.shape[0]):
                pred_sent = OrderedDict()
                # NB : length assertion for each input-output (correcting if possible)
                # TODO standartize !!  INCONSITENCIES WHEN GOLD TRUE AND GOLD FALSE, IF GOLD : pred_per_task is indexed by labels (no relation 1-1 to task and src ! )
                for task_label_or_gold_label in pred_per_task:
                    #task, _, label_processed = get_task_name_based_on_logit_label(task_label, label_processed)
                    if gold:
                        pred_sent[task_label_or_gold_label] = pred_per_task[task_label_or_gold_label][ind_sent]
                    else:
                        pred_sent[task_label_or_gold_label] = pred_per_task[task_label_or_gold_label][writing_top-1][ind_sent]
                    try:
                        # TODO : standartize  (the first if is needed because we handle at the same time gold data indexed by label and prediction labelled by task+label
                        if gold:
                            try:
                                src = src_text_ls[LABEL_PARAMETER[task_label_or_gold_label]["default_input"]][ind_sent]
                            except Exception as e:
                                src = src_text_ls["input_masked"][ind_sent]
                        else:
                            _task = re.match("(.*)-(.*)", task_label_or_gold_label)
                            assert _task is not None#, "ERROR writer could not match {}".format(task_label)
                            _label = _task.group(2)
                            _task = _task.group(1)
                            src = src_text_ls[TASKS_PARAMETER[_task]["input"]][ind_sent]

                        assert len(src) == len(pred_sent[task_label_or_gold_label]),"WARNING : (writer) task {} original_sent len {} {} \n  predicted sent len {} {}".format(task_label_or_gold_label, len(src), src,len(pred_sent[task_label_or_gold_label]), pred_sent[task_label_or_gold_label])
                    except AssertionError as e:
                        print(e)
                        pdb.set_trace()
                        if len(src) > len(pred_sent[task_label_or_gold_label]):
                            pred_sent[task_label_or_gold_label].extend(["UNK" for _ in range(len(src)-len(pred_sent[task_label_or_gold_label]))])
                            print("WARNING (writer) : original larger than prediction : so appending UNK token for writing")
                        else:
                            print("WARNING (writer) : original smaller than prediction for ")

                if sent_id is not None and raw_text is not None:
                    #norm_file.write("\n")
                    #original.write("\n")

                    norm_file.write(sent_id[ind_sent])
                    original.write(sent_id[ind_sent])
                    norm_file.write(raw_text[ind_sent])
                    original.write(raw_text[ind_sent])
                else:
                    norm_file.write("#\n")
                    original.write("#\n")
                    norm_file.write("#sent_id = {} \n".format(ind_sent+ind_batch+1))
                    original.write("#sent_id = {} \n".format(ind_sent+ind_batch+1))
                ind_adjust = 0

                #for ind, original_token in enumerate(original_sent):
                last_mwe_index = -1
                adjust_mwe = 0
                for ind in all_indexes[ind_sent, :]:
                    # WE REMOVE SPECIAL TOKENS ONLY IF THEY APPEAR AT THE BEGINING OR AT THE END
                    # on the source token !! (it tells us when we stop) (we nevern want to use gold information)
                    if "-" in ind and ind != "-1":
                        matching_mwe_ind = re.match("([0-9]+)-([0-9]+)", str(ind))
                        assert matching_mwe_ind is not None, "ERROR ind is {} : could not found mwe index".format(ind)
                        last_mwe_index = int(matching_mwe_ind.group(2))
                        ind_mwe = int(matching_mwe_ind.group(1))

                        original_token = src_text_ls["wordpieces_inputs_raw_tokens"][ind_sent][ind_mwe] if mwe_detection_label in pred_per_task or "wordpieces_inputs_words" in pred_per_task or n_masks_mwe_label in pred_per_task else "NOT_NEEDED"
                        adjust_mwe += (last_mwe_index-ind_mwe)
                        #assert ind_adjust == 0, "ERROR not supported"

                        mwe_meta = "Norm={}|mwe_detection={}|n_masks_mwe={}".format("_", pred_sent[mwe_detection_label][ind_mwe] if mwe_detection_label in pred_per_task else "_",
                                                                                    pred_sent[n_masks_mwe_label][ind_mwe] if n_masks_mwe_label in pred_per_task else "_")

                        norm_file.write("{index}\t{original}\t_\t{pos}\t_\t_\t{dep}\t_\t{types}\t{norm}\n".format(index=ind, original=original_token, pos="_", types="_", dep="_", norm=mwe_meta))
                        original.write("{}\t{}\t_\t_\t_\t_\t_\t_\t{}\t_\n".format(ind, original_token, "_"))
                        continue
                    else:
                        ind = int(ind)
                        try:
                            if "normalize" in [task for _tasks in tasks for task in _tasks]:

                                original_token = src_text_ls["wordpiece_words_src_aligned_with_norm"][ind_sent][ind]
                                original_pretokenized_field = "wordpiece_words_src_aligned_with_norm"
                            else:
                                original_token = src_text_ls["wordpieces_inputs_words"][ind_sent][ind]
                                original_pretokenized_field = "wordpieces_inputs_words"
                        except Exception as e:
                            original_token = src_text_ls["input_masked"][ind_sent][ind]
                            original_pretokenized_field = "input_masked"
                        # asserting that we have everything together on the source side
                        if ind > last_mwe_index:
                            if src_text_ls.get("wordpieces_inputs_raw_tokens") is not None:
                                try:
                                    assert src_text_ls[original_pretokenized_field][ind_sent][ind] == src_text_ls["wordpieces_inputs_raw_tokens"][ind_sent][ind-adjust_mwe], \
                                    "ERROR sequence {} on non-mwe tokens : raw and tokenized " \
                                    "should be same but are raw {} tokenized {}".format(original_pretokenized_field, src_text_ls["wordpieces_inputs_raw_tokens"][ind_sent][ind],
                                                                                        src_text_ls[original_pretokenized_field][ind_sent][ind+adjust_mwe])
                                except:
                                    print("WARNING sanity checking input failed (nomalized_writer) (might be due to dropout) {}".format(e))
                    max_len_word = max(len(original_token), len_original)
                    #if original_token in SPECIAL_TOKEN_LS and (ind+1 == len(original_sent) or ind == 0):
                    if (original_token in SPECIAL_TOKEN_LS or original_token in [cls_token, sep_token]):
                        # ind 0 is skipped because it corresponds to CLS
                        ind_adjust = 1
                        continue

                    pos = pred_sent[pos_label][ind] if pos_label in pred_per_task else "_"
                    types = pred_sent[types_label][ind] if types_label in pred_per_task else "_"
                    heads = pred_sent[heads_label][ind] if heads_label in pred_per_task else ind - 1

                    tenth_col = "Norm={}|mwe_detection={}|n_masks_mwe={}".format(pred_sent["normalize"][ind] if "normalize" in pred_per_task else "_",
                                                                                 pred_sent[mwe_detection_label][ind-adjust_mwe] if mwe_detection_label in pred_per_task else "_",
                                                                                 pred_sent[n_masks_mwe_label][ind-adjust_mwe] if n_masks_mwe_label in pred_per_task else "_")

                    if append_mwe_ind is not None:
                        # we need one list of mwe per batch
                        assert isinstance(append_mwe_ind[ind_sent], list)
                        assert isinstance(append_mwe_row[ind_sent], list)
                        if len(append_mwe_ind[ind_sent]) > 0:
                            assert len(append_mwe_row[ind_sent])>0
                            begin_mwe = append_mwe_ind[ind_sent][0]
                            if begin_mwe == ind:
                                norm_file.write(append_mwe_row[ind_sent][0])
                                original.write(append_mwe_row[ind_sent][0])
                                append_mwe_ind[ind_sent].pop()
                                append_mwe_row[ind_sent].pop()

                    norm_file.write("{index}\t{original}\t_\t{pos}\t_\t_\t{dep}\t_\t{types}\t{norm}\n".format(index=ind, original=original_token, pos=pos, types=types, dep=heads, norm=tenth_col))
                    original.write("{}\t{}\t_\t_\t_\t_\t_\t_\t{}\t_\n".format(ind, original_token, ind-1))
                    if cut_sent:
                        if ind > 50:
                            break
                        print("CUTTING SENT index {}>50 ".format(ind))
                norm_file.write("\n")
                original.write("\n")
        printing("WRITING predicted batch of {} original and {} normalized", var=[dir_original, dir_pred], verbose=verbose, verbose_level=1)
    assert max_len_word is not None, "ERROR : something went wrong in the writer"
    return max_len_word
def realigne_multi(
        ls_sent_str,
        input_alignement_with_raw,
        mask_str,
        label,
        end_token,
        #remove_null_str=True, null_str,
        remove_mask_str=False,
        remove_extra_predicted_token=False,
        keep_mask=False,
        gold_sent=False,
        flag_word_piece_token="##",
        flag_is_first_token=False,
        cumulate_shift_sub_word=None):
    """
    # factorize with net realign
    ** remove_extra_predicted_token used iif pred mode **
    - detokenization of ls_sent_str based on input_alignement_with_raw index
    - we remove paddding and end detokenization at symbol [SEP] that we take as the end of sentence signal
    """
    assert len(ls_sent_str) == len(input_alignement_with_raw), "ERROR : ls_sent_str {} : {} input_alignement_with_raw {}" \
                                                               " : {} ".format(ls_sent_str, len(ls_sent_str),
                                                                               input_alignement_with_raw,
                                                                               len(input_alignement_with_raw))
    new_sent_ls = []
    if label == "heads":
        assert cumulate_shift_sub_word is not None
    ind_sent = 0
    DEPLOY_MODE = True
    for sent, index_ls in zip(ls_sent_str, input_alignement_with_raw):
        # alignement index and input/label should have same len
        assert len(sent) == len(
            index_ls
        ), "ERROR : {} sent {} len {} and index_ls {} len {} not same len".format(
            label, sent, index_ls, len(sent), len(index_ls))

        former_index = -1
        new_sent = []
        former_token = ""

        for _i, (token, index) in enumerate(zip(sent, index_ls)):

            trigger_end_sent = False
            index = int(index)

            if remove_extra_predicted_token:
                if index == 1000 or index == -1:
                    # we reach the end according to gold data
                    # (this means we just stop looking at the prediciton of the model (we can do that because we assumed word alignement))
                    trigger_end_sent = True
                    if gold_sent:
                        # we sanity check that the alignement corredponds
                        try:
                            assert token in PADING_SYMBOLS, "WARNING 123 : breaking gold sequence on {} token not in {}".format(
                                token, PADING_SYMBOLS)
                        except Exception as e:
                            print(e)
            # if working with input : handling mask token in a specific way
            if token == mask_str and not keep_mask:
                token = "X" if not remove_mask_str else ""
            # if working with input merging # concatanating wordpieces
            if LABEL_PARAMETER[label]["realignement_mode"] == "detokenize_bpe":
                if index == former_index:
                    if token.startswith(
                            flag_word_piece_token) and not flag_is_first_token:
                        former_token += token[len(flag_word_piece_token):]
                    else:
                        former_token += token
            # for sequence labelling : ignoring
            elif LABEL_PARAMETER[label][
                    "realignement_mode"] == "ignore_non_first_bpe":
                # we just ignore bpe that are not first bpe of tokens
                if index == former_index:
                    pass
            # if new token --> do something on the label
            # if index != former_index or _i + 1 == len(index_ls): # for DEPLOY_MODE = False
            if (index != former_index
                    or index == -1) or _i + 1 == len(index_ls):
                if not flag_is_first_token:
                    new_sent.append(former_token)
                elif flag_is_first_token and (
                        isinstance(former_token, str)
                        and former_token.startswith(flag_word_piece_token)):
                    new_sent.append(former_token[len(flag_word_piece_token):])
                else:
                    if label == "heads":
                        #print("WARNING : HEAD RE-ALIGNING")
                        if isinstance(former_token, int):
                            try:
                                #print(cumulate_shift_sub_word[ind_sent][former_index], former_token, cumulate_shift_sub_word[ind_sent][former_token])
                                #cumulate_shift_sub_word[ind_sent][former_token]
                                #print(ls_sent_str[ind_sent][former_index], former_token)
                                if former_token != -1:
                                    #pdb.set_trace()
                                    #print("-->",former_index)
                                    #former_token -= cumulate_shift_sub_word[ind_sent][former_token]
                                    former_token = eval(
                                        input_alignement_with_raw[ind_sent]
                                        [former_token])
                                    token = eval(
                                        input_alignement_with_raw[ind_sent]
                                        [token])
                            except:
                                print(
                                    "error could not process former_token {} too long for cumulated_shift {} "
                                    .format(former_token,
                                            cumulate_shift_sub_word[ind_sent]))
                                if gold_sent:
                                    pdb.set_trace()
                            #former_token-=cumulate_shift_sub_word[ind_sent][former_token]
                    # is this last case possible

                    #new_sent.append(former_token)
                    new_sent.append(token)

                former_token = token
                if trigger_end_sent:
                    print("break trigger_end_sent")
                    break

            elif DEPLOY_MODE:
                # EXCEPT PUNCTUNATION FOR WHICH SHOULD ADD -1 BEFORE !

                #if former_index != -1:
                #    new_sent.append(eval(input_alignement_with_raw[ind_sent][former_token]))
                former_token = token
                new_sent.append(-1)

                # ADD MORE IF
            #pdb.set_trace()
            # if not pred mode : always not trigger_end_sent : True
            # (required for the model to not stop too early if predict SEP too soon)
            # NEW CLEANER WAY OF BREAKING : should be generalize
            if remove_extra_predicted_token and trigger_end_sent:
                if not flag_is_first_token:
                    new_sent.append(former_token)
                elif flag_is_first_token and (
                        isinstance(former_token, str)
                        and former_token.startswith(flag_word_piece_token)):
                    new_sent.append(former_token[len(flag_word_piece_token):])
                else:
                    # is this last case possible
                    new_sent.append(former_token)
                print("break remove_extra_predicted_token")
                break
            # TODO : SHOULD be cleaned
            # XLM (same first and end token) so not activated for </s>

            if not DEPLOY_MODE:
                if ((former_token == end_token and end_token != "</s>")
                        or _i + 1 == len(index_ls)
                        and not remove_extra_predicted_token) or (
                            (remove_extra_predicted_token and
                             (former_token == end_token and trigger_end_sent)
                             or _i + 1 == len(index_ls))):
                    new_sent.append(token)

                    print(
                        f"break new_sent {((former_token == end_token and end_token != '</s>') or _i + 1 == len(index_ls) and not remove_extra_predicted_token)} or { ((remove_extra_predicted_token and (former_token == end_token and trigger_end_sent) or _i + 1 == len(index_ls)))}"
                    )
                    break
            former_index = index
        #if DEPLOY_MODE:
        #    new_sent_ls.append(new_sent[1:])
        #else:
        #new_sent_ls.append(new_sent[1:])
        new_sent_ls.append(new_sent)
        ind_sent += 1
    if gold_sent:
        print("CUMULATED SHIFT", cumulate_shift_sub_word)
        print("GOLD:OUTPUT BEFORE DETOKENIZATION ", ls_sent_str)
        print("GOLD:OUTPUT AFTER DETOKENIZATION", new_sent_ls)

    return new_sent_ls
def detokenized_src_label(source_preprocessed,
                          predict_dic,
                          label_ls,
                          label_dic=None,
                          special_after_space_flag="▁",
                          input_key="wordpieces_inputs_words"):
    """
    Re-alignement from
    sub-word tokenized --> word
    all predictions    --> first token of each word prediction
    + removing first and last special characters
    :param source_preprocessed:
    :param predict_dic:
    :param label_dic:
    :return:
    """
    detokenized_source_preprocessed = OrderedDict([(input_key, [])])
    detokenized_label_batch = OrderedDict([(key, [])
                                           for key in predict_dic.keys()])
    gold_label_batch = OrderedDict([(key, []) for key in predict_dic.keys()])

    for batch_i, batch in enumerate(source_preprocessed[input_key]):
        # for src in batch[1:-1]:
        src = batch[1:-1]
        detokenized_src = []
        for ind, (label, pred_label) in enumerate(zip(label_ls, predict_dic)):
            # remove first and last specilal token

            prediction = predict_dic[pred_label][0][batch_i][1:-1]

            if label_dic is not None:
                gold = label_dic[pred_label.split("-")[0]][batch_i][1:-1]
                assert len(prediction) == len(gold)

            detokenized_label = []
            detokenized_gold = []

            try:
                assert len(prediction) == len(src), f"ERROR should have 1-1 alignement here " \
                f"for word level task prediction {len(prediction)} {len(src)}"

            except Exception as e:
                pdb.set_trace()
                print(Exception(e))

            if label_dic is None:
                for subword, label in zip(src, prediction):

                    if subword[0] != special_after_space_flag:
                        # we are in the middle of a token : so we ignore the label and we join strings
                        if ind == 0:
                            # ind 0 in case several label set --> we work on src for only 1 of them
                            # we build detokenized_src only for the first label type
                            if subword not in ["</s>", "<pad>"]:
                                # Handling special case where number have been splitted and failed to be reconstructed
                                try:
                                    fix = isinstance(eval("1" + subword), int)
                                except Exception as e:
                                    fix = False
                                # if its a number and label is -1 (pointed as non-first sub-token) then we need to fix
                                if fix and len(
                                        detokenized_src[-1]
                                ) == 0 and detokenized_label[-1] == -1:
                                    detokenized_label = detokenized_label[:-1]
                                    detokenized_src = detokenized_src[:-1]

                                detokenized_src[-1] += subword
                    else:
                        detokenized_label.append(label)
                        try:
                            fix = isinstance(eval("1" + subword[1:]), int)
                        except Exception as e:
                            fix = False
                        # we build detokenized_src only for the first label type
                        #print("Label append", detokenized_label)
                        if ind == 0:
                            if fix and detokenized_label[-1] == -1:
                                detokenized_label = detokenized_label[:-1]
                                detokenized_src[-1] += subword[1:]
                            else:
                                detokenized_src.append(subword[1:])
                detokenized_label_batch[pred_label].append(detokenized_label)
            else:
                for subword, label, gold in zip(src, prediction, gold):
                    if subword[0] != special_after_space_flag:
                        # we are in the middle of a token : so we ignore the label and we join strings
                        if ind == 0:
                            # we build detokenized_src only for the first label type
                            if subword not in ["</s>", "<pad>"]:
                                detokenized_src[-1] += subword
                    else:
                        detokenized_label.append(label)
                        detokenized_gold.append(gold)
                        # we build detokenized_src only for the first label type
                        if ind == 0:
                            # we remove the special character
                            detokenized_src.append(subword[1:])

                detokenized_label_batch[pred_label].append(detokenized_label)
                gold_label_batch[pred_label].append(detokenized_gold)

            if ind == 0:
                assert len(detokenized_src) == len(
                    detokenized_label), "Should be aligned"
                detokenized_source_preprocessed[input_key].append(
                    detokenized_src)
                if label_dic is not None:
                    assert len(detokenized_gold) == len(detokenized_label)

    def sanity_check_batch():
        batch_size = -1
        # checking that input and output have same batch size and all labels
        for key in predict_dic.keys():
            if batch_size != -1:
                assert len(detokenized_label_batch[key]) == batch_size
                if len(gold_label_batch[key]) > 0:
                    assert len(gold_label_batch[key]) == batch_size
            batch_size = len(detokenized_label_batch[key])
        assert len(detokenized_source_preprocessed[input_key]) == batch_size

    sanity_check_batch()

    return detokenized_source_preprocessed, detokenized_label_batch, gold_label_batch