Exemplo n.º 1
0
    def test_subgraph_pruning(self):
        """Tests whether new graph was pruned of old nodes."""
        new_params = flax.core.unfreeze(self.new_state)["params"]
        new_param_count = parameter_overview.count_parameters(new_params)

        params = flax.core.unfreeze(self.state)["params"]
        param_count = parameter_overview.count_parameters(params)
        self.assertLess(new_param_count, param_count)
Exemplo n.º 2
0
    def test_count_parameters_empty(self):
        module = snt.Module()
        snt.allow_empty_variables(module)

        # No variables.
        self.assertEqual(0, parameter_overview.count_parameters(module))

        # Single variable.
        module.var = tf.Variable([0, 1])
        self.assertEqual(2, parameter_overview.count_parameters(module))
Exemplo n.º 3
0
 def test_count_parameters(self):
     rng = jax.random.PRNGKey(42)
     # Weights of a 2D convolution with 2 filters..
     variables = CNN().init(rng, jnp.zeros((2, 5, 5, 3)))
     # 3 * 3*3 * 2 + 2 (bias) = 56 parameters
     self.assertEqual(
         56, parameter_overview.count_parameters(variables["params"]))
Exemplo n.º 4
0
 def test_count_parameters_on_module(self):
     module = snt.Module()
     # Weights of a 2D convolution with 2 filters..
     module.conv = snt.Conv2D(output_channels=2,
                              kernel_shape=3,
                              name="conv")
     module.conv(tf.ones(
         (2, 5, 5, 3)))  # 3 * 3*3 * 2 + 2 (bias) = 56 parameters
     self.assertEqual(56, parameter_overview.count_parameters(module))
Exemplo n.º 5
0
 def _initialize_train(self):
     self._train_input = self._build_train_input()
     if self._params is None:
         input_shape = (1, self.config.image_size, self.config.image_size,
                        3)
         inputs = jnp.ones(input_shape, jnp.float32)
         init_net = jax.pmap(lambda *a: self.net.init(*a, is_training=True),
                             axis_name='i')
         init_rng = jl_utils.bcast_local_devices(self.init_rng)
         self._params = init_net(init_rng, inputs)
         num_params = count_parameters(self._params)
         logging.info(
             f'Net params: {num_params / jax.local_device_count()}')
         self._make_opt()
         self._opt_state = self._opt.init(self._params)
Exemplo n.º 6
0
 def test_count_parameters_empty(self):
     self.assertEqual(0, parameter_overview.count_parameters({}))
Exemplo n.º 7
0
 def test_cnn_params(self):
     params = flax.core.unfreeze(self.cnn_state)["params"]
     param_count = parameter_overview.count_parameters(params)
     self.assertEqual(param_count, 2192458)