def test_emmental_task(caplog): """Unit test of emmental task.""" caplog.set_level(logging.INFO) emmental.init() def ce_loss(module_name, output_dict, Y): return F.cross_entropy(output_dict[module_name][0], Y.view(-1)) def output(module_name, output_dict): return F.softmax(output_dict[module_name][0], dim=1) task_name = "task1" task_metrics = {task_name: ["accuracy"]} scorer = Scorer(metrics=task_metrics[task_name]) task = EmmentalTask( name=task_name, module_pool=nn.ModuleDict({ "input_module0": IdentityModule(), "input_module1": IdentityModule(), f"{task_name}_pred_head": IdentityModule(), }), task_flow=[ Action("input1", "input_module0", [("_input_", "data")]), Action("input2", "input_module1", [("input1", 0)]), Action(f"{task_name}_pred_head", f"{task_name}_pred_head", [("input2", 0)]), ], module_device={ "input_module0": -1, "input_module1": 0, "input_module": -1 }, loss_func=partial(ce_loss, f"{task_name}_pred_head"), output_func=partial(output, f"{task_name}_pred_head"), action_outputs=None, scorer=scorer, require_prob_for_eval=False, require_pred_for_eval=True, weight=2.0, ) assert task.name == task_name assert set(list(task.module_pool.keys())) == set( ["input_module0", "input_module1", f"{task_name}_pred_head"]) assert task.action_outputs is None assert task.scorer == scorer assert len(task.task_flow) == 3 assert task.module_device == { "input_module0": torch.device("cpu"), "input_module1": torch.device(0), } assert task.require_prob_for_eval is False assert task.require_pred_for_eval is True assert task.weight == 2.0
def create_task(task_name, n_class=2, model="resnet18", pretrained=True): feature_extractor = get_cnn(model, pretrained, num_classes=n_class) loss = sce_loss output = output_classification logger.info(f"Built model: {feature_extractor}") return EmmentalTask( name=task_name, module_pool=nn.ModuleDict({ "feature": feature_extractor, f"{task_name}_pred_head": IdentityModule() }), task_flow=[ { "name": "feature", "module": "feature", "inputs": [("_input_", "image")] }, { "name": f"{task_name}_pred_head", "module": f"{task_name}_pred_head", "inputs": [("feature", 0)], }, ], loss_func=partial(loss, f"{task_name}_pred_head"), output_func=partial(output, f"{task_name}_pred_head"), scorer=Scorer(metrics=["precision", "recall", "f1"]), )
def test_identity_module(caplog): """Unit test of Identity Module.""" caplog.set_level(logging.INFO) identity_module = IdentityModule() input = torch.randn(10, 10) assert torch.equal(input, identity_module(input))
def get_tv_encoder(self, net_name, pretrained, drop_rate): # HACK: replace linear with identity -- ideally remove this net = getattr(models, net_name, None) if net is None: raise ValueError(f"Unknown torchvision network {net_name}") if "densenet" in net_name.lower(): model = net(pretrained=pretrained, drop_rate=drop_rate) self.encode_dim = int(model.classifier.weight.size()[1]) model.classifier = IdentityModule() # model = torch.nn.Sequential(*(list(model.children())[:-1])) elif "resnet" in net_name.lower(): model = net(pretrained=pretrained) self.encode_dim = int(model.fc.weight.size()[1]) model.fc = IdentityModule() # model = torch.nn.Sequential(*(list(model.children())[:-1])) else: raise ValueError("Network {net_name} not supported") return model
def create_task(task_name, args, nclasses, emb_layer): if args.model == "cnn": input_module = IdentityModule() feature_extractor = CNN(emb_layer.n_d, widths=[3, 4, 5], filters=args.n_filters) d_out = args.n_filters * 3 elif args.model == "lstm": input_module = IdentityModule() feature_extractor = LSTM(emb_layer.n_d, args.dim, args.depth, dropout=args.dropout) d_out = args.dim elif args.model == "mlp": input_module = Average() feature_extractor = nn.Sequential(nn.Linear(emb_layer.n_d, args.dim), nn.ReLU()) d_out = args.dim else: raise ValueError(f"Unrecognized model {args.model}.") return EmmentalTask( name=task_name, module_pool=nn.ModuleDict({ "emb": emb_layer, "input": input_module, "feature": feature_extractor, "dropout": nn.Dropout(args.dropout), f"{task_name}_pred_head": nn.Linear(d_out, nclasses), }), task_flow=[ { "name": "emb", "module": "emb", "inputs": [("_input_", "feature")] }, { "name": "input", "module": "input", "inputs": [("emb", 0)], }, { "name": "feature", "module": "feature", "inputs": [("input", 0)] }, { "name": "dropout", "module": "dropout", "inputs": [("feature", 0)] }, { "name": f"{task_name}_pred_head", "module": f"{task_name}_pred_head", "inputs": [("dropout", 0)], }, ], loss_func=partial(ce_loss, f"{task_name}_pred_head"), output_func=partial(output, f"{task_name}_pred_head"), scorer=Scorer(metrics=["accuracy"]), )
def test_model_invalid_task(caplog): """Unit test of model with invalid task.""" caplog.set_level(logging.INFO) dirpath = "temp_test_model_with_invalid_task" Meta.reset() init( dirpath, config={ "meta_config": { "verbose": 0 }, }, ) task_name = "task1" task = EmmentalTask( name=task_name, module_pool=nn.ModuleDict({ "input_module0": IdentityModule(), f"{task_name}_pred_head": IdentityModule(), }), task_flow=[ { "name": "input1", "module": "input_module0", "inputs": [("_input_", "data")], }, { "name": f"{task_name}_pred_head", "module": f"{task_name}_pred_head", "inputs": [("input1", 0)], }, ], module_device={"input_module0": -1}, loss_func=None, output_func=None, action_outputs=None, scorer=None, require_prob_for_eval=False, require_pred_for_eval=True, ) task1 = EmmentalTask( name=task_name, module_pool=nn.ModuleDict({ "input_module0": IdentityModule(), f"{task_name}_pred_head": IdentityModule(), }), task_flow=[ { "name": "input1", "module": "input_module0", "inputs": [("_input_", "data")], }, { "name": f"{task_name}_pred_head", "module": f"{task_name}_pred_head", "inputs": [("input1", 0)], }, ], module_device={"input_module0": -1}, loss_func=None, output_func=None, action_outputs=None, scorer=None, require_prob_for_eval=False, require_pred_for_eval=True, ) model = EmmentalModel(name="test") model.add_task(task) model.remove_task(task_name) assert model.task_names == set([]) model.remove_task("task_2") assert model.task_names == set([]) model.add_task(task) # Duplicate task with pytest.raises(ValueError): model.add_task(task1) # Invalid task with pytest.raises(ValueError): model.add_task(task_name) shutil.rmtree(dirpath)