예제 #1
0
파일: test_task.py 프로젝트: SenWu/emmental
def test_emmental_task(caplog):
    """Unit test of emmental task."""
    caplog.set_level(logging.INFO)

    emmental.init()

    def ce_loss(module_name, output_dict, Y):
        return F.cross_entropy(output_dict[module_name][0], Y.view(-1))

    def output(module_name, output_dict):
        return F.softmax(output_dict[module_name][0], dim=1)

    task_name = "task1"
    task_metrics = {task_name: ["accuracy"]}
    scorer = Scorer(metrics=task_metrics[task_name])

    task = EmmentalTask(
        name=task_name,
        module_pool=nn.ModuleDict({
            "input_module0": IdentityModule(),
            "input_module1": IdentityModule(),
            f"{task_name}_pred_head": IdentityModule(),
        }),
        task_flow=[
            Action("input1", "input_module0", [("_input_", "data")]),
            Action("input2", "input_module1", [("input1", 0)]),
            Action(f"{task_name}_pred_head", f"{task_name}_pred_head",
                   [("input2", 0)]),
        ],
        module_device={
            "input_module0": -1,
            "input_module1": 0,
            "input_module": -1
        },
        loss_func=partial(ce_loss, f"{task_name}_pred_head"),
        output_func=partial(output, f"{task_name}_pred_head"),
        action_outputs=None,
        scorer=scorer,
        require_prob_for_eval=False,
        require_pred_for_eval=True,
        weight=2.0,
    )

    assert task.name == task_name
    assert set(list(task.module_pool.keys())) == set(
        ["input_module0", "input_module1", f"{task_name}_pred_head"])
    assert task.action_outputs is None
    assert task.scorer == scorer
    assert len(task.task_flow) == 3
    assert task.module_device == {
        "input_module0": torch.device("cpu"),
        "input_module1": torch.device(0),
    }
    assert task.require_prob_for_eval is False
    assert task.require_pred_for_eval is True
    assert task.weight == 2.0
예제 #2
0
def create_task(task_name, n_class=2, model="resnet18", pretrained=True):

    feature_extractor = get_cnn(model, pretrained, num_classes=n_class)

    loss = sce_loss
    output = output_classification

    logger.info(f"Built model: {feature_extractor}")

    return EmmentalTask(
        name=task_name,
        module_pool=nn.ModuleDict({
            "feature": feature_extractor,
            f"{task_name}_pred_head": IdentityModule()
        }),
        task_flow=[
            {
                "name": "feature",
                "module": "feature",
                "inputs": [("_input_", "image")]
            },
            {
                "name": f"{task_name}_pred_head",
                "module": f"{task_name}_pred_head",
                "inputs": [("feature", 0)],
            },
        ],
        loss_func=partial(loss, f"{task_name}_pred_head"),
        output_func=partial(output, f"{task_name}_pred_head"),
        scorer=Scorer(metrics=["precision", "recall", "f1"]),
    )
예제 #3
0
def test_identity_module(caplog):
    """Unit test of Identity Module."""
    caplog.set_level(logging.INFO)

    identity_module = IdentityModule()

    input = torch.randn(10, 10)
    assert torch.equal(input, identity_module(input))
 def get_tv_encoder(self, net_name, pretrained, drop_rate):
     # HACK: replace linear with identity -- ideally remove this
     net = getattr(models, net_name, None)
     if net is None:
         raise ValueError(f"Unknown torchvision network {net_name}")
     if "densenet" in net_name.lower():
         model = net(pretrained=pretrained, drop_rate=drop_rate)
         self.encode_dim = int(model.classifier.weight.size()[1])
         model.classifier = IdentityModule()
         # model = torch.nn.Sequential(*(list(model.children())[:-1]))
     elif "resnet" in net_name.lower():
         model = net(pretrained=pretrained)
         self.encode_dim = int(model.fc.weight.size()[1])
         model.fc = IdentityModule()
         # model = torch.nn.Sequential(*(list(model.children())[:-1]))
     else:
         raise ValueError("Network {net_name} not supported")
     return model
예제 #5
0
def create_task(task_name, args, nclasses, emb_layer):
    if args.model == "cnn":
        input_module = IdentityModule()
        feature_extractor = CNN(emb_layer.n_d,
                                widths=[3, 4, 5],
                                filters=args.n_filters)
        d_out = args.n_filters * 3
    elif args.model == "lstm":
        input_module = IdentityModule()
        feature_extractor = LSTM(emb_layer.n_d,
                                 args.dim,
                                 args.depth,
                                 dropout=args.dropout)
        d_out = args.dim
    elif args.model == "mlp":
        input_module = Average()
        feature_extractor = nn.Sequential(nn.Linear(emb_layer.n_d, args.dim),
                                          nn.ReLU())
        d_out = args.dim
    else:
        raise ValueError(f"Unrecognized model {args.model}.")

    return EmmentalTask(
        name=task_name,
        module_pool=nn.ModuleDict({
            "emb":
            emb_layer,
            "input":
            input_module,
            "feature":
            feature_extractor,
            "dropout":
            nn.Dropout(args.dropout),
            f"{task_name}_pred_head":
            nn.Linear(d_out, nclasses),
        }),
        task_flow=[
            {
                "name": "emb",
                "module": "emb",
                "inputs": [("_input_", "feature")]
            },
            {
                "name": "input",
                "module": "input",
                "inputs": [("emb", 0)],
            },
            {
                "name": "feature",
                "module": "feature",
                "inputs": [("input", 0)]
            },
            {
                "name": "dropout",
                "module": "dropout",
                "inputs": [("feature", 0)]
            },
            {
                "name": f"{task_name}_pred_head",
                "module": f"{task_name}_pred_head",
                "inputs": [("dropout", 0)],
            },
        ],
        loss_func=partial(ce_loss, f"{task_name}_pred_head"),
        output_func=partial(output, f"{task_name}_pred_head"),
        scorer=Scorer(metrics=["accuracy"]),
    )
예제 #6
0
def test_model_invalid_task(caplog):
    """Unit test of model with invalid task."""
    caplog.set_level(logging.INFO)

    dirpath = "temp_test_model_with_invalid_task"

    Meta.reset()
    init(
        dirpath,
        config={
            "meta_config": {
                "verbose": 0
            },
        },
    )

    task_name = "task1"

    task = EmmentalTask(
        name=task_name,
        module_pool=nn.ModuleDict({
            "input_module0": IdentityModule(),
            f"{task_name}_pred_head": IdentityModule(),
        }),
        task_flow=[
            {
                "name": "input1",
                "module": "input_module0",
                "inputs": [("_input_", "data")],
            },
            {
                "name": f"{task_name}_pred_head",
                "module": f"{task_name}_pred_head",
                "inputs": [("input1", 0)],
            },
        ],
        module_device={"input_module0": -1},
        loss_func=None,
        output_func=None,
        action_outputs=None,
        scorer=None,
        require_prob_for_eval=False,
        require_pred_for_eval=True,
    )

    task1 = EmmentalTask(
        name=task_name,
        module_pool=nn.ModuleDict({
            "input_module0": IdentityModule(),
            f"{task_name}_pred_head": IdentityModule(),
        }),
        task_flow=[
            {
                "name": "input1",
                "module": "input_module0",
                "inputs": [("_input_", "data")],
            },
            {
                "name": f"{task_name}_pred_head",
                "module": f"{task_name}_pred_head",
                "inputs": [("input1", 0)],
            },
        ],
        module_device={"input_module0": -1},
        loss_func=None,
        output_func=None,
        action_outputs=None,
        scorer=None,
        require_prob_for_eval=False,
        require_pred_for_eval=True,
    )

    model = EmmentalModel(name="test")
    model.add_task(task)

    model.remove_task(task_name)
    assert model.task_names == set([])

    model.remove_task("task_2")
    assert model.task_names == set([])

    model.add_task(task)

    # Duplicate task
    with pytest.raises(ValueError):
        model.add_task(task1)

    # Invalid task
    with pytest.raises(ValueError):
        model.add_task(task_name)

    shutil.rmtree(dirpath)