def test_failed_to_repeat_network(self): network = layers.join( layers.Input(10), layers.Relu(5), ) network.create_variables() relu = network.layers[1] error_message = "input shape is incompatible with the output shape" with self.assertRaisesRegexp(LayerConnectionError, error_message): layers.repeat(relu, n=4)
def test_repeat_network(self): block = layers.join( layers.Convolution((3, 3, 32)), layers.Relu(), layers.BatchNorm(), ) network = layers.repeat(block, n=5) self.assertEqual(len(network), 15) self.assertShapesEqual(network.output_shape, (None, None, None, 32))
def test_repeat_with_name_patterns(self): network = layers.repeat(layers.Relu(10, name='rl{}'), n=4) layer_names = [layer.name for layer in network.layers] self.assertSequenceEqual(layer_names, ['rl1', 'rl2', 'rl3', 'rl4'])
def test_repeat_once(self): input_layer = layers.Relu(10) output_layer = layers.repeat(input_layer, n=1) self.assertIs(output_layer, input_layer)
def test_wrong_number_of_repeats(self): error_message = "parameter should be a positive integer" for wrong_value in (0, 1.5, 9. / 3.): with self.assertRaisesRegexp(ValueError, error_message): layers.repeat(layers.Relu(10), n=wrong_value)
def test_repeat_layer(self): network = layers.repeat(layers.Relu(10), n=5) self.assertEqual(len(network), 5) self.assertShapesEqual(network.output_shape, (None, 10))