def write_args(dir,
               model_id,
               checkpoint_dir=None,
               hyperparameters=None,
               info_checkpoint=None,
               verbose=1):

    args_dir = os.path.join(dir, "{}-args.json".format(model_id))
    if os.path.isfile(args_dir):
        info = "updated"
        args = json.load(open(args_dir, "r"))
        args["checkpoint_dir"] = checkpoint_dir
        args["info_checkpoint"] = info_checkpoint
        json.dump(args, open(args_dir, "w"))
    else:
        assert hyperparameters is not None, "REPORT : args.json created for the first time : hyperparameters dic required "
        #assert info_checkpoint is None, "REPORT : args. created for the first time : no checkpoint yet "
        info = "new"
        json.dump(
            OrderedDict([("checkpoint_dir", checkpoint_dir),
                         ("hyperparameters", hyperparameters),
                         ("info_checkpoint", info_checkpoint)]),
            open(args_dir, "w"))
    printing("MODEL args.json {} written {} ".format(info, args_dir),
             verbose_level=1,
             verbose=verbose)
    return args_dir
Exemple #2
0
 def save(self, output_directory, name=None):
     saving_name = name if name else self.__name
     try:
         dic_dir = os.path.join(output_directory, saving_name + ".json")
         if os.path.isfile(dic_dir):
             print("Overwriting dictionary {}".format(dic_dir))
         json.dump(self.get_content(), open(dic_dir, 'w'), indent=4)
     except Exception as e:
         raise RuntimeError("Dictionary is not saved: %s" % repr(e))
def main(args, dict_path, model_dir):

    encoder = BERT_MODEL_DIC[args.bert_model]["encoder"]
    vocab_size = BERT_MODEL_DIC[args.bert_model]["vocab_size"]
    voc_tokenizer = BERT_MODEL_DIC[args.bert_model]["vocab"]

    tokenizer = eval(BERT_MODEL_DIC[args.bert_model]["tokenizer"])
    random.seed(args.seed)

    if args.model_id_pref is None:
        run_id = str(uuid4())[:4]
    else:
        run_id = args.model_id_pref + "1"

    if args.init_args_dir is None:
        dict_path += "/" + run_id
        os.mkdir(dict_path)
    tokenizer = tokenizer.from_pretrained(voc_tokenizer,
                                          do_lower_case=args.case == "lower",
                                          shuffle_bpe_embedding=False)
    mask_id = tokenizer.encode([
        "[MASK]"
    ])[0] if args.bert_model == "bert_base_multilingual_cased" else None

    _dev_path = args.dev_path if args.dev_path is not None else args.train_path
    word_dictionary, word_norm_dictionary, char_dictionary, pos_dictionary, \
    xpos_dictionary, type_dictionary = \
        conllu_data.load_dict(dict_path=dict_path,
                              train_path=args.train_path if args.init_args_dir is None else None,
                              dev_path=args.dev_path if args.init_args_dir is None else None,
                              test_path=args.test_paths if args.init_args_dir is None else None,
                              word_embed_dict={},
                              dry_run=False,
                              expand_vocab=False,
                              word_normalization=True,
                              force_new_dic=False,
                              tasks=args.tasks,
                              pos_specific_data_set=None,
                              #pos_specific_data_set=args.train_path[1] if len(args.tasks) > 1 and len(
                              #    args.train_path) > 1 and "pos" in args.tasks else None,
                              case=args.case,
                              # if not normalize pos or parsing in tasks we don't need dictionary
                              do_not_fill_dictionaries=len(set(["normalize", "pos", "parsing"]) & set(
                                  [task for tasks in args.tasks for task in tasks])) == 0,
                              add_start_char=True if args.init_args_dir is None else None,
                              verbose=1)

    num_labels_per_task, task_to_label_dictionary = get_vocab_size_and_dictionary_per_task(
        [task for tasks in args.tasks for task in tasks],
        vocab_bert_wordpieces_len=vocab_size,
        pos_dictionary=pos_dictionary,
        type_dictionary=type_dictionary,
        task_parameters=TASKS_PARAMETER)

    model = make_bert_multitask(
        args=args,
        pretrained_model_dir=model_dir,
        init_args_dir=args.init_args_dir,
        tasks=[task for tasks in args.tasks for task in tasks],
        mask_id=mask_id,
        encoder=encoder,
        num_labels_per_task=num_labels_per_task)

    def get_n_params(model):
        pp = 0
        for p in list(model.parameters()):
            nn = 1
            for s in list(p.size()):

                nn = nn * s
            pp += nn
        return pp

    param = get_n_params(model)
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])

    pdb.set_trace()

    data = ["I am here", "How are you"]
    model.eval()
    n_obs = args.n_sent
    max_len = args.max_seq_len
    lang_ls = args.raw_text_code
    data_all, y_all = load_lang_ls(DATA_UD_RAW, lang_ls=lang_ls)

    reg = linear_model.LogisticRegression()
    X_train = OrderedDict()
    X_test = OrderedDict()
    y_train = OrderedDict()
    y_test = OrderedDict()
    # just to get the keyw
    layer_head_att = get_hidden_representation(data,
                                               model,
                                               tokenizer,
                                               max_len=max_len,
                                               output_dic=False,
                                               pad_below_max_len=True)
    layer_head_att = layer_head_att[0]
    report_ls = []
    accuracy_dic = OrderedDict()
    sampling = args.sampling
    for ind, layer_head in enumerate(list(layer_head_att.keys())):
        report = OrderedDict()
        accuracy_ls = []
        layer_head = list(
            layer_head_att.keys())[len(list(layer_head_att.keys())) - ind - 1]
        for _ in range(sampling):
            sample_ind = random.sample(population=range(len(data_all)),
                                       k=n_obs)
            sample_ind_test = random.sample(population=range(len(data_all)),
                                            k=n_obs)

            data = data_all[sample_ind]
            y = y_all[sample_ind]

            all = get_hidden_representation(data,
                                            model,
                                            tokenizer,
                                            max_len=max_len,
                                            output_dic=False,
                                            pad_below_max_len=True)

            layer_head_att = all[0]

            #pdb.set_trace()
            def reshape_x(z):
                return np.array(z.view(z.size(0) * z.size(1), -1))

            def reshape_y(z, n_seq):
                'multiply each element n_seq times'
                new_z = []
                for _z in z:
                    #for _ in range(n_seq):
                    new_z.extend([_z for _ in range(n_seq)])
                return np.array(new_z)
                #return np.array(z.view(z.size(0), -1).transpose(1, 0))

            #X_train[layer_head] = np.array(layer_head_att[layer_head].view(layer_head_att[layer_head].size(0), -1).transpose(1,0))
            X_train[layer_head] = reshape_x(layer_head_att[layer_head])
            y_train[layer_head] = reshape_y(y, max_len)
            #db.set_trace()
            #y_train[layer_head] = y

            reg.fit(X=X_train[layer_head], y=y_train[layer_head])

            # test
            data_test = data_all[sample_ind_test]
            layer_head_att_test = get_hidden_representation(
                data_test,
                model,
                tokenizer,
                max_len=max_len,
                output_dic=False,
                pad_below_max_len=True)
            X_test[layer_head] = reshape_x(layer_head_att_test[layer_head])
            y_test[layer_head] = reshape_y(y_all[sample_ind_test], max_len)

            y_pred = reg.predict(X_test[layer_head])

            Accuracy = np.sum(
                (y_test[layer_head] == y_pred)) / len(y_test[layer_head])
            accuracy_ls.append(Accuracy)

        accuracy_dic[layer_head] = np.mean(accuracy_ls)
        layer = layer_head.split("-")[0]
        if layer not in accuracy_dic:
            accuracy_dic[layer] = []
        accuracy_dic[layer].append(np.mean(accuracy_ls))

        print(
            f"Regression {layer_head} Accuracy test {np.mean(accuracy_ls)} on {n_obs * max_len}"
            f" word sample from {len(lang_ls)} languages task {args.tasks} args {'/'.join(args.init_args_dir.split('/')[-2:]) if args.init_args_dir is not None else None} "
            f"bert {args.bert_model} random init {args.random_init} std {np.std(accuracy_ls)} sampling {len(accuracy_ls)}=={sampling}"
        )

        #report["model_type"] = args.bert_model if args.init_args_dir is None else args.tasks[0][0]+"-tune"
        #report["accuracy"] = np.mean(accuracy_ls)
        #report["sampling"] = len(accuracy_ls)
        #report["std"] = np.std(accuracy_ls)
        #report["n_sent"] = n_obs
        #report["n_obs"] = n_obs*max_len

        report = report_template(
            metric_val="accuracy",
            subsample=",".join(lang_ls),
            info_score_val=sampling,
            score_val=np.mean(accuracy_ls),
            n_sents=n_obs,
            avg_per_sent=np.std(accuracy_ls),
            n_tokens_score=n_obs * max_len,
            model_full_name_val=run_id,
            task="attention_analysis",
            evaluation_script_val="exact_match",
            model_args_dir=args.init_args_dir
            if args.init_args_dir is not None else args.random_init,
            token_type="word",
            report_path_val=None,
            data_val=layer_head)
        report_ls.append(report)

        # break

    for key in accuracy_dic:
        print(
            f"Summary {key} {np.mean(accuracy_dic[key])} model word sample from {len(lang_ls)} languages task {args.tasks} args {'/'.join(args.init_args_dir.split('/')[-2:]) if args.init_args_dir is not None else None} "
            f"bert {args.bert_model} random init {args.random_init} std {np.std(accuracy_ls)} sampling {len(accuracy_ls)}=={sampling}"
        )

    if args.report_dir is None:
        report_dir = PROJECT_PATH + f"/../../analysis/attention_analysis/report/{run_id}-report"
        os.mkdir(report_dir)
    else:
        report_dir = args.report_dir
    assert os.path.isdir(report_dir)
    with open(report_dir + "/report.json", "w") as f:
        json.dump(report_ls, f)
    overall_report = args.overall_report_dir + "/" + args.overall_label + "-grid-report.json"
    with open(overall_report, "r") as g:
        report_all = json.load(g)
        report_all.extend(report_ls)
    with open(overall_report, "w") as file:
        json.dump(report_all, file)

    print("{} {} ".format(REPORT_FLAG_DIR_STR, overall_report))
Exemple #4
0
def get_from_cache(url, cache_dir=None):
    """
    Given a URL, look for the corresponding dataset in the local cache.
    If it's not there, download it. Then return the path to the cached file.
    """
    if cache_dir is None:
        cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
    if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
        cache_dir = str(cache_dir)

    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)

    # Get eTag to add to filename, if it exists.
    if url.startswith("s3://"):
        etag = s3_etag(url)
    else:
        response = requests.head(url, allow_redirects=True)
        if response.status_code != 200:
            raise IOError(
                "HEAD request failed for url {} with status code {}".format(
                    url, response.status_code))
        etag = response.headers.get("ETag")

    filename = url_to_filename(url, etag)

    # get cache path to put the file
    cache_path = os.path.join(cache_dir, filename)

    if not os.path.exists(cache_path):
        # Download to temporary file, then copy to cache dir once finished.
        # Otherwise you get corrupt cache entries if the download gets interrupted.
        with tempfile.NamedTemporaryFile() as temp_file:
            logger.info("%s not found in cache, downloading to %s", url,
                        temp_file.name)

            # GET file object
            if url.startswith("s3://"):
                s3_get(url, temp_file)
            else:
                http_get(url, temp_file)

            # we are copying the file before closing it, so flush to avoid truncation
            temp_file.flush()
            # shutil.copyfileobj() starts at the current position, so go to the start
            temp_file.seek(0)

            logger.info("copying %s to cache at %s", temp_file.name,
                        cache_path)
            with open(cache_path, 'wb') as cache_file:
                shutil.copyfileobj(temp_file, cache_file)

            logger.info("creating metadata file for %s", cache_path)
            meta = {'url': url, 'etag': etag}
            meta_path = cache_path + '.json'
            with open(meta_path, 'w', encoding="utf-8") as meta_file:
                json.dump(meta, meta_file)

            logger.info("removing temp file %s", temp_file.name)

    return cache_path
Exemple #5
0
def run(args,
        n_observation_max_per_epoch_train,
        vocab_size,
        model_dir,
        voc_tokenizer,
        auxilliary_task_norm_not_norm,
        null_token_index,
        null_str,
        tokenizer,
        n_observation_max_per_epoch_dev_test=None,
        run_mode="train",
        dict_path=None,
        end_predictions=None,
        report=True,
        model_suffix="",
        description="",
        saving_every_epoch=10,
        model_location=None,
        model_id=None,
        report_full_path_shared=None,
        skip_1_t_n=False,
        heuristic_test_ls=None,
        remove_mask_str_prediction=False,
        inverse_writing=False,
        extra_label_for_prediction="",
        random_iterator_train=True,
        bucket_test=False,
        must_get_norm_test=True,
        early_stoppin_metric=None,
        subsample_early_stoping_metric_val=None,
        compute_intersection_score_test=True,
        threshold_edit=3,
        name_with_epoch=False,
        max_token_per_batch=200,
        encoder=None,
        debug=False,
        verbose=1):
    """
    Wrapper for training/prediction/evaluation

    2 modes : train (will train using train and dev iterators with test at the end on test_path)
              test : only test at the end : requires all directories to be created
    :return:
    """
    assert run_mode in ["train", "test"
                        ], "ERROR run mode {} corrupted ".format(run_mode)
    input_level_ls = ["wordpiece"]
    assert early_stoppin_metric is not None and subsample_early_stoping_metric_val is not None, "ERROR : assert early_stoppin_metric should be defined and subsample_early_stoping_metric_val "
    if n_observation_max_per_epoch_dev_test is None:
        n_observation_max_per_epoch_dev_test = n_observation_max_per_epoch_train
    printing("MODEL : RUNNING IN {} mode",
             var=[run_mode],
             verbose=verbose,
             verbose_level=1)
    printing(
        "WARNING : casing was set to {} (this should be consistent at train and test)",
        var=[args.case],
        verbose=verbose,
        verbose_level=2)

    if len(args.tasks) == 1:
        printing("INFO : MODEL : 1 set of simultaneous tasks {}".format(
            args.tasks),
                 verbose=verbose,
                 verbose_level=1)

    if run_mode == "test":
        assert args.test_paths is not None and isinstance(
            args.test_paths, list)
    if run_mode == "train":
        printing("CHECKPOINTING info : "
                 "saving model every {}",
                 var=saving_every_epoch,
                 verbose=verbose,
                 verbose_level=1)

    use_gpu = use_gpu_(use_gpu=None, verbose=verbose)

    def get_commit_id():
        repo = git.Repo(os.path.dirname(os.path.realpath(__file__)),
                        search_parent_directories=True)
        git_commit_id = str(repo.head.commit)  # object.hexsha
        return git_commit_id

    if verbose > 1:
        print(f"GIT ID : {get_commit_id()}")

    train_data_label = get_dataset_label(args.train_path, default="train")

    iter_train = 0
    iter_dev = 0
    row = None
    writer = None

    printout_allocated_gpu_memory(verbose, "{} starting all".format(model_id))

    if run_mode == "train":
        if os.path.isdir(args.train_path[0]) and len(args.train_path) == 1:
            data_sharded = args.train_path[0]
            printing(
                "INFO args.train_path is directory so not rebuilding shards",
                verbose=verbose,
                verbose_level=1)
        elif os.path.isdir(args.train_path[0]):
            raise (Exception(
                " {} is a directory but len is more than one , not supported".
                format(args.train_path[0], len(args.train_path))))
        else:
            data_sharded = None
        assert model_location is None and model_id is None, "ERROR we are creating a new one "

        model_id, model_location, dict_path, tensorboard_log, end_predictions, data_sharded \
            = setup_repoting_location(model_suffix=model_suffix, data_sharded=data_sharded,
                                      root_dir_checkpoints=CHECKPOINT_BERT_DIR,
                                      shared_id=args.overall_label, verbose=verbose)
        hyperparameters = get_hyperparameters_dict(
            args,
            args.case,
            random_iterator_train,
            seed=args.seed,
            verbose=verbose,
            dict_path=dict_path,
            model_id=model_id,
            model_location=model_location)
        args_dir = write_args(model_location,
                              model_id=model_id,
                              hyperparameters=hyperparameters,
                              verbose=verbose)

        if report:
            if report_full_path_shared is not None:
                tensorboard_log = os.path.join(report_full_path_shared,
                                               "tensorboard")
            printing("tensorboard --logdir={} --host=localhost --port=1234 ",
                     var=[tensorboard_log],
                     verbose_level=1,
                     verbose=verbose)
            writer = SummaryWriter(log_dir=tensorboard_log)
            if writer is not None:
                writer.add_text("INFO-ARGUMENT-MODEL-{}".format(model_id),
                                str(hyperparameters), 0)
    else:
        args_checkpoint = json.load(open(args.init_args_dir, "r"))
        dict_path = args_checkpoint["hyperparameters"]["dict_path"]
        assert dict_path is not None and os.path.isdir(
            dict_path), "ERROR {} ".format(dict_path)
        end_predictions = args.end_predictions
        assert end_predictions is not None and os.path.isdir(
            end_predictions), "ERROR end_predictions"
        model_location = args_checkpoint["hyperparameters"]["model_location"]
        model_id = args_checkpoint["hyperparameters"]["model_id"]
        assert model_location is not None and model_id is not None, "ERROR model_location model_id "
        args_dir = os.path.join(model_location,
                                "{}-args.json".format(model_id))

        printing(
            "CHECKPOINTING : starting writing log \ntensorboard --logdir={} --host=localhost --port=1234 ",
            var=[os.path.join(model_id, "tensorboard")],
            verbose_level=1,
            verbose=verbose)

    # build or make dictionaries
    _dev_path = args.dev_path if args.dev_path is not None else args.train_path
    word_dictionary, word_norm_dictionary, char_dictionary, pos_dictionary, \
    xpos_dictionary, type_dictionary = \
        conllu_data.load_dict(dict_path=dict_path,
                              train_path=args.train_path if run_mode == "train" else None,
                              dev_path=args.dev_path if run_mode == "train" else None,
                              test_path=None,
                              word_embed_dict={},
                              dry_run=False,
                              expand_vocab=False,
                              word_normalization=True,
                              force_new_dic=True if run_mode == "train" else False,
                              tasks=args.tasks,
                              pos_specific_data_set=args.train_path[1] if len(args.tasks) > 1 and len(args.train_path)>1 and "pos" in args.tasks else None,
                              case=args.case,
                              # if not normalize pos or parsing in tasks we don't need dictionary
                              do_not_fill_dictionaries=len(set(["normalize", "pos", "parsing"])&set([task for tasks in args.tasks for task in tasks])) == 0,
                              add_start_char=1 if run_mode == "train" else None,
                              verbose=verbose)
    # we flatten the taskssd
    printing("DICTIONARY CREATED/LOADED", verbose=verbose, verbose_level=1)
    num_labels_per_task, task_to_label_dictionary = get_vocab_size_and_dictionary_per_task(
        [task for tasks in args.tasks for task in tasks],
        vocab_bert_wordpieces_len=vocab_size,
        pos_dictionary=pos_dictionary,
        type_dictionary=type_dictionary,
        task_parameters=TASKS_PARAMETER)
    voc_pos_size = num_labels_per_task["pos"] if "pos" in args.tasks else None
    if voc_pos_size is not None:
        printing("MODEL : voc_pos_size defined as {}",
                 var=voc_pos_size,
                 verbose_level=1,
                 verbose=verbose)
    printing("MODEL init...", verbose=verbose, verbose_level=1)
    if verbose > 1:
        print("DEBUG : TOKENIZER :voc_tokenizer from_pretrained",
              voc_tokenizer)
    #pdb.set_trace()
    #voc_tokenizer = "bert-base-multilingual-cased"
    tokenizer = tokenizer.from_pretrained(
        voc_tokenizer,
        do_lower_case=args.case == "lower",
        shuffle_bpe_embedding=args.shuffle_bpe_embedding)
    mask_id = tokenizer.convert_tokens_to_ids(
        tokenizer.mask_token)  #convert_tokens_to_ids([MASK_BERT])[0]
    printout_allocated_gpu_memory(verbose,
                                  "{} loading model ".format(model_id))
    model = get_model_multi_task_bert(args=args,
                                      model_dir=model_dir,
                                      encoder=encoder,
                                      num_labels_per_task=num_labels_per_task,
                                      mask_id=mask_id)

    def prune_heads(prune_heads):
        if prune_heads is not None:
            pune_heads_ls = prune_heads.split(",")[:-1]
            assert len(pune_heads_ls) > 0
            for layer in pune_heads_ls:
                parsed_layer_to_prune = layer.split("-")
                assert parsed_layer_to_prune[0] == "prune_heads"
                assert parsed_layer_to_prune[1] == "layer"
                assert parsed_layer_to_prune[3] == "heads"
                heads = parsed_layer_to_prune[4]
                head_index_ls = heads.split("_")
                heads_ls = [int(index) for index in head_index_ls]
                print(
                    f"MODEL : pruning layer {parsed_layer_to_prune[2]} heads {heads_ls}"
                )
                model.encoder.encoder.layer[int(
                    parsed_layer_to_prune[2])].attention.prune_heads(heads_ls)

    if args.prune_heads is not None and args.prune_heads != "None":
        print(f"INFO : args.prune_heads {args.prune_heads}")
        prune_heads(args.prune_heads)

    if use_gpu:
        model.to("cuda")
        printing("MODEL TO CUDA", verbose=verbose, verbose_level=1)
    printing("MODEL model.config {} ",
             var=[model.config],
             verbose=verbose,
             verbose_level=1)
    printout_allocated_gpu_memory(verbose, "{} model loaded".format(model_id))
    model_origin = OrderedDict()
    pruning_mask = OrderedDict()
    printout_allocated_gpu_memory(verbose, "{} model cuda".format(model_id))
    for name, param in model.named_parameters():
        model_origin[name] = param.detach().clone()
        printout_allocated_gpu_memory(verbose, "{} param cloned ".format(name))
        if args.penalization_mode == "pruning":
            abs = torch.abs(param.detach().flatten())
            median_value = torch.median(abs)
            pruning_mask[name] = (abs > median_value).float()
        printout_allocated_gpu_memory(
            verbose, "{} pruning mask loaded".format(model_id))

    printout_allocated_gpu_memory(verbose, "{} model clone".format(model_id))

    inv_word_dic = word_dictionary.instance2index
    # load , mask, bucket and index data

    assert tokenizer is not None, "ERROR : tokenizer is None , voc_tokenizer failed to be loaded {}".format(
        voc_tokenizer)
    if run_mode == "train":
        time_load_readers_train_start = time.time()
        if not args.memory_efficient_iterator:

            data_sharded, n_shards, n_sent_dataset_total_train = None, None, None
            args_load_batcher_shard_data = None
            printing("INFO : starting loading readers",
                     verbose=verbose,
                     verbose_level=1)
            readers_train = readers_load(
                datasets=args.train_path,
                tasks=args.tasks,
                word_dictionary=word_dictionary,
                bert_tokenizer=tokenizer,
                word_dictionary_norm=word_norm_dictionary,
                char_dictionary=char_dictionary,
                pos_dictionary=pos_dictionary,
                xpos_dictionary=xpos_dictionary,
                type_dictionary=type_dictionary,
                word_decoder=True,
                run_mode=run_mode,
                add_start_char=1,
                add_end_char=1,
                symbolic_end=1,
                symbolic_root=1,
                bucket=True,
                must_get_norm=True,
                input_level_ls=input_level_ls,
                verbose=verbose)
            n_sent_dataset_total_train = readers_train[list(
                readers_train.keys())[0]][3]
            printing("INFO : done with sharding",
                     verbose=verbose,
                     verbose_level=1)
        else:
            printing("INFO : building/loading shards ",
                     verbose=verbose,
                     verbose_level=1)
            data_sharded, n_shards, n_sent_dataset_total_train = build_shard(
                data_sharded,
                args.train_path,
                n_sent_max_per_file=N_SENT_MAX_CONLL_PER_SHARD,
                verbose=verbose)

        time_load_readers_dev_start = time.time()
        time_load_readers_train = time.time() - time_load_readers_train_start
        readers_dev_ls = []
        dev_data_label_ls = []
        printing("INFO : g readers for dev", verbose=verbose, verbose_level=1)
        printout_allocated_gpu_memory(
            verbose, "{} reader train loaded".format(model_id))
        for dev_path in args.dev_path:
            dev_data_label = get_dataset_label(dev_path, default="dev")
            dev_data_label_ls.append(dev_data_label)
            readers_dev = readers_load(
                datasets=dev_path,
                tasks=args.tasks,
                word_dictionary=word_dictionary,
                word_dictionary_norm=word_norm_dictionary,
                char_dictionary=char_dictionary,
                pos_dictionary=pos_dictionary,
                xpos_dictionary=xpos_dictionary,
                bert_tokenizer=tokenizer,
                type_dictionary=type_dictionary,
                word_decoder=True,
                run_mode=run_mode,
                add_start_char=1,
                add_end_char=1,
                symbolic_end=1,
                symbolic_root=1,
                bucket=False,
                must_get_norm=True,
                input_level_ls=input_level_ls,
                verbose=verbose) if args.dev_path is not None else None
            readers_dev_ls.append(readers_dev)
        printout_allocated_gpu_memory(verbose,
                                      "{} reader dev loaded".format(model_id))

        time_load_readers_dev = time.time() - time_load_readers_dev_start
        # Load tokenizer
        printing("TIME : {} ",
                 var=[
                     OrderedDict([
                         ("time_load_readers_train",
                          "{:0.4f} min".format(time_load_readers_train / 60)),
                         ("time_load_readers_dev",
                          "{:0.4f} min".format(time_load_readers_dev / 60))
                     ])
                 ],
                 verbose=verbose,
                 verbose_level=2)

        early_stoping_val_former = 1000
        # training starts when epoch is 1
        #args.epochs += 1
        #assert args.epochs >= 1, "ERROR need at least 2 epochs (1 eval , 1 train 1 eval"
        flexible_batch_size = False

        if args.optimizer == "AdamW":
            model, optimizer, scheduler = apply_fine_tuning_strategy(
                model=model,
                fine_tuning_strategy=args.fine_tuning_strategy,
                lr_init=args.lr,
                betas=(0.9, 0.99),
                epoch=0,
                weight_decay=args.weight_decay,
                optimizer_name=args.optimizer,
                t_total=n_sent_dataset_total_train / args.batch_update_train *
                args.epochs if n_sent_dataset_total_train /
                args.batch_update_train * args.epochs > 1 else 5,
                verbose=verbose)

        try:
            for epoch in range(args.epochs):
                if args.memory_efficient_iterator:
                    # we start epoch with a new shart everytime !
                    training_file = get_new_shard(data_sharded, n_shards)
                    printing(
                        "INFO Memory efficient iterator triggered (only build for train data , starting with {}",
                        var=[training_file],
                        verbose=verbose,
                        verbose_level=1)
                    args_load_batcher_shard_data = {
                        "word_dictionary": word_dictionary,
                        "tokenizer": tokenizer,
                        "word_norm_dictionary": word_norm_dictionary,
                        "char_dictionary": char_dictionary,
                        "pos_dictionary": pos_dictionary,
                        "xpos_dictionary": xpos_dictionary,
                        "type_dictionary": type_dictionary,
                        "use_gpu": use_gpu,
                        "norm_not_norm": auxilliary_task_norm_not_norm,
                        "word_decoder": True,
                        "add_start_char": 1,
                        "add_end_char": 1,
                        "symbolic_end": 1,
                        "symbolic_root": 1,
                        "bucket": True,
                        "max_char_len": 20,
                        "must_get_norm": True,
                        "use_gpu_hardcoded_readers": False,
                        "bucketing_level": "bpe",
                        "input_level_ls": ["wordpiece"],
                        "auxilliary_task_norm_not_norm":
                        auxilliary_task_norm_not_norm,
                        "random_iterator_train": random_iterator_train
                    }

                    readers_train = readers_load(
                        datasets=args.train_path if
                        not args.memory_efficient_iterator else training_file,
                        tasks=args.tasks,
                        word_dictionary=word_dictionary,
                        bert_tokenizer=tokenizer,
                        word_dictionary_norm=word_norm_dictionary,
                        char_dictionary=char_dictionary,
                        pos_dictionary=pos_dictionary,
                        xpos_dictionary=xpos_dictionary,
                        type_dictionary=type_dictionary,
                        word_decoder=True,
                        run_mode=run_mode,
                        add_start_char=1,
                        add_end_char=1,
                        symbolic_end=1,
                        symbolic_root=1,
                        bucket=True,
                        must_get_norm=True,
                        input_level_ls=input_level_ls,
                        verbose=verbose)

                checkpointing_model_data = (epoch % saving_every_epoch == 0
                                            or epoch == (args.epochs - 1))
                # build iterator on the loaded data
                printout_allocated_gpu_memory(
                    verbose, "{} loading batcher".format(model_id))

                if args.batch_size == "flexible":
                    flexible_batch_size = True

                    printing(
                        "INFO : args.batch_size {} so updating it based on mean value {}",
                        var=[
                            args.batch_size,
                            update_batch_size_mean(readers_train)
                        ],
                        verbose=verbose,
                        verbose_level=1)
                    args.batch_size = update_batch_size_mean(readers_train)

                    if args.batch_update_train == "flexible":
                        args.batch_update_train = args.batch_size
                    printing(
                        "TRAINING : backward pass every {} step of size {} in average",
                        var=[
                            int(args.batch_update_train // args.batch_size),
                            args.batch_size
                        ],
                        verbose=verbose,
                        verbose_level=1)
                    try:
                        assert isinstance(args.batch_update_train // args.batch_size, int)\
                           and args.batch_update_train // args.batch_size > 0, \
                            "ERROR batch_size {} should be a multiple of {} ".format(args.batch_update_train, args.batch_size)
                    except Exception as e:
                        print("WARNING {}".format(e))
                batchIter_train = data_gen_multi_task_sampling_batch(
                    tasks=args.tasks,
                    readers=readers_train,
                    batch_size=readers_train[list(readers_train.keys())[0]][4],
                    max_token_per_batch=max_token_per_batch
                    if flexible_batch_size else None,
                    word_dictionary=word_dictionary,
                    char_dictionary=char_dictionary,
                    pos_dictionary=pos_dictionary,
                    word_dictionary_norm=word_norm_dictionary,
                    get_batch_mode=random_iterator_train,
                    print_raw=False,
                    dropout_input=0.0,
                    verbose=verbose)

                # -|-|-
                printout_allocated_gpu_memory(
                    verbose, "{} batcher train loaded".format(model_id))
                batchIter_dev_ls = []
                batch_size_DEV = 1

                if verbose > 1:
                    print(
                        "WARNING : batch_size for final eval was hardcoded and set to {}"
                        .format(batch_size_DEV))
                for readers_dev in readers_dev_ls:
                    batchIter_dev = data_gen_multi_task_sampling_batch(
                        tasks=args.tasks,
                        readers=readers_dev,
                        batch_size=batch_size_DEV,
                        word_dictionary=word_dictionary,
                        char_dictionary=char_dictionary,
                        pos_dictionary=pos_dictionary,
                        word_dictionary_norm=word_norm_dictionary,
                        get_batch_mode=False,
                        print_raw=False,
                        dropout_input=0.0,
                        verbose=verbose) if args.dev_path is not None else None
                    batchIter_dev_ls.append(batchIter_dev)

                model.train()
                printout_allocated_gpu_memory(
                    verbose, "{} batcher dev loaded".format(model_id))
                if args.optimizer != "AdamW":

                    model, optimizer, scheduler = apply_fine_tuning_strategy(
                        model=model,
                        fine_tuning_strategy=args.fine_tuning_strategy,
                        lr_init=args.lr,
                        betas=(0.9, 0.99),
                        weight_decay=args.weight_decay,
                        optimizer_name=args.optimizer,
                        t_total=n_sent_dataset_total_train /
                        args.batch_update_train *
                        args.epochs if n_sent_dataset_total_train /
                        args.batch_update_train * args.epochs > 1 else 5,
                        epoch=epoch,
                        verbose=verbose)
                printout_allocated_gpu_memory(
                    verbose, "{} optimizer loaded".format(model_id))
                loss_train = None

                if epoch >= 0:
                    printing("TRAINING : training on GET_BATCH_MODE ",
                             verbose=verbose,
                             verbose_level=2)
                    printing(
                        "TRAINING {} training 1 'epoch' = {} observation size args.batch_"
                        "update_train (foward {} batch_size {} backward  "
                        "(every int(args.batch_update_train//args.batch_size) step if {})) ",
                        var=[
                            model_id, n_observation_max_per_epoch_train,
                            args.batch_size, args.batch_update_train,
                            args.low_memory_foot_print_batch_mode
                        ],
                        verbose=verbose,
                        verbose_level=1)
                    loss_train, iter_train, perf_report_train, _ = epoch_run(
                        batchIter_train,
                        tokenizer,
                        args=args,
                        model_origin=model_origin,
                        pruning_mask=pruning_mask,
                        task_to_label_dictionary=task_to_label_dictionary,
                        data_label=train_data_label,
                        model=model,
                        dropout_input_bpe=args.dropout_input_bpe,
                        writer=writer,
                        iter=iter_train,
                        epoch=epoch,
                        writing_pred=epoch == (args.epochs - 1),
                        dir_end_pred=end_predictions,
                        optimizer=optimizer,
                        use_gpu=use_gpu,
                        scheduler=scheduler,
                        predict_mode=(epoch - 1) % 5 == 0,
                        skip_1_t_n=skip_1_t_n,
                        model_id=model_id,
                        reference_word_dic={"InV": inv_word_dic},
                        null_token_index=null_token_index,
                        null_str=null_str,
                        norm_2_noise_eval=False,
                        early_stoppin_metric=None,
                        n_obs_max=n_observation_max_per_epoch_train,
                        data_sharded_dir=data_sharded,
                        n_shards=n_shards,
                        n_sent_dataset_total=n_sent_dataset_total_train,
                        args_load_batcher_shard_data=
                        args_load_batcher_shard_data,
                        memory_efficient_iterator=args.
                        memory_efficient_iterator,
                        verbose=verbose)

                else:
                    printing(
                        "TRAINING : skipping first epoch to start by evaluating on devs dataset0",
                        verbose=verbose,
                        verbose_level=1)
                printout_allocated_gpu_memory(
                    verbose, "{} epoch train done".format(model_id))
                model.eval()

                if args.dev_path is not None and (epoch % 3 == 0
                                                  or epoch <= 6):
                    if verbose > 1:
                        print("RUNNING DEV on ITERATION MODE")
                    early_stoping_val_ls = []
                    loss_dev_ls = []
                    for i_dev, batchIter_dev in enumerate(batchIter_dev_ls):
                        loss_dev, iter_dev, perf_report_dev, early_stoping_val = epoch_run(
                            batchIter_dev,
                            tokenizer,
                            args=args,
                            epoch=epoch,
                            model_origin=model_origin,
                            pruning_mask=pruning_mask,
                            task_to_label_dictionary=task_to_label_dictionary,
                            iter=iter_dev,
                            use_gpu=use_gpu,
                            model=model,
                            writer=writer,
                            optimizer=None,
                            writing_pred=True,  #epoch == (args.epochs - 1),
                            dir_end_pred=end_predictions,
                            predict_mode=True,
                            data_label=dev_data_label_ls[i_dev],
                            null_token_index=null_token_index,
                            null_str=null_str,
                            model_id=model_id,
                            skip_1_t_n=skip_1_t_n,
                            dropout_input_bpe=0,
                            reference_word_dic={"InV": inv_word_dic},
                            norm_2_noise_eval=False,
                            early_stoppin_metric=early_stoppin_metric,
                            subsample_early_stoping_metric_val=
                            subsample_early_stoping_metric_val,
                            #case=case,
                            n_obs_max=n_observation_max_per_epoch_dev_test,
                            verbose=verbose)

                        printing(
                            "TRAINING : loss train:{} dev {}:{} for epoch {}  out of {}",
                            var=[
                                loss_train, i_dev, loss_dev, epoch, args.epochs
                            ],
                            verbose=1,
                            verbose_level=1)
                        printing("PERFORMANCE {} DEV {} {} ",
                                 var=[epoch, i_dev + 1, perf_report_dev],
                                 verbose=verbose,
                                 verbose_level=1)
                        early_stoping_val_ls.append(early_stoping_val)
                        loss_dev_ls.append(loss_dev)

                    else:
                        if verbose > 1:
                            print("NO DEV EVAL")
                        loss_dev, iter_dev, perf_report_dev = None, 0, None
                # NB : early_stoping_val is based on first dev set
                printout_allocated_gpu_memory(
                    verbose, "{} epoch dev done".format(model_id))

                early_stoping_val = early_stoping_val_ls[0]
                if checkpointing_model_data or early_stoping_val < early_stoping_val_former:
                    if early_stoping_val is not None:
                        _epoch = "best" if early_stoping_val < early_stoping_val_former else epoch
                    else:
                        if verbose > 1:
                            print(
                                'WARNING early_stoping_val is None so saving based on checkpointing_model_data only'
                            )
                        _epoch = epoch
                    # model_id enriched possibly with some epoch informaiton if name_with_epoch
                    _model_id = get_name_model_id_with_extra_name(
                        epoch=epoch,
                        _epoch=_epoch,
                        name_with_epoch=name_with_epoch,
                        model_id=model_id)
                    checkpoint_dir = os.path.join(
                        model_location, "{}-checkpoint.pt".format(_model_id))

                    if _epoch == "best":
                        printing(
                            "CHECKPOINT : SAVING BEST MODEL {} (epoch:{}) (new loss is {} former was {})"
                            .format(checkpoint_dir, epoch, early_stoping_val,
                                    early_stoping_val_former),
                            verbose=verbose,
                            verbose_level=1)
                        last_checkpoint_dir_best = checkpoint_dir
                        early_stoping_val_former = early_stoping_val
                        best_epoch = epoch
                        best_loss = early_stoping_val
                    else:
                        printing(
                            "CHECKPOINT : NOT SAVING BEST MODEL : new loss {} did not beat first loss {}"
                            .format(early_stoping_val,
                                    early_stoping_val_former),
                            verbose_level=1,
                            verbose=verbose)
                    last_model = ""
                    if epoch == (args.epochs - 1):
                        last_model = "last"
                    printing("CHECKPOINT : epoch {} saving {} model {} ",
                             var=[epoch, last_model, checkpoint_dir],
                             verbose=verbose,
                             verbose_level=1)
                    torch.save(model.state_dict(), checkpoint_dir)

                    args_dir = write_args(
                        dir=model_location,
                        checkpoint_dir=checkpoint_dir,
                        hyperparameters=hyperparameters
                        if name_with_epoch else None,
                        model_id=_model_id,
                        info_checkpoint=OrderedDict([
                            ("epochs", epoch + 1),
                            ("batch_size", args.batch_size
                             if not args.low_memory_foot_print_batch_mode else
                             args.batch_update_train),
                            ("train_path", train_data_label),
                            ("dev_path", dev_data_label_ls),
                            ("num_labels_per_task", num_labels_per_task)
                        ]),
                        verbose=verbose)

            if row is not None and update_status is not None:
                update_status(row=row, value="training-done", verbose=1)
        except Exception as e:
            if row is not None and update_status is not None:
                update_status(row=row, value="ERROR", verbose=1)
            raise (e)

    # reloading last (best) checkpoint
    if run_mode in ["train", "test"] and args.test_paths is not None:
        report_all = []
        if run_mode == "train" and args.epochs > 0:
            if use_gpu:
                model.load_state_dict(torch.load(last_checkpoint_dir_best))
                model = model.cuda()
                printout_allocated_gpu_memory(
                    verbose, "{} after reloading model".format(model_id))
            else:
                model.load_state_dict(
                    torch.load(last_checkpoint_dir_best,
                               map_location=lambda storage, loc: storage))
            printing(
                "MODEL : RELOADING best model of epoch {} with loss {} based on {}({}) metric (from checkpoint {})",
                var=[
                    best_epoch, best_loss, early_stoppin_metric,
                    subsample_early_stoping_metric_val,
                    last_checkpoint_dir_best
                ],
                verbose=verbose,
                verbose_level=1)

        model.eval()

        printout_allocated_gpu_memory(verbose,
                                      "{} starting test".format(model_id))
        for test_path in args.test_paths:
            assert len(test_path) == len(
                args.tasks), "ERROR test_path {} args.tasks {}".format(
                    test_path, args.tasks)
            for test, task_to_eval in zip(test_path, args.tasks):
                label_data = get_dataset_label([test], default="test")
                if len(extra_label_for_prediction) > 0:
                    label_data += "-" + extra_label_for_prediction

                if args.shuffle_bpe_embedding and args.test_mode_no_shuffle_embedding:
                    printing(
                        "TOKENIZER: as args.shuffle_bpe_embedding {} and test_mode_no_shuffle {} : reloading tokenizer with no shuffle_embedding",
                        var=[
                            args.shuffle_bpe_embedding,
                            args.test_mode_no_shuffle_embedding
                        ],
                        verbose=1,
                        verbose_level=1)
                    tokenizer = tokenizer.from_pretrained(
                        voc_tokenizer,
                        do_lower_case=args.case == "lower",
                        shuffle_bpe_embedding=False)
                readers_test = readers_load(
                    datasets=[test],
                    tasks=[task_to_eval],
                    word_dictionary=word_dictionary,
                    word_dictionary_norm=word_norm_dictionary,
                    char_dictionary=char_dictionary,
                    pos_dictionary=pos_dictionary,
                    xpos_dictionary=xpos_dictionary,
                    type_dictionary=type_dictionary,
                    bert_tokenizer=tokenizer,
                    word_decoder=True,
                    run_mode=run_mode,
                    add_start_char=1,
                    add_end_char=1,
                    symbolic_end=1,
                    symbolic_root=1,
                    bucket=bucket_test,
                    input_level_ls=input_level_ls,
                    must_get_norm=must_get_norm_test,
                    verbose=verbose)

                heuritics_zip = [None]
                gold_error_or_not_zip = [False]
                norm2noise_zip = [False]

                if heuristic_test_ls is None:
                    assert len(gold_error_or_not_zip) == len(
                        heuritics_zip) and len(heuritics_zip) == len(
                            norm2noise_zip)

                batch_size_TEST = 1
                if verbose > 1:
                    print(
                        "WARNING : batch_size for final eval was hardcoded and set to {}"
                        .format(batch_size_TEST))
                for (heuristic_test, gold_error,
                     norm_2_noise_eval) in zip(heuritics_zip,
                                               gold_error_or_not_zip,
                                               norm2noise_zip):

                    assert heuristic_test is None and not gold_error and not norm_2_noise_eval

                    batchIter_test = data_gen_multi_task_sampling_batch(
                        tasks=[task_to_eval],
                        readers=readers_test,
                        batch_size=batch_size_TEST,
                        word_dictionary=word_dictionary,
                        char_dictionary=char_dictionary,
                        pos_dictionary=pos_dictionary,
                        word_dictionary_norm=word_norm_dictionary,
                        get_batch_mode=False,
                        dropout_input=0.0,
                        verbose=verbose)
                    try:
                        loss_test, iter_test, perf_report_test, _ = epoch_run(
                            batchIter_test,
                            tokenizer,
                            args=args,
                            iter=iter_dev,
                            use_gpu=use_gpu,
                            model=model,
                            task_to_label_dictionary=task_to_label_dictionary,
                            writer=None,
                            writing_pred=True,
                            optimizer=None,
                            args_dir=args_dir,
                            model_id=model_id,
                            dir_end_pred=end_predictions,
                            skip_1_t_n=skip_1_t_n,
                            predict_mode=True,
                            data_label=label_data,
                            epoch="LAST",
                            extra_label_for_prediction=label_data,
                            null_token_index=null_token_index,
                            null_str=null_str,
                            log_perf=False,
                            dropout_input_bpe=0,
                            norm_2_noise_eval=norm_2_noise_eval,
                            compute_intersection_score=
                            compute_intersection_score_test,
                            remove_mask_str_prediction=
                            remove_mask_str_prediction,
                            reference_word_dic={"InV": inv_word_dic},
                            threshold_edit=threshold_edit,
                            verbose=verbose,
                            n_obs_max=n_observation_max_per_epoch_dev_test)
                        if verbose > 1:
                            print("LOSS TEST", loss_test)
                    except Exception as e:
                        print(
                            "ERROR (epoch_run test) {} test_path {} , heuristic {} , gold error {} , norm2noise {} "
                            .format(e, test, heuristic_test, gold_error,
                                    norm_2_noise_eval))
                        raise (e)
                    print("PERFORMANCE TEST on data  {} is {} ".format(
                        label_data, perf_report_test))
                    print("DATA WRITTEN {}".format(end_predictions))
                    if writer is not None:
                        writer.add_text(
                            "Accuracy-{}-{}-{}".format(model_id, label_data,
                                                       run_mode),
                            "After {} epochs with {} : performance is \n {} ".
                            format(args.epochs, description,
                                   str(perf_report_test)), 0)
                    else:
                        printing(
                            "WARNING : could not add accuracy to tensorboard cause writer was found None",
                            verbose=verbose,
                            verbose_level=2)
                    report_all.extend(perf_report_test)
                    printout_allocated_gpu_memory(
                        verbose, "{} test done".format(model_id))
    else:
        printing("ERROR : EVALUATION none cause {} empty or run_mode {} ",
                 var=[args.test_paths, run_mode],
                 verbose_level=1,
                 verbose=verbose)

    if writer is not None:
        writer.close()
        printing("tensorboard --logdir={} --host=localhost --port=1234 ",
                 var=[tensorboard_log],
                 verbose_level=1,
                 verbose=verbose)

    report_dir = os.path.join(model_location, model_id + "-report.json")
    if report_full_path_shared is not None:
        report_full_dir = os.path.join(report_full_path_shared,
                                       args.overall_label + "-report.json")
        if os.path.isfile(report_full_dir):
            report = json.load(open(report_full_dir, "r"))
        else:
            report = []
            printing("REPORT = creating overall report at {} ",
                     var=[report_dir],
                     verbose=verbose,
                     verbose_level=1)
        report.extend(report_all)
        json.dump(report, open(report_full_dir, "w"))
        printing("{} {} ",
                 var=[REPORT_FLAG_DIR_STR, report_full_dir],
                 verbose=0,
                 verbose_level=0)

    json.dump(report_all, open(report_dir, "w"))
    printing("REPORTING TO {}".format(report_dir),
             verbose=verbose,
             verbose_level=1)
    if report_full_path_shared is None:
        printing("WARNING ; report_full_path_shared is None",
                 verbose=verbose,
                 verbose_level=1)
        printing("{} {} ",
                 var=[REPORT_FLAG_DIR_STR, report_dir],
                 verbose=verbose,
                 verbose_level=0)

    return model
def main(args, dict_path, model_dir):

    model, tokenizer, run_id = load_all_analysis(args, dict_path, model_dir)

    if args.compare_to_pretrained:
        print("Loading Pretrained model also for comparison with pretrained")
        args_origin = args_attention_analysis()
        args_origin = args_preprocess_attention_analysis(args_origin)
        args_origin.init_args_dir = None
        args_origin, dict_path_0, model_dir_0 = get_dirs(args_origin)
        args_origin.model_id_pref += "again"
        model_origin, tokenizer_0, _ = load_all_analysis(
            args_origin, dict_path_0, model_dir_0)
        model_origin.eval()
        print("seco,")
    # only allow output of the model to be hidden states here
    print("Checkpoint loaded")
    assert not args.output_attentions
    assert args.output_all_encoded_layers and args.output_hidden_states_per_head

    data = ["I am here", "How are you"]
    model.eval()
    n_obs = args.n_sent
    max_len = args.max_seq_len
    lang_ls = args.raw_text_code

    lang = [
        "fr_pud", "de_pud", "ru_pud", "tr_pud", "id_pud", "ar_pud", "pt_pud",
        "es_pud", "fi_pud", "it_pud", "sv_pud", "cs_pud", "pl_pud", "hi_pud",
        "zh_pud", "ko_pud", "ja_pud", "th_pud"
    ]
    #lang = ["fr_pud", "fr_gsd"]
    src_lang_ls = ["en_pud"]  #, "fr_pud", "ru_pud", "ar_pud"]
    print("Loading data...")
    data_target_ls = [
        load_data(DATA_UD + f"/{target}-ud-test.conllu",
                  line_filter="# text = ") for target in lang
    ]
    data_target_dic = OrderedDict([
        (lang, data) for lang, data in zip(lang, data_target_ls)
    ])
    #pdb.set_trace()
    src = src_lang_ls[0]  #"en_pud"

    data_en = data_target_ls[
        0]  #load_data(DATA_UD+f"/{src}-ud-test.conllu", line_filter="# text = ")

    for _data_target in data_target_dic:
        try:
            assert len(data_target_dic[_data_target]) == len(
                data_en
            ), f"Should have as much sentences on both sides en:{len(data_en)} target:{len(data_target_dic[_data_target])}"
        except:
            data_en = data_en[:len(data_target_dic[_data_target])]
            print(f"Cutting {src} dataset based on target")
        assert len(data_target_dic[_data_target]) == len(
            data_en
        ), f"Should have as much sentences on both sides en:{len(data_en)} target:{len(data_target_dic[_data_target])}"
    #reg = linear_model.LogisticRegression()
    # just to get the keyw
    layer_all = get_hidden_representation(data,
                                          model,
                                          tokenizer,
                                          max_len=max_len)
    # removed hidden_per_layer
    #pdb.set_trace()
    assert len(layer_all) == 1
    #assert len(layer_all) == 2, "ERROR should only have hidden_per_layer and hidden_per_head_layer"

    report_ls = []
    accuracy_dic = OrderedDict()
    sampling = args.sampling
    metric = args.similarity_metric
    if metric == "cka":
        pad_below_max_len, output_dic = False, True
    else:
        pad_below_max_len, output_dic = False, True
    assert metric in ["cos", "cka"]
    if metric == "cos":
        batch_size = 1
    else:
        batch_size = len(data_en) // 4

    task_tuned = "No"

    if args.init_args_dir is None:
        #args.init_args_dir =
        id_model = f"{args.bert_model}-init-{args.random_init}"

        hyperparameters = OrderedDict([
            ("bert_model", args.bert_model),
            ("random_init", args.random_init),
            ("not_load_params_ls", args.not_load_params_ls),
            ("dict_path", dict_path),
            ("model_id", id_model),
        ])
        info_checkpoint = OrderedDict([("epochs", 0),
                                       ("batch_size", batch_size),
                                       ("train_path", 0), ("dev_path", 0),
                                       ("num_labels_per_task", 0)])

        args.init_args_dir = write_args(os.environ.get("MT_NORM_PARSE", "./"),
                                        model_id=id_model,
                                        info_checkpoint=info_checkpoint,
                                        hyperparameters=hyperparameters,
                                        verbose=1)
        print("args_dir checkout ", args.init_args_dir)
        model_full_name_val = task_tuned + "-" + id_model
    else:
        argument = json.load(open(args.init_args_dir, 'r'))
        task_tuned = argument["hyperparameters"]["tasks"][0][
            0] if not "wiki" in argument["info_checkpoint"][
                "train_path"] else "ner"
        model_full_name_val = task_tuned + "-" + args.init_args_dir.split(
            "/")[-1]

    if args.analysis_mode == "layer":
        studied_ind = 0
    elif args.analysis_mode == "layer_head":
        studied_ind = 1
    else:
        raise (
            Exception(f"args.analysis_mode : {args.analysis_mode} corrupted"))
    layer_analysed = layer_all[studied_ind]

    #for ind, layer_head in enumerate(list(layer_analysed.keys())):
    report = OrderedDict()
    accuracy_ls = []
    src_lang = src

    cosine_sent_to_src = OrderedDict([(src_lang + "-" + lang, OrderedDict())
                                      for src_lang in src_lang_ls
                                      for lang in data_target_dic.keys()])
    cosine_sent_to_origin = OrderedDict([(lang, OrderedDict())
                                         for lang in data_target_dic.keys()])
    cosine_sent_to_origin_src = OrderedDict([(lang, OrderedDict())
                                             for lang in src_lang_ls])
    cosine_sent_to_former_layer_src = OrderedDict([(lang, OrderedDict())
                                                   for lang in src_lang_ls])
    cosine_sent_to_former_layer = OrderedDict([
        (lang, OrderedDict()) for lang in data_target_dic.keys()
    ])
    cosine_sent_to_first_layer = OrderedDict([
        (lang, OrderedDict()) for lang in data_target_dic.keys()
    ])
    #layer_head = list(layer_analysed.keys())[len(list(layer_analysed.keys())) - ind -1]

    cosinus = nn.CosineSimilarity(dim=1)
    info_model = f" task {args.tasks} args {'/'.join(args.init_args_dir.split('/')[-2:]) if args.init_args_dir is not None else None} bert {args.bert_model} random init {args.random_init} "
    #"cka"
    output_dic = True
    pad_below_max_len = False
    max_len = 200

    n_batch = len(data_en) // batch_size

    for i_data in range(n_batch):

        for src_lang in src_lang_ls:
            print(f"Starting src", {src_lang})
            data_en = load_data(DATA_UD + f"/{src_lang}-ud-test.conllu",
                                line_filter="# text = ")
            en_batch = data_en[i_data:i_data + batch_size]
            all = get_hidden_representation(
                en_batch,
                model,
                tokenizer,
                pad_below_max_len=pad_below_max_len,
                max_len=max_len,
                output_dic=output_dic)
            analysed_batch_dic_en = all[studied_ind]
            i_lang = 0

            if args.compare_to_pretrained:
                all_origin = get_hidden_representation(
                    en_batch,
                    model_origin,
                    tokenizer_0,
                    pad_below_max_len=pad_below_max_len,
                    max_len=max_len,
                    output_dic=output_dic)
                analysed_batch_dic_src_origin = all_origin[studied_ind]

            for lang, target in data_target_dic.items():
                print(f"Starting target", {lang})
                i_lang += 1
                target_batch = target[i_data:i_data + batch_size]

                all = get_hidden_representation(
                    target_batch,
                    model,
                    tokenizer,
                    pad_below_max_len=pad_below_max_len,
                    max_len=max_len,
                    output_dic=output_dic)

                if args.compare_to_pretrained:
                    all_origin = get_hidden_representation(
                        target_batch,
                        model_origin,
                        tokenizer_0,
                        pad_below_max_len=pad_below_max_len,
                        max_len=max_len,
                        output_dic=output_dic)
                    analysed_batch_dic_target_origin = all_origin[studied_ind]

                analysed_batch_dic_target = all[studied_ind]

                former_layer, former_mean_target, former_mean_src = None, None, None
                for layer in analysed_batch_dic_target:
                    print(f"Starting layer", {layer})
                    # get average for sentence removing first and last special tokens
                    if output_dic:
                        mean_over_sent_src = []
                        mean_over_sent_target = []
                        mean_over_sent_target_origin = []
                        mean_over_sent_src_origin = []
                        for i_sent in range(len(analysed_batch_dic_en[layer])):
                            # removing special characters first and last and
                            mean_over_sent_src.append(
                                np.array(analysed_batch_dic_en[layer][i_sent][
                                    0, 1:-1, :].mean(dim=0).cpu()))
                            mean_over_sent_target.append(
                                np.array(analysed_batch_dic_target[layer]
                                         [i_sent][0,
                                                  1:-1, :].mean(dim=0).cpu()))

                            if args.compare_to_pretrained:
                                mean_over_sent_target_origin.append(
                                    np.array(
                                        analysed_batch_dic_target_origin[layer]
                                        [i_sent][0,
                                                 1:-1, :].mean(dim=0).cpu()))
                                if i_lang == 1:
                                    mean_over_sent_src_origin.append(
                                        np.array(analysed_batch_dic_src_origin[
                                            layer][i_sent][0, 1:-1, :].mean(
                                                dim=0).cpu()))

                        if args.compare_to_pretrained:
                            mean_over_sent_target_origin = np.array(
                                mean_over_sent_target_origin)
                            if i_lang == 1:
                                mean_over_sent_src_origin = np.array(
                                    mean_over_sent_src_origin)
                        mean_over_sent_src = np.array(mean_over_sent_src)
                        mean_over_sent_target = np.array(mean_over_sent_target)

                    else:
                        mean_over_sent_src = analysed_batch_dic_en[
                            layer][:, 1:-1, :].mean(dim=1).cpu()
                        mean_over_sent_target = analysed_batch_dic_target[
                            layer][:, 1:-1, :].mean(dim=1).cpu()

                    if layer not in cosine_sent_to_src[src_lang + "-" + lang]:
                        cosine_sent_to_src[src_lang + "-" + lang][layer] = []
                    if layer not in cosine_sent_to_origin[lang]:
                        cosine_sent_to_origin[lang][layer] = []
                    if layer not in cosine_sent_to_origin_src[src_lang]:
                        cosine_sent_to_origin_src[src_lang][layer] = []

                    if metric == "cka":
                        mean_over_sent_src = np.array(mean_over_sent_src)
                        mean_over_sent_target = np.array(mean_over_sent_target)

                        cosine_sent_to_src[src_lang + "-" +
                                           lang][layer].append(
                                               kernel_CKA(
                                                   mean_over_sent_src,
                                                   mean_over_sent_target))
                        if args.compare_to_pretrained:
                            cosine_sent_to_origin[lang][layer].append(
                                kernel_CKA(mean_over_sent_target,
                                           mean_over_sent_target_origin))
                            if i_lang == 1:
                                cosine_sent_to_origin_src[src_lang][
                                    layer].append(
                                        kernel_CKA(mean_over_sent_src_origin,
                                                   mean_over_sent_src))
                                print(
                                    f"Measured EN TO ORIGIN {metric} {layer} {cosine_sent_to_origin_src[src_lang][layer][-1]} "
                                    + info_model)
                            print(
                                f"Measured LANG {lang} TO ORIGIN {metric} {layer} {cosine_sent_to_origin[lang][layer][-1]} "
                                + info_model)

                        print(
                            f"Measured {metric} {layer} {kernel_CKA(mean_over_sent_src,mean_over_sent_target)} "
                            + info_model)
                    else:
                        cosine_sent_to_src[
                            src_lang + "-" + lang][layer].append(
                                cosinus(mean_over_sent_src,
                                        mean_over_sent_target).item())

                    if former_layer is not None:
                        if layer not in cosine_sent_to_former_layer[lang]:
                            cosine_sent_to_former_layer[lang][layer] = []
                        if layer not in cosine_sent_to_former_layer_src[
                                src_lang]:
                            cosine_sent_to_former_layer_src[src_lang][
                                layer] = []
                        if metric == "cka":
                            cosine_sent_to_former_layer[lang][layer].append(
                                kernel_CKA(former_mean_target,
                                           mean_over_sent_target))
                            if i_lang == 1:
                                cosine_sent_to_former_layer_src[src_lang][
                                    layer].append(
                                        kernel_CKA(former_mean_src,
                                                   mean_over_sent_src))
                        else:
                            cosine_sent_to_former_layer[lang][layer].append(
                                cosinus(former_mean_target,
                                        mean_over_sent_target).item())
                            if i_lang == 1:
                                cosine_sent_to_former_layer_src[src_lang][
                                    layer].append(
                                        cosinus(former_mean_target,
                                                mean_over_sent_target).item())

                    former_layer = layer
                    former_mean_target = mean_over_sent_target
                    former_mean_src = mean_over_sent_src

    # summary
    print_all = True
    lang_i = 0
    src_lang_i = 0
    #for lang, cosine_per_layer in cosine_sent_to_src.items():
    for lang, cosine_per_layer in cosine_sent_to_former_layer.items():
        layer_i = 0
        src_lang_i += 1
        for src_lang in src_lang_ls:
            lang_i += 1
            for layer, cosine_ls in cosine_per_layer.items():
                print(
                    f"Mean {metric} between {src_lang} and {lang} for {layer} is {np.mean(cosine_sent_to_src[src_lang+'-'+lang][layer])} std:{np.std(cosine_sent_to_src[src_lang+'-'+lang][layer])} measured on {len(cosine_sent_to_src[src_lang+'-'+lang][layer])} model  "
                    + info_model)
                if layer_i > 0 and print_all:

                    print(
                        f"Mean {metric} for {lang} beween {layer} and former is {np.mean(cosine_sent_to_former_layer[lang][layer])} std:{np.std(cosine_sent_to_former_layer[lang][layer])} measured on {len(cosine_sent_to_former_layer[lang][layer])} model "
                        + info_model)

                    report = report_template(
                        metric_val=metric,
                        subsample=lang + "_to_former_layer",
                        info_score_val=None,
                        score_val=np.mean(
                            cosine_sent_to_former_layer[lang][layer]),
                        n_sents=n_obs,
                        avg_per_sent=np.std(
                            cosine_sent_to_former_layer[lang][layer]),
                        n_tokens_score=n_obs * max_len,
                        model_full_name_val=model_full_name_val,
                        task="hidden_state_analysis",
                        evaluation_script_val="exact_match",
                        model_args_dir=args.init_args_dir,
                        token_type="word",
                        report_path_val=None,
                        data_val=layer,
                    )
                    report_ls.append(report)

                    if lang_i == 1:
                        print(
                            f"Mean {metric} for {lang} beween {layer} and former is {np.mean(cosine_sent_to_former_layer_src[src_lang][layer])} std:{np.std(cosine_sent_to_former_layer_src[src_lang][layer])} measured on {len(cosine_sent_to_former_layer_src[src_lang][layer])} model "
                            + info_model)

                        report = report_template(
                            metric_val=metric,
                            subsample=src_lang + "_to_former_layer",
                            info_score_val=None,
                            score_val=np.mean(
                                cosine_sent_to_former_layer_src[src_lang]
                                [layer]),
                            n_sents=n_obs,
                            avg_per_sent=np.std(
                                cosine_sent_to_former_layer_src[src_lang]
                                [layer]),
                            n_tokens_score=n_obs * max_len,
                            model_full_name_val=model_full_name_val,
                            task="hidden_state_analysis",
                            evaluation_script_val="exact_match",
                            model_args_dir=args.init_args_dir,
                            token_type="word",
                            report_path_val=None,
                            data_val=layer,
                        )
                        report_ls.append(report)

                layer_i += 1

                report = report_template(
                    metric_val=metric,
                    subsample=lang + "_to_" + src_lang,
                    info_score_val=None,
                    score_val=np.mean(cosine_sent_to_src[src_lang + '-' +
                                                         lang][layer]),
                    n_sents=n_obs,
                    #avg_per_sent=np.std(cosine_ls),
                    avg_per_sent=np.std(cosine_sent_to_src[src_lang + '-' +
                                                           lang][layer]),
                    n_tokens_score=n_obs * max_len,
                    model_full_name_val=model_full_name_val,
                    task="hidden_state_analysis",
                    evaluation_script_val="exact_match",
                    model_args_dir=args.init_args_dir,
                    token_type="word",
                    report_path_val=None,
                    data_val=layer,
                )

                report_ls.append(report)

                #
                if args.compare_to_pretrained:

                    print(
                        f"Mean {metric} for {lang} beween {layer} and origin model is {np.mean(cosine_sent_to_origin[lang][layer])} std:{np.std(cosine_sent_to_origin[lang][layer])} measured on {len(cosine_sent_to_origin[lang][layer])} model "
                        + info_model)
                    report = report_template(
                        metric_val=metric,
                        subsample=lang + "_to_origin",
                        info_score_val=None,
                        score_val=np.mean(cosine_sent_to_origin[lang][layer]),
                        n_sents=n_obs,
                        avg_per_sent=np.std(
                            cosine_sent_to_origin[lang][layer]),
                        n_tokens_score=n_obs * max_len,
                        model_full_name_val=model_full_name_val,
                        task="hidden_state_analysis",
                        evaluation_script_val="exact_match",
                        model_args_dir=args.init_args_dir,
                        token_type="word",
                        report_path_val=None,
                        data_val=layer)
                    report_ls.append(report)

                    if lang_i == 1:

                        print(
                            f"Mean {metric} for en beween {layer} and origin model is {np.mean(cosine_sent_to_origin_src[src_lang][layer])} std:{np.std(cosine_sent_to_origin_src[src_lang][layer])} measured on {len(cosine_sent_to_origin_src[src_lang][layer])} model "
                            + info_model)
                        report = report_template(
                            metric_val=metric,
                            subsample=src_lang + "_to_origin",
                            info_score_val=None,
                            score_val=np.mean(
                                cosine_sent_to_origin_src[src_lang][layer]),
                            n_sents=n_obs,
                            avg_per_sent=np.std(
                                cosine_sent_to_origin_src[src_lang][layer]),
                            n_tokens_score=n_obs * max_len,
                            model_full_name_val=model_full_name_val,
                            task="hidden_state_analysis",
                            evaluation_script_val="exact_match",
                            model_args_dir=args.init_args_dir,
                            token_type="word",
                            report_path_val=None,
                            data_val=layer)
                        report_ls.append(report)

        # break

    if args.report_dir is None:
        report_dir = PROJECT_PATH + f"/../../analysis/attention_analysis/report/{run_id}-report"
        os.mkdir(report_dir)
    else:
        report_dir = args.report_dir
    assert os.path.isdir(report_dir)
    with open(report_dir + "/report.json", "w") as f:
        json.dump(report_ls, f)

    overall_report = args.overall_report_dir + "/" + args.overall_label + "-grid-report.json"
    with open(overall_report, "r") as g:
        report_all = json.load(g)
        report_all.extend(report_ls)
    with open(overall_report, "w") as file:
        json.dump(report_all, file)

    print("{} {} ".format(REPORT_FLAG_DIR_STR, overall_report))