Esempio n. 1
0
def run_test(test_dataloader, jiant_model: JiantModel, task, device, local_rank, verbose=True):
    if not local_rank == -1:
        return
    jiant_model.eval()
    evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task)
    eval_accumulator = evaluation_scheme.get_accumulator()

    for step, (batch, batch_metadata) in enumerate(
        maybe_tqdm(test_dataloader, desc=f"Eval ({task.name}, Test)", verbose=verbose)
    ):
        batch = batch.to(device)

        with torch.no_grad():
            model_output = wrap_jiant_forward(
                jiant_model=jiant_model, batch=batch, task=task, compute_loss=False,
            )
        batch_logits = model_output.logits.detach().cpu().numpy()
        eval_accumulator.update(
            batch_logits=batch_logits, batch_loss=0, batch=batch, batch_metadata=batch_metadata,
        )
    return {
        "preds": evaluation_scheme.get_preds_from_accumulator(
            task=task, accumulator=eval_accumulator,
        ),
        "accumulator": eval_accumulator,
    }
Esempio n. 2
0
def run_val(
    val_dataloader,
    val_labels,
    jiant_model: JiantModel,
    task,
    device,
    local_rank,
    return_preds=False,
    verbose=True,
):
    # Reminder:
    #   val_dataloader contains mostly PyTorch-relevant info
    #   val_labels might contain more details information needed for full evaluation
    if not local_rank == -1:
        return
    jiant_model.eval()
    total_eval_loss = 0
    nb_eval_steps, nb_eval_examples = 0, 0
    evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task)
    eval_accumulator = evaluation_scheme.get_accumulator()

    for step, (batch, batch_metadata) in enumerate(
        maybe_tqdm(val_dataloader, desc=f"Eval ({task.name}, Val)", verbose=verbose)
    ):
        batch = batch.to(device)

        with torch.no_grad():
            model_output = wrap_jiant_forward(
                jiant_model=jiant_model, batch=batch, task=task, compute_loss=True,
            )
        batch_logits = model_output.logits.detach().cpu().numpy()
        batch_loss = model_output.loss.mean().item()
        total_eval_loss += batch_loss
        eval_accumulator.update(
            batch_logits=batch_logits,
            batch_loss=batch_loss,
            batch=batch,
            batch_metadata=batch_metadata,
        )

        nb_eval_examples += len(batch)
        nb_eval_steps += 1
    eval_loss = total_eval_loss / nb_eval_steps
    tokenizer = (
        jiant_model.tokenizer
        if not torch_utils.is_data_parallel(jiant_model)
        else jiant_model.module.tokenizer
    )
    output = {
        "accumulator": eval_accumulator,
        "loss": eval_loss,
        "metrics": evaluation_scheme.compute_metrics_from_accumulator(
            task=task, accumulator=eval_accumulator, labels=val_labels, tokenizer=tokenizer,
        ),
    }
    if return_preds:
        output["preds"] = evaluation_scheme.get_preds_from_accumulator(
            task=task, accumulator=eval_accumulator,
        )
    return output
Esempio n. 3
0
def load_state_dict_for_jiant_model_with_adapters(jiant_model: JiantModel,
                                                  state_dict: Dict):
    """Load state_dict into jiant model, allowing for missing_keys (untrained encoder)

    The checks aren't very rigorous
    """
    mismatched = jiant_model.load_state_dict(state_dict, strict=False)
    assert mismatched.missing_keys
    assert not mismatched.unexpected_keys
Esempio n. 4
0
def get_optimized_state_dict_for_jiant_model_with_adapters(
        jiant_model: JiantModel):
    """Get the state_dict for relevant weights for a JiantModel with adapters

    Basically, the tensors for the adapters and taskmodel heads
    """
    dropped = []
    kept = {}
    for name, tensor in jiant_model.state_dict().items():
        if name.startswith("encoder.") and "adapter" not in name:
            dropped.append(name)
        elif ".encoder." in name and "adapter" not in name:
            # Do not keep the taskmodel encoder weights UNLESS they are adapter modules
            dropped.append(name)
        else:
            # This should include adapters and taskmodel heads
            kept[name] = tensor
    return kept, dropped
Esempio n. 5
0
def get_optimized_named_parameters_for_jiant_model_with_adapters(
        jiant_model: JiantModel):
    """Does a couple things:
    1. Finds the adapter parameters and taskmodel heads (the only params to be optimized)
    2. Sets the other parameters to not require gradients
    """
    set_to_no_grad_list = []
    optimized_named_parameters = []
    for name, param in jiant_model.named_parameters():
        if name.startswith("encoder.") and "adapter" not in name:
            # Do not optimize the shared encoder
            torch_utils.set_requires_grad_single(param, requires_grad=False)
            set_to_no_grad_list.append(name)
        elif ".encoder." in name and "adapter" not in name:
            # Do not optimize the taskmodel encoder weights UNLESS they are adapter modules
            # I believe this strictly speaking isn't necessary because .named_parameters()
            # doesn't return duplicates, but better to be safe.
            torch_utils.set_requires_grad_single(param, requires_grad=False)
            set_to_no_grad_list.append(name)
        else:
            # This should include adapters and taskmodel heads
            optimized_named_parameters.append((name, param))

    return optimized_named_parameters, set_to_no_grad_list
Esempio n. 6
0
def run_test(
    test_dataloader,
    jiant_model: JiantModel,
    task,
    device,
    local_rank,
    verbose=True,
    return_preds=True,
    return_logits=True,
    return_encoder_output: bool = False,
):
    if not local_rank == -1:
        return
    jiant_model.eval()
    evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task)
    eval_accumulator = evaluation_scheme.get_accumulator()

    encoder_outputs = []
    for step, (batch, batch_metadata) in enumerate(
            maybe_tqdm(test_dataloader,
                       desc=f"Eval ({task.name}, Test)",
                       verbose=verbose)):
        regular_log(logger, step, interval=10, tag='test')

        batch = batch.to(device)

        with torch.no_grad():
            model_outputs = wrap_jiant_forward(
                jiant_model=jiant_model,
                batch=batch,
                task=task,
                compute_loss=False,
                get_encoder_output=return_encoder_output,
            )
            if return_encoder_output:
                model_output, encoder_output = model_outputs
                encoder_outputs.append(encoder_output)
            else:
                model_output = model_outputs
        batch_logits = model_output.logits.detach().cpu().numpy()
        eval_accumulator.update(
            batch_logits=batch_logits,
            batch_loss=0,
            batch=batch,
            batch_metadata=batch_metadata,
        )
    output = {
        "accumulator": eval_accumulator,
    }
    if return_preds:
        output["preds"] = evaluation_scheme.get_preds_from_accumulator(
            task=task,
            accumulator=eval_accumulator,
        )
        if isinstance(eval_accumulator,
                      evaluate.ConcatenateLogitsAccumulator) and return_logits:
            output["logits"] = eval_accumulator.get_accumulated()
    if return_encoder_output:
        output["encoder_outputs_pooled"] = np.concatenate(
            [encoder_output.pooled for encoder_output in encoder_outputs])
        output["encoder_outputs_unpooled"] = np.concatenate(
            [encoder_output.unpooled for encoder_output in encoder_outputs])
    return output
Esempio n. 7
0
def run_val(
    val_dataloader,
    val_labels,
    jiant_model: JiantModel,
    task,
    device,
    local_rank,
    tf_writer: SummaryWriter,
    global_step: Optional[int] = None,
    phase=None,
    return_preds=False,
    return_logits=True,
    return_encoder_output: bool = False,
    verbose=True,
    split='valid',
):
    # Reminder:
    #   val_dataloader contains mostly PyTorch-relevant info
    #   val_labels might contain more details information needed for full evaluation
    has_labels = True  # TODO: データセットにラベルが存在するかどうかを自動判定する.

    if not local_rank == -1:
        return
    jiant_model.eval()
    evaluation_scheme = evaluate.get_evaluation_scheme_for_task(task=task)
    eval_accumulator = evaluation_scheme.get_accumulator()
    total_eval_loss = 0
    nb_eval_steps, nb_eval_examples = 0, 0

    encoder_outputs = []
    for step, (batch, batch_metadata) in enumerate(
            maybe_tqdm(val_dataloader,
                       desc=f"Eval ({task.name}, {str(phase)})",
                       verbose=verbose)):
        regular_log(logger, step, interval=10, tag=split)

        batch = batch.to(device)

        with torch.no_grad():
            model_outputs = wrap_jiant_forward(
                jiant_model=jiant_model,
                batch=batch,
                task=task,
                compute_loss=has_labels,
                get_encoder_output=return_encoder_output,
            )
            if return_encoder_output:
                model_output, encoder_output = model_outputs
                encoder_outputs.append(encoder_output)
            else:
                model_output = model_outputs
        batch_logits = model_output.logits.detach().cpu().numpy()
        if has_labels:
            batch_loss = model_output.loss.mean().item()
        else:
            batch_loss = 0
        total_eval_loss += batch_loss
        eval_accumulator.update(
            batch_logits=batch_logits,
            batch_loss=batch_loss,
            batch=batch,
            batch_metadata=batch_metadata,
        )

        nb_eval_examples += len(batch)
        nb_eval_steps += 1

    eval_loss = total_eval_loss / nb_eval_steps
    output = {
        "accumulator": eval_accumulator,
    }

    if has_labels:
        tokenizer = (jiant_model.tokenizer
                     if not torch_utils.is_data_parallel(jiant_model) else
                     jiant_model.module.tokenizer)
        metrics = evaluation_scheme.compute_metrics_from_accumulator(
            task=task,
            accumulator=eval_accumulator,
            labels=val_labels,
            tokenizer=tokenizer,
        )

        output.update({
            "loss": eval_loss,
            "metrics": metrics,
        })

        if global_step is not None:
            for metric_name, metric_value in metrics.minor.items():
                tf_writer.add_scalar(f'{split}/{metric_name}',
                                     metric_value,
                                     global_step=global_step)

    if return_preds:
        output["preds"] = evaluation_scheme.get_preds_from_accumulator(
            task=task,
            accumulator=eval_accumulator,
        )
        if isinstance(eval_accumulator,
                      evaluate.ConcatenateLogitsAccumulator) and return_logits:
            output["logits"] = eval_accumulator.get_accumulated()
    if return_encoder_output:
        output["encoder_outputs_pooled"] = np.concatenate(
            [encoder_output.pooled for encoder_output in encoder_outputs])
        output["encoder_outputs_unpooled"] = np.concatenate(
            [encoder_output.unpooled for encoder_output in encoder_outputs])
    if global_step is not None:
        tf_writer.add_scalar(f'{split}/loss',
                             eval_loss,
                             global_step=global_step)

    tf_writer.flush()
    return output
Esempio n. 8
0
def main(args: RunConfiguration):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # === Shared model components setup === #
    model_type = "roberta-base"
    model_arch = ModelArchitectures.from_model_type(model_type=model_type)
    transformers_class_spec = model_setup.TRANSFORMERS_CLASS_SPEC_DICT[
        model_arch]
    ancestor_model = model_setup.get_ancestor_model(
        transformers_class_spec=transformers_class_spec,
        model_config_path=args.model_config_path,
    )
    encoder = model_setup.get_encoder(
        model_arch=model_arch,
        ancestor_model=ancestor_model,
    )
    tokenizer = shared_model_setup.get_tokenizer(
        model_type=model_type,
        tokenizer_path=args.model_tokenizer_path,
    )

    # === Taskmodels setup === #
    task_dict = {
        "mnli":
        tasks.create_task_from_config_path(
            os.path.join(
                args.task_config_base_path,
                "mnli.json",
            )),
        "qnli":
        tasks.create_task_from_config_path(
            os.path.join(
                args.task_config_base_path,
                "qnli.json",
            )),
        "rte":
        tasks.create_task_from_config_path(
            os.path.join(
                args.task_config_base_path,
                "qnli.json",
            ))
    }
    taskmodels_dict = {
        "nli":
        taskmodels.ClassificationModel(
            encoder=encoder,
            classification_head=heads.ClassificationHead(
                hidden_size=encoder.config.hidden_size,
                hidden_dropout_prob=encoder.config.hidden_dropout_prob,
                num_labels=len(task_dict["mnli"].LABELS),
            ),
        ),
        "rte":
        taskmodels.ClassificationModel(
            encoder=encoder,
            classification_head=heads.ClassificationHead(
                hidden_size=encoder.config.hidden_size,
                hidden_dropout_prob=encoder.config.hidden_dropout_prob,
                num_labels=len(task_dict["rte"].LABELS),
            ),
        ),
    }
    task_to_taskmodel_map = {
        "mnli": "nli",
        "qnli": "nli",
        "rte": "rte",
    }

    # === Final === #
    jiant_model = JiantModel(
        task_dict=task_dict,
        encoder=encoder,
        taskmodels_dict=taskmodels_dict,
        task_to_taskmodel_map=task_to_taskmodel_map,
        tokenizer=tokenizer,
    )
    jiant_model = jiant_model.to(device)

    # === Run === #
    task_dataloader_dict = {}
    for task_name, task in task_dict.items():
        train_cache = caching.ChunkedFilesDataCache(
            cache_fol_path=os.path.join(args.task_cache_base_path, task_name,
                                        "train"), )
        train_dataset = train_cache.get_iterable_dataset(buffer_size=10000,
                                                         shuffle=True)
        train_dataloader = torch_utils.DataLoaderWithLength(
            dataset=train_dataset,
            batch_size=4,
            collate_fn=task.collate_fn,
        )
        task_dataloader_dict[task_name] = train_dataloader

    for task_name, task in task_dict.items():
        batch, batch_metadata = next(iter(task_dataloader_dict[task_name]))
        batch = batch.to(device)
        with torch.no_grad():
            model_output = wrap_jiant_forward(
                jiant_model=jiant_model,
                batch=batch,
                task=task,
                compute_loss=True,
            )
        print(task_name)
        print(model_output)
        print()