Example #1
0
 def test_dropout(self):
     input_shape = (8, 7, 9)
     output_shape = (8, 7, 9)
     final_shape = base.check_shape_agreement(
         core.Dropout(rate=0.1, mode="train"), input_shape)
     self.assertEqual(final_shape, output_shape)
     final_shape = base.check_shape_agreement(
         core.Dropout(rate=0.1, mode="eval"), input_shape)
     self.assertEqual(final_shape, output_shape)
Example #2
0
    def test_flatten_n(self):
        input_shape = (29, 87, 10, 20, 30)

        layer = core.Flatten()
        expected_shape = (29, 87 * 10 * 20 * 30)
        actual_shape = base.check_shape_agreement(layer, input_shape)
        self.assertEqual(actual_shape, expected_shape)

        layer = core.Flatten(n_axes_to_keep=2)
        expected_shape = (29, 87, 10 * 20 * 30)
        actual_shape = base.check_shape_agreement(layer, input_shape)
        self.assertEqual(actual_shape, expected_shape)

        layer = core.Flatten(n_axes_to_keep=3)
        expected_shape = (29, 87, 10, 20 * 30)
        actual_shape = base.check_shape_agreement(layer, input_shape)
        self.assertEqual(actual_shape, expected_shape)

        layer = core.Flatten(n_axes_to_keep=4)
        expected_shape = (29, 87, 10, 20, 30)
        actual_shape = base.check_shape_agreement(layer, input_shape)
        self.assertEqual(actual_shape, expected_shape)

        # Not enough dimensions.
        with self.assertRaises(base.LayerError):
            base.check_shape_agreement(core.Flatten(n_axes_to_keep=5),
                                       input_shape)

        with self.assertRaises(base.LayerError):
            base.check_shape_agreement(core.Flatten(n_axes_to_keep=6),
                                       input_shape)
Example #3
0
  def test_layer_decorator_and_shape_agreement(self):
    @base.layer()
    def add_one(x, **unused_kwargs):
      return x + 1

    output_shape = base.check_shape_agreement(
        add_one(), (12, 17))  # pylint: disable=no-value-for-parameter
    self.assertEqual(output_shape, (12, 17))
Example #4
0
 def test_parallel_no_ops(self):
     layer = cb.Parallel([], None)
     input_shape = ((3, 2), (4, 7))
     expected_shape = ((3, 2), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)
Example #5
0
 def test_parallel_div_div(self):
     layer = cb.Parallel(core.Div(divisor=0.5), core.Div(divisor=3.0))
     input_shape = ((3, 2), (4, 7))
     expected_shape = ((3, 2), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)
Example #6
0
 def test_parallel_dup_dup(self):
     layer = cb.Parallel(cb.Dup(), cb.Dup())
     input_shape = ((3, 2), (4, 7))
     expected_shape = ((3, 2), (3, 2), (4, 7), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)
Example #7
0
 def test_serial_dup_dup(self):
     layer = cb.Serial(cb.Dup(), cb.Dup())
     input_shape = (3, 2)
     expected_shape = ((3, 2), (3, 2), (3, 2))
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)
Example #8
0
 def test_serial_div_div(self):
     layer = cb.Serial(core.Div(divisor=2.0), core.Div(divisor=5.0))
     input_shape = (3, 2)
     expected_shape = (3, 2)
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)
Example #9
0
 def test_serial_no_op_list(self):
     layer = cb.Serial([])
     input_shape = ((3, 2), (4, 7))
     expected_shape = ((3, 2), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)
Example #10
0
 def test_swap(self):
     layer = cb.Swap()
     input_shape = ((3, 2), (4, 7))
     expected_shape = ((4, 7), (3, 2))
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)
Example #11
0
 def test_drop(self):
     layer = cb.Drop()
     input_shape = (3, 2)
     expected_shape = ()
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)
Example #12
0
 def _test_cell_runs(self, layer, input_shape, output_shape):
     final_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, final_shape)
Example #13
0
 def test_div_shapes(self):
     layer = core.Div(divisor=2.0)
     input_shape = (3, 2)
     expected_shape = (3, 2)
     output_shape = base.check_shape_agreement(layer, input_shape)
     self.assertEqual(output_shape, expected_shape)
Example #14
0
 def test_causal_conv(self):
   input_shape = (29, 5, 20)
   conv = convolution.CausalConv(filters=30, kernel_width=3)
   result_shape = base.check_shape_agreement(conv, input_shape)
   self.assertEqual(result_shape, (29, 5, 30))
Example #15
0
 def test_conv_rebatch(self):
   input_shape = (3, 29, 5, 5, 20)
   result_shape = base.check_shape_agreement(
       convolution.Conv(30, (3, 3)), input_shape)
   self.assertEqual(result_shape, (3, 29, 3, 3, 30))
Example #16
0
 def test_layer_norm_shape(self):
     input_shape = (29, 5, 7, 20)
     result_shape = base.check_shape_agreement(normalization.LayerNorm(),
                                               input_shape)
     self.assertEqual(result_shape, input_shape)