Esempio n. 1
0
def append_reporting_sheet(git_id,
                           tasks,
                           rioc_job,
                           description,
                           log_dir,
                           target_dir,
                           env,
                           status,
                           verbose=1):
    sheet, sheet_name, tab_name = open_client()
    # Find a workbook by name and open the first sheet
    # Make sure you use the right name here.
    #worksheet_list = sheet.worksheets()
    if not rioc_job.startswith("local"):
        sheet.append_row([
            git_id, rioc_job, tasks, description, log_dir, target_dir, env,
            status, None, None, None, None, "-"
        ])
        list_of_hashes = sheet.get_all_records()
        printing(
            "REPORT : Appending report to page {} in sheet {} of {} rows and {} columns ",
            var=[
                tab_name, sheet_name,
                len(list_of_hashes) + 1,
                len(list_of_hashes[0])
            ],
            verbose=verbose,
            verbose_level=1)
    else:
        print("LOCAL env not updating sheet")
        list_of_hashes = ["NOTHING"]
    return len(list_of_hashes) + 1, len(list_of_hashes[0])
Esempio n. 2
0
def update_status(row, value, col_number=8, sheet=None, verbose=1):
    if sheet is None:
        sheet, sheet_name, tab_name = open_client()
    if value is not None:
        sheet.update_cell(row, col_number, value)
        printing("REPORT : col {} updated in sheet with {} ",
                 var=[col_number, value],
                 verbose=verbose,
                 verbose_level=1)
def build_shard(dir_shard, dir_file, n_sent_max_per_file, format="conll",dry_run=False, verbose=1):

    onlyfiles = [f for f in listdir(dir_shard) if isfile(join(dir_shard, f))]
    if len(onlyfiles) > 0:
        n_shards = len(onlyfiles)
        n_sents = 0
        for file in onlyfiles:
            n_sents += count_conll_n_sent(os.path.join(dir_shard, file))

        printing("INFO : shards already filled in {} files {} sentences total", var=[n_shards, n_sents],
                 verbose=1, verbose_level=1)
        return dir_shard, n_shards, n_sents

    assert format in "conll"
    assert len(dir_file) == 1, "ONLY 1 set of simultaneous task supported for sharding"
    printing("STARTING SHARDING {} of {} ".format(dir_shard, dir_file), verbose=verbose, verbose_level=1)
    dir_file = dir_file[0]
    n_sents = count_conll_n_sent(dir_file)
    n_shards = n_sents//n_sent_max_per_file

    if n_shards == 0:
        printing("INFO SHARDING : n_sent_max_per_file is lower that number of files in {} so only building 1 shard", var=[dir_file], verbose=verbose, verbose_level=1)
        n_shards += 1
    split_randomly(n_shards, dir_shard, dir_file, n_sents, dry_run=dry_run)
    sys.stdout.flush()

    printing("INFO SHARD n_sent written {} splitted in {} files with "
             "in average {} sent per file written to {}",
             var=[n_sents, n_shards,n_sent_max_per_file, dir_shard], verbose=verbose, verbose_level=1)

    return dir_shard, n_shards, n_sents
Esempio n. 4
0
def get_new_shard(shard_path, n_shards, rand=True, verbose=1):
    # pick a new file randomly

    assert rand

    i_shard = random.choice(range(n_shards))

    path = os.path.join(shard_path, "train_{}.conll".format(i_shard))

    assert os.path.isfile(path), "ERROR {}".format(path)

    printing("INFO : picking shard {} ",
             var=[path],
             verbose=verbose,
             verbose_level=1)
    return [path]
def setup_repoting_location(root_dir_checkpoints,
                            model_suffix="",
                            shared_id=None,
                            data_sharded=None,
                            verbose=1):
    """
    create an id for a model and locations for checkpoints, dictionaries, tensorboard logs, data
    :param model_suffix:
    :param verbose:
    :return:
    """
    model_local_id = str(uuid4())[:5]
    if shared_id is not None:
        if len(shared_id) > 0:
            model_local_id = shared_id + "-" + model_local_id
    if model_suffix != "":
        model_local_id += "-" + model_suffix
    model_location = os.path.join(root_dir_checkpoints, model_local_id)
    dictionaries = os.path.join(root_dir_checkpoints, model_local_id,
                                "dictionaries")
    tensorboard_log = os.path.join(root_dir_checkpoints, model_local_id,
                                   "tensorboard")
    end_predictions = os.path.join(root_dir_checkpoints, model_local_id,
                                   "predictions")

    os.mkdir(model_location)

    if data_sharded is None:
        data_sharded = os.path.join(root_dir_checkpoints, model_local_id,
                                    "shards")
        os.mkdir(data_sharded)
    else:
        assert os.path.isdir(
            data_sharded), "ERROR data_sharded not dir {} ".format(
                data_sharded)
        printing("INFO DATA already sharded in {}",
                 var=[data_sharded],
                 verbose=verbose,
                 verbose_level=1)
    printing("CHECKPOINTING model location:{}",
             var=[model_location],
             verbose=verbose,
             verbose_level=1)
    printing("CHECKPOINTING model ID:{}",
             var=[model_local_id],
             verbose=verbose,
             verbose_level=1)
    os.mkdir(dictionaries)
    os.mkdir(tensorboard_log)
    os.mkdir(end_predictions)
    printing(
        "CHECKPOINTING \n- {} for checkpoints \n- {} for dictionaries created \n- {} predictions {} ",
        var=[model_location, dictionaries, end_predictions, data_sharded],
        verbose_level=1,
        verbose=verbose)
    return model_local_id, model_location, dictionaries, tensorboard_log, end_predictions, data_sharded
def get_early_stopping_metric(tasks,
                              verbose,
                              main_task=None,
                              early_stoppin_metric=None,
                              subsample_early_stoping_metric_val=None):
    """
    getting early stopping metric and evaluation subsample
    if early_stoppin_metric is None : uses first eval_metrics stated in TASKS_PARAMETER of the first task of the list passed in args.tasks
    :return:
    """
    if main_task is None:
        printing(
            "INFO : default main task provided is the first of the first list {} ",
            var=[tasks],
            verbose=verbose,
            verbose_level=1)
        if isinstance(tasks[0], list):
            main_task = tasks[0][0]
        else:
            main_task = tasks[0]

    if early_stoppin_metric is None:
        early_stoppin_metric = TASKS_PARAMETER[main_task]["eval_metrics"][0][0]

        printing(
            "INFO : default early_stoppin_metric is early_stoppin_metric  {} first one of "
            "the first possible in TASK_PARAMETER",
            var=[early_stoppin_metric],
            verbose=verbose,
            verbose_level=1)

    if subsample_early_stoping_metric_val is None:
        get_subsample = TASKS_PARAMETER[main_task].get("default-subsample")
        if get_subsample is None:
            get_subsample = "all"
            printing(
                "INFO : early stopping subsample is set to default {} all as not found in {}",
                var=["all", TASKS_PARAMETER[main_task]],
                verbose=verbose,
                verbose_level=1)
        subsample_early_stoping_metric_val = get_subsample
        assert subsample_early_stoping_metric_val in TASKS_PARAMETER[main_task][
            "subsample-allowed"], "ERROR task {} subsample not in {} ".format(
                main_task, subsample_early_stoping_metric_val)
    #sanity_check_early_stop_metric(early_stoppin_metric, TASKS_PARAMETER, tasks)

    return early_stoppin_metric, subsample_early_stoping_metric_val
Esempio n. 7
0
def run(args,
        n_observation_max_per_epoch_train,
        vocab_size,
        model_dir,
        voc_tokenizer,
        auxilliary_task_norm_not_norm,
        null_token_index,
        null_str,
        tokenizer,
        n_observation_max_per_epoch_dev_test=None,
        run_mode="train",
        dict_path=None,
        end_predictions=None,
        report=True,
        model_suffix="",
        description="",
        saving_every_epoch=10,
        model_location=None,
        model_id=None,
        report_full_path_shared=None,
        skip_1_t_n=False,
        heuristic_test_ls=None,
        remove_mask_str_prediction=False,
        inverse_writing=False,
        extra_label_for_prediction="",
        random_iterator_train=True,
        bucket_test=False,
        must_get_norm_test=True,
        early_stoppin_metric=None,
        subsample_early_stoping_metric_val=None,
        compute_intersection_score_test=True,
        threshold_edit=3,
        name_with_epoch=False,
        max_token_per_batch=200,
        encoder=None,
        debug=False,
        verbose=1):
    """
    Wrapper for training/prediction/evaluation

    2 modes : train (will train using train and dev iterators with test at the end on test_path)
              test : only test at the end : requires all directories to be created
    :return:
    """
    assert run_mode in ["train", "test"
                        ], "ERROR run mode {} corrupted ".format(run_mode)
    input_level_ls = ["wordpiece"]
    assert early_stoppin_metric is not None and subsample_early_stoping_metric_val is not None, "ERROR : assert early_stoppin_metric should be defined and subsample_early_stoping_metric_val "
    if n_observation_max_per_epoch_dev_test is None:
        n_observation_max_per_epoch_dev_test = n_observation_max_per_epoch_train
    printing("MODEL : RUNNING IN {} mode",
             var=[run_mode],
             verbose=verbose,
             verbose_level=1)
    printing(
        "WARNING : casing was set to {} (this should be consistent at train and test)",
        var=[args.case],
        verbose=verbose,
        verbose_level=1)

    if len(args.tasks) == 1:
        printing("INFO : MODEL : 1 set of simultaneous tasks {}".format(
            args.tasks),
                 verbose=verbose,
                 verbose_level=1)

    if run_mode == "test":
        assert args.test_paths is not None and isinstance(
            args.test_paths, list)
    if run_mode == "train":
        printing("CHECKPOINTING info : "
                 "saving model every {}",
                 var=saving_every_epoch,
                 verbose=verbose,
                 verbose_level=1)

    use_gpu = use_gpu_(use_gpu=None, verbose=verbose)

    train_data_label = get_dataset_label(args.train_path, default="train")

    if not debug:
        pdb.set_trace = lambda: None

    iter_train = 0
    iter_dev = 0
    row = None
    writer = None

    printout_allocated_gpu_memory(verbose, "{} starting all".format(model_id))

    if run_mode == "train":
        if os.path.isdir(args.train_path[0]) and len(args.train_path) == 1:
            data_sharded = args.train_path[0]
            printing(
                "INFO args.train_path is directory so not rebuilding shards",
                verbose=verbose,
                verbose_level=1)
        elif os.path.isdir(args.train_path[0]):
            raise (Exception(
                " {} is a directory but len is more than one , not supported".
                format(args.train_path[0], len(args.train_path))))
        else:
            data_sharded = None
        assert model_location is None and model_id is None, "ERROR we are creating a new one "

        model_id, model_location, dict_path, tensorboard_log, end_predictions, data_sharded \
            = setup_repoting_location(model_suffix=model_suffix, data_sharded=data_sharded,
                                      root_dir_checkpoints=CHECKPOINT_BERT_DIR,
                                      shared_id=args.overall_label, verbose=verbose)
        hyperparameters = get_hyperparameters_dict(
            args,
            args.case,
            random_iterator_train,
            seed=args.seed,
            verbose=verbose,
            dict_path=dict_path,
            model_id=model_id,
            model_location=model_location)
        args_dir = write_args(model_location,
                              model_id=model_id,
                              hyperparameters=hyperparameters,
                              verbose=verbose)

        if report:
            if report_full_path_shared is not None:
                tensorboard_log = os.path.join(report_full_path_shared,
                                               "tensorboard")
            printing("tensorboard --logdir={} --host=localhost --port=1234 ",
                     var=[tensorboard_log],
                     verbose_level=1,
                     verbose=verbose)
            writer = SummaryWriter(log_dir=tensorboard_log)
            if writer is not None:
                writer.add_text("INFO-ARGUMENT-MODEL-{}".format(model_id),
                                str(hyperparameters), 0)
    else:
        args_checkpoint = json.load(open(args.init_args_dir, "r"))
        print(f"LOADING MODEL using {args_checkpoint}")
        dict_path = args_checkpoint["hyperparameters"]["dict_path"]
        if not os.path.isdir(dict_path):
            dict_path = os.path.join(args.init_args_dir, dict_path)
        assert dict_path is not None and os.path.isdir(
            dict_path), "ERROR {} ".format(dict_path)
        end_predictions = args.end_predictions
        assert end_predictions is not None and os.path.isdir(
            end_predictions), "ERROR end_predictions"
        model_location = args_checkpoint["hyperparameters"]["model_location"]
        model_id = args_checkpoint["hyperparameters"]["model_id"]
        assert model_location is not None and model_id is not None, "erorr model_location model_id "
        args_dir = os.path.join(model_location,
                                "{}-args.json".format(model_id))

        printing(
            "CHECKPOINTING : starting writing log \ntensorboard --logdir={} --host=localhost --port=1234 ",
            var=[os.path.join(model_id, "tensorboard")],
            verbose_level=1,
            verbose=verbose)

    # build or make dictionaries
    _dev_path = args.dev_path if args.dev_path is not None else args.train_path
    word_dictionary, word_norm_dictionary, char_dictionary, pos_dictionary, \
    xpos_dictionary, type_dictionary = \
        conllu_data.load_dict(dict_path=dict_path,
                              train_path=args.train_path if run_mode == "train" else None,
                              dev_path=args.dev_path if run_mode == "train" else None,
                              test_path=None,
                              word_embed_dict={},
                              dry_run=False,
                              expand_vocab=False,
                              word_normalization=True,
                              force_new_dic=True if run_mode == "train" else False,
                              tasks=args.tasks,
                              pos_specific_data_set=args.train_path[1] if len(args.tasks) > 1 and len(args.train_path)>1 and "pos" in args.tasks else None,
                              case=args.case,
                              # if not normalize pos or parsing in tasks we don't need dictionary
                              do_not_fill_dictionaries=len(set(["normalize", "pos", "parsing"])&set([task for tasks in args.tasks for task in tasks])) == 0,
                              add_start_char=1 if run_mode == "train" else None,
                              verbose=1)
    # we flatten the tasks
    printing("DICTIONARY CREATED/LOADED", verbose=verbose, verbose_level=1)
    num_labels_per_task, task_to_label_dictionary = get_vocab_size_and_dictionary_per_task(
        [task for tasks in args.tasks for task in tasks],
        vocab_bert_wordpieces_len=vocab_size,
        pos_dictionary=pos_dictionary,
        type_dictionary=type_dictionary,
        task_parameters=TASKS_PARAMETER)
    voc_pos_size = num_labels_per_task["pos"] if "pos" in args.tasks else None
    if voc_pos_size is not None:
        printing("MODEL : voc_pos_size defined as {}",
                 var=voc_pos_size,
                 verbose_level=1,
                 verbose=verbose)
    printing("MODEL init...", verbose=verbose, verbose_level=1)
    tokenizer = tokenizer.from_pretrained(voc_tokenizer,
                                          do_lower_case=args.case == "lower")
    mask_id = tokenizer.convert_tokens_to_ids(
        tokenizer.mask_token)  #convert_tokens_to_ids([MASK_BERT])[0]
    printout_allocated_gpu_memory(verbose,
                                  "{} loading model ".format(model_id))
    model = get_model_multi_task_bert(args=args,
                                      model_dir=model_dir,
                                      encoder=encoder,
                                      num_labels_per_task=num_labels_per_task,
                                      mask_id=mask_id)

    if use_gpu:
        model.to("cuda")
        printing("MODEL TO CUDA", verbose=verbose, verbose_level=1)
    printing("MODEL model.config {} ",
             var=[model.config],
             verbose=verbose,
             verbose_level=1)
    printout_allocated_gpu_memory(verbose, "{} model loaded".format(model_id))
    model_origin = OrderedDict()
    pruning_mask = OrderedDict()
    printout_allocated_gpu_memory(verbose, "{} model cuda".format(model_id))
    for name, param in model.named_parameters():
        model_origin[name] = param.detach().clone()
        printout_allocated_gpu_memory(verbose, "{} param cloned ".format(name))
        if args.penalization_mode == "pruning":
            abs = torch.abs(param.detach().flatten())
            median_value = torch.median(abs)
            pruning_mask[name] = (abs > median_value).float()
        printout_allocated_gpu_memory(
            verbose, "{} pruning mask loaded".format(model_id))

    printout_allocated_gpu_memory(verbose, "{} model clone".format(model_id))

    inv_word_dic = word_dictionary.instance2index
    # load , mask, bucket and index data
    assert tokenizer is not None, "ERROR : tokenizer is None , voc_tokenizer failed to be loaded {}".format(
        voc_tokenizer)
    if run_mode == "train":
        time_load_readers_train_start = time.time()
        if not args.memory_efficient_iterator:

            data_sharded, n_shards, n_sent_dataset_total_train = None, None, None
            args_load_batcher_shard_data = None
            printing("INFO : startign loading readers",
                     verbose=verbose,
                     verbose_level=1)
            readers_train = readers_load(
                datasets=args.train_path,
                tasks=args.tasks,
                word_dictionary=word_dictionary,
                bert_tokenizer=tokenizer,
                word_dictionary_norm=word_norm_dictionary,
                char_dictionary=char_dictionary,
                pos_dictionary=pos_dictionary,
                xpos_dictionary=xpos_dictionary,
                type_dictionary=type_dictionary,
                word_decoder=True,
                run_mode=run_mode,
                add_start_char=1,
                add_end_char=1,
                symbolic_end=1,
                symbolic_root=1,
                bucket=True,
                must_get_norm=True,
                input_level_ls=input_level_ls,
                verbose=verbose)
            n_sent_dataset_total_train = readers_train[list(
                readers_train.keys())[0]][3]
            printing("INFO : done with sharding",
                     verbose=verbose,
                     verbose_level=1)
        else:
            printing("INFO : building/loading shards ",
                     verbose=verbose,
                     verbose_level=1)
            data_sharded, n_shards, n_sent_dataset_total_train = build_shard(
                data_sharded,
                args.train_path,
                n_sent_max_per_file=N_SENT_MAX_CONLL_PER_SHARD,
                verbose=verbose)

        time_load_readers_dev_start = time.time()
        time_load_readers_train = time.time() - time_load_readers_train_start
        readers_dev_ls = []
        dev_data_label_ls = []
        printing("INFO : g readers for dev", verbose=verbose, verbose_level=1)
        printout_allocated_gpu_memory(
            verbose, "{} reader train loaded".format(model_id))
        for dev_path in args.dev_path:
            dev_data_label = get_dataset_label(dev_path, default="dev")
            dev_data_label_ls.append(dev_data_label)
            readers_dev = readers_load(
                datasets=dev_path,
                tasks=args.tasks,
                word_dictionary=word_dictionary,
                word_dictionary_norm=word_norm_dictionary,
                char_dictionary=char_dictionary,
                pos_dictionary=pos_dictionary,
                xpos_dictionary=xpos_dictionary,
                bert_tokenizer=tokenizer,
                type_dictionary=type_dictionary,
                word_decoder=True,
                run_mode=run_mode,
                add_start_char=1,
                add_end_char=1,
                symbolic_end=1,
                symbolic_root=1,
                bucket=False,
                must_get_norm=True,
                input_level_ls=input_level_ls,
                verbose=verbose) if args.dev_path is not None else None
            readers_dev_ls.append(readers_dev)
        printout_allocated_gpu_memory(verbose,
                                      "{} reader dev loaded".format(model_id))

        time_load_readers_dev = time.time() - time_load_readers_dev_start
        # Load tokenizer
        printing("TIME : {} ",
                 var=[
                     OrderedDict([
                         ("time_load_readers_train",
                          "{:0.4f} min".format(time_load_readers_train / 60)),
                         ("time_load_readers_dev",
                          "{:0.4f} min".format(time_load_readers_dev / 60))
                     ])
                 ],
                 verbose=verbose,
                 verbose_level=1)

        early_stoping_val_former = 1000
        # training starts when epoch is 1
        #args.epochs += 1
        #assert args.epochs >= 1, "ERROR need at least 2 epochs (1 eval , 1 train 1 eval"
        flexible_batch_size = False

        if args.optimizer == "AdamW":
            model, optimizer, scheduler = apply_fine_tuning_strategy(
                model=model,
                fine_tuning_strategy=args.fine_tuning_strategy,
                lr_init=args.lr,
                betas=(0.9, 0.99),
                epoch=0,
                weight_decay=args.weight_decay,
                optimizer_name=args.optimizer,
                t_total=n_sent_dataset_total_train / args.batch_update_train *
                args.epochs if n_sent_dataset_total_train /
                args.batch_update_train * args.epochs > 1 else 5,
                verbose=verbose)

        try:
            for epoch in range(args.epochs):
                if args.memory_efficient_iterator:
                    # we start epoch with a new shart everytime !
                    training_file = get_new_shard(data_sharded, n_shards)
                    printing(
                        "INFO Memory efficient iterator triggered (only build for train data , starting with {}",
                        var=[training_file],
                        verbose=verbose,
                        verbose_level=1)
                    args_load_batcher_shard_data = {
                        "word_dictionary": word_dictionary,
                        "tokenizer": tokenizer,
                        "word_norm_dictionary": word_norm_dictionary,
                        "char_dictionary": char_dictionary,
                        "pos_dictionary": pos_dictionary,
                        "xpos_dictionary": xpos_dictionary,
                        "type_dictionary": type_dictionary,
                        "use_gpu": use_gpu,
                        "norm_not_norm": auxilliary_task_norm_not_norm,
                        "word_decoder": True,
                        "add_start_char": 1,
                        "add_end_char": 1,
                        "symbolic_end": 1,
                        "symbolic_root": 1,
                        "bucket": True,
                        "max_char_len": 20,
                        "must_get_norm": True,
                        "use_gpu_hardcoded_readers": False,
                        "bucketing_level": "bpe",
                        "input_level_ls": ["wordpiece"],
                        "auxilliary_task_norm_not_norm":
                        auxilliary_task_norm_not_norm,
                        "random_iterator_train": random_iterator_train
                    }

                    readers_train = readers_load(
                        datasets=args.train_path if
                        not args.memory_efficient_iterator else training_file,
                        tasks=args.tasks,
                        word_dictionary=word_dictionary,
                        bert_tokenizer=tokenizer,
                        word_dictionary_norm=word_norm_dictionary,
                        char_dictionary=char_dictionary,
                        pos_dictionary=pos_dictionary,
                        xpos_dictionary=xpos_dictionary,
                        type_dictionary=type_dictionary,
                        word_decoder=True,
                        run_mode=run_mode,
                        add_start_char=1,
                        add_end_char=1,
                        symbolic_end=1,
                        symbolic_root=1,
                        bucket=True,
                        must_get_norm=True,
                        input_level_ls=input_level_ls,
                        verbose=verbose)

                checkpointing_model_data = (epoch % saving_every_epoch == 0
                                            or epoch == (args.epochs - 1))
                # build iterator on the loaded data
                printout_allocated_gpu_memory(
                    verbose, "{} loading batcher".format(model_id))
                time_load_batcher_start = time.time()

                if args.batch_size == "flexible":
                    flexible_batch_size = True
                    printing(
                        "INFO : args.batch_size {} so updating it based on mean value {}",
                        var=[
                            args.batch_size,
                            update_batch_size_mean(readers_train)
                        ],
                        verbose=verbose,
                        verbose_level=1)
                    args.batch_size = update_batch_size_mean(readers_train)
                    if args.batch_update_train == "flexible":
                        args.batch_update_train = args.batch_size
                    printing(
                        "TRAINING : backward pass every {} step of size {} in average",
                        var=[
                            int(args.batch_update_train // args.batch_size),
                            args.batch_size
                        ],
                        verbose=verbose,
                        verbose_level=1)
                    try:
                        assert isinstance(args.batch_update_train // args.batch_size, int)\
                           and args.batch_update_train // args.batch_size > 0, \
                            "ERROR batch_size {} should be a multiple of {} ".format(args.batch_update_train, args.batch_size)
                    except Exception as e:
                        print("WARNING {}".format(e))
                batchIter_train = data_gen_multi_task_sampling_batch(
                    tasks=args.tasks,
                    readers=readers_train,
                    batch_size=readers_train[list(readers_train.keys())[0]][4],
                    max_token_per_batch=max_token_per_batch
                    if flexible_batch_size else None,
                    word_dictionary=word_dictionary,
                    char_dictionary=char_dictionary,
                    pos_dictionary=pos_dictionary,
                    word_dictionary_norm=word_norm_dictionary,
                    get_batch_mode=random_iterator_train,
                    print_raw=False,
                    dropout_input=0.0,
                    verbose=verbose)

                # -|-|-
                printout_allocated_gpu_memory(
                    verbose, "{} batcher train loaded".format(model_id))
                batchIter_dev_ls = []
                batch_size_DEV = 1
                print(
                    "WARNING : batch_size for final eval was hardcoded and set to {}"
                    .format(batch_size_DEV))
                for readers_dev in readers_dev_ls:
                    batchIter_dev = data_gen_multi_task_sampling_batch(
                        tasks=args.tasks,
                        readers=readers_dev,
                        batch_size=batch_size_DEV,
                        word_dictionary=word_dictionary,
                        char_dictionary=char_dictionary,
                        pos_dictionary=pos_dictionary,
                        word_dictionary_norm=word_norm_dictionary,
                        get_batch_mode=False,
                        print_raw=False,
                        dropout_input=0.0,
                        verbose=verbose) if args.dev_path is not None else None
                    batchIter_dev_ls.append(batchIter_dev)
                # TODO add optimizer (if not : dev loss)
                model.train()
                printout_allocated_gpu_memory(
                    verbose, "{} batcher dev loaded".format(model_id))
                if args.optimizer != "AdamW":

                    model, optimizer, scheduler = apply_fine_tuning_strategy(
                        model=model,
                        fine_tuning_strategy=args.fine_tuning_strategy,
                        lr_init=args.lr,
                        betas=(0.9, 0.99),
                        weight_decay=args.weight_decay,
                        optimizer_name=args.optimizer,
                        t_total=n_sent_dataset_total_train /
                        args.batch_update_train *
                        args.epochs if n_sent_dataset_total_train /
                        args.batch_update_train * args.epochs > 1 else 5,
                        epoch=epoch,
                        verbose=verbose)
                printout_allocated_gpu_memory(
                    verbose, "{} optimizer loaded".format(model_id))
                loss_train = None
                if epoch > 0:
                    printing("TRAINING : training on GET_BATCH_MODE ",
                             verbose=verbose,
                             verbose_level=2)
                    printing(
                        "TRAINING {} training 1 'epoch' = {} observation size args.batch_"
                        "update_train (foward {} batch_size {} backward  "
                        "(every int(args.batch_update_train//args.batch_size) step if {})) ",
                        var=[
                            model_id, n_observation_max_per_epoch_train,
                            args.batch_size, args.batch_update_train,
                            args.low_memory_foot_print_batch_mode
                        ],
                        verbose=verbose,
                        verbose_level=1)
                    loss_train, iter_train, perf_report_train, _ = epoch_run(
                        batchIter_train,
                        tokenizer,
                        args=args,
                        model_origin=model_origin,
                        pruning_mask=pruning_mask,
                        task_to_label_dictionary=task_to_label_dictionary,
                        data_label=train_data_label,
                        model=model,
                        dropout_input_bpe=args.dropout_input_bpe,
                        writer=writer,
                        iter=iter_train,
                        epoch=epoch,
                        writing_pred=epoch == (args.epochs - 1),
                        dir_end_pred=end_predictions,
                        optimizer=optimizer,
                        use_gpu=use_gpu,
                        scheduler=scheduler,
                        predict_mode=(epoch - 1) % 5 == 0,
                        skip_1_t_n=skip_1_t_n,
                        model_id=model_id,
                        reference_word_dic={"InV": inv_word_dic},
                        null_token_index=null_token_index,
                        null_str=null_str,
                        norm_2_noise_eval=False,
                        early_stoppin_metric=None,
                        n_obs_max=n_observation_max_per_epoch_train,
                        data_sharded_dir=data_sharded,
                        n_shards=n_shards,
                        n_sent_dataset_total=n_sent_dataset_total_train,
                        args_load_batcher_shard_data=
                        args_load_batcher_shard_data,
                        memory_efficient_iterator=args.
                        memory_efficient_iterator,
                        verbose=verbose)

                else:
                    printing(
                        "TRAINING : skipping first epoch to start by evaluating on devs dataset0",
                        verbose=verbose,
                        verbose_level=1)
                printout_allocated_gpu_memory(
                    verbose, "{} epcoh train done".format(model_id))
                model.eval()
                if args.dev_path is not None:
                    print("RUNNING DEV on ITERATION MODE")
                    early_stoping_val_ls = []
                    loss_dev_ls = []
                    for i_dev, batchIter_dev in enumerate(batchIter_dev_ls):
                        loss_dev, iter_dev, perf_report_dev, early_stoping_val = epoch_run(
                            batchIter_dev,
                            tokenizer,
                            args=args,
                            epoch=epoch,
                            model_origin=model_origin,
                            pruning_mask=pruning_mask,
                            task_to_label_dictionary=task_to_label_dictionary,
                            iter=iter_dev,
                            use_gpu=use_gpu,
                            model=model,
                            writer=writer,
                            optimizer=None,
                            writing_pred=True,  #epoch == (args.epochs - 1),
                            dir_end_pred=end_predictions,
                            predict_mode=True,
                            data_label=dev_data_label_ls[i_dev],
                            null_token_index=null_token_index,
                            null_str=null_str,
                            model_id=model_id,
                            skip_1_t_n=skip_1_t_n,
                            dropout_input_bpe=0,
                            reference_word_dic={"InV": inv_word_dic},
                            norm_2_noise_eval=False,
                            early_stoppin_metric=early_stoppin_metric,
                            subsample_early_stoping_metric_val=
                            subsample_early_stoping_metric_val,
                            #case=case,
                            n_obs_max=n_observation_max_per_epoch_dev_test,
                            verbose=verbose)

                        printing(
                            "TRAINING : loss train:{} dev {}:{} for epoch {}  out of {}",
                            var=[
                                loss_train, i_dev, loss_dev, epoch, args.epochs
                            ],
                            verbose=1,
                            verbose_level=1)
                        printing("PERFORMANCE {} DEV {} {} ",
                                 var=[epoch, i_dev + 1, perf_report_dev],
                                 verbose=verbose,
                                 verbose_level=1)
                        early_stoping_val_ls.append(early_stoping_val)
                        loss_dev_ls.append(loss_dev)

                    else:
                        loss_dev, iter_dev, perf_report_dev = None, 0, None
                # NB : early_stoping_val is based on first dev set
                printout_allocated_gpu_memory(
                    verbose, "{} epoch dev done".format(model_id))
                early_stoping_val = early_stoping_val_ls[0]
                if checkpointing_model_data or early_stoping_val < early_stoping_val_former:
                    if early_stoping_val is not None:
                        _epoch = "best" if early_stoping_val < early_stoping_val_former else epoch
                    else:
                        print(
                            'WARNING early_stoping_val is None so saving based on checkpointing_model_data only'
                        )
                        _epoch = epoch
                    # model_id enriched possibly with some epoch informaiton if name_with_epoch
                    _model_id = get_name_model_id_with_extra_name(
                        epoch=epoch,
                        _epoch=_epoch,
                        name_with_epoch=name_with_epoch,
                        model_id=model_id)
                    checkpoint_dir = os.path.join(
                        model_location, "{}-checkpoint.pt".format(_model_id))

                    if _epoch == "best":
                        printing(
                            "CHECKPOINT : SAVING BEST MODEL {} (epoch:{}) (new loss is {} former was {})"
                            .format(checkpoint_dir, epoch, early_stoping_val,
                                    early_stoping_val_former),
                            verbose=verbose,
                            verbose_level=1)
                        last_checkpoint_dir_best = checkpoint_dir
                        early_stoping_val_former = early_stoping_val
                        best_epoch = epoch
                        best_loss = early_stoping_val
                    else:
                        printing(
                            "CHECKPOINT : NOT SAVING BEST MODEL : new loss {} did not beat first loss {}"
                            .format(early_stoping_val,
                                    early_stoping_val_former),
                            verbose_level=1,
                            verbose=verbose)
                    last_model = ""
                    if epoch == (args.epochs - 1):
                        last_model = "last"
                    printing("CHECKPOINT : epoch {} saving {} model {} ",
                             var=[epoch, last_model, checkpoint_dir],
                             verbose=verbose,
                             verbose_level=1)
                    torch.save(model.state_dict(), checkpoint_dir)

                    args_dir = write_args(
                        dir=model_location,
                        checkpoint_dir=checkpoint_dir,
                        hyperparameters=hyperparameters
                        if name_with_epoch else None,
                        model_id=_model_id,
                        info_checkpoint=OrderedDict([
                            ("epochs", epoch + 1),
                            ("batch_size", args.batch_size
                             if not args.low_memory_foot_print_batch_mode else
                             args.batch_update_train),
                            ("train_path", train_data_label),
                            ("dev_path", dev_data_label_ls),
                            ("num_labels_per_task", num_labels_per_task)
                        ]),
                        verbose=verbose)
            if row is not None:
                update_status(row=row, value="training-done", verbose=1)
        except Exception as e:
            if row is not None:
                update_status(row=row, value="ERROR", verbose=1)
            raise (e)

    # reloading last (best) checkpoint
    if run_mode in ["train", "test"] and args.test_paths is not None:
        report_all = []
        if run_mode == "train" and args.epochs > 0:
            if use_gpu:
                model.load_state_dict(torch.load(last_checkpoint_dir_best))
                model = model.cuda()
                printout_allocated_gpu_memory(
                    verbose, "{} after reloading model".format(model_id))
            else:
                model.load_state_dict(
                    torch.load(last_checkpoint_dir_best,
                               map_location=lambda storage, loc: storage))
            printing(
                "MODEL : RELOADING best model of epoch {} with loss {} based on {}({}) metric (from checkpoint {})",
                var=[
                    best_epoch, best_loss, early_stoppin_metric,
                    subsample_early_stoping_metric_val,
                    last_checkpoint_dir_best
                ],
                verbose=verbose,
                verbose_level=1)

        model.eval()

        printout_allocated_gpu_memory(verbose,
                                      "{} starting test".format(model_id))
        for test_path in args.test_paths:
            assert len(test_path) == len(
                args.tasks), "ERROR test_path {} args.tasks {}".format(
                    test_path, args.tasks)
            for test, task_to_eval in zip(test_path, args.tasks):
                label_data = get_dataset_label([test], default="test")
                if len(extra_label_for_prediction) > 0:
                    label_data += "-" + extra_label_for_prediction

                readers_test = readers_load(
                    datasets=[test],
                    tasks=[task_to_eval],
                    word_dictionary=word_dictionary,
                    word_dictionary_norm=word_norm_dictionary,
                    char_dictionary=char_dictionary,
                    pos_dictionary=pos_dictionary,
                    xpos_dictionary=xpos_dictionary,
                    type_dictionary=type_dictionary,
                    bert_tokenizer=tokenizer,
                    word_decoder=True,
                    run_mode=run_mode,
                    add_start_char=1,
                    add_end_char=1,
                    symbolic_end=1,
                    symbolic_root=1,
                    bucket=bucket_test,
                    input_level_ls=input_level_ls,
                    must_get_norm=must_get_norm_test,
                    verbose=verbose)

                heuritics_zip = [None]
                gold_error_or_not_zip = [False]
                norm2noise_zip = [False]

                if heuristic_test_ls is None:
                    assert len(gold_error_or_not_zip) == len(
                        heuritics_zip) and len(heuritics_zip) == len(
                            norm2noise_zip)

                batch_size_TEST = 1

                print(
                    "WARNING : batch_size for final eval was hardcoded and set to {}"
                    .format(batch_size_TEST))
                for (heuristic_test, gold_error,
                     norm_2_noise_eval) in zip(heuritics_zip,
                                               gold_error_or_not_zip,
                                               norm2noise_zip):

                    assert heuristic_test is None and not gold_error and not norm_2_noise_eval

                    batchIter_test = data_gen_multi_task_sampling_batch(
                        tasks=[task_to_eval],
                        readers=readers_test,
                        batch_size=batch_size_TEST,
                        word_dictionary=word_dictionary,
                        char_dictionary=char_dictionary,
                        pos_dictionary=pos_dictionary,
                        word_dictionary_norm=word_norm_dictionary,
                        get_batch_mode=False,
                        dropout_input=0.0,
                        verbose=verbose)
                    try:
                        #pdb.set_trace()
                        loss_test, iter_test, perf_report_test, _ = epoch_run(
                            batchIter_test,
                            tokenizer,
                            args=args,
                            reader=readers_test,
                            iter=iter_dev,
                            use_gpu=use_gpu,
                            model=model,
                            task_to_label_dictionary=task_to_label_dictionary,
                            writer=None,
                            writing_pred=True,
                            optimizer=None,
                            args_dir=args_dir,
                            model_id=model_id,
                            dir_end_pred=end_predictions,
                            skip_1_t_n=skip_1_t_n,
                            predict_mode=True,
                            data_label=label_data,
                            epoch="LAST",
                            extra_label_for_prediction=label_data,
                            null_token_index=null_token_index,
                            null_str=null_str,
                            log_perf=False,
                            dropout_input_bpe=0,
                            norm_2_noise_eval=norm_2_noise_eval,
                            compute_intersection_score=
                            compute_intersection_score_test,
                            remove_mask_str_prediction=
                            remove_mask_str_prediction,
                            reference_word_dic={"InV": inv_word_dic},
                            threshold_edit=threshold_edit,
                            n_obs_max=n_observation_max_per_epoch_dev_test,
                            verbose=verbose)
                        print("LOSS TEST", loss_test)
                    except Exception as e:
                        print(
                            "ERROR (epoch_run test) {} test_path {} , heuristic {} , gold error {} , norm2noise {} "
                            .format(e, test, heuristic_test, gold_error,
                                    norm_2_noise_eval))
                        raise (e)
                    print("PERFORMANCE TEST on data  {} is {} ".format(
                        label_data, perf_report_test))
                    print("DATA WRITTEN {}".format(end_predictions))
                    if writer is not None:
                        writer.add_text(
                            "Accuracy-{}-{}-{}".format(model_id, label_data,
                                                       run_mode),
                            "After {} epochs with {} : performance is \n {} ".
                            format(args.epochs, description,
                                   str(perf_report_test)), 0)
                    else:
                        printing(
                            "WARNING : could not add accuracy to tensorboard cause writer was found None",
                            verbose=verbose,
                            verbose_level=1)
                    report_all.extend(perf_report_test)
                    printout_allocated_gpu_memory(
                        verbose, "{} test done".format(model_id))
    else:
        printing("ERROR : EVALUATION none cause {} empty or run_mode {} ",
                 var=[args.test_paths, run_mode],
                 verbose_level=1,
                 verbose=verbose)

    if writer is not None:
        writer.close()
        printing("tensorboard --logdir={} --host=localhost --port=1234 ",
                 var=[tensorboard_log],
                 verbose_level=1,
                 verbose=verbose)

    report_dir = os.path.join(model_location, model_id + "-report.json")
    if report_full_path_shared is not None:
        report_full_dir = os.path.join(report_full_path_shared,
                                       args.overall_label + "-report.json")
        if os.path.isfile(report_full_dir):
            report = json.load(open(report_full_dir, "r"))
        else:
            report = []
            printing("REPORT = creating overall report at {} ",
                     var=[report_dir],
                     verbose=verbose,
                     verbose_level=1)
        report.extend(report_all)
        json.dump(report, open(report_full_dir, "w"))
        printing("{} {} ",
                 var=[REPORT_FLAG_DIR_STR, report_full_dir],
                 verbose=0,
                 verbose_level=0)

    json.dump(report_all, open(report_dir, "w"))
    printing("REPORTING TO {}".format(report_dir),
             verbose=verbose,
             verbose_level=1)
    if report_full_path_shared is None:
        printing("WARNING ; report_full_path_shared is None",
                 verbose=verbose,
                 verbose_level=1)
        printing("{} {} ",
                 var=[REPORT_FLAG_DIR_STR, report_dir],
                 verbose=verbose,
                 verbose_level=0)

    return model
def epoch_run(batchIter, tokenizer,
              args,
              iter, n_obs_max, model, epoch,
              use_gpu, data_label, null_token_index, null_str,
              model_id, early_stoppin_metric=None,reader=None,
              skip_1_t_n=True,
              writer=None, optimizer=None,
              predict_mode=False, topk=None, metric=None,
              args_dir=None,
              reference_word_dic=None, dropout_input_bpe=0.,
              writing_pred=False, dir_end_pred=None, extra_label_for_prediction="",
              log_perf=True, remove_mask_str_prediction=False,
              norm_2_noise_eval=False,
              compute_intersection_score=False,
              subsample_early_stoping_metric_val=None,
              threshold_edit=None,
              ponderation_loss_policy="static",
              samples_per_task_reporting=None,
              task_to_eval=None, task_to_label_dictionary=None,
              data_sharded_dir=None, n_shards=None, n_sent_dataset_total=None, args_load_batcher_shard_data=None,
              memory_efficient_iterator=False, model_origin=None,pruning_mask=None,  scheduler=None,
              verbose=0):
    """
    About Evaluation :
    Logic : compare gold and prediction topk using a word level scoring fucntion
            then accumulates for each sentences and foea each batch to get global score
            CAN add SAMPLE Parameter to get scores on specific subsample of the data : e.g. NEED_NORM, NORMED...
            Can also have different aggregation function
    """
    if optimizer is not None:
        # we need it to track distance in all cases
        # if training mode and penalize we need mode_origin
        assert model_origin is not None
    assert task_to_label_dictionary is not None, "ERROR : task_to_label_dictionary should be defined "

    if samples_per_task_reporting is None:
        samples_per_task_reporting = SAMPLES_PER_TASK_TO_REPORT
    if task_to_eval is not None:
        args.tasks = task_to_eval
        assert task_to_eval in task_to_label_dictionary, "ERROR : {} label was not provided in {}".format(task_to_eval, task_to_label_dictionary)
        printing("WARNING : task_to_eval was provided ", verbose=verbose, verbose_level=1)
    if ponderation_loss_policy == "static":
        assert args.multi_task_loss_ponderation is not None
    else:
        raise(Exception("Only static strategy supported so far"))

    if args.low_memory_foot_print_batch_mode:
        assert args.batch_update_train > 0, "ERROR have to define batch_size_real in low_memory_foot_print_batch_mode"

    if args.heuristic_ls is not None:
        for edit_rule in ["all", "ref", "data"]:
            if "edit_check-"+edit_rule in args.heuristic_ls:
                assert threshold_edit is not None, "ERROR threshold_edit required as args.heuristic_ls is {}".format(args.heuristic_ls)

    if args.case is not None:
        AVAILABLE_CASE_OPTIONS = ["lower"]
        assert args.case in AVAILABLE_CASE_OPTIONS
    assert args.norm_2_noise_training is None or not norm_2_noise_eval, "only one of the two should be triggered but we have args.norm_2_noise_training : {} norm_2_noise_eval:{}".format(args.norm_2_noise_training, norm_2_noise_eval)
    if args.norm_2_noise_training is not None:
        printing("WARNING : {} args.norm_2_noise_training is on ", var=[args.norm_2_noise_training],
                 verbose=verbose, verbose_level=1)
    if norm_2_noise_eval:
        printing("WARNING : {} norm_2_noise_eval is on ", var=[norm_2_noise_eval],
                 verbose=verbose, verbose_level=1)
    assert len(args.tasks) <= 2
    evaluated_task = []
    skip_score = 0
    skipping = 0
    mean_end_pred = 0
    label_heuristic = ""
    if memory_efficient_iterator:
        assert data_sharded_dir is not None and n_shards is not None, "ERROR data_sharded_dir and n_shards needed as args.memory_efficient_iterator {}".format(memory_efficient_iterator)
        assert n_sent_dataset_total is not None
    printing("INFO : HEURISTIC used {} {}", var=[args.heuristic_ls, label_heuristic], verbose=verbose, verbose_level=1)
    if predict_mode:
        if topk is None:
            topk = 1
            printing("PREDICTION MODE : setting top-k to default 1 ", verbose_level=1, verbose=verbose)
        print_pred = False
        if metric is None:
            metric = "exact_match"
            printing("PREDICTION MODE : setting metric to default 'exact_match' ", verbose_level=1, verbose=verbose)

    if writing_pred:
        assert dir_end_pred is not None
        if extra_label_for_prediction != "":
            extra_label_for_prediction = "-"+extra_label_for_prediction
        extra_label_for_prediction += "-"+label_heuristic
        dir_normalized = os.path.join(dir_end_pred, "{}_ep-prediction{}.conll".format(epoch,
                                                                                      extra_label_for_prediction))
        dir_normalized_original_only = os.path.join(dir_end_pred, "{}_ep-prediction_src{}.conll".format(epoch,
                                                                                                        extra_label_for_prediction))
        dir_gold = os.path.join(dir_end_pred, "{}_ep-gold-{}.conll".format(epoch,
                                                                          extra_label_for_prediction))
        dir_gold_original_only = os.path.join(dir_end_pred, "{}_ep-gold_src{}.conll".format(epoch,
                                                                                            extra_label_for_prediction))

    mask_token_index = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
    cls_token_index = tokenizer.convert_tokens_to_ids(tokenizer.cls_token)
    sep_token_index = tokenizer.convert_tokens_to_ids(tokenizer.sep_token)
    pad_token_index = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)

    #space_token_index = tokenizer.convert_tokens_to_ids([null_str])[0]
    printing("ITERATOR : {} : {} {} : {} {} : {} {} : {}", var=[tokenizer.mask_token, mask_token_index, tokenizer.cls_token, cls_token_index, tokenizer.sep_token, sep_token_index, tokenizer.pad_token, pad_token_index],
             verbose=verbose, verbose_level=1)
    printing("ITERATOR : PAD TAG {} PAD HEADS {}", var=[PAD_ID_TAG, PAD_ID_HEADS], verbose=verbose, verbose_level=1)
    batch_i = 0
    noisy_over_splitted = 0
    noisy_under_splitted = 0
    aligned = 0
    skipping_batch_n_to_1 = 0
    n_obs_forwarded = 0
    n_obs_backward = 0
    n_obs_backward_save = 0
    n_obs_forwarded_not_backwarded = 0
    backprop_step = 0
    loss = 0
    penalize = 0
    penalization_dic = None
    report_penalization = False

    agg_func_ls = ["sum"]
    printout_allocated_gpu_memory(verbose=verbose, comment="starting epoch")
    score_dic, n_tokens_dic, n_sents_dic = init_score_token_sent_dict(samples_per_task_reporting, [task for tasks in args.tasks for task in tasks],
                                                                      agg_func_ls, compute_intersection_score,
                                                                      task_settings=TASKS_PARAMETER)

    _samples_per_task_reporting = list(samples_per_task_reporting.keys())+["all"]

    n_tokens_counter_per_task = OrderedDict((a, 0) for a in _samples_per_task_reporting)

    loss_dic_epoch = OrderedDict((a, 0) for a in _samples_per_task_reporting)
    # vocab_index_except_pad_cls_sep = [i for i in range(1, len(tokenizer.vocab)) if i not in [mask_token_index, sep_token_index, cls_token_index]]
    # pad is the first index
    skipping_evaluated_batch = 0
    mode = "?"
    new_file = True
    loss_norm = 0
    loss_pos = 0
    loss_n_mask_prediction = 0
    n_batch_pos = 0
    n_batch_norm = 0
    n_task_normalize_sanity = 0

    counting_failure_parralel_bpe_batch = 0

    time_multitask_train = 0
    time_backprop = 0
    time_multitask_preprocess_1 = 0
    time_multitask_preprocess_2 = 0
    time_multitask_postprocess = 0
    time_score = 0
    time_penalize = 0
    time_write_pred = 0
    backprop_step_former = -1
    time_overall_pass = time.time()
    end_schedule_lr = 0

    n_shard = 0

    while True:
        try:
            if memory_efficient_iterator and n_obs_forwarded >= n_sent_dataset_total:
                printing("BREAKING ALL ITERATORS memory_efficient_iterator True (mode is {}  shard {} ending) ",
                         var=[mode, n_shard], verbose_level=1, verbose=1)
                break
            batch_i += 1

            time_multitask_preprocess_start = time.time()
            
            start_schedule = time.time()

            if args.schedule_lr is not None and optimizer is not None:
                assert args.optimizer != "AdamW", "ERROR schedule_lr not supported in AdamW"
                assert args.fine_tuning_strategy == "standart",\
                    "ERROR only fine_tuning_strategy standart supported in shedule mode but is {} ".format(args.fine_tuning_strategy)

                def get_schedule_lr(args, i_step):
                    warmup_init_lr = 0.0000001
                    assert isinstance(args.n_steps_warmup, int) and args.n_steps_warmup>0, "ERROR n_steps_warmup {} ".format(args.n_steps_warmup)
                    #args.n_steps_warmup = 100
                    lr_step = (args.lr - warmup_init_lr) / args.n_steps_warmup
                    if i_step < args.n_steps_warmup:
                        lr = warmup_init_lr + i_step * lr_step
                    else:
                        lr = args.lr * (args.n_steps_warmup / i_step)**0.5
                    print("UPDATING OPTIMIZER WITH LR {} based on {} step , step warmup {} lr_step {} backprop ({} warming up )".format(lr, i_step, args.n_steps_warmup, lr_step, i_step < args.n_steps_warmup))
                    return lr
                if backprop_step != backprop_step_former:
                    lr = get_schedule_lr(args, i_step=backprop_step+1)
                    backprop_step_former = backprop_step
                    writer.add_scalars("opt/lr-schedule",  {"lr_model-{}".format(model_id): lr}, backprop_step)
                    writer.add_scalars("opt/lr-schedule2", {"lr_model-{}".format(model_id): lr}, backprop_step)
                    _, optimizer = apply_fine_tuning_strategy(model=model, fine_tuning_strategy=args.fine_tuning_strategy,lr_init=lr, betas=(0.9, 0.99),weight_decay=args.weight_decay,epoch=epoch, verbose=verbose)
            end_schedule_lr += time.time()-start_schedule

            batch = batchIter.__next__()
            # Normalization task is handled seperately
            # case the batches if case is 'lower'
            batch = get_casing(args.case, batch, False, cls_token=tokenizer.cls_token, sep_token=tokenizer.sep_token)
            #n_task_normalize_sanity += int(task_normalize_is)
            # handling normalization input
            time_multitask_preprocess_start = time.time()
            printout_allocated_gpu_memory(verbose=verbose, comment="starting step")
            # added but are related to the old flow
            #batch_raw_input, norm2noise_bool, args.norm_2_noise_training = input_normalization_processing(False, batch, args.norm_2_noise_training, False)

            #input_tokens_tensor, input_segments_tensors, inp_bpe_tokenized, input_alignement_with_raw, input_mask = \
            #    get_indexes(batch_raw_input, tokenizer, verbose, use_gpu, word_norm_not_norm=None)
            #input_mask = get_mask_input(input_tokens_tensor, use_gpu)
            #print(batch.)

            head_masks, input_tokens_tensor, token_type_ids, label_per_task,\
            input_tokens_tensor_per_task, input_mask_per_task, cumulate_shift_sub_word = get_label_per_bpe(tasks=args.tasks, batch=batch,
                                                                                          pad_index=pad_token_index,
                                                                                          use_gpu=use_gpu, tasks_parameters=TASKS_PARAMETER,
                                                                                          input_alignement_with_raw=None,
                                                                                          input_tokens_tensor=None,
                                                                                          masking_strategy=args.masking_strategy,
                                                                                          vocab_len=BERT_MODEL_DIC[args.bert_model]["vocab_size"],#len(tokenizer.vocab)-2,
                                                                                          mask_token_index=mask_token_index,
                                                                                          sep_token_index=sep_token_index,
                                                                                          cls_token_index=cls_token_index,
                                                                                          dropout_input_bpe=dropout_input_bpe)


            # NB : token_type_ids not used in MultiTask (no needed, just use 0 everywhere )
            #dimension_check_label(label_per_task, input_tokens_tensor)
            time_multitask_preprocess_1 += time.time()-time_multitask_preprocess_start
            printout_allocated_gpu_memory(verbose=verbose, comment="got input/output")
            # NB : we use the aligned input with the

            _1_to_n_token = 0

            if n_obs_forwarded >= n_obs_max:# and not args.low_memory_foot_print_batch_mode) or (batch_i == n_iter_max * int(args.batch_update_train // args.batch_size) and args.low_memory_foot_print_batch_mode):
                print("BREAKING ITERATION model {} because {} n_obs_max reached  (n_obs_forwarded {})".format(model_id, n_obs_max, n_obs_forwarded))
                break
            if batch_i % 1000 == 0:
                printing("TRAINING : iteration finishing {} batch", var=[batch_i], verbose=verbose, verbose_level=1)
            if _1_to_n_token:
                skipping_batch_n_to_1 += _1_to_n_token
                #continue
            # sanity checking alignement
            # we consider only 1 sentence case
            #printing("CUDA SANITY CHECK input_tokens:{}  type:{}input_mask:{}  label:{}", var=[input_tokens_tensor.is_cuda, token_type_ids.is_cuda, input_mask.is_cuda, output_tokens_tensor_aligned.is_cuda], verbose=verbose, verbose_level="cuda")
            # we have to recompute the mask based on aligned input

            assert args.masking_strategy is None, "ERROR : {} not supported in multitask mode ".format(args.masking_strategy)

            # multitask :
            time_multitask_preprocess_2_start = time.time()

            n_tokens_counter_per_task, n_tokens_counter_current_per_task, n_tokens_all = count_tokens([task for tasks in args.tasks for task in tasks],
                                                                                                      n_tokens_counter_per_task, label_per_task, LABEL_PARAMETER)
            n_tokens_counter_per_task["all"] += n_tokens_all

            time_multitask_preprocess_2 += time.time()-time_multitask_preprocess_2_start
            time_multitask_train_start = time.time()

            logits_dic, loss_dic, _ = model(input_tokens_tensor_per_task, token_type_ids=None, labels=label_per_task, head_masks=head_masks, attention_mask=input_mask_per_task)
            #pdb.set_trace()
            printout_allocated_gpu_memory(verbose=verbose, comment="feedforward done")
            # loss_dic_epoch is the sum over all the epoch (mean computed for reporting)
            loss_dic_epoch = update_loss_dic_average(loss_dic, loss_dic_epoch)
            loss_dic = loss_mean(loss_dic, n_tokens_counter_current_per_task)

            if predict_mode:
                predictions_topk_dic = get_prediction(logits_dic, topk=topk)
                printout_allocated_gpu_memory(verbose=verbose, comment="prediction done")
                time_multitask_train += time.time()-time_multitask_train_start
                time_multitask_postprocess_start = time.time()
                assert "normalize" not in args.tasks, "ERROR : following line () was needed apparently for normalize being supported"
                # for parsing heads will leave heads untouched
                # POSTPROCESSING : get back to untokenized string
                source_preprocessed_dict, label_dic, predict_dic = get_bpe_string(predictions_topk_dic, label_per_task,
                                                                                  input_tokens_tensor_per_task, topk,
                                                                                  tokenizer, task_to_label_dictionary, #null_str, null_token_index,
                                                                                  TASKS_PARAMETER, mask_token_index, verbose)
                # input_tokens_tensor=eval("batch.{}".format(tasks_parameters[task]["input"])).clone(),
                # input_alignement_with_raw=eval("batch.{}_alignement".format(tasks_parameters[task]["input"]))
                #pdb.set_trace()
                #print("source", source_preprocessed_dict)

                src_detokenized_dic, label_detokenized_dic, predict_detokenize_dic = get_detokenized_str(source_preprocessed_dict=source_preprocessed_dict,
                                                                                                         input_alignement_with_raw=eval("batch.{}_alignement".format(TASKS_PARAMETER["pos"]["input"])),
                                                                                                         label_dic=label_dic, predict_dic=predict_dic,
                                                                                                         #remove_mask_str_prediction=remove_mask_str_prediction, null_str=null_str,
                                                                                                         # batch=batch,
                                                                                                         task_settings=TASKS_PARAMETER,
                                                                                                         flag_word_piece_token=BERT_MODEL_DIC[args.bert_model].get("wordpiece_flag",  "##"),
                                                                                                         flag_is_first_token=BERT_MODEL_DIC[args.bert_model].get("flag_is_first_token",0),
                                                                                                         # BERT_MODEL_DIC[args.bert_model].get("flag_is_first_token", 0),
                                                                                                         mask_str=tokenizer.mask_token, end_token=tokenizer.sep_token,
                                                                                                         cumulate_shift_sub_word=cumulate_shift_sub_word)

                #pdb.set_trace()
                if "parsing" in args.tasks[0]:
                    assert label_detokenized_dic["heads"]
                    try:
                        import numpy as np
                        for _ind, (gold_rebuild, gold) in enumerate(zip(label_detokenized_dic["heads"][0], batch.heads[0])):
                            if gold != -1:

                                assert gold_rebuild == gold, "ERROR {}: {} {},  " \
                                                             "label_detokenized_dic[heads] {} " \
                                                             "and batch.heads[0]) {}".format(_ind, gold_rebuild, gold,
                                                                                             label_detokenized_dic["heads"],
                                                                                             batch.heads[0])
                        print("VALIDATED")
                    except Exception as e:
                        print("HEADS SANITY CHECKED FAILED")
                        raise(e)
                        #pdb.set_trace()

                        #pdb.set_trace()
                #pdb.set_trace()
                log_data_src_label_pred(src_detokenized_dic, predict_detokenize_dic, label_detokenized_dic, tasks=args.tasks, verbose=verbose, verbose_level=2)
                printout_allocated_gpu_memory(verbose=verbose, comment="got string")
                label_processed = []
                time_multitask_postprocess += time.time() - time_multitask_postprocess_start
                # SCORING : get back to untokenized string
                time_score_start = time.time()
                for label_pred in predict_detokenize_dic:
                    label, _, _continue, label_processed = get_task_name_based_on_logit_label(label_pred, label_processed)
                    if _continue:
                        continue

                    task = re.match("(.*)-.*", label_pred).group(1)
                    src_detokenized = src_detokenized_dic[TASKS_PARAMETER[task]["input"]]
                    filter_score = samples_per_task_reporting[label_pred]
                    if label_detokenized_dic[label] is not None:
                        perf_prediction, skipping, _samples = overall_word_level_metric_measure(task_label=label, pred_label=label_pred,
                                                                                            gold_sent_ls_dict=label_detokenized_dic,
                                                                                            pred_sent_ls_topk_dict=predict_detokenize_dic,
                                                                                            topk=topk,
                                                                                            metric=metric,
                                                                                            samples=filter_score,
                                                                                            agg_func_ls=agg_func_ls,
                                                                                            reference_word_dic=reference_word_dic,
                                                                                            compute_intersection_score=compute_intersection_score,
                                                                                            mask_token=tokenizer.mask_token,
                                                                                            cls_token=tokenizer.cls_token,
                                                                                            sep_token=tokenizer.sep_token,
                                                                                            src_detokenized=src_detokenized)
                    else:
                        perf_prediction = {"score": 0,
                                           "agg_func": "sum", "metric": "exact_match",
                                           "n_tokens": 0,
                                           "n_sents": 0,}
                        skipping = 0
                        _samples = ["all"]

                    printing("PREDICTION epoch {} task {} score all {}/{} total "
                             "gold {} gold token {} pred {} pred token {} ",
                             var=[epoch, label, perf_prediction["sum"]["all"]["score"],
                                  perf_prediction["sum"]["all"]["n_tokens"],
                                  label_detokenized_dic[label], label_per_task[label],
                                  predict_detokenize_dic[label_pred], predictions_topk_dic[label_pred][:, :, 0]],
                             verbose=verbose, verbose_level="pred")

                    score_dic[label_pred], n_tokens_dic[label_pred], n_sents_dic[label_pred] = \
                        accumulate_scores_across_sents(agg_func_ls=agg_func_ls,
                                                       sample_ls=_samples, dic_prediction_score=perf_prediction,
                                                       score_dic=score_dic[label_pred],
                                                       n_tokens_dic=n_tokens_dic[label_pred],
                                                       n_sents_dic=n_sents_dic[label_pred])

                    evaluated_task.append(label_pred)
                # WRITTING PREDICTION
                time_score += time.time()-time_score_start
                time_write_pred_start = time.time()
                if writing_pred:
                    #batch_i#
                    # get the righ index sentence + handle batch size >1
                    # reader[task][0][0][-1][batch_i]
                    batch_sze = len(batch.raw_input)
                    raw_tests = []
                    sent_ids = []
                    append_mwe_ind = []
                    append_mwe_row = []
                    empty_mwe = True
                    for i_sent in range(batch_sze):
                        _append_mwe_row = []
                        _append_mwe_ind = []
                        # get raw sentence and idnex for the batch
                        comment_sent_i = reader[task][0][0][-1][batch_i-1+i_sent][-1]
                        raw_tests.append(comment_sent_i[0])
                        sent_ids.append(comment_sent_i[1])
                        # look for mwe
                        for word in reader[task][0][0][-1][batch_i-1+i_sent][0]:
                            if "-" in word[0]:
                                _append_mwe_row.append("\t".join(word)+"\n")
                                _append_mwe_ind.append(int(word[0].split("-")[0]))
                                empty_mwe = False
                        append_mwe_row.append(_append_mwe_row)
                        append_mwe_ind.append(_append_mwe_ind)
                    if empty_mwe:
                        append_mwe_row = None
                        append_mwe_ind = None
                    #pdb.set_trace()
                    new_file = writing_predictions_conll_multi(
                                            dir_pred=dir_normalized,
                                            append_mwe_ind=append_mwe_ind,
                                            append_mwe_row=append_mwe_row,
                                            sent_ids=sent_ids, raw_texts=raw_tests,
                                            dir_normalized_original_only=dir_normalized_original_only,
                                            dir_gold=dir_gold, dir_gold_original_only=dir_gold_original_only,
                                            src_detokenized=src_detokenized_dic, pred_per_task=predict_detokenize_dic,
                                            iter=iter, batch_i=batch_i, new_file=new_file, gold_per_tasks=label_detokenized_dic,
                                            all_indexes=batch.all_indexes, task_parameters=TASKS_PARAMETER,
                                            cls_token=tokenizer.cls_token, sep_token=tokenizer.sep_token,
                                            tasks=args.tasks, verbose=verbose)
                time_write_pred += time.time() - time_write_pred_start
                printout_allocated_gpu_memory(verbose=verbose, comment="got score")
            report_penalization = optimizer is not None or (optimizer is None and epoch == 0)

            if report_penalization and args.ponderation_per_layer is not None:
                # NB : report_penalize is required if want to optimize using penalization
                time_get_penalize = time.time()
                penalize, penalization_dic = get_penalization(norm_order_per_layer=args.norm_order_per_layer,
                                                              ponderation_per_layer=args.ponderation_per_layer,
                                                              model_parameters=model.named_parameters(),
                                                              model_parameters_0=model_origin,
                                                              penalization_mode=args.penalization_mode,
                                                              pruning_mask=pruning_mask)

                printout_allocated_gpu_memory(verbose=verbose, comment="got penalization")

                if not args.penalize:
                    penalize = 0
                time_penalize += time.time() - time_get_penalize

            _loss = get_loss_multitask(loss_dic, args.multi_task_loss_ponderation)
            loss_dic["multitask"] = _loss.detach().clone().cpu()
            _loss += penalize
            loss_dic["all"] = _loss
            # temporary
            # based on a policy : handle batch, epoch, batch weights, simultanuously
            # assert the policy is consistent with the available labels fed to the model
            # training :
            time_backprop_start = time.time()
            loss += _loss.detach()
            # BACKWARD PASS
            # batch_i is the iteration counter
            #back_pass = optimizer is not None and ((args.low_memory_foot_print_batch_mode and batch_i % int(args.batch_update_train // args.batch_size) == 0) or not args.low_memory_foot_print_batch_mode)
            back_pass = optimizer is not None and (args.low_memory_foot_print_batch_mode and n_obs_forwarded >= args.batch_update_train) or not args.low_memory_foot_print_batch_mode
            n_obs_forwarded_not_backwarded += input_tokens_tensor_per_task[list(input_tokens_tensor_per_task.keys())[0]].size(0)
            n_obs_forwarded += input_tokens_tensor_per_task[list(input_tokens_tensor_per_task.keys())[0]].size(0)
            if optimizer is not None:
                mode = "train"

                _loss.backward()

                printout_allocated_gpu_memory(verbose, "loss backwarded")
                if (args.low_memory_foot_print_batch_mode and n_obs_forwarded >= args.batch_update_train) or not args.low_memory_foot_print_batch_mode:
                    n_obs_backward = n_obs_forwarded_not_backwarded
                    n_obs_backward_save += n_obs_forwarded_not_backwarded
                    n_obs_forwarded_not_backwarded = 0
                    backprop_step += 1
                    if args.low_memory_foot_print_batch_mode:
                        printing("OPTIMIZING in low_memory_foot_print_batch_mode cause batch index {}"
                                 "we update every {} batch_update_train {} batch_size to get batch backward pass of size {}",
                                 var=[batch_i, args.batch_update_train, args.batch_size, args.batch_update_train],
                                 verbose=verbose, verbose_level=1)
                    for opti in optimizer:
                        opti.step()
                        if scheduler is not None:
                            printing("OPTIMIZING : updating scheduler current step {} backward pass {} lr {}  init lr {}",
                                     var=[batch_i, backprop_step, opti.param_groups[0]["lr"],
                                          opti.param_groups[0]["initial_lr"]],
                                     verbose=verbose, verbose_level=1)
                            scheduler.step()

                        opti.zero_grad()
                printout_allocated_gpu_memory(verbose=verbose, comment="optimized")
            else:
                mode = "dev"
            time_backprop += time.time() - time_backprop_start

            #if writer is not None:
            #    tensorboard_loss_writer_batch_level(writer, mode, model_id, _loss, batch_i, iter,  loss_dic,
            #                                        False, args.append_n_mask)
            #    tensorboard_loss_writer_batch_level_multi(writer, mode, model_id, _loss, batch_i, iter, loss_dic, tasks=args.tasks)

        except StopIteration:
            printing("BREAKING ITERATION model {} -  {} iter - {} n_obs_forwarded  -  {} n_obs_backward or {} step of {} batch_update_train"
                     "(mode is {} memory_efficient_iterator {} , shard {} ending ",
                     var=[model_id, batch_i, n_obs_forwarded, n_obs_backward_save, backprop_step, args.batch_update_train, mode, memory_efficient_iterator, n_shard],
                     verbose_level=1, verbose=1)
            if optimizer is not None:
                assert n_obs_backward > 0, "ERROR : train mode but did not backpropagage any thing "
            n_shard += 1
            if not memory_efficient_iterator:
                break
            training_file = get_new_shard(data_sharded_dir, n_shards, verbose=verbose)
            printing("ITERATOR shard model {} - epoch {} , n observations forwarded {} "
                     "batch {} (n_ob_max {}) starting new {} ".format(model_id, epoch, n_obs_forwarded, batch_i, n_obs_max, training_file),
                     verbose=verbose, verbose_level=1)
            batchIter = load_batcher_shard_data(args, args_load_batcher_shard_data, training_file, verbose)

    overall_pass = time.time()-time_overall_pass

    printing("TIME epoch {}/{} done mode_id {}  {:0.3f}/{:0.3f} min "
             "({} forward obs {} backward obs) for {} iteration of {} batch_size in {} (fixed) averaged if flexible {} "
             " mode out of {} sent total or {} steps : {} min/batch {} min/sent",
             var=[epoch, args.epochs,
                  model_id,
                  overall_pass/60,n_sent_dataset_total * (overall_pass / 60) / n_obs_forwarded if n_sent_dataset_total is not None else 0,
                  n_obs_backward, n_obs_forwarded,
                  batch_i, args.batch_size, args.batch_size,
                  "train" if optimizer is not None else "dev", str(n_sent_dataset_total),
                  n_sent_dataset_total / args.batch_size if n_sent_dataset_total is not None else "_",
                  overall_pass/60/batch_i, overall_pass/60/n_obs_forwarded ],
             verbose_level=1, verbose=verbose)

    timing = OrderedDict([("time_multitask_preprocess_1 (get_label)", "{:0.4f} min/total {:0.4f} s/batch".format(time_multitask_preprocess_1/60, time_multitask_preprocess_1/batch_i)),
             ("time_multitask_preprocess_2 (count)", "{:0.4f} min/total {:0.4f} s/batch".format(time_multitask_preprocess_2/60, time_multitask_preprocess_2/batch_i)),
             ("time_multitask_feedforward (foward+pred)", "{:0.4f} min/total {:0.4f} s/batch".format(time_multitask_train/60, time_multitask_train/batch_i)),
             ("time_penalize", "{:0.4f} min/total {:0.4f} s/batch".format(time_penalize/60, time_penalize/batch_i)),
             ("time_multitask_backprop","{:0.4f} min/total {:0.4f} s/batch".format(time_backprop / 60, time_backprop / batch_i)),
             ("time_write_pred", "{:0.4f} min/total {:0.4f} s/batch".format(time_write_pred / 60, time_write_pred / batch_i)),
             ("time_multitask_get_string (get string) ", "{:0.4f} min/total {:0.4f} s/batch".format(time_multitask_postprocess/60, time_multitask_postprocess/batch_i)),
             ("time_score (score) ","{:0.4f} min/total {:0.4f} s/batch".format(time_score / 60, time_score / batch_i)),
             ("time schedule lr ", "scheduleer {:0.4} min in average".format(end_schedule_lr/batch_i)),
             ])

    print("TIME epoch {}/{} ({} step of {} size in {} mode {} pass (done mode_id {}  task:{}): {})".format(epoch, args.epochs,
                                                                                                           batch_i,
                                                                                                           args.batch_size,
                                                                                                           "predict" if optimizer is None else "train/accumulate",
                                                                                                           "backward" if back_pass else "foward",
                                                                                                           model_id,
                                                                                                           args.tasks,
                                                                                                           timing))
    log_warning(counting_failure_parralel_bpe_batch, data_label, batch_i, batch, noisy_under_splitted, skipping_batch_n_to_1, aligned, noisy_over_splitted, skip_1_t_n, skipping_evaluated_batch, verbose)

    early_stoppin_metric_val = 999
    evaluated_task = list(set(evaluated_task))
    if predict_mode:
        if writer is not None:
            # n_tokens_counter_per_task
            tensorboard_loss_writer_epoch_level_multi(writer,  mode, model_id, epoch, loss_dic_epoch,
                                                      n_tokens_counter_per_task, data_label,
                                                      penalization_dic=penalization_dic if report_penalization else None,
                                                      group_mapping=["bert.encoder.layer.*.attention.*", "bert.encoder.layer.*.intermediate.*",
                                                                     "bert.encoder.layer.*.output.*", "bert.embedding", "bert.pooler", "head"])

            tensorboard_loss_writer_epoch_level(writer, args.tasks, mode, model_id, epoch, n_batch_norm, n_batch_pos, args.append_n_mask, loss, loss_norm, loss_pos, loss_n_mask_prediction, batch_i, data_label)
        printing("TRAINING : evaluating on {} args.tasks ", var=[evaluated_task], verbose_level=1, verbose=verbose)
        reports = []
        reports, early_stoppin_metric_val, score, n_tokens = report_score_all(evaluated_task, agg_func_ls, samples_per_task_reporting, label_heuristic, score_dic, n_tokens_dic, n_sents_dic, model_id, args.tasks, args_dir,
                                                                              data_label, reports,  writer, log_perf, early_stoppin_metric_val,
                                                                              early_stoppin_metric, mode, subsample_early_stoping_metric_val, epoch)
    else:
        reports = None
    iter += batch_i
    if writing_pred:
        printing("DATA WRITTEN TO {} ", var=[dir_end_pred], verbose=verbose, verbose_level=1)
    printing("END EPOCH {} mode, iterated {} on normalisation ", var=[mode, n_task_normalize_sanity], verbose_level=1, verbose=verbose)

    # eval NER :

    if writing_pred:
        printing("SCORE computing F1 ",
                 verbose=verbose, verbose_level=1)
        f1 = evaluate(dataset_name=None,
                      dataset=None,
                      dir_end_pred=dir_end_pred,
                      prediction_file=dir_normalized, #os.path.join(dir_end_pred, "LAST_ep-prediction-fr_ftb_pos_ner-ud-test-.conll"),
                      gold_file_name=dir_gold) #os.path.join(dir_end_pred, "LAST_ep-gold--fr_ftb_pos_ner-ud-test-.conll"))
        if f1 is not None:
            f1 = f1/100
        f1_report = report_template(metric_val="f1", info_score_val="all",score_val=f1,model_full_name_val=model_id,
                                    report_path_val=None, evaluation_script_val="conlleval",
                                    data_val=data_label,
                                    model_args_dir=args_dir,
                                    n_tokens_score=None, task=args.tasks, n_sents=None, token_type="word", subsample="all",
                                    avg_per_sent=None, min_per_epoch=None)
        if predict_mode:
            reports.append(f1_report)
        if "ftb" in data_label:
            printing("WARNING : defining early_stoppin_metric_val based on F1 ", verbose=verbose, verbose_level=1)
            early_stoppin_metric_val = -f1
            printing("SCORE model {} data {} epoch {} NER : score {}", var=[model_id, data_label, epoch, f1], verbose=verbose, verbose_level=1)



    try:
        if early_stoppin_metric is not None:
            assert early_stoppin_metric_val is not None, "ERROR : early_stoppin_metric_val should have been found " \
                                                         "but was not {} sample metric {}  not found in {} (NB : MIGHT ALSO BECAUSE THE PERF DID NOT DECREASED AT ALL ) ".format(early_stoppin_metric, subsample_early_stoping_metric_val, reports)
    except Exception as e:
        print(e)
    if early_stoppin_metric_val is None:
        print("WARNING : early_stoppin_metric_val is None, score {} n_tokens {}".format(score, n_tokens))
    return loss/batch_i, iter, reports, early_stoppin_metric_val
def train_predict_eval(args, verbose=1):

    init_seed(args)

    model_dir = BERT_MODEL_DIC[
        args.bert_model]["model"] if args.bert_model else None
    encoder = BERT_MODEL_DIC[
        args.bert_model]["encoder"] if args.bert_model else None

    if args.init_args_dir is not None:
        args_checkpoint = json.load(open(args.init_args_dir, "r"))
        args.bert_model = args_checkpoint["hyperparameters"]["bert_model"]
    tokenizer = eval(BERT_MODEL_DIC[args.bert_model]["tokenizer"]
                     ) if args.bert_model else None  #, "BertTokenizer"))
    voc_tokenizer = BERT_MODEL_DIC[
        args.bert_model]["vocab"] if args.bert_model else None
    vocab_size = BERT_MODEL_DIC[
        args.bert_model]["vocab_size"] if args.bert_model else None

    debug = True
    if not debug:
        pdb.set_trace = lambda: None

    null_token_index = BERT_MODEL_DIC[args.bert_model][
        "vocab_size"]  # based on bert cased vocabulary
    description = "grid"

    # We checkpoint the model only if early_stoppin_metric gets better ,
    # early_stoppin_metric choosen in relation to the first task defined in the list
    early_stoppin_metric, subsample_early_stoping_metric_val = get_early_stopping_metric(
        tasks=args.tasks, early_stoppin_metric=None, verbose=verbose)

    printing("INFO : tasks is {} so setting early_stoppin_metric to {} ",
             var=[args.tasks, early_stoppin_metric],
             verbose=verbose,
             verbose_level=1)

    printing("INFO : environ is {} so debug set to {}",
             var=[os.environ.get("ENV", "Unkwnown"), debug],
             verbose_level=1,
             verbose=verbose)

    printing(
        "INFO : model {} batch_update_train {} batch_size {} ",
        var=[args.model_id_pref, args.batch_update_train, args.batch_size],
        verbose=verbose,
        verbose_level=1)

    run(args=args,
        voc_tokenizer=voc_tokenizer,
        vocab_size=vocab_size,
        model_dir=model_dir,
        report_full_path_shared=args.overall_report_dir,
        description=description,
        null_token_index=null_token_index,
        null_str=NULL_STR,
        model_suffix="{}".format(args.model_id_pref),
        debug=debug,
        random_iterator_train=True,
        bucket_test=False,
        compute_intersection_score_test=True,
        n_observation_max_per_epoch_train=args.n_iter_max_train
        if not args.demo_run else 2,
        n_observation_max_per_epoch_dev_test=50000 if not args.demo_run else 2,
        early_stoppin_metric=early_stoppin_metric,
        subsample_early_stoping_metric_val=subsample_early_stoping_metric_val,
        saving_every_epoch=args.saving_every_n_epoch,
        run_mode="train" if args.train else "test",
        auxilliary_task_norm_not_norm=True,
        tokenizer=tokenizer,
        max_token_per_batch=300,
        name_with_epoch=args.name_inflation,
        encoder=encoder,
        report=True,
        verbose=1)

    printing("MODEL {} trained and evaluated",
             var=[args.model_id_pref],
             verbose_level=1,
             verbose=verbose)
def get_perf_rate(metric, score_dic, n_tokens_dic, agg_func, task, verbose=1):
    """
    provides metric : the confusion matrix standart rates for the given task
    :param metric:
    :param score_dic: two level dictionay : first level for agg_func second
    for prediciton class based on CLASS_PER_TASK and task
    :param agg_func:
    :return: rate, denumerator of the rate (if means like f1 : returns all )
    """
    pdb.set_trace()
    if metric in [
            "recall-{}".format(task), "f1-{}".format(task),
            "accuracy-{}".format(task)
    ]:

        positive_obs = n_tokens_dic[agg_func][TASKS_PARAMETER[task]
                                              ["predicted_classes"][1]]
        recall = score_dic[agg_func][TASKS_PARAMETER[task]["predicted_classes"][1]] / positive_obs \
            if positive_obs > 0 else None
        if positive_obs == 0:
            printing("WARNING : no positive observation were seen ",
                     verbose=verbose,
                     verbose_level=1)
        if metric == "recall-{}".format(task):
            return recall, positive_obs
    if metric in [
            "precision-{}".format(task), "f1-{}".format(task),
            "accuracy-{}".format(task)
    ]:
        #positive_prediction = n_tokens_dic[agg_func][TASKS_PARAMETER[task]["predicted_classes"][0]] - score_dic[agg_func][TASKS_PARAMETER[task]["predicted_classes"][0]] \
        #                      + score_dic[agg_func][TASKS_PARAMETER[task]["predicted_classes"][1]]
        positive_prediction = n_tokens_dic[agg_func][
            TASKS_PARAMETER[task]["predicted_classes_pred_field"][1]]
        precision = score_dic[agg_func][
            TASKS_PARAMETER[task]["predicted_classes"]
            [1]] / positive_prediction if positive_prediction > 0 else None
        if metric == "precision-{}".format(task):
            return precision, positive_prediction
    if metric in [
            "tnr-{}".format(task), "accuracy-{}".format(task),
            "f1-{}".format(task)
    ]:
        negative_obs = n_tokens_dic[agg_func][TASKS_PARAMETER[task]
                                              ["predicted_classes"][0]]
        if metric == "tnr-{}".format(task):
            return score_dic[agg_func][TASKS_PARAMETER[task]["predicted_classes"][0]] / negative_obs if negative_obs>0 else None, \
                   negative_obs
    if metric == "f1-{}".format(task):
        if recall is not None and precision is not None and recall > 0 and precision > 0:
            return hmean([recall, precision]), negative_obs + positive_obs
        else:
            return None, negative_obs + positive_obs

    if metric in ["npv-{}".format(task)]:
        negative_prediction = n_tokens_dic[agg_func][
            TASKS_PARAMETER[task]["predicted_classes_pred_field"][0]]
        return score_dic[agg_func][
                   TASKS_PARAMETER[task]["predicted_classes"][0]] / negative_prediction if negative_prediction > 0 else None, \
               negative_prediction
    if metric == "accuracy-{}".format(task):
        accuracy = (
            score_dic[agg_func][TASKS_PARAMETER[task]["predicted_classes"][0]]
            +
            score_dic[agg_func][TASKS_PARAMETER[task]["predicted_classes"][1]]
        ) / (positive_obs +
             negative_obs) if positive_obs > 0 and negative_obs > 0 else None
        return accuracy, positive_obs + negative_obs

    raise (Exception("metric {} not supported".format(metric)))
Esempio n. 11
0
def load_batcher_shard_data(args, args_load_batcher_shard_data, shard_dir,
                            verbose):

    word_dictionary, tokenizer, word_norm_dictionary, char_dictionary,\
    pos_dictionary, xpos_dictionary, type_dictionary, use_gpu,\
    norm_not_norm, word_decoder, add_start_char, add_end_char, symbolic_end,\
    symbolic_root, bucket, max_char_len, must_get_norm, bucketing_level,\
    use_gpu_hardcoded_readers, auxilliary_task_norm_not_norm, random_iterator_train = \
        args_load_batcher_shard_data["word_dictionary"],\
        args_load_batcher_shard_data["tokenizer"], args_load_batcher_shard_data["word_norm_dictionary"], \
        args_load_batcher_shard_data["char_dictionary"], args_load_batcher_shard_data["pos_dictionary"], \
        args_load_batcher_shard_data["xpos_dictionary"], args_load_batcher_shard_data["type_dictionary"], \
        args_load_batcher_shard_data["use_gpu"], args_load_batcher_shard_data["norm_not_norm"], \
        args_load_batcher_shard_data["word_decoder"], args_load_batcher_shard_data["add_start_char"], \
        args_load_batcher_shard_data["add_end_char"], args_load_batcher_shard_data["symbolic_end"], \
        args_load_batcher_shard_data["symbolic_root"], args_load_batcher_shard_data["bucket"], \
        args_load_batcher_shard_data["max_char_len"], args_load_batcher_shard_data["must_get_norm"], args_load_batcher_shard_data["bucketing_level"], \
        args_load_batcher_shard_data["use_gpu_hardcoded_readers"], args_load_batcher_shard_data["auxilliary_task_norm_not_norm"], args_load_batcher_shard_data["random_iterator_train"],

    printing("INFO ITERATOR LOADING new batcher based on {} ",
             var=[shard_dir],
             verbose=verbose,
             verbose_level=1)
    start = time.time()
    readers = readers_load(
        datasets=shard_dir,
        tasks=args.tasks,
        args=args,
        word_dictionary=word_dictionary,
        bert_tokenizer=tokenizer,
        word_dictionary_norm=word_norm_dictionary,
        char_dictionary=char_dictionary,
        pos_dictionary=pos_dictionary,
        xpos_dictionary=xpos_dictionary,
        type_dictionary=type_dictionary,
        use_gpu=use_gpu_hardcoded_readers,
        norm_not_norm=auxilliary_task_norm_not_norm,
        word_decoder=True,
        add_start_char=1,
        add_end_char=1,
        symbolic_end=1,
        symbolic_root=1,
        bucket=True,
        max_char_len=20,
        input_level_ls=args_load_batcher_shard_data["input_level_ls"],
        must_get_norm=True,
        verbose=verbose)

    batchIter = data_gen_multi_task_sampling_batch(
        tasks=args.tasks,
        readers=readers,
        batch_size=args.batch_size,
        word_dictionary=word_dictionary,
        char_dictionary=char_dictionary,
        pos_dictionary=pos_dictionary,
        word_dictionary_norm=word_norm_dictionary,
        get_batch_mode=random_iterator_train,
        print_raw=False,
        dropout_input=0.0,
        verbose=verbose)
    end = time.time() - start
    printing("INFO ITERATOR LOADED  {:0.3f}min ",
             var=[end / 60],
             verbose=verbose,
             verbose_level=1)

    return batchIter