Ejemplo n.º 1
0
def get_dir_data(set, data_code, demo=False):

    assert set in ["train", "dev", "test"], "{} - {}".format(set, data_code)
    assert data_code in DATASET_CODE_LS, "ERROR {}".format(data_code)
    demo_str = "-demo" if demo else ""
    # WE ASSSUME DEV AND TEST CANNOT FINISH by INTERGER_INTERGER IF THEY DO --> fall back to data_code origin
    if set in ["dev", "test"]:
        matching = re.match("(.*)_([0-9]+)_([0-9]+)$", data_code)
        if matching is not None:
            data_code = matching.group(1)
            print("WARNING : changed data code with {}".format(data_code))
        else:
            pass  #print("DATA_CODE no int found  {} ".format(data_code))
    file_dir = os.path.join(
        DATA_UD, "{}-ud-{}{}.conllu".format(data_code, set, demo_str))
    try:
        assert os.path.isfile(file_dir), "{} not found".format(file_dir)
    except:
        try:
            file_dir = os.path.join(
                DATA_UD_25, "{}-ud-{}{}.conllu".format(data_code, set,
                                                       demo_str))
            assert os.path.isfile(file_dir), "{} not found".format(file_dir)
            print("WARNING : UD 25 USED ")
        except Exception as e:
            print("--> data ", e)
            demo_str = ""
            file_dir = os.path.join(
                DATA_WIKI_NER,
                "{}-wikiner-{}{}.conll".format(data_code, set, demo_str))
            assert os.path.isfile(file_dir), "{} not found".format(file_dir)
            print("WARNING : WIKI NER USED")
    return file_dir
Ejemplo n.º 2
0
def get_task_name_based_on_logit_label(logit_label, label_processed):
    match = re.match("(.*)-(.*)", logit_label)
    assert match is not None, "ERROR {}".format(logit_label)
    label = match.group(2)
    task = match.group(1)
    #else:
    #    label = logit_label
    _continue = False
    if label in label_processed:
        _continue = True
    else:
        _continue = False
        label_processed.append(label)
    return label, task, _continue, label_processed
def get_normalized_token(norm_field,
                         n_exception,
                         verbose,
                         predict_mode_only=False):

    match = re.match("^Norm=([^|]+)|.+", norm_field)

    try:
        assert match.group(
            1
        ) is not None, " ERROR : not normalization found for norm_field {} ".format(
            norm_field)
        normalized_token = match.group(1)

    except:
        match_double_bar = re.match("^Norm=([|]+)|.+", norm_field)

        if match_double_bar.group(1) is not None:
            match = match_double_bar
            n_exception += 1
            printing("Exception handled we match with {}".format(
                match_double_bar.group(1)),
                     verbose=verbose,
                     verbose_level=2)
            normalized_token = match.group(1)

        else:
            exc = Exception(
                "Failed to handle exception with | on field {} ".format(
                    norm_field))
            if not predict_mode_only:
                raise (exc)
            else:
                print("REPLACING with UNK", exc)
                normalized_token = "UNK"

    return normalized_token, n_exception
Ejemplo n.º 4
0
def get_init_args_dir(init_args_dir):
    """
    to simplify reporting we allow three ways of providing init_args_dir
    :param init_args_dir:
    :return:
    """
    if os.path.isfile(init_args_dir):  # , "ERROR {} not found to reload checkpoint".format(init_args_dir)
        _dir = init_args_dir
    elif os.path.isfile(os.path.join(CHECKPOINT_BERT_DIR, init_args_dir)):
        printing("MODEL init {} not found as directory so using second template ", var=[init_args_dir], verbose=1,
                 verbose_level=1)
        _dir = os.path.join(CHECKPOINT_BERT_DIR, init_args_dir)
    else:
        printing("MODEL init {} not found as directory and as subdirectory so using third template template ",
                 var=[init_args_dir], verbose=1, verbose_level=1)
        match = re.match("(.*-model_[0-9]+).*", init_args_dir)
        assert match is not None, "ERROR : template {} not found in {}".format("([.*]-model_[0-9]+).*", init_args_dir)
        _dir = os.path.join(CHECKPOINT_BERT_DIR, match.group(1), init_args_dir + "-args.json")
        assert os.path.isfile(_dir), "ERROR : {} does not exist (based on param {}) ".format(_dir, init_args_dir)
    return _dir
def get_bpe_string(predictions_topk_dic, output_tokens_tensor_aligned_dic,
                   input_tokens_tensor_per_task, topk, tokenizer,
                   task_to_label_dictionary, task_settings, mask_index,
                   verbose):

    predict_dic = OrderedDict()
    source_preprocessed = OrderedDict()
    label_dic = OrderedDict()

    input_already_processed = []
    gold_already_processed = []
    for task_label in predictions_topk_dic:

        label = re.match("(.*)-(.*)", task_label)
        assert label is not None, "ERROR : {} task_label does not fit the right template (.*)-.* ".format(
            task_label)
        task = label.group(1)
        label = label.group(2)
        #task_settings[]
        sent_ls_top = from_bpe_token_to_str(
            predictions_topk_dic[task_label],
            topk,
            tokenizer=tokenizer,
            pred_mode=True,
            task=task,
            mask_index=mask_index,
            bpe_tensor_src=input_tokens_tensor_per_task["input_masked"]
            if task == "mlm" else None,
            label_dictionary=task_to_label_dictionary[task_label],
            get_bpe_string=LABEL_PARAMETER[label]["bpe"],
            label=label  #, null_token_index=null_token_index, null_str=null_str
        )
        # some tasks may share same outputs : we don't want to post-process them several times
        if label in gold_already_processed:
            continue
        if output_tokens_tensor_aligned_dic[label] is not None:
            gold_already_processed.append(label)
            gold = from_bpe_token_to_str(
                output_tokens_tensor_aligned_dic[label],
                topk,
                tokenizer=tokenizer,
                task=task,
                label_dictionary=task_to_label_dictionary[task_label],
                pred_mode=False,
                get_bpe_string=LABEL_PARAMETER[label]["bpe"],
                label=label,
                #null_token_index=null_token_index, null_str=null_str
            )
            label_dic[label] = gold
        else:
            label_dic[label] = None

        predict_dic[task_label] = sent_ls_top
        input_label = task_settings[task]["input"]
        input_tokens_tensor = input_tokens_tensor_per_task[input_label]
        # some tasks may share same inputs : we don't want to post-process them several times
        if input_label in input_already_processed:
            continue
        input_already_processed.append(input_label)

        source_preprocessed[input_label] = from_bpe_token_to_str(
            input_tokens_tensor,
            topk,
            tokenizer=tokenizer,
            label_dictionary=task_to_label_dictionary[task_label],
            pred_mode=False,
            task=task,
            #null_token_index=null_token_index, null_str=null_str,
            get_bpe_string=True,
            verbose=verbose)

    return source_preprocessed, label_dic, predict_dic
    def forward(self,
                input_ids_dict,
                token_type_ids=None,
                attention_mask=None,
                labels=None,
                head_masks=None):
        if labels is None:
            labels = OrderedDict()
        if head_masks is None:
            head_masks = OrderedDict()
        sequence_output_dict = OrderedDict()
        logits_dict = OrderedDict()
        loss_dict = OrderedDict()
        # sanity check the labels : they should all be in
        for label, value in labels.items():
            assert label in self.labels_supported, "label {} in {} not supported".format(
                label, self.labels_supported)

        # task_wise layer attention
        printout_allocated_gpu_memory(1, " foward starting ")
        for input_name, input_tensors in input_ids_dict.items():
            # not able to output all layers anymore
            #print("INPUT {} {} ".format(input_name, input_tensors))
            sequence_output, _ = self.encoder(
                input_tensors,
                token_type_ids=None,
                attention_mask=attention_mask[input_name])
            sequence_output_dict[input_name] = sequence_output
            printout_allocated_gpu_memory(1, " forward pass bert")

        for task in self.tasks:
            # we don't use mask for parsing heads (cf. test performed below : the -1 already ignore the heads we don't want)
            # NB : head_masks for parsing only applies to heads not types
            head_masks_task = None  # head_masks.get(task, None) if task != "parsing" else None
            # NB : head_mask means masks specific the the module heads (nothing related to parsing !! )
            assert self.task_parameters[task]["input"] in sequence_output_dict, \
                "ERROR input {} of task {} was not found in input_ids_dict {}" \
                " and therefore not in sequence_output_dict {} ".format(self.task_parameters[task]["input"],
                                                                        task, input_ids_dict.keys(),
                                                                        sequence_output_dict.keys())

            if not self.head[
                    task].__class__.__name__ == BertOnlyMLMHead.__name__:  #isinstance(self.head[task], BertOnlyMLMHead):
                logits_dict[task] = self.head[task](
                    sequence_output_dict[self.task_parameters[task]["input"]],
                    head_mask=head_masks_task)
            else:
                logits_dict[task] = self.head[task](
                    sequence_output_dict[self.task_parameters[task]["input"]])
            # test performed : (logits_dict[task][0][1,2,:20]==float('-inf'))==(labels["parsing_heads"][1,:20]==-1)
            # handle several labels at output (e.g  parsing)

            printout_allocated_gpu_memory(1,
                                          " foward pass head {}".format(task))

            logits_dict = self.rename_multi_modal_task_logits(
                labels=self.task_parameters[task]["label"],
                task=task,
                logits_dict=logits_dict,
                task_parameters=self.task_parameters)

            printout_allocated_gpu_memory(1, "after renaming")

            for logit_label in logits_dict:

                label = re.match("(.*)-(.*)", logit_label)
                assert label is not None, "ERROR logit_label {}".format(
                    logit_label)
                label = label.group(2)
                if label in self.task_parameters[task]["label"]:
                    _labels = None
                    if self.task_parameters[task]["input"] == "input_masked":
                        _labels = labels.get(label)
                        if _labels is not None:
                            _labels = _labels.clone()
                            _labels[input_ids_dict["input_masked"] != self.
                                    mask_index_bert] = PAD_ID_LOSS_STANDART
                    else:
                        _labels = labels.get(label)
                    printout_allocated_gpu_memory(
                        1, " get label head {}".format(logit_label))
                    if _labels is not None:
                        #print("LABEL label {} {}".format(label, _labels))
                        loss_dict[logit_label] = self.get_loss(
                            loss_func=self.task_parameters[task]["loss"],
                            label=label,
                            num_label_dic=self.num_labels_dic,
                            labels=_labels,
                            logits_dict=logits_dict,
                            task=task,
                            logit_label=logit_label,
                            head_label=labels["heads"]
                            if label == "types" else None)
                    printout_allocated_gpu_memory(1,
                                                  " get loss {}".format(task))
                printout_allocated_gpu_memory(
                    1, " puting to cpu {}".format(logit_label))
        # thrid output is for potential attention weights

        return logits_dict, loss_dict, None
Ejemplo n.º 7
0
 def get_group(group_mapping, penalize_lab):
     for group_regex in group_mapping:
         if re.match(group_regex, penalize_lab) is not None:
             return group_regex
def write_conll_multitask(format, dir_pred, dir_original, src_text_ls,
                          pred_per_task, tasks, task_parameters, cp_paste=False, gold=False,
                          all_indexes=None, sep_token=None, cls_token=None,
                          sent_id=None, raw_text=None,
                          append_mwe_ind= None,
                          append_mwe_row= None,
                          ind_batch=0, new_file=False, cut_sent=False, verbose=0):

    assert format in ["conll"]
    max_len_word = None
    writing_top = 1
    # assert each task is predicting as many sample per batch
    pred_task_len_former = -1
    task_former = ""

    # assertion on number of samples predicted
    for task_label in pred_per_task:

        pred_task_len = len(pred_per_task[task_label]) if gold else len(pred_per_task[task_label][writing_top-1])
        _task = re.match("(.*)-(.*)", task_label)
        if _task is not None:  # , "ERROR writer could not match {}".format(task_label)
            task = _task.group(1)
        else:
            task = task_label
        if pred_task_len_former > 0:
            assert pred_task_len == pred_task_len_former, \
                "ERROR {} and {} task ".format(task_former, task_label)
            if not gold:
                assert pred_task_len == len(src_text_ls[task_parameters[task]["input"]]), "ERROR  src len {} and pred len {} ".format(len(src_text_ls[task_parameters[task]["input"]]),pred_task_len)
            # we check also other input length
            if src_text_ls.get("input_masked") is not None:
                assert pred_task_len == len(src_text_ls["input_masked"])
            if src_text_ls.get("wordpieces_inputs_words") is not None:
                assert pred_task_len == len(src_text_ls["wordpieces_inputs_words"]), "ERROR mismatch source " \
                                                                            "wordpieces_inputs_words {}  " \
                                                                            "and prediction {} ".format(src_text_ls, pred_per_task[task_label])
            if src_text_ls.get("wordpieces_inputs_raw_tokens") is not None:
                assert pred_task_len == len(src_text_ls["wordpieces_inputs_raw_tokens"]), \
                                    "ERROR mismatch source wordpieces_inputs_" \
                                    "raw_tokens {} and prediction {} ".format(src_text_ls, pred_per_task[task_label])
            try:
                assert pred_task_len == all_indexes.shape[0], "ERROR mismatch index {}  and all_indexes {} : pred {}".format(pred_task_len, all_indexes.shape[0], pred_per_task[task_label])
            except:
                pdb.set_trace()
        pred_task_len_former = pred_task_len

        task_former = task_label
        if format == "conll":
            mode_write = "w" if new_file else "a"
        if new_file:
            printing("CREATING NEW FILE (io_/dat/normalized_writer) : {} ", var=[dir_pred], verbose=verbose,
                     verbose_level=1)

    pos_label = "pos-pos" if not gold else "pos"
    types_label = "parsing-types" if not gold else "types"
    heads_label = "parsing-heads" if not gold else "heads"
    n_masks_mwe_label = "n_masks_mwe-n_masks_mwe" if not gold else "n_masks_mwe"
    mwe_detection_label = "mwe_detection-mwe_detection" if not gold else "mwe_detection"

    with open(dir_pred, mode_write) as norm_file:
        with open(dir_original, mode_write) as original:
            len_original = 0
            for ind_sent in range(all_indexes.shape[0]):
                pred_sent = OrderedDict()
                # NB : length assertion for each input-output (correcting if possible)
                # TODO standartize !!  INCONSITENCIES WHEN GOLD TRUE AND GOLD FALSE, IF GOLD : pred_per_task is indexed by labels (no relation 1-1 to task and src ! )
                for task_label_or_gold_label in pred_per_task:
                    #task, _, label_processed = get_task_name_based_on_logit_label(task_label, label_processed)
                    if gold:
                        pred_sent[task_label_or_gold_label] = pred_per_task[task_label_or_gold_label][ind_sent]
                    else:
                        pred_sent[task_label_or_gold_label] = pred_per_task[task_label_or_gold_label][writing_top-1][ind_sent]
                    try:
                        # TODO : standartize  (the first if is needed because we handle at the same time gold data indexed by label and prediction labelled by task+label
                        if gold:
                            try:
                                src = src_text_ls[LABEL_PARAMETER[task_label_or_gold_label]["default_input"]][ind_sent]
                            except Exception as e:
                                src = src_text_ls["input_masked"][ind_sent]
                        else:
                            _task = re.match("(.*)-(.*)", task_label_or_gold_label)
                            assert _task is not None#, "ERROR writer could not match {}".format(task_label)
                            _label = _task.group(2)
                            _task = _task.group(1)
                            src = src_text_ls[TASKS_PARAMETER[_task]["input"]][ind_sent]

                        assert len(src) == len(pred_sent[task_label_or_gold_label]),"WARNING : (writer) task {} original_sent len {} {} \n  predicted sent len {} {}".format(task_label_or_gold_label, len(src), src,len(pred_sent[task_label_or_gold_label]), pred_sent[task_label_or_gold_label])
                    except AssertionError as e:
                        print(e)
                        pdb.set_trace()
                        if len(src) > len(pred_sent[task_label_or_gold_label]):
                            pred_sent[task_label_or_gold_label].extend(["UNK" for _ in range(len(src)-len(pred_sent[task_label_or_gold_label]))])
                            print("WARNING (writer) : original larger than prediction : so appending UNK token for writing")
                        else:
                            print("WARNING (writer) : original smaller than prediction for ")

                if sent_id is not None and raw_text is not None:
                    #norm_file.write("\n")
                    #original.write("\n")

                    norm_file.write(sent_id[ind_sent])
                    original.write(sent_id[ind_sent])
                    norm_file.write(raw_text[ind_sent])
                    original.write(raw_text[ind_sent])
                else:
                    norm_file.write("#\n")
                    original.write("#\n")
                    norm_file.write("#sent_id = {} \n".format(ind_sent+ind_batch+1))
                    original.write("#sent_id = {} \n".format(ind_sent+ind_batch+1))
                ind_adjust = 0

                #for ind, original_token in enumerate(original_sent):
                last_mwe_index = -1
                adjust_mwe = 0
                for ind in all_indexes[ind_sent, :]:
                    # WE REMOVE SPECIAL TOKENS ONLY IF THEY APPEAR AT THE BEGINING OR AT THE END
                    # on the source token !! (it tells us when we stop) (we nevern want to use gold information)
                    if "-" in ind and ind != "-1":
                        matching_mwe_ind = re.match("([0-9]+)-([0-9]+)", str(ind))
                        assert matching_mwe_ind is not None, "ERROR ind is {} : could not found mwe index".format(ind)
                        last_mwe_index = int(matching_mwe_ind.group(2))
                        ind_mwe = int(matching_mwe_ind.group(1))

                        original_token = src_text_ls["wordpieces_inputs_raw_tokens"][ind_sent][ind_mwe] if mwe_detection_label in pred_per_task or "wordpieces_inputs_words" in pred_per_task or n_masks_mwe_label in pred_per_task else "NOT_NEEDED"
                        adjust_mwe += (last_mwe_index-ind_mwe)
                        #assert ind_adjust == 0, "ERROR not supported"

                        mwe_meta = "Norm={}|mwe_detection={}|n_masks_mwe={}".format("_", pred_sent[mwe_detection_label][ind_mwe] if mwe_detection_label in pred_per_task else "_",
                                                                                    pred_sent[n_masks_mwe_label][ind_mwe] if n_masks_mwe_label in pred_per_task else "_")

                        norm_file.write("{index}\t{original}\t_\t{pos}\t_\t_\t{dep}\t_\t{types}\t{norm}\n".format(index=ind, original=original_token, pos="_", types="_", dep="_", norm=mwe_meta))
                        original.write("{}\t{}\t_\t_\t_\t_\t_\t_\t{}\t_\n".format(ind, original_token, "_"))
                        continue
                    else:
                        ind = int(ind)
                        try:
                            if "normalize" in [task for _tasks in tasks for task in _tasks]:

                                original_token = src_text_ls["wordpiece_words_src_aligned_with_norm"][ind_sent][ind]
                                original_pretokenized_field = "wordpiece_words_src_aligned_with_norm"
                            else:
                                original_token = src_text_ls["wordpieces_inputs_words"][ind_sent][ind]
                                original_pretokenized_field = "wordpieces_inputs_words"
                        except Exception as e:
                            original_token = src_text_ls["input_masked"][ind_sent][ind]
                            original_pretokenized_field = "input_masked"
                        # asserting that we have everything together on the source side
                        if ind > last_mwe_index:
                            if src_text_ls.get("wordpieces_inputs_raw_tokens") is not None:
                                try:
                                    assert src_text_ls[original_pretokenized_field][ind_sent][ind] == src_text_ls["wordpieces_inputs_raw_tokens"][ind_sent][ind-adjust_mwe], \
                                    "ERROR sequence {} on non-mwe tokens : raw and tokenized " \
                                    "should be same but are raw {} tokenized {}".format(original_pretokenized_field, src_text_ls["wordpieces_inputs_raw_tokens"][ind_sent][ind],
                                                                                        src_text_ls[original_pretokenized_field][ind_sent][ind+adjust_mwe])
                                except:
                                    print("WARNING sanity checking input failed (nomalized_writer) (might be due to dropout) {}".format(e))
                    max_len_word = max(len(original_token), len_original)
                    #if original_token in SPECIAL_TOKEN_LS and (ind+1 == len(original_sent) or ind == 0):
                    if (original_token in SPECIAL_TOKEN_LS or original_token in [cls_token, sep_token]):
                        # ind 0 is skipped because it corresponds to CLS
                        ind_adjust = 1
                        continue

                    pos = pred_sent[pos_label][ind] if pos_label in pred_per_task else "_"
                    types = pred_sent[types_label][ind] if types_label in pred_per_task else "_"
                    heads = pred_sent[heads_label][ind] if heads_label in pred_per_task else ind - 1

                    tenth_col = "Norm={}|mwe_detection={}|n_masks_mwe={}".format(pred_sent["normalize"][ind] if "normalize" in pred_per_task else "_",
                                                                                 pred_sent[mwe_detection_label][ind-adjust_mwe] if mwe_detection_label in pred_per_task else "_",
                                                                                 pred_sent[n_masks_mwe_label][ind-adjust_mwe] if n_masks_mwe_label in pred_per_task else "_")

                    if append_mwe_ind is not None:
                        # we need one list of mwe per batch
                        assert isinstance(append_mwe_ind[ind_sent], list)
                        assert isinstance(append_mwe_row[ind_sent], list)
                        if len(append_mwe_ind[ind_sent]) > 0:
                            assert len(append_mwe_row[ind_sent])>0
                            begin_mwe = append_mwe_ind[ind_sent][0]
                            if begin_mwe == ind:
                                norm_file.write(append_mwe_row[ind_sent][0])
                                original.write(append_mwe_row[ind_sent][0])
                                append_mwe_ind[ind_sent].pop()
                                append_mwe_row[ind_sent].pop()

                    norm_file.write("{index}\t{original}\t_\t{pos}\t_\t_\t{dep}\t_\t{types}\t{norm}\n".format(index=ind, original=original_token, pos=pos, types=types, dep=heads, norm=tenth_col))
                    original.write("{}\t{}\t_\t_\t_\t_\t_\t_\t{}\t_\n".format(ind, original_token, ind-1))
                    if cut_sent:
                        if ind > 50:
                            break
                        print("CUTTING SENT index {}>50 ".format(ind))
                norm_file.write("\n")
                original.write("\n")
        printing("WRITING predicted batch of {} original and {} normalized", var=[dir_original, dir_pred], verbose=verbose, verbose_level=1)
    assert max_len_word is not None, "ERROR : something went wrong in the writer"
    return max_len_word
Ejemplo n.º 9
0
def get_code_data(dir):
    matching = re.match(".*\/([^\/]+).*.conllu", dir)
    if matching is not None:
        return matching.group(1)
    return "training_set-not-found"