Example #1
0
 def test_dropout(self):
     input_signature = ShapeDtype((8, 7, 9))
     output_shape = (8, 7, 9)
     final_shape = base.check_shape_agreement(
         core.Dropout(rate=0.1, mode='train'), input_signature)
     self.assertEqual(final_shape, output_shape)
     final_shape = base.check_shape_agreement(
         core.Dropout(rate=0.1, mode='eval'), input_signature)
     self.assertEqual(final_shape, output_shape)
Example #2
0
    def test_frn_shape(self):
        B, H, W, C = 64, 5, 7, 3  # pylint: disable=invalid-name
        input_signature = ShapeDtype((B, H, W, C))
        result_shape = base.check_shape_agreement(
            normalization.FilterResponseNorm(), input_signature)
        self.assertEqual(result_shape, input_signature.shape)

        result_shape = base.check_shape_agreement(
            normalization.FilterResponseNorm(learn_epsilon=False),
            input_signature)
        self.assertEqual(result_shape, input_signature.shape)
Example #3
0
    def test_flatten_n(self):
        input_signature = ShapeDtype((29, 87, 10, 20, 30))

        layer = core.Flatten()
        expected_shape = (29, 87 * 10 * 20 * 30)
        actual_shape = base.check_shape_agreement(layer, input_signature)
        self.assertEqual(actual_shape, expected_shape)

        layer = core.Flatten(n_axes_to_keep=2)
        expected_shape = (29, 87, 10 * 20 * 30)
        actual_shape = base.check_shape_agreement(layer, input_signature)
        self.assertEqual(actual_shape, expected_shape)

        layer = core.Flatten(n_axes_to_keep=3)
        expected_shape = (29, 87, 10, 20 * 30)
        actual_shape = base.check_shape_agreement(layer, input_signature)
        self.assertEqual(actual_shape, expected_shape)

        layer = core.Flatten(n_axes_to_keep=4)
        expected_shape = (29, 87, 10, 20, 30)
        actual_shape = base.check_shape_agreement(layer, input_signature)
        self.assertEqual(actual_shape, expected_shape)

        # Not enough dimensions.
        with self.assertRaises(base.LayerError):
            base.check_shape_agreement(core.Flatten(n_axes_to_keep=5),
                                       input_signature)

        with self.assertRaises(base.LayerError):
            base.check_shape_agreement(core.Flatten(n_axes_to_keep=6),
                                       input_signature)
Example #4
0
 def test_ngpu(self):
   vocab_size = 2
   input_shape = [3, 5, 7]
   model = neural_gpu.NeuralGPU(d_feature=30, steps=4, vocab_size=vocab_size)
   final_shape = base.check_shape_agreement(
       model, tuple(input_shape), integer_inputs=True)
   self.assertEqual(tuple(input_shape + [vocab_size]), final_shape)
Example #5
0
 def test_time_bin_causal_attention_bin_length(self):
   qkv_shape = (3, 57, 8)
   input_shape = (qkv_shape, qkv_shape, qkv_shape)
   layer = attention.TimeBinCausalAttention(
       bin_length=16, dropout=0.1, mode='train')
   final_shape = base.check_shape_agreement(layer, input_shape)
   self.assertEqual((3, 57, 8), final_shape)
Example #6
0
 def test_select_second_of_3(self):
     layer = cb.Select([1], n_in=3)
     input_signature = (ShapeDtype((3, 2)), ShapeDtype(
         (4, 7)), ShapeDtype((11, 13)))
     expected_shape = (4, 7)
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
 def test_merged_hashed_causal_attention(self):
   qkv_shape = ShapeDtype((3, 32, 8))
   input_signature = (qkv_shape, qkv_shape, qkv_shape)
   layer = attention.MemoryEfficientCausalAttention(
       loop_stride=16, dropout=0.1, mode='train')
   final_shape = base.check_shape_agreement(layer, input_signature)
   self.assertEqual((3, 32, 8), final_shape)
Example #8
0
 def test_lsh_causal_attention_fast_inference(self):
     qkv_shape = ShapeDtype((3, 1, 8))
     input_signature = (qkv_shape, qkv_shape, qkv_shape)
     layer = efficient_attention.LSHCausalAttention(
         n_bins=4, dropout=0.0, max_len_for_inference=128, mode='predict')
     final_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual((3, 1, 8), final_shape)
 def test_time_bin_causal_attention_n_bins(self):
   qkv_shape = ShapeDtype((3, 57, 8))
   input_signature = (qkv_shape, qkv_shape, qkv_shape)
   layer = attention.TimeBinCausalAttention(
       n_bins=4, dropout=0.1, mode='train')
   final_shape = base.check_shape_agreement(layer, input_signature)
   self.assertEqual((3, 57, 8), final_shape)
Example #10
0
    def test_layer_decorator_and_shape_agreement(self):
        @base.layer()
        def add_one(x, **unused_kwargs):
            return x + 1

        output_shape = base.check_shape_agreement(add_one(),
                                                  ShapeDtype((12, 17)))  # pylint: disable=no-value-for-parameter
        self.assertEqual(output_shape, (12, 17))
Example #11
0
 def test_ngpu(self):
     vocab_size = 2
     input_signature = ShapeDtype((3, 5, 7), np.int32)
     model = neural_gpu.NeuralGPU(d_feature=30,
                                  steps=4,
                                  vocab_size=vocab_size)
     final_shape = base.check_shape_agreement(model, input_signature)
     self.assertEqual((3, 5, 7, vocab_size), final_shape)
Example #12
0
 def test_fn_layer_varargs_n_in(self):
     with self.assertRaisesRegex(ValueError, 'variable arg'):
         base.Fn(lambda *args: args[0])
     # Check that varargs work when n_in is set.
     id_layer = base.Fn(lambda *args: args[0], n_in=1)
     input_signature = ShapeDtype((2, 7))
     expected_shape = (2, 7)
     output_shape = base.check_shape_agreement(id_layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Example #13
0
 def test_serial_with_side_outputs_div_div(self):
   def some_layer():
     return cb.Parallel(divide_by(2.0), divide_by(5.0))
   layer = cb.SerialWithSideOutputs([some_layer(), some_layer()])
   input_signature = (ShapeDtype((3, 2)), ShapeDtype((4, 2)),
                      ShapeDtype((5, 2)))
   expected_shape = ((3, 2), (4, 2), (5, 2))
   output_shape = base.check_shape_agreement(layer, input_signature)
   self.assertEqual(output_shape, expected_shape)
Example #14
0
 def test_fn_layer_example(self):
     layer = base.Fn(lambda x, y: (x + y, np.concatenate([x, y], axis=0)))
     input_signature = (ShapeDtype((2, 7)), ShapeDtype((2, 7)))
     expected_shape = ((2, 7), (4, 7))
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
     inp = (np.array([2]), np.array([3]))
     x, xs = layer(inp)
     self.assertEqual(int(x), 5)
     self.assertEqual([int(y) for y in xs], [2, 3])
Example #15
0
 def test_scan_axis1(self):
   @base.layer(n_in=2, n_out=2)
   def add(x, **unused_kwargs):
     res = x[0] + x[1]
     return res, res
   scan = cb.Scan(add(), axis=1)  # pylint: disable=no-value-for-parameter
   input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((3, 7)))
   expected_shape = ((3, 2, 7), (3, 7))
   output_shape = base.check_shape_agreement(scan, input_signature)
   self.assertEqual(output_shape, expected_shape)
Example #16
0
 def test_fn_layer_difficult_n_out(self):
     with self.assertRaisesRegex(ValueError, 'n_out'):
         # Determining the output of this layer is hard with dummies.
         base.Fn(lambda x: np.concatencate([x, x], axis=4))
     # Check that this layer works when n_out is set.
     layer = base.Fn(lambda x: np.concatenate([x, x], axis=4), n_out=1)
     input_signature = ShapeDtype((2, 1, 2, 2, 3))
     expected_shape = (2, 1, 2, 2, 6)
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Example #17
0
 def test_scan_multiinput(self):
   @base.layer(n_in=3, n_out=2)
   def foo(x, **unused_kwargs):
     a, b, carry = x
     return a + b, b, carry + 1
   scan = cb.Scan(foo(), axis=1)  # pylint: disable=no-value-for-parameter
   input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((3, 2, 7)),
                      ShapeDtype((3, 7)))
   expected_shape = ((3, 2, 7), (3, 2, 7), (3, 7))
   output_shape = base.check_shape_agreement(scan, input_signature)
   self.assertEqual(output_shape, expected_shape)
Example #18
0
  def test_scan_nocarry(self):
    def addone():  # pylint: disable=invalid-name
      return base.Fn('addone', lambda x: x + 1)

    scan_layer = cb.Scan(addone(), n_carry=0)
    input_signature = ShapeDtype((3, 2, 7))
    expected_shape = (3, 2, 7)
    output_shape = base.check_shape_agreement(scan_layer, input_signature)
    self.assertEqual(output_shape, expected_shape)
    inp = np.array([1, 2, 3])
    o = scan_layer(inp)
    self.assertEqual([int(x) for x in o], [2, 3, 4])
Example #19
0
  def test_scan_multiinput(self):
    def foo():  # pylint: disable=invalid-name
      def f(a, b, carry):
        return a + b, b, carry + 1
      return base.Fn('foo', f, n_out=2)

    scan = cb.Scan(foo(), axis=1)
    input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((3, 2, 7)),
                       ShapeDtype((3, 7)))
    expected_shape = ((3, 2, 7), (3, 2, 7), (3, 7))
    output_shape = base.check_shape_agreement(scan, input_signature)
    self.assertEqual(output_shape, expected_shape)
Example #20
0
  def test_scan_axis1(self):
    def add():  # pylint: disable=invalid-name
      def f(x, carry):
        res = x + carry
        return res, res  # output and carry are the same
      return base.Fn('add', f, n_out=2)

    scan = cb.Scan(add(), axis=1)
    input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((3, 7)))
    expected_shape = ((3, 2, 7), (3, 7))
    output_shape = base.check_shape_agreement(scan, input_signature)
    self.assertEqual(output_shape, expected_shape)
Example #21
0
    def test_scan_nocarry(self):
        @base.layer(n_in=1, n_out=1)
        def addone(x, **unused_kwargs):
            return x + 1

        scan_layer = cb.Scan(addone(), n_carry=0)  # pylint: disable=no-value-for-parameter
        input_signature = ShapeDtype((3, 2, 7))
        expected_shape = (3, 2, 7)
        output_shape = base.check_shape_agreement(scan_layer, input_signature)
        self.assertEqual(output_shape, expected_shape)
        inp = np.array([1, 2, 3])
        o = scan_layer(inp)
        self.assertEqual([int(x) for x in o], [2, 3, 4])
Example #22
0
 def test_scan_basic(self):
   @base.layer(n_in=2, n_out=2)
   def add(x, **unused_kwargs):
     res = x[0] + x[1]
     return res, res
   scan_layer = cb.Scan(add())  # pylint: disable=no-value-for-parameter
   input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((2, 7)))
   expected_shape = ((3, 2, 7), (2, 7))
   output_shape = base.check_shape_agreement(scan_layer, input_signature)
   self.assertEqual(output_shape, expected_shape)
   inp = (np.array([1, 2, 3]), np.array(0))
   o, v = scan_layer(inp)
   self.assertEqual(int(v), 6)
   self.assertEqual([int(x) for x in o], [1, 3, 6])
Example #23
0
  def test_scan_basic(self):
    def add():  # pylint: disable=invalid-name
      def f(x, carry):
        res = x + carry
        return res, res  # output and carry are the same
      return base.Fn('add', f, n_out=2)

    scan_layer = cb.Scan(add())
    input_signature = (ShapeDtype((3, 2, 7)), ShapeDtype((2, 7)))
    expected_shape = ((3, 2, 7), (2, 7))
    output_shape = base.check_shape_agreement(scan_layer, input_signature)
    self.assertEqual(output_shape, expected_shape)
    inp = (np.array([1, 2, 3]), np.array(0))
    o, v = scan_layer(inp)
    self.assertEqual(int(v), 6)
    self.assertEqual([int(x) for x in o], [1, 3, 6])
 def test_self_attention(self):
     with math.use_backend('jax'):
         input_signature = ShapeDtype((3, 32, 8))
         layer = efficient_attention.SelfAttention(n_heads=5,
                                                   d_qk=7,
                                                   d_v=17,
                                                   share_qk=False,
                                                   causal=True,
                                                   chunk_len=8,
                                                   n_chunks_before=1,
                                                   n_chunks_after=0,
                                                   use_reference_code=True,
                                                   attention_dropout=0.0,
                                                   mode='train')
         final_shape = base.check_shape_agreement(layer, input_signature)
         self.assertEqual((3, 32, 8), final_shape)
Example #25
0
 def test_branch_add_div(self):
     layer = cb.Branch(cb.Add(), divide_by(0.5))
     input_signature = (ShapeDtype((3, 2)), ShapeDtype((3, 2)))
     expected_shape = ((3, 2), (3, 2))
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Example #26
0
 def test_div_shapes(self):
     layer = core.Div(divisor=2.0)
     input_signature = ShapeDtype((3, 2))
     expected_shape = (3, 2)
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Example #27
0
 def test_branch_noop_dup(self):
     layer = cb.Branch([], cb.Dup())
     input_signature = ShapeDtype((3, 2))
     expected_shape = ((3, 2), (3, 2), (3, 2))
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Example #28
0
 def test_branch_one_layer(self):
     layer = cb.Branch(divide_by(0.5))
     input_signature = ShapeDtype((3, 2))
     expected_shape = (3, 2)
     output_shape = base.check_shape_agreement(layer, input_signature)
     self.assertEqual(output_shape, expected_shape)
Example #29
0
 def test_conv_rebatch(self):
   input_signature = ShapeDtype((3, 29, 5, 5, 20))
   result_shape = base.check_shape_agreement(convolution.Conv(30, (3, 3)),
                                             input_signature)
   self.assertEqual(result_shape, (3, 29, 3, 3, 30))
Example #30
0
 def test_causal_conv(self):
   input_signature = ShapeDtype((29, 5, 20))
   conv = convolution.CausalConv(filters=30, kernel_width=3)
   result_shape = base.check_shape_agreement(conv, input_signature)
   self.assertEqual(result_shape, (29, 5, 30))