def data_gen_dummy(V,
                   batch,
                   nbatches,
                   sent_len=9,
                   word_len=5,
                   verbose=0,
                   seed=None):
    "Generate random data for a src-tgt copy task."
    if seed is not None:
        np.random.seed(seed)
    for i in tqdm(range(nbatches),
                  disable=disable_tqdm_level(verbose, verbose_level=2)):
        data = torch.from_numpy(
            np.random.randint(low=2, high=V, size=(batch, sent_len, word_len)))
        data[:, :, 0] = 2
        # we force padding in the dummy model
        data[:, :, -1] = 1
        data[:, :, -2] = 1
        printing("DATA dummy {} ",
                 var=(data),
                 verbose=verbose,
                 verbose_level=5)
        src = Variable(data, requires_grad=False)
        tgt = Variable(data, requires_grad=False)
        yield MaskBatch(src, tgt, pad=1)
def printout_allocated_gpu_memory(verbose, comment):

    if verbose == "gpu":
        try:
            printing("GPU {} {}",
                     var=[comment, torch.cuda.memory_allocated()],
                     verbose=verbose,
                     verbose_level="gpu")
        except Exception as e:
            print(e)
def use_gpu_(use_gpu, verbose=0):
    if use_gpu is not None and use_gpu:
        assert torch.cuda.is_available(
        ), "ERROR : use_gpu was set to True but cuda not available "
    use_gpu = torch.cuda.is_available() if use_gpu is None else use_gpu
    printing("HARDWARE : use_gpu set to {} ",
             var=[use_gpu],
             verbose=verbose,
             verbose_level=1)
    return use_gpu
def sanity_check_loss_poneration(ponderation_dic, verbose=1):
    if isinstance(ponderation_dic, dict):
        for task in TASKS_PARAMETER:
            assert task in ponderation_dic, "ERROR : task {} is not related to a ponderation while it should ".format(task)
    elif isinstance(ponderation_dic,str):
        assert ponderation_dic in MULTI_TASK_LOSS_PONDERATION_PREDEFINED_MODE, "ERROR ponderation should be in {}".format(ponderation_dic,MULTI_TASK_LOSS_PONDERATION_PREDEFINED_MODE)
        printing("WARNING : COULD NOT SANITY CHECK ponderation_dic {} ", var=[ponderation_dic], verbose=verbose,
                 verbose_level=1)
    else:
        raise(Exception("ponderation_dic is neither string or dict {}".format(ponderation_dic)))
def align_bpe(n_bpe_target_minus_source,
              source_aligned,
              source_aligned_index,
              target_aligned,
              target_aligned_index,
              n_masks_to_add,
              src_token_len,
              bert_tokenizer,
              mask_token,
              mode="dummy",
              index_src=None,
              index_target=None,
              verbose=0):
    """
    align bpe of a given token using mode
    :return:
    """
    assert mode in ["dummy"]
    # dummy means appending with SPACE or MASK when needed
    if n_bpe_target_minus_source > 0:
        assert index_src is not None
        source_aligned_index.extend(
            [index_src for _ in range(n_bpe_target_minus_source)])
        source_aligned.extend(
            bert_tokenizer.convert_tokens_to_ids(
                [mask_token for _ in range(n_bpe_target_minus_source)]))

    elif n_bpe_target_minus_source < 0:
        assert index_target is not None
        # we add a NULL_STR (to be predicted) and index it as the former bpe token
        target_aligned_index.extend(
            [index_target for _ in range(-n_bpe_target_minus_source)])
        target_aligned.extend(
            bert_tokenizer.convert_tokens_to_ids(
                [NULL_STR for _ in range(-n_bpe_target_minus_source)]))

    n_masks_to_add.append(n_bpe_target_minus_source)
    n_masks_to_add.extend([-1 for _ in range(src_token_len - 1)])

    if verbose == "reader":
        printing(
            "SRC appending word bpe align : {}\nTARGET appending word bpe align : {} \nN_MASKS------------ : {}",
            var=[[mask_token for _ in range(n_bpe_target_minus_source)]
                 if n_bpe_target_minus_source > 0 else "",
                 [NULL_STR for _ in range(-n_bpe_target_minus_source)]
                 if n_bpe_target_minus_source < 0 else "",
                 [n_bpe_target_minus_source] +
                 [-1 for _ in range(src_token_len - 1)]],
            verbose_level="reader",
            verbose=verbose)

    return source_aligned, source_aligned_index, target_aligned, target_aligned_index, n_masks_to_add
def get_vocab_size_and_dictionary_per_task(tasks, pos_dictionary=None, type_dictionary=None, vocab_bert_wordpieces_len=None, task_parameters=None,verbose=1):

    if pos_dictionary is None and type_dictionary is None:
        assert "pos" not in tasks and "parsing" not in tasks, \
            "ERROR : pos or parsing are in tasks but related dictionaries are None"
        printing("INFO : no dictionaries and voc_sizes needed",     verbose=verbose, verbose_level=1)
        return None, None
    num_labels_per_task = OrderedDict()
    task_to_label_dictionary = OrderedDict()

    if "pos" in tasks:
        assert pos_dictionary is not None
        task_to_label_dictionary["pos-pos"] = pos_dictionary
        num_labels_per_task["pos-"+task_parameters["pos"]["label"][0]] = len(pos_dictionary.instance2index) + 1
    if "parsing" in tasks:
        assert type_dictionary is not None
        num_labels_per_task["parsing-types"] = len(type_dictionary.instance2index) + 1
        num_labels_per_task["parsing-heads"] = 0

        task_to_label_dictionary["parsing-types"] = type_dictionary
        task_to_label_dictionary["parsing-heads"] = "index"

    if "n_masks_mwe" in tasks:
        num_labels_per_task["n_masks_mwe-"+task_parameters["n_masks_mwe"]["label"][0]] = 3
        task_to_label_dictionary["n_masks_mwe-n_masks_mwe"] = "index"

    if "mwe_detection" in tasks:
        num_labels_per_task["mwe_detection-"+task_parameters["mwe_detection"]["label"][0]] = 2
        task_to_label_dictionary["mwe_detection-mwe_detection"] = "index"
    if "mwe_prediction" in tasks:
        assert vocab_bert_wordpieces_len is not None
        num_labels_per_task["mwe_prediction-"+task_parameters["mwe_prediction"]["label"][0]] = vocab_bert_wordpieces_len
        task_to_label_dictionary["mwe_prediction-"+task_parameters["mwe_prediction"]["label"][0]] = "index"
    if "mlm" in tasks:
        assert vocab_bert_wordpieces_len is not None
        num_labels_per_task["mlm-" + task_parameters["mlm"]["label"][0]] = vocab_bert_wordpieces_len
        task_to_label_dictionary["mlm-"+task_parameters["mlm"]["label"][0]] = "index"

    if "normalize" in tasks:
        num_labels_per_task["normalize-" + task_parameters["normalize"]["label"][0]] = vocab_bert_wordpieces_len+1
        task_to_label_dictionary["normalize-" + task_parameters["normalize"]["label"][0]] = "index"

    if "norm_not_norm" in tasks:
        num_labels_per_task["norm_not_norm-" + task_parameters["norm_not_norm"]["label"][0]] = 2
        task_to_label_dictionary["norm_not_norm-" + task_parameters["norm_not_norm"]["label"][0]] = "index"

    if "norm_n_masks" in tasks:
        num_labels_per_task["norm_n_masks-" + task_parameters["norm_n_masks"]["label"][0]] = 5
        task_to_label_dictionary["norm_n_masks-" + task_parameters["norm_n_masks"]["label"][0]] = "index"

    return num_labels_per_task, task_to_label_dictionary
Example #7
0
def get_dataset_label(dataset_dir_ls, default):
    if dataset_dir_ls is None:
        return None

    if REPO_DATASET.get(dataset_dir_ls[0], None) is None:
        try:
            label = "|".join([get_code_data(path) for _, path in enumerate(dataset_dir_ls)])
        except:
            printing("REPORT : dataset name of directory {} not found as UD so using default ", var=[dataset_dir_ls], verbose=1, verbose_level=1)
            label = "|".join([REPO_DATASET.get(path, "{}_{}".format(default, i)) for i, path in enumerate(dataset_dir_ls)])
    else:
        label = "|".join([REPO_DATASET.get(path, "{}_{}".format(default, i)) for i, path in enumerate(dataset_dir_ls)])

    return label
Example #8
0
def write_args(dir, model_id, checkpoint_dir=None,
               hyperparameters=None,
               info_checkpoint=None, verbose=1):

    args_dir = os.path.join(dir, "{}-args.json".format(model_id))
    if os.path.isfile(args_dir):
        info = "updated"
        args = json.load(open(args_dir, "r"))
        args["checkpoint_dir"] = checkpoint_dir
        args["info_checkpoint"] = info_checkpoint
        json.dump(args, open(args_dir, "w"))
    else:
        assert hyperparameters is not None, "REPORT : args.json created for the first time : hyperparameters dic required "
        #assert info_checkpoint is None, "REPORT : args. created for the first time : no checkpoint yet "
        info = "new"
        json.dump(OrderedDict([("checkpoint_dir", checkpoint_dir),
                               ("hyperparameters", hyperparameters),
                               ("info_checkpoint", info_checkpoint)]), open(args_dir, "w"))
    printing("MODEL args.json {} written {} ".format(info, args_dir), verbose_level=1, verbose=verbose)
    return args_dir
Example #9
0
def get_init_args_dir(init_args_dir):
    """
    to simplify reporting we allow three ways of providing init_args_dir
    :param init_args_dir:
    :return:
    """
    if os.path.isfile(init_args_dir):  # , "ERROR {} not found to reload checkpoint".format(init_args_dir)
        _dir = init_args_dir
    elif os.path.isfile(os.path.join(CHECKPOINT_BERT_DIR, init_args_dir)):
        printing("MODEL init {} not found as directory so using second template ", var=[init_args_dir], verbose=1,
                 verbose_level=1)
        _dir = os.path.join(CHECKPOINT_BERT_DIR, init_args_dir)
    else:
        printing("MODEL init {} not found as directory and as subdirectory so using third template template ",
                 var=[init_args_dir], verbose=1, verbose_level=1)
        match = re.match("(.*-model_[0-9]+).*", init_args_dir)
        assert match is not None, "ERROR : template {} not found in {}".format("([.*]-model_[0-9]+).*", init_args_dir)
        _dir = os.path.join(CHECKPOINT_BERT_DIR, match.group(1), init_args_dir + "-args.json")
        assert os.path.isfile(_dir), "ERROR : {} does not exist (based on param {}) ".format(_dir, init_args_dir)
    return _dir
Example #10
0
def print_align_bpe(source_preprocessed, gold, input_alignement_with_raw,
                    labels_n_mask_prediction, verbose, verbose_level):
    if labels_n_mask_prediction is None:
        labels_n_mask_prediction = [[None for _ in range(len(sent))]
                                    for sent in input_alignement_with_raw]
    if isinstance(verbose, int) or verbose == "alignement":
        if verbose == "alignement" or verbose >= verbose_level:
            assert len(source_preprocessed) == len(gold), ""
            assert len(input_alignement_with_raw) == len(gold), ""
            for sent_src, sent_gold, index_match_with_src, append_masks in zip(
                    source_preprocessed, gold, input_alignement_with_raw,
                    labels_n_mask_prediction):
                assert len(sent_src) == len(sent_gold)
                assert len(sent_src) == len(sent_gold)
                for src, gold_tok, index, masks in zip(sent_src, sent_gold,
                                                       index_match_with_src,
                                                       append_masks):
                    printing("{}:{} --> {} (n_masks {})",
                             var=[index, src, gold_tok, masks],
                             verbose=1,
                             verbose_level=1)
Example #11
0
def freeze_param(model, freeze_layer_prefix_ls=None, not_freeze_layer_prefix_ls=None,verbose=1):
    freezing_layer = 0

    if not_freeze_layer_prefix_ls is None:
        not_freeze_layer_prefix_ls = []
    if freeze_layer_prefix_ls is None:
        freeze_layer_prefix_ls = []
    for name, param in model.named_parameters():
        for prefix in freeze_layer_prefix_ls:
            if name.startswith(prefix):
                param.requires_grad = False
                freezing_layer += 1
                printing("TRAINING : freezing {} parameter ", var=[name], verbose=verbose, verbose_level=1)
        to_freeze = 0
        for prefix in not_freeze_layer_prefix_ls:
            if not name.startswith(prefix):
                to_freeze += 1
            if not to_freeze == len(not_freeze_layer_prefix_ls):
                param.requires_grad = False
                freezing_layer += 1
                printing("TRAINING :- freezing {} parameter ", var=[name], verbose=verbose, verbose_level=1)
    printing("TRAINING : freezing {} layers : {} prefix , not freezing {} ",
             var=[freezing_layer, freeze_layer_prefix_ls, not_freeze_layer_prefix_ls],
             verbose=verbose,
             verbose_level=1)
    assert freezing_layer > 0, "ERROR : did not fine any layers starting with {}".format(prefix)

    return model
Example #12
0
def make_bert_multitask(pretrained_model_dir, tasks, num_labels_per_task, init_args_dir, mask_id, encoder=None, args=None, model_dir=None, hugging_face_name=None):
    assert num_labels_per_task is not None and isinstance(num_labels_per_task, dict), \
        "ERROR : num_labels_per_task {} should be a dictionary".format(num_labels_per_task)
    assert isinstance(tasks, list) and len(tasks) >= 1, "ERROR tasks {} should be a list of len >=1".format(tasks)
    # we modify programmatically the config file base on argument passed to args

    if pretrained_model_dir is not None and init_args_dir is None:
        raise(Exception("Not supported yet"))
        # hugly but handling specific heritage of XLMModel (should be made better!)
        #multitask_wrapper = BertMultiTask#BertMultiTaskXLM if encoder == "XLMModel" else BertMultiTask
        #printing("WARNING : as encoder is {} using {} ", var=["XLMModel", multitask_wrapper], verbose=1, verbose_level=1)
        #model = multitask_wrapper.from_pretrained(pretrained_model_dir, tasks=tasks, mask_id=mask_id,
        #                                          num_labels_per_task=num_labels_per_task, mapping_keys_state_dic=DIR_2_STAT_MAPPING[pretrained_model_dir],
        #                                          encoder=eval(encoder), dropout_classifier=args.dropout_classifier,
        #                                          hidden_dropout_prob=args.hidden_dropout_prob, random_init=False)

    elif init_args_dir is not None:
        init_args_dir = get_init_args_dir(init_args_dir)
        args_checkpoint = json.load(open(init_args_dir, "r"))
        #assert "checkpoint_dir" in args_checkpoint, "ERROR checkpoint_dir not in {} ".format(args_checkpoint)

        #checkpoint_dir = args_checkpoint.get("checkpoint_dir")
        #if checkpoint_dir is None or not os.path.isfile(checkpoint_dir):
        assert model_dir is not None
        checkpoint_dir = model_dir+"/"+"checkpoint.pt"
        assert os.path.isfile(checkpoint_dir), f"ERROR checkpoint file was not found {checkpoint_dir} "
        # redefining model and reloading

        encoder = CamembertModel
        config = CamembertConfig.from_pretrained(hugging_face_name)

        model = BertMultiTask(config=config, tasks=[task for tasks in args_checkpoint["hyperparameters"]["tasks"] for task in tasks], num_labels_per_task=args_checkpoint["info_checkpoint"]["num_labels_per_task"],
                              encoder=encoder, mask_id=mask_id)
        printing("MODEL : loading model from checkpoint {}", var=[checkpoint_dir], verbose=1, verbose_level=1)
        model.load_state_dict(torch.load(checkpoint_dir, map_location=lambda storage, loc: storage))
        model.append_extra_heads_model(downstream_tasks=tasks, num_labels_dic_new=num_labels_per_task)
    else:
        raise(Exception("only one of pretrained_model_dir checkpoint_dir can be defined "))

    return model
Example #13
0
def get_optimizer(parameters, lr, optimizer="adam", betas=None, weight_decay=None, verbose=1):

    assert optimizer in AVAILABLE_OPTIMIZER, "ERROR optimizers supported are {} ".format(AVAILABLE_OPTIMIZER)
    if optimizer == "adam":
        if betas is None:
            # betas = (0.9, 0.9)
            print("DEFAULT betas:", betas)
        if weight_decay is None:
            weight_decay = 0

        opt = torch.optim.Adam(parameters, lr=lr, betas=betas, eps=1e-9, weight_decay=weight_decay)
    elif optimizer == "SGD":
        assert betas is None, "ERROR "
        opt = torch.optim.SGD(parameters, lr=lr)
    elif optimizer == "bahdanu-adadelta":
        assert betas is None, "ERROR betas not supported for optimizer {}".format(optimizer)
        opt = torch.optim.Adadelta(parameters, eps=10e-6, rho=0.95)
    elif optimizer == "AdamW":
        opt = AdamW(parameters, lr=lr, weight_decay=weight_decay)

    printing("TRAINING : optimizer {} has been reloaded with lr {} betas {} , decay {}", var=[optimizer, lr, betas, weight_decay], verbose=verbose, verbose_level=1)

    return opt
def get_normalized_token(norm_field,
                         n_exception,
                         verbose,
                         predict_mode_only=False):

    match = re.match("^Norm=([^|]+)|.+", norm_field)

    try:
        assert match.group(
            1
        ) is not None, " ERROR : not normalization found for norm_field {} ".format(
            norm_field)
        normalized_token = match.group(1)

    except:
        match_double_bar = re.match("^Norm=([|]+)|.+", norm_field)

        if match_double_bar.group(1) is not None:
            match = match_double_bar
            n_exception += 1
            printing("Exception handled we match with {}".format(
                match_double_bar.group(1)),
                     verbose=verbose,
                     verbose_level=2)
            normalized_token = match.group(1)

        else:
            exc = Exception(
                "Failed to handle exception with | on field {} ".format(
                    norm_field))
            if not predict_mode_only:
                raise (exc)
            else:
                print("REPLACING with UNK", exc)
                normalized_token = "UNK"

    return normalized_token, n_exception
Example #15
0
def log_data_src_label_pred(src_detokenized_dic, predict_detokenize_dic,
                            label_detokenized_dic, tasks, verbose,
                            verbose_level):

    if isinstance(verbose, int) or verbose == "alignment":
        if verbose == "alignment" or verbose >= verbose_level:
            for task in [_task for _tasks in tasks for _task in _tasks]:
                input_name = TASKS_PARAMETER[task]["input"]
                label_name_ls = TASKS_PARAMETER[task]["label"]

                for ind_src_sent, src_sent in enumerate(
                        src_detokenized_dic[input_name]):
                    print("      ")
                    for label in label_name_ls:
                        try:
                            assert len(predict_detokenize_dic[task + "-" + label][0][ind_src_sent]) == len(label_detokenized_dic[label][ind_src_sent]), \
                                "ERROR pred {} label {} ".format(predict_detokenize_dic[task + "-" + label][ind_src_sent], label_detokenized_dic[label][ind_src_sent])
                            assert len(src_detokenized_dic[input_name]
                                       [ind_src_sent]) == len(
                                           label_detokenized_dic[label]
                                           [ind_src_sent]), "ERROR "
                            for ind_src, src in enumerate(src_sent):
                                to_print = "SRC : {} ,    ".format(
                                    src) + " ".join([
                                        "PRED:{}  GOLD:{} (label {})".format(
                                            predict_detokenize_dic[task + "-" +
                                                                   label][0]
                                            [ind_src_sent][ind_src],
                                            label_detokenized_dic[label]
                                            [ind_src_sent][ind_src], label)
                                        for label in label_name_ls
                                    ])
                                printing(to_print, verbose=1, verbose_level=1)
                        except Exception as e:
                            print("ERROR : not aligned labels so cannot log ",
                                  e)
Example #16
0
def log_warning(counting_failure_parralel_bpe_batch, data_label, batch_i,
                batch, noisy_under_splitted, skipping_batch_n_to_1, aligned,
                noisy_over_splitted, skip_1_t_n, skipping_evaluated_batch,
                verbose):
    printing("WARNING {} aignement failure caused by parallel ",
             var=[counting_failure_parralel_bpe_batch],
             verbose=verbose,
             verbose_level=1)
    printing(
        "WARNING on {} : Out of {} batch of X sentences each {} skipped ({} batch aligned ; {} with at least 1 sentence noisy MORE SPLITTED ; {} with  LESS SPLITTED {} + SENT with skipped_1_to_n : {}) ",
        var=[
            data_label, batch_i, noisy_under_splitted + skipping_batch_n_to_1,
            aligned, noisy_over_splitted, noisy_under_splitted,
            "SKIPPED" if skip_1_t_n else "", skipping_batch_n_to_1
        ],
        verbose=verbose,
        verbose_level=0)
    printing("WARNING on {} ON THE EVALUATION SIDE we skipped extra {} batch ",
             var=[data_label, skipping_evaluated_batch],
             verbose_level=1,
             verbose=1)
def input_normalization_processing(task_normalize_is, batch,
                                   norm_2_noise_training, norm_2_noise_eval):
    norm2noise_bool = False
    if (norm_2_noise_training is not None
            or norm_2_noise_eval) and task_normalize_is:
        portion_norm2noise = norm_2_noise_training if norm_2_noise_training is not None else 1.
        norm_2_noise_training = portion_norm2noise is not None
        rand = np.random.uniform(low=0, high=1, size=1)[0]
        norm2noise_bool = portion_norm2noise >= rand
        if norm2noise_bool:
            batch_raw_input = preprocess_batch_string_for_bert(
                batch.raw_output)
            printing("WARNING : input is gold norm",
                     verbose_level=2,
                     verbose=1)
        else:
            printing("WARNING : input is input", verbose_level=2, verbose=1)
            batch_raw_input = preprocess_batch_string_for_bert(batch.raw_input)
    else:
        printing("WARNING : input is input ", verbose_level=2, verbose=1)
        batch_raw_input = preprocess_batch_string_for_bert(batch.raw_input)
    return batch_raw_input, norm2noise_bool, norm_2_noise_training
Example #18
0
def focused_masking(masking_strategy, input_tokens_tensor,
                    output_tokens_tensor_aligned, dropout_input_bpe,
                    mask_token_index, sep_token_index, use_gpu, epoch, n_epoch,
                    portion_mask, input_mask, tokenizer, verbose):

    if masking_strategy in ["mlm", "mlm_need_norm"]:

        dropout = 0.15
        assert dropout_input_bpe == 0., "in args.masking_strategy mlm we hardcoded dropout to 0.2 {}".format(
            dropout)
        # standart standart_mlm means : standart MLM prediction
        standart_mlm = True
        # unmask_loss : bool do we unmask other loss than only the MASKed tokens
        unmask_loss = portion_mask
        if masking_strategy == "mlm_need_norm":
            # if mlm_need_norm strategy : in args.portion_mask% of the time we learn as a standart mlm the rest
            # of the time we do the same but only on need_norm tokens (masking them)
            standart_mlm = np.random.random() < portion_mask
            # we force unmask loss to 0
            unmask_loss = 0
        if standart_mlm:
            # standart mlm
            input_tokens_tensor, mask_dropout, dropout_applied = dropout_input_tensor(
                input_tokens_tensor,
                mask_token_index,
                sep_token_index=sep_token_index,
                applied_dropout_rate=0.8,
                dropout=dropout)
        elif masking_strategy == "mlm_need_norm" and not standart_mlm:
            # todo : factorize
            feeding_the_model_with_label = output_tokens_tensor_aligned.clone()
            # we only learn on tokens that are different from gold
            feeding_the_model_with_label[input_tokens_tensor ==
                                         output_tokens_tensor_aligned] = -1
            if np.random.random() < 0.85:
                # 80% of the time we mask the tokens as standart mlm
                input_tokens_tensor[
                    input_tokens_tensor !=
                    output_tokens_tensor_aligned] = mask_token_index
            else:
                # within the 15% rest : 50% of the time we replace by random 50% we keep
                if np.random.random() < 0.5:
                    permute = (torch.randperm(
                        torch.tensor(len(tokenizer.vocab) - 2)
                    )[:len(input_tokens_tensor[
                        input_tokens_tensor != output_tokens_tensor_aligned])]
                               + 1)
                    permute[permute == sep_token_index] = sep_token_index + 10
                    permute[permute ==
                            mask_token_index] = mask_token_index + 10
                    permute[permute == 0] = 53
                    if use_gpu:
                        permute = permute.cuda()
                    input_tokens_tensor[input_tokens_tensor !=
                                        output_tokens_tensor_aligned] = permute
            mask_dropout = (
                input_tokens_tensor == output_tokens_tensor_aligned)

        if standart_mlm and not dropout_applied:
            random_bpe_instead = np.random.random() < 0.5
            if random_bpe_instead:
                permute = (
                    torch.randperm(torch.tensor(len(tokenizer.vocab) - 2))
                    [:len(input_tokens_tensor[mask_dropout == 0])] + 1)
                permute[permute == sep_token_index] = sep_token_index + 10
                permute[permute == mask_token_index] = mask_token_index + 10
                permute[permute == 0] = 53
                if use_gpu:
                    permute = permute.cuda()

                input_tokens_tensor[mask_dropout == 0] = permute

        if unmask_loss:
            print(
                "WARNING : unmaskloss is {}  : 0 means only optimizing on the MASK  , > 0 means optimizes "
                "also on some other sampled based on dropout_adapted)".format(
                    unmask_loss))
            power = 3
            capped = 0.5
            dropout_adated = min(((epoch + 1) / n_epoch)**power, capped)
            printing(
                "LABEL NOT MASKING {}/1 of gold labels with power {} and capped {}"
                .format(dropout_adated, power, capped),
                verbose=verbose,
                verbose_level=2)
            _, mask_losses = dropout_input_tensor(
                input_tokens_tensor,
                mask_token_index,
                sep_token_index=sep_token_index,
                apply_dropout=False,
                dropout=dropout_adated)
            # we backpropagate only on tokens that receive a mask (MLM objective) +
            #  some extra ones tgat we control with dropout_adated
            mask_loss = mask_dropout * mask_losses
        else:
            mask_loss = mask_dropout
        feeding_the_model_with_label = output_tokens_tensor_aligned.clone()
        feeding_the_model_with_label[mask_loss != 0] = -1
        # hald the time we actually mask those tokens otherwise we predict
    elif masking_strategy in ["norm_mask", "norm_mask_variable"]:
        if masking_strategy == "norm_mask_variable":
            # args.portion_mask = min(((epoch + 1) / n_epoch), 0.6)
            portion_mask = 1 - (epoch + 1) / n_epoch  # , 0.6))
        mask_normed = np.random.random() < portion_mask
        feeding_the_model_with_label = output_tokens_tensor_aligned.clone()
        if mask_normed:
            print("MASKING NORMED in mode {} portion mask {}".format(
                masking_strategy, portion_mask))
            feeding_the_model_with_label[input_tokens_tensor ==
                                         output_tokens_tensor_aligned] = -1
            if np.random.random() < 0.5:
                # half the time we mask not to make the model only normalizing
                input_tokens_tensor[
                    input_tokens_tensor !=
                    output_tokens_tensor_aligned] = mask_token_index
    else:
        feeding_the_model_with_label = output_tokens_tensor_aligned.clone()
        # TODO -- handle loggin of output_tokens_tensor_aligned everywhere
        printing("MASK mask:{} \nMASK input:{} \nMASK output:{}",
                 var=[
                     input_mask, input_tokens_tensor,
                     feeding_the_model_with_label
                 ],
                 verbose_level="raw_data",
                 verbose=verbose)

    return input_tokens_tensor, feeding_the_model_with_label
def data_gen_conllu(data,
                    word_dictionary,
                    char_dictionary,
                    word_dictionary_norm,
                    batch_size,
                    task_info="",
                    get_batch_mode=True,
                    print_raw=False,
                    normalization=False,
                    pos_dictionary=None,
                    max_token_per_batch=None,
                    dropout_input=0,
                    timing=False,
                    verbose=0):

    n_sents = data[3]
    nbatch = n_sents // batch_size

    if nbatch == 0:
        printing("INFO : n_sents < batch_size so nbatch set to 1 ",
                 verbose=verbose,
                 verbose_level=0)

    printing(
        "TRAINING : Task {} Running {} batches of {} dim (n_sents : {}  time(s)) (if 0 will be set to 1) "
        .format(task_info, nbatch, batch_size, n_sents),
        verbose=verbose,
        verbose_level=2)
    printing("ITERATOR INFO : 1 epoch is {} iteration/step/batch  ",
             var=[nbatch],
             verbose=verbose,
             verbose_level=2)
    nbatch = 1 if nbatch == 0 else nbatch
    # deterministic run over all the dataset (for evaluation)
    if not get_batch_mode:

        for batch in tqdm(conllu_data.iterate_batch_variable(
                data, batch_size=batch_size, normalization=normalization),
                          disable=disable_tqdm_level(verbose,
                                                     verbose_level=2)):

            all_indexes, words, word_norm, wordpieces_words, wordpieces_raw_aligned_with_words, wordpieces_inputs_raw_tokens, \
            ind_wordpieces_words_alignement_index, ind_wordpieces_raw_aligned_alignement_index, ind_wordpieces_inputs_raw_tokens_alignement_index, \
            is_mwe_label, n_masks_to_app_in_raw_label, \
            wordpiece_normalization, ind_wordpiece_normalization_alignement_index,\
            wordpiece_normalization_target_aligned_with_word, ind_wordpiece_normalization_target_aligned_with_word_index,\
            wordpiece_words_src_aligned_with_norm, ind_wordpiece_words_src_aligned_with_norm_index,\
            n_masks_for_norm, to_norm_np,\
            chars, chars_norm, word_norm_not_norm, edit, pos, xpos, heads, types, \
                masks, lengths, order_ids, raw_word_inputs, normalized_str, raw_lines = batch

            outputing_raw_data_from_iterator(
                words,
                word_norm,
                chars,
                chars_norm,
                word_norm_not_norm,
                pos,
                word_dictionary=word_dictionary,
                pos_dictionary=pos_dictionary,
                word_norm_dictionary=word_dictionary_norm,
                char_dictionary=char_dictionary,
                verbose=verbose,
                print_raw=print_raw,
                normalization=normalization)
            yield MaskBatch(
                chars,
                chars_norm,
                edit=edit,
                types=types,
                heads=heads,
                output_word=word_norm,
                pos=pos,
                input_word=words,
                raw_input=raw_word_inputs,
                raw_output=normalized_str,
                wordpieces_words=wordpieces_words,
                ind_wordpieces_words_alignement_index=
                ind_wordpieces_words_alignement_index,
                ind_wordpieces_raw_aligned_alignement_index=
                ind_wordpieces_raw_aligned_alignement_index,
                ind_wordpieces_inputs_raw_tokens_alignement_index=
                ind_wordpieces_inputs_raw_tokens_alignement_index,
                wordpieces_raw_aligned_with_words=
                wordpieces_raw_aligned_with_words,
                wordpieces_inputs_raw_tokens=wordpieces_inputs_raw_tokens,
                is_mwe_label=is_mwe_label,
                n_masks_to_app_in_raw_label=n_masks_to_app_in_raw_label,
                wordpiece_normalization=wordpiece_normalization,
                ind_wordpiece_normalization_alignement_index=
                ind_wordpiece_normalization_alignement_index,
                wordpiece_normalization_target_aligned_with_word=
                wordpiece_normalization_target_aligned_with_word,
                ind_wordpiece_normalization_target_aligned_with_word_index=
                ind_wordpiece_normalization_target_aligned_with_word_index,
                wordpiece_words_src_aligned_with_norm=
                wordpiece_words_src_aligned_with_norm,
                ind_wordpiece_words_src_aligned_with_norm_index=
                ind_wordpiece_words_src_aligned_with_norm_index,
                n_masks_for_norm=n_masks_for_norm,
                to_norm_np=to_norm_np,
                all_indexes=all_indexes,
            ), order_ids

    # get_batch randomly (for training purpose)
    elif get_batch_mode:
        for ibatch in tqdm(range(1, nbatch + 1),
                           disable=disable_tqdm_level(verbose,
                                                      verbose_level=2)):
            # word, char, pos, xpos, heads, types, masks, lengths, morph
            printing("Data : getting {} out of {} batches",
                     var=(ibatch, nbatch + 1),
                     verbose=verbose,
                     verbose_level=2)

            all_indexes, word, word_norm, wordpieces_words, wordpieces_raw_aligned_with_words, wordpieces_inputs_raw_tokens, \
            ind_wordpieces_words_alignement_index, ind_wordpieces_raw_aligned_alignement_index, ind_wordpieces_inputs_raw_tokens_alignement_index, \
            is_mwe_label, n_masks_to_app_in_raw_label, \
            wordpiece_normalization, ind_wordpiece_normalization_alignement_index, \
            wordpiece_normalization_target_aligned_with_word, ind_wordpiece_normalization_target_aligned_with_word_index, \
            wordpiece_words_src_aligned_with_norm, ind_wordpiece_words_src_aligned_with_norm_index, \
            n_masks_for_norm, to_norm_np, \
            char, chars_norm, word_norm_not_norm, edit, pos, _, heads, types, _, \
            lenght, order_ids, raw_word_inputs, normalized_str, _ = conllu_data.get_batch_variable(data, batch_size=batch_size, normalization=normalization, max_token_per_batch=max_token_per_batch)

            assert min(
                lenght.data) > 0, "ERROR : min(lenght.data) is {} ".format(
                    min(lenght.data))

            outputing_raw_data_from_iterator(
                word,
                word_norm,
                char,
                chars_norm,
                word_norm_not_norm,
                pos,
                word_dictionary=word_dictionary,
                pos_dictionary=pos_dictionary,
                char_dictionary=char_dictionary,
                word_norm_dictionary=word_dictionary_norm,
                verbose=verbose,
                print_raw=print_raw,
                normalization=normalization)
            yield MaskBatch(
                char,
                chars_norm,
                output_word=word_norm,
                edit=edit,
                wordpieces_words=wordpieces_words,
                wordpieces_raw_aligned_with_words=
                wordpieces_raw_aligned_with_words,
                wordpieces_inputs_raw_tokens=wordpieces_inputs_raw_tokens,
                is_mwe_label=is_mwe_label,
                types=types,
                heads=heads,
                ind_wordpieces_words_alignement_index=
                ind_wordpieces_words_alignement_index,
                ind_wordpieces_raw_aligned_alignement_index=
                ind_wordpieces_raw_aligned_alignement_index,
                ind_wordpieces_inputs_raw_tokens_alignement_index=
                ind_wordpieces_inputs_raw_tokens_alignement_index,
                n_masks_to_app_in_raw_label=n_masks_to_app_in_raw_label,
                all_indexes=all_indexes,
                wordpiece_normalization=wordpiece_normalization,
                ind_wordpiece_normalization_alignement_index=
                ind_wordpiece_normalization_alignement_index,
                wordpiece_normalization_target_aligned_with_word=
                wordpiece_normalization_target_aligned_with_word,
                ind_wordpiece_normalization_target_aligned_with_word_index=
                ind_wordpiece_normalization_target_aligned_with_word_index,
                wordpiece_words_src_aligned_with_norm=
                wordpiece_words_src_aligned_with_norm,
                ind_wordpiece_words_src_aligned_with_norm_index=
                ind_wordpiece_words_src_aligned_with_norm_index,
                n_masks_for_norm=n_masks_for_norm,
                to_norm_np=to_norm_np,
                pos=pos,
                input_word=word,
                raw_input=raw_word_inputs,
                raw_output=normalized_str), order_ids
def data_gen_multi_task_sampling_batch(tasks,
                                       readers,
                                       word_dictionary,
                                       char_dictionary,
                                       pos_dictionary,
                                       word_dictionary_norm,
                                       batch_size,
                                       get_batch_mode,
                                       mode_batch_sampling="proportional",
                                       dropout_input=0,
                                       max_token_per_batch=None,
                                       print_raw=False,
                                       verbose=1):
    "multitask learning iterator"
    assert len(tasks) == len(readers)
    assert mode_batch_sampling in MODE_BATCH_SAMPLING_AVAILABLE
    iterator = {}
    end_task_flag = {}
    n_sents_per_task_dataset_cumul = {}
    cumul_n_sent = 0
    for simult_task in tasks:
        needs_normalization = does_one_task_require_normalization(simult_task)
        iterator[",".join(simult_task)] = data_gen_conllu(
            data=readers[",".join(simult_task)],
            word_dictionary=word_dictionary,
            task_info=",".join(simult_task),
            char_dictionary=char_dictionary,
            pos_dictionary=pos_dictionary,
            word_dictionary_norm=word_dictionary_norm,
            batch_size=batch_size,
            get_batch_mode=get_batch_mode,
            dropout_input=dropout_input,
            max_token_per_batch=max_token_per_batch,
            print_raw=print_raw,
            normalization=needs_normalization,
            verbose=verbose)
        end_task_flag[",".join(simult_task)] = False
        cumul_n_sent += readers[",".join(simult_task)][-1]
        n_sents_per_task_dataset_cumul[",".join(simult_task)] = cumul_n_sent
    n_sents_per_task_dataset_cumul["all"] = n_sents_per_task_dataset_cumul[
        ",".join(tasks[-1])]

    batch_iter = 0
    while True:
        n_sent_start = 0
        random_sample_id = np.random.randint(0, 100)
        for ind, simult_task in enumerate(tasks):
            simult_task = ",".join(simult_task)
            if sampling_proportion(
                    n_sent_start, n_sents_per_task_dataset_cumul["all"]
            ) < random_sample_id < sampling_proportion(
                    n_sents_per_task_dataset_cumul[simult_task],
                    n_sents_per_task_dataset_cumul["all"]
            ) and not end_task_flag[simult_task]:
                try:
                    batch, order = iterator[simult_task].__next__()
                    sanity_check_batch_label(simult_task,
                                             batch,
                                             verbose=verbose)
                    batch_iter += 1
                    yield batch
                except StopIteration:
                    end_task_flag[simult_task] = True
                    printing("ITERATOR END {} ",
                             var=[simult_task],
                             verbose_level=2,
                             verbose=verbose)
                    break
            else:
                n_sent_start = n_sents_per_task_dataset_cumul[simult_task]
        if sum(end_task_flag.values()) == len(tasks):
            break
def readers_load(datasets,
                 tasks,
                 word_dictionary,
                 word_dictionary_norm,
                 char_dictionary,
                 pos_dictionary,
                 xpos_dictionary,
                 type_dictionary,
                 bert_tokenizer,
                 word_decoder=False,
                 must_get_norm=True,
                 bucket=True,
                 input_level_ls=None,
                 run_mode="train",
                 add_start_char=1,
                 add_end_char=1,
                 symbolic_end=True,
                 symbolic_root=True,
                 verbose=1):

    readers = {}
    simultanuous_training = False  #depreciated
    assert "all" not in tasks, "ERROR not supported yet (pb for simultanuous training..) "
    if not "all" in tasks and not simultanuous_training:
        try:
            assert len(tasks) == len(datasets), "ERROR : as simultanuous_training is {} : " \
                                                "we need 1 dataset per task but have only {} for task {} ".format(simultanuous_training, datasets, tasks)

        except Exception as e:
            pdb.set_trace()
            datasets = [datasets[0] for _ in tasks]
            # SHOULD NOT DO THAT !!
            print("WARNING : duplicating readers", e)

    elif not simultanuous_training:
        assert len(
            tasks) == 1, "ERROR : if all should have only all nothing else"
        printing("TRAINING : MultiTask Iterator wit task 'all' ",
                 verbose_level=1,
                 verbose=verbose)
    elif simultanuous_training:
        printing(
            "TRAINING : Training simulatnuously tasks provided in {} (should have all required labels in datasets)",
            verbose_level=1,
            verbose=verbose)
        raise (Exception("Not supported yet --> should handle the loop "))

    for simul_task, data in zip(tasks, datasets):
        normalization_in_reader = does_one_task_require_normalization(
            simul_task)
        # 1 reader per simultaneously trained task
        readers[",".join(simul_task)] = conllu_data.read_data_to_variable(
            data,
            word_dictionary,
            char_dictionary,
            pos_dictionary,
            xpos_dictionary,
            type_dictionary,
            word_decoder=word_decoder,
            symbolic_end=symbolic_end,
            symbolic_root=symbolic_root,
            dry_run=0,
            normalization=normalization_in_reader,
            bucket=bucket,
            add_start_char=add_start_char,
            add_end_char=add_end_char,
            tasks=simul_task,
            max_char_len=None,
            must_get_norm=must_get_norm,
            bert_tokenizer=bert_tokenizer,
            input_level_ls=input_level_ls,
            run_mode=run_mode,
            word_norm_dictionary=word_dictionary_norm,
            pad_id=bert_tokenizer.convert_tokens_to_ids(
                bert_tokenizer.pad_token),
            verbose=verbose)

    return readers
def write_conll(format, dir_normalized, dir_original, src_text_ls, text_decoded_ls,
                src_text_pos, pred_pos_ls, tasks, inverse=False,permuting_mode=None, cp_paste=False, sep_token=None, cls_token=None,
                ind_batch=0, new_file=False, cut_sent=False, verbose=0):
    assert format in ["conll"]
    #assert len(tasks) == 1, "ERROR : only supported so far 1 task at a time"

    if tasks[0] == "normalize":
        src_ls = src_text_ls
        pred_ls = text_decoded_ls
        if text_decoded_ls is None:
            assert permuting_mode is not None or cp_paste
            pred_ls = src_text_ls
    elif tasks[0] == "pos":
        src_ls = src_text_pos
        pred_ls = pred_pos_ls
    if format == "conll":
        mode_write = "w" if new_file else "a"
        if new_file:
            printing("CREATING NEW FILE (io_/dat/normalized_writer) : {} ", var=[dir_normalized], verbose=verbose, verbose_level=1)
        with open(dir_normalized, mode_write) as norm_file:
            with open(dir_original, mode_write) as original:
                len_original = 0
                for ind_sent, (original_sent, normalized_sent) in enumerate(zip(src_ls, pred_ls)):
                    try:
                        assert len(original_sent) == len(normalized_sent), "WARNING : (writer) original_sent len {} {} \n  " \
                                                                           "normalized_sent len {} {} ".format(len(original_sent), original_sent, len(normalized_sent), normalized_sent)
                    except AssertionError as e:
                        print(e)
                        if len(original_sent) > len(normalized_sent):
                            normalized_sent.extend(["UNK" for _ in range(len(original_sent)-len(normalized_sent))])
                            print("WARNING (writer) : original larger than prediction : so appending UNK token for writing")
                        else:
                            print("WARNING (writer) : original smaller than prediction ! ")

                    norm_file.write("#\n")
                    original.write("#\n")
                    norm_file.write("#sent_id = {} \n".format(ind_sent+ind_batch+1))
                    original.write("#sent_id = {} \n".format(ind_sent+ind_batch+1))
                    ind_adjust = 0

                    if permuting_mode == "sample_mode":
                        noise_level_sentence = np.random.random(1)[0]

                    for ind, (original_token, normalized_token) in enumerate(zip(original_sent,
                                                                                 normalized_sent)):
                        # WE REMOVE SPECIAL TOKENS ONLY IF THEY APPEAR AT THE BEGINING OR AT THE END
                        # on the source token !! (it tells us when we stop) (we nevern want to use gold information)
                        max_len_word = max(len(original_token), len_original)
                        if (original_token in SPECIAL_TOKEN_LS or original_token in [cls_token, sep_token]) and (ind+1 == len(original_sent) or ind == 0):
                            ind_adjust = 1
                            continue

                        if permuting_mode == "sample_mode":
                            # 20% of sentences we apply a 80 noise level n 80% of cases only a 20% noise level
                            rand_word = np.random.random(1)[0]
                            threshold_word = 0.8 if noise_level_sentence < 0.2 else 0.2
                            if rand_word < threshold_word:
                                permuting_mode = np.random.choice(["permute", "double", "random_replace",
                                                                   "multiply_last", "double_last","remove",
                                                                   "remove_last", "z_replace_s"])
                            #print("PERMUTATION is ", permuting_mode, rand_word, APPLY_PERMUTE_WORD,noise_level_sentence)

                        else:
                            rand_word = None

                        # TODO : when want simultanuous training : assert src_pos src_norm same
                        #   --> assert pred_pos and pred_norm are same lengh (number of words) ans write
                        if tasks[0] == "normalize":
                            if inverse:
                                assert not cp_paste
                                _original_token = normalized_token
                                _normalized_token = original_token

                            else:
                                _original_token = original_token
                                _normalized_token = normalized_token
                                if permuting_mode is not None:
                                    assert not cp_paste
                                    # rule one
                                    #print("ORIGINAL TOKEN", original_token)
                                    if ( _original_token == _normalized_token or _original_token.lower() == _normalized_token.lower())\
                                        and not (original_token.startswith("#") or original_token.startswith("@")):
                                        # rule 1
                                        if permuting_mode == "z_replace_s" and len(original_token) > 1:
                                            if original_token.endswith("s"):
                                                _original_token = original_token[:-1] + "z"
                                            else:
                                                permuting_mode = np.random.choice(["permute", "double",
                                                                                   "random_replace",
                                                                                   "remove", "remove_last",
                                                                                   "multiply_last","double_last",
                                                                                    "z_replace_s"])

                                        if permuting_mode == "permute" and len(original_token) > 1:
                                            start_index = 0 if not (original_token.startswith("#") or original_token.startswith("@")) else 1
                                            to_permute = np.random.randint(start_index, len(original_token)-1)
                                            second_letter = original_token[to_permute+1]
                                            first_letter = original_token[to_permute]
                                            list_original_token = list(original_token)
                                            #pdb.set_trace()
                                            list_original_token[to_permute] = second_letter
                                            list_original_token[to_permute+1] = first_letter
                                            _original_token = "".join(list_original_token)
                                        # rule 2
                                        if (permuting_mode == "double" or permuting_mode == "remove") and len(original_token) > 1:
                                            start_index = 0
                                            to_double = np.random.randint(start_index, len(original_token)-1)
                                            first_letter = original_token[to_double]
                                            list_original_token = list(original_token)
                                            #pdb.set_trace()
                                            if permuting_mode == "double":
                                                list_original_token = list_original_token[:to_double] + [first_letter] + list_original_token[to_double:]
                                            else:
                                                list_original_token = list_original_token[:to_double] + list_original_token[to_double:]

                                            _original_token = "".join(list_original_token)

                                        if permuting_mode == "remove_last" and len(original_token) > 1:
                                            _original_token = _original_token[:-1]
                                        if permuting_mode == "double_last" and len(original_token) > 1:
                                            _original_token = _original_token+_original_token[-1]
                                        if permuting_mode == "random_replace" and len(original_token) > 1:
                                            start_index = 0
                                            to_replace = np.random.randint(start_index, len(original_token) - 1)
                                            random_letter = np.random.choice(list("abcdefghijklmnopqrstuvwxyz"))
                                            first_letter = original_token[to_replace]
                                            list_original_token = list(original_token)
                                            # pdb.set_trace()

                                            list_original_token[to_replace] = random_letter

                                            _original_token = "".join(list_original_token)



                                        #print("NEW TOKEN", permuting_mode, _original_token)

                                    #pdb.set_trace()

                            if cp_paste:
                                _normalized_token = _original_token

                            norm_file.write("{}\t{}\t_\t_\t_\t_\t{}\t_\t_\tNorm={}|\n".format(ind + 1 - ind_adjust,
                                                                                              _original_token,
                                                                                              ind - ind_adjust if ind - ind_adjust > 0 else 0,
                                                                                              _normalized_token))
                        if tasks[0] == "pos":
                            norm_file.write("{}\t{}\t_\t{}\t_\t_\t{}\t_\t_\tNorm=()|\n".format(ind + 1 - ind_adjust,
                                                                                               original_token,
                                                                                               normalized_token,
                                                                                               ind-ind_adjust if ind - ind_adjust > 0 else 0
                                                                                               ))
                        original.write("{}\t{}\t_\t_\t_\t_\t_\t_\t{}\t_\n".format(ind+1,
                                                                                  original_token,
                                                                                  ind - ind_adjust if ind - ind_adjust > 0 else 0))

                        if cut_sent:
                            if ind > 50:
                                break
                    norm_file.write("\n")
                    original.write("\n")
            printing("WRITING predicted batch of {} original and {} normalized",
                     var=[dir_original, dir_normalized], verbose=verbose, verbose_level="raw_data")

    return max_len_word
def from_bpe_token_to_str(
        bpe_tensor,
        topk,
        pred_mode,  #null_token_index, null_str,
        task,
        tokenizer=None,
        bpe_tensor_src=None,
        pos_dictionary=None,
        label="normalize",
        label_dictionary=None,
        mask_index=None,
        get_bpe_string=False,
        verbose=1):
    """
    it actually supports not only bpe token but also pos-token
    pred_mode allow to handle gold data also (which only have 2 dim and not three)
    :param bpe_tensor:
    :param topk: int : number of top prediction : will arrange them with all the top1 all the 2nd all the third...
    :param pred_mode: book
    :return:
    """
    assert label is not None or get_bpe_string, \
        "ERROR : task {} get_string {} : one of them should be defined or True".format(label, get_bpe_string)
    if task == "mlm" and pred_mode:
        assert bpe_tensor_src is not None and mask_index is not None, "ERROR bpe_tensor_src is needed to get not-predicted token as well as mask_index "
        predictions_topk_ls = [[[
            bpe_tensor[sent, word, top].item() if bpe_tensor_src[sent,
                                                                 word].item()
            == mask_index else bpe_tensor_src[sent, word].item()
            for word in range(bpe_tensor.size(1))
        ] for sent in range(bpe_tensor.size(0))] for top in range(topk)]
    else:
        predictions_topk_ls = [[[
            bpe_tensor[sent, word,
                       top].item() if pred_mode else bpe_tensor[sent,
                                                                word].item()
            for word in range(bpe_tensor.size(1))
        ] for sent in range(bpe_tensor.size(0))] for top in range(topk)]
    # here all labels that require the tokenizer (should factorize it in some way)
    if get_bpe_string:  #label in ["normalize", "mwe_prediction", "input_masked"] or
        assert tokenizer is not None
        # requires task specific here : mlm only prediction we are interested in are
        # RM , special_extra_token=null_token_index, special_token_string=null_str
        sent_ls_top = [[
            tokenizer.convert_ids_to_tokens(sent_bpe)
            for sent_bpe in predictions_topk
        ] for predictions_topk in predictions_topk_ls]

        printing("DATA : bpe string again {}",
                 var=[sent_ls_top],
                 verbose=verbose,
                 verbose_level="raw_data")
    else:
        dictionary = label_dictionary

        if label_dictionary == "index":
            sent_ls_top = [[[token_ind for token_ind in sent_bpe]
                            for sent_bpe in predictions_topk]
                           for predictions_topk in predictions_topk_ls]
        else:
            try:
                sent_ls_top = [[[
                    dictionary.instances[token_ind -
                                         1] if token_ind > 0 else "UNK"
                    for token_ind in sent_bpe
                ] for sent_bpe in predictions_topk]
                               for predictions_topk in predictions_topk_ls]
            # adding more information about the exe
            except Exception as e:
                print(
                    "{} : dictionary : {} and prediction {} (POS specificity was removed )"
                    .format(e, dictionary.instances, predictions_topk_ls))
                raise (e)

    if not pred_mode:
        sent_ls_top = sent_ls_top[0]

    return sent_ls_top
Example #24
0
def logging_processing_data(_verbose, verbose, verbose_level, batch_raw_input,
                            input_tokens_tensor, batch_raw_output,
                            output_tokens_tensor, inp_bpe_tokenized,
                            out_bpe_tokenized):
    printing("DATA : pre-tokenized input {} ",
             var=[batch_raw_input],
             verbose_level=verbose_level,
             verbose=_verbose)
    printing("DATA : BPEtokenized input ids {}",
             var=[input_tokens_tensor],
             verbose_level=3,
             verbose=verbose)

    printing("DATA : pre-tokenized output {} ",
             var=[batch_raw_output],
             verbose_level=verbose_level,
             verbose=_verbose)
    printing("DATA : BPE tokenized output ids  {}",
             var=[output_tokens_tensor],
             verbose_level=4,
             verbose=verbose)
    # BPE
    printing("DATA : BPE tokenized input  {}",
             var=[inp_bpe_tokenized],
             verbose_level=4,
             verbose=_verbose)
    printing("DATA : BPE tokenized output  {}",
             var=[out_bpe_tokenized],
             verbose_level=4,
             verbose=_verbose)
Example #25
0
                                   ("norm_2_noise_training", args.norm_2_noise_training),
                                   ("random_iterator_train", random_iterator_train),
                                   ("aggregating_bert_layer_mode", args.aggregating_bert_layer_mode),
                                   ("tokenize_and_bpe", args.tokenize_and_bpe),
                                   ("seed", seed), ("case", case), ("bert_module", args.bert_module),
                                   ("freeze_layer_prefix_ls", args.freeze_parameters),
                                   ("layer_wise_attention", args.layer_wise_attention),
                                   ("append_n_mask", args.append_n_mask),
                                   ("multi_task_loss_ponderation", args.multi_task_loss_ponderation),
                                   ("multitask", args.multitask),
                                   ("low_memory_foot_print_batch_mode", args.low_memory_foot_print_batch_mode),
                                   ("graph_head_hidden_size_mlp_rel", args.graph_head_hidden_size_mlp_rel),
                                   ("graph_head_hidden_size_mlp_arc", args.graph_head_hidden_size_mlp_arc),
                                   ("ponderation_per_layer", args.ponderation_per_layer),
                                   ("norm_order_per_layer", args.norm_order_per_layer),
                                   ("weight_decay", args.weight_decay),
                                   ("penalize", args.penalize),
                                   ("hidden_dropout_prob", args.hidden_dropout_prob),
                                   ("schedule_lr", args.schedule_lr),
                                   ("n_steps_warmup",args.n_steps_warmup),
                                   ("random_init", args.random_init),
                                   ("dict_path", dict_path),
                                   ("model_id",model_id),
                                   ("optimizer", args.optimizer),
                                   ("model_location", model_location)
                                   ])
    printing("HYPERPARAMETERS {} ", var=[hyperparameters], verbose=verbose, verbose_level=1)
    printing("HYPERPARAMETERS KEYS {} ", var=[hyperparameters.keys()], verbose=verbose, verbose_level=1)
    return hyperparameters

def args_preprocessing(args, verbose=1):
    """
    sanity checking , changing types of arguments and parsing arguments
    """

    if isinstance(args.schedule_lr, str) and args.schedule_lr == "None":
        args.schedule_lr = eval(args.schedule_lr)

    if args.batch_size != "flexible":
        args.batch_size = int(args.batch_size)

    if args.low_memory_foot_print_batch_mode is not None and args.low_memory_foot_print_batch_mode != "flexible_forward_batch_size":
        args.low_memory_foot_print_batch_mode = int(
            args.low_memory_foot_print_batch_mode)
    low_memory_foot_print_batch_mode_available = [
        0, 1, "flexible_forward_batch_size"
    ]

    assert args.low_memory_foot_print_batch_mode is None or args.low_memory_foot_print_batch_mode in low_memory_foot_print_batch_mode_available, "ERROR args.low_memory_foot_print_batch_mode {} should be in {}".format(
        args.low_memory_foot_print_batch_mode,
        low_memory_foot_print_batch_mode_available)

    if args.low_memory_foot_print_batch_mode:
        args.batch_update_train = args.batch_size
        args.batch_size = "flexible" if args.low_memory_foot_print_batch_mode == "flexible_forward_batch_size" else 2
        printing(
            "INFO : args.low_memory_foot_print_batch_mode {} "
            "so setting batch_size to {} and args.batch_update_train {}",
            var=[
                args.low_memory_foot_print_batch_mode, args.batch_size,
                args.batch_update_train
            ],
            verbose=verbose,
            verbose_level=1)

        if args.low_memory_foot_print_batch_mode != "flexible_forward_batch_size":
            assert isinstance(
                args.batch_update_train // args.batch_size, int
            ) and args.batch_update_train // args.batch_size > 0, "ERROR batch_size {} should be a multiple of 2 ".format(
                args.batch_update_train)
        printing(
            "INFO iterator : updating with {} equivalent batch size : forward pass is {} batch size",
            var=[args.batch_update_train, args.batch_size],
            verbose=verbose,
            verbose_level=1)
    else:
        args.batch_update_train = args.batch_size
    params = vars(args)
    args.lr = parse_argument_dictionary(params["lr"], hyperparameter="lr")

    if args.test_paths is not None:
        args.test_paths = [
            test_path_task.split(",") for test_path_task in args.test_paths
        ]

    if args.dev_path is not None:
        args.dev_path = [
            dev_path_task.split(",") for dev_path_task in args.dev_path
        ]

    if args.ponderation_per_layer is not None:
        args.ponderation_per_layer = parse_argument_dictionary(
            params["ponderation_per_layer"],
            hyperparameter="ponderation_per_layer")
    if args.norm_order_per_layer is not None:
        args.norm_order_per_layer = parse_argument_dictionary(
            params["norm_order_per_layer"],
            hyperparameter="norm_order_per_layer")

    args.tasks = [task_simul.split(",") for task_simul in args.tasks]

    if args.test_paths is not None:
        assert isinstance(args.test_paths, list) and isinstance(
            args.test_paths[0], list), "ERROR args.test_paths should be a list"
    # 1 simultaneous set of tasks per training dataset
    assert len(args.tasks) == len(
        args.train_path
    ), "ERROR args.tasks is {} but train paths are {}".format(
        args.tasks, args.train_path)

    assert args.penalization_mode in AVAILALE_PENALIZATION_MODE, "ERROR args.penalization_mode {} should be in {}".format(
        args.penalization_mode, AVAILALE_PENALIZATION_MODE)

    if args.multi_task_loss_ponderation is not None:
        argument_as_string = args.multi_task_loss_ponderation
        assert args.tasks is not None
        tasks = [task for tasks in args.tasks for task in tasks]
        # should add test on task X label calling task setting
        for task in tasks:
            if task != "all":
                for label in TASKS_PARAMETER[task]["label"]:
                    pattern = "{}-{}=([^=]*),".format(task, label)
                    match = re.search(pattern, argument_as_string)
                    assert match is not None, "ERROR : pattern {} not found for task {} in argument_as_string {}  ".format(
                        pattern, task, argument_as_string)

    if args.bert_model is not None:
        assert args.bert_model in BERT_MODEL_DIC, "ERROR args.bert_model {} should be in {}".format(
            args.bert_model, BERT_MODEL_DIC.keys())

    return args
def get_indexes(list_pretokenized_str,
                tokenizer,
                verbose,
                use_gpu,
                word_norm_not_norm=None):
    """
    from pretokenized string : it will bpe-tokenize it using BERT 'tokenizer'
    and then convert it to tokens ids
    :param list_pretokenized_str:
    :param tokenizer:
    :param verbose:
    :param use_gpu:
    :return:
    """
    all_tokenized_ls = [
        tokenizer.tokenize_origin(inp, ) for inp in list_pretokenized_str
    ]
    tokenized_ls = [tup[0] for tup in all_tokenized_ls]

    aligned_index = [tup[1] for tup in all_tokenized_ls]
    segments_ids = [[0 for _ in range(len(tokenized))]
                    for tokenized in tokenized_ls]

    printing("DATA : bpe tokenized {} , {} {} ",
             var=[tokenized_ls,
                  len(tokenized_ls),
                  len(tokenized_ls[0])],
             verbose=verbose,
             verbose_level="raw_data")
    printing("DATA : bpe tokenized {} , {} {} ",
             var=[tokenized_ls,
                  len(tokenized_ls),
                  len(tokenized_ls[0])],
             verbose=verbose,
             verbose_level="alignement")
    ids_ls = [tokenizer.convert_tokens_to_ids(inp) for inp in tokenized_ls]
    max_sent_len = max([len(inp) for inp in tokenized_ls])
    ids_padded = [
        inp + [PAD_ID_BERT for _ in range(max_sent_len - len(inp))]
        for inp in ids_ls
    ]
    aligned_index_padded = [[e for e in inp] +
                            [1000 for _ in range(max_sent_len - len(inp))]
                            for inp in aligned_index]
    segments_padded = [
        inp + [PAD_ID_BERT for _ in range(max_sent_len - len(inp))]
        for inp in segments_ids
    ]

    if word_norm_not_norm is not None:
        mask = mask_group(word_norm_not_norm,
                          bpe_aligned_index=aligned_index_padded)
    else:
        mask = [[1 for _ in inp] + [0 for _ in range(max_sent_len - len(inp))]
                for inp in segments_ids]
    mask = torch.LongTensor(mask)
    tokens_tensor = torch.LongTensor(ids_padded)
    segments_tensors = torch.LongTensor(segments_padded)
    if use_gpu:
        mask = mask.cuda()
        tokens_tensor = tokens_tensor.cuda()
        segments_tensors = segments_tensors.cuda()

    printing("DATA {}", var=[tokens_tensor], verbose=verbose, verbose_level=3)

    sanity_check_data_len(tokens_tensor,
                          segments_tensors,
                          tokenized_ls,
                          aligned_index,
                          raising_error=True)

    return tokens_tensor, segments_tensors, tokenized_ls, aligned_index_padded, mask
def write_conll_multitask(format, dir_pred, dir_original, src_text_ls,
                          pred_per_task, tasks, task_parameters, cp_paste=False, gold=False,
                          all_indexes=None, sep_token=None, cls_token=None,
                          sent_id=None, raw_text=None,
                          append_mwe_ind= None,
                          append_mwe_row= None,
                          ind_batch=0, new_file=False, cut_sent=False, verbose=0):

    assert format in ["conll"]
    max_len_word = None
    writing_top = 1
    # assert each task is predicting as many sample per batch
    pred_task_len_former = -1
    task_former = ""

    # assertion on number of samples predicted
    for task_label in pred_per_task:

        pred_task_len = len(pred_per_task[task_label]) if gold else len(pred_per_task[task_label][writing_top-1])
        _task = re.match("(.*)-(.*)", task_label)
        if _task is not None:  # , "ERROR writer could not match {}".format(task_label)
            task = _task.group(1)
        else:
            task = task_label
        if pred_task_len_former > 0:
            assert pred_task_len == pred_task_len_former, \
                "ERROR {} and {} task ".format(task_former, task_label)
            if not gold:
                assert pred_task_len == len(src_text_ls[task_parameters[task]["input"]]), "ERROR  src len {} and pred len {} ".format(len(src_text_ls[task_parameters[task]["input"]]),pred_task_len)
            # we check also other input length
            if src_text_ls.get("input_masked") is not None:
                assert pred_task_len == len(src_text_ls["input_masked"])
            if src_text_ls.get("wordpieces_inputs_words") is not None:
                assert pred_task_len == len(src_text_ls["wordpieces_inputs_words"]), "ERROR mismatch source " \
                                                                            "wordpieces_inputs_words {}  " \
                                                                            "and prediction {} ".format(src_text_ls, pred_per_task[task_label])
            if src_text_ls.get("wordpieces_inputs_raw_tokens") is not None:
                assert pred_task_len == len(src_text_ls["wordpieces_inputs_raw_tokens"]), \
                                    "ERROR mismatch source wordpieces_inputs_" \
                                    "raw_tokens {} and prediction {} ".format(src_text_ls, pred_per_task[task_label])
            try:
                assert pred_task_len == all_indexes.shape[0], "ERROR mismatch index {}  and all_indexes {} : pred {}".format(pred_task_len, all_indexes.shape[0], pred_per_task[task_label])
            except:
                pdb.set_trace()
        pred_task_len_former = pred_task_len

        task_former = task_label
        if format == "conll":
            mode_write = "w" if new_file else "a"
        if new_file:
            printing("CREATING NEW FILE (io_/dat/normalized_writer) : {} ", var=[dir_pred], verbose=verbose,
                     verbose_level=1)

    pos_label = "pos-pos" if not gold else "pos"
    types_label = "parsing-types" if not gold else "types"
    heads_label = "parsing-heads" if not gold else "heads"
    n_masks_mwe_label = "n_masks_mwe-n_masks_mwe" if not gold else "n_masks_mwe"
    mwe_detection_label = "mwe_detection-mwe_detection" if not gold else "mwe_detection"

    with open(dir_pred, mode_write) as norm_file:
        with open(dir_original, mode_write) as original:
            len_original = 0
            for ind_sent in range(all_indexes.shape[0]):
                pred_sent = OrderedDict()
                # NB : length assertion for each input-output (correcting if possible)
                # TODO standartize !!  INCONSITENCIES WHEN GOLD TRUE AND GOLD FALSE, IF GOLD : pred_per_task is indexed by labels (no relation 1-1 to task and src ! )
                for task_label_or_gold_label in pred_per_task:
                    #task, _, label_processed = get_task_name_based_on_logit_label(task_label, label_processed)
                    if gold:
                        pred_sent[task_label_or_gold_label] = pred_per_task[task_label_or_gold_label][ind_sent]
                    else:
                        pred_sent[task_label_or_gold_label] = pred_per_task[task_label_or_gold_label][writing_top-1][ind_sent]
                    try:
                        # TODO : standartize  (the first if is needed because we handle at the same time gold data indexed by label and prediction labelled by task+label
                        if gold:
                            try:
                                src = src_text_ls[LABEL_PARAMETER[task_label_or_gold_label]["default_input"]][ind_sent]
                            except Exception as e:
                                src = src_text_ls["input_masked"][ind_sent]
                        else:
                            _task = re.match("(.*)-(.*)", task_label_or_gold_label)
                            assert _task is not None#, "ERROR writer could not match {}".format(task_label)
                            _label = _task.group(2)
                            _task = _task.group(1)
                            src = src_text_ls[TASKS_PARAMETER[_task]["input"]][ind_sent]

                        assert len(src) == len(pred_sent[task_label_or_gold_label]),"WARNING : (writer) task {} original_sent len {} {} \n  predicted sent len {} {}".format(task_label_or_gold_label, len(src), src,len(pred_sent[task_label_or_gold_label]), pred_sent[task_label_or_gold_label])
                    except AssertionError as e:
                        print(e)
                        pdb.set_trace()
                        if len(src) > len(pred_sent[task_label_or_gold_label]):
                            pred_sent[task_label_or_gold_label].extend(["UNK" for _ in range(len(src)-len(pred_sent[task_label_or_gold_label]))])
                            print("WARNING (writer) : original larger than prediction : so appending UNK token for writing")
                        else:
                            print("WARNING (writer) : original smaller than prediction for ")

                if sent_id is not None and raw_text is not None:
                    #norm_file.write("\n")
                    #original.write("\n")

                    norm_file.write(sent_id[ind_sent])
                    original.write(sent_id[ind_sent])
                    norm_file.write(raw_text[ind_sent])
                    original.write(raw_text[ind_sent])
                else:
                    norm_file.write("#\n")
                    original.write("#\n")
                    norm_file.write("#sent_id = {} \n".format(ind_sent+ind_batch+1))
                    original.write("#sent_id = {} \n".format(ind_sent+ind_batch+1))
                ind_adjust = 0

                #for ind, original_token in enumerate(original_sent):
                last_mwe_index = -1
                adjust_mwe = 0
                for ind in all_indexes[ind_sent, :]:
                    # WE REMOVE SPECIAL TOKENS ONLY IF THEY APPEAR AT THE BEGINING OR AT THE END
                    # on the source token !! (it tells us when we stop) (we nevern want to use gold information)
                    if "-" in ind and ind != "-1":
                        matching_mwe_ind = re.match("([0-9]+)-([0-9]+)", str(ind))
                        assert matching_mwe_ind is not None, "ERROR ind is {} : could not found mwe index".format(ind)
                        last_mwe_index = int(matching_mwe_ind.group(2))
                        ind_mwe = int(matching_mwe_ind.group(1))

                        original_token = src_text_ls["wordpieces_inputs_raw_tokens"][ind_sent][ind_mwe] if mwe_detection_label in pred_per_task or "wordpieces_inputs_words" in pred_per_task or n_masks_mwe_label in pred_per_task else "NOT_NEEDED"
                        adjust_mwe += (last_mwe_index-ind_mwe)
                        #assert ind_adjust == 0, "ERROR not supported"

                        mwe_meta = "Norm={}|mwe_detection={}|n_masks_mwe={}".format("_", pred_sent[mwe_detection_label][ind_mwe] if mwe_detection_label in pred_per_task else "_",
                                                                                    pred_sent[n_masks_mwe_label][ind_mwe] if n_masks_mwe_label in pred_per_task else "_")

                        norm_file.write("{index}\t{original}\t_\t{pos}\t_\t_\t{dep}\t_\t{types}\t{norm}\n".format(index=ind, original=original_token, pos="_", types="_", dep="_", norm=mwe_meta))
                        original.write("{}\t{}\t_\t_\t_\t_\t_\t_\t{}\t_\n".format(ind, original_token, "_"))
                        continue
                    else:
                        ind = int(ind)
                        try:
                            if "normalize" in [task for _tasks in tasks for task in _tasks]:

                                original_token = src_text_ls["wordpiece_words_src_aligned_with_norm"][ind_sent][ind]
                                original_pretokenized_field = "wordpiece_words_src_aligned_with_norm"
                            else:
                                original_token = src_text_ls["wordpieces_inputs_words"][ind_sent][ind]
                                original_pretokenized_field = "wordpieces_inputs_words"
                        except Exception as e:
                            original_token = src_text_ls["input_masked"][ind_sent][ind]
                            original_pretokenized_field = "input_masked"
                        # asserting that we have everything together on the source side
                        if ind > last_mwe_index:
                            if src_text_ls.get("wordpieces_inputs_raw_tokens") is not None:
                                try:
                                    assert src_text_ls[original_pretokenized_field][ind_sent][ind] == src_text_ls["wordpieces_inputs_raw_tokens"][ind_sent][ind-adjust_mwe], \
                                    "ERROR sequence {} on non-mwe tokens : raw and tokenized " \
                                    "should be same but are raw {} tokenized {}".format(original_pretokenized_field, src_text_ls["wordpieces_inputs_raw_tokens"][ind_sent][ind],
                                                                                        src_text_ls[original_pretokenized_field][ind_sent][ind+adjust_mwe])
                                except:
                                    print("WARNING sanity checking input failed (nomalized_writer) (might be due to dropout) {}".format(e))
                    max_len_word = max(len(original_token), len_original)
                    #if original_token in SPECIAL_TOKEN_LS and (ind+1 == len(original_sent) or ind == 0):
                    if (original_token in SPECIAL_TOKEN_LS or original_token in [cls_token, sep_token]):
                        # ind 0 is skipped because it corresponds to CLS
                        ind_adjust = 1
                        continue

                    pos = pred_sent[pos_label][ind] if pos_label in pred_per_task else "_"
                    types = pred_sent[types_label][ind] if types_label in pred_per_task else "_"
                    heads = pred_sent[heads_label][ind] if heads_label in pred_per_task else ind - 1

                    tenth_col = "Norm={}|mwe_detection={}|n_masks_mwe={}".format(pred_sent["normalize"][ind] if "normalize" in pred_per_task else "_",
                                                                                 pred_sent[mwe_detection_label][ind-adjust_mwe] if mwe_detection_label in pred_per_task else "_",
                                                                                 pred_sent[n_masks_mwe_label][ind-adjust_mwe] if n_masks_mwe_label in pred_per_task else "_")

                    if append_mwe_ind is not None:
                        # we need one list of mwe per batch
                        assert isinstance(append_mwe_ind[ind_sent], list)
                        assert isinstance(append_mwe_row[ind_sent], list)
                        if len(append_mwe_ind[ind_sent]) > 0:
                            assert len(append_mwe_row[ind_sent])>0
                            begin_mwe = append_mwe_ind[ind_sent][0]
                            if begin_mwe == ind:
                                norm_file.write(append_mwe_row[ind_sent][0])
                                original.write(append_mwe_row[ind_sent][0])
                                append_mwe_ind[ind_sent].pop()
                                append_mwe_row[ind_sent].pop()

                    norm_file.write("{index}\t{original}\t_\t{pos}\t_\t_\t{dep}\t_\t{types}\t{norm}\n".format(index=ind, original=original_token, pos=pos, types=types, dep=heads, norm=tenth_col))
                    original.write("{}\t{}\t_\t_\t_\t_\t_\t_\t{}\t_\n".format(ind, original_token, ind-1))
                    if cut_sent:
                        if ind > 50:
                            break
                        print("CUTTING SENT index {}>50 ".format(ind))
                norm_file.write("\n")
                original.write("\n")
        printing("WRITING predicted batch of {} original and {} normalized", var=[dir_original, dir_pred], verbose=verbose, verbose_level=1)
    assert max_len_word is not None, "ERROR : something went wrong in the writer"
    return max_len_word
def outputing_raw_data_from_iterator(words, word_norm, chars, chars_norm, word_norm_not_norm, pos,
                                     verbose, print_raw, normalization, char_dictionary, word_dictionary,
                                     word_norm_dictionary,
                                     pos_dictionary):
    """
    printing real data on the fly for debugging, data sanity check, ...
    TODO : may factorize a few things here
    :param words:
    :param word_norm:
    :param chars:
    :param chars_norm:
    :param word_norm_not_norm:
    :param pos:
    :param verbose:
    :param print_raw:
    :param normalization:
    :param char_dictionary:
    :param word_dictionary:
    :param word_norm_dictionary:
    :param pos_dictionary:
    :return:
    """
    _verbose = verbose if isinstance(verbose, int) else 0
    if print_raw:
        _verbose = 5

    if _verbose >= 5:
        if word_norm_not_norm is not None:
            character_display = [
                " ".join([char_dictionary.get_instance(chars[sent, word_ind, char_i]) for char_i in range(chars.size(2))]) +
                " | NORM : {} |SENT {} WORD {}| ".format(word_norm_not_norm[sent, word_ind], sent, word_ind) for
                ind_sent, sent in enumerate(range(chars.size(0)))
                for ind_w, word_ind in enumerate(range(chars.size(1)))]
        else:
            character_display = [
                " ".join(
                    [char_dictionary.get_instance(chars[sent, word_ind, char_i]) for char_i in range(chars.size(2))])
                for ind_sent, sent in enumerate(range(chars.size(0)))
                for ind_w, word_ind in enumerate(range(chars.size(1)))]

        if word_norm is not None:
            assert word_norm_dictionary is not None
            word_norm_display = " ".join([word_norm_dictionary.get_instance(word_norm[sent, word_ind]) for word_ind in range(word_norm.size(1)) for sent in range(word_norm.size(0))])
        else:
            print("No word level normalized word (only char)")
            word_norm_display = ["NONE"]

        word_display = [word_dictionary.get_instance(words[batch, word_ind]) + " "
                        for batch in range(chars.size(0)) for word_ind in range(chars.size(1))]

        if pos_dictionary is not None:
            pos_display = [pos_dictionary.get_instance(pos[batch, 0]) + " " for batch in
                           range(chars.size(0))]
        else:
            pos_display = None

    else:
        word_display = []
        character_display = []
        pos_display = []
    if not normalization and chars is not None:
        chars_norm = chars.clone()

    # TODO add word_norm
    if _verbose >= 5:
        if word_norm_not_norm is not None:
            character_norm_display = [" ".join([char_dictionary.get_instance(chars_norm[sent, word_ind, char_i])
                                                for char_i in range(chars_norm.size(2))]) +
                                      "|  NORM : {} |SENT {} WORD {}| \n ".format(word_norm_not_norm[sent, word_ind], sent,
                                                                                  word_ind)
                                      for ind_sent, sent in enumerate(range(chars_norm.size(0)))
                                      for ind_w, word_ind in enumerate(range(chars_norm.size(1)))]
        else:
            character_norm_display = [" ".join([char_dictionary.get_instance(chars_norm[sent, word_ind, char_i])
                                                for char_i in range(chars_norm.size(2))])
                                      for ind_sent, sent in enumerate(range(chars_norm.size(0)))
                                      for ind_w, word_ind in enumerate(range(chars_norm.size(1)))]
        printing("Feeding source characters {} \n ------ Target characters {}  "
                 "(NB : the character vocabulary is the same at input and output)",
                 var=(character_display, character_norm_display),
                 verbose=_verbose, verbose_level=5)
        printing("Feeding source words {} ", var=[word_display], verbose=_verbose, verbose_level=5)
        printing("Feeding Word normalized (word level) {}", var=[word_norm_display], verbose=_verbose, verbose_level=5)
        printing("Feeding source pos {} ", var=[pos_display], verbose=_verbose, verbose_level=5)
        if chars is not None and chars_norm is not None:
            printing("TYPE {} char before batch chars_norm {} ", var=(chars.is_cuda, chars_norm.is_cuda), verbose=verbose, verbose_level=5)
def parse_argument_dictionary(argument_as_string,
                              logits_label=None,
                              hyperparameter="multi_task_loss_ponderation",
                              verbose=1):
    """
    All arguments that are meant to be defined as dictionaries are passed to the Argument Parser as string:
    following  template :  i.e 'key1=value1,key2=value,'  (matched with "{}=([^=]*),".format(sub) )
    ALl the dictionary arguments are listed in DIC_ARGS
    """
    assert hyperparameter in DIC_ARGS, "ERROR only supported"
    if argument_as_string in MULTI_TASK_LOSS_PONDERATION_PREDEFINED_MODE:
        return argument_as_string
    else:
        dic = OrderedDict()
        if hyperparameter == "multi_task_loss_ponderation":
            assert logits_label is not None
            for task in logits_label:
                # useless (I think)
                if task == "parsing":
                    for sub in ["parsing-heads", "parsing-types"]:
                        pattern = "{}=([^=]*),".format(sub)
                        match = re.search(pattern, argument_as_string)
                        assert match is not None, "ERROR : pattern {} not found for task {} in argument_as_string {}  ".format(
                            pattern, task, argument_as_string)
                        dic[sub] = eval(match.group(1))
                # useless (I thinh)
                elif task == "normalize":
                    for sub in ["normalize", "append_masks"]:
                        pattern = "{}=([^=]*),".format(sub)
                        match = re.search(pattern, argument_as_string)
                        if sub == "normalize":
                            assert match is not None, "ERROR : pattern {} not found for task {} " \
                                                      "in argument_as_string {}  ".format( pattern, task, argument_as_string)
                            dic[sub] = eval(match.group(1))
                        else:
                            if match is not None:
                                dic[sub] = eval(match.group(1))
                # all cases should be in this one
                if task != "all" and task != "parsing":

                    pattern = "{}=([^=]*),".format(task)
                    match = re.search(pattern, argument_as_string)
                    assert match is not None, "ERROR : pattern {} not found for task {} in argument_as_string {}  ".format(
                        pattern, task, argument_as_string)
                    dic[task] = eval(match.group(1))

            printing("SANITY CHECK : multi_task_loss_ponderation {} ",
                     var=[argument_as_string],
                     verbose_level=3,
                     verbose=verbose)

        elif hyperparameter in [
                "lr", "norm_order_per_layer", "ponderation_per_layer"
        ]:
            # to handle several optimizers
            try:
                assert isinstance(eval(argument_as_string), float)
                return eval(argument_as_string)
            except:
                argument_as_string = argument_as_string.split(",")
                for arg in argument_as_string[:-1]:
                    # DIFFERENCE WITH ABOVE IS THE COMMA
                    pattern = "([^=]*)=([^=]*)"
                    match = re.search(pattern, arg)
                    assert match is not None, "ERROR : pattern {} not found in argument_as_string {}  ".format(
                        pattern, arg)
                    if hyperparameter in ["lr"]:
                        dic[match.group(1)] = float(match.group(2))
                    elif hyperparameter in ["norm_order_per_layer"]:
                        if match.group(2) != "fro":
                            dic[match.group(1)] = float(match.group(2))
                        else:
                            dic[match.group(1)] = match.group(2)
                    elif hyperparameter in ["ponderation_per_layer"]:
                        dic[match.group(1)] = float(match.group(2))

        return dic