Esempio n. 1
0
def train(local_rank, args):
    # debug = False
    # print("GPU:", gpu)
    # world_size = args.world_size
    args.global_rank = args.node_rank * args.gpus_per_node + local_rank
    args.local_rank = local_rank
    # args.warmup_steps = 20
    debug_count = 1000
    num_epoch = args.epochs

    actual_train_batch_size = args.world_size * args.per_gpu_train_batch_size * args.gradient_accumulation_steps
    args.actual_train_batch_size = actual_train_batch_size

    set_seed(args.seed)
    num_labels = 3  # we are doing NLI so we set num_labels = 3, for other task we can change this value.

    max_length = args.max_length

    model_class_item = MODEL_CLASSES[args.model_class_name]
    model_name = model_class_item['model_name']
    do_lower_case = model_class_item[
        'do_lower_case'] if 'do_lower_case' in model_class_item else False

    tokenizer = model_class_item['tokenizer'].from_pretrained(
        model_name,
        cache_dir=str(config.PRO_ROOT / "trans_cache"),
        do_lower_case=do_lower_case)

    model = model_class_item['sequence_classification'].from_pretrained(
        model_name,
        cache_dir=str(config.PRO_ROOT / "trans_cache"),
        num_labels=num_labels)

    padding_token_value = tokenizer.convert_tokens_to_ids(
        [tokenizer.pad_token])[0]
    padding_segement_value = model_class_item["padding_segement_value"]
    padding_att_value = model_class_item["padding_att_value"]
    left_pad = model_class_item[
        'left_pad'] if 'left_pad' in model_class_item else False

    batch_size_per_gpu_train = args.per_gpu_train_batch_size
    batch_size_per_gpu_eval = args.per_gpu_eval_batch_size

    if not args.cpu and not args.single_gpu:
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=args.world_size,
                                rank=args.global_rank)

    train_data_str = args.train_data
    train_data_weights_str = args.train_weights
    eval_data_str = args.eval_data

    train_data_name = []
    train_data_path = []
    train_data_list = []
    train_data_weights = []

    eval_data_name = []
    eval_data_path = []
    eval_data_list = []

    train_data_named_path = train_data_str.split(',')
    weights_str = train_data_weights_str.split(
        ',') if train_data_weights_str is not None else None

    eval_data_named_path = eval_data_str.split(',')

    for named_path in train_data_named_path:
        ind = named_path.find(':')
        name = named_path[:ind]
        path = name[ind + 1:]
        if name in registered_path:
            d_list = common.load_jsonl(registered_path[name])
        else:
            d_list = common.load_jsonl(path)

        train_data_name.append(name)
        train_data_path.append(path)

        train_data_list.append(d_list)

    if weights_str is not None:
        for weights in weights_str:
            train_data_weights.append(float(weights))
    else:
        for i in range(len(train_data_list)):
            train_data_weights.append(1)

    for named_path in eval_data_named_path:
        ind = named_path.find(':')
        name = named_path[:ind]
        path = name[ind + 1:]
        if name in registered_path:
            d_list = common.load_jsonl(registered_path[name])
        else:
            d_list = common.load_jsonl(path)
        eval_data_name.append(name)
        eval_data_path.append(path)

        eval_data_list.append(d_list)

    assert len(train_data_weights) == len(train_data_list)

    batching_schema = {
        'uid':
        RawFlintField(),
        'y':
        LabelFlintField(),
        'input_ids':
        ArrayIndexFlintField(pad_idx=padding_token_value, left_pad=left_pad),
        'token_type_ids':
        ArrayIndexFlintField(pad_idx=padding_segement_value,
                             left_pad=left_pad),
        'attention_mask':
        ArrayIndexFlintField(pad_idx=padding_att_value, left_pad=left_pad),
    }

    data_transformer = NLITransform(model_name, tokenizer, max_length)
    # data_transformer = NLITransform(model_name, tokenizer, max_length, with_element=True)

    eval_data_loaders = []
    for eval_d_list in eval_data_list:
        d_dataset, d_sampler, d_dataloader = build_eval_dataset_loader_and_sampler(
            eval_d_list, data_transformer, batching_schema,
            batch_size_per_gpu_eval)
        eval_data_loaders.append(d_dataloader)

    # Estimate the training size:
    training_list = []
    for i in range(len(train_data_list)):
        print("Build Training Data ...")
        train_d_list = train_data_list[i]
        train_d_name = train_data_name[i]
        train_d_weight = train_data_weights[i]
        cur_train_list = sample_data_list(
            train_d_list, train_d_weight
        )  # change later  # we can apply different sample strategy here.
        print(
            f"Data Name:{train_d_name}; Weight: {train_d_weight}; "
            f"Original Size: {len(train_d_list)}; Sampled Size: {len(cur_train_list)}"
        )
        training_list.extend(cur_train_list)
    estimated_training_size = len(training_list)
    print("Estimated training size:", estimated_training_size)
    # Estimate the training size ends:

    # t_total = estimated_training_size // args.gradient_accumulation_steps * num_epoch
    t_total = estimated_training_size * num_epoch // args.actual_train_batch_size
    if args.warmup_steps <= 0:  # set the warmup steps to 0.1 * total step if the given warmup step is -1.
        args.warmup_steps = int(t_total * 0.1)

    if not args.cpu:
        torch.cuda.set_device(args.local_rank)
        model.cuda(args.local_rank)

    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            args.weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0
        },
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    if not args.cpu and not args.single_gpu:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True)

    args_dict = dict(vars(args))
    file_path_prefix = '.'
    if args.global_rank in [-1, 0]:
        print("Total Steps:", t_total)
        args.total_step = t_total
        print("Warmup Steps:", args.warmup_steps)
        print("Actual Training Batch Size:", actual_train_batch_size)
        print("Arguments", pp.pprint(args))

    # Let build the logger and log everything before the start of the first training epoch.
    if args.global_rank in [
            -1, 0
    ]:  # only do logging if we use cpu or global_rank=0
        if not args.debug_mode:
            file_path_prefix, date = save_tool.gen_file_prefix(
                f"{args.experiment_name}")
            # # # Create Log File
            # Save the source code.
            script_name = os.path.basename(__file__)
            with open(os.path.join(file_path_prefix, script_name),
                      'w') as out_f, open(__file__, 'r') as it:
                out_f.write(it.read())
                out_f.flush()

            # Save option file
            common.save_json(args_dict,
                             os.path.join(file_path_prefix, "args.json"))
            checkpoints_path = Path(file_path_prefix) / "checkpoints"
            if not checkpoints_path.exists():
                checkpoints_path.mkdir()
            prediction_path = Path(file_path_prefix) / "predictions"
            if not prediction_path.exists():
                prediction_path.mkdir()

    global_step = 0

    # print(f"Global Rank:{args.global_rank} ### ", 'Init!')

    for epoch in tqdm(range(num_epoch),
                      desc="Epoch",
                      disable=args.global_rank not in [-1, 0]):
        # Let's build up training dataset for this epoch
        training_list = []
        for i in range(len(train_data_list)):
            print("Build Training Data ...")
            train_d_list = train_data_list[i]
            train_d_name = train_data_name[i]
            train_d_weight = train_data_weights[i]
            cur_train_list = sample_data_list(
                train_d_list, train_d_weight
            )  # change later  # we can apply different sample strategy here.
            print(
                f"Data Name:{train_d_name}; Weight: {train_d_weight}; "
                f"Original Size: {len(train_d_list)}; Sampled Size: {len(cur_train_list)}"
            )
            training_list.extend(cur_train_list)

        random.shuffle(training_list)
        train_dataset = NLIDataset(training_list, data_transformer)

        train_sampler = SequentialSampler(train_dataset)
        if not args.cpu and not args.single_gpu:
            print("Use distributed sampler.")
            train_sampler = DistributedSampler(train_dataset,
                                               args.world_size,
                                               args.global_rank,
                                               shuffle=True)

        train_dataloader = DataLoader(
            dataset=train_dataset,
            batch_size=batch_size_per_gpu_train,
            shuffle=False,  #
            num_workers=0,
            pin_memory=True,
            sampler=train_sampler,
            collate_fn=BaseBatchBuilder(batching_schema))  #
        # training build finished.

        print(debug_node_info(args), "epoch: ", epoch)

        if not args.cpu and not args.single_gpu:
            train_sampler.set_epoch(
                epoch
            )  # setup the epoch to ensure random sampling at each epoch

        for forward_step, batch in enumerate(
                tqdm(train_dataloader,
                     desc="Iteration",
                     disable=args.global_rank not in [-1, 0]), 0):
            model.train()

            batch = move_to_device(batch, local_rank)
            # print(batch['input_ids'], batch['y'])
            if args.model_class_name in ["distilbert", "bart-large"]:
                outputs = model(batch['input_ids'],
                                attention_mask=batch['attention_mask'],
                                labels=batch['y'])
            else:
                outputs = model(batch['input_ids'],
                                attention_mask=batch['attention_mask'],
                                token_type_ids=batch['token_type_ids'],
                                labels=batch['y'])
            loss, logits = outputs[:2]
            # print(debug_node_info(args), loss, logits, batch['uid'])
            # print(debug_node_info(args), loss, batch['uid'])

            # Accumulated loss
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            # if this forward step need model updates
            # handle fp16
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

                # Gradient clip: if max_grad_norm < 0
            if (forward_step + 1) % args.gradient_accumulation_steps == 0:
                if args.max_grad_norm > 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()

                global_step += 1

                if args.global_rank in [
                        -1, 0
                ] and args.eval_frequency > 0 and global_step % args.eval_frequency == 0:
                    r_dict = dict()
                    # Eval loop:
                    for i in range(len(eval_data_name)):
                        cur_eval_data_name = eval_data_name[i]
                        cur_eval_data_list = eval_data_list[i]
                        cur_eval_dataloader = eval_data_loaders[i]
                        # cur_eval_raw_data_list = eval_raw_data_list[i]

                        evaluation_dataset(args,
                                           cur_eval_dataloader,
                                           cur_eval_data_list,
                                           model,
                                           r_dict,
                                           eval_name=cur_eval_data_name)

                    # saving checkpoints
                    current_checkpoint_filename = \
                        f'e({epoch})|i({global_step})'

                    for i in range(len(eval_data_name)):
                        cur_eval_data_name = eval_data_name[i]
                        current_checkpoint_filename += \
                            f'|{cur_eval_data_name}#({round(r_dict[cur_eval_data_name]["acc"], 4)})'

                    if not args.debug_mode:
                        # save model:
                        model_output_dir = checkpoints_path / current_checkpoint_filename
                        if not model_output_dir.exists():
                            model_output_dir.mkdir()
                        model_to_save = (
                            model.module if hasattr(model, "module") else model
                        )  # Take care of distributed/parallel training

                        torch.save(model_to_save.state_dict(),
                                   str(model_output_dir / "model.pt"))
                        torch.save(optimizer.state_dict(),
                                   str(model_output_dir / "optimizer.pt"))
                        torch.save(scheduler.state_dict(),
                                   str(model_output_dir / "scheduler.pt"))

                    # save prediction:
                    if not args.debug_mode and args.save_prediction:
                        cur_results_path = prediction_path / current_checkpoint_filename
                        if not cur_results_path.exists():
                            cur_results_path.mkdir(parents=True)
                        for key, item in r_dict.items():
                            common.save_jsonl(
                                item['predictions'],
                                cur_results_path / f"{key}.jsonl")

                        # avoid saving too many things
                        for key, item in r_dict.items():
                            del r_dict[key]['predictions']
                        common.save_json(r_dict,
                                         cur_results_path /
                                         "results_dict.json",
                                         indent=2)

        # End of epoch evaluation.
        if args.global_rank in [-1, 0]:
            r_dict = dict()
            # Eval loop:
            for i in range(len(eval_data_name)):
                cur_eval_data_name = eval_data_name[i]
                cur_eval_data_list = eval_data_list[i]
                cur_eval_dataloader = eval_data_loaders[i]
                # cur_eval_raw_data_list = eval_raw_data_list[i]

                evaluation_dataset(args,
                                   cur_eval_dataloader,
                                   cur_eval_data_list,
                                   model,
                                   r_dict,
                                   eval_name=cur_eval_data_name)

            # saving checkpoints
            current_checkpoint_filename = \
                f'e({epoch})|i({global_step})'

            for i in range(len(eval_data_name)):
                cur_eval_data_name = eval_data_name[i]
                current_checkpoint_filename += \
                    f'|{cur_eval_data_name}#({round(r_dict[cur_eval_data_name]["acc"], 4)})'

            if not args.debug_mode:
                # save model:
                model_output_dir = checkpoints_path / current_checkpoint_filename
                if not model_output_dir.exists():
                    model_output_dir.mkdir()
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training

                torch.save(model_to_save.state_dict(),
                           str(model_output_dir / "model.pt"))
                torch.save(optimizer.state_dict(),
                           str(model_output_dir / "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           str(model_output_dir / "scheduler.pt"))

            # save prediction:
            if not args.debug_mode and args.save_prediction:
                cur_results_path = prediction_path / current_checkpoint_filename
                if not cur_results_path.exists():
                    cur_results_path.mkdir(parents=True)
                for key, item in r_dict.items():
                    common.save_jsonl(item['predictions'],
                                      cur_results_path / f"{key}.jsonl")

                # avoid saving too many things
                for key, item in r_dict.items():
                    del r_dict[key]['predictions']
                common.save_json(r_dict,
                                 cur_results_path / "results_dict.json",
                                 indent=2)
Esempio n. 2
0
def model_go_with_old_data():
    seed = 12
    torch.manual_seed(seed)
    # bert_model_name = 'bert-large-uncased'
    bert_model_name = 'bert-base-uncased'
    experiment_name = 'fever_v1_nli'
    lazy = False
    # lazy = True
    forward_size = 16
    # batch_size = 64
    # batch_size = 192
    batch_size = 32
    gradient_accumulate_step = int(batch_size / forward_size)
    warmup_proportion = 0.1
    learning_rate = 5e-5
    num_train_epochs = 3
    eval_frequency = 2000
    do_lower_case = True
    pair_order = 'cq'
    # debug_mode = True
    debug_mode = False
    # est_datasize = 900_000

    num_class = 3
    # num_train_optimization_steps

    train_sent_filtering_prob = 0.35
    dev_sent_filtering_prob = 0.1

    # dev_sent_results_file = config.RESULT_PATH / "doc_retri_results/fever_results/sent_results/4-14-sent_results_v0/i(5000)|e(0)|s01(0.9170917091709171)|s05(0.8842384238423843)|seed(12)_dev_sent_results.json"
    # train_sent_results_file = config.RESULT_PATH / "doc_retri_results/fever_results/sent_results/4-14-sent_results_v0/train_sent_results.jsonl"
    from utest.utest_format_converter_for_old_sent.tool import format_convert
    dev_sent_results_file = format_convert(
        config.PRO_ROOT /
        "results/doc_retri_results/fever_results/sent_results/old_sent_data_by_NSMN/4-15-dev_sent_pred_scores_old_format.jsonl"
    )
    train_sent_results_file = format_convert(
        config.PRO_ROOT /
        "results/doc_retri_results/fever_results/sent_results/old_sent_data_by_NSMN/train_sent_scores_old_format.jsonl"
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_num = 0 if torch.cuda.is_available() else -1

    n_gpu = torch.cuda.device_count()

    unk_token_num = {'tokens': 1}  # work around for initiating vocabulary.
    vocab = ExVocabulary(unk_token_num=unk_token_num)
    vocab.add_token_to_namespace('SUPPORTS', namespace='labels')
    vocab.add_token_to_namespace('REFUTES', namespace='labels')
    vocab.add_token_to_namespace('NOT ENOUGH INFO', namespace='labels')
    vocab.add_token_to_namespace("hidden", namespace="labels")
    vocab.change_token_with_index_to_namespace("hidden",
                                               -2,
                                               namespace='labels')

    # Load Dataset
    # train_fitems_list = get_inference_pair('train', True, train_sent_results_file, debug_mode, train_sent_filtering_prob)
    dev_debug_num = 2481 if debug_mode else None
    dev_fitems_list, dev_list = get_inference_pair('dev', False,
                                                   dev_sent_results_file,
                                                   dev_debug_num,
                                                   dev_sent_filtering_prob)
    # = common.load_jsonl(config.FEVER_DEV)

    if debug_mode:
        dev_list = dev_list[:50]
        eval_frequency = 1
        # print(dev_list[-1]['_id'])
        # exit(0)

    # sampled_train_list = down_sample_neg(train_fitems_list, ratio=pos_ratio)
    train_debug_num = 2971 if debug_mode else None
    train_fitems_list, _ = get_inference_pair('train', True,
                                              train_sent_results_file,
                                              train_debug_num,
                                              train_sent_filtering_prob)
    est_datasize = len(train_fitems_list)

    # dev_o_dict = list_dict_data_tool.list_to_dict(dev_list, 'id')
    # print(dev_o_dict)

    bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name,
                                                   do_lower_case=do_lower_case)
    bert_cs_reader = BertFeverNLIReader(bert_tokenizer,
                                        lazy,
                                        is_paired=True,
                                        query_l=64,
                                        example_filter=None,
                                        max_l=364,
                                        pair_order=pair_order)

    bert_encoder = BertModel.from_pretrained(bert_model_name)
    model = BertMultiLayerSeqClassification(bert_encoder,
                                            num_labels=num_class,
                                            num_of_pooling_layer=1,
                                            act_type='tanh',
                                            use_pretrained_pooler=True,
                                            use_sigmoid=False)
    #
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    num_train_optimization_steps = int(est_datasize / forward_size / gradient_accumulate_step) * \
                                   num_train_epochs

    if debug_mode:
        num_train_optimization_steps = 100

    print("Estimated training size", est_datasize)
    print("Number of optimization steps:", num_train_optimization_steps)

    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=learning_rate,
                         warmup=warmup_proportion,
                         t_total=num_train_optimization_steps)

    dev_instances = bert_cs_reader.read(dev_fitems_list)

    biterator = BasicIterator(batch_size=forward_size)
    biterator.index_with(vocab)

    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    forbackward_step = 0
    update_step = 0

    logging_agent = save_tool.ScoreLogger({})

    file_path_prefix = '.'
    if not debug_mode:
        file_path_prefix, date = save_tool.gen_file_prefix(
            f"{experiment_name}")
        # # # Create Log File
        # Save the source code.
        script_name = os.path.basename(__file__)
        with open(os.path.join(file_path_prefix, script_name),
                  'w') as out_f, open(__file__, 'r') as it:
            out_f.write(it.read())
            out_f.flush()
        # # # Log File end

    for epoch_i in range(num_train_epochs):
        print("Epoch:", epoch_i)

        train_fitems_list, _ = get_inference_pair('train', True,
                                                  train_sent_results_file,
                                                  train_debug_num,
                                                  train_sent_filtering_prob)
        random.shuffle(train_fitems_list)
        train_instance = bert_cs_reader.read(train_fitems_list)
        train_iter = biterator(train_instance, num_epochs=1, shuffle=True)

        for batch in tqdm(train_iter):
            model.train()
            batch = move_to_device(batch, device_num)

            paired_sequence = batch['paired_sequence']
            paired_segments_ids = batch['paired_segments_ids']
            labels_ids = batch['label']
            att_mask, _ = torch_util.get_length_and_mask(paired_sequence)
            s1_span = batch['bert_s1_span']
            s2_span = batch['bert_s2_span']

            loss = model(
                paired_sequence,
                token_type_ids=paired_segments_ids,
                attention_mask=att_mask,
                mode=BertMultiLayerSeqClassification.ForwardMode.TRAIN,
                labels=labels_ids)

            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.

            if gradient_accumulate_step > 1:
                loss = loss / gradient_accumulate_step

            loss.backward()
            forbackward_step += 1

            if forbackward_step % gradient_accumulate_step == 0:
                optimizer.step()
                optimizer.zero_grad()
                update_step += 1

                if update_step % eval_frequency == 0:
                    print("Update steps:", update_step)
                    dev_iter = biterator(dev_instances,
                                         num_epochs=1,
                                         shuffle=False)

                    cur_eval_results_list = eval_model(model,
                                                       dev_iter,
                                                       device_num,
                                                       with_probs=True,
                                                       make_int=True)

                    results_dict = list_dict_data_tool.list_to_dict(
                        cur_eval_results_list, 'oid')
                    copied_dev_list = copy.deepcopy(dev_list)
                    list_dict_data_tool.append_item_from_dict_to_list(
                        copied_dev_list, results_dict, 'id', 'predicted_label')

                    mode = {'standard': True}
                    strict_score, acc_score, pr, rec, f1 = fever_scorer.fever_score(
                        copied_dev_list,
                        dev_fitems_list,
                        mode=mode,
                        max_evidence=5)
                    logging_item = {
                        'ss': strict_score,
                        'ac': acc_score,
                        'pr': pr,
                        'rec': rec,
                        'f1': f1,
                    }

                    save_file_name = f'i({update_step})|e({epoch_i})' \
                        f'|ss({strict_score})|ac({acc_score})|pr({pr})|rec({rec})|f1({f1})' \
                        f'|seed({seed})'

                    common.save_jsonl(
                        copied_dev_list,
                        Path(file_path_prefix) /
                        f"{save_file_name}_dev_nli_results.json")

                    # print(save_file_name)
                    logging_agent.incorporate_results({}, save_file_name,
                                                      logging_item)
                    logging_agent.logging_to_file(
                        Path(file_path_prefix) / "log.json")

                    model_to_save = model.module if hasattr(
                        model, 'module') else model
                    output_model_file = Path(file_path_prefix) / save_file_name
                    torch.save(model_to_save.state_dict(),
                               str(output_model_file))
Esempio n. 3
0
def model_go_pure_aug():
    # for some_params in [0.25, 0.25, 0.25]:
    for some_params in [0.25, 0.25, 0.25]:
        # bert_model_name = 'bert-large-uncased'
        seed = 6
        bert_model_name = 'bert-base-uncased'
        lazy = False
        forward_size = 16
        batch_size = 32
        gradient_accumulate_step = int(batch_size / forward_size)
        warmup_proportion = 0.1
        learning_rate = 5e-5
        num_train_epochs = 3
        do_ema = False
        dev_prob_threshold = 0.1
        train_prob_threshold = 0.35
        debug_mode = False
        # experiment_name = f"bert_fever_nli_baseline_on_fulldata"
        # experiment_name = f"bert_fever_nli_baseline_on_fulldata_aug_the_same_gt_mrate({some_params})"
        # experiment_name = f"bert_fever_nli_baseline_on_10p_aug_ratio({some_params})"
        experiment_name = f"bert_fever_nli_baseline_on_fulldata_aug_ratio({some_params})"
        # experiment_name = f"bert_fever_nli_baseline_pure_aug"

        data_aug = True
        # data_aug_file = config.FEVER_DATA_ROOT / "qa_aug/squad_train_turker_groundtruth.json"
        # data_aug_size = int(21_015 * some_params)   # 10p
        # data_aug_size = int(208_346 * some_params)

        # training_file = config.FEVER_DATA_ROOT / "fever_1.0/train_10.jsonl"
        training_file = config.FEVER_DATA_ROOT / "fever_1.0/train.jsonl"

        train_sample_top_k = 8

        # est_datasize = 208_346    # full
        # est_datasize = 14_544
        # est_datasize = 21_015 + data_aug_size   # 10p
        aug_size = int(208_346 * some_params)
        est_datasize = 208_346 + aug_size
        # est_datasize = 208_346 + data_aug_size

        num_class = 3

        # num_train_optimization_steps
        torch.manual_seed(seed)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        n_gpu = torch.cuda.device_count()

        unk_token_num = {'tokens': 1}  # work around for initiating vocabulary.
        vocab = ExVocabulary(unk_token_num=unk_token_num)
        vocab.add_token_to_namespace('SUPPORTS', namespace='labels')
        vocab.add_token_to_namespace('REFUTES', namespace='labels')
        vocab.add_token_to_namespace('NOT ENOUGH INFO', namespace='labels')
        vocab.add_token_to_namespace("hidden", namespace="labels")
        vocab.change_token_with_index_to_namespace("hidden", -2, namespace='labels')
        # Finished build vocabulary.

        # Load standardized sentence file
        dev_upstream_sent_list = common.load_jsonl(config.FEVER_DATA_ROOT /
                                                   "upstream_sentence_selection_Feb16/dev_sent_pred_scores.jsonl")
        dev_sent_after_threshold_filter = fever_ss_sampler.threshold_sampler_insure_unique(
            config.FEVER_DATA_ROOT / "fever_1.0/shared_task_dev.jsonl",
            dev_upstream_sent_list,
            prob_threshold=dev_prob_threshold, top_n=5)

        dev_data_list = fever_nli_sampler.select_sent_with_prob_for_eval(
            config.FEVER_DATA_ROOT / "fever_1.0/shared_task_dev.jsonl", dev_sent_after_threshold_filter,
            None, tokenized=True)

        # print(dev_data_list[0])
        # exit(0)

        train_upstream_sent_list = common.load_jsonl(config.FEVER_DATA_ROOT /
                                                     "upstream_sentence_selection_Feb16/train_sent_scores.jsonl")
        # Finished loading standardized sentence file.

        bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name, do_lower_case=True)

        bert_fever_reader = BertReaderFeverNLI(bert_tokenizer, lazy=lazy)

        dev_instances = bert_fever_reader.read(dev_data_list)

        biterator = BasicIterator(batch_size=forward_size)
        biterator.index_with(vocab)

        # print(list(mnli_dev_instances))

        # Load training model
        # Load training model
        model_clf = BertForSequenceClassification.from_pretrained(bert_model_name, num_labels=num_class)

        ema_tracker = None
        ema_model_copy = None
        if do_ema and ema_tracker is None:
            ema_tracker = EMA(model_clf.named_parameters(), on_cpu=True)
            ema_model_copy = copy.deepcopy(model_clf)

        model_clf.to(device)

        param_optimizer = list(model_clf.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        num_train_optimization_steps = int(est_datasize / forward_size / gradient_accumulate_step) * \
                                       num_train_epochs

        print(num_train_optimization_steps)

        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=learning_rate,
                             warmup=warmup_proportion,
                             t_total=num_train_optimization_steps)

        # optimizer = optim.Adam(optimizer_grouped_parameters, lr=learning_rate)

        # # # Create Log File
        file_path_prefix, date = save_tool.gen_file_prefix(f"{experiment_name}")
        # Save the source code.
        script_name = os.path.basename(__file__)
        with open(os.path.join(file_path_prefix, script_name), 'w') as out_f, open(__file__, 'r') as it:
            out_f.write(it.read())
            out_f.flush()
        # # # Log File end

        model_clf.train()

        if n_gpu > 1:
            model_clf = nn.DataParallel(model_clf)

        forbackward_step = 0
        update_step = 0
        eval_iter_num = 2_000  # Change this to real evaluation.
        best_fever_score = -1

        for n_epoch in range(num_train_epochs):
            print("Resampling...")
            train_sent_after_threshold_filter = \
                fever_ss_sampler.threshold_sampler_insure_unique(training_file,
                                                                 train_upstream_sent_list,
                                                                 train_prob_threshold,
                                                                 top_n=train_sample_top_k)
            #
            train_data_list = fever_nli_sampler.adv_simi_sample_with_prob_v1_1(
                training_file,
                train_sent_after_threshold_filter,
                None,
                tokenized=True)

            aug_d_list = []
            if data_aug:
                aug_d_list = get_sample_data(-1)
                random.shuffle(aug_d_list)
                aug_d_list = aug_d_list[:aug_size]

            train_data_list = train_data_list + aug_d_list

            random.shuffle(train_data_list)
            # train_data_list = get_sample_data(-1)
            print("Sample data length:", len(train_data_list))
            sampled_train_instances = bert_fever_reader.read(train_data_list)
            #
            train_iter = biterator(sampled_train_instances, shuffle=True, num_epochs=1)

            for i, batch in enumerate(tqdm(train_iter)):
                paired_sequence = batch['paired_sequence']
                paired_segments_ids = batch['paired_segments_ids']
                labels_ids = batch['label']
                att_mask, _ = torch_util.get_length_and_mask(paired_sequence)

                paired_sequence = paired_sequence.to(device)
                paired_segments_ids = paired_segments_ids.to(device)
                labels_ids = labels_ids.to(device)
                att_mask = att_mask.to(device)

                loss = model_clf(paired_sequence, token_type_ids=paired_segments_ids, attention_mask=att_mask,
                                 labels=labels_ids)

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.

                if gradient_accumulate_step > 1:
                    loss = loss / gradient_accumulate_step

                loss.backward()
                forbackward_step += 1

                if forbackward_step % gradient_accumulate_step == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    update_step += 1
                    if do_ema and ema_tracker is not None:
                        # if model_clf is DataParallel, then we use model_clf.module
                        model_to_track = model_clf.module if hasattr(model_clf,
                                                                     'module') else model_clf
                        ema_tracker(model_to_track.named_parameters())  # Whenever we do update, the do ema update

                    if update_step % eval_iter_num == 0:
                        print("Update steps:", update_step)
                        dev_iter = biterator(dev_instances, num_epochs=1, shuffle=False)

                        if do_ema and ema_model_copy is not None and ema_tracker is not None:
                            print("EMA evaluation.")
                            EMA.load_ema_to_model(ema_model_copy, ema_tracker)
                            ema_model_copy.to(device)
                            if n_gpu > 1:
                                ema_model_copy = nn.DataParallel(ema_model_copy)
                            dev_data_list = hidden_eval(ema_model_copy, dev_iter, dev_data_list, device)
                        else:
                            dev_data_list = hidden_eval(model_clf, dev_iter, dev_data_list, device)

                        eval_mode = {'check_sent_id_correct': True, 'standard': True}
                        fever_score, label_score, pr, rec, f1 = fever_scorer.fever_score(dev_data_list,
                                                                                         common.load_jsonl(config.FEVER_DATA_ROOT / "fever_1.0/shared_task_dev.jsonl"),
                                                                                         mode=eval_mode,
                                                                                         verbose=False)
                        print("Fever Score(FScore/LScore:/Precision/Recall/F1):", fever_score, label_score, pr, rec, f1)

                        print(f"Dev:{fever_score}/{label_score}")

                        if best_fever_score < fever_score:
                            print("New Best FScore")
                            best_fever_score = fever_score

                            save_path = os.path.join(
                                file_path_prefix,
                                f'i({update_step})_epoch({n_epoch})_dev({fever_score})_lacc({label_score})_seed({seed})'
                            )
                            model_to_save = model_clf.module if hasattr(model_clf,
                                                                        'module') else model_clf
                            output_model_file = os.path.join(file_path_prefix, save_path)
                            torch.save(model_to_save.state_dict(), output_model_file)

            print("Update steps:", update_step)
            dev_iter = biterator(dev_instances, num_epochs=1, shuffle=False)

            if do_ema and ema_model_copy is not None and ema_tracker is not None:
                print("EMA evaluation.")
                EMA.load_ema_to_model(ema_model_copy, ema_tracker)
                ema_model_copy.to(device)
                if n_gpu > 1:
                    ema_model_copy = nn.DataParallel(ema_model_copy)
                dev_data_list = hidden_eval(ema_model_copy, dev_iter, dev_data_list, device)
            else:
                dev_data_list = hidden_eval(model_clf, dev_iter, dev_data_list, device)

            eval_mode = {'check_sent_id_correct': True, 'standard': True}
            fever_score, label_score, pr, rec, f1 = fever_scorer.fever_score(dev_data_list,
                                                                             common.load_jsonl(config.FEVER_DATA_ROOT / "fever_1.0/shared_task_dev.jsonl"),
                                                                             mode=eval_mode,
                                                                             verbose=False)
            print("Fever Score(FScore/LScore:/Precision/Recall/F1):", fever_score, label_score, pr, rec, f1)

            print(f"Dev:{fever_score}/{label_score}")

            if best_fever_score < fever_score:
                print("New Best FScore")
                best_fever_score = fever_score

                save_path = os.path.join(
                    file_path_prefix,
                    f'i({update_step})_epoch({n_epoch})_dev({fever_score})_lacc({label_score})_seed({seed})'
                )
                model_to_save = model_clf.module if hasattr(model_clf,
                                                            'module') else model_clf
                output_model_file = os.path.join(file_path_prefix, save_path)
                torch.save(model_to_save.state_dict(), output_model_file)
Esempio n. 4
0
def model_go():
    seed = 12
    torch.manual_seed(seed)
    # bert_model_name = 'bert-large-uncased'
    bert_model_name = 'bert-base-uncased'
    experiment_name = 'hotpot_v0_cs'
    lazy = False
    # lazy = True
    forward_size = 16
    # batch_size = 64
    batch_size = 128
    gradient_accumulate_step = int(batch_size / forward_size)
    warmup_proportion = 0.1
    learning_rate = 5e-5
    num_train_epochs = 5
    eval_frequency = 5000
    pos_ratio = 0.2
    do_lower_case = True

    debug_mode = False
    # est_datasize = 900_000

    num_class = 1
    # num_train_optimization_steps

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_num = 0 if torch.cuda.is_available() else -1

    n_gpu = torch.cuda.device_count()

    unk_token_num = {'tokens': 1}  # work around for initiating vocabulary.
    vocab = ExVocabulary(unk_token_num=unk_token_num)
    vocab.add_token_to_namespace("false", namespace="labels")  # 0
    vocab.add_token_to_namespace("true", namespace="labels")  # 1
    vocab.add_token_to_namespace("hidden", namespace="labels")
    vocab.change_token_with_index_to_namespace("hidden", -2, namespace='labels')

    # Load Dataset
    train_list = common.load_json(config.TRAIN_FILE)
    dev_list = common.load_json(config.DEV_FULLWIKI_FILE)

    dev_fitems_list = common.load_jsonl(
        config.PDATA_ROOT / "content_selection_forward" / "hotpot_dev_p_level_unlabeled.jsonl")
    train_fitems_list = common.load_jsonl(
        config.PDATA_ROOT / "content_selection_forward" / "hotpot_train_p_level_labeled.jsonl")

    if debug_mode:
        dev_list = dev_list[:10]
        dev_fitems_list = dev_fitems_list[:296]
        train_fitems_list = train_fitems_list[:300]
        eval_frequency = 2
        # print(dev_list[-1]['_id'])
        # exit(0)

    sampled_train_list = down_sample_neg(train_fitems_list, ratio=pos_ratio)
    est_datasize = len(sampled_train_list)

    dev_o_dict = list_dict_data_tool.list_to_dict(dev_list, '_id')
    # print(dev_o_dict)

    bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name, do_lower_case=do_lower_case)
    bert_cs_reader = BertContentSelectionReader(bert_tokenizer, lazy, is_paired=True,
                                                example_filter=lambda x: len(x['context']) == 0, max_l=286)

    bert_encoder = BertModel.from_pretrained(bert_model_name)
    model = BertMultiLayerSeqClassification(bert_encoder, num_labels=num_class, num_of_pooling_layer=1,
                                            act_type='tanh', use_pretrained_pooler=True, use_sigmoid=True)

    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)
    #
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    num_train_optimization_steps = int(est_datasize / forward_size / gradient_accumulate_step) * \
                                   num_train_epochs

    print("Estimated training size", est_datasize)
    print("Number of optimization steps:", num_train_optimization_steps)

    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=learning_rate,
                         warmup=warmup_proportion,
                         t_total=num_train_optimization_steps)

    dev_instances = bert_cs_reader.read(dev_fitems_list)

    biterator = BasicIterator(batch_size=forward_size)
    biterator.index_with(vocab)

    forbackward_step = 0
    update_step = 0

    logging_agent = save_tool.ScoreLogger({})

    # # # Create Log File
    file_path_prefix, date = save_tool.gen_file_prefix(f"{experiment_name}")
    # Save the source code.
    script_name = os.path.basename(__file__)
    with open(os.path.join(file_path_prefix, script_name), 'w') as out_f, open(__file__, 'r') as it:
        out_f.write(it.read())
        out_f.flush()
    # # # Log File end

    for epoch_i in range(num_train_epochs):
        print("Epoch:", epoch_i)
        sampled_train_list = down_sample_neg(train_fitems_list, ratio=pos_ratio)
        train_instance = bert_cs_reader.read(sampled_train_list)
        train_iter = biterator(train_instance, num_epochs=1, shuffle=True)

        for batch in tqdm(train_iter):
            model.train()
            batch = move_to_device(batch, device_num)

            paired_sequence = batch['paired_sequence']
            paired_segments_ids = batch['paired_segments_ids']
            labels_ids = batch['label']
            att_mask, _ = torch_util.get_length_and_mask(paired_sequence)
            s1_span = batch['bert_s1_span']
            s2_span = batch['bert_s2_span']

            loss = model(paired_sequence, token_type_ids=paired_segments_ids, attention_mask=att_mask,
                         mode=BertMultiLayerSeqClassification.ForwardMode.TRAIN,
                         labels=labels_ids)

            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.

            if gradient_accumulate_step > 1:
                loss = loss / gradient_accumulate_step

            loss.backward()
            forbackward_step += 1

            if forbackward_step % gradient_accumulate_step == 0:
                optimizer.step()
                optimizer.zero_grad()
                update_step += 1

                if update_step % eval_frequency == 0:
                    print("Update steps:", update_step)
                    dev_iter = biterator(dev_instances, num_epochs=1, shuffle=False)

                    cur_eval_results_list = eval_model(model, dev_iter, device_num, with_probs=True)
                    copied_dev_o_dict = copy.deepcopy(dev_o_dict)
                    list_dict_data_tool.append_subfield_from_list_to_dict(cur_eval_results_list, copied_dev_o_dict,
                                                                          'qid', 'fid', check=True)
                    # Top_5
                    cur_results_dict_top5 = select_top_k_and_to_results_dict(copied_dev_o_dict, top_k=5)
                    upperbound_results_dict_top5 = append_gt_downstream_to_get_upperbound_from_doc_retri(
                        cur_results_dict_top5,
                        dev_list)

                    cur_results_dict_top10 = select_top_k_and_to_results_dict(copied_dev_o_dict, top_k=10)
                    upperbound_results_dict_top10 = append_gt_downstream_to_get_upperbound_from_doc_retri(
                        cur_results_dict_top10,
                        dev_list)

                    _, metrics_top5 = ext_hotpot_eval.eval(cur_results_dict_top5, dev_list, verbose=False)
                    _, metrics_top5_UB = ext_hotpot_eval.eval(upperbound_results_dict_top5, dev_list, verbose=False)

                    _, metrics_top10 = ext_hotpot_eval.eval(cur_results_dict_top10, dev_list, verbose=False)
                    _, metrics_top10_UB = ext_hotpot_eval.eval(upperbound_results_dict_top10, dev_list, verbose=False)

                    # top5_doc_f1, top5_UB_sp_f1, top10_doc_f1, top10_Ub_sp_f1
                    # top5_doc_f1 = metrics_top5['doc_f1']
                    # top5_UB_sp_f1 = metrics_top5_UB['sp_f1']
                    # top10_doc_f1 = metrics_top10['doc_f1']
                    # top10_Ub_sp_f1 = metrics_top10_UB['sp_f1']

                    top5_doc_recall = metrics_top5['doc_recall']
                    top5_UB_sp_recall = metrics_top5_UB['sp_recall']
                    top10_doc_recall = metrics_top10['doc_recall']
                    top10_Ub_sp_recall = metrics_top10_UB['sp_recall']

                    logging_item = {
                        'top5': metrics_top5,
                        'top5_UB': metrics_top5_UB,
                        'top10': metrics_top10,
                        'top10_UB': metrics_top10_UB,
                    }

                    # print(logging_item)
                    save_file_name = f'i({update_step})|e({epoch_i})' \
                        f'|t5_doc_recall({top5_doc_recall})|t5_sp_recall({top5_UB_sp_recall})' \
                        f'|t10_doc_recall({top10_doc_recall})|t5_sp_recall({top10_Ub_sp_recall})|seed({seed})'

                    # print(save_file_name)
                    logging_agent.incorporate_results({}, save_file_name, logging_item)
                    logging_agent.logging_to_file(Path(file_path_prefix) / "log.json")

                    model_to_save = model.module if hasattr(model, 'module') else model
                    output_model_file = Path(file_path_prefix) / save_file_name
                    torch.save(model_to_save.state_dict(), str(output_model_file))
Esempio n. 5
0
def model_go(sent_filter_value, sent_top_k=5):
    seed = 12
    torch.manual_seed(seed)

    bert_pretrain_path = config.PRO_ROOT / '.pytorch_pretrained_bert'
    bert_model_name = "bert-base-uncased"
    lazy = False
    forward_size = 32
    batch_size = 32
    gradient_accumulate_step = int(batch_size / forward_size)
    warmup_rate = 0.1
    learning_rate = 5e-5
    num_train_epochs = 5
    eval_frequency = 1000

    do_lower_case = True

    debug = False

    max_pre_context_length = 320
    max_query_length = 64
    doc_stride = 128
    qa_num_of_layer = 2
    do_ema = True
    ema_device_num = 1
    # s_filter_value = 0.5
    s_filter_value = sent_filter_value
    # s_top_k = 5
    s_top_k = sent_top_k

    experiment_name = f'hotpot_v0_qa_(s_top_k:{s_top_k},s_fv:{s_filter_value},qa_layer:{qa_num_of_layer})'

    print("Potential total length:",
          max_pre_context_length + max_query_length + 3)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_num = 0 if torch.cuda.is_available() else -1

    n_gpu = torch.cuda.device_count()

    tokenizer = BertTokenizer.from_pretrained(bert_model_name,
                                              do_lower_case=do_lower_case,
                                              cache_dir=bert_pretrain_path)

    # Load Dataset.
    dev_list = common.load_json(config.DEV_FULLWIKI_FILE)
    train_list = common.load_json(config.TRAIN_FILE)

    dev_sentence_level_results = common.load_jsonl(
        config.PRO_ROOT /
        "data/p_hotpotqa/hotpotqa_sentence_level/04-19-02:17:11_hotpot_v0_slevel_retri_(doc_top_k:2)/i(12000)|e(2)|v02_f1(0.7153646038858843)|v02_recall(0.7114645831323757)|v05_f1(0.7153646038858843)|v05_recall(0.7114645831323757)|seed(12)/dev_s_level_bert_v1_results.jsonl"
    )
    train_sentence_level_results = common.load_jsonl(
        config.PRO_ROOT /
        "data/p_hotpotqa/hotpotqa_sentence_level/04-19-02:17:11_hotpot_v0_slevel_retri_(doc_top_k:2)/i(12000)|e(2)|v02_f1(0.7153646038858843)|v02_recall(0.7114645831323757)|v05_f1(0.7153646038858843)|v05_recall(0.7114645831323757)|seed(12)/train_s_level_bert_v1_results.jsonl"
    )

    dev_fitem_dict, dev_fitem_list, dev_sp_results_dict = get_qa_item_with_upstream_sentence(
        dev_list,
        dev_sentence_level_results,
        is_training=False,
        tokenizer=tokenizer,
        max_context_length=max_pre_context_length,
        max_query_length=max_query_length,
        filter_value=s_filter_value,
        doc_stride=doc_stride,
        top_k=s_top_k,
        debug_mode=debug)

    train_fitem_dict, train_fitem_list, _ = get_qa_item_with_upstream_sentence(
        train_list,
        train_sentence_level_results,
        is_training=True,
        tokenizer=tokenizer,
        max_context_length=max_pre_context_length,
        max_query_length=max_query_length,
        filter_value=s_filter_value,
        doc_stride=doc_stride,
        top_k=s_top_k,
        debug_mode=debug)

    # print(len(dev_fitem_list))
    # print(len(dev_fitem_dict))

    dev_o_dict = list_dict_data_tool.list_to_dict(dev_list, '_id')

    if debug:
        dev_list = dev_list[:100]
        eval_frequency = 2

    est_datasize = len(train_fitem_list)

    span_pred_reader = BertPairedSpanPredReader(bert_tokenizer=tokenizer,
                                                lazy=lazy,
                                                example_filter=None)
    bert_encoder = BertModel.from_pretrained(bert_model_name,
                                             cache_dir=bert_pretrain_path)
    model = BertSpan(bert_encoder, qa_num_of_layer)

    ema = None
    if do_ema:
        ema = EMA(model, model.named_parameters(), device_num=ema_device_num)

    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)
    iterator = BasicIterator(batch_size=batch_size)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    print("Total train instances:", len(train_fitem_list))

    num_train_optimization_steps = int(est_datasize / forward_size / gradient_accumulate_step) * \
                                   num_train_epochs

    if debug:
        num_train_optimization_steps = 100

    print("Estimated training size", est_datasize)
    print("Number of optimization steps:", num_train_optimization_steps)

    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=learning_rate,
                         warmup=warmup_rate,
                         t_total=num_train_optimization_steps)

    dev_instances = span_pred_reader.read(dev_fitem_list)

    forbackward_step = 0
    update_step = 0

    logging_agent = save_tool.ScoreLogger({})

    # # # Create Log File
    file_path_prefix = None
    if not debug:
        file_path_prefix, date = save_tool.gen_file_prefix(
            f"{experiment_name}")
        # Save the source code.
        script_name = os.path.basename(__file__)
        with open(os.path.join(file_path_prefix, script_name),
                  'w') as out_f, open(__file__, 'r') as it:
            out_f.write(it.read())
            out_f.flush()
    # # # Log File end

    for epoch_i in range(num_train_epochs):
        print("Epoch:", epoch_i)

        print("Resampling:")
        train_fitem_dict, train_fitem_list, _ = get_qa_item_with_upstream_sentence(
            train_list,
            train_sentence_level_results,
            is_training=True,
            tokenizer=tokenizer,
            max_context_length=max_pre_context_length,
            max_query_length=max_query_length,
            filter_value=s_filter_value,
            doc_stride=doc_stride,
            top_k=s_top_k,
            debug_mode=debug)

        random.shuffle(train_fitem_list)
        train_instances = span_pred_reader.read(train_fitem_list)
        train_iter = iterator(train_instances, num_epochs=1, shuffle=True)

        for batch in tqdm(train_iter, desc="Batch Loop"):
            model.train()
            batch = allen_util.move_to_device(batch, device_num)
            paired_sequence = batch['paired_sequence']
            paired_segments_ids = batch['paired_segments_ids']
            att_mask, _ = torch_util.get_length_and_mask(paired_sequence)
            gt_span = batch['gt_span']

            loss = model(mode=BertSpan.ForwardMode.TRAIN,
                         input_ids=paired_sequence,
                         token_type_ids=paired_segments_ids,
                         attention_mask=att_mask,
                         gt_span=gt_span)

            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.

            if gradient_accumulate_step > 1:
                loss = loss / gradient_accumulate_step

            loss.backward()
            forbackward_step += 1

            if forbackward_step % gradient_accumulate_step == 0:
                optimizer.step()
                if ema is not None and do_ema:
                    updated_model = model.module if hasattr(
                        model, 'module') else model
                    ema(updated_model.named_parameters())
                optimizer.zero_grad()
                update_step += 1

                if update_step % eval_frequency == 0:
                    # print("Non-EMA EVAL:")
                    # eval_iter = iterator(dev_instances, num_epochs=1, shuffle=False)
                    # cur_eitem_list, cur_eval_dict = span_eval(model, eval_iter, do_lower_case, dev_fitem_dict,
                    #                                           device_num)
                    # cur_results_dict = dict()
                    # cur_results_dict['p_answer'] = cur_eval_dict
                    # cur_results_dict['sp'] = dev_sp_results_dict
                    #
                    # _, metrics = ext_hotpot_eval.eval(cur_results_dict, dev_list, verbose=False)
                    # # print(metrics)
                    #
                    # logging_item = {
                    #     'score': metrics,
                    # }
                    #
                    # joint_f1 = metrics['joint_f1']
                    # joint_em = metrics['joint_em']
                    #
                    # print(logging_item)
                    #
                    # if not debug:
                    #     save_file_name = f'i({update_step})|e({epoch_i})' \
                    #         f'|j_f1({joint_f1})|j_em({joint_em})|seed({seed})'
                    #
                    #     # print(save_file_name)
                    #     logging_agent.incorporate_results({}, save_file_name, logging_item)
                    #     logging_agent.logging_to_file(Path(file_path_prefix) / "log.json")
                    #
                    #     model_to_save = model.module if hasattr(model, 'module') else model
                    #     output_model_file = Path(file_path_prefix) / save_file_name
                    #     torch.save(model_to_save.state_dict(), str(output_model_file))

                    if do_ema and ema is not None:
                        print("EMA EVAL")
                        ema_model = ema.get_inference_model()
                        ema_inference_device_ids = get_ema_gpu_id_list(
                            master_device_num=ema_device_num)
                        ema_model = ema_model.to(ema_device_num)
                        ema_model = torch.nn.DataParallel(
                            ema_model, device_ids=ema_inference_device_ids)
                        dev_iter = iterator(dev_instances,
                                            num_epochs=1,
                                            shuffle=False)
                        cur_eitem_list, cur_eval_dict = span_eval(
                            ema_model,
                            dev_iter,
                            do_lower_case,
                            dev_fitem_dict,
                            ema_device_num,
                            show_progress=False)
                        cur_results_dict = dict()
                        cur_results_dict['p_answer'] = cur_eval_dict
                        cur_results_dict['sp'] = dev_sp_results_dict

                        _, metrics = ext_hotpot_eval.eval(cur_results_dict,
                                                          dev_list,
                                                          verbose=False)
                        print(metrics)
                        print("---------------" * 3)

                        logging_item = {
                            'label': 'ema',
                            'score': metrics,
                        }

                        joint_f1 = metrics['joint_f1']
                        joint_em = metrics['joint_em']

                        print(logging_item)

                        if not debug:
                            save_file_name = f'ema_i({update_step})|e({epoch_i})' \
                                f'|j_f1({joint_f1})|j_em({joint_em})|seed({seed})'
                            # print(save_file_name)
                            logging_agent.incorporate_results({},
                                                              save_file_name,
                                                              logging_item)
                            logging_agent.logging_to_file(
                                Path(file_path_prefix) / "log.json")

                            model_to_save = ema_model.module if hasattr(
                                ema_model, 'module') else ema_model
                            output_model_file = Path(
                                file_path_prefix) / save_file_name
                            torch.save(model_to_save.state_dict(),
                                       str(output_model_file))
def model_go():
    seed = 12
    torch.manual_seed(seed)
    # bert_model_name = 'bert-large-uncased'
    bert_model_name = 'bert-base-uncased'
    lazy = False
    # lazy = True
    forward_size = 64
    # batch_size = 64
    batch_size = 128
    gradient_accumulate_step = int(batch_size / forward_size)
    warmup_proportion = 0.1
    learning_rate = 5e-5
    num_train_epochs = 5
    eval_frequency = 5000
    do_lower_case = True
    ignore_non_verifiable = True
    experiment_name = f'fever_v0_plevel_retri_(ignore_non_verifiable:{ignore_non_verifiable})'

    debug_mode = False
    max_l = 264
    # est_datasize = 900_000

    num_class = 1
    # num_train_optimization_steps

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_num = 0 if torch.cuda.is_available() else -1

    n_gpu = torch.cuda.device_count()

    unk_token_num = {'tokens': 1}  # work around for initiating vocabulary.
    vocab = ExVocabulary(unk_token_num=unk_token_num)
    vocab.add_token_to_namespace("false", namespace="labels")  # 0
    vocab.add_token_to_namespace("true", namespace="labels")  # 1
    vocab.add_token_to_namespace("hidden", namespace="labels")
    vocab.change_token_with_index_to_namespace("hidden",
                                               -2,
                                               namespace='labels')

    # Load Dataset
    train_ruleterm_doc_results = common.load_jsonl(
        config.PRO_ROOT /
        "results/doc_retri_results/fever_results/merged_doc_results/m_doc_train.jsonl"
    )
    dev_ruleterm_doc_results = common.load_jsonl(
        config.PRO_ROOT /
        "results/doc_retri_results/fever_results/merged_doc_results/m_doc_dev.jsonl"
    )

    # train_list = common.load_json(config.TRAIN_FILE)
    dev_list = common.load_jsonl(config.FEVER_DEV)

    train_fitems = fever_p_level_sampler.get_paragraph_forward_pair(
        'train',
        train_ruleterm_doc_results,
        is_training=True,
        debug=debug_mode,
        ignore_non_verifiable=True)
    dev_fitems = fever_p_level_sampler.get_paragraph_forward_pair(
        'dev',
        dev_ruleterm_doc_results,
        is_training=False,
        debug=debug_mode,
        ignore_non_verifiable=False)

    # Just to show the information
    fever_p_level_sampler.down_sample_neg(train_fitems, None)
    fever_p_level_sampler.down_sample_neg(dev_fitems, None)

    if debug_mode:
        dev_list = dev_list[:100]
        eval_frequency = 2
        # print(dev_list[-1]['_id'])
        # exit(0)

    # sampled_train_list = down_sample_neg(train_fitems_list, ratio=pos_ratio)
    est_datasize = len(train_fitems)

    dev_o_dict = list_dict_data_tool.list_to_dict(dev_list, 'id')
    # print(dev_o_dict)

    bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name,
                                                   do_lower_case=do_lower_case)
    bert_cs_reader = BertContentSelectionReader(
        bert_tokenizer,
        lazy,
        is_paired=True,
        example_filter=lambda x: len(x['context']) == 0,
        max_l=max_l,
        element_fieldname='element')

    bert_encoder = BertModel.from_pretrained(bert_model_name)
    model = BertMultiLayerSeqClassification(bert_encoder,
                                            num_labels=num_class,
                                            num_of_pooling_layer=1,
                                            act_type='tanh',
                                            use_pretrained_pooler=True,
                                            use_sigmoid=True)

    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)
    #
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    num_train_optimization_steps = int(est_datasize / forward_size / gradient_accumulate_step) * \
                                   num_train_epochs

    if debug_mode:
        num_train_optimization_steps = 100

    print("Estimated training size", est_datasize)
    print("Number of optimization steps:", num_train_optimization_steps)

    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=learning_rate,
                         warmup=warmup_proportion,
                         t_total=num_train_optimization_steps)

    dev_instances = bert_cs_reader.read(dev_fitems)

    biterator = BasicIterator(batch_size=forward_size)
    biterator.index_with(vocab)

    forbackward_step = 0
    update_step = 0

    logging_agent = save_tool.ScoreLogger({})

    if not debug_mode:
        # # # Create Log File
        file_path_prefix, date = save_tool.gen_file_prefix(
            f"{experiment_name}")
        # Save the source code.
        script_name = os.path.basename(__file__)
        with open(os.path.join(file_path_prefix, script_name),
                  'w') as out_f, open(__file__, 'r') as it:
            out_f.write(it.read())
            out_f.flush()
        # # # Log File end

    for epoch_i in range(num_train_epochs):
        print("Epoch:", epoch_i)
        # sampled_train_list = down_sample_neg(train_fitems_list, ratio=pos_ratio)
        random.shuffle(train_fitems)
        train_instance = bert_cs_reader.read(train_fitems)
        train_iter = biterator(train_instance, num_epochs=1, shuffle=True)

        for batch in tqdm(train_iter):
            model.train()
            batch = move_to_device(batch, device_num)

            paired_sequence = batch['paired_sequence']
            paired_segments_ids = batch['paired_segments_ids']
            labels_ids = batch['label']
            att_mask, _ = torch_util.get_length_and_mask(paired_sequence)
            s1_span = batch['bert_s1_span']
            s2_span = batch['bert_s2_span']

            loss = model(
                paired_sequence,
                token_type_ids=paired_segments_ids,
                attention_mask=att_mask,
                mode=BertMultiLayerSeqClassification.ForwardMode.TRAIN,
                labels=labels_ids)

            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.

            if gradient_accumulate_step > 1:
                loss = loss / gradient_accumulate_step

            loss.backward()
            forbackward_step += 1

            if forbackward_step % gradient_accumulate_step == 0:
                optimizer.step()
                optimizer.zero_grad()
                update_step += 1

                if update_step % eval_frequency == 0:
                    print("Update steps:", update_step)
                    dev_iter = biterator(dev_instances,
                                         num_epochs=1,
                                         shuffle=False)

                    cur_eval_results_list = eval_model(model,
                                                       dev_iter,
                                                       device_num,
                                                       make_int=True,
                                                       with_probs=True)
                    copied_dev_o_dict = copy.deepcopy(dev_o_dict)
                    copied_dev_d_list = copy.deepcopy(dev_list)
                    list_dict_data_tool.append_subfield_from_list_to_dict(
                        cur_eval_results_list,
                        copied_dev_o_dict,
                        'qid',
                        'fid',
                        check=True)

                    cur_results_dict_th0_5 = select_top_k_and_to_results_dict(
                        copied_dev_o_dict,
                        score_field_name='prob',
                        top_k=5,
                        filter_value=0.5)

                    list_dict_data_tool.append_item_from_dict_to_list_hotpot_style(
                        copied_dev_d_list, cur_results_dict_th0_5, 'id',
                        'predicted_docids')
                    # mode = {'standard': False, 'check_doc_id_correct': True}
                    strict_score, pr, rec, f1 = fever_scorer.fever_doc_only(
                        copied_dev_d_list, dev_list, max_evidence=5)
                    score_05 = {
                        'ss': strict_score,
                        'pr': pr,
                        'rec': rec,
                        'f1': f1,
                    }

                    list_dict_data_tool.append_subfield_from_list_to_dict(
                        cur_eval_results_list,
                        copied_dev_o_dict,
                        'qid',
                        'fid',
                        check=True)

                    cur_results_dict_th0_2 = select_top_k_and_to_results_dict(
                        copied_dev_o_dict,
                        score_field_name='prob',
                        top_k=5,
                        filter_value=0.2)

                    list_dict_data_tool.append_item_from_dict_to_list_hotpot_style(
                        copied_dev_d_list, cur_results_dict_th0_2, 'id',
                        'predicted_docids')
                    # mode = {'standard': False, 'check_doc_id_correct': True}
                    strict_score, pr, rec, f1 = fever_scorer.fever_doc_only(
                        copied_dev_d_list, dev_list, max_evidence=5)
                    score_02 = {
                        'ss': strict_score,
                        'pr': pr,
                        'rec': rec,
                        'f1': f1,
                    }

                    logging_item = {
                        'score_02': score_02,
                        'score_05': score_05,
                    }

                    print(logging_item)

                    s02_ss_score = score_02['ss']
                    s05_ss_score = score_05['ss']

                    if not debug_mode:
                        save_file_name = f'i({update_step})|e({epoch_i})' \
                            f'|v02_ofever({s02_ss_score})' \
                            f'|v05_ofever({s05_ss_score})|seed({seed})'

                        # print(save_file_name)
                        logging_agent.incorporate_results({}, save_file_name,
                                                          logging_item)
                        logging_agent.logging_to_file(
                            Path(file_path_prefix) / "log.json")

                        model_to_save = model.module if hasattr(
                            model, 'module') else model
                        output_model_file = Path(
                            file_path_prefix) / save_file_name
                        torch.save(model_to_save.state_dict(),
                                   str(output_model_file))
Esempio n. 7
0
def model_go():
    seed = 12
    torch.manual_seed(seed)
    # bert_model_name = 'bert-large-uncased'
    bert_pretrain_path = config.PRO_ROOT / '.pytorch_pretrained_bert'
    bert_model_name = 'bert-base-uncased'
    lazy = False
    # lazy = True
    forward_size = 128
    # batch_size = 64
    batch_size = 128
    gradient_accumulate_step = int(batch_size / forward_size)
    warmup_proportion = 0.1
    learning_rate = 5e-5
    num_train_epochs = 5
    eval_frequency = 2000
    pos_ratio = 0.2
    do_lower_case = True
    document_top_k = 2
    experiment_name = f'hotpot_v0_slevel_retri_(doc_top_k:{document_top_k})'

    debug_mode = False
    do_ema = True
    # est_datasize = 900_000

    num_class = 1
    # num_train_optimization_steps

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_num = 0 if torch.cuda.is_available() else -1

    n_gpu = torch.cuda.device_count()

    unk_token_num = {'tokens': 1}  # work around for initiating vocabulary.
    vocab = ExVocabulary(unk_token_num=unk_token_num)
    vocab.add_token_to_namespace("false", namespace="labels")  # 0
    vocab.add_token_to_namespace("true", namespace="labels")  # 1
    vocab.add_token_to_namespace("hidden", namespace="labels")
    vocab.change_token_with_index_to_namespace("hidden",
                                               -2,
                                               namespace='labels')

    # Load Dataset
    train_list = common.load_json(config.TRAIN_FILE)
    dev_list = common.load_json(config.DEV_FULLWIKI_FILE)

    # train_fitems = sentence_level_sampler.get_train_sentence_pair(document_top_k, True, debug_mode)
    # dev_fitems = sentence_level_sampler.get_dev_sentence_pair(document_top_k, False, debug_mode)

    # Load train eval results list
    cur_train_eval_results_list = common.load_jsonl(
        config.PRO_ROOT /
        "data/p_hotpotqa/hotpotqa_paragraph_level/04-10-17:44:54_hotpot_v0_cs/"
        "i(40000)|e(4)|t5_doc_recall(0.8793382849426064)|t5_sp_recall(0.879496479212887)|t10_doc_recall(0.888656313301823)|t5_sp_recall(0.8888325134240054)|seed(12)/train_p_level_bert_v1_results.jsonl"
    )

    cur_dev_eval_results_list = common.load_jsonl(
        config.PRO_ROOT /
        "data/p_hotpotqa/hotpotqa_paragraph_level/04-10-17:44:54_hotpot_v0_cs/"
        "i(40000)|e(4)|t5_doc_recall(0.8793382849426064)|t5_sp_recall(0.879496479212887)|t10_doc_recall(0.888656313301823)|t5_sp_recall(0.8888325134240054)|seed(12)/dev_p_level_bert_v1_results.jsonl"
    )

    train_fitems = get_sentence_pair(document_top_k,
                                     train_list,
                                     cur_train_eval_results_list,
                                     is_training=True,
                                     debug_mode=debug_mode)

    dev_fitems = get_sentence_pair(document_top_k,
                                   dev_list,
                                   cur_dev_eval_results_list,
                                   is_training=False,
                                   debug_mode=debug_mode)

    if debug_mode:
        dev_list = dev_list[:100]
        eval_frequency = 2
        # print(dev_list[-1]['_id'])
        # exit(0)

    # sampled_train_list = down_sample_neg(train_fitems_list, ratio=pos_ratio)
    est_datasize = len(train_fitems)

    dev_o_dict = list_dict_data_tool.list_to_dict(dev_list, '_id')
    # print(dev_o_dict)

    bert_tokenizer = BertTokenizer.from_pretrained(
        bert_model_name,
        do_lower_case=do_lower_case,
        cache_dir=bert_pretrain_path)
    bert_cs_reader = BertContentSelectionReader(
        bert_tokenizer,
        lazy,
        is_paired=True,
        example_filter=lambda x: len(x['context']) == 0,
        max_l=128,
        element_fieldname='element')

    bert_encoder = BertModel.from_pretrained(bert_model_name,
                                             cache_dir=bert_pretrain_path)
    model = BertMultiLayerSeqClassification(bert_encoder,
                                            num_labels=num_class,
                                            num_of_pooling_layer=1,
                                            act_type='tanh',
                                            use_pretrained_pooler=True,
                                            use_sigmoid=True)

    ema = None
    if do_ema:
        ema = EMA(model, model.named_parameters(), device_num=1)

    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)
    #
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    num_train_optimization_steps = int(est_datasize / forward_size / gradient_accumulate_step) * \
                                   num_train_epochs

    if debug_mode:
        num_train_optimization_steps = 100

    print("Estimated training size", est_datasize)
    print("Number of optimization steps:", num_train_optimization_steps)

    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=learning_rate,
                         warmup=warmup_proportion,
                         t_total=num_train_optimization_steps)

    dev_instances = bert_cs_reader.read(dev_fitems)

    biterator = BasicIterator(batch_size=forward_size)
    biterator.index_with(vocab)

    forbackward_step = 0
    update_step = 0

    logging_agent = save_tool.ScoreLogger({})

    # # # Create Log File
    file_path_prefix, date = save_tool.gen_file_prefix(f"{experiment_name}")
    # Save the source code.
    script_name = os.path.basename(__file__)
    with open(os.path.join(file_path_prefix, script_name),
              'w') as out_f, open(__file__, 'r') as it:
        out_f.write(it.read())
        out_f.flush()
    # # # Log File end

    for epoch_i in range(num_train_epochs):
        print("Epoch:", epoch_i)
        # sampled_train_list = down_sample_neg(train_fitems_list, ratio=pos_ratio)
        random.shuffle(train_fitems)
        train_instance = bert_cs_reader.read(train_fitems)
        train_iter = biterator(train_instance, num_epochs=1, shuffle=True)

        for batch in tqdm(train_iter):
            model.train()
            batch = move_to_device(batch, device_num)

            paired_sequence = batch['paired_sequence']
            paired_segments_ids = batch['paired_segments_ids']
            labels_ids = batch['label']
            att_mask, _ = torch_util.get_length_and_mask(paired_sequence)
            s1_span = batch['bert_s1_span']
            s2_span = batch['bert_s2_span']

            loss = model(
                paired_sequence,
                token_type_ids=paired_segments_ids,
                attention_mask=att_mask,
                mode=BertMultiLayerSeqClassification.ForwardMode.TRAIN,
                labels=labels_ids)

            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.

            if gradient_accumulate_step > 1:
                loss = loss / gradient_accumulate_step

            loss.backward()
            forbackward_step += 1

            if forbackward_step % gradient_accumulate_step == 0:
                optimizer.step()
                if ema is not None and do_ema:
                    updated_model = model.module if hasattr(
                        model, 'module') else model
                    ema(updated_model.named_parameters())
                optimizer.zero_grad()
                update_step += 1

                if update_step % eval_frequency == 0:
                    print("Update steps:", update_step)
                    dev_iter = biterator(dev_instances,
                                         num_epochs=1,
                                         shuffle=False)

                    cur_eval_results_list = eval_model(model,
                                                       dev_iter,
                                                       device_num,
                                                       with_probs=True)
                    copied_dev_o_dict = copy.deepcopy(dev_o_dict)
                    list_dict_data_tool.append_subfield_from_list_to_dict(
                        cur_eval_results_list,
                        copied_dev_o_dict,
                        'qid',
                        'fid',
                        check=True)
                    # 0.5
                    cur_results_dict_v05 = select_top_k_and_to_results_dict(
                        copied_dev_o_dict,
                        top_k=5,
                        score_field_name='prob',
                        filter_value=0.5,
                        result_field='sp')

                    cur_results_dict_v02 = select_top_k_and_to_results_dict(
                        copied_dev_o_dict,
                        top_k=5,
                        score_field_name='prob',
                        filter_value=0.2,
                        result_field='sp')

                    _, metrics_v5 = ext_hotpot_eval.eval(cur_results_dict_v05,
                                                         dev_list,
                                                         verbose=False)

                    _, metrics_v2 = ext_hotpot_eval.eval(cur_results_dict_v02,
                                                         dev_list,
                                                         verbose=False)

                    v02_sp_f1 = metrics_v2['sp_f1']
                    v02_sp_recall = metrics_v2['sp_recall']
                    v02_sp_prec = metrics_v2['sp_prec']

                    v05_sp_f1 = metrics_v5['sp_f1']
                    v05_sp_recall = metrics_v5['sp_recall']
                    v05_sp_prec = metrics_v5['sp_prec']

                    logging_item = {
                        'v02': metrics_v2,
                        'v05': metrics_v5,
                    }

                    print(logging_item)

                    # print(logging_item)
                    if not debug_mode:
                        save_file_name = f'i({update_step})|e({epoch_i})' \
                            f'|v02_f1({v02_sp_f1})|v02_recall({v02_sp_recall})' \
                            f'|v05_f1({v05_sp_f1})|v05_recall({v05_sp_recall})|seed({seed})'

                        # print(save_file_name)
                        logging_agent.incorporate_results({}, save_file_name,
                                                          logging_item)
                        logging_agent.logging_to_file(
                            Path(file_path_prefix) / "log.json")

                        model_to_save = model.module if hasattr(
                            model, 'module') else model
                        output_model_file = Path(
                            file_path_prefix) / save_file_name
                        torch.save(model_to_save.state_dict(),
                                   str(output_model_file))

                    if do_ema and ema is not None:
                        ema_model = ema.get_inference_model()
                        master_device_num = 1
                        ema_inference_device_ids = get_ema_gpu_id_list(
                            master_device_num=master_device_num)
                        ema_model = ema_model.to(master_device_num)
                        ema_model = torch.nn.DataParallel(
                            ema_model, device_ids=ema_inference_device_ids)
                        dev_iter = biterator(dev_instances,
                                             num_epochs=1,
                                             shuffle=False)

                        cur_eval_results_list = eval_model(ema_model,
                                                           dev_iter,
                                                           master_device_num,
                                                           with_probs=True)
                        copied_dev_o_dict = copy.deepcopy(dev_o_dict)
                        list_dict_data_tool.append_subfield_from_list_to_dict(
                            cur_eval_results_list,
                            copied_dev_o_dict,
                            'qid',
                            'fid',
                            check=True)
                        # 0.5
                        cur_results_dict_v05 = select_top_k_and_to_results_dict(
                            copied_dev_o_dict,
                            top_k=5,
                            score_field_name='prob',
                            filter_value=0.5,
                            result_field='sp')

                        cur_results_dict_v02 = select_top_k_and_to_results_dict(
                            copied_dev_o_dict,
                            top_k=5,
                            score_field_name='prob',
                            filter_value=0.2,
                            result_field='sp')

                        _, metrics_v5 = ext_hotpot_eval.eval(
                            cur_results_dict_v05, dev_list, verbose=False)

                        _, metrics_v2 = ext_hotpot_eval.eval(
                            cur_results_dict_v02, dev_list, verbose=False)

                        v02_sp_f1 = metrics_v2['sp_f1']
                        v02_sp_recall = metrics_v2['sp_recall']
                        v02_sp_prec = metrics_v2['sp_prec']

                        v05_sp_f1 = metrics_v5['sp_f1']
                        v05_sp_recall = metrics_v5['sp_recall']
                        v05_sp_prec = metrics_v5['sp_prec']

                        logging_item = {
                            'label': 'ema',
                            'v02': metrics_v2,
                            'v05': metrics_v5,
                        }

                        print(logging_item)

                        if not debug_mode:
                            save_file_name = f'ema_i({update_step})|e({epoch_i})' \
                                f'|v02_f1({v02_sp_f1})|v02_recall({v02_sp_recall})' \
                                f'|v05_f1({v05_sp_f1})|v05_recall({v05_sp_recall})|seed({seed})'

                            # print(save_file_name)
                            logging_agent.incorporate_results({},
                                                              save_file_name,
                                                              logging_item)
                            logging_agent.logging_to_file(
                                Path(file_path_prefix) / "log.json")

                            model_to_save = ema_model.module if hasattr(
                                ema_model, 'module') else ema_model
                            output_model_file = Path(
                                file_path_prefix) / save_file_name
                            torch.save(model_to_save.state_dict(),
                                       str(output_model_file))
Esempio n. 8
0
def multitask_model_go():
    seed = 12
    torch.manual_seed(seed)
    # bert_model_name = 'bert-large-uncased'
    bert_pretrain_path = config.PRO_ROOT / '.pytorch_pretrained_bert'
    bert_model_name = 'bert-base-uncased'
    lazy = False
    # lazy = True
    forward_size = 64
    # batch_size = 64
    batch_size = 128
    gradient_accumulate_step = int(batch_size / forward_size)
    warmup_proportion = 0.1
    learning_rate = 5e-5
    num_train_epochs = 1
    eval_frequency = 5000
    hotpot_pos_ratio = 0.2
    do_lower_case = True
    max_l = 264

    experiment_name = f'mtr_p_level_(num_train_epochs:{num_train_epochs})'

    debug_mode = False
    do_ema = True
    # est_datasize = 900_000

    num_class = 1
    # num_train_optimization_steps

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_num = 0 if torch.cuda.is_available() else -1

    n_gpu = torch.cuda.device_count()

    unk_token_num = {'tokens': 1}  # work around for initiating vocabulary.
    vocab = ExVocabulary(unk_token_num=unk_token_num)
    vocab.add_token_to_namespace("false", namespace="labels")  # 0
    vocab.add_token_to_namespace("true", namespace="labels")  # 1
    vocab.add_token_to_namespace("hidden", namespace="labels")
    vocab.change_token_with_index_to_namespace("hidden",
                                               -2,
                                               namespace='labels')

    # Load Hotpot Dataset
    hotpot_train_list = common.load_json(config.TRAIN_FILE)
    hotpot_dev_list = common.load_json(config.DEV_FULLWIKI_FILE)
    hotpot_dev_o_dict = list_dict_data_tool.list_to_dict(
        hotpot_dev_list, '_id')

    # Load Hotpot upstream paragraph forward item
    hotpot_dev_fitems_list = common.load_jsonl(
        config.PDATA_ROOT / "content_selection_forward" /
        "hotpot_dev_p_level_unlabeled.jsonl")
    hotpot_train_fitems_list = common.load_jsonl(
        config.PDATA_ROOT / "content_selection_forward" /
        "hotpot_train_p_level_labeled.jsonl")

    hotpot_train_fitems_list = hotpot_sampler_utils.field_name_convert(
        hotpot_train_fitems_list, 'doc_t', 'element')
    hotpot_dev_fitems_list = hotpot_sampler_utils.field_name_convert(
        hotpot_dev_fitems_list, 'doc_t', 'element')

    # Load FEVER Dataset
    # fever_train_list = common.load_json(config.FEVER_TRAIN)
    fever_dev_list = common.load_jsonl(config.FEVER_DEV)
    fever_dev_o_dict = list_dict_data_tool.list_to_dict(fever_dev_list, 'id')

    train_ruleterm_doc_results = common.load_jsonl(
        config.PRO_ROOT /
        "results/doc_retri_results/fever_results/merged_doc_results/m_doc_train.jsonl"
    )
    dev_ruleterm_doc_results = common.load_jsonl(
        config.PRO_ROOT /
        "results/doc_retri_results/fever_results/merged_doc_results/m_doc_dev.jsonl"
    )

    fever_train_fitems_list = fever_p_level_sampler.get_paragraph_forward_pair(
        'train',
        train_ruleterm_doc_results,
        is_training=True,
        debug=debug_mode,
        ignore_non_verifiable=True)
    fever_dev_fitems_list = fever_p_level_sampler.get_paragraph_forward_pair(
        'dev',
        dev_ruleterm_doc_results,
        is_training=False,
        debug=debug_mode,
        ignore_non_verifiable=False)
    if debug_mode:
        hotpot_dev_list = hotpot_dev_list[:10]
        hotpot_dev_fitems_list = hotpot_dev_fitems_list[:296]
        hotpot_train_fitems_list = hotpot_train_fitems_list[:300]

        fever_dev_list = fever_dev_list[:100]
        eval_frequency = 2

    # Down_sample for hotpot.
    hotpot_sampled_train_list = down_sample_neg(hotpot_train_fitems_list,
                                                ratio=hotpot_pos_ratio)
    hotpot_est_datasize = len(hotpot_sampled_train_list)
    fever_est_datasize = len(fever_train_fitems_list)

    print("Hotpot Train Size:", hotpot_est_datasize)
    print("Fever Train Size:", fever_est_datasize)

    est_datasize = hotpot_est_datasize + fever_est_datasize

    bert_tokenizer = BertTokenizer.from_pretrained(
        bert_model_name,
        do_lower_case=do_lower_case,
        cache_dir=bert_pretrain_path)
    bert_cs_reader = BertContentSelectionReader(
        bert_tokenizer,
        lazy,
        is_paired=True,
        example_filter=lambda x: len(x['context']) == 0,
        max_l=max_l,
        element_fieldname='element')

    bert_encoder = BertModel.from_pretrained(bert_model_name,
                                             cache_dir=bert_pretrain_path)
    model = BertMultiLayerSeqClassification(bert_encoder,
                                            num_labels=num_class,
                                            num_of_pooling_layer=1,
                                            act_type='tanh',
                                            use_pretrained_pooler=True,
                                            use_sigmoid=True)

    ema = None
    if do_ema:
        ema = EMA(model, model.named_parameters(), device_num=1)

    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)
    #
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    num_train_optimization_steps = int(est_datasize / forward_size / gradient_accumulate_step) * \
                                   num_train_epochs

    if debug_mode:
        num_train_optimization_steps = 100

    print("Estimated training size", est_datasize)
    print("Number of optimization steps:", num_train_optimization_steps)

    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=learning_rate,
                         warmup=warmup_proportion,
                         t_total=num_train_optimization_steps)

    hotpot_dev_instances = bert_cs_reader.read(hotpot_dev_fitems_list)
    fever_dev_instances = bert_cs_reader.read(fever_dev_fitems_list)

    biterator = BasicIterator(batch_size=forward_size)
    biterator.index_with(vocab)

    forbackward_step = 0
    update_step = 0

    logging_agent = save_tool.ScoreLogger({})

    file_path_prefix = '.'
    if not debug_mode:
        # # # Create Log File
        file_path_prefix, date = save_tool.gen_file_prefix(
            f"{experiment_name}")
        # Save the source code.
        script_name = os.path.basename(__file__)
        with open(os.path.join(file_path_prefix, script_name),
                  'w') as out_f, open(__file__, 'r') as it:
            out_f.write(it.read())
            out_f.flush()
        # # # Log File end

    for epoch_i in range(num_train_epochs):
        print("Epoch:", epoch_i)
        # sampled_train_list = down_sample_neg(train_fitems_list, ratio=pos_ratio)
        hotpot_sampled_train_list = down_sample_neg(hotpot_train_fitems_list,
                                                    ratio=hotpot_pos_ratio)
        all_train_data = hotpot_sampled_train_list + fever_train_fitems_list
        random.shuffle(all_train_data)
        train_instance = bert_cs_reader.read(all_train_data)
        train_iter = biterator(train_instance, num_epochs=1, shuffle=True)

        for batch in tqdm(train_iter):
            model.train()
            batch = move_to_device(batch, device_num)

            paired_sequence = batch['paired_sequence']
            paired_segments_ids = batch['paired_segments_ids']
            labels_ids = batch['label']
            att_mask, _ = torch_util.get_length_and_mask(paired_sequence)
            s1_span = batch['bert_s1_span']
            s2_span = batch['bert_s2_span']

            loss = model(
                paired_sequence,
                token_type_ids=paired_segments_ids,
                attention_mask=att_mask,
                mode=BertMultiLayerSeqClassification.ForwardMode.TRAIN,
                labels=labels_ids)

            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.

            if gradient_accumulate_step > 1:
                loss = loss / gradient_accumulate_step

            loss.backward()
            forbackward_step += 1

            if forbackward_step % gradient_accumulate_step == 0:
                optimizer.step()
                if ema is not None and do_ema:
                    updated_model = model.module if hasattr(
                        model, 'module') else model
                    ema(updated_model.named_parameters())
                optimizer.zero_grad()
                update_step += 1

                if update_step % eval_frequency == 0:
                    print("Update steps:", update_step)
                    # Eval FEVER
                    eval_fever_procedure(biterator, fever_dev_instances, model,
                                         device_num, 1, fever_dev_list,
                                         fever_dev_o_dict, debug_mode,
                                         logging_agent, update_step, epoch_i,
                                         file_path_prefix, do_ema, ema, seed)
                    eval_hotpot_procedure(biterator, hotpot_dev_instances,
                                          model, device_num, 1,
                                          hotpot_dev_list, hotpot_dev_o_dict,
                                          debug_mode, logging_agent,
                                          update_step, epoch_i,
                                          file_path_prefix, do_ema, ema, seed)

    if not debug_mode:
        print("Final Saving.")
        save_file_name = f'i({update_step})|e({num_train_epochs})_final_model'
        model_to_save = model.module if hasattr(model, 'module') else model
        output_model_file = Path(file_path_prefix) / save_file_name
        torch.save(model_to_save.state_dict(), str(output_model_file))

        if do_ema and ema is not None:
            print("Final EMA Saving")
            ema_model = ema.get_inference_model()
            save_file_name = f'i({update_step})|e({num_train_epochs})_final_ema_model'
            model_to_save = ema_model.module if hasattr(
                ema_model, 'module') else ema_model
            output_model_file = Path(file_path_prefix) / save_file_name
            torch.save(model_to_save.state_dict(), str(output_model_file))
Esempio n. 9
0
def model_go(th_filter_prob=0.2, top_k_sent=5):
    seed = 12
    torch.manual_seed(seed)
    # bert_model_name = 'bert-large-uncased'
    bert_model_name = 'bert-base-uncased'
    bert_pretrain_path = config.PRO_ROOT / '.pytorch_pretrained_bert'
    lazy = False
    # lazy = True
    forward_size = 32
    # batch_size = 64
    # batch_size = 192
    batch_size = 32
    gradient_accumulate_step = int(batch_size / forward_size)
    warmup_proportion = 0.1
    # schedule_type = 'warmup_constant'
    # 'warmup_cosine': warmup_cosine,
    # 'warmup_constant': warmup_constant,
    # 'warmup_linear': warmup_linear,
    schedule_type = 'warmup_linear'
    learning_rate = 5e-5
    num_train_epochs = 5
    eval_frequency = 4000
    do_lower_case = True
    pair_order = 'cq'
    # debug_mode = True
    # debug_mode = True
    debug_mode = False
    do_ema = True

    maxout_model = False
    # est_datasize = 900_000

    num_class = 3
    # num_train_optimization_steps
    top_k = top_k_sent

    train_sent_filtering_prob = th_filter_prob
    dev_sent_filtering_prob = th_filter_prob
    experiment_name = f'fever_v2_nli_th{train_sent_filtering_prob}_tk{top_k}'

    # Data dataset and upstream sentence results.
    dev_sent_results_list = common.load_jsonl(
        config.PRO_ROOT / "data/p_fever/fever_sentence_level/04-24-00-11-19_fever_v0_slevel_retri_(ignore_non_verifiable-True)/fever_s_level_dev_results.jsonl")
    train_sent_results_list = common.load_jsonl(
        config.PRO_ROOT / "data/p_fever/fever_sentence_level/04-24-00-11-19_fever_v0_slevel_retri_(ignore_non_verifiable-True)/fever_s_level_train_results.jsonl")

    dev_fitems, dev_list = get_nli_pair('dev', is_training=False,
                                        sent_level_results_list=dev_sent_results_list, debug=debug_mode,
                                        sent_top_k=top_k_sent, sent_filter_value=dev_sent_filtering_prob)
    train_fitems, train_list = get_nli_pair('train', is_training=True,
                                            sent_level_results_list=train_sent_results_list, debug=debug_mode,
                                            sent_top_k=top_k_sent, sent_filter_value=train_sent_filtering_prob)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_num = 0 if torch.cuda.is_available() else -1

    n_gpu = torch.cuda.device_count()

    unk_token_num = {'tokens': 1}  # work around for initiating vocabulary.
    vocab = ExVocabulary(unk_token_num=unk_token_num)
    vocab.add_token_to_namespace('SUPPORTS', namespace='labels')
    vocab.add_token_to_namespace('REFUTES', namespace='labels')
    vocab.add_token_to_namespace('NOT ENOUGH INFO', namespace='labels')
    vocab.add_token_to_namespace("hidden", namespace="labels")
    vocab.change_token_with_index_to_namespace("hidden", -2, namespace='labels')

    if debug_mode:
        dev_list = dev_list[:100]
        train_list = train_list[:100]
        eval_frequency = 2

    est_datasize = len(train_fitems)

    bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name, do_lower_case=do_lower_case,
                                                   cache_dir=bert_pretrain_path)
    bert_cs_reader = BertFeverNLIReader(bert_tokenizer, lazy, is_paired=True, query_l=64,
                                        example_filter=None, max_l=384, pair_order=pair_order)

    bert_encoder = BertModel.from_pretrained(bert_model_name, cache_dir=bert_pretrain_path)
    if not maxout_model:
        model = BertMultiLayerSeqClassification(bert_encoder, num_labels=num_class, num_of_pooling_layer=1,
                                                act_type='tanh', use_pretrained_pooler=True, use_sigmoid=False)
    else:
        model = BertPairMaxOutMatcher(bert_encoder, num_of_class=num_class, act_type="gelu", num_of_out_layers=2)

    ema = None
    if do_ema:
        ema = EMA(model, model.named_parameters(), device_num=1)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    num_train_optimization_steps = int(est_datasize / forward_size / gradient_accumulate_step) * \
                                   num_train_epochs

    if debug_mode:
        num_train_optimization_steps = 100

    print("Estimated training size", est_datasize)
    print("Number of optimization steps:", num_train_optimization_steps)
    print("Do EMA:", do_ema)

    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=learning_rate,
                         warmup=warmup_proportion,
                         t_total=num_train_optimization_steps,
                         schedule=schedule_type)

    dev_instances = bert_cs_reader.read(dev_fitems)

    biterator = BasicIterator(batch_size=forward_size)
    biterator.index_with(vocab)

    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    forbackward_step = 0
    update_step = 0

    logging_agent = save_tool.ScoreLogger({})

    file_path_prefix = '.'
    if not debug_mode:
        file_path_prefix, date = save_tool.gen_file_prefix(f"{experiment_name}")
        # # # Create Log File
        # Save the source code.
        script_name = os.path.basename(__file__)
        with open(os.path.join(file_path_prefix, script_name), 'w') as out_f, open(__file__, 'r') as it:
            out_f.write(it.read())
            out_f.flush()
        # # # Log File end

    for epoch_i in range(num_train_epochs):
        print("Epoch:", epoch_i)

        train_fitems_list, _ = get_nli_pair('train', is_training=True,
                                            sent_level_results_list=train_sent_results_list, debug=debug_mode,
                                            sent_top_k=5, sent_filter_value=train_sent_filtering_prob)

        random.shuffle(train_fitems_list)
        train_instance = bert_cs_reader.read(train_fitems_list)
        train_iter = biterator(train_instance, num_epochs=1, shuffle=True)

        for batch in tqdm(train_iter):
            model.train()
            batch = move_to_device(batch, device_num)

            paired_sequence = batch['paired_sequence']
            paired_segments_ids = batch['paired_segments_ids']
            labels_ids = batch['label']
            att_mask, _ = torch_util.get_length_and_mask(paired_sequence)
            s1_span = batch['bert_s1_span']
            s2_span = batch['bert_s2_span']

            if not maxout_model:
                loss = model(paired_sequence, token_type_ids=paired_segments_ids, attention_mask=att_mask,
                             mode=BertMultiLayerSeqClassification.ForwardMode.TRAIN,
                             labels=labels_ids)
            else:
                loss = model(paired_sequence, token_type_ids=paired_segments_ids, attention_mask=att_mask,
                             s1_span=s1_span, s2_span=s2_span,
                             mode=BertPairMaxOutMatcher.ForwardMode.TRAIN,
                             labels=labels_ids)

            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.

            if gradient_accumulate_step > 1:
                loss = loss / gradient_accumulate_step

            loss.backward()
            forbackward_step += 1

            if forbackward_step % gradient_accumulate_step == 0:
                optimizer.step()
                if ema is not None and do_ema:
                    updated_model = model.module if hasattr(model, 'module') else model
                    ema(updated_model.named_parameters())
                optimizer.zero_grad()
                update_step += 1

                if update_step % eval_frequency == 0:
                    print("Update steps:", update_step)
                    # dev_iter = biterator(dev_instances, num_epochs=1, shuffle=False)
                    #
                    # cur_eval_results_list = eval_model(model, dev_iter, device_num, with_probs=True, make_int=True,
                    #                                    feed_input_span=maxout_model)
                    #
                    # ema_results_dict = list_dict_data_tool.list_to_dict(cur_eval_results_list, 'oid')
                    # copied_dev_list = copy.deepcopy(dev_list)
                    # list_dict_data_tool.append_item_from_dict_to_list(copied_dev_list, ema_results_dict,
                    #                                                   'id', 'predicted_label')
                    #
                    # mode = {'standard': True}
                    # strict_score, acc_score, pr, rec, f1 = fever_scorer.fever_score(copied_dev_list, dev_list,
                    #                                                                 mode=mode, max_evidence=5)
                    # logging_item = {
                    #     'ss': strict_score, 'ac': acc_score,
                    #     'pr': pr, 'rec': rec, 'f1': f1,
                    # }
                    #
                    # if not debug_mode:
                    #     save_file_name = f'i({update_step})|e({epoch_i})' \
                    #         f'|ss({strict_score})|ac({acc_score})|pr({pr})|rec({rec})|f1({f1})' \
                    #         f'|seed({seed})'
                    #
                    #     common.save_jsonl(copied_dev_list, Path(file_path_prefix) /
                    #                       f"{save_file_name}_dev_nli_results.json")
                    #
                    #     # print(save_file_name)
                    #     logging_agent.incorporate_results({}, save_file_name, logging_item)
                    #     logging_agent.logging_to_file(Path(file_path_prefix) / "log.json")
                    #
                    #     model_to_save = model.module if hasattr(model, 'module') else model
                    #     output_model_file = Path(file_path_prefix) / save_file_name
                    #     torch.save(model_to_save.state_dict(), str(output_model_file))

                    if do_ema and ema is not None:
                        ema_model = ema.get_inference_model()
                        ema_device_num = 0
                        ema_model = ema_model.to(device)
                        ema_model = torch.nn.DataParallel(ema_model)
                        dev_iter = biterator(dev_instances, num_epochs=1, shuffle=False)
                        cur_ema_eval_results_list = eval_model(ema_model, dev_iter, ema_device_num, with_probs=True,
                                                               make_int=True,
                                                               feed_input_span=maxout_model)

                        ema_results_dict = list_dict_data_tool.list_to_dict(cur_ema_eval_results_list, 'oid')
                        copied_dev_list = copy.deepcopy(dev_list)
                        list_dict_data_tool.append_item_from_dict_to_list(copied_dev_list, ema_results_dict,
                                                                          'id', 'predicted_label')

                        mode = {'standard': True}
                        strict_score, acc_score, pr, rec, f1 = fever_scorer.fever_score(copied_dev_list, dev_list,
                                                                                        mode=mode, max_evidence=5)
                        ema_logging_item = {
                            'label': 'ema',
                            'ss': strict_score, 'ac': acc_score,
                            'pr': pr, 'rec': rec, 'f1': f1,
                        }

                        if not debug_mode:
                            save_file_name = f'ema_i({update_step})|e({epoch_i})' \
                                f'|ss({strict_score})|ac({acc_score})|pr({pr})|rec({rec})|f1({f1})' \
                                f'|seed({seed})'

                            common.save_jsonl(copied_dev_list, Path(file_path_prefix) /
                                              f"{save_file_name}_dev_nli_results.json")

                            # print(save_file_name)
                            logging_agent.incorporate_results({}, save_file_name, ema_logging_item)
                            logging_agent.logging_to_file(Path(file_path_prefix) / "log.json")

                            model_to_save = ema_model.module if hasattr(ema_model, 'module') else ema_model
                            output_model_file = Path(file_path_prefix) / save_file_name
                            torch.save(model_to_save.state_dict(), str(output_model_file))