def test_dropout(self): input_shape = (8, 7, 9) output_shape = (8, 7, 9) final_shape = base.check_shape_agreement( core.Dropout(rate=0.1, mode="train"), input_shape) self.assertEqual(final_shape, output_shape) final_shape = base.check_shape_agreement( core.Dropout(rate=0.1, mode="eval"), input_shape) self.assertEqual(final_shape, output_shape)
def test_serial_no_op_list(self): layer = cb.Serial([]) input_shape = ((3, 2), (4, 7)) expected_shape = ((3, 2), (4, 7)) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape) input_shape = ((3, 2), (4, 7)) + _REST_OF_STACK expected_shape = ((3, 2), (4, 7)) + _REST_OF_STACK output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_swap(self): layer = cb.Swap() input_shape = ((3, 2), (4, 7)) expected_shape = ((4, 7), (3, 2)) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape) input_shape = ((3, 2), (4, 7)) + _REST_OF_STACK expected_shape = ((4, 7), (3, 2)) + _REST_OF_STACK output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_drop(self): layer = cb.Drop() input_shape = ((3, 2), ) expected_shape = _EMPTY_STACK output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape) input_shape = ((3, 2), ) + _REST_OF_STACK expected_shape = _REST_OF_STACK output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_flatten_n(self): input_shape = (29, 87, 10, 20, 30) actual_shape = base.check_shape_agreement(core.Flatten(), input_shape) self.assertEqual(actual_shape, (29, 87 * 10 * 20 * 30)) actual_shape = base.check_shape_agreement( core.Flatten(num_axis_to_keep=2), input_shape) self.assertEqual(actual_shape, (29, 87, 10 * 20 * 30)) actual_shape = base.check_shape_agreement( core.Flatten(num_axis_to_keep=3), input_shape) self.assertEqual(actual_shape, (29, 87, 10, 20 * 30)) actual_shape = base.check_shape_agreement( core.Flatten(num_axis_to_keep=4), input_shape) self.assertEqual(actual_shape, (29, 87, 10, 20, 30)) # Not enough dimensions. with self.assertRaises(base.LayerError): base.check_shape_agreement(core.Flatten(num_axis_to_keep=5), input_shape) with self.assertRaises(base.LayerError): base.check_shape_agreement(core.Flatten(num_axis_to_keep=6), input_shape)
def test_flatten_n(self): input_shape = (29, 87, 10, 20, 30) layer = core.Flatten() expected_shape = (29, 87 * 10 * 20 * 30) actual_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(actual_shape, expected_shape) layer = core.Flatten(n_axes_to_keep=2) expected_shape = (29, 87, 10 * 20 * 30) actual_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(actual_shape, expected_shape) layer = core.Flatten(n_axes_to_keep=3) expected_shape = (29, 87, 10, 20 * 30) actual_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(actual_shape, expected_shape) layer = core.Flatten(n_axes_to_keep=4) expected_shape = (29, 87, 10, 20, 30) actual_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(actual_shape, expected_shape) # Not enough dimensions. with self.assertRaises(base.LayerError): base.check_shape_agreement(core.Flatten(n_axes_to_keep=5), input_shape) with self.assertRaises(base.LayerError): base.check_shape_agreement(core.Flatten(n_axes_to_keep=6), input_shape)
def test_merged_hashed_causal_attention(self): qkv_shape = (3, 32, 8) input_shape = (qkv_shape, qkv_shape, qkv_shape) layer = attention.MemoryEfficientCausalAttention( loop_stride=16, dropout=0.1, mode='train') final_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual((3, 32, 8), final_shape)
def test_layer_decorator_and_shape_agreement(self): @base.layer() def add_one(x, **unused_kwargs): return x + 1 output_shape = base.check_shape_agreement(add_one(), (12, 17)) # pylint: disable=no-value-for-parameter self.assertEqual(output_shape, (12, 17))
def test_time_bin_causal_attention_n_bins(self): qkv_shape = (3, 57, 8) input_shape = (qkv_shape, qkv_shape, qkv_shape) layer = attention.TimeBinCausalAttention( n_bins=4, dropout=0.1, mode='train') final_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual((3, 57, 8), final_shape)
def test_branch_named(self): input_shape = (2, 3) expected_shape = {'a': (2, 3), 'b': (2, 3)} output_shape = base.check_shape_agreement( combinators.Branch(a=combinators.NoOp(), b=combinators.NoOp()), input_shape) self.assertEqual(output_shape, expected_shape)
def test_parallel(self): input_shape = ((2, 3), (2, 3)) expected_shape = ((2, 3), (2, 3)) output_shape = base.check_shape_agreement( combinators.Parallel(combinators.NoOp(), combinators.NoOp()), input_shape) self.assertEqual(output_shape, expected_shape)
def test_branch(self): input_shape = (2, 3) expected_shape = ((2, 3), (2, 3)) output_shape = base.check_shape_agreement( combinators.Branch(combinators.Copy(), combinators.Copy()), input_shape) self.assertEqual(output_shape, expected_shape)
def test_rebatch(self): input_shape = (29, 5, 5, 20) result_shape = base.check_shape_agreement(convolution.Conv(30, (3, 3)), input_shape) self.assertEqual(result_shape, (29, 3, 3, 30)) input_shape = (29, 5, 5, 20) result_shape = base.check_shape_agreement( combinators.Rebatch(convolution.Conv(30, (3, 3)), n_batch_dims=1), input_shape) self.assertEqual(result_shape, (29, 3, 3, 30)) input_shape = (19, 29, 5, 5, 20) result_shape = base.check_shape_agreement( combinators.Rebatch(convolution.Conv(30, (3, 3)), n_batch_dims=2), input_shape) self.assertEqual(result_shape, (19, 29, 3, 3, 30))
def test_ngpu(self): vocab_size = 2 input_shape = [3, 5, 7] model = neural_gpu.NeuralGPU(d_feature=30, steps=4, vocab_size=vocab_size) final_shape = base.check_shape_agreement(model, tuple(input_shape), integer_inputs=True) self.assertEqual(tuple(input_shape + [vocab_size]), final_shape)
def test_weighted_mean_shape(self): input_shape = ((29, 4, 4, 20), (29, 4, 4, 20)) result_shape = base.check_shape_agreement( metrics.WeightedMean(), input_shape) self.assertEqual(result_shape, ())
def test_select(self): input_shape = ((2, 3), (3, 4)) expected_shape = (3, 4) output_shape = base.check_shape_agreement(combinators.Select(1), input_shape) self.assertEqual(output_shape, expected_shape)
def test_select_named(self): input_shape = {'a': (2, 3), 'b': (3, 4)} expected_shape = (3, 4) output_shape = base.check_shape_agreement(combinators.Select('b'), input_shape) self.assertEqual(output_shape, expected_shape)
def test_parallel_div_div(self): layer = cb.Parallel(core.Div(divisor=0.5), core.Div(divisor=3.0)) input_shape = ((3, 2), (4, 7)) expected_shape = ((3, 2), (4, 7)) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_parallel_named(self): input_shape = {'a': (2, 3), 'b': (2, 3)} expected_shape = {'a': (2, 3), 'b': (2, 3)} output_shape = base.check_shape_agreement( combinators.Parallel(a=combinators.NoOp()), input_shape) self.assertEqual(output_shape, expected_shape)
def test_div_shapes(self): layer = core.Div(divisor=2.0) input_shape = (3, 2) expected_shape = (3, 2) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_parallel_no_ops(self): layer = cb.Parallel([], None) input_shape = ((3, 2), (4, 7)) expected_shape = ((3, 2), (4, 7)) output_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, expected_shape)
def test_reversible_swap(self): layer = reversible.ReversibleSwap() input_shape = ((2, 3), (3, 3)) final_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(final_shape, input_shape[::-1])
def test_batch_norm_shape(self): input_shape = (29, 5, 7, 20) result_shape = base.check_shape_agreement( normalization.BatchNorm(), input_shape) self.assertEqual(result_shape, input_shape)
def test_conv_rebatch(self): input_shape = (3, 29, 5, 5, 20) result_shape = base.check_shape_agreement(convolution.Conv(30, (3, 3)), input_shape) self.assertEqual(result_shape, (3, 29, 3, 3, 30))
def test_layer_norm(self): input_shape = (29, 5, 7, 20) result_shape = base.check_shape_agreement( normalization.LayerNorm(), input_shape) self.assertEqual(result_shape, input_shape)
def _test_cell_runs(self, layer, input_shape, output_shape): final_shape = base.check_shape_agreement(layer, input_shape) self.assertEqual(output_shape, final_shape)
def test_accuracy_scalar(self): input_shape = ((29, 4, 4, 20), (29, 4, 4)) result_shape = base.check_shape_agreement( metrics.AccuracyScalar(), input_shape) self.assertEqual(result_shape, ())
def test_cross_entropy_loss_scalar(self): input_shape = ((29, 4, 4, 20), (29, 4, 4)) result_shape = base.check_shape_agreement( metrics.CrossEntropyLossScalar(), input_shape) self.assertEqual(result_shape, ())
def test_causal_conv(self): input_shape = (29, 5, 20) conv = convolution.CausalConv(filters=30, kernel_width=3) result_shape = base.check_shape_agreement(conv, input_shape) self.assertEqual(result_shape, (29, 5, 30))
def test_weight_mask(self): input_shape = (29, 4, 4, 20) result_shape = base.check_shape_agreement( metrics.WeightMask(), input_shape) self.assertEqual(result_shape, input_shape)