Ejemplo n.º 1
0
    def _validate_export_model(self, model: Model) -> bool:
        model.eval()
        model_out = model(TestModelParser.INPUT)
        # model.fuse().eval()
        model.export()
        model_out_fused = model(TestModelParser.INPUT)

        return torch.all(torch.isclose(model_out, model_out_fused, rtol=1e-5))
Ejemplo n.º 2
0
    def __call__(self, repeat: int = 1) -> nn.Module:
        module: Union[List[nn.Module], nn.Module]

        if repeat > 1:
            # Currently, yaml module must have same in and out channel in order to apply repeat.
            module = [Model(self.cfg, **self.kwargs) for _ in range(repeat)]
        else:
            module = Model(self.cfg, **self.kwargs)

        return self._get_module(module)
Ejemplo n.º 3
0
 def test_example(self, verbose: bool = False):
     """Test example model."""
     model = Model(
         os.path.join("tests", "test_configs", "example.yaml"), verbose=verbose
     )
     assert model(TestModelParser.INPUT).shape == torch.Size([1, 10])
     assert count_model_params(model) == 137862
     assert self._validate_export_model(model)
Ejemplo n.º 4
0
 def test_vgg(self, verbose: bool = False):
     """Test vgg model."""
     model = Model(
         os.path.join("tests", "test_configs", "vgg.yaml"), verbose=verbose
     )
     assert model(TestModelParser.INPUT).shape == torch.Size([1, 10])
     assert count_model_params(model) == 3732970
     assert self._validate_export_model(model)
Ejemplo n.º 5
0
 def test_show_case(self, verbose: bool = False):
     """Test show case model."""
     model = Model(
         os.path.join("tests", "test_configs", "show_case.yaml"),
         verbose=verbose,
     )
     assert model(TestModelParser.INPUT).shape == torch.Size([1, 10])
     assert count_model_params(model) == 168866
Ejemplo n.º 6
0
 def test_gap_model(self, verbose: bool = False):
     """Test example model."""
     model = Model(
         os.path.join("tests", "test_configs", "gap_test_model.yaml"),
         verbose=verbose,
     )
     assert model(TestModelParser.INPUT).shape == torch.Size([1, 10])
     assert count_model_params(model) == 20148
Ejemplo n.º 7
0
 def test_pretrained2(self, verbose: bool = False):
     """Test show case model."""
     model = Model(
         os.path.join("tests", "test_configs", "pretrained_example2.yaml"),
         verbose=verbose,
     )
     assert model(TestModelParser.INPUT).shape == torch.Size([1, 10])
     assert count_model_params(model) == 3760122
     assert self._validate_export_model(model)
Ejemplo n.º 8
0
    def __init__(self, *args, **kwargs) -> None:
        """Initialize YamlModuleGenerator."""
        super().__init__(*args, **kwargs)
        with open(self.args[0], "r") as f:
            self.cfg = yaml.load(f, yaml.FullLoader)

        self.cfg.update(
            {
                "input_channel": self.in_channel,
                "depth_multiple": 1.0,
                "width_multiple": self.width_multiply,
                "backbone": self.cfg.pop("module"),
            }
        )
        self.module = Model(self.cfg, verbose=False)