Example #1
0
    def test_group_norm_connection_exception(self):
        message = "Cannot divide 11 input channels into 4 groups"
        with self.assertRaisesRegexp(LayerConnectionError, message):
            layers.join(
                layers.Input((10, 10, 11)),
                layers.GroupNorm(4),
            )

        message = ("Group normalization layer expects 4 "
                   "dimensional input, got 3 instead.")
        with self.assertRaisesRegexp(LayerConnectionError, message):
            layers.join(
                layers.Input((10, 12)),
                layers.GroupNorm(4),
            )
Example #2
0
    def test_group_norm_unknown_shape(self):
        network = layers.join(
            layers.Input((None, None, 16)),
            layers.GroupNorm(4),
        )

        input = np.random.random((7, 6, 6, 16))
        actual_output = self.eval(network.output(input))
        self.assertEqual(actual_output.shape, (7, 6, 6, 16))
Example #3
0
    def test_group_norm(self):
        network = layers.join(
            layers.Input((10, 10, 12)),
            layers.GroupNorm(4),
        )
        self.assertShapesEqual(network.input_shape, (None, 10, 10, 12))
        self.assertShapesEqual(network.output_shape, (None, 10, 10, 12))

        input = np.random.random((7, 10, 10, 12))
        actual_output = self.eval(network.output(input))
        self.assertEqual(actual_output.shape, (7, 10, 10, 12))
Example #4
0
    def test_group_norm_weight_init_exception(self):
        network = layers.join(
            layers.Input((10, 10, None)),
            layers.GroupNorm(4),
        )

        message = ("Cannot initialize variables when number of "
                   "channels is unknown.")
        with self.assertRaisesRegexp(WeightInitializationError, message):
            network.create_variables()

        with self.assertRaisesRegexp(WeightInitializationError, message):
            network.outputs
Example #5
0
 def test_group_norm_repr(self):
     layer = layers.GroupNorm(4)
     self.assertEqual(
         str(layer),
         ("GroupNorm(n_groups=4, beta=Constant(0), "
          "gamma=Constant(1), epsilon=1e-05, name='group-norm-1')"))