def test_build_head(self): config = self._get_config() head = build_head(config) self.assertEqual(head.unique_id, config["unique_id"]) del config["unique_id"] with self.assertRaises(AssertionError): head = build_head(config)
def test_identity_forward(self): config = self._get_pass_through_config() head = build_head(config) input = torch.randn(1, config["in_plane"]) output = head(input) self.assertEqual(input.size(), output.size()) self.assert_(torch.all(torch.eq(input, output)))
def build_model(config): """Builds a ClassyModel from a config. This assumes a 'name' key in the config which is used to determine what model class to instantiate. For instance, a config `{"name": "my_model", "foo": "bar"}` will find a class that was registered as "my_model" (see :func:`register_model`) and call .from_config on it.""" assert config["name"] in MODEL_REGISTRY, "unknown model" model = MODEL_REGISTRY[config["name"]].from_config(config) if "heads" in config: heads = defaultdict(list) for head_config in config["heads"]: assert "fork_block" in head_config, "Expect fork_block in config" fork_block = head_config["fork_block"] updated_config = copy.deepcopy(head_config) del updated_config["fork_block"] head = build_head(updated_config) heads[fork_block].append(head) model.set_heads(heads) log_class_usage("Model", model.__class__) return model
def test_conv_planes(self): num_classes = 10 in_plane = 3 conv_planes = 5 batch_size = 2 image_size = 4 head_config = { "name": "fully_connected", "unique_id": "asd", "in_plane": in_plane, "conv_planes": conv_planes, "num_classes": num_classes, } head = build_head(head_config) self.assertIsInstance(head, FullyConnectedHead) # specify an activation head_config["activation"] = "relu" head = build_head(head_config) # make sure that the head runs and returns the expected dimensions input = torch.rand([batch_size, in_plane, image_size, image_size]) output = head(input) self.assertEqual(output.shape, (batch_size, num_classes))
def test_get_set_head_states(self): config = copy.deepcopy(self._get_config(self.model_configs[0])) head_configs = config["model"]["heads"] config["model"]["heads"] = [] model = build_model(config["model"]) trunk_state = model.get_classy_state() heads = defaultdict(dict) for head_config in head_configs: head = build_head(head_config) heads[head_config["fork_block"]][head.unique_id] = head model.set_heads(heads) model_state = model.get_classy_state() # the heads should be the same as we set self.assertEqual(len(heads), len(model.get_heads())) for block_name, hs in model.get_heads().items(): self.assertEqual(hs, heads[block_name]) model._clear_heads() self._compare_model_state(model.get_classy_state(), trunk_state) model.set_heads(heads) self._compare_model_state(model.get_classy_state(), model_state)
def test_forward(self): config = self._get_config() head = build_head(config) input = torch.randn(1, config["in_plane"]) output = head(input) self.assertEqual(output.size(), torch.Size([1, 3]))