Example #1
0
 def test_unspecified_weight_dim_fails(self):
     input_tensor = keras.Input(shape=(32,))
     layer = einsum_dense.EinsumDense(equation="ab,zd->ad", output_shape=64)
     with self.assertRaisesRegex(
         ValueError, ".*Weight dimension 'z' did not have a match "
     ):
         _ = layer(input_tensor)
Example #2
0
    def test_layer_creation(
        self,
        equation,
        bias_axes,
        input_shape,
        output_shape,
        expected_weight_shape,
        expected_bias_shape,
        expected_output_shape,
    ):
        # Keras elides the 0-dimension of the input shape when constructing inputs.
        non_batch_input_shape = list(input_shape)[1:]

        input_tensor = keras.Input(shape=non_batch_input_shape)
        layer = einsum_dense.EinsumDense(
            equation=equation, output_shape=output_shape, bias_axes=bias_axes
        )
        output_tensor = layer(input_tensor)

        self.assertAllEqual(expected_weight_shape, layer.kernel.shape.as_list())
        if expected_bias_shape is None:
            self.assertIsNone(layer.bias)
        else:
            self.assertAllEqual(expected_bias_shape, layer.bias.shape.as_list())
        self.assertAllEqual(
            expected_output_shape, output_tensor.shape.as_list()
        )
Example #3
0
 def test_unspecified_output_dim_fails(self):
     input_tensor = keras.Input(shape=(32, ))
     layer = einsum_dense.EinsumDense(equation="ab,bc->cd", output_shape=64)
     with self.assertRaisesRegex(
             ValueError,
             ".*Dimension 'd' was specified in the output 'cd' but has "
             "no corresponding dim.*"):
         _ = layer(input_tensor)
Example #4
0
 def test_unspecified_bias_dim_fails(self):
     input_tensor = keras.Input(shape=(32, ))
     layer = einsum_dense.EinsumDense(equation="ab,bc->ac",
                                      output_shape=64,
                                      bias_axes="y")
     with self.assertRaisesRegex(ValueError,
                                 ".*is not part of the output spec.*"):
         _ = layer(input_tensor)
Example #5
0
 def test_incompatible_input_output_shape_fails(self):
     input_tensor = keras.Input(shape=(32, 64))
     layer = einsum_dense.EinsumDense(equation="abc,cd->abd",
                                      output_shape=(10, 96))
     with self.assertRaisesRegex(
             ValueError,
             ".*Input shape and output shape do not match at shared "
             "dimension 'b'.*"):
         _ = layer(input_tensor)