def test_instantiate_other_obj(self): # do nothing for other obj self.assertEqual(instantiate(5), 5) x = [3, 4, 5] self.assertEqual(instantiate(x), x) x = TestClass(1) self.assertIs(instantiate(x), x) x = {"xx": "yy"} self.assertIs(instantiate(x), x)
def test_instantiate_lst(self): lst = [1, 2, L(TestClass)(int_arg=1)] x = L(TestClass)(int_arg=lst) # list as an argument should be recursively instantiated x = instantiate(x).int_arg self.assertEqual(x[:2], [1, 2]) self.assertIsInstance(x[2], TestClass) self.assertEqual(x[2].int_arg, 1)
def new_f(cls, save_path, **load_kwargs): tracing_adapted = load_kwargs.pop("tracing_adapted", False) if not tracing_adapted: logger.info("The model is not tracing adapted, load it normally.") return old_f(cls, save_path, **load_kwargs) logger.info( "The model is tracing adapted, load the schema and wrap the model for inference." ) assert "inputs_schema" in load_kwargs, load_kwargs.keys() assert "outputs_schema" in load_kwargs, load_kwargs.keys() inputs_schema = instantiate(load_kwargs.pop("inputs_schema")) outputs_schema = instantiate(load_kwargs.pop("outputs_schema")) traced_model = old_f(cls, save_path, **load_kwargs) return TracingAdapterModelWrapper(traced_model, inputs_schema, outputs_schema)
def _check_schema(self, schema): dumped_schema = dump_dataclass(schema) # Check that the schema is json-serializable # Although in reality you might want to use yaml because it often has many levels json.dumps(dumped_schema) # Check that the schema can be deserialized new_schema = instantiate(dumped_schema) self.assertEqual(schema, new_schema)
def test_basic_construct(self): objconf = L(TestClass)( int_arg=3, list_arg=[10], dict_arg={}, extra_arg=L(TestClass)(int_arg=4, list_arg="${..list_arg}"), ) obj = instantiate(objconf) self.assertIsInstance(obj, TestClass) self.assertEqual(obj.int_arg, 3) self.assertEqual(obj.extra_arg.int_arg, 4) self.assertEqual(obj.extra_arg.list_arg, obj.list_arg) objconf.extra_arg.list_arg = [5] obj = instantiate(objconf) self.assertIsInstance(obj, TestClass) self.assertEqual(obj.extra_arg.list_arg, [5])
def load(cls, save_path, inputs_schema, outputs_schema, **load_kwargs): inputs_schema = instantiate(inputs_schema) outputs_schema = instantiate(outputs_schema) traced_model = load_model(save_path, "torchscript") class TracingAdapterWrapper(nn.Module): def __init__(self, traced_model, inputs_schema, outputs_schema): super().__init__() self.traced_model = traced_model self.inputs_schema = inputs_schema self.outputs_schema = outputs_schema def forward(self, *input_args): flattened_inputs, _ = flatten_to_tuple(input_args) flattened_outputs = self.traced_model(*flattened_inputs) return self.outputs_schema(flattened_outputs) return TracingAdapterWrapper(traced_model, inputs_schema, outputs_schema)
def new_f(cls, save_path, **load_kwargs): assert "inputs_schema" in load_kwargs assert "outputs_schema" in load_kwargs inputs_schema = instantiate(load_kwargs.pop("inputs_schema")) outputs_schema = instantiate(load_kwargs.pop("outputs_schema")) traced_model = old_f(cls, save_path, **load_kwargs) class TracingAdapterModelWrapper(nn.Module): def __init__(self, traced_model, inputs_schema, outputs_schema): super().__init__() self.traced_model = traced_model self.inputs_schema = inputs_schema self.outputs_schema = outputs_schema def forward(self, *input_args): flattened_inputs, _ = flatten_to_tuple(input_args) flattened_outputs = self.traced_model(*flattened_inputs) return self.outputs_schema(flattened_outputs) return TracingAdapterModelWrapper(traced_model, inputs_schema, outputs_schema)
def test_instantiate_namedtuple(self): x = L(TestClass)(int_arg=ShapeSpec(channels=1, width=3)) # test serialization with tempfile.TemporaryDirectory() as d: fname = os.path.join(d, "d2_test.yaml") OmegaConf.save(x, fname) with open(fname) as f: x = yaml.unsafe_load(f) x = instantiate(x) self.assertIsInstance(x.int_arg, ShapeSpec) self.assertEqual(x.int_arg.channels, 1)
def test_instantiate_lazy_target(self): # _target_ is result of instantiate objconf = L(L(len)(int_arg=3))(call_arg=4) objconf._target_._target_ = TestClass self.assertEqual(instantiate(objconf), 7)