Ejemplo n.º 1
0
    def test_get_parameter_overview_empty(self):
        module = snt.Module()
        snt.allow_empty_variables(module)

        # No variables.
        self.assertEqual(EMPTY_PARAMETER_OVERVIEW,
                         parameter_overview.get_parameter_overview(module))

        module.conv = snt.Conv2D(output_channels=2, kernel_shape=3)
        # Variables not yet created (happens in the first forward pass).
        self.assertEqual(EMPTY_PARAMETER_OVERVIEW,
                         parameter_overview.get_parameter_overview(module))
Ejemplo n.º 2
0
 def test_get_parameter_overview(self):
     rng = jax.random.PRNGKey(42)
     # Weights of a 2D convolution with 2 filters..
     variables = CNN().init(rng, jnp.zeros((2, 5, 5, 3)))
     self.assertEqual(
         FLAX_CONV2D_PARAMETER_OVERVIEW,
         parameter_overview.get_parameter_overview(variables["params"],
                                                   include_stats=False))
     print(parameter_overview.get_parameter_overview(variables["params"]))
     self.assertEqual(
         FLAX_CONV2D_PARAMETER_OVERVIEW_WITH_STATS,
         parameter_overview.get_parameter_overview(variables["params"]))
Ejemplo n.º 3
0
 def test_get_parameter_overview_on_module(self):
   module = snt.Module()
   # Weights of a 2D convolution with 2 filters..
   module.conv = snt.Conv2D(output_channels=2, kernel_shape=3, name="conv")
   module.conv(tf.ones((2, 5, 5, 3)))  # 3 * 3^2 * 2 = 56 parameters
   for v in module.variables:
     v.assign(tf.ones_like(v))
   self.assertEqual(
       SNT_CONV2D_PARAMETER_OVERVIEW,
       parameter_overview.get_parameter_overview(module, include_stats=False))
   self.assertEqual(SNT_CONV2D_PARAMETER_OVERVIEW_WITH_STATS,
                    parameter_overview.get_parameter_overview(module))
Ejemplo n.º 4
0
 def test_get_parameter_overview_empty(self):
     self.assertEqual(EMPTY_PARAMETER_OVERVIEW,
                      parameter_overview.get_parameter_overview({}))