示例#1
0
def test_checkpoint_metric(caplog):
    """Unit test of parsing checkpoint metric."""
    caplog.set_level(logging.INFO)

    # Test different checkpoint_metric
    dirpath = "temp_parse_args"
    Meta.reset()

    init(
        log_dir=dirpath,
        config={
            "logging_config": {
                "checkpointer_config": {
                    "checkpoint_metric": {
                        "model/valid/all/accuracy": "max"
                    }
                }
            }
        },
    )

    assert Meta.config == {
        "meta_config": {
            "seed": None,
            "verbose": True,
            "log_path": "logs",
            "use_exact_log_path": False,
        },
        "data_config": {
            "min_data_len": 0,
            "max_data_len": 0
        },
        "model_config": {
            "model_path": None,
            "device": 0,
            "dataparallel": True,
            "distributed_backend": "nccl",
        },
        "learner_config": {
            "optimizer_path": None,
            "scheduler_path": None,
            "fp16": False,
            "fp16_opt_level": "O1",
            "local_rank": -1,
            "epochs_learned": 0,
            "n_epochs": 1,
            "steps_learned": 0,
            "n_steps": None,
            "train_split": ["train"],
            "valid_split": ["valid"],
            "test_split": ["test"],
            "ignore_index": None,
            "online_eval": False,
            "global_evaluation_metric_dict": None,
            "optimizer_config": {
                "optimizer": "adam",
                "parameters": None,
                "lr": 0.001,
                "l2": 0.0,
                "grad_clip": None,
                "gradient_accumulation_steps": 1,
                "asgd_config": {
                    "lambd": 0.0001,
                    "alpha": 0.75,
                    "t0": 1000000.0
                },
                "adadelta_config": {
                    "rho": 0.9,
                    "eps": 1e-06
                },
                "adagrad_config": {
                    "lr_decay": 0,
                    "initial_accumulator_value": 0,
                    "eps": 1e-10,
                },
                "adam_config": {
                    "betas": (0.9, 0.999),
                    "amsgrad": False,
                    "eps": 1e-08
                },
                "adamw_config": {
                    "betas": (0.9, 0.999),
                    "amsgrad": False,
                    "eps": 1e-08
                },
                "adamax_config": {
                    "betas": (0.9, 0.999),
                    "eps": 1e-08
                },
                "lbfgs_config": {
                    "max_iter": 20,
                    "max_eval": None,
                    "tolerance_grad": 1e-07,
                    "tolerance_change": 1e-09,
                    "history_size": 100,
                    "line_search_fn": None,
                },
                "rms_prop_config": {
                    "alpha": 0.99,
                    "eps": 1e-08,
                    "momentum": 0,
                    "centered": False,
                },
                "r_prop_config": {
                    "etas": (0.5, 1.2),
                    "step_sizes": (1e-06, 50)
                },
                "sgd_config": {
                    "momentum": 0,
                    "dampening": 0,
                    "nesterov": False
                },
                "sparse_adam_config": {
                    "betas": (0.9, 0.999),
                    "eps": 1e-08
                },
                "bert_adam_config": {
                    "betas": (0.9, 0.999),
                    "eps": 1e-08
                },
            },
            "lr_scheduler_config": {
                "lr_scheduler": None,
                "lr_scheduler_step_unit": "batch",
                "lr_scheduler_step_freq": 1,
                "warmup_steps": None,
                "warmup_unit": "batch",
                "warmup_percentage": None,
                "min_lr": 0.0,
                "reset_state": False,
                "exponential_config": {
                    "gamma": 0.9
                },
                "plateau_config": {
                    "metric": "model/train/all/loss",
                    "mode": "min",
                    "factor": 0.1,
                    "patience": 10,
                    "threshold": 0.0001,
                    "threshold_mode": "rel",
                    "cooldown": 0,
                    "eps": 1e-08,
                },
                "step_config": {
                    "step_size": 1,
                    "gamma": 0.1,
                    "last_epoch": -1
                },
                "multi_step_config": {
                    "milestones": [1000],
                    "gamma": 0.1,
                    "last_epoch": -1,
                },
                "cyclic_config": {
                    "base_lr": 0.001,
                    "base_momentum": 0.8,
                    "cycle_momentum": True,
                    "gamma": 1.0,
                    "last_epoch": -1,
                    "max_lr": 0.1,
                    "max_momentum": 0.9,
                    "mode": "triangular",
                    "scale_fn": None,
                    "scale_mode": "cycle",
                    "step_size_down": None,
                    "step_size_up": 2000,
                },
                "one_cycle_config": {
                    "anneal_strategy": "cos",
                    "base_momentum": 0.85,
                    "cycle_momentum": True,
                    "div_factor": 25.0,
                    "final_div_factor": 10000.0,
                    "last_epoch": -1,
                    "max_lr": 0.1,
                    "max_momentum": 0.95,
                    "pct_start": 0.3,
                },
                "cosine_annealing_config": {
                    "last_epoch": -1
                },
            },
            "task_scheduler_config": {
                "task_scheduler": "round_robin",
                "sequential_scheduler_config": {
                    "fillup": False
                },
                "round_robin_scheduler_config": {
                    "fillup": False
                },
                "mixed_scheduler_config": {
                    "fillup": False
                },
            },
        },
        "logging_config": {
            "counter_unit": "epoch",
            "evaluation_freq": 1,
            "writer_config": {
                "writer": "tensorboard",
                "verbose": True
            },
            "checkpointing": False,
            "checkpointer_config": {
                "checkpoint_path": None,
                "checkpoint_freq": 1,
                "checkpoint_metric": {
                    "model/valid/all/accuracy": "max"
                },
                "checkpoint_task_metrics": None,
                "checkpoint_runway": 0,
                "checkpoint_all": False,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": False,
            },
        },
    }

    shutil.rmtree(dirpath)
示例#2
0
def test_step_scheduler(caplog):
    """Unit test of step scheduler."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_scheduler"
    model = nn.Linear(1, 1)
    emmental_learner = EmmentalLearner()

    Meta.reset()
    emmental.init(dirpath)

    # Test warmup steps
    config = {
        "learner_config": {
            "n_epochs": 4,
            "optimizer_config": {"optimizer": "sgd", "lr": 10},
            "lr_scheduler_config": {
                "lr_scheduler": None,
                "warmup_steps": 2,
                "warmup_unit": "batch",
            },
        }
    }
    emmental.Meta.update_config(config)
    emmental_learner.n_batches_per_epoch = 1
    emmental_learner._set_optimizer(model)
    emmental_learner._set_lr_scheduler(model)

    assert emmental_learner.optimizer.param_groups[0]["lr"] == 0

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 0, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 5) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 1, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 2, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 3, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

    Meta.reset()
    emmental.init(dirpath)

    # Test warmup percentage
    config = {
        "learner_config": {
            "n_epochs": 4,
            "optimizer_config": {"optimizer": "sgd", "lr": 10},
            "lr_scheduler_config": {
                "lr_scheduler": None,
                "warmup_percentage": 0.5,
                "warmup_unit": "epoch",
            },
        }
    }
    emmental.Meta.update_config(config)
    emmental_learner.n_batches_per_epoch = 1
    emmental_learner._set_optimizer(model)
    emmental_learner._set_lr_scheduler(model)

    assert emmental_learner.optimizer.param_groups[0]["lr"] == 0

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 0, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 5) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 1, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 2, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 3, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

    shutil.rmtree(dirpath)
示例#3
0
def test_emmental_dataloader(caplog):
    """Unit test of emmental dataloader."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_data"

    Meta.reset()
    emmental.init(dirpath)

    x1 = [
        torch.Tensor([1]),
        torch.Tensor([1, 2]),
        torch.Tensor([1, 2, 3]),
        torch.Tensor([1, 2, 3, 4]),
        torch.Tensor([1, 2, 3, 4, 5]),
    ]

    y1 = torch.Tensor([0, 0, 0, 0, 0])

    x2 = [
        torch.Tensor([1, 2, 3, 4, 5]),
        torch.Tensor([1, 2, 3, 4]),
        torch.Tensor([1, 2, 3]),
        torch.Tensor([1, 2]),
        torch.Tensor([1]),
    ]

    y2 = torch.Tensor([1, 1, 1, 1, 1])

    dataset = EmmentalDataset(
        X_dict={
            "data1": x1,
            "data2": x2
        },
        Y_dict={
            "label1": y1,
            "label2": y2
        },
        name="new_data",
    )

    dataloader1 = EmmentalDataLoader(
        task_to_label_dict={"task1": "label1"},
        dataset=dataset,
        split="train",
        batch_size=2,
    )

    x_batch, y_batch = next(iter(dataloader1))

    # Check if the dataloader is correctly constructed
    assert dataloader1.task_to_label_dict == {"task1": "label1"}
    assert dataloader1.split == "train"
    assert torch.equal(x_batch["data1"], torch.Tensor([[1, 0], [1, 2]]))
    assert torch.equal(x_batch["data2"],
                       torch.Tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 0]]))
    assert torch.equal(y_batch["label1"], torch.Tensor([0, 0]))
    assert torch.equal(y_batch["label2"], torch.Tensor([1, 1]))

    dataloader2 = EmmentalDataLoader(
        task_to_label_dict={"task2": "label2"},
        dataset=dataset,
        split="test",
        batch_size=3,
    )

    x_batch, y_batch = next(iter(dataloader2))

    # Check if the dataloader with differet batch size is correctly constructed
    assert dataloader2.task_to_label_dict == {"task2": "label2"}
    assert dataloader2.split == "test"
    assert dataloader2.is_learnable is True
    assert torch.equal(x_batch["data1"],
                       torch.Tensor([[1, 0, 0], [1, 2, 0], [1, 2, 3]]))
    assert torch.equal(
        x_batch["data2"],
        torch.Tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 0], [1, 2, 3, 0, 0]]),
    )
    assert torch.equal(y_batch["label1"], torch.Tensor([0, 0, 0]))
    assert torch.equal(y_batch["label2"], torch.Tensor([1, 1, 1]))

    y3 = [
        torch.Tensor([2]),
        torch.Tensor([2]),
        torch.Tensor([2]),
        torch.Tensor([2]),
        torch.Tensor([2]),
    ]

    dataset.Y_dict["label2"] = y3

    x_batch, y_batch = next(iter(dataloader1))
    # Check dataloader is correctly updated with update dataset
    assert torch.equal(x_batch["data2"],
                       torch.Tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 0]]))
    assert torch.equal(y_batch["label2"], torch.Tensor([[2], [2]]))

    x_batch, y_batch = next(iter(dataloader2))
    assert torch.equal(
        x_batch["data2"],
        torch.Tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 0], [1, 2, 3, 0, 0]]),
    )
    assert torch.equal(y_batch["label2"], torch.Tensor([[2], [2], [2]]))

    dataset = EmmentalDataset(
        X_dict={"data1": x1},
        name="new_data",
    )

    dataloader3 = EmmentalDataLoader(
        task_to_label_dict=["task1"],
        dataset=dataset,
        split="train",
        batch_size=2,
    )

    x_batch = next(iter(dataloader3))

    # Check if the dataloader is correctly constructed
    assert dataloader3.task_to_label_dict == ["task1"]
    assert dataloader3.split == "train"
    assert dataloader3.is_learnable is False
    assert torch.equal(x_batch["data1"], torch.Tensor([[1, 0], [1, 2]]))

    shutil.rmtree(dirpath)
示例#4
0
def test_model(caplog):
    """Unit test of model."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_model"

    Meta.reset()
    emmental.init(dirpath)

    def ce_loss(module_name, immediate_output_dict, Y, active):
        return F.cross_entropy(immediate_output_dict[module_name][0][active],
                               (Y.view(-1))[active])

    def output(module_name, immediate_output_dict):
        return F.softmax(immediate_output_dict[module_name][0], dim=1)

    task1 = EmmentalTask(
        name="task_1",
        module_pool=nn.ModuleDict({
            "m1": nn.Linear(10, 10, bias=False),
            "m2": nn.Linear(10, 2, bias=False)
        }),
        task_flow=[
            {
                "name": "m1",
                "module": "m1",
                "inputs": [("_input_", "data")]
            },
            {
                "name": "m2",
                "module": "m2",
                "inputs": [("m1", 0)]
            },
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    new_task1 = EmmentalTask(
        name="task_1",
        module_pool=nn.ModuleDict({
            "m1": nn.Linear(10, 5, bias=False),
            "m2": nn.Linear(5, 2, bias=False)
        }),
        task_flow=[
            {
                "name": "m1",
                "module": "m1",
                "inputs": [("_input_", "data")]
            },
            {
                "name": "m2",
                "module": "m2",
                "inputs": [("m1", 0)]
            },
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    task2 = EmmentalTask(
        name="task_2",
        module_pool=nn.ModuleDict({
            "m1": nn.Linear(10, 5, bias=False),
            "m2": nn.Linear(5, 2, bias=False)
        }),
        task_flow=[
            {
                "name": "m1",
                "module": "m1",
                "inputs": [("_input_", "data")]
            },
            {
                "name": "m2",
                "module": "m2",
                "inputs": [("m1", 0)]
            },
        ],
        loss_func=partial(ce_loss, "m2"),
        output_func=partial(output, "m2"),
        scorer=Scorer(metrics=["accuracy"]),
    )

    config = {"model_config": {"dataparallel": False}}
    emmental.Meta.update_config(config)

    model = EmmentalModel(name="test", tasks=task1)

    assert repr(model) == "EmmentalModel(name=test)"
    assert model.name == "test"
    assert model.task_names == set(["task_1"])
    assert model.module_pool["m1"].weight.data.size() == (10, 10)
    assert model.module_pool["m2"].weight.data.size() == (2, 10)

    model.update_task(new_task1)

    assert model.module_pool["m1"].weight.data.size() == (5, 10)
    assert model.module_pool["m2"].weight.data.size() == (2, 5)

    model.update_task(task2)

    assert model.task_names == set(["task_1"])

    model.add_task(task2)

    assert model.task_names == set(["task_1", "task_2"])

    model.remove_task("task_1")
    assert model.task_names == set(["task_2"])

    model.remove_task("task_1")
    assert model.task_names == set(["task_2"])

    model.save(f"{dirpath}/saved_model.pth")

    model.load(f"{dirpath}/saved_model.pth")

    # Test add_tasks
    model = EmmentalModel(name="test")

    model.add_tasks([task1, task2])
    assert model.task_names == set(["task_1", "task_2"])

    shutil.rmtree(dirpath)
def test_bert_adam_optimizer(caplog):
    """Unit test of BertAdam optimizer"""

    caplog.set_level(logging.INFO)

    optimizer = "bert_adam"
    dirpath = "temp_test_optimizer"
    model = nn.Linear(1, 1)
    emmental_learner = EmmentalLearner()

    Meta.reset()
    emmental.init(dirpath)

    # Test default BertAdam setting
    config = {"learner_config": {"optimizer_config": {"optimizer": optimizer}}}
    emmental.Meta.update_config(config)
    emmental_learner._set_optimizer(model)

    assert emmental_learner.optimizer.defaults == {
        "lr": 0.001,
        "betas": (0.9, 0.999),
        "eps": 1e-08,
        "weight_decay": 0.0,
    }

    # Test new BertAdam setting
    config = {
        "learner_config": {
            "optimizer_config": {
                "optimizer": optimizer,
                "lr": 0.02,
                "l2": 0.05,
                f"{optimizer}_config": {
                    "betas": (0.8, 0.9),
                    "eps": 1e-05
                },
            }
        }
    }
    emmental.Meta.update_config(config)
    emmental_learner._set_optimizer(model)

    assert emmental_learner.optimizer.defaults == {
        "lr": 0.02,
        "betas": (0.8, 0.9),
        "eps": 1e-05,
        "weight_decay": 0.05,
    }

    # Test BertAdam setp
    emmental_learner.optimizer.zero_grad()
    torch.Tensor(1)
    F.mse_loss(model(torch.randn(1, 1)), torch.randn(1, 1)).backward()
    emmental_learner.optimizer.step()

    # Test wrong lr
    with pytest.raises(ValueError):
        config = {
            "learner_config": {
                "optimizer_config": {
                    "optimizer": optimizer,
                    "lr": -0.1,
                    "l2": 0.05,
                    f"{optimizer}_config": {
                        "betas": (0.8, 0.9),
                        "eps": 1e-05
                    },
                }
            }
        }
        emmental.Meta.update_config(config)
        emmental_learner._set_optimizer(model)

    # Test wrong eps
    with pytest.raises(ValueError):
        config = {
            "learner_config": {
                "optimizer_config": {
                    "optimizer": optimizer,
                    "lr": 0.1,
                    "l2": 0.05,
                    f"{optimizer}_config": {
                        "betas": (0.8, 0.9),
                        "eps": -1e-05
                    },
                }
            }
        }
        emmental.Meta.update_config(config)
        emmental_learner._set_optimizer(model)

    # Test wrong betas
    with pytest.raises(ValueError):
        config = {
            "learner_config": {
                "optimizer_config": {
                    "optimizer": optimizer,
                    "lr": 0.1,
                    "l2": 0.05,
                    f"{optimizer}_config": {
                        "betas": (-0.8, 0.9),
                        "eps": 1e-05
                    },
                }
            }
        }
        emmental.Meta.update_config(config)
        emmental_learner._set_optimizer(model)

    # Test wrong betas
    with pytest.raises(ValueError):
        config = {
            "learner_config": {
                "optimizer_config": {
                    "optimizer": optimizer,
                    "lr": 0.1,
                    "l2": 0.05,
                    f"{optimizer}_config": {
                        "betas": (0.8, -0.9),
                        "eps": 1e-05
                    },
                }
            }
        }
        emmental.Meta.update_config(config)
        emmental_learner._set_optimizer(model)

    shutil.rmtree(dirpath)
示例#6
0
def test_emmental_dataset(caplog):
    """Unit test of emmental dataset."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_data"

    Meta.reset()
    emmental.init(dirpath)

    x1 = [
        torch.Tensor([1]),
        torch.Tensor([1, 2]),
        torch.Tensor([1, 2, 3]),
        torch.Tensor([1, 2, 3, 4]),
        torch.Tensor([1, 2, 3, 4, 5]),
    ]

    y1 = torch.Tensor([0, 0, 0, 0, 0])

    dataset = EmmentalDataset(X_dict={"data1": x1},
                              Y_dict={"label1": y1},
                              name="new_data")

    # Check if the dataset is correctly constructed
    assert torch.equal(dataset[0][0]["data1"], x1[0])
    assert torch.equal(dataset[0][1]["label1"], y1[0])

    x2 = [
        torch.Tensor([1, 2, 3, 4, 5]),
        torch.Tensor([1, 2, 3, 4]),
        torch.Tensor([1, 2, 3]),
        torch.Tensor([1, 2]),
        torch.Tensor([1]),
    ]

    dataset.add_features(X_dict={"data2": x2})

    dataset.remove_feature("data2")
    assert "data2" not in dataset.X_dict

    dataset.add_features(X_dict={"data2": x2})

    # Check add one more feature to dataset
    assert torch.equal(dataset[0][0]["data2"], x2[0])

    y2 = torch.Tensor([1, 1, 1, 1, 1])

    dataset.add_labels(Y_dict={"label2": y2})

    with pytest.raises(ValueError):
        dataset.add_labels(Y_dict={"label2": x2})

    # Check add one more label to dataset
    assert torch.equal(dataset[0][1]["label2"], y2[0])

    dataset.remove_label(label_name="label1")

    # Check remove one more label to dataset
    assert "label1" not in dataset.Y_dict

    with pytest.raises(ValueError):
        dataset = EmmentalDataset(X_dict={"data1": x1},
                                  Y_dict={"label1": y1},
                                  name="new_data",
                                  uid="ids")

    dataset = EmmentalDataset(X_dict={"_uids_": x1},
                              Y_dict={"label1": y1},
                              name="new_data")

    dataset = EmmentalDataset(X_dict={"data1": x1}, name="new_data")

    # Check if the dataset is correctly constructed
    assert torch.equal(dataset[0]["data1"], x1[0])

    dataset.add_features(X_dict={"data2": x2})

    dataset.remove_feature("data2")
    assert "data2" not in dataset.X_dict

    dataset.add_features(X_dict={"data2": x2})

    # Check add one more feature to dataset
    assert torch.equal(dataset[0]["data2"], x2[0])

    y2 = torch.Tensor([1, 1, 1, 1, 1])

    dataset.add_labels(Y_dict={"label2": y2})

    # Check add one more label to dataset
    assert torch.equal(dataset[0][1]["label2"], y2[0])

    shutil.rmtree(dirpath)
def test_sparse_adam_optimizer(caplog):
    """Unit test of SparseAdam optimizer."""
    caplog.set_level(logging.INFO)

    optimizer = "sparse_adam"
    dirpath = "temp_test_optimizer"
    model = nn.Linear(1, 1)
    emmental_learner = EmmentalLearner()

    Meta.reset()
    emmental.init(dirpath)

    def grouped_parameters(model):
        no_decay = ["bias"]
        return [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                emmental.Meta.config["learner_config"]["optimizer_config"]
                ["l2"],
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]

    emmental.Meta.config["learner_config"]["optimizer_config"][
        "parameters"] = grouped_parameters

    # Test default SparseAdam setting
    config = {"learner_config": {"optimizer_config": {"optimizer": optimizer}}}
    emmental.Meta.update_config(config)
    emmental_learner._set_optimizer(model)

    assert emmental_learner.optimizer.defaults == {
        "lr": 0.001,
        "betas": (0.9, 0.999),
        "eps": 1e-08,
    }

    # Test new SparseAdam setting
    config = {
        "learner_config": {
            "optimizer_config": {
                "optimizer": optimizer,
                "lr": 0.02,
                "l2": 0.05,
                f"{optimizer}_config": {
                    "betas": (0.8, 0.9),
                    "eps": 1e-05
                },
            }
        }
    }
    emmental.Meta.update_config(config)
    emmental_learner._set_optimizer(model)

    assert emmental_learner.optimizer.defaults == {
        "lr": 0.02,
        "betas": (0.8, 0.9),
        "eps": 1e-05,
    }

    shutil.rmtree(dirpath)
def test_exponential_scheduler(caplog):
    """Unit test of exponential scheduler"""

    caplog.set_level(logging.INFO)

    lr_scheduler = "exponential"
    dirpath = "temp_test_scheduler"
    model = nn.Linear(1, 1)
    emmental_learner = EmmentalLearner()

    Meta.reset()
    emmental.init(dirpath)

    # Test step per batch
    config = {
        "learner_config": {
            "n_epochs": 4,
            "optimizer_config": {"optimizer": "sgd", "lr": 10},
            "lr_scheduler_config": {
                "lr_scheduler": lr_scheduler,
                "exponential_config": {"gamma": 0.1},
            },
        }
    }
    emmental.Meta.update_config(config)
    emmental_learner.n_batches_per_epoch = 1
    emmental_learner._set_optimizer(model)
    emmental_learner._set_lr_scheduler(model)

    assert emmental_learner.optimizer.param_groups[0]["lr"] == 10

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 0, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 1) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 1, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 0.1) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 2, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 0.01) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 3, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 0.001) < 1e-5

    # Test step per epoch
    config = {
        "learner_config": {
            "n_epochs": 4,
            "optimizer_config": {"optimizer": "sgd", "lr": 10},
            "lr_scheduler_config": {
                "lr_scheduler": lr_scheduler,
                "lr_scheduler_step_unit": "epoch",
                "exponential_config": {"gamma": 0.1},
            },
        }
    }
    emmental.Meta.update_config(config)
    emmental_learner.n_batches_per_epoch = 2
    emmental_learner._set_optimizer(model)
    emmental_learner._set_lr_scheduler(model)

    assert emmental_learner.optimizer.param_groups[0]["lr"] == 10

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 0, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 1, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 1) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 2, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 1) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 3, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 0.1) < 1e-5

    shutil.rmtree(dirpath)
示例#9
0
def test_cyclic_scheduler(caplog):
    """Unit test of cyclic scheduler"""

    caplog.set_level(logging.INFO)

    lr_scheduler = "cyclic"
    dirpath = "temp_test_scheduler"
    model = nn.Linear(1, 1)
    emmental_learner = EmmentalLearner()

    Meta.reset()
    emmental.init(dirpath)

    config = {
        "learner_config": {
            "n_epochs": 4,
            "optimizer_config": {
                "optimizer": "sgd",
                "lr": 10
            },
            "lr_scheduler_config": {
                "lr_scheduler": lr_scheduler,
                "cyclic_config": {
                    "base_lr": 10,
                    "base_momentum": 0.8,
                    "cycle_momentum": True,
                    "gamma": 1.0,
                    "last_epoch": -1,
                    "max_lr": 0.1,
                    "max_momentum": 0.9,
                    "mode": "triangular",
                    "scale_fn": None,
                    "scale_mode": "cycle",
                    "step_size_down": None,
                    "step_size_up": 2000,
                },
            },
        }
    }
    emmental.Meta.update_config(config)
    emmental_learner.n_batches_per_epoch = 1
    emmental_learner._set_optimizer(model)
    emmental_learner._set_lr_scheduler(model)

    assert emmental_learner.optimizer.param_groups[0]["lr"] == 10

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 0, {})
    assert (abs(emmental_learner.optimizer.param_groups[0]["lr"] -
                9.995049999999999) < 1e-5)

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 1, {})
    assert (abs(emmental_learner.optimizer.param_groups[0]["lr"] -
                9.990100000000002) < 1e-5)

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 2, {})
    assert (abs(emmental_learner.optimizer.param_groups[0]["lr"] -
                9.985149999999999) < 1e-5)

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 3, {})
    assert (abs(emmental_learner.optimizer.param_groups[0]["lr"] -
                9.980200000000002) < 1e-5)

    shutil.rmtree(dirpath)
示例#10
0
def test_plateau_scheduler(caplog):
    """Unit test of plateau scheduler"""

    caplog.set_level(logging.INFO)

    lr_scheduler = "plateau"
    dirpath = "temp_test_scheduler"
    model = nn.Linear(1, 1)
    emmental_learner = EmmentalLearner()

    Meta.reset()
    emmental.init(dirpath)

    config = {
        "learner_config": {
            "n_epochs": 4,
            "optimizer_config": {
                "optimizer": "sgd",
                "lr": 10
            },
            "lr_scheduler_config": {
                "lr_scheduler": lr_scheduler,
                "plateau_config": {
                    "metric": "model/train/all/loss",
                    "mode": "min",
                    "factor": 0.1,
                    "patience": 1,
                    "threshold": 0.0001,
                    "threshold_mode": "rel",
                    "cooldown": 0,
                    "eps": 1e-08,
                },
            },
        }
    }
    emmental.Meta.update_config(config)
    emmental_learner.n_batches_per_epoch = 1
    emmental_learner._set_optimizer(model)
    emmental_learner._set_lr_scheduler(model)

    assert emmental_learner.optimizer.param_groups[0]["lr"] == 10

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 0,
                                          {"model/train/all/loss": 1})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 1,
                                          {"model/train/all/loss": 1})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 2,
                                          {"model/train/all/loss": 1})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 1) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 3,
                                          {"model/train/all/loss": 0.1})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 1) < 1e-5

    shutil.rmtree(dirpath)
示例#11
0
def test_e2e(caplog):
    """Run an end-to-end test."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_e2e"

    Meta.reset()
    emmental.init(dirpath)

    # Generate synthetic data
    N = 50
    X = np.random.random((N, 2)) * 2 - 1
    Y1 = (X[:, 0] > X[:, 1] + 0.25).astype(int) + 1
    Y2 = (-X[:, 0] > X[:, 1] + 0.25).astype(int) + 1

    # Create dataset and dataloader

    splits = [0.8, 0.1, 0.1]

    X_train, X_dev, X_test = [], [], []
    Y1_train, Y1_dev, Y1_test = [], [], []
    Y2_train, Y2_dev, Y2_test = [], [], []

    for i in range(N):
        if i <= N * splits[0]:
            X_train.append(torch.Tensor(X[i]))
            Y1_train.append(Y1[i])
            Y2_train.append(Y2[i])
        elif i < N * (splits[0] + splits[1]):
            X_dev.append(torch.Tensor(X[i]))
            Y1_dev.append(Y1[i])
            Y2_dev.append(Y2[i])
        else:
            X_test.append(torch.Tensor(X[i]))
            Y1_test.append(Y1[i])
            Y2_test.append(Y2[i])

    Y1_train = torch.from_numpy(np.array(Y1_train))
    Y1_dev = torch.from_numpy(np.array(Y1_dev))
    Y1_test = torch.from_numpy(np.array(Y1_test))

    Y2_train = torch.from_numpy(np.array(Y1_train))
    Y2_dev = torch.from_numpy(np.array(Y2_dev))
    Y2_test = torch.from_numpy(np.array(Y2_test))

    train_dataset1 = EmmentalDataset(
        name="synthetic", X_dict={"data": X_train}, Y_dict={"label1": Y1_train}
    )

    train_dataset2 = EmmentalDataset(
        name="synthetic", X_dict={"data": X_train}, Y_dict={"label2": Y2_train}
    )

    dev_dataset1 = EmmentalDataset(
        name="synthetic", X_dict={"data": X_dev}, Y_dict={"label1": Y1_dev}
    )

    dev_dataset2 = EmmentalDataset(
        name="synthetic", X_dict={"data": X_dev}, Y_dict={"label2": Y2_dev}
    )

    test_dataset1 = EmmentalDataset(
        name="synthetic", X_dict={"data": X_test}, Y_dict={"label1": Y2_test}
    )

    test_dataset2 = EmmentalDataset(
        name="synthetic", X_dict={"data": X_test}, Y_dict={"label2": Y2_test}
    )

    task_to_label_dict = {"task1": "label1"}

    train_dataloader1 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=train_dataset1,
        split="train",
        batch_size=10,
    )
    dev_dataloader1 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=dev_dataset1,
        split="valid",
        batch_size=10,
    )
    test_dataloader1 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=test_dataset1,
        split="test",
        batch_size=10,
    )

    task_to_label_dict = {"task2": "label2"}

    train_dataloader2 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=train_dataset2,
        split="train",
        batch_size=10,
    )
    dev_dataloader2 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=dev_dataset2,
        split="valid",
        batch_size=10,
    )
    test_dataloader2 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=test_dataset2,
        split="test",
        batch_size=10,
    )

    # Create task
    def ce_loss(task_name, immediate_ouput_dict, Y, active):
        module_name = f"{task_name}_pred_head"
        return F.cross_entropy(
            immediate_ouput_dict[module_name][0][active], (Y.view(-1) - 1)[active]
        )

    def output(task_name, immediate_ouput_dict):
        module_name = f"{task_name}_pred_head"
        return F.softmax(immediate_ouput_dict[module_name][0], dim=1)

    task_name = "task1"

    task1 = EmmentalTask(
        name=task_name,
        module_pool=nn.ModuleDict(
            {"input_module": nn.Linear(2, 8), f"{task_name}_pred_head": nn.Linear(8, 2)}
        ),
        task_flow=[
            {
                "name": "input",
                "module": "input_module",
                "inputs": [("_input_", "data")],
            },
            {
                "name": f"{task_name}_pred_head",
                "module": f"{task_name}_pred_head",
                "inputs": [("input", 0)],
            },
        ],
        loss_func=partial(ce_loss, task_name),
        output_func=partial(output, task_name),
        scorer=Scorer(metrics=["accuracy", "roc_auc"]),
    )

    task_name = "task2"

    task2 = EmmentalTask(
        name=task_name,
        module_pool=nn.ModuleDict(
            {"input_module": nn.Linear(2, 8), f"{task_name}_pred_head": nn.Linear(8, 2)}
        ),
        task_flow=[
            {
                "name": "input",
                "module": "input_module",
                "inputs": [("_input_", "data")],
            },
            {
                "name": f"{task_name}_pred_head",
                "module": f"{task_name}_pred_head",
                "inputs": [("input", 0)],
            },
        ],
        loss_func=partial(ce_loss, task_name),
        output_func=partial(output, task_name),
        scorer=Scorer(metrics=["accuracy", "roc_auc"]),
    )

    # Build model

    mtl_model = EmmentalModel(name="all", tasks=[task1, task2])

    # Create learner

    emmental_learner = EmmentalLearner()

    # Update learning config
    Meta.update_config(
        config={"learner_config": {"n_epochs": 10, "optimizer_config": {"lr": 0.01}}}
    )

    # Learning
    emmental_learner.learn(
        mtl_model,
        [train_dataloader1, train_dataloader2, dev_dataloader1, dev_dataloader2],
    )

    test1_score = mtl_model.score(test_dataloader1)
    test2_score = mtl_model.score(test_dataloader2)

    assert test1_score["task1/synthetic/test/accuracy"] >= 0.5
    assert test1_score["task1/synthetic/test/roc_auc"] >= 0.6
    assert test2_score["task2/synthetic/test/accuracy"] >= 0.5
    assert test2_score["task2/synthetic/test/roc_auc"] >= 0.6

    shutil.rmtree(dirpath)
示例#12
0
def test_model_invalid_task(caplog):
    """Unit test of model with invalid task."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_model_with_invalid_task"

    Meta.reset()
    init(
        dirpath,
        config={
            "meta_config": {
                "verbose": 0
            },
        },
    )

    task_name = "task1"

    task = EmmentalTask(
        name=task_name,
        module_pool=nn.ModuleDict({
            "input_module0": IdentityModule(),
            f"{task_name}_pred_head": IdentityModule(),
        }),
        task_flow=[
            {
                "name": "input1",
                "module": "input_module0",
                "inputs": [("_input_", "data")],
            },
            {
                "name": f"{task_name}_pred_head",
                "module": f"{task_name}_pred_head",
                "inputs": [("input1", 0)],
            },
        ],
        module_device={"input_module0": -1},
        loss_func=None,
        output_func=None,
        action_outputs=None,
        scorer=None,
        require_prob_for_eval=False,
        require_pred_for_eval=True,
    )

    task1 = EmmentalTask(
        name=task_name,
        module_pool=nn.ModuleDict({
            "input_module0": IdentityModule(),
            f"{task_name}_pred_head": IdentityModule(),
        }),
        task_flow=[
            {
                "name": "input1",
                "module": "input_module0",
                "inputs": [("_input_", "data")],
            },
            {
                "name": f"{task_name}_pred_head",
                "module": f"{task_name}_pred_head",
                "inputs": [("input1", 0)],
            },
        ],
        module_device={"input_module0": -1},
        loss_func=None,
        output_func=None,
        action_outputs=None,
        scorer=None,
        require_prob_for_eval=False,
        require_pred_for_eval=True,
    )

    model = EmmentalModel(name="test")
    model.add_task(task)

    model.remove_task(task_name)
    assert model.task_names == set([])

    model.remove_task("task_2")
    assert model.task_names == set([])

    model.add_task(task)

    # Duplicate task
    with pytest.raises(ValueError):
        model.add_task(task1)

    # Invalid task
    with pytest.raises(ValueError):
        model.add_task(task_name)

    shutil.rmtree(dirpath)
                                     clobber_label=True)

# Setting random seed
seed = config['seed']
random.seed(seed)
np.random.seed(seed)

import emmental
from emmental import Meta
from emmental.data import EmmentalDataLoader
from emmental.learner import EmmentalLearner
from emmental.model import EmmentalModel
from emmental.utils.parse_arg import parse_arg, parse_arg_to_config

# HACK: To get Emmental to initialize correctly
Meta.reset()

#parser_emm = parse_arg()
#args_emm = parser_emm.parse_args()
#args_emm = vars(args_emm)
#args_emm.update({'model_path': config['prediction_model_path']})
#args_emm = SimpleNamespace(**args_emm)
#config_emm = parse_arg_to_config(args_emm)
emmental.init(
    config={'model_path': f"{config['prediction_model_path']}/checkpoint.pth"})

Meta.config["model_config"][
    "model_path"] = f"{config['prediction_model_path']}/checkpoint.pth"

# Defining tasks
task_names = ["ht_page"]
示例#14
0
def test_e2e_skip_trained_epoch(caplog):
    """Run an end-to-end test."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_e2e_skip_trained"
    use_exact_log_path = True
    Meta.reset()
    init(dirpath, use_exact_log_path=use_exact_log_path)

    # Generate synthetic data
    N = 500
    X = np.random.random((N, 2)) * 2 - 1
    Y = (X[:, 0] > X[:, 1] + 0.25).astype(int)

    X = [torch.Tensor(X[i]) for i in range(N)]
    # Create dataset and dataloader

    X_train, X_dev, X_test = (
        X[:int(0.8 * N)],
        X[int(0.8 * N):int(0.9 * N)],
        X[int(0.9 * N):],
    )
    Y_train, Y_dev, Y_test = (
        torch.tensor(Y[:int(0.8 * N)]),
        torch.tensor(Y[int(0.8 * N):int(0.9 * N)]),
        torch.tensor(Y[int(0.9 * N):]),
    )

    train_dataset = EmmentalDataset(
        name="synthetic",
        X_dict={"data": X_train},
        Y_dict={"label1": Y_train},
    )

    dev_dataset = EmmentalDataset(
        name="synthetic",
        X_dict={"data": X_dev},
        Y_dict={"label1": Y_dev},
    )

    test_dataset = EmmentalDataset(
        name="synthetic",
        X_dict={"data": X_test},
        Y_dict={"label1": Y_test},
    )

    task_to_label_dict = {"task1": "label1"}

    train_dataloader = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=train_dataset,
        split="train",
        batch_size=10,
    )
    dev_dataloader = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=dev_dataset,
        split="valid",
        batch_size=10,
    )
    test_dataloader = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=test_dataset,
        split="test",
        batch_size=10,
    )

    # Create task
    def ce_loss(task_name, immediate_output_dict, Y):
        module_name = f"{task_name}_pred_head"
        return F.cross_entropy(immediate_output_dict[module_name], Y)

    def output(task_name, immediate_output_dict):
        module_name = f"{task_name}_pred_head"
        return F.softmax(immediate_output_dict[module_name], dim=1)

    task_metrics = {"task1": ["accuracy"]}

    class IdentityModule(nn.Module):
        def __init__(self):
            """Initialize IdentityModule."""
            super().__init__()

        def forward(self, input):
            return {"out": input}

    tasks = [
        EmmentalTask(
            name=task_name,
            module_pool=nn.ModuleDict({
                "input_module0":
                IdentityModule(),
                "input_module1":
                nn.Linear(2, 8),
                f"{task_name}_pred_head":
                nn.Linear(8, 2),
            }),
            task_flow=[
                Action(name="input",
                       module="input_module0",
                       inputs=[("_input_", "data")]),
                Action(name="input1",
                       module="input_module1",
                       inputs=[("input", "out")]),
                Action(
                    name=f"{task_name}_pred_head",
                    module=f"{task_name}_pred_head",
                    inputs=[("input1", 0)],
                ),
            ],
            module_device={"input_module0": -1},
            loss_func=partial(ce_loss, task_name),
            output_func=partial(output, task_name),
            action_outputs=None,
            scorer=Scorer(metrics=task_metrics[task_name]),
            require_prob_for_eval=False,
            require_pred_for_eval=True,
        ) for task_name in ["task1"]
    ]
    # Build model

    model = EmmentalModel(name="all", tasks=tasks)

    # Create learner
    emmental_learner = EmmentalLearner()

    config = {
        "meta_config": {
            "seed": 0,
            "verbose": True
        },
        "learner_config": {
            "n_epochs": 1,
            "epochs_learned": 0,
            "steps_learned": 0,
            "skip_learned_data": False,
            "online_eval": True,
            "optimizer_config": {
                "lr": 0.01,
                "grad_clip": 100
            },
        },
        "logging_config": {
            "counter_unit": "batch",
            "evaluation_freq": 5,
            "writer_config": {
                "writer": "json",
                "write_loss_per_step": True,
                "verbose": True,
            },
            "checkpointing": True,
            "checkpointer_config": {
                "checkpoint_path": None,
                "checkpoint_freq": 1,
                "checkpoint_metric": {
                    "model/all/train/loss": "min"
                },
                "checkpoint_task_metrics": None,
                "checkpoint_runway": 1,
                "checkpoint_all": False,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": False,
            },
        },
    }
    Meta.update_config(config)

    # Learning
    emmental_learner.learn(
        model,
        [train_dataloader, dev_dataloader],
    )

    test_score = model.score(test_dataloader)

    assert test_score["task1/synthetic/test/loss"] > 0.3

    Meta.reset()
    init(dirpath, use_exact_log_path=use_exact_log_path)

    config = {
        "meta_config": {
            "seed": 0,
            "verbose": False
        },
        "learner_config": {
            "n_epochs":
            5,
            "epochs_learned":
            1,
            "steps_learned":
            0,
            "skip_learned_data":
            True,
            "online_eval":
            False,
            "optimizer_config": {
                "lr": 0.01,
                "grad_clip": 100
            },
            "optimizer_path":
            (f"{dirpath}/"
             "best_model_model_all_train_loss.optimizer.pth"),
            "scheduler_path":
            (f"{dirpath}/"
             "best_model_model_all_train_loss.scheduler.pth"),
        },
        "model_config": {
            "model_path":
            f"{dirpath}/best_model_model_all_train_loss.model.pth"
        },
        "logging_config": {
            "counter_unit": "batch",
            "evaluation_freq": 5,
            "writer_config": {
                "writer": "json",
                "write_loss_per_step": True,
                "verbose": True,
            },
            "checkpointing": True,
            "checkpointer_config": {
                "checkpoint_path": None,
                "checkpoint_freq": 1,
                "checkpoint_metric": {
                    "model/all/train/loss": "min"
                },
                "checkpoint_task_metrics": None,
                "checkpoint_runway": 1,
                "checkpoint_all": False,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": False,
            },
        },
    }
    Meta.update_config(config)

    if Meta.config["model_config"]["model_path"]:
        model.load(Meta.config["model_config"]["model_path"])

    # Learning
    emmental_learner.learn(
        model,
        [train_dataloader, dev_dataloader],
    )

    test_score = model.score(test_dataloader)

    assert test_score["task1/synthetic/test/loss"] <= 0.4

    shutil.rmtree(dirpath)
示例#15
0
def test_linear_scheduler(caplog):
    """Unit test of linear scheduler."""
    caplog.set_level(logging.INFO)

    lr_scheduler = "linear"
    dirpath = "temp_test_scheduler"
    model = nn.Linear(1, 1)
    emmental_learner = EmmentalLearner()

    Meta.reset()
    emmental.init(dirpath)

    # Test per batch
    config = {
        "learner_config": {
            "n_epochs": 4,
            "optimizer_config": {
                "optimizer": "sgd",
                "lr": 10
            },
            "lr_scheduler_config": {
                "lr_scheduler": lr_scheduler
            },
        }
    }
    emmental.Meta.update_config(config)
    emmental_learner.n_batches_per_epoch = 1
    emmental_learner._set_learning_counter()
    emmental_learner._set_optimizer(model)
    emmental_learner._set_lr_scheduler(model)

    assert emmental_learner.optimizer.param_groups[0]["lr"] == 10

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 0, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 7.5) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 1, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 5) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 2, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 2.5) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 3, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"]) < 1e-5

    # Test every 2 batches
    config = {
        "learner_config": {
            "n_epochs": 4,
            "optimizer_config": {
                "optimizer": "sgd",
                "lr": 10
            },
            "lr_scheduler_config": {
                "lr_scheduler": lr_scheduler,
                "lr_scheduler_step_freq": 2,
            },
        }
    }
    emmental.Meta.update_config(config)
    emmental_learner.n_batches_per_epoch = 1
    emmental_learner._set_optimizer(model)
    emmental_learner._set_lr_scheduler(model)

    assert emmental_learner.optimizer.param_groups[0]["lr"] == 10

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 0, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 10) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 1, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 7.5) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 2, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 7.5) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 3, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 5.0) < 1e-5

    shutil.rmtree(dirpath)
示例#16
0
def main(args):
    # Ensure that global state is fresh
    Meta.reset()

    # Initialize Emmental
    config = parse_arg_to_config(args)
    # HACK: handle None in model_path, proper way to handle this in the
    # next release of Emmental
    if (config["model_config"]["model_path"]
            and config["model_config"]["model_path"].lower() == "none"):
        config["model_config"]["model_path"] = None
    emmental.init(config["meta_config"]["log_path"], config=config)

    # Save command line argument into file
    cmd_msg = " ".join(sys.argv)
    logger.info(f"COMMAND: {cmd_msg}")
    write_to_file(Meta.log_path, "cmd.txt", cmd_msg)

    # Save Emmental config into file
    logger.info(f"Config: {Meta.config}")
    write_to_file(Meta.log_path, "config.txt", Meta.config)

    Meta.config["learner_config"]["global_evaluation_metric_dict"] = {
        f"model/SuperGLUE/{split}/score": partial(superglue_scorer,
                                                  split=split)
        for split in ["val"]
    }

    # Construct dataloaders and tasks and load slices
    dataloaders = []
    tasks = []

    for task_name in args.task:
        task_dataloaders = get_dataloaders(
            data_dir=args.data_dir,
            task_name=task_name,
            splits=["train", "val", "test"],
            max_sequence_length=args.max_sequence_length,
            max_data_samples=args.max_data_samples,
            tokenizer_name=args.bert_model,
            batch_size=args.batch_size,
            augment=args.augmentations,
        )
        task = models.model[task_name](
            args.bert_model,
            last_hidden_dropout_prob=args.last_hidden_dropout_prob)
        if args.slices:
            logger.info("Initializing task-specific slices")
            slice_func_dict = slicing.slice_func_dict[task_name]
            # Include general purpose slices
            if args.general_slices:
                logger.info("Including general slices")
                slice_func_dict.update(slicing.slice_func_dict["general"])

            task_dataloaders = slicing.add_slice_labels(
                task_name, task_dataloaders, slice_func_dict)

            slice_tasks = slicing.add_slice_tasks(task_name, task,
                                                  slice_func_dict,
                                                  args.slice_hidden_dim)
            tasks.extend(slice_tasks)
        else:
            tasks.append(task)

        dataloaders.extend(task_dataloaders)

    # Build Emmental model
    model = EmmentalModel(name=f"SuperGLUE", tasks=tasks)

    # Load pretrained model if necessary
    if Meta.config["model_config"]["model_path"]:
        model.load(Meta.config["model_config"]["model_path"])

    # Training
    if args.train:
        emmental_learner = EmmentalLearner()
        emmental_learner.learn(model, dataloaders)

    # If model is slice-aware, slice scores will be calculated from slice heads
    # If model is not slice-aware, manually calculate performance on slices
    if not args.slices:
        slice_func_dict = {}
        slice_keys = args.task
        if args.general_slices:
            slice_keys.append("general")

        for k in slice_keys:
            slice_func_dict.update(slicing.slice_func_dict[k])

        scores = slicing.score_slices(model, dataloaders, args.task,
                                      slice_func_dict)
    else:
        scores = model.score(dataloaders)

    # Save metrics into file
    logger.info(f"Metrics: {scores}")
    write_to_file(Meta.log_path, "metrics.txt", scores)

    # Save best metrics into file
    if args.train:
        logger.info(
            f"Best metrics: "
            f"{emmental_learner.logging_manager.checkpointer.best_metric_dict}"
        )
        write_to_file(
            Meta.log_path,
            "best_metrics.txt",
            emmental_learner.logging_manager.checkpointer.best_metric_dict,
        )

    # Save submission file
    for task_name in args.task:
        dataloaders = [d for d in dataloaders if d.split == "test"]
        assert len(dataloaders) == 1
        filepath = os.path.join(Meta.log_path, f"{task_name}.jsonl")
        make_submission_file(model, dataloaders[0], task_name, filepath)
示例#17
0
def test_emmental_dataloader(caplog):
    """Unit test of emmental dataloader."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_data"

    Meta.reset()
    emmental.init(dirpath)

    x1 = [
        torch.Tensor([1]),
        torch.Tensor([1, 2]),
        torch.Tensor([1, 2, 3]),
        torch.Tensor([1, 2, 3, 4]),
        torch.Tensor([1, 2, 3, 4, 5]),
    ]

    y1 = torch.Tensor([0, 0, 0, 0, 0])

    x2 = [
        torch.Tensor([1, 2, 3, 4, 5]),
        torch.Tensor([1, 2, 3, 4]),
        torch.Tensor([1, 2, 3]),
        torch.Tensor([1, 2]),
        torch.Tensor([1]),
    ]

    y2 = torch.Tensor([1, 1, 1, 1, 1])

    dataset = EmmentalDataset(
        X_dict={
            "data1": x1,
            "data2": x2
        },
        Y_dict={
            "label1": y1,
            "label2": y2
        },
        name="new_data",
    )

    dataloader1 = EmmentalDataLoader(
        task_to_label_dict={"task1": "label1"},
        dataset=dataset,
        split="train",
        batch_size=2,
        num_workers=2,
    )

    x_batch, y_batch = next(iter(dataloader1))

    # Check if the dataloader is correctly constructed
    assert dataloader1.task_to_label_dict == {"task1": "label1"}
    assert dataloader1.split == "train"
    assert torch.equal(x_batch["data1"], torch.Tensor([[1, 0], [1, 2]]))
    assert torch.equal(x_batch["data2"],
                       torch.Tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 0]]))
    assert torch.equal(y_batch["label1"], torch.Tensor([0, 0]))
    assert torch.equal(y_batch["label2"], torch.Tensor([1, 1]))

    dataloader2 = EmmentalDataLoader(
        task_to_label_dict={"task2": "label2"},
        dataset=dataset,
        split="test",
        batch_size=3,
        collate_fn=partial(emmental_collate_fn, min_data_len=0,
                           max_data_len=0),
    )

    x_batch, y_batch = next(iter(dataloader2))

    # Check if the dataloader with different batch size is correctly constructed
    assert dataloader2.task_to_label_dict == {"task2": "label2"}
    assert dataloader2.split == "test"
    assert torch.equal(x_batch["data1"],
                       torch.Tensor([[1, 0, 0], [1, 2, 0], [1, 2, 3]]))
    assert torch.equal(
        x_batch["data2"],
        torch.Tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 0], [1, 2, 3, 0, 0]]),
    )
    assert torch.equal(y_batch["label1"], torch.Tensor([0, 0, 0]))
    assert torch.equal(y_batch["label2"], torch.Tensor([1, 1, 1]))

    y3 = [
        torch.Tensor([2]),
        torch.Tensor([2]),
        torch.Tensor([2]),
        torch.Tensor([2]),
        torch.Tensor([2]),
    ]

    dataset.Y_dict["label2"] = y3

    x_batch, y_batch = next(iter(dataloader1))
    # Check dataloader is correctly updated with update dataset
    assert torch.equal(x_batch["data2"],
                       torch.Tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 0]]))
    assert torch.equal(y_batch["label2"], torch.Tensor([[2], [2]]))

    x_batch, y_batch = next(iter(dataloader2))
    assert torch.equal(
        x_batch["data2"],
        torch.Tensor([[1, 2, 3, 4, 5], [1, 2, 3, 4, 0], [1, 2, 3, 0, 0]]),
    )
    assert torch.equal(y_batch["label2"], torch.Tensor([[2], [2], [2]]))

    dataset = EmmentalDataset(X_dict={"data1": x1}, name="new_data")

    dataloader3 = EmmentalDataLoader(task_to_label_dict={"task1": None},
                                     dataset=dataset,
                                     split="train",
                                     batch_size=2)

    x_batch = next(iter(dataloader3))

    # Check if the dataloader is correctly constructed
    assert dataloader3.task_to_label_dict == {"task1": None}
    assert dataloader3.split == "train"
    assert torch.equal(x_batch["data1"], torch.Tensor([[1, 0], [1, 2]]))

    # Check there is an error if task_to_label_dict has task to label mapping while
    # no y_dict in dataset
    with pytest.raises(ValueError):
        EmmentalDataLoader(
            task_to_label_dict={"task1": "label1"},
            dataset=dataset,
            split="train",
            batch_size=2,
        )

    shutil.rmtree(dirpath)
示例#18
0
def test_e2e(caplog):
    """Run an end-to-end test."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_e2e"
    use_exact_log_path = False
    Meta.reset()
    emmental.init(dirpath, use_exact_log_path=use_exact_log_path)

    config = {
        "meta_config": {
            "seed": 0
        },
        "learner_config": {
            "n_epochs": 3,
            "optimizer_config": {
                "lr": 0.01,
                "grad_clip": 100
            },
        },
        "logging_config": {
            "counter_unit": "epoch",
            "evaluation_freq": 1,
            "writer_config": {
                "writer": "tensorboard",
                "verbose": True
            },
            "checkpointing": True,
            "checkpointer_config": {
                "checkpoint_path": None,
                "checkpoint_freq": 1,
                "checkpoint_metric": {
                    "model/all/train/loss": "min"
                },
                "checkpoint_task_metrics": None,
                "checkpoint_runway": 1,
                "checkpoint_all": False,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }
    emmental.Meta.update_config(config)

    # Generate synthetic data
    N = 500
    X = np.random.random((N, 2)) * 2 - 1
    Y1 = (X[:, 0] > X[:, 1] + 0.25).astype(int)
    Y2 = (X[:, 0] > X[:, 1] + 0.2).astype(int)

    X = [torch.Tensor(X[i]) for i in range(N)]
    # Create dataset and dataloader

    X_train, X_dev, X_test = (
        X[:int(0.8 * N)],
        X[int(0.8 * N):int(0.9 * N)],
        X[int(0.9 * N):],
    )
    Y1_train, Y1_dev, Y1_test = (
        torch.tensor(Y1[:int(0.8 * N)]),
        torch.tensor(Y1[int(0.8 * N):int(0.9 * N)]),
        torch.tensor(Y1[int(0.9 * N):]),
    )
    Y2_train, Y2_dev, Y2_test = (
        torch.tensor(Y2[:int(0.8 * N)]),
        torch.tensor(Y2[int(0.8 * N):int(0.9 * N)]),
        torch.tensor(Y2[int(0.9 * N):]),
    )

    train_dataset1 = EmmentalDataset(name="synthetic",
                                     X_dict={"data": X_train},
                                     Y_dict={"label1": Y1_train})

    train_dataset2 = EmmentalDataset(name="synthetic",
                                     X_dict={"data": X_train},
                                     Y_dict={"label2": Y2_train})

    dev_dataset1 = EmmentalDataset(name="synthetic",
                                   X_dict={"data": X_dev},
                                   Y_dict={"label1": Y1_dev})

    dev_dataset2 = EmmentalDataset(name="synthetic",
                                   X_dict={"data": X_dev},
                                   Y_dict={"label2": Y2_dev})

    test_dataset1 = EmmentalDataset(name="synthetic",
                                    X_dict={"data": X_test},
                                    Y_dict={"label1": Y1_test})

    test_dataset2 = EmmentalDataset(name="synthetic",
                                    X_dict={"data": X_test},
                                    Y_dict={"label2": Y2_test})

    task_to_label_dict = {"task1": "label1"}

    train_dataloader1 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=train_dataset1,
        split="train",
        batch_size=10,
    )
    dev_dataloader1 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=dev_dataset1,
        split="valid",
        batch_size=10,
    )
    test_dataloader1 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=test_dataset1,
        split="test",
        batch_size=10,
    )

    task_to_label_dict = {"task2": "label2"}

    train_dataloader2 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=train_dataset2,
        split="train",
        batch_size=10,
    )
    dev_dataloader2 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=dev_dataset2,
        split="valid",
        batch_size=10,
    )
    test_dataloader2 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=test_dataset2,
        split="test",
        batch_size=10,
    )

    # Create task
    def ce_loss(task_name, immediate_ouput_dict, Y, active):
        module_name = f"{task_name}_pred_head"
        return F.cross_entropy(immediate_ouput_dict[module_name][0][active],
                               (Y.view(-1))[active])

    def output(task_name, immediate_ouput_dict):
        module_name = f"{task_name}_pred_head"
        return F.softmax(immediate_ouput_dict[module_name][0], dim=1)

    task_metrics = {"task1": ["accuracy"], "task2": ["accuracy", "roc_auc"]}

    tasks = [
        EmmentalTask(
            name=task_name,
            module_pool=nn.ModuleDict({
                "input_module":
                nn.Linear(2, 8),
                f"{task_name}_pred_head":
                nn.Linear(8, 2),
            }),
            task_flow=[
                {
                    "name": "input",
                    "module": "input_module",
                    "inputs": [("_input_", "data")],
                },
                {
                    "name": f"{task_name}_pred_head",
                    "module": f"{task_name}_pred_head",
                    "inputs": [("input", 0)],
                },
            ],
            loss_func=partial(ce_loss, task_name),
            output_func=partial(output, task_name),
            scorer=Scorer(metrics=task_metrics[task_name]),
        ) for task_name in ["task1", "task2"]
    ]

    # Build model

    mtl_model = EmmentalModel(name="all", tasks=tasks)

    # Create learner
    emmental_learner = EmmentalLearner()

    # Learning
    emmental_learner.learn(
        mtl_model,
        [
            train_dataloader1, train_dataloader2, dev_dataloader1,
            dev_dataloader2
        ],
    )

    test1_score = mtl_model.score(test_dataloader1)
    test2_score = mtl_model.score(test_dataloader2)

    assert test1_score["task1/synthetic/test/accuracy"] >= 0.7
    assert (test1_score["model/all/test/macro_average"] ==
            test1_score["task1/synthetic/test/accuracy"])
    assert test2_score["task2/synthetic/test/accuracy"] >= 0.7
    assert test2_score["task2/synthetic/test/roc_auc"] >= 0.7

    shutil.rmtree(dirpath)
def test_one_cycle_scheduler(caplog):
    """Unit test of one cycle scheduler"""

    caplog.set_level(logging.INFO)

    lr_scheduler = "one_cycle"
    dirpath = "temp_test_scheduler"
    model = nn.Linear(1, 1)
    emmental_learner = EmmentalLearner()

    Meta.reset()
    emmental.init(dirpath)

    config = {
        "learner_config": {
            "n_epochs": 4,
            "optimizer_config": {"optimizer": "sgd", "lr": 10},
            "lr_scheduler_config": {
                "lr_scheduler": lr_scheduler,
                "one_cycle_config": {
                    "anneal_strategy": "cos",
                    "base_momentum": 0.85,
                    "cycle_momentum": True,
                    "div_factor": 1,
                    "final_div_factor": 10000.0,
                    "last_epoch": -1,
                    "max_lr": 0.1,
                    "max_momentum": 0.95,
                    "pct_start": 0.3,
                },
            },
        }
    }
    emmental.Meta.update_config(config)
    emmental_learner.n_batches_per_epoch = 1
    emmental_learner._set_optimizer(model)
    emmental_learner._set_lr_scheduler(model)

    assert emmental_learner.optimizer.param_groups[0]["lr"] == 0.1

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 0, {})
    assert (
        abs(emmental_learner.optimizer.param_groups[0]["lr"] - 0.08117637264392738)
        < 1e-5
    )

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 1, {})
    assert (
        abs(emmental_learner.optimizer.param_groups[0]["lr"] - 0.028312982462817687)
        < 1e-5
    )

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 2, {})
    assert abs(emmental_learner.optimizer.param_groups[0]["lr"] - 1e-05) < 1e-5

    emmental_learner.optimizer.step()
    emmental_learner._update_lr_scheduler(model, 3, {})
    assert (
        abs(emmental_learner.optimizer.param_groups[0]["lr"] - 0.028312982462817677)
        < 1e-5
    )

    shutil.rmtree(dirpath)
示例#20
0
def test_e2e(caplog):
    """Run an end-to-end test."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_e2e"
    use_exact_log_path = False
    Meta.reset()
    init(dirpath, use_exact_log_path=use_exact_log_path)

    config = {
        "meta_config": {
            "seed": 0
        },
        "learner_config": {
            "n_epochs": 3,
            "online_eval": True,
            "optimizer_config": {
                "lr": 0.01,
                "grad_clip": 100
            },
        },
        "data_config": {
            "max_data_len": 10
        },
        "logging_config": {
            "counter_unit": "epoch",
            "evaluation_freq": 0.2,
            "writer_config": {
                "writer": "tensorboard",
                "verbose": True
            },
            "checkpointing": True,
            "checkpointer_config": {
                "checkpoint_path": None,
                "checkpoint_freq": 1,
                "checkpoint_metric": {
                    "model/all/train/loss": "min"
                },
                "checkpoint_task_metrics": None,
                "checkpoint_runway": 1,
                "checkpoint_all": False,
                "clear_intermediate_checkpoints": True,
                "clear_all_checkpoints": True,
            },
        },
    }
    Meta.update_config(config)

    def grouped_parameters(model):
        no_decay = ["bias", "LayerNorm.weight"]
        return [
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
            {
                "params": [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]

    # Generate synthetic data
    N = 500
    X = np.random.random((N, 2)) * 2 - 1
    Y1 = (X[:, 0] > X[:, 1] + 0.25).astype(int)
    Y2 = (X[:, 0] > X[:, 1] + 0.2).astype(int)

    X = [torch.Tensor(X[i]) for i in range(N)]
    # Create dataset and dataloader

    X_train, X_dev, X_test = (
        X[:int(0.8 * N)],
        X[int(0.8 * N):int(0.9 * N)],
        X[int(0.9 * N):],
    )
    Y1_train, Y1_dev, Y1_test = (
        torch.tensor(Y1[:int(0.8 * N)]),
        torch.tensor(Y1[int(0.8 * N):int(0.9 * N)]),
        torch.tensor(Y1[int(0.9 * N):]),
    )
    Y2_train, Y2_dev, Y2_test = (
        torch.tensor(Y2[:int(0.8 * N)]),
        torch.tensor(Y2[int(0.8 * N):int(0.9 * N)]),
        torch.tensor(Y2[int(0.9 * N):]),
    )

    train_dataset1 = EmmentalDataset(name="synthetic",
                                     X_dict={"data": X_train},
                                     Y_dict={"label1": Y1_train})

    train_dataset2 = EmmentalDataset(name="synthetic",
                                     X_dict={"data": X_train},
                                     Y_dict={"label2": Y2_train})

    dev_dataset1 = EmmentalDataset(name="synthetic",
                                   X_dict={"data": X_dev},
                                   Y_dict={"label1": Y1_dev})

    dev_dataset2 = EmmentalDataset(name="synthetic",
                                   X_dict={"data": X_dev},
                                   Y_dict={"label2": Y2_dev})

    test_dataset1 = EmmentalDataset(name="synthetic",
                                    X_dict={"data": X_test},
                                    Y_dict={"label1": Y1_test})

    test_dataset2 = EmmentalDataset(name="synthetic",
                                    X_dict={"data": X_test},
                                    Y_dict={"label2": Y2_test})

    test_dataset3 = EmmentalDataset(name="synthetic", X_dict={"data": X_test})

    task_to_label_dict = {"task1": "label1"}

    train_dataloader1 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=train_dataset1,
        split="train",
        batch_size=10,
        num_workers=2,
    )
    dev_dataloader1 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=dev_dataset1,
        split="valid",
        batch_size=10,
        num_workers=2,
    )
    test_dataloader1 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=test_dataset1,
        split="test",
        batch_size=10,
        num_workers=2,
    )

    task_to_label_dict = {"task2": "label2"}

    train_dataloader2 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=train_dataset2,
        split="train",
        batch_size=10,
    )
    dev_dataloader2 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=dev_dataset2,
        split="valid",
        batch_size=10,
    )
    test_dataloader2 = EmmentalDataLoader(
        task_to_label_dict=task_to_label_dict,
        dataset=test_dataset2,
        split="test",
        batch_size=10,
    )

    test_dataloader3 = EmmentalDataLoader(
        task_to_label_dict={"task2": None},
        dataset=test_dataset3,
        split="test",
        batch_size=10,
    )

    # Create task
    def ce_loss(task_name, immediate_output_dict, Y):
        module_name = f"{task_name}_pred_head"
        return F.cross_entropy(immediate_output_dict[module_name], Y)

    def output(task_name, immediate_output_dict):
        module_name = f"{task_name}_pred_head"
        return F.softmax(immediate_output_dict[module_name], dim=1)

    task_metrics = {"task1": ["accuracy"], "task2": ["accuracy", "roc_auc"]}

    class IdentityModule(nn.Module):
        def __init__(self):
            """Initialize IdentityModule."""
            super().__init__()

        def forward(self, input):
            return input, input

    tasks = [
        EmmentalTask(
            name=task_name,
            module_pool=nn.ModuleDict({
                "input_module0":
                IdentityModule(),
                "input_module1":
                nn.Linear(2, 8),
                f"{task_name}_pred_head":
                nn.Linear(8, 2),
            }),
            task_flow=[
                Action(name="input",
                       module="input_module0",
                       inputs=[("_input_", "data")]),
                Action(name="input1",
                       module="input_module1",
                       inputs=[("input", 0)]),
                Action(
                    name=f"{task_name}_pred_head",
                    module=f"{task_name}_pred_head",
                    inputs="input1",
                ),
            ],
            module_device={"input_module0": -1},
            loss_func=partial(ce_loss, task_name),
            output_func=partial(output, task_name),
            action_outputs=[
                (f"{task_name}_pred_head", 0),
                ("_input_", "data"),
                (f"{task_name}_pred_head", 0),
                f"{task_name}_pred_head",
            ] if task_name == "task2" else None,
            scorer=Scorer(metrics=task_metrics[task_name]),
            require_prob_for_eval=True if task_name in ["task2"] else False,
            require_pred_for_eval=True if task_name in ["task1"] else False,
        ) for task_name in ["task1", "task2"]
    ]
    # Build model

    mtl_model = EmmentalModel(name="all", tasks=tasks)

    Meta.config["learner_config"]["optimizer_config"][
        "parameters"] = grouped_parameters

    # Create learner
    emmental_learner = EmmentalLearner()

    # Learning
    emmental_learner.learn(
        mtl_model,
        [
            train_dataloader1, train_dataloader2, dev_dataloader1,
            dev_dataloader2
        ],
    )

    test1_score = mtl_model.score(test_dataloader1)
    test2_score = mtl_model.score(test_dataloader2)

    assert test1_score["task1/synthetic/test/accuracy"] >= 0.7
    assert (test1_score["model/all/test/macro_average"] ==
            test1_score["task1/synthetic/test/accuracy"])
    assert test2_score["task2/synthetic/test/accuracy"] >= 0.7
    assert test2_score["task2/synthetic/test/roc_auc"] >= 0.7

    test2_pred = mtl_model.predict(test_dataloader2,
                                   return_action_outputs=True)
    test3_pred = mtl_model.predict(
        test_dataloader3,
        return_action_outputs=True,
        return_loss=False,
    )

    assert test2_pred["uids"] == test3_pred["uids"]
    assert False not in [
        np.array_equal(test2_pred["probs"]["task2"][idx],
                       test3_pred["probs"]["task2"][idx])
        for idx in range(len(test3_pred["probs"]["task2"]))
    ]
    assert "outputs" in test2_pred
    assert "outputs" in test3_pred
    assert False not in [
        np.array_equal(
            test2_pred["outputs"]["task2"]["task2_pred_head_0"][idx],
            test3_pred["outputs"]["task2"]["task2_pred_head_0"][idx],
        ) for idx in range(
            len(test2_pred["outputs"]["task2"]["task2_pred_head_0"]))
    ]
    assert False not in [
        np.array_equal(
            test2_pred["outputs"]["task2"]["_input__data"][idx],
            test3_pred["outputs"]["task2"]["_input__data"][idx],
        ) for idx in range(len(test2_pred["outputs"]["task2"]["_input__data"]))
    ]

    assert len(test3_pred["outputs"]["task2"]["task2_pred_head"]) == 50
    assert len(test2_pred["outputs"]["task2"]["task2_pred_head"]) == 50

    test4_pred = mtl_model.predict(test_dataloader2,
                                   return_action_outputs=False)
    assert "outputs" not in test4_pred

    shutil.rmtree(dirpath)