def report_template(metric_val,
                    info_score_val,
                    score_val,
                    model_full_name_val,
                    report_path_val,
                    evaluation_script_val,
                    data_val,
                    model_args_dir,
                    n_tokens_score=None,
                    task=None,
                    n_sents=None,
                    token_type=None,
                    subsample=None,
                    avg_per_sent=None,
                    min_per_epoch=None,
                    layer_i=None):

    return OrderedDict([
        ("metric", metric_val), ("info_score", info_score_val),
        ("score", score_val), ("model_full_name", model_full_name_val),
        ("n_tokens_score", n_tokens_score), ("n_sents", n_sents),
        ("token_type", token_type), ("subsample", subsample),
        ("avg_per_sent", avg_per_sent), ("min/epoch", min_per_epoch),
        ("model_args_dir", model_args_dir), ("report_path", report_path_val),
        ("evaluation_script", evaluation_script_val), ("task", task),
        ("layer", layer_i), ("data", data_val)
    ])
Exemple #2
0
def get_penalization(model_parameters,
                     model_parameters_0,
                     norm_order_per_layer,
                     ponderation_per_layer,
                     penalization_mode=None,
                     pruning_mask=None):
    penalization_dic = OrderedDict()
    assert isinstance(ponderation_per_layer,
                      dict), "{} should be dict ".format(ponderation_per_layer)
    #assert set(ponderation_per_layer) == set(norm_order_per_layer), "ERROR {} not same keys as {}".format(norm_order_per_layer, ponderation_per_layer)
    if penalization_mode is None:
        penalization_mode = "pruning"
    assert penalization_mode in AVAILALE_PENALIZATION_MODE, "ERROR {} shoul be in {}".format(
        penalization_mode, AVAILALE_PENALIZATION_MODE)
    penalization = 0

    for (name, param), (name_0, param_0) in zip(model_parameters,
                                                model_parameters_0.items()):
        assert name == name_0, "ERROR {} <> {}".format(name, name_0)
        key_norm = _get_key_ponderation(name, norm_order_per_layer)
        key_pond = _get_key_ponderation(name, ponderation_per_layer)
        n_param_layer = _get_n_param(param)
        # Each single unit parameter count the same (modulo ponderation_per_layer-)
        #print("SANITY CHECKING debugging param {} has norm {} for dim {} ".format(name, torch.norm(param_0, p=norm_order_per_layer[key]), n_param_layer))
        power = norm_order_per_layer[key_norm] if norm_order_per_layer[
            key_norm] == 2 else 1

        # new --
        if penalization_mode == "pruning":
            assert pruning_mask is not None, "ERROR pruning_mask needed"
            # only norm 2 supported so far
            # ponderation applies on the non pruned mask
            pruning_mask_non = 1 - pruning_mask[name_0]
            diff = (param - param_0).flatten()
            norm_2_other = torch.sum((pruning_mask_non * diff)**2)
            norm_2_on_mask_param = torch.sum(
                (pruning_mask[name_0] * (diff))**2)
            _penalization = ponderation_per_layer[
                key_pond] * norm_2_other + 10 * norm_2_on_mask_param
            penalization += _penalization
            penalization_dic[name] = (n_param_layer,
                                      ponderation_per_layer[key_pond],
                                      _penalization.detach().cpu(),
                                      norm_2_other.detach().cpu(),
                                      norm_2_on_mask_param.detach().cpu())
        elif penalization_mode == "layer_wise":
            penalization += ponderation_per_layer[key_pond] * torch.norm(
                (param - param_0).flatten(),
                p=norm_order_per_layer[key_norm])**power
            penalization_dic[name] = (
                n_param_layer, ponderation_per_layer[key_pond], (torch.norm(
                    (param.detach() - param_0).flatten(),
                    p=norm_order_per_layer[key_norm])).cpu()**power)
        else:
            raise (Exception(
                "penalization_mode {} not supported".format(penalization_mode))
                   )

    penalization_dic["all"] = ("all", "total", penalization)
    return penalization, penalization_dic
def write_args(dir,
               model_id,
               checkpoint_dir=None,
               hyperparameters=None,
               info_checkpoint=None,
               verbose=1):

    args_dir = os.path.join(dir, "{}-args.json".format(model_id))
    if os.path.isfile(args_dir):
        info = "updated"
        args = json.load(open(args_dir, "r"))
        args["checkpoint_dir"] = checkpoint_dir
        args["info_checkpoint"] = info_checkpoint
        json.dump(args, open(args_dir, "w"))
    else:
        assert hyperparameters is not None, "REPORT : args.json created for the first time : hyperparameters dic required "
        #assert info_checkpoint is None, "REPORT : args. created for the first time : no checkpoint yet "
        info = "new"
        json.dump(
            OrderedDict([("checkpoint_dir", checkpoint_dir),
                         ("hyperparameters", hyperparameters),
                         ("info_checkpoint", info_checkpoint)]),
            open(args_dir, "w"))
    printing("MODEL args.json {} written {} ".format(info, args_dir),
             verbose_level=1,
             verbose=verbose)
    return args_dir
Exemple #4
0
def count_tokens(task_ls, n_tokens_counter_per_task, label_per_task,
                 label_paremeter):
    n_tokens_counter_current_per_task = OrderedDict()
    """"get exact number of non-pad tokens for each tasks"""
    for task in task_ls:
        for label in TASKS_PARAMETER[task]["label"]:
            n_tokens_counter_current_per_task[task + "-" + label] = (
                label_per_task[label] !=
                label_paremeter[label]["pad_value"]).sum().item()
            n_tokens_counter_per_task[
                task + "-" +
                label] += n_tokens_counter_current_per_task[task + "-" + label]
    ## TODO : handle in a more standart way
    n_tokens_all = n_tokens_counter_current_per_task[task + "-" + label]
    return n_tokens_counter_per_task, n_tokens_counter_current_per_task, n_tokens_all
def get_config_param_to_modify(args):
    """ for now only internal bert dropout can be modifed a such"""
    config_to_update = OrderedDict()
    if args.dropout_bert is not None:
        assert args.dropout_bert >= 0, "ERROR {}".format(args.dropout_bert)
        config_to_update["attention_probs_dropout_prob"] = args.dropout_bert
        config_to_update["hidden_dropout_prob"] = args.dropout_bert
    if args.dropout_classifier is not None:
        assert args.dropout_classifier >= 0
        config_to_update["dropout_classifier"] = args.dropout_classifier
    if args.graph_head_hidden_size_mlp_rel is not None:
        config_to_update[
            "graph_head_hidden_size_mlp_rel"] = args.graph_head_hidden_size_mlp_rel
    if args.graph_head_hidden_size_mlp_rel is not None:
        config_to_update[
            "graph_head_hidden_size_mlp_arc"] = args.graph_head_hidden_size_mlp_arc

    return config_to_update
def get_batch_per_layer_head(attention, layer_head_att_batch=None, head=True):
    """

    :param attention:
    :param layer_head_att_batch:  if not None, will append to it as a list of tensor
    :return:
    """
    if layer_head_att_batch is None:
        layer_head_att_batch = OrderedDict()
    for i_layer in range(len(attention)):
        if head:
            for i_head in range(attention[0].size(1)):
                if f"layer_{i_layer}-head_{i_head}" not in layer_head_att_batch:
                    layer_head_att_batch[f"layer_{i_layer}-head_{i_head}"] = []
                #pdb.set_trace()
                layer_head_att_batch[f"layer_{i_layer}-head_{i_head}"].append(
                    attention[i_layer][:, i_head].detach())
        else:
            if f"layer_{i_layer}" not in layer_head_att_batch:
                layer_head_att_batch[f"layer_{i_layer}"] = []
            layer_head_att_batch[f"layer_{i_layer}"].append(
                attention[i_layer][:].detach())
    return layer_head_att_batch
def main(args, dict_path, model_dir):

    encoder = BERT_MODEL_DIC[args.bert_model]["encoder"]
    vocab_size = BERT_MODEL_DIC[args.bert_model]["vocab_size"]
    voc_tokenizer = BERT_MODEL_DIC[args.bert_model]["vocab"]

    tokenizer = eval(BERT_MODEL_DIC[args.bert_model]["tokenizer"])
    random.seed(args.seed)

    if args.model_id_pref is None:
        run_id = str(uuid4())[:4]
    else:
        run_id = args.model_id_pref + "1"

    if args.init_args_dir is None:
        dict_path += "/" + run_id
        os.mkdir(dict_path)
    tokenizer = tokenizer.from_pretrained(voc_tokenizer,
                                          do_lower_case=args.case == "lower",
                                          shuffle_bpe_embedding=False)
    mask_id = tokenizer.encode([
        "[MASK]"
    ])[0] if args.bert_model == "bert_base_multilingual_cased" else None

    _dev_path = args.dev_path if args.dev_path is not None else args.train_path
    word_dictionary, word_norm_dictionary, char_dictionary, pos_dictionary, \
    xpos_dictionary, type_dictionary = \
        conllu_data.load_dict(dict_path=dict_path,
                              train_path=args.train_path if args.init_args_dir is None else None,
                              dev_path=args.dev_path if args.init_args_dir is None else None,
                              test_path=args.test_paths if args.init_args_dir is None else None,
                              word_embed_dict={},
                              dry_run=False,
                              expand_vocab=False,
                              word_normalization=True,
                              force_new_dic=False,
                              tasks=args.tasks,
                              pos_specific_data_set=None,
                              #pos_specific_data_set=args.train_path[1] if len(args.tasks) > 1 and len(
                              #    args.train_path) > 1 and "pos" in args.tasks else None,
                              case=args.case,
                              # if not normalize pos or parsing in tasks we don't need dictionary
                              do_not_fill_dictionaries=len(set(["normalize", "pos", "parsing"]) & set(
                                  [task for tasks in args.tasks for task in tasks])) == 0,
                              add_start_char=True if args.init_args_dir is None else None,
                              verbose=1)

    num_labels_per_task, task_to_label_dictionary = get_vocab_size_and_dictionary_per_task(
        [task for tasks in args.tasks for task in tasks],
        vocab_bert_wordpieces_len=vocab_size,
        pos_dictionary=pos_dictionary,
        type_dictionary=type_dictionary,
        task_parameters=TASKS_PARAMETER)

    model = make_bert_multitask(
        args=args,
        pretrained_model_dir=model_dir,
        init_args_dir=args.init_args_dir,
        tasks=[task for tasks in args.tasks for task in tasks],
        mask_id=mask_id,
        encoder=encoder,
        num_labels_per_task=num_labels_per_task)

    def get_n_params(model):
        pp = 0
        for p in list(model.parameters()):
            nn = 1
            for s in list(p.size()):

                nn = nn * s
            pp += nn
        return pp

    param = get_n_params(model)
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])

    pdb.set_trace()

    data = ["I am here", "How are you"]
    model.eval()
    n_obs = args.n_sent
    max_len = args.max_seq_len
    lang_ls = args.raw_text_code
    data_all, y_all = load_lang_ls(DATA_UD_RAW, lang_ls=lang_ls)

    reg = linear_model.LogisticRegression()
    X_train = OrderedDict()
    X_test = OrderedDict()
    y_train = OrderedDict()
    y_test = OrderedDict()
    # just to get the keyw
    layer_head_att = get_hidden_representation(data,
                                               model,
                                               tokenizer,
                                               max_len=max_len,
                                               output_dic=False,
                                               pad_below_max_len=True)
    layer_head_att = layer_head_att[0]
    report_ls = []
    accuracy_dic = OrderedDict()
    sampling = args.sampling
    for ind, layer_head in enumerate(list(layer_head_att.keys())):
        report = OrderedDict()
        accuracy_ls = []
        layer_head = list(
            layer_head_att.keys())[len(list(layer_head_att.keys())) - ind - 1]
        for _ in range(sampling):
            sample_ind = random.sample(population=range(len(data_all)),
                                       k=n_obs)
            sample_ind_test = random.sample(population=range(len(data_all)),
                                            k=n_obs)

            data = data_all[sample_ind]
            y = y_all[sample_ind]

            all = get_hidden_representation(data,
                                            model,
                                            tokenizer,
                                            max_len=max_len,
                                            output_dic=False,
                                            pad_below_max_len=True)

            layer_head_att = all[0]

            #pdb.set_trace()
            def reshape_x(z):
                return np.array(z.view(z.size(0) * z.size(1), -1))

            def reshape_y(z, n_seq):
                'multiply each element n_seq times'
                new_z = []
                for _z in z:
                    #for _ in range(n_seq):
                    new_z.extend([_z for _ in range(n_seq)])
                return np.array(new_z)
                #return np.array(z.view(z.size(0), -1).transpose(1, 0))

            #X_train[layer_head] = np.array(layer_head_att[layer_head].view(layer_head_att[layer_head].size(0), -1).transpose(1,0))
            X_train[layer_head] = reshape_x(layer_head_att[layer_head])
            y_train[layer_head] = reshape_y(y, max_len)
            #db.set_trace()
            #y_train[layer_head] = y

            reg.fit(X=X_train[layer_head], y=y_train[layer_head])

            # test
            data_test = data_all[sample_ind_test]
            layer_head_att_test = get_hidden_representation(
                data_test,
                model,
                tokenizer,
                max_len=max_len,
                output_dic=False,
                pad_below_max_len=True)
            X_test[layer_head] = reshape_x(layer_head_att_test[layer_head])
            y_test[layer_head] = reshape_y(y_all[sample_ind_test], max_len)

            y_pred = reg.predict(X_test[layer_head])

            Accuracy = np.sum(
                (y_test[layer_head] == y_pred)) / len(y_test[layer_head])
            accuracy_ls.append(Accuracy)

        accuracy_dic[layer_head] = np.mean(accuracy_ls)
        layer = layer_head.split("-")[0]
        if layer not in accuracy_dic:
            accuracy_dic[layer] = []
        accuracy_dic[layer].append(np.mean(accuracy_ls))

        print(
            f"Regression {layer_head} Accuracy test {np.mean(accuracy_ls)} on {n_obs * max_len}"
            f" word sample from {len(lang_ls)} languages task {args.tasks} args {'/'.join(args.init_args_dir.split('/')[-2:]) if args.init_args_dir is not None else None} "
            f"bert {args.bert_model} random init {args.random_init} std {np.std(accuracy_ls)} sampling {len(accuracy_ls)}=={sampling}"
        )

        #report["model_type"] = args.bert_model if args.init_args_dir is None else args.tasks[0][0]+"-tune"
        #report["accuracy"] = np.mean(accuracy_ls)
        #report["sampling"] = len(accuracy_ls)
        #report["std"] = np.std(accuracy_ls)
        #report["n_sent"] = n_obs
        #report["n_obs"] = n_obs*max_len

        report = report_template(
            metric_val="accuracy",
            subsample=",".join(lang_ls),
            info_score_val=sampling,
            score_val=np.mean(accuracy_ls),
            n_sents=n_obs,
            avg_per_sent=np.std(accuracy_ls),
            n_tokens_score=n_obs * max_len,
            model_full_name_val=run_id,
            task="attention_analysis",
            evaluation_script_val="exact_match",
            model_args_dir=args.init_args_dir
            if args.init_args_dir is not None else args.random_init,
            token_type="word",
            report_path_val=None,
            data_val=layer_head)
        report_ls.append(report)

        # break

    for key in accuracy_dic:
        print(
            f"Summary {key} {np.mean(accuracy_dic[key])} model word sample from {len(lang_ls)} languages task {args.tasks} args {'/'.join(args.init_args_dir.split('/')[-2:]) if args.init_args_dir is not None else None} "
            f"bert {args.bert_model} random init {args.random_init} std {np.std(accuracy_ls)} sampling {len(accuracy_ls)}=={sampling}"
        )

    if args.report_dir is None:
        report_dir = PROJECT_PATH + f"/../../analysis/attention_analysis/report/{run_id}-report"
        os.mkdir(report_dir)
    else:
        report_dir = args.report_dir
    assert os.path.isdir(report_dir)
    with open(report_dir + "/report.json", "w") as f:
        json.dump(report_ls, f)
    overall_report = args.overall_report_dir + "/" + args.overall_label + "-grid-report.json"
    with open(overall_report, "r") as g:
        report_all = json.load(g)
        report_all.extend(report_ls)
    with open(overall_report, "w") as file:
        json.dump(report_all, file)

    print("{} {} ".format(REPORT_FLAG_DIR_STR, overall_report))
Exemple #8
0
def write_conll_multitask(format, dir_pred, dir_original, src_text_ls,
                          pred_per_task, tasks, task_parameters, cp_paste=False, gold=False,
                          all_indexes=None, sep_token=None, cls_token=None,
                          ind_batch=0, new_file=False, cut_sent=False, verbose=0):

    assert format in ["conll"]
    max_len_word = None
    writing_top = 1
    # assert each task is predicting as many sample per batch
    pred_task_len_former = -1
    task_former = ""

    # assertion on number of samples predicted
    for task_label in pred_per_task:

        pred_task_len = len(pred_per_task[task_label]) if gold else len(pred_per_task[task_label][writing_top-1])
        _task = re.match("(.*)-(.*)", task_label)
        if _task is not None:  # , "ERROR writer could not match {}".format(task_label)
            task = _task.group(1)
        else:
            task = task_label
        if pred_task_len_former > 0:
            assert pred_task_len == pred_task_len_former, \
                "ERROR {} and {} task ".format(task_former, task_label)
            if not gold:
                assert pred_task_len == len(src_text_ls[task_parameters[task]["input"]]), "ERROR  src len {} and pred len {} ".format(len(src_text_ls[task_parameters[task]["input"]]),pred_task_len)
            # we check also other input length
            if src_text_ls.get("input_masked") is not None:
                assert pred_task_len == len(src_text_ls["input_masked"])
            if src_text_ls.get("wordpieces_inputs_words") is not None:
                assert pred_task_len == len(src_text_ls["wordpieces_inputs_words"]), "ERROR mismatch source " \
                                                                            "wordpieces_inputs_words {}  " \
                                                                            "and prediction {} ".format(src_text_ls, pred_per_task[task_label])
            if src_text_ls.get("wordpieces_inputs_raw_tokens") is not None:
                assert pred_task_len == len(src_text_ls["wordpieces_inputs_raw_tokens"]), \
                                    "ERROR mismatch source wordpieces_inputs_" \
                                    "raw_tokens {} and prediction {} ".format(src_text_ls, pred_per_task[task_label])
            try:
                assert pred_task_len == all_indexes.shape[0], "ERROR mismatch index {}  and all_indexes {} : pred {}".format(pred_task_len, all_indexes.shape[0], pred_per_task[task_label])
            except:
                pdb.set_trace()
        pred_task_len_former = pred_task_len

        task_former = task_label
        if format == "conll":
            mode_write = "w" if new_file else "a"
        if new_file:
            printing("CREATING NEW FILE (io_/dat/normalized_writer) : {} ", var=[dir_pred], verbose=verbose,
                     verbose_level=1)

    pos_label = "pos-pos" if not gold else "pos"
    types_label = "parsing-types" if not gold else "types"
    heads_label = "parsing-heads" if not gold else "heads"
    n_masks_mwe_label = "n_masks_mwe-n_masks_mwe" if not gold else "n_masks_mwe"
    mwe_detection_label = "mwe_detection-mwe_detection" if not gold else "mwe_detection"

    with open(dir_pred, mode_write) as norm_file:
        with open(dir_original, mode_write) as original:
            len_original = 0
            for ind_sent in range(all_indexes.shape[0]):
                pred_sent = OrderedDict()
                # NB : length assertion for each input-output (correcting if possible)
                # TODO standartize !!  INCONSITENCIES WHEN GOLD TRUE AND GOLD FALSE, IF GOLD : pred_per_task is indexed by labels (no relation 1-1 to task and src ! )
                for task_label_or_gold_label in pred_per_task:
                    #task, _, label_processed = get_task_name_based_on_logit_label(task_label, label_processed)
                    if gold:
                        pred_sent[task_label_or_gold_label] = pred_per_task[task_label_or_gold_label][ind_sent]
                    else:
                        pred_sent[task_label_or_gold_label] = pred_per_task[task_label_or_gold_label][writing_top-1][ind_sent]
                    try:
                        # TODO : standartize  (the first if is needed because we handle at the same time gold data indexed by label and prediction labelled by task+label
                        if gold:
                            try:
                                src = src_text_ls[LABEL_PARAMETER[task_label_or_gold_label]["default_input"]][ind_sent]
                            except Exception as e:
                                src = src_text_ls["input_masked"][ind_sent]
                        else:
                            _task = re.match("(.*)-(.*)", task_label_or_gold_label)
                            assert _task is not None#, "ERROR writer could not match {}".format(task_label)
                            _label = _task.group(2)
                            _task = _task.group(1)
                            src = src_text_ls[TASKS_PARAMETER[_task]["input"]][ind_sent]

                        assert len(src) == len(pred_sent[task_label_or_gold_label]),"WARNING : (writer) task {} original_sent len {} {} \n  predicted sent len {} {}".format(task_label_or_gold_label, len(src), src,len(pred_sent[task_label_or_gold_label]), pred_sent[task_label_or_gold_label])
                    except AssertionError as e:
                        print(e)
                        pdb.set_trace()
                        if len(src) > len(pred_sent[task_label_or_gold_label]):
                            pred_sent[task_label_or_gold_label].extend(["UNK" for _ in range(len(src)-len(pred_sent[task_label_or_gold_label]))])
                            print("WARNING (writer) : original larger than prediction : so appending UNK token for writing")
                        else:
                            print("WARNING (writer) : original smaller than prediction for ")

                norm_file.write("#\n")
                original.write("#\n")
                norm_file.write("#sent_id = {} \n".format(ind_sent+ind_batch+1))
                original.write("#sent_id = {} \n".format(ind_sent+ind_batch+1))
                ind_adjust = 0

                #for ind, original_token in enumerate(original_sent):
                last_mwe_index = -1
                adjust_mwe = 0
                for ind in all_indexes[ind_sent, :]:
                    # WE REMOVE SPECIAL TOKENS ONLY IF THEY APPEAR AT THE BEGINING OR AT THE END
                    # on the source token !! (it tells us when we stop) (we nevern want to use gold information)
                    if "-" in ind and ind != "-1":
                        matching_mwe_ind = re.match("([0-9]+)-([0-9]+)", str(ind))
                        assert matching_mwe_ind is not None, "ERROR ind is {} : could not found mwe index".format(ind)
                        last_mwe_index = int(matching_mwe_ind.group(2))
                        ind_mwe = int(matching_mwe_ind.group(1))

                        original_token = src_text_ls["wordpieces_inputs_raw_tokens"][ind_sent][ind_mwe] if mwe_detection_label in pred_per_task or "wordpieces_inputs_words" in pred_per_task or n_masks_mwe_label in pred_per_task else "NOT_NEEDED"
                        adjust_mwe += (last_mwe_index-ind_mwe)
                        #assert ind_adjust == 0, "ERROR not supported"

                        mwe_meta = "Norm={}|mwe_detection={}|n_masks_mwe={}".format("_", pred_sent[mwe_detection_label][ind_mwe] if mwe_detection_label in pred_per_task else "_",
                                                                                    pred_sent[n_masks_mwe_label][ind_mwe] if n_masks_mwe_label in pred_per_task else "_")

                        norm_file.write("{index}\t{original}\t_\t{pos}\t_\t_\t{dep}\t_\t{types}\t{norm}\n".format(index=ind, original=original_token, pos="_", types="_", dep="_", norm=mwe_meta))
                        original.write("{}\t{}\t_\t_\t_\t_\t_\t_\t{}\t_\n".format(ind, original_token, "_"))
                        continue
                    else:
                        ind = int(ind)
                        try:
                            if "normalize" in [task for _tasks in tasks for task in _tasks]:

                                original_token = src_text_ls["wordpiece_words_src_aligned_with_norm"][ind_sent][ind]
                                original_pretokenized_field = "wordpiece_words_src_aligned_with_norm"
                            else:
                                original_token = src_text_ls["wordpieces_inputs_words"][ind_sent][ind]
                                original_pretokenized_field = "wordpieces_inputs_words"
                        except Exception as e:
                            original_token = src_text_ls["input_masked"][ind_sent][ind]
                            original_pretokenized_field = "input_masked"
                        # asserting that we have everything together on the source side
                        if ind > last_mwe_index:
                            if src_text_ls.get("wordpieces_inputs_raw_tokens") is not None:
                                try:
                                    assert src_text_ls[original_pretokenized_field][ind_sent][ind] == src_text_ls["wordpieces_inputs_raw_tokens"][ind_sent][ind-adjust_mwe], \
                                    "ERROR sequence {} on non-mwe tokens : raw and tokenized " \
                                    "should be same but are raw {} tokenized {}".format(original_pretokenized_field, src_text_ls["wordpieces_inputs_raw_tokens"][ind_sent][ind],
                                                                                        src_text_ls[original_pretokenized_field][ind_sent][ind+adjust_mwe])
                                except:
                                    print("WARNING sanity checking input failed (nomalized_writer) (might be due to dropout) {}".format(e))
                    max_len_word = max(len(original_token), len_original)
                    #if original_token in SPECIAL_TOKEN_LS and (ind+1 == len(original_sent) or ind == 0):
                    if (original_token in SPECIAL_TOKEN_LS or original_token in [cls_token, sep_token]):
                        # ind 0 is skipped because it corresponds to CLS
                        ind_adjust = 1
                        continue

                    pos = pred_sent[pos_label][ind] if pos_label in pred_per_task else "_"
                    types = pred_sent[types_label][ind] if types_label in pred_per_task else "_"
                    heads = pred_sent[heads_label][ind] if heads_label in pred_per_task else ind - 1

                    tenth_col = "Norm={}|mwe_detection={}|n_masks_mwe={}".format(pred_sent["normalize"][ind] if "normalize" in pred_per_task else "_",
                                                                                 pred_sent[mwe_detection_label][ind-adjust_mwe] if mwe_detection_label in pred_per_task else "_",
                                                                                 pred_sent[n_masks_mwe_label][ind-adjust_mwe] if n_masks_mwe_label in pred_per_task else "_")

                    norm_file.write("{index}\t{original}\t_\t{pos}\t_\t_\t{dep}\t_\t{types}\t{norm}\n".format(index=ind, original=original_token, pos=pos, types=types, dep=heads, norm=tenth_col))
                    original.write("{}\t{}\t_\t_\t_\t_\t_\t_\t{}\t_\n".format(ind, original_token, ind-1))
                    if cut_sent:
                        if ind > 50:
                            break
                        print("CUTTING SENT index {}>50 ".format(ind))
                norm_file.write("\n")
                original.write("\n")
        printing("WRITING predicted batch of {} original and {} normalized", var=[dir_original, dir_pred], verbose=verbose, verbose_level=2)
    assert max_len_word is not None, "ERROR : something went wrong in the writer"
    return max_len_word
def parse_argument_dictionary(argument_as_string,
                              logits_label=None,
                              hyperparameter="multi_task_loss_ponderation",
                              verbose=1):
    """
    All arguments that are meant to be defined as dictionaries are passed to the Argument Parser as string:
    following  template :  i.e 'key1=value1,key2=value,'  (matched with "{}=([^=]*),".format(sub) )
    ALl the dictionary arguments are listed in DIC_ARGS
    """
    assert hyperparameter in DIC_ARGS, "ERROR only supported"
    if argument_as_string in MULTI_TASK_LOSS_PONDERATION_PREDEFINED_MODE:
        return argument_as_string
    else:
        dic = OrderedDict()
        if hyperparameter == "multi_task_loss_ponderation":
            assert logits_label is not None
            for task in logits_label:
                # useless (I think)
                if task == "parsing":
                    for sub in ["parsing-heads", "parsing-types"]:
                        pattern = "{}=([^=]*),".format(sub)
                        match = re.search(pattern, argument_as_string)
                        assert match is not None, "ERROR : pattern {} not found for task {} in argument_as_string {}  ".format(
                            pattern, task, argument_as_string)
                        dic[sub] = eval(match.group(1))
                # useless (I thinh)
                elif task == "normalize":
                    for sub in ["normalize", "append_masks"]:
                        pattern = "{}=([^=]*),".format(sub)
                        match = re.search(pattern, argument_as_string)
                        if sub == "normalize":
                            assert match is not None, "ERROR : pattern {} not found for task {} " \
                                                      "in argument_as_string {}  ".format( pattern, task, argument_as_string)
                            dic[sub] = eval(match.group(1))
                        else:
                            if match is not None:
                                dic[sub] = eval(match.group(1))
                # all cases should be in this one
                if task != "all" and task != "parsing":

                    pattern = "{}=([^=]*),".format(task)
                    match = re.search(pattern, argument_as_string)
                    assert match is not None, "ERROR : pattern {} not found for task {} in argument_as_string {}  ".format(
                        pattern, task, argument_as_string)
                    dic[task] = eval(match.group(1))

            printing("SANITY CHECK : multi_task_loss_ponderation {} ",
                     var=[argument_as_string],
                     verbose_level=3,
                     verbose=verbose)

        elif hyperparameter in [
                "lr", "norm_order_per_layer", "ponderation_per_layer"
        ]:
            # to handle several optimizers
            try:
                assert isinstance(eval(argument_as_string), float)
                return eval(argument_as_string)
            except Exception as e:
                print("Exception", hyperparameter, e)
                argument_as_string = argument_as_string.split(",")
                for arg in argument_as_string[:-1]:
                    # DIFFERENCE WITH ABOVE IS THE COMMA
                    pattern = "([^=]*)=([^=]*)"
                    match = re.search(pattern, arg)
                    assert match is not None, "ERROR : pattern {} not found in argument_as_string {}  ".format(
                        pattern, arg)
                    if hyperparameter in ["lr"]:
                        dic[match.group(1)] = float(match.group(2))
                    elif hyperparameter in ["norm_order_per_layer"]:
                        if match.group(2) != "fro":
                            dic[match.group(1)] = float(match.group(2))
                        else:
                            dic[match.group(1)] = match.group(2)
                    elif hyperparameter in ["ponderation_per_layer"]:
                        dic[match.group(1)] = float(match.group(2))

        return dic
def get_hidden_representation(data,
                              model,
                              tokenizer,
                              special_start="[CLS]",
                              special_end="[SEP]",
                              pad="[PAD]",
                              max_len=100,
                              pad_below_max_len=False,
                              output_dic=True):
    """
    get hidden representation (ie contetualized vector at the word level : add it as list or padded tensor : output[attention|layer|layer_head]["layer_x"] list or tensor)
    :param data: list of raw text
    :param pad: will add padding below max_len
    :return: output a dictionary (if output_dic) or a tensor (if not output_dic) : of contextualized representation at the word level per layer/layer_head
    """
    model.eval()
    special_start = tokenizer.bos_token
    special_end = tokenizer.eos_token
    if special_start is None or special_end is None:
        special_start = "[CLS]"
        special_end = "[SEP]"

    layer_head_att_tensor_dic = OrderedDict()
    layer_hidden_state_tensor_dic = OrderedDict()
    layer_head_hidden_state_tensor_dic = OrderedDict()
    layer_head_att_batch_dic = OrderedDict()
    layer_head_hidden_state_dic = OrderedDict()
    layer_hidden_state_dic = OrderedDict()
    print(
        f"Getting hidden representation : adding special char start:{special_start} end:{special_end}"
    )
    for seq in data:
        seq = special_start + " " + seq + " " + special_end
        tokenized = tokenizer.encode(seq)
        if len(tokenized) >= max_len:
            tokenized = tokenized[:max_len - 1]
            tokenized += tokenizer.encode(special_end)
        mask = [1 for _ in range(len(tokenized))]
        real_len = len(tokenized)
        if pad_below_max_len:
            if len(tokenized) < max_len:
                for _ in range(max_len - len(tokenized)):
                    tokenized += tokenizer.encode(pad)
                    mask.append(0)
            assert len(tokenized) == max_len
        assert len(tokenized) <= max_len + 2

        encoded = torch.tensor(tokenized).unsqueeze(0)
        inputs = OrderedDict([("wordpieces_inputs_words", encoded)])
        attention_mask = OrderedDict([("wordpieces_inputs_words",
                                       torch.tensor(mask).unsqueeze(0))])
        assert real_len
        if torch.cuda.is_available():
            inputs["wordpieces_inputs_words"] = inputs[
                "wordpieces_inputs_words"].cuda()

            attention_mask["wordpieces_inputs_words"] = attention_mask[
                "wordpieces_inputs_words"].cuda()
        model_output = model(input_ids_dict=inputs,
                             attention_mask=attention_mask)
        #pdb.set_trace()
        #logits = model_output[0]

        # getting the output index based on what we are asking the model
        hidden_state_per_layer_index = 2 if model.config.output_hidden_states else False
        attention_index_original_index = 3 - int(
            not hidden_state_per_layer_index
        ) if model.config.output_attentions else False
        hidden_state_per_layer_per_head_index = False  #4-int(not attention_index_original_index) if model.config.output_hidden_states_per_head else False
        # getting the output
        hidden_state_per_layer = model_output[
            hidden_state_per_layer_index] if hidden_state_per_layer_index else None
        attention = model_output[
            attention_index_original_index] if attention_index_original_index else None
        hidden_state_per_layer_per_head = model_output[
            hidden_state_per_layer_per_head_index] if hidden_state_per_layer_per_head_index else None

        # checking that we got the correct output
        try:
            if attention is not None:
                assert len(attention) == 12, "ERROR attenttion"
                assert attention[0].size()[-1] == attention[0].size(
                )[-2], "ERROR attenttion"
            if hidden_state_per_layer is not None:
                assert len(
                    hidden_state_per_layer) == 12 + 1, "ERROR hidden state"
                assert hidden_state_per_layer[0].size(
                )[-1] == 768, "ERROR hidden state"
            if hidden_state_per_layer_per_head is not None:
                assert len(hidden_state_per_layer_per_head
                           ) == 12, "ERROR hidden state per layer"
                assert hidden_state_per_layer_per_head[0].size(
                )[1] == 12 and hidden_state_per_layer_per_head[0].size(
                )[-1] == 64, "ERROR hidden state per layer"
        except Exception as e:
            raise (Exception(e))

        # concat as a batch per layer/layer_head
        if hidden_state_per_layer is not None:
            layer_hidden_state_dic = get_batch_per_layer_head(
                hidden_state_per_layer, layer_hidden_state_dic, head=False)
        if attention is not None:
            layer_head_att_batch_dic = get_batch_per_layer_head(
                attention, layer_head_att_batch_dic)
        if hidden_state_per_layer_per_head is not None:
            layer_head_hidden_state_dic = get_batch_per_layer_head(
                hidden_state_per_layer_per_head, layer_head_hidden_state_dic)

    output = ()
    if output_dic:
        if len(layer_hidden_state_dic) > 0:
            output = output + (layer_hidden_state_dic, )
        if len(layer_head_att_batch_dic) > 0:
            output = output + (layer_head_att_batch_dic, )
        if len(layer_head_hidden_state_dic) > 0:
            output = output + (layer_head_hidden_state_dic, )
    else:
        # concatanate in a tensor
        # should have padding on !
        assert pad_below_max_len
        if len(layer_hidden_state_dic) > 0:
            for key in layer_hidden_state_dic:
                layer_hidden_state_tensor_dic[key] = torch.cat(
                    layer_hidden_state_dic[key], 0)
            output = output + (layer_hidden_state_tensor_dic, )
        if len(layer_head_att_batch_dic) > 0:
            for key in layer_head_att_batch_dic:
                layer_head_att_tensor_dic[key] = torch.cat(
                    layer_head_att_batch_dic[key], 0)
            output = output + (layer_head_att_tensor_dic, )
        if len(layer_head_hidden_state_dic) > 0:
            for key in layer_head_hidden_state_dic:
                layer_head_hidden_state_tensor_dic[key] = torch.cat(
                    layer_head_hidden_state_dic[key], 0)
            output = output + (layer_head_hidden_state_tensor_dic, )

    return output
Exemple #11
0
def run(args,
        n_observation_max_per_epoch_train,
        vocab_size,
        model_dir,
        voc_tokenizer,
        auxilliary_task_norm_not_norm,
        null_token_index,
        null_str,
        tokenizer,
        n_observation_max_per_epoch_dev_test=None,
        run_mode="train",
        dict_path=None,
        end_predictions=None,
        report=True,
        model_suffix="",
        description="",
        saving_every_epoch=10,
        model_location=None,
        model_id=None,
        report_full_path_shared=None,
        skip_1_t_n=False,
        heuristic_test_ls=None,
        remove_mask_str_prediction=False,
        inverse_writing=False,
        extra_label_for_prediction="",
        random_iterator_train=True,
        bucket_test=False,
        must_get_norm_test=True,
        early_stoppin_metric=None,
        subsample_early_stoping_metric_val=None,
        compute_intersection_score_test=True,
        threshold_edit=3,
        name_with_epoch=False,
        max_token_per_batch=200,
        encoder=None,
        debug=False,
        verbose=1):
    """
    Wrapper for training/prediction/evaluation

    2 modes : train (will train using train and dev iterators with test at the end on test_path)
              test : only test at the end : requires all directories to be created
    :return:
    """
    assert run_mode in ["train", "test"
                        ], "ERROR run mode {} corrupted ".format(run_mode)
    input_level_ls = ["wordpiece"]
    assert early_stoppin_metric is not None and subsample_early_stoping_metric_val is not None, "ERROR : assert early_stoppin_metric should be defined and subsample_early_stoping_metric_val "
    if n_observation_max_per_epoch_dev_test is None:
        n_observation_max_per_epoch_dev_test = n_observation_max_per_epoch_train
    printing("MODEL : RUNNING IN {} mode",
             var=[run_mode],
             verbose=verbose,
             verbose_level=1)
    printing(
        "WARNING : casing was set to {} (this should be consistent at train and test)",
        var=[args.case],
        verbose=verbose,
        verbose_level=2)

    if len(args.tasks) == 1:
        printing("INFO : MODEL : 1 set of simultaneous tasks {}".format(
            args.tasks),
                 verbose=verbose,
                 verbose_level=1)

    if run_mode == "test":
        assert args.test_paths is not None and isinstance(
            args.test_paths, list)
    if run_mode == "train":
        printing("CHECKPOINTING info : "
                 "saving model every {}",
                 var=saving_every_epoch,
                 verbose=verbose,
                 verbose_level=1)

    use_gpu = use_gpu_(use_gpu=None, verbose=verbose)

    def get_commit_id():
        repo = git.Repo(os.path.dirname(os.path.realpath(__file__)),
                        search_parent_directories=True)
        git_commit_id = str(repo.head.commit)  # object.hexsha
        return git_commit_id

    if verbose > 1:
        print(f"GIT ID : {get_commit_id()}")

    train_data_label = get_dataset_label(args.train_path, default="train")

    iter_train = 0
    iter_dev = 0
    row = None
    writer = None

    printout_allocated_gpu_memory(verbose, "{} starting all".format(model_id))

    if run_mode == "train":
        if os.path.isdir(args.train_path[0]) and len(args.train_path) == 1:
            data_sharded = args.train_path[0]
            printing(
                "INFO args.train_path is directory so not rebuilding shards",
                verbose=verbose,
                verbose_level=1)
        elif os.path.isdir(args.train_path[0]):
            raise (Exception(
                " {} is a directory but len is more than one , not supported".
                format(args.train_path[0], len(args.train_path))))
        else:
            data_sharded = None
        assert model_location is None and model_id is None, "ERROR we are creating a new one "

        model_id, model_location, dict_path, tensorboard_log, end_predictions, data_sharded \
            = setup_repoting_location(model_suffix=model_suffix, data_sharded=data_sharded,
                                      root_dir_checkpoints=CHECKPOINT_BERT_DIR,
                                      shared_id=args.overall_label, verbose=verbose)
        hyperparameters = get_hyperparameters_dict(
            args,
            args.case,
            random_iterator_train,
            seed=args.seed,
            verbose=verbose,
            dict_path=dict_path,
            model_id=model_id,
            model_location=model_location)
        args_dir = write_args(model_location,
                              model_id=model_id,
                              hyperparameters=hyperparameters,
                              verbose=verbose)

        if report:
            if report_full_path_shared is not None:
                tensorboard_log = os.path.join(report_full_path_shared,
                                               "tensorboard")
            printing("tensorboard --logdir={} --host=localhost --port=1234 ",
                     var=[tensorboard_log],
                     verbose_level=1,
                     verbose=verbose)
            writer = SummaryWriter(log_dir=tensorboard_log)
            if writer is not None:
                writer.add_text("INFO-ARGUMENT-MODEL-{}".format(model_id),
                                str(hyperparameters), 0)
    else:
        args_checkpoint = json.load(open(args.init_args_dir, "r"))
        dict_path = args_checkpoint["hyperparameters"]["dict_path"]
        assert dict_path is not None and os.path.isdir(
            dict_path), "ERROR {} ".format(dict_path)
        end_predictions = args.end_predictions
        assert end_predictions is not None and os.path.isdir(
            end_predictions), "ERROR end_predictions"
        model_location = args_checkpoint["hyperparameters"]["model_location"]
        model_id = args_checkpoint["hyperparameters"]["model_id"]
        assert model_location is not None and model_id is not None, "ERROR model_location model_id "
        args_dir = os.path.join(model_location,
                                "{}-args.json".format(model_id))

        printing(
            "CHECKPOINTING : starting writing log \ntensorboard --logdir={} --host=localhost --port=1234 ",
            var=[os.path.join(model_id, "tensorboard")],
            verbose_level=1,
            verbose=verbose)

    # build or make dictionaries
    _dev_path = args.dev_path if args.dev_path is not None else args.train_path
    word_dictionary, word_norm_dictionary, char_dictionary, pos_dictionary, \
    xpos_dictionary, type_dictionary = \
        conllu_data.load_dict(dict_path=dict_path,
                              train_path=args.train_path if run_mode == "train" else None,
                              dev_path=args.dev_path if run_mode == "train" else None,
                              test_path=None,
                              word_embed_dict={},
                              dry_run=False,
                              expand_vocab=False,
                              word_normalization=True,
                              force_new_dic=True if run_mode == "train" else False,
                              tasks=args.tasks,
                              pos_specific_data_set=args.train_path[1] if len(args.tasks) > 1 and len(args.train_path)>1 and "pos" in args.tasks else None,
                              case=args.case,
                              # if not normalize pos or parsing in tasks we don't need dictionary
                              do_not_fill_dictionaries=len(set(["normalize", "pos", "parsing"])&set([task for tasks in args.tasks for task in tasks])) == 0,
                              add_start_char=1 if run_mode == "train" else None,
                              verbose=verbose)
    # we flatten the taskssd
    printing("DICTIONARY CREATED/LOADED", verbose=verbose, verbose_level=1)
    num_labels_per_task, task_to_label_dictionary = get_vocab_size_and_dictionary_per_task(
        [task for tasks in args.tasks for task in tasks],
        vocab_bert_wordpieces_len=vocab_size,
        pos_dictionary=pos_dictionary,
        type_dictionary=type_dictionary,
        task_parameters=TASKS_PARAMETER)
    voc_pos_size = num_labels_per_task["pos"] if "pos" in args.tasks else None
    if voc_pos_size is not None:
        printing("MODEL : voc_pos_size defined as {}",
                 var=voc_pos_size,
                 verbose_level=1,
                 verbose=verbose)
    printing("MODEL init...", verbose=verbose, verbose_level=1)
    if verbose > 1:
        print("DEBUG : TOKENIZER :voc_tokenizer from_pretrained",
              voc_tokenizer)
    #pdb.set_trace()
    #voc_tokenizer = "bert-base-multilingual-cased"
    tokenizer = tokenizer.from_pretrained(
        voc_tokenizer,
        do_lower_case=args.case == "lower",
        shuffle_bpe_embedding=args.shuffle_bpe_embedding)
    mask_id = tokenizer.convert_tokens_to_ids(
        tokenizer.mask_token)  #convert_tokens_to_ids([MASK_BERT])[0]
    printout_allocated_gpu_memory(verbose,
                                  "{} loading model ".format(model_id))
    model = get_model_multi_task_bert(args=args,
                                      model_dir=model_dir,
                                      encoder=encoder,
                                      num_labels_per_task=num_labels_per_task,
                                      mask_id=mask_id)

    def prune_heads(prune_heads):
        if prune_heads is not None:
            pune_heads_ls = prune_heads.split(",")[:-1]
            assert len(pune_heads_ls) > 0
            for layer in pune_heads_ls:
                parsed_layer_to_prune = layer.split("-")
                assert parsed_layer_to_prune[0] == "prune_heads"
                assert parsed_layer_to_prune[1] == "layer"
                assert parsed_layer_to_prune[3] == "heads"
                heads = parsed_layer_to_prune[4]
                head_index_ls = heads.split("_")
                heads_ls = [int(index) for index in head_index_ls]
                print(
                    f"MODEL : pruning layer {parsed_layer_to_prune[2]} heads {heads_ls}"
                )
                model.encoder.encoder.layer[int(
                    parsed_layer_to_prune[2])].attention.prune_heads(heads_ls)

    if args.prune_heads is not None and args.prune_heads != "None":
        print(f"INFO : args.prune_heads {args.prune_heads}")
        prune_heads(args.prune_heads)

    if use_gpu:
        model.to("cuda")
        printing("MODEL TO CUDA", verbose=verbose, verbose_level=1)
    printing("MODEL model.config {} ",
             var=[model.config],
             verbose=verbose,
             verbose_level=1)
    printout_allocated_gpu_memory(verbose, "{} model loaded".format(model_id))
    model_origin = OrderedDict()
    pruning_mask = OrderedDict()
    printout_allocated_gpu_memory(verbose, "{} model cuda".format(model_id))
    for name, param in model.named_parameters():
        model_origin[name] = param.detach().clone()
        printout_allocated_gpu_memory(verbose, "{} param cloned ".format(name))
        if args.penalization_mode == "pruning":
            abs = torch.abs(param.detach().flatten())
            median_value = torch.median(abs)
            pruning_mask[name] = (abs > median_value).float()
        printout_allocated_gpu_memory(
            verbose, "{} pruning mask loaded".format(model_id))

    printout_allocated_gpu_memory(verbose, "{} model clone".format(model_id))

    inv_word_dic = word_dictionary.instance2index
    # load , mask, bucket and index data

    assert tokenizer is not None, "ERROR : tokenizer is None , voc_tokenizer failed to be loaded {}".format(
        voc_tokenizer)
    if run_mode == "train":
        time_load_readers_train_start = time.time()
        if not args.memory_efficient_iterator:

            data_sharded, n_shards, n_sent_dataset_total_train = None, None, None
            args_load_batcher_shard_data = None
            printing("INFO : starting loading readers",
                     verbose=verbose,
                     verbose_level=1)
            readers_train = readers_load(
                datasets=args.train_path,
                tasks=args.tasks,
                word_dictionary=word_dictionary,
                bert_tokenizer=tokenizer,
                word_dictionary_norm=word_norm_dictionary,
                char_dictionary=char_dictionary,
                pos_dictionary=pos_dictionary,
                xpos_dictionary=xpos_dictionary,
                type_dictionary=type_dictionary,
                word_decoder=True,
                run_mode=run_mode,
                add_start_char=1,
                add_end_char=1,
                symbolic_end=1,
                symbolic_root=1,
                bucket=True,
                must_get_norm=True,
                input_level_ls=input_level_ls,
                verbose=verbose)
            n_sent_dataset_total_train = readers_train[list(
                readers_train.keys())[0]][3]
            printing("INFO : done with sharding",
                     verbose=verbose,
                     verbose_level=1)
        else:
            printing("INFO : building/loading shards ",
                     verbose=verbose,
                     verbose_level=1)
            data_sharded, n_shards, n_sent_dataset_total_train = build_shard(
                data_sharded,
                args.train_path,
                n_sent_max_per_file=N_SENT_MAX_CONLL_PER_SHARD,
                verbose=verbose)

        time_load_readers_dev_start = time.time()
        time_load_readers_train = time.time() - time_load_readers_train_start
        readers_dev_ls = []
        dev_data_label_ls = []
        printing("INFO : g readers for dev", verbose=verbose, verbose_level=1)
        printout_allocated_gpu_memory(
            verbose, "{} reader train loaded".format(model_id))
        for dev_path in args.dev_path:
            dev_data_label = get_dataset_label(dev_path, default="dev")
            dev_data_label_ls.append(dev_data_label)
            readers_dev = readers_load(
                datasets=dev_path,
                tasks=args.tasks,
                word_dictionary=word_dictionary,
                word_dictionary_norm=word_norm_dictionary,
                char_dictionary=char_dictionary,
                pos_dictionary=pos_dictionary,
                xpos_dictionary=xpos_dictionary,
                bert_tokenizer=tokenizer,
                type_dictionary=type_dictionary,
                word_decoder=True,
                run_mode=run_mode,
                add_start_char=1,
                add_end_char=1,
                symbolic_end=1,
                symbolic_root=1,
                bucket=False,
                must_get_norm=True,
                input_level_ls=input_level_ls,
                verbose=verbose) if args.dev_path is not None else None
            readers_dev_ls.append(readers_dev)
        printout_allocated_gpu_memory(verbose,
                                      "{} reader dev loaded".format(model_id))

        time_load_readers_dev = time.time() - time_load_readers_dev_start
        # Load tokenizer
        printing("TIME : {} ",
                 var=[
                     OrderedDict([
                         ("time_load_readers_train",
                          "{:0.4f} min".format(time_load_readers_train / 60)),
                         ("time_load_readers_dev",
                          "{:0.4f} min".format(time_load_readers_dev / 60))
                     ])
                 ],
                 verbose=verbose,
                 verbose_level=2)

        early_stoping_val_former = 1000
        # training starts when epoch is 1
        #args.epochs += 1
        #assert args.epochs >= 1, "ERROR need at least 2 epochs (1 eval , 1 train 1 eval"
        flexible_batch_size = False

        if args.optimizer == "AdamW":
            model, optimizer, scheduler = apply_fine_tuning_strategy(
                model=model,
                fine_tuning_strategy=args.fine_tuning_strategy,
                lr_init=args.lr,
                betas=(0.9, 0.99),
                epoch=0,
                weight_decay=args.weight_decay,
                optimizer_name=args.optimizer,
                t_total=n_sent_dataset_total_train / args.batch_update_train *
                args.epochs if n_sent_dataset_total_train /
                args.batch_update_train * args.epochs > 1 else 5,
                verbose=verbose)

        try:
            for epoch in range(args.epochs):
                if args.memory_efficient_iterator:
                    # we start epoch with a new shart everytime !
                    training_file = get_new_shard(data_sharded, n_shards)
                    printing(
                        "INFO Memory efficient iterator triggered (only build for train data , starting with {}",
                        var=[training_file],
                        verbose=verbose,
                        verbose_level=1)
                    args_load_batcher_shard_data = {
                        "word_dictionary": word_dictionary,
                        "tokenizer": tokenizer,
                        "word_norm_dictionary": word_norm_dictionary,
                        "char_dictionary": char_dictionary,
                        "pos_dictionary": pos_dictionary,
                        "xpos_dictionary": xpos_dictionary,
                        "type_dictionary": type_dictionary,
                        "use_gpu": use_gpu,
                        "norm_not_norm": auxilliary_task_norm_not_norm,
                        "word_decoder": True,
                        "add_start_char": 1,
                        "add_end_char": 1,
                        "symbolic_end": 1,
                        "symbolic_root": 1,
                        "bucket": True,
                        "max_char_len": 20,
                        "must_get_norm": True,
                        "use_gpu_hardcoded_readers": False,
                        "bucketing_level": "bpe",
                        "input_level_ls": ["wordpiece"],
                        "auxilliary_task_norm_not_norm":
                        auxilliary_task_norm_not_norm,
                        "random_iterator_train": random_iterator_train
                    }

                    readers_train = readers_load(
                        datasets=args.train_path if
                        not args.memory_efficient_iterator else training_file,
                        tasks=args.tasks,
                        word_dictionary=word_dictionary,
                        bert_tokenizer=tokenizer,
                        word_dictionary_norm=word_norm_dictionary,
                        char_dictionary=char_dictionary,
                        pos_dictionary=pos_dictionary,
                        xpos_dictionary=xpos_dictionary,
                        type_dictionary=type_dictionary,
                        word_decoder=True,
                        run_mode=run_mode,
                        add_start_char=1,
                        add_end_char=1,
                        symbolic_end=1,
                        symbolic_root=1,
                        bucket=True,
                        must_get_norm=True,
                        input_level_ls=input_level_ls,
                        verbose=verbose)

                checkpointing_model_data = (epoch % saving_every_epoch == 0
                                            or epoch == (args.epochs - 1))
                # build iterator on the loaded data
                printout_allocated_gpu_memory(
                    verbose, "{} loading batcher".format(model_id))

                if args.batch_size == "flexible":
                    flexible_batch_size = True

                    printing(
                        "INFO : args.batch_size {} so updating it based on mean value {}",
                        var=[
                            args.batch_size,
                            update_batch_size_mean(readers_train)
                        ],
                        verbose=verbose,
                        verbose_level=1)
                    args.batch_size = update_batch_size_mean(readers_train)

                    if args.batch_update_train == "flexible":
                        args.batch_update_train = args.batch_size
                    printing(
                        "TRAINING : backward pass every {} step of size {} in average",
                        var=[
                            int(args.batch_update_train // args.batch_size),
                            args.batch_size
                        ],
                        verbose=verbose,
                        verbose_level=1)
                    try:
                        assert isinstance(args.batch_update_train // args.batch_size, int)\
                           and args.batch_update_train // args.batch_size > 0, \
                            "ERROR batch_size {} should be a multiple of {} ".format(args.batch_update_train, args.batch_size)
                    except Exception as e:
                        print("WARNING {}".format(e))
                batchIter_train = data_gen_multi_task_sampling_batch(
                    tasks=args.tasks,
                    readers=readers_train,
                    batch_size=readers_train[list(readers_train.keys())[0]][4],
                    max_token_per_batch=max_token_per_batch
                    if flexible_batch_size else None,
                    word_dictionary=word_dictionary,
                    char_dictionary=char_dictionary,
                    pos_dictionary=pos_dictionary,
                    word_dictionary_norm=word_norm_dictionary,
                    get_batch_mode=random_iterator_train,
                    print_raw=False,
                    dropout_input=0.0,
                    verbose=verbose)

                # -|-|-
                printout_allocated_gpu_memory(
                    verbose, "{} batcher train loaded".format(model_id))
                batchIter_dev_ls = []
                batch_size_DEV = 1

                if verbose > 1:
                    print(
                        "WARNING : batch_size for final eval was hardcoded and set to {}"
                        .format(batch_size_DEV))
                for readers_dev in readers_dev_ls:
                    batchIter_dev = data_gen_multi_task_sampling_batch(
                        tasks=args.tasks,
                        readers=readers_dev,
                        batch_size=batch_size_DEV,
                        word_dictionary=word_dictionary,
                        char_dictionary=char_dictionary,
                        pos_dictionary=pos_dictionary,
                        word_dictionary_norm=word_norm_dictionary,
                        get_batch_mode=False,
                        print_raw=False,
                        dropout_input=0.0,
                        verbose=verbose) if args.dev_path is not None else None
                    batchIter_dev_ls.append(batchIter_dev)

                model.train()
                printout_allocated_gpu_memory(
                    verbose, "{} batcher dev loaded".format(model_id))
                if args.optimizer != "AdamW":

                    model, optimizer, scheduler = apply_fine_tuning_strategy(
                        model=model,
                        fine_tuning_strategy=args.fine_tuning_strategy,
                        lr_init=args.lr,
                        betas=(0.9, 0.99),
                        weight_decay=args.weight_decay,
                        optimizer_name=args.optimizer,
                        t_total=n_sent_dataset_total_train /
                        args.batch_update_train *
                        args.epochs if n_sent_dataset_total_train /
                        args.batch_update_train * args.epochs > 1 else 5,
                        epoch=epoch,
                        verbose=verbose)
                printout_allocated_gpu_memory(
                    verbose, "{} optimizer loaded".format(model_id))
                loss_train = None

                if epoch >= 0:
                    printing("TRAINING : training on GET_BATCH_MODE ",
                             verbose=verbose,
                             verbose_level=2)
                    printing(
                        "TRAINING {} training 1 'epoch' = {} observation size args.batch_"
                        "update_train (foward {} batch_size {} backward  "
                        "(every int(args.batch_update_train//args.batch_size) step if {})) ",
                        var=[
                            model_id, n_observation_max_per_epoch_train,
                            args.batch_size, args.batch_update_train,
                            args.low_memory_foot_print_batch_mode
                        ],
                        verbose=verbose,
                        verbose_level=1)
                    loss_train, iter_train, perf_report_train, _ = epoch_run(
                        batchIter_train,
                        tokenizer,
                        args=args,
                        model_origin=model_origin,
                        pruning_mask=pruning_mask,
                        task_to_label_dictionary=task_to_label_dictionary,
                        data_label=train_data_label,
                        model=model,
                        dropout_input_bpe=args.dropout_input_bpe,
                        writer=writer,
                        iter=iter_train,
                        epoch=epoch,
                        writing_pred=epoch == (args.epochs - 1),
                        dir_end_pred=end_predictions,
                        optimizer=optimizer,
                        use_gpu=use_gpu,
                        scheduler=scheduler,
                        predict_mode=(epoch - 1) % 5 == 0,
                        skip_1_t_n=skip_1_t_n,
                        model_id=model_id,
                        reference_word_dic={"InV": inv_word_dic},
                        null_token_index=null_token_index,
                        null_str=null_str,
                        norm_2_noise_eval=False,
                        early_stoppin_metric=None,
                        n_obs_max=n_observation_max_per_epoch_train,
                        data_sharded_dir=data_sharded,
                        n_shards=n_shards,
                        n_sent_dataset_total=n_sent_dataset_total_train,
                        args_load_batcher_shard_data=
                        args_load_batcher_shard_data,
                        memory_efficient_iterator=args.
                        memory_efficient_iterator,
                        verbose=verbose)

                else:
                    printing(
                        "TRAINING : skipping first epoch to start by evaluating on devs dataset0",
                        verbose=verbose,
                        verbose_level=1)
                printout_allocated_gpu_memory(
                    verbose, "{} epoch train done".format(model_id))
                model.eval()

                if args.dev_path is not None and (epoch % 3 == 0
                                                  or epoch <= 6):
                    if verbose > 1:
                        print("RUNNING DEV on ITERATION MODE")
                    early_stoping_val_ls = []
                    loss_dev_ls = []
                    for i_dev, batchIter_dev in enumerate(batchIter_dev_ls):
                        loss_dev, iter_dev, perf_report_dev, early_stoping_val = epoch_run(
                            batchIter_dev,
                            tokenizer,
                            args=args,
                            epoch=epoch,
                            model_origin=model_origin,
                            pruning_mask=pruning_mask,
                            task_to_label_dictionary=task_to_label_dictionary,
                            iter=iter_dev,
                            use_gpu=use_gpu,
                            model=model,
                            writer=writer,
                            optimizer=None,
                            writing_pred=True,  #epoch == (args.epochs - 1),
                            dir_end_pred=end_predictions,
                            predict_mode=True,
                            data_label=dev_data_label_ls[i_dev],
                            null_token_index=null_token_index,
                            null_str=null_str,
                            model_id=model_id,
                            skip_1_t_n=skip_1_t_n,
                            dropout_input_bpe=0,
                            reference_word_dic={"InV": inv_word_dic},
                            norm_2_noise_eval=False,
                            early_stoppin_metric=early_stoppin_metric,
                            subsample_early_stoping_metric_val=
                            subsample_early_stoping_metric_val,
                            #case=case,
                            n_obs_max=n_observation_max_per_epoch_dev_test,
                            verbose=verbose)

                        printing(
                            "TRAINING : loss train:{} dev {}:{} for epoch {}  out of {}",
                            var=[
                                loss_train, i_dev, loss_dev, epoch, args.epochs
                            ],
                            verbose=1,
                            verbose_level=1)
                        printing("PERFORMANCE {} DEV {} {} ",
                                 var=[epoch, i_dev + 1, perf_report_dev],
                                 verbose=verbose,
                                 verbose_level=1)
                        early_stoping_val_ls.append(early_stoping_val)
                        loss_dev_ls.append(loss_dev)

                    else:
                        if verbose > 1:
                            print("NO DEV EVAL")
                        loss_dev, iter_dev, perf_report_dev = None, 0, None
                # NB : early_stoping_val is based on first dev set
                printout_allocated_gpu_memory(
                    verbose, "{} epoch dev done".format(model_id))

                early_stoping_val = early_stoping_val_ls[0]
                if checkpointing_model_data or early_stoping_val < early_stoping_val_former:
                    if early_stoping_val is not None:
                        _epoch = "best" if early_stoping_val < early_stoping_val_former else epoch
                    else:
                        if verbose > 1:
                            print(
                                'WARNING early_stoping_val is None so saving based on checkpointing_model_data only'
                            )
                        _epoch = epoch
                    # model_id enriched possibly with some epoch informaiton if name_with_epoch
                    _model_id = get_name_model_id_with_extra_name(
                        epoch=epoch,
                        _epoch=_epoch,
                        name_with_epoch=name_with_epoch,
                        model_id=model_id)
                    checkpoint_dir = os.path.join(
                        model_location, "{}-checkpoint.pt".format(_model_id))

                    if _epoch == "best":
                        printing(
                            "CHECKPOINT : SAVING BEST MODEL {} (epoch:{}) (new loss is {} former was {})"
                            .format(checkpoint_dir, epoch, early_stoping_val,
                                    early_stoping_val_former),
                            verbose=verbose,
                            verbose_level=1)
                        last_checkpoint_dir_best = checkpoint_dir
                        early_stoping_val_former = early_stoping_val
                        best_epoch = epoch
                        best_loss = early_stoping_val
                    else:
                        printing(
                            "CHECKPOINT : NOT SAVING BEST MODEL : new loss {} did not beat first loss {}"
                            .format(early_stoping_val,
                                    early_stoping_val_former),
                            verbose_level=1,
                            verbose=verbose)
                    last_model = ""
                    if epoch == (args.epochs - 1):
                        last_model = "last"
                    printing("CHECKPOINT : epoch {} saving {} model {} ",
                             var=[epoch, last_model, checkpoint_dir],
                             verbose=verbose,
                             verbose_level=1)
                    torch.save(model.state_dict(), checkpoint_dir)

                    args_dir = write_args(
                        dir=model_location,
                        checkpoint_dir=checkpoint_dir,
                        hyperparameters=hyperparameters
                        if name_with_epoch else None,
                        model_id=_model_id,
                        info_checkpoint=OrderedDict([
                            ("epochs", epoch + 1),
                            ("batch_size", args.batch_size
                             if not args.low_memory_foot_print_batch_mode else
                             args.batch_update_train),
                            ("train_path", train_data_label),
                            ("dev_path", dev_data_label_ls),
                            ("num_labels_per_task", num_labels_per_task)
                        ]),
                        verbose=verbose)

            if row is not None and update_status is not None:
                update_status(row=row, value="training-done", verbose=1)
        except Exception as e:
            if row is not None and update_status is not None:
                update_status(row=row, value="ERROR", verbose=1)
            raise (e)

    # reloading last (best) checkpoint
    if run_mode in ["train", "test"] and args.test_paths is not None:
        report_all = []
        if run_mode == "train" and args.epochs > 0:
            if use_gpu:
                model.load_state_dict(torch.load(last_checkpoint_dir_best))
                model = model.cuda()
                printout_allocated_gpu_memory(
                    verbose, "{} after reloading model".format(model_id))
            else:
                model.load_state_dict(
                    torch.load(last_checkpoint_dir_best,
                               map_location=lambda storage, loc: storage))
            printing(
                "MODEL : RELOADING best model of epoch {} with loss {} based on {}({}) metric (from checkpoint {})",
                var=[
                    best_epoch, best_loss, early_stoppin_metric,
                    subsample_early_stoping_metric_val,
                    last_checkpoint_dir_best
                ],
                verbose=verbose,
                verbose_level=1)

        model.eval()

        printout_allocated_gpu_memory(verbose,
                                      "{} starting test".format(model_id))
        for test_path in args.test_paths:
            assert len(test_path) == len(
                args.tasks), "ERROR test_path {} args.tasks {}".format(
                    test_path, args.tasks)
            for test, task_to_eval in zip(test_path, args.tasks):
                label_data = get_dataset_label([test], default="test")
                if len(extra_label_for_prediction) > 0:
                    label_data += "-" + extra_label_for_prediction

                if args.shuffle_bpe_embedding and args.test_mode_no_shuffle_embedding:
                    printing(
                        "TOKENIZER: as args.shuffle_bpe_embedding {} and test_mode_no_shuffle {} : reloading tokenizer with no shuffle_embedding",
                        var=[
                            args.shuffle_bpe_embedding,
                            args.test_mode_no_shuffle_embedding
                        ],
                        verbose=1,
                        verbose_level=1)
                    tokenizer = tokenizer.from_pretrained(
                        voc_tokenizer,
                        do_lower_case=args.case == "lower",
                        shuffle_bpe_embedding=False)
                readers_test = readers_load(
                    datasets=[test],
                    tasks=[task_to_eval],
                    word_dictionary=word_dictionary,
                    word_dictionary_norm=word_norm_dictionary,
                    char_dictionary=char_dictionary,
                    pos_dictionary=pos_dictionary,
                    xpos_dictionary=xpos_dictionary,
                    type_dictionary=type_dictionary,
                    bert_tokenizer=tokenizer,
                    word_decoder=True,
                    run_mode=run_mode,
                    add_start_char=1,
                    add_end_char=1,
                    symbolic_end=1,
                    symbolic_root=1,
                    bucket=bucket_test,
                    input_level_ls=input_level_ls,
                    must_get_norm=must_get_norm_test,
                    verbose=verbose)

                heuritics_zip = [None]
                gold_error_or_not_zip = [False]
                norm2noise_zip = [False]

                if heuristic_test_ls is None:
                    assert len(gold_error_or_not_zip) == len(
                        heuritics_zip) and len(heuritics_zip) == len(
                            norm2noise_zip)

                batch_size_TEST = 1
                if verbose > 1:
                    print(
                        "WARNING : batch_size for final eval was hardcoded and set to {}"
                        .format(batch_size_TEST))
                for (heuristic_test, gold_error,
                     norm_2_noise_eval) in zip(heuritics_zip,
                                               gold_error_or_not_zip,
                                               norm2noise_zip):

                    assert heuristic_test is None and not gold_error and not norm_2_noise_eval

                    batchIter_test = data_gen_multi_task_sampling_batch(
                        tasks=[task_to_eval],
                        readers=readers_test,
                        batch_size=batch_size_TEST,
                        word_dictionary=word_dictionary,
                        char_dictionary=char_dictionary,
                        pos_dictionary=pos_dictionary,
                        word_dictionary_norm=word_norm_dictionary,
                        get_batch_mode=False,
                        dropout_input=0.0,
                        verbose=verbose)
                    try:
                        loss_test, iter_test, perf_report_test, _ = epoch_run(
                            batchIter_test,
                            tokenizer,
                            args=args,
                            iter=iter_dev,
                            use_gpu=use_gpu,
                            model=model,
                            task_to_label_dictionary=task_to_label_dictionary,
                            writer=None,
                            writing_pred=True,
                            optimizer=None,
                            args_dir=args_dir,
                            model_id=model_id,
                            dir_end_pred=end_predictions,
                            skip_1_t_n=skip_1_t_n,
                            predict_mode=True,
                            data_label=label_data,
                            epoch="LAST",
                            extra_label_for_prediction=label_data,
                            null_token_index=null_token_index,
                            null_str=null_str,
                            log_perf=False,
                            dropout_input_bpe=0,
                            norm_2_noise_eval=norm_2_noise_eval,
                            compute_intersection_score=
                            compute_intersection_score_test,
                            remove_mask_str_prediction=
                            remove_mask_str_prediction,
                            reference_word_dic={"InV": inv_word_dic},
                            threshold_edit=threshold_edit,
                            verbose=verbose,
                            n_obs_max=n_observation_max_per_epoch_dev_test)
                        if verbose > 1:
                            print("LOSS TEST", loss_test)
                    except Exception as e:
                        print(
                            "ERROR (epoch_run test) {} test_path {} , heuristic {} , gold error {} , norm2noise {} "
                            .format(e, test, heuristic_test, gold_error,
                                    norm_2_noise_eval))
                        raise (e)
                    print("PERFORMANCE TEST on data  {} is {} ".format(
                        label_data, perf_report_test))
                    print("DATA WRITTEN {}".format(end_predictions))
                    if writer is not None:
                        writer.add_text(
                            "Accuracy-{}-{}-{}".format(model_id, label_data,
                                                       run_mode),
                            "After {} epochs with {} : performance is \n {} ".
                            format(args.epochs, description,
                                   str(perf_report_test)), 0)
                    else:
                        printing(
                            "WARNING : could not add accuracy to tensorboard cause writer was found None",
                            verbose=verbose,
                            verbose_level=2)
                    report_all.extend(perf_report_test)
                    printout_allocated_gpu_memory(
                        verbose, "{} test done".format(model_id))
    else:
        printing("ERROR : EVALUATION none cause {} empty or run_mode {} ",
                 var=[args.test_paths, run_mode],
                 verbose_level=1,
                 verbose=verbose)

    if writer is not None:
        writer.close()
        printing("tensorboard --logdir={} --host=localhost --port=1234 ",
                 var=[tensorboard_log],
                 verbose_level=1,
                 verbose=verbose)

    report_dir = os.path.join(model_location, model_id + "-report.json")
    if report_full_path_shared is not None:
        report_full_dir = os.path.join(report_full_path_shared,
                                       args.overall_label + "-report.json")
        if os.path.isfile(report_full_dir):
            report = json.load(open(report_full_dir, "r"))
        else:
            report = []
            printing("REPORT = creating overall report at {} ",
                     var=[report_dir],
                     verbose=verbose,
                     verbose_level=1)
        report.extend(report_all)
        json.dump(report, open(report_full_dir, "w"))
        printing("{} {} ",
                 var=[REPORT_FLAG_DIR_STR, report_full_dir],
                 verbose=0,
                 verbose_level=0)

    json.dump(report_all, open(report_dir, "w"))
    printing("REPORTING TO {}".format(report_dir),
             verbose=verbose,
             verbose_level=1)
    if report_full_path_shared is None:
        printing("WARNING ; report_full_path_shared is None",
                 verbose=verbose,
                 verbose_level=1)
        printing("{} {} ",
                 var=[REPORT_FLAG_DIR_STR, report_dir],
                 verbose=verbose,
                 verbose_level=0)

    return model
Exemple #12
0
def get_label_per_bpe(tasks,
                      batch,
                      input_tokens_tensor,
                      input_alignement_with_raw,
                      use_gpu,
                      tasks_parameters,
                      pad_index,
                      vocab_len=None,
                      masking_strategy=0,
                      mask_token_index=None,
                      sep_token_index=None,
                      cls_token_index=None,
                      dropout_input_bpe=None):
    """
    returns input, input masks and output for each tasks
    (in regard to the task type , so far only word level is supported)
    """
    #  TODO : should be done in pytorch + reducancies with get_index
    label_per_task = OrderedDict()
    input_tokens_tensor_per_task = OrderedDict()
    token_type_ids = OrderedDict()
    input_mask_per_task = OrderedDict()
    input_mask, output_tokens_tensor = None, None
    cumulate_shift = None
    head_masks = OrderedDict()
    for simul_task in tasks:
        for task in simul_task:
            for task_batch_name in tasks_parameters[task]["label"]:
                task_batch = eval("batch.{}".format(task_batch_name)).clone()
                # why not is_mwe and n_masks also
                if task in ["parsing", "pos"]:
                    # for now we align parsing and tagging signal with raw input using
                    # get_bpe_label_word_level_task here
                    output_tokens_tensor, head_mask, input_tokens_tensor, _cumulate_shift = get_bpe_label_word_level_task(
                        labels=task_batch,
                        pad=pad_index,
                        batch=batch,
                        #input_tokens_tensor,
                        #input_alignement_with_raw,
                        use_gpu=use_gpu,
                        label_name=task_batch_name,
                        input_tokens_tensor=eval("batch.{}".format(
                            tasks_parameters[task]["input"])).clone(),
                        input_alignement_with_raw=eval(
                            "batch.{}_alignement".format(
                                tasks_parameters[task]["input"])),
                        graph_labels=LABEL_PARAMETER[task_batch_name].get(
                            "graph_label", False))

                    output_tokens_tensor_aligned = output_tokens_tensor  #[:, : input_tokens_tensor.size(1)]
                    if task == "parsing" and task_batch_name == "heads":
                        cumulate_shift = _cumulate_shift
                else:
                    # for tokenization related tasks we already took care of alignement during CoNLLReader
                    output_tokens_tensor_aligned = task_batch
                    head_mask = None

                head_masks[task] = head_mask
                if output_tokens_tensor_aligned is not None:
                    output_tokens_tensor_aligned = output_tokens_tensor_aligned.contiguous(
                    )

                    if use_gpu:
                        output_tokens_tensor_aligned = output_tokens_tensor_aligned.cuda(
                        )
                # if the task has several label : we just appen the label name to the task in the label dictionary
                # ALL output padded with BERT pad are padded with LOSS pad (-1)
                label_per_task[task_batch_name] = output_tokens_tensor_aligned

            if not tasks_parameters[task].get("mask_input", False):
                #input_tokens_tensor_per_task[tasks_parameters[task]["input"]] = eval("batch.{}".format(tasks_parameters[task]["input"])).clone() if task not in ["parsing", "pos"] else input_tokens_tensor.clone()
                input_tokens_tensor_per_task[
                    tasks_parameters[task]["input"]] = eval("batch.{}".format(
                        tasks_parameters[task]["input"])).clone()

                # we dropout input for regulirization purpose here if needed
                if dropout_input_bpe is not None and dropout_input_bpe > 0:
                    input_tokens_tensor_per_task[
                        tasks_parameters[task]["input"]] = dropout_mlm(
                            input_tokens_tensor_per_task[tasks_parameters[task]
                                                         ["input"]],
                            mask_token_index=mask_token_index,
                            sep_token_index=sep_token_index,
                            cls_token_index=cls_token_index,
                            pad_index=pad_index,
                            use_gpu=False,
                            dropout_mask=dropout_input_bpe,
                            dropout_random_bpe_of_masked=0.5,
                            vocab_len=vocab_len)

                input_mask_per_task[tasks_parameters[task]["input"]] = (
                    input_tokens_tensor_per_task[tasks_parameters[task]
                                                 ["input"]] != pad_index)
            else:  # mlm
                # mask_input is for Mask Languag Model task  : which means Masking + replacing by random wordpiece
                assert masking_strategy is None
                #NB : dropout_input_bpe is ignored in MLM : set to 15% as Bert Paper
                assert tasks_parameters[task].get("original") is not None, \
                    "ERROR 'original' field is needed to get raw sequence before preprocssing for task {} ".format(task)
                input_tokens_tensor_per_task[
                    tasks_parameters[task]["input"]] = dropout_mlm(
                        eval("batch.{}".format(
                            tasks_parameters[task]["original"])).clone(),
                        mask_token_index=mask_token_index,
                        sep_token_index=sep_token_index,
                        cls_token_index=cls_token_index,
                        pad_index=pad_index,
                        use_gpu=False,
                        dropout_mask=0.15,
                        dropout_random_bpe_of_masked=0.5,
                        vocab_len=vocab_len)
                # NB ; this mask is for PADDING !! (bad naming)
                input_mask_per_task[tasks_parameters[task]["input"]] = (
                    input_tokens_tensor_per_task[tasks_parameters[task]
                                                 ["input"]] != pad_index)

            token_type_ids[tasks_parameters[task]["input"]] = torch.zeros_like(
                input_tokens_tensor_per_task[tasks_parameters[task]["input"]])
            if use_gpu:
                input_tokens_tensor_per_task[tasks_parameters[task][
                    "input"]] = input_tokens_tensor_per_task[
                        tasks_parameters[task]["input"]].cuda()
                input_mask_per_task[
                    tasks_parameters[task]["input"]] = input_mask_per_task[
                        tasks_parameters[task]["input"]].cuda()
                token_type_ids[tasks_parameters[task]
                               ["input"]] = token_type_ids[
                                   tasks_parameters[task]["input"]].cuda()

    return head_masks, input_tokens_tensor, token_type_ids, label_per_task, \
           input_tokens_tensor_per_task, input_mask_per_task, cumulate_shift
Exemple #13
0
    def forward(self,
                input_ids_dict,
                token_type_ids=None,
                attention_mask=None,
                labels=None,
                head_masks=None):
        if labels is None:
            labels = OrderedDict()
        if head_masks is None:
            head_masks = OrderedDict()
        sequence_output_dict = OrderedDict()
        logits_dict = OrderedDict()
        loss_dict = OrderedDict()
        # sanity check the labels : they should all be in
        for label, value in labels.items():
            assert label in self.labels_supported, "label {} in {} not supported".format(
                label, self.labels_supported)

        # task_wise layer attention
        printout_allocated_gpu_memory(1, " foward starting ")
        for input_name, input_tensors in input_ids_dict.items():
            # not able to output all layers anymore
            #print("INPUT {} {} ".format(input_name, input_tensors))

            sequence_output, _ = self.encoder(
                input_tensors,
                token_type_ids=None,
                attention_mask=attention_mask[input_name])
            sequence_output_dict[input_name] = sequence_output
            printout_allocated_gpu_memory(1, " forward pass bert")

        for task in self.tasks:
            # we don't use mask for parsing heads (cf. test performed below : the -1 already ignore the heads we don't want)
            # NB : head_masks for parsing only applies to heads not types
            head_masks_task = None  # head_masks.get(task, None) if task != "parsing" else None
            # NB : head_mask means masks specific the the module heads (nothing related to parsing !! )
            assert self.task_parameters[task]["input"] in sequence_output_dict, \
                "ERROR input {} of task {} was not found in input_ids_dict {}" \
                " and therefore not in sequence_output_dict {} ".format(self.task_parameters[task]["input"],
                                                                        task, input_ids_dict.keys(),
                                                                        sequence_output_dict.keys())

            if not self.head[
                    task].__class__.__name__ == BertOnlyMLMHead.__name__:  #isinstance(self.head[task], BertOnlyMLMHead):

                logits_dict[task] = self.head[task](
                    sequence_output_dict[self.task_parameters[task]["input"]],
                    head_mask=head_masks_task)
            else:
                logits_dict[task] = self.head[task](
                    sequence_output_dict[self.task_parameters[task]["input"]])
            # test performed : (logits_dict[task][0][1,2,:20]==float('-inf'))==(labels["parsing_heads"][1,:20]==-1)
            # handle several labels at output (e.g  parsing)

            printout_allocated_gpu_memory(1,
                                          " foward pass head {}".format(task))

            logits_dict = self.rename_multi_modal_task_logits(
                labels=self.task_parameters[task]["label"],
                task=task,
                logits_dict=logits_dict,
                task_parameters=self.task_parameters)

            printout_allocated_gpu_memory(1, "after renaming")

            for logit_label in logits_dict:

                label = re.match("(.*)-(.*)", logit_label)
                assert label is not None, "ERROR logit_label {}".format(
                    logit_label)
                label = label.group(2)
                if label in self.task_parameters[task]["label"]:
                    _labels = None
                    if self.task_parameters[task]["input"] == "input_masked":
                        _labels = labels.get(label)
                        if _labels is not None:
                            _labels = _labels.clone()
                            _labels[input_ids_dict["input_masked"] != self.
                                    mask_index_bert] = PAD_ID_LOSS_STANDART
                    else:
                        _labels = labels.get(label)
                    printout_allocated_gpu_memory(
                        1, " get label head {}".format(logit_label))
                    if _labels is not None:
                        #print("LABEL label {} {}".format(label, _labels))
                        loss_dict[logit_label] = self.get_loss(
                            loss_func=self.task_parameters[task]["loss"],
                            label=label,
                            num_label_dic=self.num_labels_dic,
                            labels=_labels,
                            logits_dict=logits_dict,
                            task=task,
                            logit_label=logit_label,
                            head_label=labels["heads"]
                            if label == "types" else None)
                    printout_allocated_gpu_memory(1,
                                                  " get loss {}".format(task))
                printout_allocated_gpu_memory(
                    1, " puting to cpu {}".format(logit_label))
        # thrid output is for potential attention weights

        return logits_dict, loss_dict, None
def main(args, dict_path, model_dir):

    model, tokenizer, run_id = load_all_analysis(args, dict_path, model_dir)

    if args.compare_to_pretrained:
        print("Loading Pretrained model also for comparison with pretrained")
        args_origin = args_attention_analysis()
        args_origin = args_preprocess_attention_analysis(args_origin)
        args_origin.init_args_dir = None
        args_origin, dict_path_0, model_dir_0 = get_dirs(args_origin)
        args_origin.model_id_pref += "again"
        model_origin, tokenizer_0, _ = load_all_analysis(
            args_origin, dict_path_0, model_dir_0)
        model_origin.eval()
        print("seco,")
    # only allow output of the model to be hidden states here
    print("Checkpoint loaded")
    assert not args.output_attentions
    assert args.output_all_encoded_layers and args.output_hidden_states_per_head

    data = ["I am here", "How are you"]
    model.eval()
    n_obs = args.n_sent
    max_len = args.max_seq_len
    lang_ls = args.raw_text_code

    lang = [
        "fr_pud", "de_pud", "ru_pud", "tr_pud", "id_pud", "ar_pud", "pt_pud",
        "es_pud", "fi_pud", "it_pud", "sv_pud", "cs_pud", "pl_pud", "hi_pud",
        "zh_pud", "ko_pud", "ja_pud", "th_pud"
    ]
    #lang = ["fr_pud", "fr_gsd"]
    src_lang_ls = ["en_pud"]  #, "fr_pud", "ru_pud", "ar_pud"]
    print("Loading data...")
    data_target_ls = [
        load_data(DATA_UD + f"/{target}-ud-test.conllu",
                  line_filter="# text = ") for target in lang
    ]
    data_target_dic = OrderedDict([
        (lang, data) for lang, data in zip(lang, data_target_ls)
    ])
    #pdb.set_trace()
    src = src_lang_ls[0]  #"en_pud"

    data_en = data_target_ls[
        0]  #load_data(DATA_UD+f"/{src}-ud-test.conllu", line_filter="# text = ")

    for _data_target in data_target_dic:
        try:
            assert len(data_target_dic[_data_target]) == len(
                data_en
            ), f"Should have as much sentences on both sides en:{len(data_en)} target:{len(data_target_dic[_data_target])}"
        except:
            data_en = data_en[:len(data_target_dic[_data_target])]
            print(f"Cutting {src} dataset based on target")
        assert len(data_target_dic[_data_target]) == len(
            data_en
        ), f"Should have as much sentences on both sides en:{len(data_en)} target:{len(data_target_dic[_data_target])}"
    #reg = linear_model.LogisticRegression()
    # just to get the keyw
    layer_all = get_hidden_representation(data,
                                          model,
                                          tokenizer,
                                          max_len=max_len)
    # removed hidden_per_layer
    #pdb.set_trace()
    assert len(layer_all) == 1
    #assert len(layer_all) == 2, "ERROR should only have hidden_per_layer and hidden_per_head_layer"

    report_ls = []
    accuracy_dic = OrderedDict()
    sampling = args.sampling
    metric = args.similarity_metric
    if metric == "cka":
        pad_below_max_len, output_dic = False, True
    else:
        pad_below_max_len, output_dic = False, True
    assert metric in ["cos", "cka"]
    if metric == "cos":
        batch_size = 1
    else:
        batch_size = len(data_en) // 4

    task_tuned = "No"

    if args.init_args_dir is None:
        #args.init_args_dir =
        id_model = f"{args.bert_model}-init-{args.random_init}"

        hyperparameters = OrderedDict([
            ("bert_model", args.bert_model),
            ("random_init", args.random_init),
            ("not_load_params_ls", args.not_load_params_ls),
            ("dict_path", dict_path),
            ("model_id", id_model),
        ])
        info_checkpoint = OrderedDict([("epochs", 0),
                                       ("batch_size", batch_size),
                                       ("train_path", 0), ("dev_path", 0),
                                       ("num_labels_per_task", 0)])

        args.init_args_dir = write_args(os.environ.get("MT_NORM_PARSE", "./"),
                                        model_id=id_model,
                                        info_checkpoint=info_checkpoint,
                                        hyperparameters=hyperparameters,
                                        verbose=1)
        print("args_dir checkout ", args.init_args_dir)
        model_full_name_val = task_tuned + "-" + id_model
    else:
        argument = json.load(open(args.init_args_dir, 'r'))
        task_tuned = argument["hyperparameters"]["tasks"][0][
            0] if not "wiki" in argument["info_checkpoint"][
                "train_path"] else "ner"
        model_full_name_val = task_tuned + "-" + args.init_args_dir.split(
            "/")[-1]

    if args.analysis_mode == "layer":
        studied_ind = 0
    elif args.analysis_mode == "layer_head":
        studied_ind = 1
    else:
        raise (
            Exception(f"args.analysis_mode : {args.analysis_mode} corrupted"))
    layer_analysed = layer_all[studied_ind]

    #for ind, layer_head in enumerate(list(layer_analysed.keys())):
    report = OrderedDict()
    accuracy_ls = []
    src_lang = src

    cosine_sent_to_src = OrderedDict([(src_lang + "-" + lang, OrderedDict())
                                      for src_lang in src_lang_ls
                                      for lang in data_target_dic.keys()])
    cosine_sent_to_origin = OrderedDict([(lang, OrderedDict())
                                         for lang in data_target_dic.keys()])
    cosine_sent_to_origin_src = OrderedDict([(lang, OrderedDict())
                                             for lang in src_lang_ls])
    cosine_sent_to_former_layer_src = OrderedDict([(lang, OrderedDict())
                                                   for lang in src_lang_ls])
    cosine_sent_to_former_layer = OrderedDict([
        (lang, OrderedDict()) for lang in data_target_dic.keys()
    ])
    cosine_sent_to_first_layer = OrderedDict([
        (lang, OrderedDict()) for lang in data_target_dic.keys()
    ])
    #layer_head = list(layer_analysed.keys())[len(list(layer_analysed.keys())) - ind -1]

    cosinus = nn.CosineSimilarity(dim=1)
    info_model = f" task {args.tasks} args {'/'.join(args.init_args_dir.split('/')[-2:]) if args.init_args_dir is not None else None} bert {args.bert_model} random init {args.random_init} "
    #"cka"
    output_dic = True
    pad_below_max_len = False
    max_len = 200

    n_batch = len(data_en) // batch_size

    for i_data in range(n_batch):

        for src_lang in src_lang_ls:
            print(f"Starting src", {src_lang})
            data_en = load_data(DATA_UD + f"/{src_lang}-ud-test.conllu",
                                line_filter="# text = ")
            en_batch = data_en[i_data:i_data + batch_size]
            all = get_hidden_representation(
                en_batch,
                model,
                tokenizer,
                pad_below_max_len=pad_below_max_len,
                max_len=max_len,
                output_dic=output_dic)
            analysed_batch_dic_en = all[studied_ind]
            i_lang = 0

            if args.compare_to_pretrained:
                all_origin = get_hidden_representation(
                    en_batch,
                    model_origin,
                    tokenizer_0,
                    pad_below_max_len=pad_below_max_len,
                    max_len=max_len,
                    output_dic=output_dic)
                analysed_batch_dic_src_origin = all_origin[studied_ind]

            for lang, target in data_target_dic.items():
                print(f"Starting target", {lang})
                i_lang += 1
                target_batch = target[i_data:i_data + batch_size]

                all = get_hidden_representation(
                    target_batch,
                    model,
                    tokenizer,
                    pad_below_max_len=pad_below_max_len,
                    max_len=max_len,
                    output_dic=output_dic)

                if args.compare_to_pretrained:
                    all_origin = get_hidden_representation(
                        target_batch,
                        model_origin,
                        tokenizer_0,
                        pad_below_max_len=pad_below_max_len,
                        max_len=max_len,
                        output_dic=output_dic)
                    analysed_batch_dic_target_origin = all_origin[studied_ind]

                analysed_batch_dic_target = all[studied_ind]

                former_layer, former_mean_target, former_mean_src = None, None, None
                for layer in analysed_batch_dic_target:
                    print(f"Starting layer", {layer})
                    # get average for sentence removing first and last special tokens
                    if output_dic:
                        mean_over_sent_src = []
                        mean_over_sent_target = []
                        mean_over_sent_target_origin = []
                        mean_over_sent_src_origin = []
                        for i_sent in range(len(analysed_batch_dic_en[layer])):
                            # removing special characters first and last and
                            mean_over_sent_src.append(
                                np.array(analysed_batch_dic_en[layer][i_sent][
                                    0, 1:-1, :].mean(dim=0).cpu()))
                            mean_over_sent_target.append(
                                np.array(analysed_batch_dic_target[layer]
                                         [i_sent][0,
                                                  1:-1, :].mean(dim=0).cpu()))

                            if args.compare_to_pretrained:
                                mean_over_sent_target_origin.append(
                                    np.array(
                                        analysed_batch_dic_target_origin[layer]
                                        [i_sent][0,
                                                 1:-1, :].mean(dim=0).cpu()))
                                if i_lang == 1:
                                    mean_over_sent_src_origin.append(
                                        np.array(analysed_batch_dic_src_origin[
                                            layer][i_sent][0, 1:-1, :].mean(
                                                dim=0).cpu()))

                        if args.compare_to_pretrained:
                            mean_over_sent_target_origin = np.array(
                                mean_over_sent_target_origin)
                            if i_lang == 1:
                                mean_over_sent_src_origin = np.array(
                                    mean_over_sent_src_origin)
                        mean_over_sent_src = np.array(mean_over_sent_src)
                        mean_over_sent_target = np.array(mean_over_sent_target)

                    else:
                        mean_over_sent_src = analysed_batch_dic_en[
                            layer][:, 1:-1, :].mean(dim=1).cpu()
                        mean_over_sent_target = analysed_batch_dic_target[
                            layer][:, 1:-1, :].mean(dim=1).cpu()

                    if layer not in cosine_sent_to_src[src_lang + "-" + lang]:
                        cosine_sent_to_src[src_lang + "-" + lang][layer] = []
                    if layer not in cosine_sent_to_origin[lang]:
                        cosine_sent_to_origin[lang][layer] = []
                    if layer not in cosine_sent_to_origin_src[src_lang]:
                        cosine_sent_to_origin_src[src_lang][layer] = []

                    if metric == "cka":
                        mean_over_sent_src = np.array(mean_over_sent_src)
                        mean_over_sent_target = np.array(mean_over_sent_target)

                        cosine_sent_to_src[src_lang + "-" +
                                           lang][layer].append(
                                               kernel_CKA(
                                                   mean_over_sent_src,
                                                   mean_over_sent_target))
                        if args.compare_to_pretrained:
                            cosine_sent_to_origin[lang][layer].append(
                                kernel_CKA(mean_over_sent_target,
                                           mean_over_sent_target_origin))
                            if i_lang == 1:
                                cosine_sent_to_origin_src[src_lang][
                                    layer].append(
                                        kernel_CKA(mean_over_sent_src_origin,
                                                   mean_over_sent_src))
                                print(
                                    f"Measured EN TO ORIGIN {metric} {layer} {cosine_sent_to_origin_src[src_lang][layer][-1]} "
                                    + info_model)
                            print(
                                f"Measured LANG {lang} TO ORIGIN {metric} {layer} {cosine_sent_to_origin[lang][layer][-1]} "
                                + info_model)

                        print(
                            f"Measured {metric} {layer} {kernel_CKA(mean_over_sent_src,mean_over_sent_target)} "
                            + info_model)
                    else:
                        cosine_sent_to_src[
                            src_lang + "-" + lang][layer].append(
                                cosinus(mean_over_sent_src,
                                        mean_over_sent_target).item())

                    if former_layer is not None:
                        if layer not in cosine_sent_to_former_layer[lang]:
                            cosine_sent_to_former_layer[lang][layer] = []
                        if layer not in cosine_sent_to_former_layer_src[
                                src_lang]:
                            cosine_sent_to_former_layer_src[src_lang][
                                layer] = []
                        if metric == "cka":
                            cosine_sent_to_former_layer[lang][layer].append(
                                kernel_CKA(former_mean_target,
                                           mean_over_sent_target))
                            if i_lang == 1:
                                cosine_sent_to_former_layer_src[src_lang][
                                    layer].append(
                                        kernel_CKA(former_mean_src,
                                                   mean_over_sent_src))
                        else:
                            cosine_sent_to_former_layer[lang][layer].append(
                                cosinus(former_mean_target,
                                        mean_over_sent_target).item())
                            if i_lang == 1:
                                cosine_sent_to_former_layer_src[src_lang][
                                    layer].append(
                                        cosinus(former_mean_target,
                                                mean_over_sent_target).item())

                    former_layer = layer
                    former_mean_target = mean_over_sent_target
                    former_mean_src = mean_over_sent_src

    # summary
    print_all = True
    lang_i = 0
    src_lang_i = 0
    #for lang, cosine_per_layer in cosine_sent_to_src.items():
    for lang, cosine_per_layer in cosine_sent_to_former_layer.items():
        layer_i = 0
        src_lang_i += 1
        for src_lang in src_lang_ls:
            lang_i += 1
            for layer, cosine_ls in cosine_per_layer.items():
                print(
                    f"Mean {metric} between {src_lang} and {lang} for {layer} is {np.mean(cosine_sent_to_src[src_lang+'-'+lang][layer])} std:{np.std(cosine_sent_to_src[src_lang+'-'+lang][layer])} measured on {len(cosine_sent_to_src[src_lang+'-'+lang][layer])} model  "
                    + info_model)
                if layer_i > 0 and print_all:

                    print(
                        f"Mean {metric} for {lang} beween {layer} and former is {np.mean(cosine_sent_to_former_layer[lang][layer])} std:{np.std(cosine_sent_to_former_layer[lang][layer])} measured on {len(cosine_sent_to_former_layer[lang][layer])} model "
                        + info_model)

                    report = report_template(
                        metric_val=metric,
                        subsample=lang + "_to_former_layer",
                        info_score_val=None,
                        score_val=np.mean(
                            cosine_sent_to_former_layer[lang][layer]),
                        n_sents=n_obs,
                        avg_per_sent=np.std(
                            cosine_sent_to_former_layer[lang][layer]),
                        n_tokens_score=n_obs * max_len,
                        model_full_name_val=model_full_name_val,
                        task="hidden_state_analysis",
                        evaluation_script_val="exact_match",
                        model_args_dir=args.init_args_dir,
                        token_type="word",
                        report_path_val=None,
                        data_val=layer,
                    )
                    report_ls.append(report)

                    if lang_i == 1:
                        print(
                            f"Mean {metric} for {lang} beween {layer} and former is {np.mean(cosine_sent_to_former_layer_src[src_lang][layer])} std:{np.std(cosine_sent_to_former_layer_src[src_lang][layer])} measured on {len(cosine_sent_to_former_layer_src[src_lang][layer])} model "
                            + info_model)

                        report = report_template(
                            metric_val=metric,
                            subsample=src_lang + "_to_former_layer",
                            info_score_val=None,
                            score_val=np.mean(
                                cosine_sent_to_former_layer_src[src_lang]
                                [layer]),
                            n_sents=n_obs,
                            avg_per_sent=np.std(
                                cosine_sent_to_former_layer_src[src_lang]
                                [layer]),
                            n_tokens_score=n_obs * max_len,
                            model_full_name_val=model_full_name_val,
                            task="hidden_state_analysis",
                            evaluation_script_val="exact_match",
                            model_args_dir=args.init_args_dir,
                            token_type="word",
                            report_path_val=None,
                            data_val=layer,
                        )
                        report_ls.append(report)

                layer_i += 1

                report = report_template(
                    metric_val=metric,
                    subsample=lang + "_to_" + src_lang,
                    info_score_val=None,
                    score_val=np.mean(cosine_sent_to_src[src_lang + '-' +
                                                         lang][layer]),
                    n_sents=n_obs,
                    #avg_per_sent=np.std(cosine_ls),
                    avg_per_sent=np.std(cosine_sent_to_src[src_lang + '-' +
                                                           lang][layer]),
                    n_tokens_score=n_obs * max_len,
                    model_full_name_val=model_full_name_val,
                    task="hidden_state_analysis",
                    evaluation_script_val="exact_match",
                    model_args_dir=args.init_args_dir,
                    token_type="word",
                    report_path_val=None,
                    data_val=layer,
                )

                report_ls.append(report)

                #
                if args.compare_to_pretrained:

                    print(
                        f"Mean {metric} for {lang} beween {layer} and origin model is {np.mean(cosine_sent_to_origin[lang][layer])} std:{np.std(cosine_sent_to_origin[lang][layer])} measured on {len(cosine_sent_to_origin[lang][layer])} model "
                        + info_model)
                    report = report_template(
                        metric_val=metric,
                        subsample=lang + "_to_origin",
                        info_score_val=None,
                        score_val=np.mean(cosine_sent_to_origin[lang][layer]),
                        n_sents=n_obs,
                        avg_per_sent=np.std(
                            cosine_sent_to_origin[lang][layer]),
                        n_tokens_score=n_obs * max_len,
                        model_full_name_val=model_full_name_val,
                        task="hidden_state_analysis",
                        evaluation_script_val="exact_match",
                        model_args_dir=args.init_args_dir,
                        token_type="word",
                        report_path_val=None,
                        data_val=layer)
                    report_ls.append(report)

                    if lang_i == 1:

                        print(
                            f"Mean {metric} for en beween {layer} and origin model is {np.mean(cosine_sent_to_origin_src[src_lang][layer])} std:{np.std(cosine_sent_to_origin_src[src_lang][layer])} measured on {len(cosine_sent_to_origin_src[src_lang][layer])} model "
                            + info_model)
                        report = report_template(
                            metric_val=metric,
                            subsample=src_lang + "_to_origin",
                            info_score_val=None,
                            score_val=np.mean(
                                cosine_sent_to_origin_src[src_lang][layer]),
                            n_sents=n_obs,
                            avg_per_sent=np.std(
                                cosine_sent_to_origin_src[src_lang][layer]),
                            n_tokens_score=n_obs * max_len,
                            model_full_name_val=model_full_name_val,
                            task="hidden_state_analysis",
                            evaluation_script_val="exact_match",
                            model_args_dir=args.init_args_dir,
                            token_type="word",
                            report_path_val=None,
                            data_val=layer)
                        report_ls.append(report)

        # break

    if args.report_dir is None:
        report_dir = PROJECT_PATH + f"/../../analysis/attention_analysis/report/{run_id}-report"
        os.mkdir(report_dir)
    else:
        report_dir = args.report_dir
    assert os.path.isdir(report_dir)
    with open(report_dir + "/report.json", "w") as f:
        json.dump(report_ls, f)

    overall_report = args.overall_report_dir + "/" + args.overall_label + "-grid-report.json"
    with open(overall_report, "r") as g:
        report_all = json.load(g)
        report_all.extend(report_ls)
    with open(overall_report, "w") as file:
        json.dump(report_all, file)

    print("{} {} ".format(REPORT_FLAG_DIR_STR, overall_report))
def get_vocab_size_and_dictionary_per_task(tasks,
                                           pos_dictionary=None,
                                           type_dictionary=None,
                                           vocab_bert_wordpieces_len=None,
                                           task_parameters=None,
                                           verbose=1):
    # TODO : should be factorize with load dictionaries
    if pos_dictionary is None and type_dictionary is None:
        assert "pos" not in tasks and "parsing" not in tasks, \
            "ERROR : pos or parsing are in tasks but related dictionaries are None"
        printing("INFO : no dictionaries and voc_sizes needed",
                 verbose=verbose,
                 verbose_level=1)
        return None, None
    num_labels_per_task = OrderedDict()
    task_to_label_dictionary = OrderedDict()

    if "pos" in tasks:
        assert pos_dictionary is not None
        task_to_label_dictionary["pos-pos"] = pos_dictionary
        num_labels_per_task["pos-" + task_parameters["pos"]["label"][0]] = len(
            pos_dictionary.instance2index) + 1
    if "parsing" in tasks:
        assert type_dictionary is not None
        num_labels_per_task["parsing-types"] = len(
            type_dictionary.instance2index) + 1
        num_labels_per_task["parsing-heads"] = 0

        task_to_label_dictionary["parsing-types"] = type_dictionary
        task_to_label_dictionary["parsing-heads"] = "index"

    if "n_masks_mwe" in tasks:
        num_labels_per_task["n_masks_mwe-" +
                            task_parameters["n_masks_mwe"]["label"][0]] = 3
        task_to_label_dictionary["n_masks_mwe-n_masks_mwe"] = "index"

    if "mwe_detection" in tasks:
        num_labels_per_task["mwe_detection-" +
                            task_parameters["mwe_detection"]["label"][0]] = 2
        task_to_label_dictionary["mwe_detection-mwe_detection"] = "index"
    if "mwe_prediction" in tasks:
        assert vocab_bert_wordpieces_len is not None
        num_labels_per_task["mwe_prediction-" +
                            task_parameters["mwe_prediction"]["label"]
                            [0]] = vocab_bert_wordpieces_len
        task_to_label_dictionary[
            "mwe_prediction-" +
            task_parameters["mwe_prediction"]["label"][0]] = "index"
    if "mlm" in tasks:
        assert vocab_bert_wordpieces_len is not None
        num_labels_per_task[
            "mlm-" +
            task_parameters["mlm"]["label"][0]] = vocab_bert_wordpieces_len
        task_to_label_dictionary["mlm-" +
                                 task_parameters["mlm"]["label"][0]] = "index"

    if "normalize" in tasks:
        # TODO : add ones (will go along the prediciton module : MLM + 1 token)
        num_labels_per_task["normalize-" + task_parameters["normalize"]
                            ["label"][0]] = vocab_bert_wordpieces_len + 1
        task_to_label_dictionary[
            "normalize-" + task_parameters["normalize"]["label"][0]] = "index"

    if "norm_not_norm" in tasks:
        # TODO : add ones (will go along the prediciton module : MLM + 1 token)
        num_labels_per_task["norm_not_norm-" +
                            task_parameters["norm_not_norm"]["label"][0]] = 2
        task_to_label_dictionary[
            "norm_not_norm-" +
            task_parameters["norm_not_norm"]["label"][0]] = "index"

    if "norm_n_masks" in tasks:
        # TODO : add ones (will go along the prediciton module : MLM + 1 token)
        num_labels_per_task["norm_n_masks-" +
                            task_parameters["norm_n_masks"]["label"][0]] = 5
        task_to_label_dictionary[
            "norm_n_masks-" +
            task_parameters["norm_n_masks"]["label"][0]] = "index"

    return num_labels_per_task, task_to_label_dictionary
def main(args, dict_path, model_dir):
    model, tokenizer, run_id = load_all_analysis(args, dict_path, model_dir)
    
    if args.compare_to_pretrained:
        print("Loading Pretrained model also for comparison with pretrained")
        args_origin = args_attention_analysis()
        args_origin = args_preprocess_attention_analysis(args_origin)
        args_origin.init_args_dir = None
        args_origin, dict_path_0, model_dir_0 = get_dirs(args_origin)
        args_origin.model_id_pref += "again"
        model_origin, tokenizer_0, _ = load_all_analysis(args_origin, dict_path_0, model_dir_0)
        model_origin.eval()
        print("seco,")
    # only allow output of the model to be hidden states here
    print("Checkpoint loaded")
    assert not args.output_attentions
    assert args.output_all_encoded_layers and args.output_hidden_states_per_head

    data = ["I am here", "How are you"]
    model.eval()
    n_obs = args.n_sent
    max_len = args.max_seq_len
    lang_ls = args.raw_text_code

    lang = ["fr_pud",  # "de_pud", "ru_pud", "tr_pud", "id_pud", "ar_pud", "pt_pud",  "es_pud", "fi_pud",
            # "it_pud", "sv_pud", "cs_pud", "pl_pud", "hi_pud", "zh_pud", "ko_pud", "ja_pud","th_pud"
            ]
    src_lang_ls = ["tr_imst", "en_ewt", #"ja_gsd", "ar_padt",  #"en_pud", "tr_pud", "ru_pud",# "ar_pud", #"de_pud", "ko_pud",
                   "ug_udt"
                    ]  # , "fr_pud", "ru_pud", "ar_pud"]
    src_lang_ls = ["tr_dedup", "az_100k_shuff",
                    "en_100k", "kk_100k_shuff", #"hu_dedup", #"ar_padt",
                   # "en_pud", "tr_pud", "ru_pud",# "ar_pud", #"de_pud", "ko_pud",
                   #"ckb_dedup",# "ja_dedup_200k",
                   #"ar_dedup_200k", "fa_dedup_200k", 
                   "ug_udt",
                   ]
    src_lang_ls = [#"ar_oscar", "tr_dedup", "az_100k_shuff", "fa_dedup_200k",
                   # "it_oscar", "en_oscar", #"hu_dedup", #"ar_padt",
                   "ar_oscar","de_oscar","en_oscar","fa_oscar" ,"fi_oscar" ,"fr_oscar", "he_oscar", "hi_oscar","hu_oscar","it_oscar","ja_oscar", "ko_oscar", "ru_oscar","tr_oscar", 
                   ]
    src_lang_ls.append(args.target_lang)

    def add_demo(src_lang_ls):
        for i in range(len(src_lang_ls)):
            if src_lang_ls[i]!="mt_mudt":
                src_lang_ls[i] += "_demo"
        return src_lang_ls

    #add_demo(src_lang_ls)
    

    # target is last
    target_class_ind = len(src_lang_ls)-1
    target_lang = src_lang_ls[target_class_ind]
    #to_class = [""]
    set_ = "test"
    #set_ = "test-demo"
    #print("Loading data...")

    #data_en = load_data(DATA_UD + f"/{src_lang_ls[0]}-ud-{set_}.conllu", line_filter="# text = ")

    #id_start_start_class, id_end_target_class = get_id_sent_target(target_class_ind, data_target_dic)

    # reg = linear_model.LogisticRegression()
    # just to get the keyw
    layer_all = get_hidden_representation(data, model, tokenizer, max_len=max_len)
    # removed hidden_per_layer

    assert len(layer_all) == 1
    # assert len(layer_all) == 2, "ERROR should only have hidden_per_layer and hidden_per_head_layer"

    report_ls = []
    accuracy_dic = OrderedDict()
    sampling = args.sampling
    metric = args.similarity_metric
    if metric == "cka":
        pad_below_max_len, output_dic = False, True
    else:
        pad_below_max_len, output_dic = False, True
    assert metric in ["cos", "cka"]

    batch_size = args.batch_size #len(data_en) // 4

    task_tuned = "No"

    if args.init_args_dir is None:
        # args.init_args_dir =
        id_model = f"{args.bert_model}-init-{args.random_init}"

        hyperparameters = OrderedDict([("bert_model", args.bert_model),
                                       ("random_init", args.random_init),
                                       ("not_load_params_ls", args.not_load_params_ls),
                                       ("dict_path", dict_path),
                                       ("model_id", id_model), ])
        info_checkpoint = OrderedDict([("epochs", 0), ("batch_size", batch_size),
                                       ("train_path", 0), ("dev_path", 0), ("num_labels_per_task", 0)])

        args.init_args_dir = write_args(os.environ.get("MT_NORM_PARSE", "./"), model_id=id_model,
                                        info_checkpoint=info_checkpoint,
                                        hyperparameters=hyperparameters, verbose=1)
        print("args_dir checkout ", args.init_args_dir)
        model_full_name_val = task_tuned + "-" + id_model
    else:
        argument = json.load(open(args.init_args_dir, 'r'))
        task_tuned = argument["hyperparameters"]["tasks"][0][0] if not "wiki" in argument["info_checkpoint"][
            "train_path"] else "ner"
        model_full_name_val = task_tuned + "-" + args.init_args_dir.split("/")[-1]

    if args.analysis_mode == "layer":
        studied_ind = 0
    elif args.analysis_mode == "layer_head":
        studied_ind = 1
    else:
        raise (Exception(f"args.analysis_mode : {args.analysis_mode} corrupted"))


    output_dic = True
    pad_below_max_len = False
    max_len = 500

    sent_embeddings_per_lang = OrderedDict()
    sent_text_per_lang = OrderedDict()
    pick_layer = ["layer_6"]
    n_batch = args.n_batch
    #assert n_batch==1, "ERROR not working otherwise ! "
    demo = 0
    assert args.n_sent_extract <= args.batch_size * args.n_batch * (len(src_lang_ls) - 1), "ERROR not enough data provided for the selection"
    
    print(f"Starting processing : {n_batch} batch of size {batch_size}")

    def sanity_len_check(src_lang_ls, n_sent_per_lang):
        for src_lang in src_lang_ls:
            
            dir_data = OSCAR + f"/{src_lang}-train.txt"
            num_lines = sum(1 for line in open(dir_data))
            print(f"Sanity checking {src_lang} should have more than {n_sent_per_lang} sentences, it has {num_lines}")
            assert num_lines>=n_sent_per_lang, f"ERROR {src_lang} {num_lines} < {n_sent_per_lang} n_sent_per_lang"
    
    sanity_len_check(src_lang_ls[:-1], n_sent_per_lang=args.batch_size * args.n_batch)



    for i_data in tqdm(range(n_batch)):
        if demo:
            batch_size = 50
            n_batch = 1
            if i_data > 0:
                break
        for src_lang in tqdm(src_lang_ls):
            print(f"Loading lang {src_lang} batch size {batch_size}")
            
            #data_en = load_data(DATA_UD + f"/{src_lang}-ud-{set_}.conllu", line_filter="# text = ")
            #en_batch =  # data[i_data:i_data + batch_size]
            try:
                dir_data = get_dir_data(set="train", data_code=src_lang)
                filter_row = "# text = "
            except Exception as e:
                dir_data = OSCAR + f"/{src_lang}-train.txt"
                filter_row = ""
                print(f"{src_lang} not supported or missing : data defined as {dir_data} filter empty")
            try:
                en_batch = load_data(dir_data, line_filter=filter_row, id_start=i_data*batch_size, id_end=(i_data+1)*batch_size)
            except Exception as e:
                print(f"ERROR: cannot load data {dir_data} skipping")
                if i_data==0:
                    raise(Exception(e))
                continue
            if en_batch is None:
                print(f"lang {src_lang} reading {i_data*batch_size} seems empty so skipping")
                continue

            if src_lang not in sent_text_per_lang:
                sent_text_per_lang[src_lang] = []
            sent_text_per_lang[src_lang].extend(en_batch)

            all = get_hidden_representation(en_batch, model, tokenizer, pad_below_max_len=pad_below_max_len,
                                            max_len=max_len, output_dic=output_dic)

            analysed_batch_dic_en = all[studied_ind]
            i_lang = 0

            if args.compare_to_pretrained:
                all_origin = get_hidden_representation(en_batch, model_origin, tokenizer_0,
                                                       pad_below_max_len=pad_below_max_len,
                                                       max_len=max_len, output_dic=output_dic)
                analysed_batch_dic_src_origin = all_origin[studied_ind]

            for layer in analysed_batch_dic_en:
                if layer not in pick_layer:
                    continue
                else:
                    print(f"Picking {pick_layer} layer")
                print(f"Starting layer", {layer})
                # get average for sentence removing first and last special tokens
                if layer not in sent_embeddings_per_lang:
                    sent_embeddings_per_lang[layer] = OrderedDict()
                if src_lang not in sent_embeddings_per_lang[layer]:
                    sent_embeddings_per_lang[layer][src_lang] = []
                if output_dic:
                    mean_over_sent_src = []
                    #mean_over_sent_target = []
                    #mean_over_sent_target_origin = []
                    mean_over_sent_src_origin = []
                    for i_sent in range(len(analysed_batch_dic_en[layer])):
                        # removing special characters first and last and
                        mean_over_sent_src.append(
                            np.array(analysed_batch_dic_en[layer][i_sent][0, 1:-1, :].cpu().mean(dim=0)))
                        #mean_over_sent_target.append(
                        #    np.array(analysed_batch_dic_target[layer][i_sent][0, 1:-1, :].mean(dim=0)))
                        if args.compare_to_pretrained:
                        #    mean_over_sent_target_origin.append(
                        #        np.array(analysed_batch_dic_target_origin[layer][i_sent][0, 1:-1, :].mean(dim=0)))
                            if i_lang == 1:
                                mean_over_sent_src_origin.append(
                                    np.array(analysed_batch_dic_src_origin[layer][i_sent][0, 1:-1, :].mean(dim=0)))
                    if args.compare_to_pretrained:
                    #    mean_over_sent_target_origin = np.array(mean_over_sent_target_origin)
                        if i_lang == 1:
                            mean_over_sent_src_origin = np.array(mean_over_sent_src_origin)
                    mean_over_sent_src = np.array(mean_over_sent_src)
                    #mean_over_sent_target = np.array(mean_over_sent_target)
                else:
                    mean_over_sent_src = analysed_batch_dic_en[layer][:, 1:-1, :].mean(dim=1)
                    #mean_over_sent_target = analysed_batch_dic_target[layer][:, 1:-1, :].mean(dim=1)

                sent_embeddings_per_lang[layer][src_lang].append(mean_over_sent_src)

    def get_id_sent_target(target_class_ind, data_target_dic):
        n_sent_total = 0

        assert target_class_ind <= len(data_target_dic)
        for ind_class, lang in enumerate(src_lang_ls):
            n_sent_total += len(data_target_dic[lang])
            if ind_class == target_class_ind:
                n_sent_class = len(data_target_dic[lang])
                id_start_start_class = n_sent_total
                id_end_target_class = n_sent_total + n_sent_class
        return id_start_start_class, id_end_target_class

    clustering = "distance"

    if clustering in ["gmm", "spectral"]:
        concat_train,  concat_test, y_train, y_test, lang2id = concat_all_lang_space_split_train_test(sent_embeddings_per_lang, src_lang_ls, pick_layer)
        #X = np.array(concat).squeeze(1)
        X_train = np.array(concat_train)
        X_test = np.array(concat_test)

        if len(X_train.shape) > 2:

            X_train = X_train.reshape(X_train.shape[0]*X_train.shape[1],-1)
            X_test = X_test.reshape(X_test.shape[0]*X_test.shape[1],-1)
        if clustering == "gmm":
            model = mixture.GaussianMixture(n_components=len(src_lang_ls)-1, covariance_type='full')
            model.fit(X_train)
            model_based_clustering = True
        elif clustering == "spectral":
            model = cluster.spectral_clustering(n_clusters=len(src_lang_ls))
            model.fit(X_train)
            model_based_clustering = True

    elif clustering == "distance":
        # concat batch_size

        for layer in sent_embeddings_per_lang:
            assert len(sent_embeddings_per_lang[layer])>1, "ERRO you're doing distance measure ! "
            for lang in sent_embeddings_per_lang[layer]:
                arr = np.array(sent_embeddings_per_lang[layer][lang])
                if arr.shape[0]!=n_batch:
                    print(f"WARNNIG: shape: {lang}  {np.array(sent_embeddings_per_lang[layer][lang]).shape} reshaping to {arr.shape[0]*arr.shape[1]}")
                sent_embeddings_per_lang[layer][lang] = arr.reshape(arr.shape[0] * arr.shape[1], -1)
            assert sent_embeddings_per_lang[layer][lang].shape[0] == len(sent_text_per_lang[lang]), f"ERROR lang {lang} layer {layer}  {sent_embeddings_per_lang[layer][lang].shape}[0]<>{len(sent_text_per_lang[lang])}"

        sent_embeddings_per_lang_train, sent_embeddings_per_lang_test, sent_text_per_lang = \
            split_train_test(sent_embeddings_per_lang, sent_text_per_lang,
                             keep_text_test=True, target_lang=target_lang,
                             target_lang_no_test=True,
                             prop_train=1 / 20)

        centroid_train, ls_lang = get_centroid(sent_embeddings_per_lang_train, target_lang=target_lang, only_target_centoid=False)
        # outputing for each sentence (with layer x lang information)
        print("ls_lang", ls_lang)
        closest_lang, score_to_target_test = get_closest_centroid(sent_embeddings_per_lang_test, centroid_train, ls_lang, ind_lang_target=target_class_ind)

        get_stat_distance(closest_lang, ls_lang, target_lang)

        count_n_extracted_sent = 0
        for layer in score_to_target_test:
            for lang in score_to_target_test[layer]:
                count_n_extracted_sent += len(score_to_target_test[layer][lang])
        print(f"Cosine extracted sent {count_n_extracted_sent}")
        test_sent_extracted, index_test_extraxted, info_per_layer_select = get_closest_n_sent(n_sent=args.n_sent_extract, score_to_target=score_to_target_test, sent_text_per_lang=sent_text_per_lang, lang_ls=src_lang_ls,
                                                                                              target_lang=target_lang)
        get_iou_inter(index_test_extraxted)


        dir_file = os.path.join(os.environ.get("OSCAR", "/Users/bemuller/Documents/Work/INRIA/dev/data"),"data_selected")
        #dir_file = "/Users/bemuller/Documents/Work/INRIA/dev/data/data_selected"
        write_down_selected(test_sent_extracted, info_per_layer_select, dir_file, id=f"select-{args.overall_label}-{args.bert_model}-{target_lang}-n_sent-{args.n_sent_extract}")


    if clustering in ["gmm", "spectral"]:
        target_class_ind = X_train
        predict_proba_train = model.predict_proba(X_train)
        predict_train = model.predict(X_train)
        predict_proba = model.predict_proba(X_test)
        predict_test = model.predict(X_test)

        def get_most_common_per_class(predict, lang2id):
            " for each class : finding the clustering predicting using majority vote "
            id_class_start = 0
            id_class_end = 0
            pred_label_to_real_label = {}
            for lang in lang2id:
                id_class_end += lang2id[lang]["n_sent_train"]

                pred_class = predict[id_class_start:id_class_end]

                assert len(pred_class)>0
                id_class_start = id_class_end
                from collections import Counter
                pred_class_counter = Counter(pred_class)
                lang2id[lang]["pred_label"] = pred_class_counter.most_common()[0][0]
                if pred_class_counter.most_common()[0][0] in pred_label_to_real_label:
                    print(f"WARNING: {pred_class_counter.most_common()[0][0]} pred label as mot_common in a class is predicted in two classes")
                pred_label_to_real_label[pred_class_counter.most_common()[0][0]] = lang2id[lang]["id"]
            return lang2id, pred_label_to_real_label

        lang2id, pred_label_to_real_label = get_most_common_per_class(predict_train, lang2id)
        print(f"V metric train {v_measure_score(predict_train, y_train)}")
        print(f"V metric test {v_measure_score(predict_test, y_test)}")

        def adapt_label(pred_label_to_real_label, pred):
            " based on majority bvote prediction : adapt prediction set to real label set"
            pred_new = []
            for label_pred in pred:
                if label_pred not in pred_label_to_real_label:
                    print("Warning : pred label not associated to any true label")
                pred_new.append(pred_label_to_real_label.get(label_pred, label_pred))
            return pred_new

        def print_report(report, src_lang_ls, lang2id):
            for lang in src_lang_ls:
                id_label = lang2id[lang]["id"]
                print(f"Lang {lang} summary {report[str(id_label)]}")

            print(f"Macro Avg {lang} summary {report['macro avg']}")

        pred_new_train = adapt_label(pred_label_to_real_label, predict_train)
        report = classification_report(y_pred=pred_new_train, y_true=y_train, output_dict=True)
        print_report(report, src_lang_ls, lang2id)

        pred_new_test = adapt_label(pred_label_to_real_label, predict_test)
        report = classification_report(y_pred=pred_new_test, y_true=y_test, output_dict=True)

        print_report(report, src_lang_ls, lang2id)

        #print(predict_proba_train, predict_proba)


    #print(gmm.predict(X_len(-train),
    #gmm.predict_proba(X[:1, :]))

    # based on this --> for a given source set of sentences (ex : uyghur sentences)
    # 1 - find the cluster id of Uyghur sentences
    # 2 - get the x top sentences that have high proba for
    # 3 - print it to see if that makes sense
    # do it for 1000 uy , 1000k for 10 other languages
    # then same
    # then compare overlap per layers


    # summary
    print_all = True
    lang_i = 0
    src_lang_i = 0