Пример #1
0
def test_logging_manager_no_checkpointing(caplog):
    """Unit test of logging_manager (no checkpointing)"""

    caplog.set_level(logging.INFO)

    emmental.init()
    Meta.update_config(
        config={
            "meta_config": {
                "verbose": False
            },
            "logging_config": {
                "counter_unit": "epoch",
                "evaluation_freq": 1,
                "checkpointing": False,
                "checkpointer_config": {
                    "checkpoint_freq": 2
                },
                "writer_config": {
                    "writer": "json"
                },
            },
        })

    logging_manager = LoggingManager(n_batches_per_epoch=2)

    logging_manager.update(5)
    assert logging_manager.trigger_evaluation() is False
    assert logging_manager.trigger_checkpointing() is False

    logging_manager.update(5)
    assert logging_manager.trigger_evaluation() is True
    assert logging_manager.trigger_checkpointing() is False

    logging_manager.update(10)
    assert logging_manager.trigger_evaluation() is False
    assert logging_manager.trigger_checkpointing() is False

    logging_manager.update(5)
    assert logging_manager.trigger_evaluation() is True
    assert logging_manager.trigger_checkpointing() is False

    assert logging_manager.epoch_count == 0

    assert logging_manager.sample_total == 25
    assert logging_manager.batch_total == 4
    assert logging_manager.epoch_total == 2

    model = EmmentalModel()

    logging_manager.close(model)
Пример #2
0
def predict_and_write(task_name, path, data_dir, submit_subdir, batch_size):
    bert_model_name, max_seq_len = extract_from_cmd(path)
    msg = (
        f"Using {bert_model_name} and max_sequence_len={max_seq_len} for task "
        f"{task_name}")
    logger.info(msg)

    # Build model
    task = build_model[task_name](bert_model_name)
    model = EmmentalModel(name=f"SuperGLUE_{task_name}", tasks=[task])
    try:
        model.load(path)
    except UnboundLocalError:
        msg = (
            "Failed to load state dict; confirm that your model was saved with "
            "a command such as 'torch.save(model.state_dict(), PATH)'")
        logging.error(msg)
        raise

    # Build dataloaders
    dataloaders = get_dataloaders(
        data_dir,
        task_name=task_name,
        splits=["val",
                "test"],  # TODO: replace with ['split'] and update below
        max_data_samples=None,
        max_sequence_length=max_seq_len,
        tokenizer_name=bert_model_name,
        batch_size=batch_size,
        uid="uids",
    )
    # TEMP: Sanity check val performance
    logger.info(f"Valid score: {model.score(dataloaders[0])}")
    # TEMP

    filename = f"{task_name}.jsonl"
    filepath = os.path.join(submit_subdir, filename)
    make_submission_file(model, dataloaders[-1], task_name, filepath)
Пример #3
0
    def _set_optimizer(self, model: EmmentalModel) -> None:
        r"""Set optimizer for learning process.

        Args:
          model(EmmentalModel): The model to set up the optimizer.

        """
        optimizer_config = Meta.config["learner_config"]["optimizer_config"]
        opt = optimizer_config["optimizer"]

        parameters = filter(lambda p: p.requires_grad, model.parameters())

        optim_dict = {
            # PyTorch optimizer
            "asgd": optim.ASGD,  # type: ignore
            "adadelta": optim.Adadelta,  # type: ignore
            "adagrad": optim.Adagrad,  # type: ignore
            "adam": optim.Adam,  # type: ignore
            "adamw": optim.AdamW,  # type: ignore
            "adamax": optim.Adamax,  # type: ignore
            "lbfgs": optim.LBFGS,  # type: ignore
            "rms_prop": optim.RMSprop,  # type: ignore
            "r_prop": optim.Rprop,  # type: ignore
            "sgd": optim.SGD,  # type: ignore
            "sparse_adam": optim.SparseAdam,  # type: ignore
            # Customize optimizer
            "bert_adam": BertAdam,
        }

        if opt in ["lbfgs", "r_prop", "sparse_adam"]:
            optimizer = optim_dict[opt](
                parameters,
                lr=optimizer_config["lr"],
                **optimizer_config[f"{opt}_config"],
            )
        elif opt in optim_dict.keys():
            optimizer = optim_dict[opt](
                parameters,
                lr=optimizer_config["lr"],
                weight_decay=optimizer_config["l2"],
                **optimizer_config[f"{opt}_config"],
            )
        elif isinstance(opt, optim.Optimizer):  # type: ignore
            optimizer = opt(parameters)
        else:
            raise ValueError(f"Unrecognized optimizer option '{opt}'")

        self.optimizer = optimizer

        logger.info(f"Using optimizer {self.optimizer}")
Пример #4
0
    def _set_optimizer(self, model: EmmentalModel) -> None:
        r"""Set optimizer for learning process.

        Args:
          model(EmmentalModel): The model to set up the optimizer.

        """

        # TODO: add more optimizer support and fp16
        optimizer_config = Meta.config["learner_config"]["optimizer_config"]
        opt = optimizer_config["optimizer"]

        parameters = filter(lambda p: p.requires_grad, model.parameters())

        if opt == "sgd":
            optimizer = optim.SGD(
                parameters,
                lr=optimizer_config["lr"],
                weight_decay=optimizer_config["l2"],
                **optimizer_config["sgd_config"],
            )
        elif opt == "adam":
            optimizer = optim.Adam(
                parameters,
                lr=optimizer_config["lr"],
                weight_decay=optimizer_config["l2"],
                **optimizer_config["adam_config"],
            )
        elif opt == "adamax":
            optimizer = optim.Adamax(
                parameters,
                lr=optimizer_config["lr"],
                weight_decay=optimizer_config["l2"],
                **optimizer_config["adamax_config"],
            )
        elif opt == "bert_adam":
            optimizer = BertAdam(
                parameters,
                lr=optimizer_config["lr"],
                weight_decay=optimizer_config["l2"],
                **optimizer_config["bert_adam_config"],
            )
        else:
            raise ValueError(f"Unrecognized optimizer option '{opt}'")

        logger.info(f"Using optimizer {optimizer}")

        self.optimizer = optimizer
Пример #5
0
    def collect_state_dict(
        self,
        iteration: Union[float, int],
        model: EmmentalModel,
        optimizer: Optimizer,
        lr_scheduler: _LRScheduler,
        metric_dict: Dict[str, float],
    ) -> Dict[str, Any]:
        r"""Collect the state dict of the model.

        Args:
          iteration(float or int): The current iteration.
          model(EmmentalModel): The model to checkpoint.
          optimizer(Optimizer): The optimizer used during training process.
          lr_scheduler(_LRScheduler): Learning rate scheduler.
          metric_dict(dict): the metric dict.

        Returns:
          dict: The state dict.
        """

        model_params = {
            "name": model.name,
            "module_pool": model.collect_state_dict(),
            # "task_names": model.task_names,
            # "task_flows": model.task_flows,
            # "loss_funcs": model.loss_funcs,
            # "output_funcs": model.output_funcs,
            # "scorers": model.scorers,
        }

        state_dict = {
            "iteration": iteration,
            "model": model_params,
            "optimizer": optimizer.state_dict(),
            "lr_scheduler":
            lr_scheduler.state_dict() if lr_scheduler else None,
            "metric_dict": metric_dict,
        }

        return state_dict
Пример #6
0
def test_predict(mocker, setup_common_components: Dict):
    """Test if a Fonduer model can predict."""
    kwargs = setup_common_components
    featurizer = Featurizer(None, [PartTemp])
    # Mock the get_keys()
    featurizer.get_keys = MagicMock(return_value=[FeatureKey(name="key1")])
    emmental.meta.init_config()

    # Log the model with FonduerModel()
    log_model(
        FonduerModel(),
        artifact_path,
        **kwargs,
        code_paths=[
            "tests"
        ],  # pass a directory name to preserver the directory hierarchy
        featurizer=featurizer,
        emmental_model=EmmentalModel(),
        word2id={"foo": 1},
    )
    # Load the model
    fonduer_model = mlflow.pyfunc.load_model(
        os.path.join(mlflow.active_run().info.artifact_uri, artifact_path)
    )
    with pytest.raises(NotImplementedError):
        _ = fonduer_model.predict(
            pd.DataFrame(data={"html_path": ["tests/data/html/112823.html"]})
        )

    # Log the model with HardwareFonduerModel()
    log_model(
        HardwareFonduerModel(),
        artifact_path,
        **kwargs,
        code_paths=[
            "tests"
        ],  # pass a directory name to preserver the directory hierarchy
        featurizer=featurizer,
        emmental_model=EmmentalModel(),
        word2id={"foo": 1},
    )
    # Load the model
    fonduer_model = mlflow.pyfunc.load_model(
        os.path.join(mlflow.active_run().info.artifact_uri, artifact_path)
    )

    # Mock the _classify as we don't test the implementation of _classify here.
    mock_output = pd.DataFrame(data={"col1": ["val1"], "col2": ["val2"]})
    fonduer_model._classify = MagicMock(return_value=mock_output)

    # Input both html_path and pdf_html
    spy = mocker.spy(fonduer_model, "_process")
    output = fonduer_model.predict(
        pd.DataFrame(
            data={
                "html_path": ["tests/data/html/112823.html"],
                "pdf_path": ["tests/data/pdf/112823.pdf"],
            }
        )
    )
    spy.assert_called_once_with(
        "tests/data/html/112823.html", "tests/data/pdf/112823.pdf"
    )
    assert output.equals(
        pd.DataFrame(
            data={
                "col1": ["val1"],
                "col2": ["val2"],
                "html_path": ["tests/data/html/112823.html"],
            }
        )
    )

    # Input only html_path
    spy.reset_mock()
    output = fonduer_model.predict(
        pd.DataFrame(data={"html_path": ["tests/data/html/112823.html"]})
    )
    spy.assert_called_once_with("tests/data/html/112823.html", None)
    assert output.equals(
        pd.DataFrame(
            data={
                "col1": ["val1"],
                "col2": ["val2"],
                "html_path": ["tests/data/html/112823.html"],
            }
        )
    )

    # Input html_path that does not exist
    spy.reset_mock()

    with pytest.raises(ValueError):
        _ = fonduer_model.predict(
            pd.DataFrame(data={"html_path": ["tests/data/html/foo.html"]})
        )

    # Test when _classify produces multiple relations per doc.
    mock_output = pd.DataFrame(data={"col0": ["00", "10"], "col1": ["01", "11"]})
    fonduer_model._classify = MagicMock(return_value=mock_output)
    output = fonduer_model.predict(
        pd.DataFrame(data={"html_path": ["tests/data/html/112823.html"]})
    )
    assert output.equals(
        pd.DataFrame(
            data={
                "col0": ["00", "10"],
                "col1": ["01", "11"],
                "html_path": [
                    "tests/data/html/112823.html",
                    "tests/data/html/112823.html",
                ],
            }
        )
    )
Пример #7
0
def test_model(caplog):
    """Unit test of model."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_model"

    Meta.reset()
    emmental.init(dirpath)

    def ce_loss(module_name, immediate_output_dict, Y, active):
        return F.cross_entropy(immediate_output_dict[module_name][0][active],
                               (Y.view(-1))[active])

    def output(module_name, immediate_output_dict):
        return F.softmax(immediate_output_dict[module_name][0], dim=1)

    task1 = EmmentalTask(
        name="task_1",
        module_pool=nn.ModuleDict({
            "m1": nn.Linear(10, 10, bias=False),
            "m2": nn.Linear(10, 2, bias=False)
        }),
        task_flow=[
            {
                "name": "m1",
                "module": "m1",
                "inputs": [("_input_", "data")]
            },
            {
                "name": "m2",
                "module": "m2",
                "inputs": [("m1", 0)]
            },
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    new_task1 = EmmentalTask(
        name="task_1",
        module_pool=nn.ModuleDict({
            "m1": nn.Linear(10, 5, bias=False),
            "m2": nn.Linear(5, 2, bias=False)
        }),
        task_flow=[
            {
                "name": "m1",
                "module": "m1",
                "inputs": [("_input_", "data")]
            },
            {
                "name": "m2",
                "module": "m2",
                "inputs": [("m1", 0)]
            },
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    task2 = EmmentalTask(
        name="task_2",
        module_pool=nn.ModuleDict({
            "m1": nn.Linear(10, 5, bias=False),
            "m2": nn.Linear(5, 2, bias=False)
        }),
        task_flow=[
            {
                "name": "m1",
                "module": "m1",
                "inputs": [("_input_", "data")]
            },
            {
                "name": "m2",
                "module": "m2",
                "inputs": [("m1", 0)]
            },
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    config = {"model_config": {"dataparallel": False}}
    emmental.Meta.update_config(config)

    model = EmmentalModel(name="test", tasks=task1)

    assert repr(model) == "EmmentalModel(name=test)"
    assert model.name == "test"
    assert model.task_names == set(["task_1"])
    assert model.module_pool["m1"].weight.data.size() == (10, 10)
    assert model.module_pool["m2"].weight.data.size() == (2, 10)

    model.update_task(new_task1)

    assert model.module_pool["m1"].weight.data.size() == (5, 10)
    assert model.module_pool["m2"].weight.data.size() == (2, 5)

    model.update_task(task2)

    assert model.task_names == set(["task_1"])

    model.add_task(task2)

    assert model.task_names == set(["task_1", "task_2"])

    model.remove_task("task_1")
    assert model.task_names == set(["task_2"])

    model.remove_task("task_1")
    assert model.task_names == set(["task_2"])

    model.save(f"{dirpath}/saved_model.pth")

    model.load(f"{dirpath}/saved_model.pth")

    # Test add_tasks
    model = EmmentalModel(name="test")

    model.add_tasks([task1, task2])
    assert model.task_names == set(["task_1", "task_2"])

    shutil.rmtree(dirpath)
Пример #8
0
def test_e2e(caplog):
    """Run an end-to-end test."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_e2e"
    use_exact_log_path = False
    Meta.reset()
    emmental.init(dirpath, use_exact_log_path=use_exact_log_path)

    config = {
        "meta_config": {
            "seed": 0
        },
        "learner_config": {
            "n_epochs": 3,
            "optimizer_config": {
                "lr": 0.01,
                "grad_clip": 100
            },
        },
        "logging_config": {
            "counter_unit": "epoch",
            "evaluation_freq": 1,
            "writer_config": {
                "writer": "tensorboard",
                "verbose": True
            },
            "checkpointing": True,
            "checkpointer_config": {
                "checkpoint_path": None,
                "checkpoint_freq": 1,
                "checkpoint_metric": {
                    "model/all/train/loss": "min"
                },
                "checkpoint_task_metrics": None,
                "checkpoint_runway": 1,
                "checkpoint_all": False,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }
    emmental.Meta.update_config(config)

    # Generate synthetic data
    N = 500
    X = np.random.random((N, 2)) * 2 - 1
    Y1 = (X[:, 0] > X[:, 1] + 0.25).astype(int)
    Y2 = (X[:, 0] > X[:, 1] + 0.2).astype(int)

    X = [torch.Tensor(X[i]) for i in range(N)]
    # Create dataset and dataloader

    X_train, X_dev, X_test = (
        X[:int(0.8 * N)],
        X[int(0.8 * N):int(0.9 * N)],
        X[int(0.9 * N):],
    )
    Y1_train, Y1_dev, Y1_test = (
        torch.tensor(Y1[:int(0.8 * N)]),
        torch.tensor(Y1[int(0.8 * N):int(0.9 * N)]),
        torch.tensor(Y1[int(0.9 * N):]),
    )
    Y2_train, Y2_dev, Y2_test = (
        torch.tensor(Y2[:int(0.8 * N)]),
        torch.tensor(Y2[int(0.8 * N):int(0.9 * N)]),
        torch.tensor(Y2[int(0.9 * N):]),
    )

    train_dataset1 = EmmentalDataset(name="synthetic",
                                     X_dict={"data": X_train},
                                     Y_dict={"label1": Y1_train})

    train_dataset2 = EmmentalDataset(name="synthetic",
                                     X_dict={"data": X_train},
                                     Y_dict={"label2": Y2_train})

    dev_dataset1 = EmmentalDataset(name="synthetic",
                                   X_dict={"data": X_dev},
                                   Y_dict={"label1": Y1_dev})

    dev_dataset2 = EmmentalDataset(name="synthetic",
                                   X_dict={"data": X_dev},
                                   Y_dict={"label2": Y2_dev})

    test_dataset1 = EmmentalDataset(name="synthetic",
                                    X_dict={"data": X_test},
                                    Y_dict={"label1": Y1_test})

    test_dataset2 = EmmentalDataset(name="synthetic",
                                    X_dict={"data": X_test},
                                    Y_dict={"label2": Y2_test})

    task_to_label_dict = {"task1": "label1"}

    train_dataloader1 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=train_dataset1,
        split="train",
        batch_size=10,
    )
    dev_dataloader1 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=dev_dataset1,
        split="valid",
        batch_size=10,
    )
    test_dataloader1 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=test_dataset1,
        split="test",
        batch_size=10,
    )

    task_to_label_dict = {"task2": "label2"}

    train_dataloader2 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=train_dataset2,
        split="train",
        batch_size=10,
    )
    dev_dataloader2 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=dev_dataset2,
        split="valid",
        batch_size=10,
    )
    test_dataloader2 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=test_dataset2,
        split="test",
        batch_size=10,
    )

    # Create task
    def ce_loss(task_name, immediate_ouput_dict, Y, active):
        module_name = f"{task_name}_pred_head"
        return F.cross_entropy(immediate_ouput_dict[module_name][0][active],
                               (Y.view(-1))[active])

    def output(task_name, immediate_ouput_dict):
        module_name = f"{task_name}_pred_head"
        return F.softmax(immediate_ouput_dict[module_name][0], dim=1)

    task_metrics = {"task1": ["accuracy"], "task2": ["accuracy", "roc_auc"]}

    tasks = [
        EmmentalTask(
            name=task_name,
            module_pool=nn.ModuleDict({
                "input_module":
                nn.Linear(2, 8),
                f"{task_name}_pred_head":
                nn.Linear(8, 2),
            }),
            task_flow=[
                {
                    "name": "input",
                    "module": "input_module",
                    "inputs": [("_input_", "data")],
                },
                {
                    "name": f"{task_name}_pred_head",
                    "module": f"{task_name}_pred_head",
                    "inputs": [("input", 0)],
                },
            ],
            loss_func=partial(ce_loss, task_name),
            output_func=partial(output, task_name),
            scorer=Scorer(metrics=task_metrics[task_name]),
        ) for task_name in ["task1", "task2"]
    ]

    # Build model

    mtl_model = EmmentalModel(name="all", tasks=tasks)

    # Create learner
    emmental_learner = EmmentalLearner()

    # Learning
    emmental_learner.learn(
        mtl_model,
        [
            train_dataloader1, train_dataloader2, dev_dataloader1,
            dev_dataloader2
        ],
    )

    test1_score = mtl_model.score(test_dataloader1)
    test2_score = mtl_model.score(test_dataloader2)

    assert test1_score["task1/synthetic/test/accuracy"] >= 0.7
    assert (test1_score["model/all/test/macro_average"] ==
            test1_score["task1/synthetic/test/accuracy"])
    assert test2_score["task2/synthetic/test/accuracy"] >= 0.7
    assert test2_score["task2/synthetic/test/roc_auc"] >= 0.7

    shutil.rmtree(dirpath)
Пример #9
0
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]

    emmental.Meta.config["learner_config"]["optimizer_config"][
        "parameters"] = grouped_parameters

    # Create tasks
    model = EmmentalModel(name="TACRED_task")
    model.add_task(create_task(args))

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    if args.train:
        emmental_learner = EmmentalLearner()
        emmental_learner.learn(model, dataloaders)

    # Remove all extra augmentation policy
    for idx in range(len(dataloaders)):
        dataloaders[idx].dataset.transform_cls = None

    scores = model.score(dataloaders)
Пример #10
0
def run_model(mode, config, run_config_path=None):
    """
    Main run method for Emmental Bootleg models.
    Args:
        mode: run mode (train, eval, dump_preds, dump_embs)
        config: parsed model config
        run_config_path: original config path (for saving)

    Returns:

    """

    # Set up distributed backend and save configuration files
    setup(config, run_config_path)

    # Load entity symbols
    log_rank_0_info(logger, f"Loading entity symbols...")
    entity_symbols = EntitySymbols.load_from_cache(
        load_dir=os.path.join(config.data_config.entity_dir,
                              config.data_config.entity_map_dir),
        alias_cand_map_file=config.data_config.alias_cand_map,
        alias_idx_file=config.data_config.alias_idx_map,
    )
    # Create tasks
    tasks = [NED_TASK]
    if config.data_config.type_prediction.use_type_pred is True:
        tasks.append(TYPE_PRED_TASK)

    # Create splits for data loaders
    data_splits = [TRAIN_SPLIT, DEV_SPLIT, TEST_SPLIT]
    # Slices are for eval so we only split on test/dev
    slice_splits = [DEV_SPLIT, TEST_SPLIT]
    # If doing eval, only run on test data
    if mode in ["eval", "dump_preds", "dump_embs"]:
        data_splits = [TEST_SPLIT]
        slice_splits = [TEST_SPLIT]
        # We only do dumping if weak labels is True
        if mode in ["dump_preds", "dump_embs"]:
            if config.data_config[
                    f"{TEST_SPLIT}_dataset"].use_weak_label is False:
                raise ValueError(
                    f"When calling dump_preds or dump_embs, we require use_weak_label to be True."
                )

    # Gets embeddings that need to be prepped during data prep or in the __get_item__ method
    batch_on_the_fly_kg_adj = get_dataloader_embeddings(config, entity_symbols)
    # Gets dataloaders
    dataloaders = get_dataloaders(
        config,
        tasks,
        data_splits,
        entity_symbols,
        batch_on_the_fly_kg_adj,
    )
    slice_datasets = get_slicedatasets(config, slice_splits, entity_symbols)

    configure_optimizer(config)

    # Create models and add tasks
    if config.model_config.attn_class == "BERTNED":
        log_rank_0_info(logger, f"Starting NED-Base Model")
        assert (config.data_config.type_prediction.use_type_pred is
                False), f"NED-Base does not support type prediction"
        assert (
            config.data_config.word_embedding.use_sent_proj is False
        ), f"NED-Base requires word_embeddings.use_sent_proj to be False"
        model = EmmentalModel(name="NED-Base")
        model.add_tasks(
            ned_task.create_task(config, entity_symbols, slice_datasets))
    else:
        log_rank_0_info(logger, f"Starting Bootleg Model")
        model = EmmentalModel(name="Bootleg")
        # TODO: make this more general for other tasks -- iterate through list of tasks
        # and add task for each
        model.add_task(
            ned_task.create_task(config, entity_symbols, slice_datasets))
        if TYPE_PRED_TASK in tasks:
            model.add_task(
                type_pred_task.create_task(config, entity_symbols,
                                           slice_datasets))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(model)

    # Print param counts
    if mode == "train":
        log_rank_0_debug(logger, "PARAMS WITH GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=True,
                                        logger=logger)
        log_rank_0_info(logger, f"===> Total Params With Grad: {total_params}")
        log_rank_0_debug(logger, "PARAMS WITHOUT GRAD\n" + "=" * 30)
        total_params = count_parameters(model,
                                        requires_grad=False,
                                        logger=logger)
        log_rank_0_info(logger,
                        f"===> Total Params Without Grad: {total_params}")

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    # Barrier
    if config["learner_config"]["local_rank"] == 0:
        torch.distributed.barrier()

    # Train model
    if mode == "train":
        emmental_learner = EmmentalLearner()
        emmental_learner._set_optimizer(model)
        emmental_learner.learn(model, dataloaders)
        if config.learner_config.local_rank in [0, -1]:
            model.save(f"{emmental.Meta.log_path}/last_model.pth")

    # Multi-gpu DataParallel eval (NOT distributed)
    if mode in ["eval", "dump_embs", "dump_preds"]:
        # This happens inside EmmentalLearner for training
        if (config["learner_config"]["local_rank"] == -1
                and config["model_config"]["dataparallel"]):
            model._to_dataparallel()

    # If just finished training a model or in eval mode, run eval
    if mode in ["train", "eval"]:
        scores = model.score(dataloaders)
        # Save metrics and models
        log_rank_0_info(logger, f"Saving metrics to {emmental.Meta.log_path}")
        log_rank_0_info(logger, f"Metrics: {scores}")
        scores["log_path"] = emmental.Meta.log_path
        if config.learner_config.local_rank in [0, -1]:
            write_to_file(f"{emmental.Meta.log_path}/{mode}_metrics.txt",
                          scores)
            eval_utils.write_disambig_metrics_to_csv(
                f"{emmental.Meta.log_path}/{mode}_disambig_metrics.csv",
                scores)
        return scores

    # If you want detailed dumps, save model outputs
    assert mode in [
        "dump_preds",
        "dump_embs",
    ], 'Mode must be "dump_preds" or "dump_embs"'
    dump_embs = False if mode != "dump_embs" else True
    assert (
        len(dataloaders) == 1
    ), f"We should only have length 1 dataloaders for dump_embs and dump_preds!"
    final_result_file, final_out_emb_file = None, None
    if config.learner_config.local_rank in [0, -1]:
        # Setup files/folders
        filename = os.path.basename(dataloaders[0].dataset.raw_filename)
        log_rank_0_debug(
            logger,
            f"Collecting sentence to mention map {os.path.join(config.data_config.data_dir, filename)}",
        )
        sentidx2num_mentions, sent_idx2row = eval_utils.get_sent_idx2num_mens(
            os.path.join(config.data_config.data_dir, filename))
        log_rank_0_debug(logger, f"Done collecting sentence to mention map")
        eval_folder = eval_utils.get_eval_folder(filename)
        subeval_folder = os.path.join(eval_folder, "batch_results")
        utils.ensure_dir(subeval_folder)
        # Will keep track of sentences dumped already. These will only be ones with mentions
        all_dumped_sentences = set()
        number_dumped_batches = 0
        total_mentions_seen = 0
        all_result_files = []
        all_out_emb_files = []
        # Iterating over batches of predictions
        for res_i, res_dict in enumerate(
                eval_utils.batched_pred_iter(
                    model,
                    dataloaders[0],
                    config.run_config.eval_accumulation_steps,
                    sentidx2num_mentions,
                )):
            (
                result_file,
                out_emb_file,
                final_sent_idxs,
                mentions_seen,
            ) = eval_utils.disambig_dump_preds(
                res_i,
                total_mentions_seen,
                config,
                res_dict,
                sentidx2num_mentions,
                sent_idx2row,
                subeval_folder,
                entity_symbols,
                dump_embs,
                NED_TASK,
            )
            all_dumped_sentences.update(final_sent_idxs)
            all_result_files.append(result_file)
            all_out_emb_files.append(out_emb_file)
            total_mentions_seen += mentions_seen
            number_dumped_batches += 1

        # Dump the sentences that had no mentions and were not already dumped
        # Assert all remaining sentences have no mentions
        assert all(
            v == 0 for k, v in sentidx2num_mentions.items()
            if k not in all_dumped_sentences
        ), (f"Sentences with mentions were not dumped: "
            f"{[k for k, v in sentidx2num_mentions.items() if k not in all_dumped_sentences]}"
            )
        empty_sentidx2row = {
            k: v
            for k, v in sent_idx2row.items() if k not in all_dumped_sentences
        }
        empty_resultfile = eval_utils.get_result_file(number_dumped_batches,
                                                      subeval_folder)
        all_result_files.append(empty_resultfile)
        # Dump the outputs
        eval_utils.write_data_labels_single(
            sentidx2row=empty_sentidx2row,
            output_file=empty_resultfile,
            filt_emb_data=None,
            sental2embid={},
            alias_cand_map=entity_symbols.get_alias2qids(),
            qid2eid=entity_symbols.get_qid2eid(),
            result_alias_offset=total_mentions_seen,
            train_in_cands=config.data_config.train_in_candidates,
            max_cands=entity_symbols.max_candidates,
            dump_embs=dump_embs,
        )

        log_rank_0_info(
            logger,
            f"Finished dumping. Merging results across accumulation steps.")
        # Final result files for labels and embeddings
        final_result_file = os.path.join(eval_folder,
                                         config.run_config.result_label_file)
        # Copy labels
        output = open(final_result_file, "wb")
        for file in all_result_files:
            shutil.copyfileobj(open(file, "rb"), output)
        output.close()
        log_rank_0_info(logger, f"Bootleg labels saved at {final_result_file}")
        # Try to copy embeddings
        if dump_embs:
            final_out_emb_file = os.path.join(
                eval_folder, config.run_config.result_emb_file)
            log_rank_0_info(
                logger,
                f"Trying to merge numpy embedding arrays. "
                f"If your machine is limited in memory, this may cause OOM errors. "
                f"Is that happens, result files should be saved in {subeval_folder}.",
            )
            all_arrays = []
            for i, npfile in enumerate(all_out_emb_files):
                all_arrays.append(np.load(npfile))
            np.save(final_out_emb_file, np.concatenate(all_arrays))
            log_rank_0_info(
                logger, f"Bootleg embeddings saved at {final_out_emb_file}")

        # Cleanup
        try_rmtree(subeval_folder)
    return final_result_file, final_out_emb_file
Пример #11
0
    def train_model(cands, F, align_type, model_type="LogisticRegression"):
        # Extract candidates and features based on the align type (row/column)
        align_val = 0 if align_type == "row" else 1
        train_cands = cands[align_val][0]
        F_train = F[align_val][0]
        train_marginals = np.array([[0, 1] if gold[align_val](x) else [1, 0]
                                    for x in train_cands[0]])

        # 1.) Setup training config
        config = {
            "meta_config": {
                "verbose": True
            },
            "model_config": {
                "model_path": None,
                "device": 0,
                "dataparallel": False
            },
            "learner_config": {
                "n_epochs": 50,
                "optimizer_config": {
                    "lr": 0.001,
                    "l2": 0.0
                },
                "task_scheduler": "round_robin",
            },
            "logging_config": {
                "evaluation_freq": 1,
                "counter_unit": "epoch",
                "checkpointing": False,
                "checkpointer_config": {
                    "checkpoint_metric": {
                        f"{ATTRIBUTE}/{ATTRIBUTE}/train/loss": "min"
                    },
                    "checkpoint_freq": 1,
                    "checkpoint_runway": 2,
                    "clear_intermediate_checkpoints": True,
                    "clear_all_checkpoints": True,
                },
            },
        }

        emmental.init(Meta.log_path)
        emmental.Meta.update_config(config=config)

        # 2.) Collect word counter from training data
        word_counter = collect_word_counter(train_cands)

        # 3.) Generate word embedding module for LSTM model
        # (in Logistic Regression, we generate it since Fonduer dataset requires word2id dict)
        # Geneate special tokens
        arity = 2
        specials = []
        for i in range(arity):
            specials += [f"~~[[{i}", f"{i}]]~~"]

        emb_layer = EmbeddingModule(word_counter=word_counter,
                                    word_dim=300,
                                    specials=specials)

        # 4.) Generate dataloader for training set
        # No noise in Gold labels
        train_dataloader = EmmentalDataLoader(
            task_to_label_dict={ATTRIBUTE: "labels"},
            dataset=FonduerDataset(
                ATTRIBUTE,
                train_cands[0],
                F_train[0],
                emb_layer.word2id,
                train_marginals,
            ),
            split="train",
            batch_size=100,
            shuffle=True,
        )

        # 5.) Training
        tasks = create_task(
            ATTRIBUTE,
            2,
            F_train[0].shape[1],
            2,
            emb_layer,
            model=model_type  # "LSTM" 
        )

        model = EmmentalModel(name=f"{ATTRIBUTE}_task")

        for task in tasks:
            model.add_task(task)

        emmental_learner = EmmentalLearner()
        emmental_learner.learn(model, [train_dataloader])

        return (model, emb_layer)
Пример #12
0
    def __init__(
        self,
        config=None,
        device=None,
        max_alias_len=6,
        cand_map=None,
        threshold=0.0,
        cache_dir=None,
        model_name=None,
        verbose=False,
    ):
        self.max_alias_len = (
            max_alias_len  # minimum probability of prediction to return mention
        )
        self.verbose = verbose
        self.threshold = threshold

        if not cache_dir:
            self.cache_dir = get_default_cache()
            self.model_path = self.cache_dir / "models"
            self.data_path = self.cache_dir / "data"
        else:
            self.cache_dir = Path(cache_dir)
            self.model_path = self.cache_dir / "models"
            self.data_path = self.cache_dir / "data"

        if not model_name:
            model_name = "bootleg_uncased"

        assert model_name in {
            "bootleg_cased",
            "bootleg_cased_mini",
            "bootleg_uncased",
            "bootleg_uncased_mini",
        }, (f"model_name must be one of [bootleg_cased, bootleg_cased_mini, "
            f"bootleg_uncased_mini, bootleg_uncased]. You have {model_name}.")

        if not config:
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            self.model_path.mkdir(parents=True, exist_ok=True)
            self.data_path.mkdir(parents=True, exist_ok=True)
            create_sources(self.model_path, self.data_path, model_name)
            self.config = create_config(self.model_path, self.data_path,
                                        model_name)
        else:
            if "emmental" in config:
                config = parse_boot_and_emm_args(config)
            self.config = config
            # Ensure some of the critical annotator args are the correct type
            self.config.data_config.max_aliases = int(
                self.config.data_config.max_aliases)
            self.config.run_config.eval_batch_size = int(
                self.config.run_config.eval_batch_size)
            self.config.data_config.max_seq_len = int(
                self.config.data_config.max_seq_len)
            self.config.data_config.train_in_candidates = bool(
                self.config.data_config.train_in_candidates)

        if not device:
            device = 0 if torch.cuda.is_available() else -1

        if self.verbose:
            self.config.run_config.log_level = "DEBUG"
        else:
            self.config.run_config.log_level = "INFO"

        self.torch_device = (torch.device(device)
                             if device != -1 else torch.device("cpu"))
        self.config.model_config.device = device

        log_level = logging.getLevelName(
            self.config["run_config"]["log_level"].upper())
        emmental.init(
            log_dir=self.config["meta_config"]["log_path"],
            config=self.config,
            use_exact_log_path=self.config["meta_config"]
            ["use_exact_log_path"],
            level=log_level,
        )

        logger.debug("Reading entity database")
        self.entity_db = EntitySymbols.load_from_cache(
            os.path.join(
                self.config.data_config.entity_dir,
                self.config.data_config.entity_map_dir,
            ),
            alias_cand_map_file=self.config.data_config.alias_cand_map,
            alias_idx_file=self.config.data_config.alias_idx_map,
        )
        logger.debug("Reading word tokenizers")
        self.tokenizer = BertTokenizer.from_pretrained(
            self.config.data_config.word_embedding.bert_model,
            do_lower_case=True if "uncased"
            in self.config.data_config.word_embedding.bert_model else False,
            cache_dir=self.config.data_config.word_embedding.cache_dir,
        )

        # Create tasks
        tasks = [NED_TASK]
        if self.config.data_config.type_prediction.use_type_pred is True:
            tasks.append(TYPE_PRED_TASK)
        self.task_to_label_dict = {t: NED_TASK_TO_LABEL[t] for t in tasks}

        # Create tasks
        self.model = EmmentalModel(name="Bootleg")
        self.model.add_task(ned_task.create_task(self.config, self.entity_db))
        if TYPE_PRED_TASK in tasks:
            self.model.add_task(
                type_pred_task.create_task(self.config, self.entity_db))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(self.model)

        logger.debug("Loading model")
        # Load the best model from the pretrained model
        assert (
            self.config["model_config"]["model_path"] is not None
        ), f"Must have a model to load in the model_path for the BootlegAnnotator"
        self.model.load(self.config["model_config"]["model_path"])
        self.model.eval()
        if cand_map is None:
            alias_map = self.entity_db.get_alias2qids()
        else:
            logger.debug(f"Loading candidate map")
            alias_map = ujson.load(open(cand_map))

        self.all_aliases_trie = get_all_aliases(alias_map, verbose)

        logger.debug("Reading in alias table")
        self.alias2cands = AliasEntityTable(
            data_config=self.config.data_config, entity_symbols=self.entity_db)

        # get batch_on_the_fly embeddings
        self.batch_on_the_fly_embs = get_dataloader_embeddings(
            self.config, self.entity_db)
Пример #13
0
    def _logging(
        self,
        model: EmmentalModel,
        dataloaders: List[EmmentalDataLoader],
        batch_size: int,
    ) -> Dict[str, float]:
        r"""Checking if it's time to evaluting or checkpointing.

        Args:
          model(EmmentalModel): The model to log.
          dataloaders(List[EmmentalDataLoader]): The data to evaluate.
          batch_size(int): Batch size.

        Returns:
          dict: The score dict.

        """

        # Switch to eval mode for evaluation
        model.eval()

        metric_dict = dict()

        self.logging_manager.update(batch_size)

        # Log the loss and lr
        metric_dict.update(self._aggregate_running_metrics(model))

        # Evaluate the model and log the metric
        trigger_evaluation = self.logging_manager.trigger_evaluation()
        if trigger_evaluation:

            # Log task specific metric
            metric_dict.update(
                self._evaluate(
                    model, dataloaders, Meta.config["learner_config"]["valid_split"]
                )
            )

            self.logging_manager.write_log(metric_dict)

            self._reset_losses()

        # Log metric dict every trigger evaluation time or full epoch
        if Meta.config["meta_config"]["verbose"] and (
            trigger_evaluation
            or self.logging_manager.epoch_total == int(self.logging_manager.epoch_total)
        ):
            logger.info(
                f"{self.logging_manager.counter_unit.capitalize()}: "
                f"{self.logging_manager.unit_total:.2f} {metric_dict}"
            )

        # Checkpoint the model
        if self.logging_manager.trigger_checkpointing():
            self.logging_manager.checkpoint_model(
                model, self.optimizer, self.lr_scheduler, metric_dict
            )

            self.logging_manager.write_log(metric_dict)

            self._reset_losses()

        # Switch to train mode
        model.train()

        return metric_dict
Пример #14
0
def main(
    conn_string,
    stg_temp_min=False,
    stg_temp_max=False,
    polarity=False,
    ce_v_max=False,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    re_label=False,
    parallel=4,
    log_dir=None,
    verbose=False,
):
    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    rel_list = []
    if stg_temp_min:
        rel_list.append("stg_temp_min")

    if stg_temp_max:
        rel_list.append("stg_temp_max")

    if polarity:
        rel_list.append("polarity")

    if ce_v_max:
        rel_list.append("ce_v_max")

    session = Meta.init(conn_string).Session()

    # Parsing
    logger.info(f"Starting parsing...")
    start = timer()
    docs, train_docs, dev_docs, test_docs = parse_dataset(session,
                                                          dirname,
                                                          first_time=parse,
                                                          parallel=parallel,
                                                          max_docs=max_docs)
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")
    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    # Mention Extraction
    start = timer()
    mentions = []
    ngrams = []
    matchers = []

    # Only do those that are enabled
    Part = mention_subclass("Part")
    part_matcher = get_matcher("part")
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)

    mentions.append(Part)
    ngrams.append(part_ngrams)
    matchers.append(part_matcher)

    if stg_temp_min:
        StgTempMin = mention_subclass("StgTempMin")
        stg_temp_min_matcher = get_matcher("stg_temp_min")
        stg_temp_min_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMin)
        ngrams.append(stg_temp_min_ngrams)
        matchers.append(stg_temp_min_matcher)

    if stg_temp_max:
        StgTempMax = mention_subclass("StgTempMax")
        stg_temp_max_matcher = get_matcher("stg_temp_max")
        stg_temp_max_ngrams = MentionNgramsTemp(n_max=2)

        mentions.append(StgTempMax)
        ngrams.append(stg_temp_max_ngrams)
        matchers.append(stg_temp_max_matcher)

    if polarity:
        Polarity = mention_subclass("Polarity")
        polarity_matcher = get_matcher("polarity")
        polarity_ngrams = MentionNgrams(n_max=1)

        mentions.append(Polarity)
        ngrams.append(polarity_ngrams)
        matchers.append(polarity_matcher)

    if ce_v_max:
        CeVMax = mention_subclass("CeVMax")
        ce_v_max_matcher = get_matcher("ce_v_max")
        ce_v_max_ngrams = MentionNgramsVolt(n_max=1)

        mentions.append(CeVMax)
        ngrams.append(ce_v_max_ngrams)
        matchers.append(ce_v_max_matcher)

    mention_extractor = MentionExtractor(session, mentions, ngrams, matchers)

    if first_time:
        mention_extractor.apply(docs, parallelism=parallel)

    logger.info(f"Total Mentions: {session.query(Mention).count()}")
    logger.info(f"Total Part: {session.query(Part).count()}")
    if stg_temp_min:
        logger.info(f"Total StgTempMin: {session.query(StgTempMin).count()}")
    if stg_temp_max:
        logger.info(f"Total StgTempMax: {session.query(StgTempMax).count()}")
    if polarity:
        logger.info(f"Total Polarity: {session.query(Polarity).count()}")
    if ce_v_max:
        logger.info(f"Total CeVMax: {session.query(CeVMax).count()}")

    # Candidate Extraction
    cands = []
    throttlers = []
    if stg_temp_min:
        PartStgTempMin = candidate_subclass("PartStgTempMin",
                                            [Part, StgTempMin])
        stg_temp_min_throttler = stg_temp_filter

        cands.append(PartStgTempMin)
        throttlers.append(stg_temp_min_throttler)

    if stg_temp_max:
        PartStgTempMax = candidate_subclass("PartStgTempMax",
                                            [Part, StgTempMax])
        stg_temp_max_throttler = stg_temp_filter

        cands.append(PartStgTempMax)
        throttlers.append(stg_temp_max_throttler)

    if polarity:
        PartPolarity = candidate_subclass("PartPolarity", [Part, Polarity])
        polarity_throttler = polarity_filter

        cands.append(PartPolarity)
        throttlers.append(polarity_throttler)

    if ce_v_max:
        PartCeVMax = candidate_subclass("PartCeVMax", [Part, CeVMax])
        ce_v_max_throttler = ce_v_max_filter

        cands.append(PartCeVMax)
        throttlers.append(ce_v_max_throttler)

    candidate_extractor = CandidateExtractor(session,
                                             cands,
                                             throttlers=throttlers)

    if first_time:
        for i, docs in enumerate([train_docs, dev_docs, test_docs]):
            candidate_extractor.apply(docs, split=i, parallelism=parallel)
            num_cands = session.query(Candidate).filter(
                Candidate.split == i).count()
            logger.info(f"Candidates in split={i}: {num_cands}")

    # These must be sorted for deterministic behavior.
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)

    end = timer()
    logger.warning(
        f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"Total train candidate: {sum(len(_) for _ in train_cands)}")
    logger.info(f"Total dev candidate: {sum(len(_) for _ in dev_cands)}")
    logger.info(f"Total test candidate: {sum(len(_) for _ in test_cands)}")

    pickle_file = os.path.join(dirname, "data/parts_by_doc_new.pkl")
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    # Check total recall
    for i, name in enumerate(rel_list):
        logger.info(name)
        result = entity_level_scores(
            candidates_to_entities(dev_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=dev_docs,
        )
        logger.info(f"{name} Total Dev Recall: {result.rec:.3f}")
        result = entity_level_scores(
            candidates_to_entities(test_cands[i], parts_by_doc=parts_by_doc),
            attribute=name,
            corpus=test_docs,
        )
        logger.info(f"{name} Total Test Recall: {result.rec:.3f}")

    # Featurization
    start = timer()
    cands = []
    if stg_temp_min:
        cands.append(PartStgTempMin)

    if stg_temp_max:
        cands.append(PartStgTempMax)

    if polarity:
        cands.append(PartPolarity)

    if ce_v_max:
        cands.append(PartCeVMax)

    # Using parallelism = 1 for deterministic behavior.
    featurizer = Featurizer(session, cands, parallelism=1)
    if first_time:
        logger.info("Starting featurizer...")
        featurizer.apply(split=0, train=True)
        featurizer.apply(split=1)
        featurizer.apply(split=2)
        logger.info("Done")

    logger.info("Getting feature matrices...")
    if first_time:
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)
        end = timer()
        logger.warning(
            f"Featurization Time (min): {((end - start) / 60.0):.1f}")

        F_train_dict = {}
        F_dev_dict = {}
        F_test_dict = {}
        for idx, relation in enumerate(rel_list):
            F_train_dict[relation] = F_train[idx]
            F_dev_dict[relation] = F_dev[idx]
            F_test_dict[relation] = F_test[idx]

        pickle.dump(F_train_dict,
                    open(os.path.join(dirname, "F_train_dict.pkl"), "wb"))
        pickle.dump(F_dev_dict,
                    open(os.path.join(dirname, "F_dev_dict.pkl"), "wb"))
        pickle.dump(F_test_dict,
                    open(os.path.join(dirname, "F_test_dict.pkl"), "wb"))
    else:
        F_train_dict = pickle.load(
            open(os.path.join(dirname, "F_train_dict.pkl"), "rb"))
        F_dev_dict = pickle.load(
            open(os.path.join(dirname, "F_dev_dict.pkl"), "rb"))
        F_test_dict = pickle.load(
            open(os.path.join(dirname, "F_test_dict.pkl"), "rb"))

        F_train = []
        F_dev = []
        F_test = []
        for relation in rel_list:
            F_train.append(F_train_dict[relation])
            F_dev.append(F_dev_dict[relation])
            F_test.append(F_test_dict[relation])

    logger.info("Done.")

    for i, cand in enumerate(cands):
        logger.info(f"{cand} Train shape: {F_train[i].shape}")
        logger.info(f"{cand} Test shape: {F_test[i].shape}")
        logger.info(f"{cand} Dev shape: {F_dev[i].shape}")

    logger.info("Labeling training data...")

    # Labeling
    start = timer()
    lfs = []
    if stg_temp_min:
        lfs.append(stg_temp_min_lfs)

    if stg_temp_max:
        lfs.append(stg_temp_max_lfs)

    if polarity:
        lfs.append(polarity_lfs)

    if ce_v_max:
        lfs.append(ce_v_max_lfs)

    # Using parallelism = 1 for deterministic behavior.
    labeler = Labeler(session, cands, parallelism=1)

    if first_time:
        logger.info("Applying LFs...")
        labeler.apply(split=0, lfs=lfs, train=True)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  load_transistor_labels(session, cands, ["ce_v_max"])
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    elif re_label:
        logger.info("Updating LFs...")
        labeler.update(split=0, lfs=lfs)
        logger.info("Done...")

        # Uncomment if debugging LFs
        #  labeler.apply(split=1, lfs=lfs, train=False, parallelism=parallel)
        #  labeler.apply(split=2, lfs=lfs, train=False, parallelism=parallel)

    logger.info("Getting label matrices...")

    L_train = labeler.get_label_matrices(train_cands)

    # Uncomment if debugging LFs
    #  L_dev = labeler.get_label_matrices(dev_cands)
    #  L_dev_gold = labeler.get_gold_labels(dev_cands, annotator="gold")
    #
    #  L_test = labeler.get_label_matrices(test_cands)
    #  L_test_gold = labeler.get_gold_labels(test_cands, annotator="gold")

    logger.info("Done.")

    if first_time:
        marginals_dict = {}
        for idx, relation in enumerate(rel_list):
            marginals_dict[relation] = generative_model(L_train[idx])

        pickle.dump(marginals_dict,
                    open(os.path.join(dirname, "marginals_dict.pkl"), "wb"))
    else:
        marginals_dict = pickle.load(
            open(os.path.join(dirname, "marginals_dict.pkl"), "rb"))

    marginals = []
    for relation in rel_list:
        marginals.append(marginals_dict[relation])

    end = timer()
    logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}")

    start = timer()

    word_counter = collect_word_counter(train_cands)

    # Training config
    config = {
        "meta_config": {
            "verbose": True,
            "seed": 17
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 5,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.0
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    "model/all/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }

    emmental.init(log_dir=Meta.log_path, config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)
    train_idxs = []
    train_dataloader = []
    for idx, relation in enumerate(rel_list):
        diffs = marginals[idx].max(axis=1) - marginals[idx].min(axis=1)
        train_idxs.append(np.where(diffs > 1e-6)[0])

        train_dataloader.append(
            EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(
                    relation,
                    train_cands[idx],
                    F_train[idx],
                    emb_layer.word2id,
                    marginals[idx],
                    train_idxs[idx],
                ),
                split="train",
                batch_size=100,
                shuffle=True,
            ))

    num_feature_keys = len(featurizer.get_keys())

    model = EmmentalModel(name=f"transistor_tasks")

    # List relation names, arities, list of classes
    tasks = create_task(
        rel_list,
        [2] * len(rel_list),
        num_feature_keys,
        [2] * len(rel_list),
        emb_layer,
        model="LogisticRegression",
    )

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()

    # If given a list of multi, will train on multiple
    emmental_learner.learn(model, train_dataloader)

    # List of dataloader for each rlation
    for idx, relation in enumerate(rel_list):
        test_dataloader = EmmentalDataLoader(
            task_to_label_dict={relation: "labels"},
            dataset=FonduerDataset(relation, test_cands[idx], F_test[idx],
                                   emb_layer.word2id, 2),
            split="test",
            batch_size=100,
            shuffle=False,
        )

        test_preds = model.predict(test_dataloader, return_preds=True)

        best_result, best_b = scoring(
            relation,
            test_preds,
            test_cands[idx],
            test_docs,
            F_test[idx],
            parts_by_doc,
            num=100,
        )

        # Dump CSV files for CE_V_MAX for digi-key analysis
        if relation == "ce_v_max":
            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=100,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            dump_candidates(test_cands[idx], Y_prob, "ce_v_max_test_probs.csv")
            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            dump_candidates(dev_cands[idx], Y_prob, "ce_v_max_dev_probs.csv")

        # Dump CSV files for POLARITY for digi-key analysis
        if relation == "polarity":
            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=100,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            dump_candidates(test_cands[idx], Y_prob, "polarity_test_probs.csv")
            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            dump_candidates(dev_cands[idx], Y_prob, "polarity_dev_probs.csv")

    end = timer()
    logger.warning(f"Classification Time (min): {((end - start) / 60.0):.1f}")
Пример #15
0
class BootlegAnnotator(object):
    """BootlegAnnotator class: convenient wrapper of preprocessing and model
    eval to allow for annotating single sentences at a time for quick
    experimentation, e.g. in notebooks.

    Args:
        config: model config (default None)
        device: model device, -1 for CPU (default None)
        max_alias_len: maximum alias length (default 6)
        cand_map: alias candidate map (default None)
        threshold: probability threshold (default 0.0)
        cache_dir: cache directory (default None)
        model_name: model name (default None)
        verbose: verbose boolean (default False)
    """
    def __init__(
        self,
        config=None,
        device=None,
        max_alias_len=6,
        cand_map=None,
        threshold=0.0,
        cache_dir=None,
        model_name=None,
        verbose=False,
    ):
        self.max_alias_len = (
            max_alias_len  # minimum probability of prediction to return mention
        )
        self.verbose = verbose
        self.threshold = threshold

        if not cache_dir:
            self.cache_dir = get_default_cache()
            self.model_path = self.cache_dir / "models"
            self.data_path = self.cache_dir / "data"
        else:
            self.cache_dir = Path(cache_dir)
            self.model_path = self.cache_dir / "models"
            self.data_path = self.cache_dir / "data"

        if not model_name:
            model_name = "bootleg_uncased"

        assert model_name in {
            "bootleg_cased",
            "bootleg_cased_mini",
            "bootleg_uncased",
            "bootleg_uncased_mini",
        }, (f"model_name must be one of [bootleg_cased, bootleg_cased_mini, "
            f"bootleg_uncased_mini, bootleg_uncased]. You have {model_name}.")

        if not config:
            self.cache_dir.mkdir(parents=True, exist_ok=True)
            self.model_path.mkdir(parents=True, exist_ok=True)
            self.data_path.mkdir(parents=True, exist_ok=True)
            create_sources(self.model_path, self.data_path, model_name)
            self.config = create_config(self.model_path, self.data_path,
                                        model_name)
        else:
            if "emmental" in config:
                config = parse_boot_and_emm_args(config)
            self.config = config
            # Ensure some of the critical annotator args are the correct type
            self.config.data_config.max_aliases = int(
                self.config.data_config.max_aliases)
            self.config.run_config.eval_batch_size = int(
                self.config.run_config.eval_batch_size)
            self.config.data_config.max_seq_len = int(
                self.config.data_config.max_seq_len)
            self.config.data_config.train_in_candidates = bool(
                self.config.data_config.train_in_candidates)

        if not device:
            device = 0 if torch.cuda.is_available() else -1

        if self.verbose:
            self.config.run_config.log_level = "DEBUG"
        else:
            self.config.run_config.log_level = "INFO"

        self.torch_device = (torch.device(device)
                             if device != -1 else torch.device("cpu"))
        self.config.model_config.device = device

        log_level = logging.getLevelName(
            self.config["run_config"]["log_level"].upper())
        emmental.init(
            log_dir=self.config["meta_config"]["log_path"],
            config=self.config,
            use_exact_log_path=self.config["meta_config"]
            ["use_exact_log_path"],
            level=log_level,
        )

        logger.debug("Reading entity database")
        self.entity_db = EntitySymbols.load_from_cache(
            os.path.join(
                self.config.data_config.entity_dir,
                self.config.data_config.entity_map_dir,
            ),
            alias_cand_map_file=self.config.data_config.alias_cand_map,
            alias_idx_file=self.config.data_config.alias_idx_map,
        )
        logger.debug("Reading word tokenizers")
        self.tokenizer = BertTokenizer.from_pretrained(
            self.config.data_config.word_embedding.bert_model,
            do_lower_case=True if "uncased"
            in self.config.data_config.word_embedding.bert_model else False,
            cache_dir=self.config.data_config.word_embedding.cache_dir,
        )

        # Create tasks
        tasks = [NED_TASK]
        if self.config.data_config.type_prediction.use_type_pred is True:
            tasks.append(TYPE_PRED_TASK)
        self.task_to_label_dict = {t: NED_TASK_TO_LABEL[t] for t in tasks}

        # Create tasks
        self.model = EmmentalModel(name="Bootleg")
        self.model.add_task(ned_task.create_task(self.config, self.entity_db))
        if TYPE_PRED_TASK in tasks:
            self.model.add_task(
                type_pred_task.create_task(self.config, self.entity_db))
            # Add the mention type embedding to the embedding payload
            type_pred_task.update_ned_task(self.model)

        logger.debug("Loading model")
        # Load the best model from the pretrained model
        assert (
            self.config["model_config"]["model_path"] is not None
        ), f"Must have a model to load in the model_path for the BootlegAnnotator"
        self.model.load(self.config["model_config"]["model_path"])
        self.model.eval()
        if cand_map is None:
            alias_map = self.entity_db.get_alias2qids()
        else:
            logger.debug(f"Loading candidate map")
            alias_map = ujson.load(open(cand_map))

        self.all_aliases_trie = get_all_aliases(alias_map, verbose)

        logger.debug("Reading in alias table")
        self.alias2cands = AliasEntityTable(
            data_config=self.config.data_config, entity_symbols=self.entity_db)

        # get batch_on_the_fly embeddings
        self.batch_on_the_fly_embs = get_dataloader_embeddings(
            self.config, self.entity_db)

    def extract_mentions(self, text, label_func):
        """Wrapper function for mention extraction.

        Args:
            text: text to extract mentions from
            label_func: function that performs extraction (input is (text, alias trie, max alias length) ->
                        output is list of found aliases and found spans

        Returns: JSON object of sentence to be used in eval
        """
        found_aliases, found_spans = label_func(text, self.all_aliases_trie,
                                                self.max_alias_len)
        return {
            "sentence": text,
            "aliases": found_aliases,
            "spans": found_spans,
            # we don't know the true QID
            "qids": ["Q-1" for i in range(len(found_aliases))],
            "gold": [True for i in range(len(found_aliases))],
        }

    def set_threshold(self, value):
        """Sets threshold.

        Args:
            value: threshold value

        Returns:
        """
        self.threshold = value

    def label_mentions(self,
                       text_list,
                       label_func=find_aliases_in_sentence_tag):
        """Extracts mentions and runs disambiguation.

        Args:
            text_list: list of text to disambiguate (or single sentence)
            label_func: mention extraction funciton (optional)

        Returns: Dict of

            * ``qids``: final predicted QIDs,
            * ``probs``: final predicted probs,
            * ``titles``: final predicted titles,
            * ``cands``: all entity canddiates,
            * ``cand_probs``: probabilities of all candidates,
            * ``spans``: final extracted word spans,
            * ``aliases``: final extracted aliases,
        """
        if type(text_list) is str:
            text_list = [text_list]
        else:
            assert (type(text_list) is list and len(text_list) > 0
                    and type(text_list[0]) is str
                    ), f"We only accept inputs of strings and lists of strings"

        ebs = int(self.config.run_config.eval_batch_size)
        self.config.data_config.max_aliases = int(
            self.config.data_config.max_aliases)
        total_start_exs = 0
        total_final_exs = 0
        dropped_by_thresh = 0

        final_char_spans = []

        batch_example_aliases = []
        batch_example_aliases_locs_start = []
        batch_example_aliases_locs_end = []
        batch_example_alias_list_pos = []
        batch_example_true_entities = []
        batch_word_indices = []
        batch_spans_arr = []
        batch_aliases_arr = []
        batch_idx_unq = []
        batch_subsplit_idx = []
        for idx_unq, text in tqdm(
                enumerate(text_list),
                desc="Prepping data",
                total=len(text_list),
                disable=not self.verbose,
        ):
            sample = self.extract_mentions(text, label_func)
            total_start_exs += len(sample["aliases"])
            char_spans = self.get_char_spans(sample["spans"], text)

            final_char_spans.append(char_spans)

            (
                idxs_arr,
                aliases_to_predict_per_split,
                spans_arr,
                phrase_tokens_arr,
                pos_idxs,
            ) = sentence_utils.split_sentence(
                max_aliases=self.config.data_config.max_aliases,
                phrase=sample["sentence"],
                spans=sample["spans"],
                aliases=sample["aliases"],
                aliases_seen_by_model=list(range(len(sample["aliases"]))),
                seq_len=self.config.data_config.max_seq_len,
                is_bert=True,
                tokenizer=self.tokenizer,
            )
            aliases_arr = [[sample["aliases"][idx] for idx in idxs]
                           for idxs in idxs_arr]
            old_spans_arr = [[sample["spans"][idx] for idx in idxs]
                             for idxs in idxs_arr]
            qids_arr = [[sample["qids"][idx] for idx in idxs]
                        for idxs in idxs_arr]
            word_indices_arr = [
                self.tokenizer.convert_tokens_to_ids(pt)
                for pt in phrase_tokens_arr
            ]
            # iterate over each sample in the split

            for sub_idx in range(len(idxs_arr)):
                # ====================================================
                # GENERATE MODEL INPUTS
                # ====================================================
                aliases_to_predict_arr = aliases_to_predict_per_split[sub_idx]

                assert (
                    len(aliases_to_predict_arr) >= 0
                ), f"There are no aliases to predict for an example. This should not happen at this point."
                assert (
                    len(aliases_arr[sub_idx]) <=
                    self.config.data_config.max_aliases
                ), f"{sample} should have no more than {self.config.data_config.max_aliases} aliases."

                example_aliases = np.ones(
                    self.config.data_config.max_aliases) * PAD_ID
                example_aliases_locs_start = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_aliases_locs_end = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_alias_list_pos = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)
                example_true_entities = (
                    np.ones(self.config.data_config.max_aliases) * PAD_ID)

                for mention_idx, alias in enumerate(aliases_arr[sub_idx]):
                    span_start_idx, span_end_idx = spans_arr[sub_idx][
                        mention_idx]
                    # generate indexes into alias table.
                    alias_trie_idx = self.entity_db.get_alias_idx(alias)
                    alias_qids = np.array(self.entity_db.get_qid_cands(alias))
                    if not qids_arr[sub_idx][mention_idx] in alias_qids:
                        # assert not data_args.train_in_candidates
                        if not self.config.data_config.train_in_candidates:
                            # set class label to be "not in candidate set"
                            true_entity_idx = 0
                        else:
                            true_entity_idx = -2
                    else:
                        # Here we are getting the correct class label for training.
                        # Our training is "which of the max_entities entity candidates is the right one
                        # (class labels 1 to max_entities) or is it none of these (class label 0)".
                        # + (not discard_noncandidate_entities) is to ensure label 0 is
                        # reserved for "not in candidate set" class
                        true_entity_idx = np.nonzero(
                            alias_qids == qids_arr[sub_idx][mention_idx]
                        )[0][0] + (
                            not self.config.data_config.train_in_candidates)
                    example_aliases[mention_idx] = alias_trie_idx
                    example_aliases_locs_start[mention_idx] = span_start_idx
                    # The span_idxs are [start, end). We want [start, end]. So subtract 1 from end idx.
                    example_aliases_locs_end[mention_idx] = span_end_idx - 1
                    example_alias_list_pos[mention_idx] = idxs_arr[sub_idx][
                        mention_idx]
                    # leave as -1 if it's not an alias we want to predict; we get these if we split a sentence
                    # and need to only predict subsets
                    if mention_idx in aliases_to_predict_arr:
                        example_true_entities[mention_idx] = true_entity_idx

                # get word indices
                word_indices = word_indices_arr[sub_idx]

                batch_example_aliases.append(example_aliases)
                batch_example_aliases_locs_start.append(
                    example_aliases_locs_start)
                batch_example_aliases_locs_end.append(example_aliases_locs_end)
                batch_example_alias_list_pos.append(example_alias_list_pos)
                batch_example_true_entities.append(example_true_entities)
                batch_word_indices.append(word_indices)
                batch_aliases_arr.append(aliases_arr[sub_idx])
                # Add the orginal sample spans because spans_arr is w.r.t BERT subword token
                batch_spans_arr.append(old_spans_arr[sub_idx])
                batch_idx_unq.append(idx_unq)
                batch_subsplit_idx.append(sub_idx)

        batch_example_aliases = torch.tensor(batch_example_aliases).long()
        batch_example_aliases_locs_start = torch.tensor(
            batch_example_aliases_locs_start, device=self.torch_device)
        batch_example_aliases_locs_end = torch.tensor(
            batch_example_aliases_locs_end, device=self.torch_device)
        batch_example_true_entities = torch.tensor(batch_example_true_entities,
                                                   device=self.torch_device)
        batch_word_indices = torch.tensor(batch_word_indices,
                                          device=self.torch_device)

        final_pred_cands = [[] for _ in range(len(text_list))]
        final_all_cands = [[] for _ in range(len(text_list))]
        final_cand_probs = [[] for _ in range(len(text_list))]
        final_pred_probs = [[] for _ in range(len(text_list))]
        final_titles = [[] for _ in range(len(text_list))]
        final_spans = [[] for _ in range(len(text_list))]
        final_aliases = [[] for _ in range(len(text_list))]
        for b_i in tqdm(
                range(0, batch_example_aliases.shape[0], ebs),
                desc="Evaluating model",
                disable=not self.verbose,
        ):
            start_span_idx = batch_example_aliases_locs_start[b_i:b_i + ebs]
            end_span_idx = batch_example_aliases_locs_end[b_i:b_i + ebs]
            word_indices = batch_word_indices[b_i:b_i + ebs]
            alias_indices = batch_example_aliases[b_i:b_i + ebs]
            x_dict = self.get_forward_batch(start_span_idx, end_span_idx,
                                            word_indices, alias_indices)
            x_dict["guid"] = torch.arange(b_i,
                                          b_i + ebs,
                                          device=self.torch_device)

            (uid_bdict, _, prob_bdict, _) = self.model(  # type: ignore
                uids=x_dict["guid"],
                X_dict=x_dict,
                Y_dict=None,
                task_to_label_dict=self.task_to_label_dict,
                return_action_outputs=False,
            )
            # ====================================================
            # EVALUATE MODEL OUTPUTS
            # ====================================================
            # recover predictions
            probs = prob_bdict[NED_TASK]
            max_probs = probs.max(2)
            max_probs_indices = probs.argmax(2)
            for ex_i in range(probs.shape[0]):
                idx_unq = batch_idx_unq[b_i + ex_i]
                entity_cands = eval_utils.map_aliases_to_candidates(
                    self.config.data_config.train_in_candidates,
                    self.config.data_config.max_aliases,
                    self.entity_db.get_alias2qids(),
                    batch_aliases_arr[b_i + ex_i],
                )
                # batch size is 1 so we can reshape
                probs_ex = probs[ex_i].reshape(
                    self.config.data_config.max_aliases, probs.shape[2])
                for alias_idx, true_entity_pos_idx in enumerate(
                        batch_example_true_entities[b_i + ex_i]):
                    if true_entity_pos_idx != PAD_ID:
                        pred_idx = max_probs_indices[ex_i][alias_idx]
                        pred_prob = max_probs[ex_i][alias_idx].item()
                        all_cands = entity_cands[alias_idx]
                        pred_qid = all_cands[pred_idx]
                        if pred_prob > self.threshold:
                            final_all_cands[idx_unq].append(all_cands)
                            final_cand_probs[idx_unq].append(
                                probs_ex[alias_idx])
                            final_pred_cands[idx_unq].append(pred_qid)
                            final_pred_probs[idx_unq].append(pred_prob)
                            final_aliases[idx_unq].append(
                                batch_aliases_arr[b_i + ex_i][alias_idx])
                            final_spans[idx_unq].append(
                                batch_spans_arr[b_i + ex_i][alias_idx])
                            final_titles[idx_unq].append(
                                self.entity_db.get_title(pred_qid)
                                if pred_qid != "NC" else "NC")
                            total_final_exs += 1
                        else:
                            dropped_by_thresh += 1
        assert total_final_exs + dropped_by_thresh == total_start_exs, (
            f"Something went wrong and we have predicted fewer mentions than extracted. "
            f"Start {total_start_exs}, Out {total_final_exs}, No cand {dropped_by_thresh}"
        )
        res_dict = {
            "qids": final_pred_cands,
            "probs": final_pred_probs,
            "titles": final_titles,
            "cands": final_all_cands,
            "cand_probs": final_cand_probs,
            "spans": final_spans,
            "aliases": final_aliases,
        }
        return res_dict

    def get_forward_batch(self, start_span_idx, end_span_idx, token_ids,
                          alias_idx):
        """Preps the forward batch for disambiguation.

        Args:
            start_span_idx: start span tensor
            end_span_idx: end span tensor
            token_ids: word token tensor
            alias_idx: alias index used for extracting candidate eids

        Returns: X_dict used in Emmental
        """
        entity_cand_eid = self.alias2cands(alias_idx).long()
        entity_cand_eid_mask = entity_cand_eid == -1
        entity_cand_eid_noneg = torch.where(
            entity_cand_eid >= 0,
            entity_cand_eid,
            (torch.ones_like(entity_cand_eid, dtype=torch.long) *
             (self.entity_db.num_entities_with_pad_and_nocand - 1)),
        )

        kg_prepped_embs = {}
        for emb_key in self.batch_on_the_fly_embs:
            kg_adj = self.batch_on_the_fly_embs[emb_key]["kg_adj"]
            prep_func = self.batch_on_the_fly_embs[emb_key][
                "kg_adj_process_func"]
            batch_prep = []
            for j in range(entity_cand_eid_noneg.shape[0]):
                batch_prep.append(
                    prep_func(entity_cand_eid_noneg[j].cpu(),
                              kg_adj).reshape(1, -1))
            kg_prepped_embs[emb_key] = torch.tensor(batch_prep,
                                                    device=self.torch_device)

        X_dict = {
            "guids": [],
            "start_span_idx": start_span_idx,
            "end_span_idx": end_span_idx,
            "token_ids": token_ids,
            "entity_cand_eid": entity_cand_eid_noneg,
            "entity_cand_eid_mask": entity_cand_eid_mask,
            "batch_on_the_fly_kg_adj": kg_prepped_embs,
        }
        return X_dict

    def get_char_spans(self, spans, text):
        """Helper function to get character spans instead of default word
        spans.

        Args:
            spans: word spans
            text: text

        Returns: character spans
        """
        query_toks = text.split()
        char_spans = []
        for span in spans:
            space_btwn_toks = (len(" ".join(query_toks[0:span[0] + 1])) -
                               len(" ".join(query_toks[0:span[0]])) -
                               len(query_toks[span[0]]))
            char_b = len(" ".join(query_toks[0:span[0]])) + space_btwn_toks
            char_e = char_b + len(" ".join(query_toks[span[0]:span[1]]))
            char_spans.append([char_b, char_e])
        return char_spans
Пример #16
0
def main(
    conn_string,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    gpu=None,
    parallel=4,
    log_dir=None,
    verbose=False,
):
    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    session = Meta.init(conn_string).Session()

    # Parsing
    logger.info(f"Starting parsing...")
    start = timer()
    docs, train_docs, dev_docs, test_docs = parse_dataset(
        session, dirname, first_time=first_time, parallel=parallel, max_docs=max_docs
    )
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")

    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    start = timer()

    Thumbnails = mention_subclass("Thumbnails")

    thumbnails_img = MentionFigures()

    class HasFigures(_Matcher):
        def _f(self, m):
            file_path = ""
            for prefix in [
                f"{dirname}/data/train/html/",
                f"{dirname}/data/dev/html/",
                f"{dirname}/data/test/html/",
            ]:
                if os.path.exists(prefix + m.figure.url):
                    file_path = prefix + m.figure.url
            if file_path == "":
                return False
            img = Image.open(file_path)
            width, height = img.size
            min_value = min(width, height)
            return min_value > 50

    mention_extractor = MentionExtractor(
        session, [Thumbnails], [thumbnails_img], [HasFigures()], parallelism=parallel
    )

    if first_time:
        mention_extractor.apply(docs)

    logger.info("Total Mentions: {}".format(session.query(Mention).count()))

    ThumbnailLabel = candidate_subclass("ThumbnailLabel", [Thumbnails])

    candidate_extractor = CandidateExtractor(
        session, [ThumbnailLabel], throttlers=[None], parallelism=parallel
    )

    if first_time:
        candidate_extractor.apply(train_docs, split=0)
        candidate_extractor.apply(dev_docs, split=1)
        candidate_extractor.apply(test_docs, split=2)

    train_cands = candidate_extractor.get_candidates(split=0)
    # Sort the dev_cands, which are used for training, for deterministic behavior
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2)

    end = timer()
    logger.warning(f"Candidate Extraction Time (min): {((end - start) / 60.0):.1f}")

    logger.info("Total train candidate:\t{}".format(len(train_cands[0])))
    logger.info("Total dev candidate:\t{}".format(len(dev_cands[0])))
    logger.info("Total test candidate:\t{}".format(len(test_cands[0])))

    fin = open(f"{dirname}/data/ground_truth.txt", "r")
    gt = set()
    for line in fin:
        gt.add("::".join(line.lower().split()))
    fin.close()

    # Labeling
    start = timer()

    def LF_gt_label(c):
        doc_file_id = (
            f"{c[0].context.figure.document.name.lower()}.pdf::"
            f"{os.path.basename(c[0].context.figure.url.lower())}"
        )
        return TRUE if doc_file_id in gt else FALSE

    gt_dev = [LF_gt_label(cand) for cand in dev_cands[0]]
    gt_test = [LF_gt_label(cand) for cand in test_cands[0]]

    end = timer()
    logger.warning(f"Supervision Time (min): {((end - start) / 60.0):.1f}")

    batch_size = 64
    input_size = 224
    K = 2

    emmental.init(log_dir=Meta.log_path, config=emmental_config)

    emmental.Meta.config["learner_config"]["task_scheduler_config"][
        "task_scheduler"
    ] = DauphinScheduler(augment_k=K, enlarge=1)

    train_dataset = ThumbnailDataset(
        "Thumbnail",
        dev_cands[0],
        gt_dev,
        "train",
        prob_label=True,
        prefix=f"{dirname}/data/dev/html/",
        input_size=input_size,
        transform_cls=Augmentation(2),
        k=K,
    )

    val_dataset = ThumbnailDataset(
        "Thumbnail",
        dev_cands[0],
        gt_dev,
        "valid",
        prob_label=False,
        prefix=f"{dirname}/data/dev/html/",
        input_size=input_size,
        k=1,
    )

    test_dataset = ThumbnailDataset(
        "Thumbnail",
        test_cands[0],
        gt_test,
        "test",
        prob_label=False,
        prefix=f"{dirname}/data/test/html/",
        input_size=input_size,
        k=1,
    )

    dataloaders = []

    dataloaders.append(
        EmmentalDataLoader(
            task_to_label_dict={"Thumbnail": "labels"},
            dataset=train_dataset,
            split="train",
            shuffle=True,
            batch_size=batch_size,
            num_workers=1,
        )
    )

    dataloaders.append(
        EmmentalDataLoader(
            task_to_label_dict={"Thumbnail": "labels"},
            dataset=val_dataset,
            split="valid",
            shuffle=False,
            batch_size=batch_size,
            num_workers=1,
        )
    )

    dataloaders.append(
        EmmentalDataLoader(
            task_to_label_dict={"Thumbnail": "labels"},
            dataset=test_dataset,
            split="test",
            shuffle=False,
            batch_size=batch_size,
            num_workers=1,
        )
    )

    model = EmmentalModel(name=f"Thumbnail")
    model.add_task(
        create_task("Thumbnail", n_class=2, model="resnet18", pretrained=True)
    )

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, dataloaders)

    scores = model.score(dataloaders)

    logger.warning("Model Score:")
    logger.warning(f"precision: {scores['Thumbnail/Thumbnail/test/precision']:.3f}")
    logger.warning(f"recall: {scores['Thumbnail/Thumbnail/test/recall']:.3f}")
    logger.warning(f"f1: {scores['Thumbnail/Thumbnail/test/f1']:.3f}")
Пример #17
0
        dataloaders.append(
            EmmentalDataLoader(
                task_to_label_dict=task_to_label_dict,
                dataset=dataset,
                split=split,
                shuffle=True if split == "train" else False,
                batch_size=args.batch_size,
                num_workers=8,
            ))
        logger.info(f"Built dataloader for {dataset.name} {split} set.")

    tasks = create_task(list(task_to_label_dict.keys()),
                        cnn_encoder=args.model)

    # Build Emmental model
    model = EmmentalModel(name=DATA_NAME, tasks=tasks)

    # Load the pre-trained model
    if Meta.config["model_config"]["model_path"]:
        model.load(Meta.config["model_config"]["model_path"])

    # Training
    if args.train:
        emmental_learner = EmmentalLearner()
        emmental_learner.learn(model, dataloaders)

    scores = model.score(dataloaders)

    # Save metrics into file
    logger.info(f"Metrics: {scores}")
    write_to_json_file(f"{Meta.log_path}/metrics.txt", scores)
Пример #18
0
def test_e2e(caplog):
    """Run an end-to-end test."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_e2e"

    Meta.reset()
    emmental.init(dirpath)

    # Generate synthetic data
    N = 50
    X = np.random.random((N, 2)) * 2 - 1
    Y1 = (X[:, 0] > X[:, 1] + 0.25).astype(int) + 1
    Y2 = (-X[:, 0] > X[:, 1] + 0.25).astype(int) + 1

    # Create dataset and dataloader

    splits = [0.8, 0.1, 0.1]

    X_train, X_dev, X_test = [], [], []
    Y1_train, Y1_dev, Y1_test = [], [], []
    Y2_train, Y2_dev, Y2_test = [], [], []

    for i in range(N):
        if i <= N * splits[0]:
            X_train.append(torch.Tensor(X[i]))
            Y1_train.append(Y1[i])
            Y2_train.append(Y2[i])
        elif i < N * (splits[0] + splits[1]):
            X_dev.append(torch.Tensor(X[i]))
            Y1_dev.append(Y1[i])
            Y2_dev.append(Y2[i])
        else:
            X_test.append(torch.Tensor(X[i]))
            Y1_test.append(Y1[i])
            Y2_test.append(Y2[i])

    Y1_train = torch.from_numpy(np.array(Y1_train))
    Y1_dev = torch.from_numpy(np.array(Y1_dev))
    Y1_test = torch.from_numpy(np.array(Y1_test))

    Y2_train = torch.from_numpy(np.array(Y1_train))
    Y2_dev = torch.from_numpy(np.array(Y2_dev))
    Y2_test = torch.from_numpy(np.array(Y2_test))

    train_dataset1 = EmmentalDataset(
        name="synthetic", X_dict={"data": X_train}, Y_dict={"label1": Y1_train}
    )

    train_dataset2 = EmmentalDataset(
        name="synthetic", X_dict={"data": X_train}, Y_dict={"label2": Y2_train}
    )

    dev_dataset1 = EmmentalDataset(
        name="synthetic", X_dict={"data": X_dev}, Y_dict={"label1": Y1_dev}
    )

    dev_dataset2 = EmmentalDataset(
        name="synthetic", X_dict={"data": X_dev}, Y_dict={"label2": Y2_dev}
    )

    test_dataset1 = EmmentalDataset(
        name="synthetic", X_dict={"data": X_test}, Y_dict={"label1": Y2_test}
    )

    test_dataset2 = EmmentalDataset(
        name="synthetic", X_dict={"data": X_test}, Y_dict={"label2": Y2_test}
    )

    task_to_label_dict = {"task1": "label1"}

    train_dataloader1 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=train_dataset1,
        split="train",
        batch_size=10,
    )
    dev_dataloader1 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=dev_dataset1,
        split="valid",
        batch_size=10,
    )
    test_dataloader1 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=test_dataset1,
        split="test",
        batch_size=10,
    )

    task_to_label_dict = {"task2": "label2"}

    train_dataloader2 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=train_dataset2,
        split="train",
        batch_size=10,
    )
    dev_dataloader2 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=dev_dataset2,
        split="valid",
        batch_size=10,
    )
    test_dataloader2 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=test_dataset2,
        split="test",
        batch_size=10,
    )

    # Create task
    def ce_loss(task_name, immediate_ouput_dict, Y, active):
        module_name = f"{task_name}_pred_head"
        return F.cross_entropy(
            immediate_ouput_dict[module_name][0][active], (Y.view(-1) - 1)[active]
        )

    def output(task_name, immediate_ouput_dict):
        module_name = f"{task_name}_pred_head"
        return F.softmax(immediate_ouput_dict[module_name][0], dim=1)

    task_name = "task1"

    task1 = EmmentalTask(
        name=task_name,
        module_pool=nn.ModuleDict(
            {"input_module": nn.Linear(2, 8), f"{task_name}_pred_head": nn.Linear(8, 2)}
        ),
        task_flow=[
            {
                "name": "input",
                "module": "input_module",
                "inputs": [("_input_", "data")],
            },
            {
                "name": f"{task_name}_pred_head",
                "module": f"{task_name}_pred_head",
                "inputs": [("input", 0)],
            },
        ],
        loss_func=partial(ce_loss, task_name),
        output_func=partial(output, task_name),
        scorer=Scorer(metrics=["accuracy", "roc_auc"]),
    )

    task_name = "task2"

    task2 = EmmentalTask(
        name=task_name,
        module_pool=nn.ModuleDict(
            {"input_module": nn.Linear(2, 8), f"{task_name}_pred_head": nn.Linear(8, 2)}
        ),
        task_flow=[
            {
                "name": "input",
                "module": "input_module",
                "inputs": [("_input_", "data")],
            },
            {
                "name": f"{task_name}_pred_head",
                "module": f"{task_name}_pred_head",
                "inputs": [("input", 0)],
            },
        ],
        loss_func=partial(ce_loss, task_name),
        output_func=partial(output, task_name),
        scorer=Scorer(metrics=["accuracy", "roc_auc"]),
    )

    # Build model

    mtl_model = EmmentalModel(name="all", tasks=[task1, task2])

    # Create learner

    emmental_learner = EmmentalLearner()

    # Update learning config
    Meta.update_config(
        config={"learner_config": {"n_epochs": 10, "optimizer_config": {"lr": 0.01}}}
    )

    # Learning
    emmental_learner.learn(
        mtl_model,
        [train_dataloader1, train_dataloader2, dev_dataloader1, dev_dataloader2],
    )

    test1_score = mtl_model.score(test_dataloader1)
    test2_score = mtl_model.score(test_dataloader2)

    assert test1_score["task1/synthetic/test/accuracy"] >= 0.5
    assert test1_score["task1/synthetic/test/roc_auc"] >= 0.6
    assert test2_score["task2/synthetic/test/accuracy"] >= 0.5
    assert test2_score["task2/synthetic/test/roc_auc"] >= 0.6

    shutil.rmtree(dirpath)
Пример #19
0
    def learn(self, model: EmmentalModel,
              dataloaders: List[EmmentalDataLoader]) -> None:
        """Learning procedure of emmental MTL.

        Args:
          model: The emmental model that needs to learn.
          dataloaders: A list of dataloaders used to learn the model.
        """
        start_time = time.time()

        # Generate the list of dataloaders for learning process
        train_split = Meta.config["learner_config"]["train_split"]
        if isinstance(train_split, str):
            train_split = [train_split]

        train_dataloaders = [
            dataloader for dataloader in dataloaders
            if dataloader.split in train_split
        ]

        if not train_dataloaders:
            raise ValueError(
                f"Cannot find the specified train_split "
                f'{Meta.config["learner_config"]["train_split"]} in dataloaders.'
            )

        # Set up task_scheduler
        self._set_task_scheduler()

        # Calculate the total number of batches per epoch
        self.n_batches_per_epoch: int = self.task_scheduler.get_num_batches(
            train_dataloaders)
        if self.n_batches_per_epoch == 0:
            logger.info("No batches in training dataloaders, existing...")
            return

        # Set up learning counter
        self._set_learning_counter()
        # Set up logging manager
        self._set_logging_manager()
        # Set up wandb watch model
        if (Meta.config["logging_config"]["writer_config"]["writer"] == "wandb"
                and Meta.config["logging_config"]["writer_config"]
            ["wandb_watch_model"]):
            if Meta.config["logging_config"]["writer_config"][
                    "wandb_model_watch_freq"]:
                wandb.watch(
                    model,
                    log_freq=Meta.config["logging_config"]["writer_config"]
                    ["wandb_model_watch_freq"],
                )
            else:
                wandb.watch(model)
        # Set up optimizer
        self._set_optimizer(model)
        # Set up lr_scheduler
        self._set_lr_scheduler(model)

        if Meta.config["learner_config"]["fp16"]:
            try:
                from apex import amp  # type: ignore
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to "
                    "use fp16 training.")
            logger.info(
                f"Modeling training with 16-bit (mixed) precision "
                f"and {Meta.config['learner_config']['fp16_opt_level']} opt level."
            )
            model, self.optimizer = amp.initialize(
                model,
                self.optimizer,
                opt_level=Meta.config["learner_config"]["fp16_opt_level"],
            )

        # Multi-gpu training (after apex fp16 initialization)
        if (Meta.config["learner_config"]["local_rank"] == -1
                and Meta.config["model_config"]["dataparallel"]):
            model._to_dataparallel()

        # Distributed training (after apex fp16 initialization)
        if Meta.config["learner_config"]["local_rank"] != -1:
            model._to_distributed_dataparallel()

        # Set to training mode
        model.train()

        if Meta.config["meta_config"]["verbose"]:
            logger.info("Start learning...")

        self.metrics: Dict[str, float] = dict()
        self._reset_losses()

        # Set gradients of all model parameters to zero
        self.optimizer.zero_grad()

        batch_iterator = self.task_scheduler.get_batches(
            train_dataloaders, model)
        for epoch_num in range(self.start_epoch, self.end_epoch):
            for train_dataloader in train_dataloaders:
                # Set epoch for distributed sampler
                if isinstance(train_dataloader, DataLoader) and isinstance(
                        train_dataloader.sampler, DistributedSampler):
                    train_dataloader.sampler.set_epoch(epoch_num)
            step_pbar = tqdm(
                range(self.start_step, self.end_step),
                desc=f"Step {self.start_step + 1}/{self.end_step}"
                if self.use_step_base_counter else
                f"Epoch {epoch_num + 1}/{self.end_epoch}",
                disable=not Meta.config["meta_config"]["verbose"]
                or Meta.config["learner_config"]["local_rank"] not in [-1, 0],
            )
            for step_num in step_pbar:
                if self.use_step_base_counter:
                    step_pbar.set_description(
                        f"Step {step_num + 1}/{self.total_steps}")
                    step_pbar.refresh()
                try:
                    batch = next(batch_iterator)
                except StopIteration:
                    batch_iterator = self.task_scheduler.get_batches(
                        train_dataloaders, model)
                    batch = next(batch_iterator)

                # Check if skip the current batch
                if epoch_num < self.start_train_epoch or (
                        epoch_num == self.start_train_epoch
                        and step_num < self.start_train_step):
                    continue

                # Covert single batch into a batch list
                if not isinstance(batch, list):
                    batch = [batch]

                total_step_num = epoch_num * self.n_batches_per_epoch + step_num
                batch_size = 0

                for _batch in batch:
                    batch_size += len(_batch.uids)
                    # Perform forward pass and calcualte the loss and count
                    uid_dict, loss_dict, prob_dict, gold_dict = model(
                        _batch.uids,
                        _batch.X_dict,
                        _batch.Y_dict,
                        _batch.task_to_label_dict,
                        return_probs=Meta.config["learner_config"]
                        ["online_eval"],
                        return_action_outputs=False,
                    )

                    # Update running loss and count
                    for task_name in uid_dict.keys():
                        identifier = f"{task_name}/{_batch.data_name}/{_batch.split}"
                        self.running_uids[identifier].extend(
                            uid_dict[task_name])
                        self.running_losses[identifier] += (
                            loss_dict[task_name].item() *
                            len(uid_dict[task_name])
                            if len(loss_dict[task_name].size()) == 0 else
                            torch.sum(loss_dict[task_name]).item()
                        ) * model.task_weights[task_name]
                        if (Meta.config["learner_config"]["online_eval"]
                                and prob_dict and gold_dict):
                            self.running_probs[identifier].extend(
                                prob_dict[task_name])
                            self.running_golds[identifier].extend(
                                gold_dict[task_name])

                    # Calculate the average loss
                    loss = sum([
                        model.task_weights[task_name] *
                        task_loss if len(task_loss.size()) == 0 else
                        torch.mean(model.task_weights[task_name] * task_loss)
                        for task_name, task_loss in loss_dict.items()
                    ])

                    # Perform backward pass to calculate gradients
                    if Meta.config["learner_config"]["fp16"]:
                        with amp.scale_loss(loss,
                                            self.optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()  # type: ignore

                if (total_step_num +
                        1) % Meta.config["learner_config"]["optimizer_config"][
                            "gradient_accumulation_steps"] == 0 or (
                                step_num + 1 == self.end_step
                                and epoch_num + 1 == self.end_epoch):
                    # Clip gradient norm
                    if Meta.config["learner_config"]["optimizer_config"][
                            "grad_clip"]:
                        if Meta.config["learner_config"]["fp16"]:
                            torch.nn.utils.clip_grad_norm_(
                                amp.master_params(self.optimizer),
                                Meta.config["learner_config"]
                                ["optimizer_config"]["grad_clip"],
                            )
                        else:
                            torch.nn.utils.clip_grad_norm_(
                                model.parameters(),
                                Meta.config["learner_config"]
                                ["optimizer_config"]["grad_clip"],
                            )

                    # Update the parameters
                    self.optimizer.step()

                    # Set gradients of all model parameters to zero
                    self.optimizer.zero_grad()

                if Meta.config["learner_config"]["local_rank"] in [-1, 0]:
                    self.metrics.update(
                        self._logging(model, dataloaders, batch_size))

                    step_pbar.set_postfix(self.metrics)

                # Update lr using lr scheduler
                self._update_lr_scheduler(model, total_step_num, self.metrics)
            step_pbar.close()

        if Meta.config["learner_config"]["local_rank"] in [-1, 0]:
            model = self.logging_manager.close(model)
        logger.info(
            f"Total learning time: {time.time() - start_time} seconds.")
Пример #20
0
    )

    dataloaders = []
    for task_name in args.task:
        dataloaders += create_dataloaders(
            task_name, datasets[task_name], args.batch_size, emb_layer.word2id
        )

    tasks = {
        task_name: create_task(
            task_name, args, datasets[task_name]["nclasses"], emb_layer
        )
        for task_name in args.task
    }

    model = EmmentalModel(name="TC_task")

    if Meta.config["model_config"]["model_path"]:
        model.load(Meta.config["model_config"]["model_path"])
    else:
        for task_name, task in tasks.items():
            model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, dataloaders)

    scores = model.score(dataloaders)
    logger.info(f"Metrics: {scores}")
    write_to_json_file(f"{Meta.log_path}/metrics.txt", scores)

    if args.checkpointing:
Пример #21
0
    def _set_optimizer(self, model: EmmentalModel) -> None:
        """Set optimizer for learning process.

        Args:
          model: The model to set up the optimizer.
        """
        optimizer_config = Meta.config["learner_config"]["optimizer_config"]
        opt = optimizer_config["optimizer"]

        # If Meta.config["learner_config"]["optimizer_config"]["parameters"] is None,
        # create a parameter group with all parameters in the model, else load user
        # specified parameter groups.
        if optimizer_config["parameters"] is None:
            parameters = filter(lambda p: p.requires_grad, model.parameters())
        else:
            parameters = optimizer_config["parameters"](model)

        optim_dict = {
            # PyTorch optimizer
            "asgd": optim.ASGD,
            "adadelta": optim.Adadelta,
            "adagrad": optim.Adagrad,
            "adam": optim.Adam,
            "adamw": optim.AdamW,
            "adamax": optim.Adamax,
            "lbfgs": optim.LBFGS,
            "rms_prop": optim.RMSprop,
            "r_prop": optim.Rprop,
            "sgd": optim.SGD,
            "sparse_adam": optim.SparseAdam,
            # Customized optimizer
            "bert_adam": BertAdam,
        }

        if opt in ["lbfgs", "r_prop", "sparse_adam"]:
            optimizer = optim_dict[opt](
                parameters,
                lr=optimizer_config["lr"],
                **optimizer_config[f"{opt}_config"],
            )
        elif opt in optim_dict.keys():
            optimizer = optim_dict[opt](
                parameters,
                lr=optimizer_config["lr"],
                weight_decay=optimizer_config["l2"],
                **optimizer_config[f"{opt}_config"],
            )
        elif (isinstance(opt, type) and issubclass(opt, optim.Optimizer)) or (
                isinstance(opt, partial)
                and issubclass(opt.func, optim.Optimizer)  # type: ignore
        ):
            optimizer = opt(parameters)  # type: ignore
        else:
            raise ValueError(f"Unrecognized optimizer option '{opt}'")

        self.optimizer = optimizer

        if Meta.config["meta_config"]["verbose"]:
            logger.info(f"Using optimizer {self.optimizer}")

        if Meta.config["learner_config"]["optimizer_path"]:
            try:
                self.optimizer.load_state_dict(
                    torch.load(
                        Meta.config["learner_config"]["optimizer_path"],
                        map_location=torch.device("cpu"),
                    )["optimizer"])
                logger.info(
                    f"Optimizer state loaded from "
                    f"{Meta.config['learner_config']['optimizer_path']}")
            except BaseException:
                logger.error(
                    f"Loading failed... Cannot load optimizer state from "
                    f"{Meta.config['learner_config']['optimizer_path']}, "
                    f"continuing anyway.")
Пример #22
0
def main(args):
    # Initialize Emmental
    config = parse_args_to_config(args)
    emmental.init(log_dir=config["meta_config"]["log_path"], config=config)

    # Log configuration into files
    cmd_msg = " ".join(sys.argv)
    logger.info(f"COMMAND: {cmd_msg}")
    write_to_file(f"{emmental.Meta.log_path}/cmd.txt", cmd_msg)

    logger.info(f"Config: {emmental.Meta.config}")
    write_to_file(f"{emmental.Meta.log_path}/config.txt", emmental.Meta.config)

    # Create dataloaders
    dataloaders = get_dataloaders(args)

    config["learner_config"]["task_scheduler_config"][
        "task_scheduler"] = AugScheduler(augment_k=args.augment_k,
                                         enlarge=args.augment_enlarge)
    emmental.Meta.config["learner_config"]["task_scheduler_config"][
        "task_scheduler"] = config["learner_config"]["task_scheduler_config"][
            "task_scheduler"]

    # Specify parameter group for Adam BERT
    def grouped_parameters(model):
        no_decay = ["bias", "LayerNorm.weight"]
        return [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                emmental.Meta.config["learner_config"]["optimizer_config"]
                ["l2"],
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]

    emmental.Meta.config["learner_config"]["optimizer_config"][
        "parameters"] = grouped_parameters

    # Create tasks
    model = EmmentalModel(name=f"{args.task}_task")
    model.add_task(create_task(args))

    # Load the best model from the pretrained model
    if config["model_config"]["model_path"] is not None:
        model.load(config["model_config"]["model_path"])

    if args.train:
        emmental_learner = EmmentalLearner()
        emmental_learner.learn(model, dataloaders)

    # Remove all extra augmentation policy
    for idx in range(len(dataloaders)):
        dataloaders[idx].dataset.transform_cls = None
        dataloaders[idx].dataset.k = 1

    scores = model.score(dataloaders)

    # Save metrics and models
    logger.info(f"Metrics: {scores}")
    scores["log_path"] = emmental.Meta.log_path
    write_to_json_file(f"{emmental.Meta.log_path}/metrics.txt", scores)
    model.save(f"{emmental.Meta.log_path}/last_model.pth")
Пример #23
0
    for task_name in args.task:
        for split in ["train", "dev", "test"]:
            dataloaders.append(
                EmmentalDataLoader(
                    task_to_label_dict={task_name: "labels"},
                    dataset=datasets[task_name][split],
                    split=split,
                    batch_size=args.batch_size,
                    shuffle=True if split == "train" else False,
                )
            )
            logger.info(f"Built dataloader for {task_name} {split} set.")

    tasks = get_gule_task(args.task, args.bert_model)

    mtl_model = EmmentalModel(name="GLUE_multi_task")

    if Meta.config["model_config"]["model_path"]:
        mtl_model.load(Meta.config["model_config"]["model_path"])
    else:
        for task_name, task in tasks.items():
            mtl_model.add_task(task)

    emmental_learner = EmmentalLearner()

    emmental_learner.learn(mtl_model, dataloaders)

    scores = mtl_model.score(dataloaders)
    logger.info(f"Metrics: {scores}")
    write_to_file("metrics.txt", scores)
    logger.info(
for split in splits:
    dataloaders.append(
        EmmentalDataLoader(
            task_to_label_dict={"ht_page": "label"},
            dataset=datasets[split],
            split=split,
            batch_size=16,
            shuffle=False,
        ))
    print(f"Built dataloader for {split} set.")

# Getting tasks
tasks = get_task(task_names, config['embed_dim'], char_dict_size)

# Build Emmental model
model = EmmentalModel(name="HT", tasks=tasks)

if Meta.config["model_config"]["model_path"]:
    print('Loading model...')
    model.load(Meta.config["model_config"]["model_path"])

# Scoring
import torch
print("Running prediction model...")
sft = torch.nn.Softmax()
res = model.predict(dataloaders[0], return_preds=True, return_uids=True)
doc_extractions = {}
doc_extractions = {
    res['uids'][task_names[0]][ii]: {
        'prediction':
        str(np.array(sft(torch.Tensor(res['probs'][task_names[0]][ii])))[0])
Пример #25
0
    def learn(
        self, model: EmmentalModel, dataloaders: List[EmmentalDataLoader]
    ) -> None:
        r"""The learning procedure of emmental MTL.

        Args:
          model(EmmentalModel): The emmental model that needs to learn.
          dataloaders(List[EmmentalDataLoader]): a list of dataloaders used to
            learn the model.

        """

        # Generate the list of dataloaders for learning process
        train_split = Meta.config["learner_config"]["train_split"]
        if isinstance(train_split, str):
            train_split = [train_split]

        train_dataloaders = [
            dataloader for dataloader in dataloaders if dataloader.split in train_split
        ]

        if not train_dataloaders:
            raise ValueError(
                f"Cannot find the specified train_split "
                f'{Meta.config["learner_config"]["train_split"]} in dataloaders.'
            )

        # Set up task_scheduler
        self._set_task_scheduler()

        # Calculate the total number of batches per epoch
        self.n_batches_per_epoch = self.task_scheduler.get_num_batches(
            train_dataloaders
        )

        # Set up logging manager
        self._set_logging_manager()
        # Set up optimizer
        self._set_optimizer(model)
        # Set up lr_scheduler
        self._set_lr_scheduler(model)

        # Set to training mode
        model.train()

        if Meta.config["meta_config"]["verbose"]:
            logger.info(f"Start learning...")

        self.metrics: Dict[str, float] = dict()
        self._reset_losses()

        for epoch_num in range(Meta.config["learner_config"]["n_epochs"]):
            batches = tqdm(
                enumerate(self.task_scheduler.get_batches(train_dataloaders, model)),
                total=self.n_batches_per_epoch,
                disable=(not Meta.config["meta_config"]["verbose"]),
                desc=f"Epoch {epoch_num}:",
            )

            for batch_num, batch in batches:

                # Covert single batch into a batch list
                if not isinstance(batch, list):
                    batch = [batch]

                total_batch_num = epoch_num * self.n_batches_per_epoch + batch_num
                batch_size = 0

                # Set gradients of all model parameters to zero
                self.optimizer.zero_grad()

                for uids, X_dict, Y_dict, task_to_label_dict, data_name, split in batch:
                    batch_size += len(next(iter(Y_dict.values())))

                    # Perform forward pass and calcualte the loss and count
                    uid_dict, loss_dict, prob_dict, gold_dict = model(
                        uids, X_dict, Y_dict, task_to_label_dict
                    )

                    # Update running loss and count
                    for task_name in uid_dict.keys():
                        identifier = f"{task_name}/{data_name}/{split}"
                        self.running_uids[identifier].extend(uid_dict[task_name])
                        self.running_losses[identifier] += (
                            loss_dict[task_name].item() * len(uid_dict[task_name])
                            if len(loss_dict[task_name].size()) == 0
                            else torch.sum(loss_dict[task_name]).item()
                        )
                        self.running_probs[identifier].extend(prob_dict[task_name])
                        self.running_golds[identifier].extend(gold_dict[task_name])

                    # Skip the backward pass if no loss is calcuated
                    if not loss_dict:
                        continue

                    # Calculate the average loss
                    loss = sum(
                        [
                            model.weights[task_name] * task_loss
                            if len(task_loss.size()) == 0
                            else torch.mean(model.weights[task_name] * task_loss)
                            for task_name, task_loss in loss_dict.items()
                        ]
                    )

                    # Perform backward pass to calculate gradients
                    loss.backward()  # type: ignore

                # Clip gradient norm
                if Meta.config["learner_config"]["optimizer_config"]["grad_clip"]:
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(),
                        Meta.config["learner_config"]["optimizer_config"]["grad_clip"],
                    )

                # Update the parameters
                self.optimizer.step()

                self.metrics.update(self._logging(model, dataloaders, batch_size))

                batches.set_postfix(self.metrics)

                # Update lr using lr scheduler
                self._update_lr_scheduler(model, total_batch_num, self.metrics)

        model = self.logging_manager.close(model)
Пример #26
0
def main(
    conn_string,
    gain=False,
    current=False,
    max_docs=float("inf"),
    parse=False,
    first_time=False,
    re_label=False,
    parallel=8,
    log_dir="logs",
    verbose=False,
):
    # Setup initial configuration
    if not log_dir:
        log_dir = "logs"

    if verbose:
        level = logging.INFO
    else:
        level = logging.WARNING

    dirname = os.path.dirname(os.path.abspath(__file__))
    init_logging(log_dir=os.path.join(dirname, log_dir), level=level)

    rel_list = []
    if gain:
        rel_list.append("gain")

    if current:
        rel_list.append("current")

    logger.info(f"=" * 30)
    logger.info(f"Running with parallel: {parallel}, max_docs: {max_docs}")

    session = Meta.init(conn_string).Session()

    # Parsing
    start = timer()
    logger.info(f"Starting parsing...")
    docs, train_docs, dev_docs, test_docs = parse_dataset(session,
                                                          dirname,
                                                          first_time=parse,
                                                          parallel=parallel,
                                                          max_docs=max_docs)
    logger.debug(f"Done")
    end = timer()
    logger.warning(f"Parse Time (min): {((end - start) / 60.0):.1f}")

    logger.info(f"# of Documents: {len(docs)}")
    logger.info(f"# of train Documents: {len(train_docs)}")
    logger.info(f"# of dev Documents: {len(dev_docs)}")
    logger.info(f"# of test Documents: {len(test_docs)}")
    logger.info(f"Documents: {session.query(Document).count()}")
    logger.info(f"Sections: {session.query(Section).count()}")
    logger.info(f"Paragraphs: {session.query(Paragraph).count()}")
    logger.info(f"Sentences: {session.query(Sentence).count()}")
    logger.info(f"Figures: {session.query(Figure).count()}")

    # Mention Extraction
    start = timer()
    mentions = []
    ngrams = []
    matchers = []

    # Only do those that are enabled
    if gain:
        Gain = mention_subclass("Gain")
        gain_matcher = get_gain_matcher()
        gain_ngrams = MentionNgrams(n_max=2)
        mentions.append(Gain)
        ngrams.append(gain_ngrams)
        matchers.append(gain_matcher)

    if current:
        Current = mention_subclass("SupplyCurrent")
        current_matcher = get_supply_current_matcher()
        current_ngrams = MentionNgramsCurrent(n_max=3)
        mentions.append(Current)
        ngrams.append(current_ngrams)
        matchers.append(current_matcher)

    mention_extractor = MentionExtractor(session, mentions, ngrams, matchers)

    if first_time:
        mention_extractor.apply(docs, parallelism=parallel)

    logger.info(f"Total Mentions: {session.query(Mention).count()}")

    if gain:
        logger.info(f"Total Gain: {session.query(Gain).count()}")

    if current:
        logger.info(f"Total Current: {session.query(Current).count()}")

    cand_classes = []
    if gain:
        GainCand = candidate_subclass("GainCand", [Gain])
        cand_classes.append(GainCand)
    if current:
        CurrentCand = candidate_subclass("CurrentCand", [Current])
        cand_classes.append(CurrentCand)

    candidate_extractor = CandidateExtractor(session, cand_classes)

    if first_time:
        for i, docs in enumerate([train_docs, dev_docs, test_docs]):
            candidate_extractor.apply(docs, split=i, parallelism=parallel)

    # These must be sorted for deterministic behavior.
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)
    logger.info(
        f"Total train candidate: {len(train_cands[0]) + len(train_cands[1])}")
    logger.info(
        f"Total dev candidate: {len(dev_cands[0]) + len(dev_cands[1])}")
    logger.info(
        f"Total test candidate: {len(test_cands[0]) + len(test_cands[1])}")

    logger.info("Done w/ candidate extraction.")
    end = timer()
    logger.warning(f"CE Time (min): {((end - start) / 60.0):.1f}")

    # First, check total recall
    #  result = entity_level_scores(
    #      candidates_to_entities(dev_cands[0], is_gain=True),
    #      corpus=dev_docs,
    #      is_gain=True,
    #  )
    #  logger.info(f"Gain Total Dev Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #  result = entity_level_scores(
    #      candidates_to_entities(test_cands[0], is_gain=True),
    #      corpus=test_docs,
    #      is_gain=True,
    #  )
    #  logger.info(f"Gain Total Test Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #
    #  result = entity_level_scores(
    #      candidates_to_entities(dev_cands[1], is_gain=False),
    #      corpus=dev_docs,
    #      is_gain=False,
    #  )
    #  logger.info(f"Current Total Dev Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")
    #  result = entity_level_scores(
    #      candidates_to_entities(test_cands[1], is_gain=False),
    #      corpus=test_docs,
    #      is_gain=False,
    #  )
    #  logger.info(f"Current Test Recall: {result.rec:.3f}")
    #  logger.info(f"\n{pformat(result.FN)}")

    start = timer()

    # Using parallelism = 1 for deterministic behavior.
    featurizer = Featurizer(session, cand_classes, parallelism=1)

    if first_time:
        logger.info("Starting featurizer...")
        # Set feature space based on dev set, which we use for training rather
        # than the large train set.
        featurizer.apply(split=1, train=True)
        featurizer.apply(split=0)
        featurizer.apply(split=2)
        logger.info("Done")

    logger.info("Getting feature matrices...")
    # Serialize feature matrices on first run
    if first_time:
        F_train = featurizer.get_feature_matrices(train_cands)
        F_dev = featurizer.get_feature_matrices(dev_cands)
        F_test = featurizer.get_feature_matrices(test_cands)
        end = timer()
        logger.warning(
            f"Featurization Time (min): {((end - start) / 60.0):.1f}")

        F_train_dict = {}
        F_dev_dict = {}
        F_test_dict = {}
        for idx, relation in enumerate(rel_list):
            F_train_dict[relation] = F_train[idx]
            F_dev_dict[relation] = F_dev[idx]
            F_test_dict[relation] = F_test[idx]

        pickle.dump(F_train_dict,
                    open(os.path.join(dirname, "F_train_dict.pkl"), "wb"))
        pickle.dump(F_dev_dict,
                    open(os.path.join(dirname, "F_dev_dict.pkl"), "wb"))
        pickle.dump(F_test_dict,
                    open(os.path.join(dirname, "F_test_dict.pkl"), "wb"))
    else:
        F_train_dict = pickle.load(
            open(os.path.join(dirname, "F_train_dict.pkl"), "rb"))
        F_dev_dict = pickle.load(
            open(os.path.join(dirname, "F_dev_dict.pkl"), "rb"))
        F_test_dict = pickle.load(
            open(os.path.join(dirname, "F_test_dict.pkl"), "rb"))

        F_train = []
        F_dev = []
        F_test = []
        for relation in rel_list:
            F_train.append(F_train_dict[relation])
            F_dev.append(F_dev_dict[relation])
            F_test.append(F_test_dict[relation])

    logger.info("Done.")

    start = timer()
    logger.info("Labeling training data...")
    #  labeler = Labeler(session, cand_classes)
    #  lfs = []
    #  if gain:
    #      lfs.append(gain_lfs)
    #
    #  if current:
    #      lfs.append(current_lfs)
    #
    #  if first_time:
    #      logger.info("Applying LFs...")
    #      labeler.apply(split=0, lfs=lfs, train=True, parallelism=parallel)
    #  elif re_label:
    #      logger.info("Re-applying LFs...")
    #      labeler.update(split=0, lfs=lfs, parallelism=parallel)
    #
    #  logger.info("Done...")

    #  logger.info("Getting label matrices...")
    #  L_train = labeler.get_label_matrices(train_cands)
    #  logger.info("Done...")

    if first_time:
        marginals_dict = {}
        for idx, relation in enumerate(rel_list):
            # Manually create marginals from human annotations
            marginal = []
            dev_gold_entities = get_gold_set(is_gain=(relation == "gain"))
            for c in dev_cands[idx]:
                flag = False
                for entity in cand_to_entity(c, is_gain=(relation == "gain")):
                    if entity in dev_gold_entities:
                        flag = True

                if flag:
                    marginal.append([0.0, 1.0])
                else:
                    marginal.append([1.0, 0.0])

            marginals_dict[relation] = np.array(marginal)

        pickle.dump(marginals_dict,
                    open(os.path.join(dirname, "marginals_dict.pkl"), "wb"))
    else:
        marginals_dict = pickle.load(
            open(os.path.join(dirname, "marginals_dict.pkl"), "rb"))

    marginals = []
    for relation in rel_list:
        marginals.append(marginals_dict[relation])

    end = timer()
    logger.warning(
        f"Weak Supervision Time (min): {((end - start) / 60.0):.1f}")

    start = timer()

    word_counter = collect_word_counter(train_cands)

    # Training config
    config = {
        "meta_config": {
            "verbose": True,
            "seed": 30
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 500,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.005
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    "model/all/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }

    emmental.init(log_dir=Meta.log_path, config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)
    train_idxs = []
    train_dataloader = []
    for idx, relation in enumerate(rel_list):
        diffs = marginals[idx].max(axis=1) - marginals[idx].min(axis=1)
        train_idxs.append(np.where(diffs > 1e-6)[0])

        # only uses dev set as training data, with human annotations
        train_dataloader.append(
            EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(
                    relation,
                    dev_cands[idx],
                    F_dev[idx],
                    emb_layer.word2id,
                    marginals[idx],
                    train_idxs[idx],
                ),
                split="train",
                batch_size=256,
                shuffle=True,
            ))

    num_feature_keys = len(featurizer.get_keys())

    model = EmmentalModel(name=f"opamp_tasks")

    # List relation names, arities, list of classes
    tasks = create_task(
        rel_list,
        [2] * len(rel_list),
        num_feature_keys,
        [2] * len(rel_list),
        emb_layer,
        model="LogisticRegression",
    )

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()

    # If given a list of multi, will train on multiple
    emmental_learner.learn(model, train_dataloader)

    # List of dataloader for each relation
    for idx, relation in enumerate(rel_list):
        test_dataloader = EmmentalDataLoader(
            task_to_label_dict={relation: "labels"},
            dataset=FonduerDataset(relation, test_cands[idx], F_test[idx],
                                   emb_layer.word2id, 2),
            split="test",
            batch_size=256,
            shuffle=False,
        )

        test_preds = model.predict(test_dataloader, return_preds=True)

        best_result, best_b = scoring(
            test_preds,
            test_cands[idx],
            test_docs,
            is_gain=(relation == "gain"),
            num=100,
        )

        # Dump CSV files for analysis
        if relation == "gain":
            train_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, train_cands[idx],
                                       F_train[idx], emb_layer.word2id, 2),
                split="train",
                batch_size=256,
                shuffle=False,
            )

            train_preds = model.predict(train_dataloader, return_preds=True)
            Y_prob = np.array(train_preds["probs"][relation])[:, TRUE]
            output_csv(train_cands[idx], Y_prob, is_gain=True)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            output_csv(test_cands[idx], Y_prob, is_gain=True, append=True)
            dump_candidates(test_cands[idx],
                            Y_prob,
                            "gain_test_probs.csv",
                            is_gain=True)

            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=256,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            output_csv(dev_cands[idx], Y_prob, is_gain=True, append=True)
            dump_candidates(dev_cands[idx],
                            Y_prob,
                            "gain_dev_probs.csv",
                            is_gain=True)

        if relation == "current":
            train_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, train_cands[idx],
                                       F_train[idx], emb_layer.word2id, 2),
                split="train",
                batch_size=256,
                shuffle=False,
            )

            train_preds = model.predict(train_dataloader, return_preds=True)
            Y_prob = np.array(train_preds["probs"][relation])[:, TRUE]
            output_csv(train_cands[idx], Y_prob, is_gain=False)

            Y_prob = np.array(test_preds["probs"][relation])[:, TRUE]
            output_csv(test_cands[idx], Y_prob, is_gain=False, append=True)
            dump_candidates(test_cands[idx],
                            Y_prob,
                            "current_test_probs.csv",
                            is_gain=False)

            dev_dataloader = EmmentalDataLoader(
                task_to_label_dict={relation: "labels"},
                dataset=FonduerDataset(relation, dev_cands[idx], F_dev[idx],
                                       emb_layer.word2id, 2),
                split="dev",
                batch_size=256,
                shuffle=False,
            )

            dev_preds = model.predict(dev_dataloader, return_preds=True)

            Y_prob = np.array(dev_preds["probs"][relation])[:, TRUE]
            output_csv(dev_cands[idx], Y_prob, is_gain=False, append=True)
            dump_candidates(dev_cands[idx],
                            Y_prob,
                            "current_dev_probs.csv",
                            is_gain=False)

    end = timer()
    logger.warning(
        f"Classification AND dump data Time (min): {((end - start) / 60.0):.1f}"
    )
Пример #27
0
def test_e2e():
    """Run an end-to-end test on documents of the hardware domain."""
    # GitHub Actions gives 2 cores
    # help.github.com/en/actions/reference/virtual-environments-for-github-hosted-runners
    PARALLEL = 2

    max_docs = 12

    fonduer.init_logging(
        format="[%(asctime)s][%(levelname)s] %(name)s:%(lineno)s - %(message)s",
        level=logging.INFO,
    )

    session = fonduer.Meta.init(CONN_STRING).Session()

    docs_path = "tests/data/html/"
    pdf_path = "tests/data/pdf/"

    doc_preprocessor = HTMLDocPreprocessor(docs_path, max_docs=max_docs)

    corpus_parser = Parser(
        session,
        parallelism=PARALLEL,
        structural=True,
        lingual=True,
        visual=True,
        pdf_path=pdf_path,
    )
    corpus_parser.apply(doc_preprocessor)
    assert session.query(Document).count() == max_docs

    num_docs = session.query(Document).count()
    logger.info(f"Docs: {num_docs}")
    assert num_docs == max_docs

    num_sentences = session.query(Sentence).count()
    logger.info(f"Sentences: {num_sentences}")

    # Divide into test and train
    docs = sorted(corpus_parser.get_documents())
    last_docs = sorted(corpus_parser.get_last_documents())

    ld = len(docs)
    assert ld == len(last_docs)
    assert len(docs[0].sentences) == len(last_docs[0].sentences)

    assert len(docs[0].sentences) == 799
    assert len(docs[1].sentences) == 663
    assert len(docs[2].sentences) == 784
    assert len(docs[3].sentences) == 661
    assert len(docs[4].sentences) == 513
    assert len(docs[5].sentences) == 700
    assert len(docs[6].sentences) == 528
    assert len(docs[7].sentences) == 161
    assert len(docs[8].sentences) == 228
    assert len(docs[9].sentences) == 511
    assert len(docs[10].sentences) == 331
    assert len(docs[11].sentences) == 528

    # Check table numbers
    assert len(docs[0].tables) == 9
    assert len(docs[1].tables) == 9
    assert len(docs[2].tables) == 14
    assert len(docs[3].tables) == 11
    assert len(docs[4].tables) == 11
    assert len(docs[5].tables) == 10
    assert len(docs[6].tables) == 10
    assert len(docs[7].tables) == 2
    assert len(docs[8].tables) == 7
    assert len(docs[9].tables) == 10
    assert len(docs[10].tables) == 6
    assert len(docs[11].tables) == 9

    # Check figure numbers
    assert len(docs[0].figures) == 32
    assert len(docs[1].figures) == 11
    assert len(docs[2].figures) == 38
    assert len(docs[3].figures) == 31
    assert len(docs[4].figures) == 7
    assert len(docs[5].figures) == 38
    assert len(docs[6].figures) == 10
    assert len(docs[7].figures) == 31
    assert len(docs[8].figures) == 4
    assert len(docs[9].figures) == 27
    assert len(docs[10].figures) == 5
    assert len(docs[11].figures) == 27

    # Check caption numbers
    assert len(docs[0].captions) == 0
    assert len(docs[1].captions) == 0
    assert len(docs[2].captions) == 0
    assert len(docs[3].captions) == 0
    assert len(docs[4].captions) == 0
    assert len(docs[5].captions) == 0
    assert len(docs[6].captions) == 0
    assert len(docs[7].captions) == 0
    assert len(docs[8].captions) == 0
    assert len(docs[9].captions) == 0
    assert len(docs[10].captions) == 0
    assert len(docs[11].captions) == 0

    train_docs = set()
    dev_docs = set()
    test_docs = set()
    splits = (0.5, 0.75)
    data = [(doc.name, doc) for doc in docs]
    data.sort(key=lambda x: x[0])
    for i, (doc_name, doc) in enumerate(data):
        if i < splits[0] * ld:
            train_docs.add(doc)
        elif i < splits[1] * ld:
            dev_docs.add(doc)
        else:
            test_docs.add(doc)
    logger.info([x.name for x in train_docs])

    # NOTE: With multi-relation support, return values of getting candidates,
    # mentions, or sparse matrices are formatted as a list of lists. This means
    # that with a single relation, we need to index into the list of lists to
    # get the candidates/mentions/sparse matrix for a particular relation or
    # mention.

    # Mention Extraction
    part_ngrams = MentionNgramsPart(parts_by_doc=None, n_max=3)
    temp_ngrams = MentionNgramsTemp(n_max=2)
    volt_ngrams = MentionNgramsVolt(n_max=1)

    Part = mention_subclass("Part")
    Temp = mention_subclass("Temp")
    Volt = mention_subclass("Volt")

    mention_extractor = MentionExtractor(
        session,
        [Part, Temp, Volt],
        [part_ngrams, temp_ngrams, volt_ngrams],
        [part_matcher, temp_matcher, volt_matcher],
    )

    mention_extractor.apply(docs, parallelism=PARALLEL)

    assert session.query(Part).count() == 299
    assert session.query(Temp).count() == 138
    assert session.query(Volt).count() == 140
    assert len(mention_extractor.get_mentions()) == 3
    assert len(mention_extractor.get_mentions()[0]) == 299
    assert (len(
        mention_extractor.get_mentions(docs=[
            session.query(Document).filter(Document.name == "112823").first()
        ])[0]) == 70)

    # Candidate Extraction
    PartTemp = candidate_subclass("PartTemp", [Part, Temp])
    PartVolt = candidate_subclass("PartVolt", [Part, Volt])

    candidate_extractor = CandidateExtractor(
        session, [PartTemp, PartVolt],
        throttlers=[temp_throttler, volt_throttler])

    for i, docs in enumerate([train_docs, dev_docs, test_docs]):
        candidate_extractor.apply(docs, split=i, parallelism=PARALLEL)

    assert session.query(PartTemp).filter(PartTemp.split == 0).count() == 3493
    assert session.query(PartTemp).filter(PartTemp.split == 1).count() == 61
    assert session.query(PartTemp).filter(PartTemp.split == 2).count() == 416
    assert session.query(PartVolt).count() == 4282

    # Grab candidate lists
    train_cands = candidate_extractor.get_candidates(split=0, sort=True)
    dev_cands = candidate_extractor.get_candidates(split=1, sort=True)
    test_cands = candidate_extractor.get_candidates(split=2, sort=True)
    assert len(train_cands) == 2
    assert len(train_cands[0]) == 3493
    assert (len(
        candidate_extractor.get_candidates(docs=[
            session.query(Document).filter(Document.name == "112823").first()
        ])[0]) == 1432)

    # Featurization
    featurizer = Featurizer(session, [PartTemp, PartVolt])

    # Test that FeatureKey is properly reset
    featurizer.apply(split=1, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 214
    assert session.query(FeatureKey).count() == 1260

    # Test Dropping FeatureKey
    # Should force a row deletion
    featurizer.drop_keys(["DDL_e1_W_LEFT_POS_3_[NNP NN IN]"])
    assert session.query(FeatureKey).count() == 1259

    # Should only remove the part_volt as a relation and leave part_temp
    assert set(
        session.query(FeatureKey).filter(
            FeatureKey.name ==
            "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == {
                "part_temp", "part_volt"
            }
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartVolt])
    assert session.query(FeatureKey).filter(
        FeatureKey.name ==
        "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes == ["part_temp"]
    assert session.query(FeatureKey).count() == 1259

    # Inserting the removed key
    featurizer.upsert_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                           candidate_classes=[PartTemp, PartVolt])
    assert set(
        session.query(FeatureKey).filter(
            FeatureKey.name ==
            "DDL_e1_LEMMA_SEQ_[bc182]").one().candidate_classes) == {
                "part_temp", "part_volt"
            }
    assert session.query(FeatureKey).count() == 1259
    # Removing the key again
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartVolt])

    # Removing the last relation from a key should delete the row
    featurizer.drop_keys(["DDL_e1_LEMMA_SEQ_[bc182]"],
                         candidate_classes=[PartTemp])
    assert session.query(FeatureKey).count() == 1258
    session.query(Feature).delete(synchronize_session="fetch")
    session.query(FeatureKey).delete(synchronize_session="fetch")

    featurizer.apply(split=0, train=True, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6478
    assert session.query(FeatureKey).count() == 4538
    F_train = featurizer.get_feature_matrices(train_cands)
    assert F_train[0].shape == (3493, 4538)
    assert F_train[1].shape == (2985, 4538)
    assert len(featurizer.get_keys()) == 4538

    featurizer.apply(split=1, parallelism=PARALLEL)
    assert session.query(Feature).count() == 6692
    assert session.query(FeatureKey).count() == 4538
    F_dev = featurizer.get_feature_matrices(dev_cands)
    assert F_dev[0].shape == (61, 4538)
    assert F_dev[1].shape == (153, 4538)

    featurizer.apply(split=2, parallelism=PARALLEL)
    assert session.query(Feature).count() == 8252
    assert session.query(FeatureKey).count() == 4538
    F_test = featurizer.get_feature_matrices(test_cands)
    assert F_test[0].shape == (416, 4538)
    assert F_test[1].shape == (1144, 4538)

    gold_file = "tests/data/hardware_tutorial_gold.csv"

    labeler = Labeler(session, [PartTemp, PartVolt])

    labeler.apply(
        docs=last_docs,
        lfs=[[gold], [gold]],
        table=GoldLabel,
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(GoldLabel).count() == 8252

    stg_temp_lfs = [
        LF_storage_row,
        LF_operating_row,
        LF_temperature_row,
        LF_tstg_row,
        LF_to_left,
        LF_negative_number_left,
    ]

    ce_v_max_lfs = [
        LF_bad_keywords_in_row,
        LF_current_in_row,
        LF_non_ce_voltages_in_row,
    ]

    with pytest.raises(ValueError):
        labeler.apply(split=0,
                      lfs=stg_temp_lfs,
                      train=True,
                      parallelism=PARALLEL)

    labeler.apply(
        docs=train_docs,
        lfs=[stg_temp_lfs, ce_v_max_lfs],
        train=True,
        parallelism=PARALLEL,
    )
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 9
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 9)
    assert L_train[1].shape == (2985, 9)
    assert len(labeler.get_keys()) == 9

    # Test Dropping LabelerKey
    labeler.drop_keys(["LF_storage_row"])
    assert len(labeler.get_keys()) == 8

    # Test Upserting LabelerKey
    labeler.upsert_keys(["LF_storage_row"])
    assert "LF_storage_row" in [label.name for label in labeler.get_keys()]

    L_train_gold = labeler.get_gold_labels(train_cands)
    assert L_train_gold[0].shape == (3493, 1)

    L_train_gold = labeler.get_gold_labels(train_cands, annotator="gold")
    assert L_train_gold[0].shape == (3493, 1)

    label_model = LabelModel()
    label_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100)

    train_marginals = label_model.predict_proba(L_train[0])

    # Collect word counter
    word_counter = collect_word_counter(train_cands)

    emmental.init(fonduer.Meta.log_path)

    # Training config
    config = {
        "meta_config": {
            "verbose": False
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": False
        },
        "learner_config": {
            "n_epochs": 5,
            "optimizer_config": {
                "lr": 0.001,
                "l2": 0.0
            },
            "task_scheduler": "round_robin",
        },
        "logging_config": {
            "evaluation_freq": 1,
            "counter_unit": "epoch",
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_metric": {
                    f"{ATTRIBUTE}/{ATTRIBUTE}/train/loss": "min"
                },
                "checkpoint_freq": 1,
                "checkpoint_runway": 2,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }
    emmental.Meta.update_config(config=config)

    # Generate word embedding module
    arity = 2
    # Geneate special tokens
    specials = []
    for i in range(arity):
        specials += [f"~~[[{i}", f"{i}]]~~"]

    emb_layer = EmbeddingModule(word_counter=word_counter,
                                word_dim=300,
                                specials=specials)

    diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1)
    train_idxs = np.where(diffs > 1e-6)[0]

    train_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            train_marginals,
            train_idxs,
        ),
        split="train",
        batch_size=100,
        shuffle=True,
    )

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LogisticRegression")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(ATTRIBUTE, test_cands[0], F_test[0],
                               emb_layer.word2id, 2),
        split="test",
        batch_size=100,
        shuffle=False,
    )

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.6)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    pickle_file = "tests/data/parts_by_doc_dict.pkl"
    with open(pickle_file, "rb") as f:
        parts_by_doc = pickle.load(f)

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 < 0.7 and f1 > 0.3

    stg_temp_lfs_2 = [
        LF_to_left,
        LF_test_condition_aligned,
        LF_collector_aligned,
        LF_current_aligned,
        LF_voltage_row_temp,
        LF_voltage_row_part,
        LF_typ_row,
        LF_complement_left_row,
        LF_too_many_numbers_row,
        LF_temp_on_high_page_num,
        LF_temp_outside_table,
        LF_not_temp_relevant,
    ]
    labeler.update(split=0,
                   lfs=[stg_temp_lfs_2, ce_v_max_lfs],
                   parallelism=PARALLEL)
    assert session.query(Label).count() == 6478
    assert session.query(LabelKey).count() == 16
    L_train = labeler.get_label_matrices(train_cands)
    assert L_train[0].shape == (3493, 16)

    label_model = LabelModel()
    label_model.fit(L_train=L_train[0], n_epochs=500, log_freq=100)

    train_marginals = label_model.predict_proba(L_train[0])

    diffs = train_marginals.max(axis=1) - train_marginals.min(axis=1)
    train_idxs = np.where(diffs > 1e-6)[0]

    train_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            train_marginals,
            train_idxs,
        ),
        split="train",
        batch_size=100,
        shuffle=True,
    )

    valid_dataloader = EmmentalDataLoader(
        task_to_label_dict={ATTRIBUTE: "labels"},
        dataset=FonduerDataset(
            ATTRIBUTE,
            train_cands[0],
            F_train[0],
            emb_layer.word2id,
            np.argmax(train_marginals, axis=1),
            train_idxs,
        ),
        split="valid",
        batch_size=100,
        shuffle=False,
    )

    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LogisticRegression")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader, valid_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7

    # Testing LSTM
    emmental.Meta.reset()
    emmental.init(fonduer.Meta.log_path)
    emmental.Meta.update_config(config=config)

    tasks = create_task(ATTRIBUTE,
                        2,
                        F_train[0].shape[1],
                        2,
                        emb_layer,
                        model="LSTM")

    model = EmmentalModel(name=f"{ATTRIBUTE}_task")

    for task in tasks:
        model.add_task(task)

    emmental_learner = EmmentalLearner()
    emmental_learner.learn(model, [train_dataloader])

    test_preds = model.predict(test_dataloader, return_preds=True)
    positive = np.where(
        np.array(test_preds["probs"][ATTRIBUTE])[:, TRUE] > 0.7)
    true_pred = [test_cands[0][_] for _ in positive[0]]

    (TP, FP, FN) = entity_level_f1(true_pred,
                                   gold_file,
                                   ATTRIBUTE,
                                   test_docs,
                                   parts_by_doc=parts_by_doc)

    tp_len = len(TP)
    fp_len = len(FP)
    fn_len = len(FN)
    prec = tp_len / (tp_len + fp_len) if tp_len + fp_len > 0 else float("nan")
    rec = tp_len / (tp_len + fn_len) if tp_len + fn_len > 0 else float("nan")
    f1 = 2 * (prec * rec) / (prec + rec) if prec + rec > 0 else float("nan")

    logger.info(f"prec: {prec}")
    logger.info(f"rec: {rec}")
    logger.info(f"f1: {f1}")

    assert f1 > 0.7
Пример #28
0
            max_seq_length=args.max_seq_length,
        )
        logger.info(f"Loaded {split} containing {len(dataset)} samples.")
        dataloaders.append(
            EmmentalDataLoader(
                task_to_label_dict={args.task_name: "labels"},
                dataset=dataset,
                split=split,
                shuffle=True if split == "train" else False,
                batch_size=args.batch_size,
                # num_workers=8,
            ))
        logger.info(f"Built dataloader for {dataset.name} {split} set.")

    # Build Emmental model
    model = EmmentalModel(name=args.task_name, tasks=create_task(args))

    # Load the pre-trained model
    if Meta.config["model_config"]["model_path"]:
        model.load(Meta.config["model_config"]["model_path"])

    # Training
    if args.train:
        emmental_learner = EmmentalLearner()
        emmental_learner.learn(model, dataloaders)

    scores = model.score(dataloaders)

    # Save metrics into file
    logger.info(f"Metrics: {scores}")
    write_to_json_file(f"{Meta.log_path}/metrics.txt", scores)
Пример #29
0
def main(args):
    # Ensure that global state is fresh
    Meta.reset()

    # Initialize Emmental
    config = parse_arg_to_config(args)
    emmental.init(config["meta_config"]["log_path"], config=config)

    # Save command line argument into file
    cmd_msg = " ".join(sys.argv)
    logger.info(f"COMMAND: {cmd_msg}")
    write_to_file(Meta.log_path, "cmd.txt", cmd_msg)

    # Save Emmental config into file
    logger.info(f"Config: {Meta.config}")
    write_to_file(Meta.log_path, "config.txt", Meta.config)

    Meta.config["learner_config"]["global_evaluation_metric_dict"] = {
        f"model/SuperGLUE/{split}/score": partial(superglue_scorer,
                                                  split=split)
        for split in ["val"]
    }

    # Construct dataloaders and tasks and load slices
    dataloaders = []
    tasks = []

    for task_name in args.task:
        task_dataloaders = get_dataloaders(
            data_dir=args.data_dir,
            task_name=task_name,
            splits=["train", "val", "test"],
            max_sequence_length=args.max_sequence_length,
            max_data_samples=args.max_data_samples,
            tokenizer_name=args.bert_model,
            batch_size=args.batch_size,
            augment=args.augmentations,
        )
        task = models.model[task_name](
            args.bert_model,
            last_hidden_dropout_prob=args.last_hidden_dropout_prob)
        if args.slices:
            logger.info("Initializing task-specific slices")
            slice_func_dict = slicing.slice_func_dict[task_name]
            # Include general purpose slices
            if args.general_slices:
                logger.info("Including general slices")
                slice_func_dict.update(slicing.slice_func_dict["general"])

            task_dataloaders = slicing.add_slice_labels(
                task_name, task_dataloaders, slice_func_dict)

            slice_tasks = slicing.add_slice_tasks(task_name, task,
                                                  slice_func_dict,
                                                  args.slice_hidden_dim)
            tasks.extend(slice_tasks)
        else:
            tasks.append(task)

        dataloaders.extend(task_dataloaders)

    # Build Emmental model
    model = EmmentalModel(name="SuperGLUE", tasks=tasks)

    # Load pretrained model if necessary
    if Meta.config["model_config"]["model_path"]:
        model.load(Meta.config["model_config"]["model_path"])

    # Training
    if args.train:
        emmental_learner = EmmentalLearner()
        emmental_learner.learn(model, dataloaders)

    # If model is slice-aware, slice scores will be calculated from slice heads
    # If model is not slice-aware, manually calculate performance on slices
    if not args.slices:
        slice_func_dict = {}
        slice_keys = args.task
        if args.general_slices:
            slice_keys.append("general")

        for k in slice_keys:
            slice_func_dict.update(slicing.slice_func_dict[k])

        scores = slicing.score_slices(model, dataloaders, args.task,
                                      slice_func_dict)
    else:
        scores = model.score(dataloaders)

    # Save metrics into file
    logger.info(f"Metrics: {scores}")
    write_to_file(Meta.log_path, "metrics.txt", scores)

    # Save best metrics into file
    if args.train:
        logger.info(
            f"Best metrics: "
            f"{emmental_learner.logging_manager.checkpointer.best_metric_dict}"
        )
        write_to_file(
            Meta.log_path,
            "best_metrics.txt",
            emmental_learner.logging_manager.checkpointer.best_metric_dict,
        )

    # Save submission file
    for task_name in args.task:
        dataloaders = [d for d in dataloaders if d.split == "test"]
        assert len(dataloaders) == 1
        filepath = os.path.join(Meta.log_path, f"{task_name}.jsonl")
        make_submission_file(model, dataloaders[0], task_name, filepath)
Пример #30
0
    def learn(self, model: EmmentalModel,
              dataloaders: List[EmmentalDataLoader]) -> None:
        """Learning procedure of emmental MTL.

        Args:
          model: The emmental model that needs to learn.
          dataloaders: A list of dataloaders used to learn the model.
        """
        # Generate the list of dataloaders for learning process
        start_time = time.time()

        train_split = Meta.config["learner_config"]["train_split"]
        if isinstance(train_split, str):
            train_split = [train_split]

        train_dataloaders = [
            dataloader for dataloader in dataloaders
            if dataloader.split in train_split
        ]

        if not train_dataloaders:
            raise ValueError(
                f"Cannot find the specified train_split "
                f'{Meta.config["learner_config"]["train_split"]} in dataloaders.'
            )

        # Set up task_scheduler
        self._set_task_scheduler()

        # Calculate the total number of batches per epoch
        self.n_batches_per_epoch = self.task_scheduler.get_num_batches(
            train_dataloaders)

        # Set up logging manager
        self._set_logging_manager()
        # Set up optimizer
        self._set_optimizer(model)
        # Set up lr_scheduler
        self._set_lr_scheduler(model)

        if Meta.config["learner_config"]["fp16"]:
            try:
                from apex import amp  # type: ignore
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to "
                    "use fp16 training.")
            logger.info(
                f"Modeling training with 16-bit (mixed) precision "
                f"and {Meta.config['learner_config']['fp16_opt_level']} opt level."
            )
            model, self.optimizer = amp.initialize(
                model,
                self.optimizer,
                opt_level=Meta.config["learner_config"]["fp16_opt_level"],
            )

        # Multi-gpu training (after apex fp16 initialization)
        if (Meta.config["learner_config"]["local_rank"] == -1
                and Meta.config["model_config"]["dataparallel"]):
            model._to_dataparallel()

        # Distributed training (after apex fp16 initialization)
        if Meta.config["learner_config"]["local_rank"] != -1:
            model._to_distributed_dataparallel()

        # Set to training mode
        model.train()

        if Meta.config["meta_config"]["verbose"]:
            logger.info("Start learning...")

        self.metrics: Dict[str, float] = dict()
        self._reset_losses()

        # Set gradients of all model parameters to zero
        self.optimizer.zero_grad()

        for epoch_num in range(Meta.config["learner_config"]["n_epochs"]):
            batches = tqdm(
                enumerate(
                    self.task_scheduler.get_batches(train_dataloaders, model)),
                total=self.n_batches_per_epoch,
                disable=(not Meta.config["meta_config"]["verbose"]
                         or Meta.config["learner_config"]["local_rank"]
                         not in [-1, 0]),
                desc=f"Epoch {epoch_num}:",
            )

            for batch_num, batch in batches:
                # Covert single batch into a batch list
                if not isinstance(batch, list):
                    batch = [batch]

                total_batch_num = epoch_num * self.n_batches_per_epoch + batch_num
                batch_size = 0

                for uids, X_dict, Y_dict, task_to_label_dict, data_name, split in batch:
                    batch_size += len(next(iter(Y_dict.values())))

                    # Perform forward pass and calcualte the loss and count
                    uid_dict, loss_dict, prob_dict, gold_dict = model(
                        uids, X_dict, Y_dict, task_to_label_dict)

                    # Update running loss and count
                    for task_name in uid_dict.keys():
                        identifier = f"{task_name}/{data_name}/{split}"
                        self.running_uids[identifier].extend(
                            uid_dict[task_name])
                        self.running_losses[identifier] += (
                            loss_dict[task_name].item() *
                            len(uid_dict[task_name])
                            if len(loss_dict[task_name].size()) == 0 else
                            torch.sum(loss_dict[task_name]).item())
                        self.running_probs[identifier].extend(
                            prob_dict[task_name])
                        self.running_golds[identifier].extend(
                            gold_dict[task_name])

                    # Skip the backward pass if no loss is calcuated
                    if not loss_dict:
                        continue

                    # Calculate the average loss
                    loss = sum([
                        model.weights[task_name] *
                        task_loss if len(task_loss.size()) == 0 else
                        torch.mean(model.weights[task_name] * task_loss)
                        for task_name, task_loss in loss_dict.items()
                    ])

                    # Perform backward pass to calculate gradients
                    if Meta.config["learner_config"]["fp16"]:
                        with amp.scale_loss(loss,
                                            self.optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()  # type: ignore

                if (total_batch_num +
                        1) % Meta.config["learner_config"]["optimizer_config"][
                            "gradient_accumulation_steps"] == 0 or (
                                batch_num + 1 == self.n_batches_per_epoch
                                and epoch_num + 1
                                == Meta.config["learner_config"]["n_epochs"]):
                    # Clip gradient norm
                    if Meta.config["learner_config"]["optimizer_config"][
                            "grad_clip"]:
                        if Meta.config["learner_config"]["fp16"]:
                            torch.nn.utils.clip_grad_norm_(
                                amp.master_params(self.optimizer),
                                Meta.config["learner_config"]
                                ["optimizer_config"]["grad_clip"],
                            )
                        else:
                            torch.nn.utils.clip_grad_norm_(
                                model.parameters(),
                                Meta.config["learner_config"]
                                ["optimizer_config"]["grad_clip"],
                            )

                    # Update the parameters
                    self.optimizer.step()

                    # Set gradients of all model parameters to zero
                    self.optimizer.zero_grad()

                if Meta.config["learner_config"]["local_rank"] in [-1, 0]:
                    self.metrics.update(
                        self._logging(model, dataloaders, batch_size))

                    batches.set_postfix(self.metrics)

                # Update lr using lr scheduler
                self._update_lr_scheduler(model, total_batch_num, self.metrics)

        if Meta.config["learner_config"]["local_rank"] in [-1, 0]:
            model = self.logging_manager.close(model)
        logger.info(
            f"Total learning time: {time.time() - start_time} seconds.")