コード例 #1
0
 def test_example(self, verbose: bool = False):
     """Test example model."""
     model = Model(
         os.path.join("tests", "test_configs", "example.yaml"), verbose=verbose
     )
     assert model(TestModelParser.INPUT).shape == torch.Size([1, 10])
     assert count_model_params(model) == 137862
     assert self._validate_export_model(model)
コード例 #2
0
 def test_vgg(self, verbose: bool = False):
     """Test vgg model."""
     model = Model(
         os.path.join("tests", "test_configs", "vgg.yaml"), verbose=verbose
     )
     assert model(TestModelParser.INPUT).shape == torch.Size([1, 10])
     assert count_model_params(model) == 3732970
     assert self._validate_export_model(model)
コード例 #3
0
def test_custom_module(verbose: bool = False):
    model = Model(
        os.path.join("tests", "test_configs", "custom_module_model.yaml"),
        verbose=verbose,
    )

    assert model(torch.rand(1, 3, 32, 32)).shape == torch.Size([1, 10])
    assert count_model_params(model) == 138568
コード例 #4
0
 def test_pretrained2(self, verbose: bool = False):
     """Test show case model."""
     model = Model(
         os.path.join("tests", "test_configs", "pretrained_example2.yaml"),
         verbose=verbose,
     )
     assert model(TestModelParser.INPUT).shape == torch.Size([1, 10])
     assert count_model_params(model) == 3760122
     assert self._validate_export_model(model)
コード例 #5
0
 def test_show_case(self, verbose: bool = False):
     """Test show case model."""
     model = Model(
         os.path.join("tests", "test_configs", "show_case.yaml"),
         verbose=verbose,
     )
     assert count_model_params(model) == 185666
     assert model(TestModelParser.INPUT).shape == torch.Size([1, 10])
     assert self._validate_export_model(model)