def test_dropout(self): input_shape = (8, 7, 9) output_shape = (8, 7, 9) final_shape = base.check_shape_agreement( core.Dropout(rate=0.1, mode="train"), input_shape) self.assertEqual(final_shape, output_shape) final_shape = base.check_shape_agreement( core.Dropout(rate=0.1, mode="eval"), input_shape) self.assertEqual(final_shape, output_shape)
def test_flatten_n(self): input_shape = (29, 87, 10, 20, 30) layer = core.Flatten() expected_shape = (29, 87 * 10 * 20 * 30) actual_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(actual_shape, expected_shape) layer = core.Flatten(n_axes_to_keep=2) expected_shape = (29, 87, 10 * 20 * 30) actual_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(actual_shape, expected_shape) layer = core.Flatten(n_axes_to_keep=3) expected_shape = (29, 87, 10, 20 * 30) actual_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(actual_shape, expected_shape) layer = core.Flatten(n_axes_to_keep=4) expected_shape = (29, 87, 10, 20, 30) actual_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(actual_shape, expected_shape) # Not enough dimensions. with self.assertRaises(base.LayerError): base.check_shape_agreement(core.Flatten(n_axes_to_keep=5), input_shape) with self.assertRaises(base.LayerError): base.check_shape_agreement(core.Flatten(n_axes_to_keep=6), input_shape)
def test_layer_decorator_and_shape_agreement(self): @base.layer() def add_one(x, **unused_kwargs): return x + 1 output_shape = base.check_shape_agreement( add_one(), (12, 17)) # pylint: disable=no-value-for-parameter self.assertEqual(output_shape, (12, 17))
def test_parallel_no_ops(self): layer = cb.Parallel([], None) input_shape = ((3, 2), (4, 7)) expected_shape = ((3, 2), (4, 7)) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_parallel_div_div(self): layer = cb.Parallel(core.Div(divisor=0.5), core.Div(divisor=3.0)) input_shape = ((3, 2), (4, 7)) expected_shape = ((3, 2), (4, 7)) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_parallel_dup_dup(self): layer = cb.Parallel(cb.Dup(), cb.Dup()) input_shape = ((3, 2), (4, 7)) expected_shape = ((3, 2), (3, 2), (4, 7), (4, 7)) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_serial_dup_dup(self): layer = cb.Serial(cb.Dup(), cb.Dup()) input_shape = (3, 2) expected_shape = ((3, 2), (3, 2), (3, 2)) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_serial_div_div(self): layer = cb.Serial(core.Div(divisor=2.0), core.Div(divisor=5.0)) input_shape = (3, 2) expected_shape = (3, 2) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_serial_no_op_list(self): layer = cb.Serial([]) input_shape = ((3, 2), (4, 7)) expected_shape = ((3, 2), (4, 7)) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_swap(self): layer = cb.Swap() input_shape = ((3, 2), (4, 7)) expected_shape = ((4, 7), (3, 2)) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_drop(self): layer = cb.Drop() input_shape = (3, 2) expected_shape = () output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def _test_cell_runs(self, layer, input_shape, output_shape): final_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, final_shape)
def test_div_shapes(self): layer = core.Div(divisor=2.0) input_shape = (3, 2) expected_shape = (3, 2) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_causal_conv(self): input_shape = (29, 5, 20) conv = convolution.CausalConv(filters=30, kernel_width=3) result_shape = base.check_shape_agreement(conv, input_shape) self.assertEqual(result_shape, (29, 5, 30))
def test_conv_rebatch(self): input_shape = (3, 29, 5, 5, 20) result_shape = base.check_shape_agreement( convolution.Conv(30, (3, 3)), input_shape) self.assertEqual(result_shape, (3, 29, 3, 3, 30))
def test_layer_norm_shape(self): input_shape = (29, 5, 7, 20) result_shape = base.check_shape_agreement(normalization.LayerNorm(), input_shape) self.assertEqual(result_shape, input_shape)