Beispiel #1
0
    def test_output_signature(self):
        input_signature = (ShapeDtype((2, 3, 5)), ShapeDtype((2, 3, 5)))
        layer = Fn('2in1out', lambda x, y: x + y)
        output_signature = layer.output_signature(input_signature)
        self.assertEqual(output_signature, ShapeDtype((2, 3, 5)))

        input_signature = ShapeDtype((5, 7))
        layer = Fn('1in3out', lambda x: (x, 2 * x, 3 * x), n_out=3)
        output_signature = layer.output_signature(input_signature)
        self.assertEqual(output_signature, (ShapeDtype((5, 7)), ) * 3)
        self.assertNotEqual(output_signature, (ShapeDtype((4, 7)), ) * 3)
        self.assertNotEqual(output_signature, (ShapeDtype((5, 7)), ) * 2)