def test_transpose_undefined_input_shape(self): network = layers.Transpose((1, 0, 2)) self.assertShapesEqual(network.input_shape, None) self.assertShapesEqual(network.output_shape, (None, None, None)) network = layers.Transpose((1, 0)) self.assertShapesEqual(network.input_shape, None) self.assertShapesEqual(network.output_shape, (None, None))
def test_transpose_repr(self): layer = layers.Transpose((0, 2, 1)) self.assertEqual( "Transpose((0, 2, 1), name='transpose-1')", str(layer)) layer = layers.Transpose((0, 2, 1), name='test') self.assertEqual( "Transpose((0, 2, 1), name='test')", str(layer))
def test_transpose_exceptions(self): with self.assertRaisesRegexp(ValueError, "cannot be used"): layers.join( layers.Input((7, 11)), layers.Transpose([2, 0]), # cannot use 0 index (batch dim) ) with self.assertRaisesRegexp(LayerConnectionError, "at least 3"): layers.join( layers.Input(20), layers.Transpose([2, 1]), )
def test_transpose_exceptions(self): error_message = "Cannot apply transpose operation to the input" with self.assertRaisesRegexp(LayerConnectionError, error_message): layers.join( layers.Input(20), layers.Transpose((0, 2, 1)), )
def test_transpose_unknown_input_dim(self): network = layers.join( layers.Input((None, 10, 20)), layers.Transpose((0, 2, 1, 3)), ) self.assertShapesEqual(network.output_shape, (None, 10, None, 20)) value = asfloat(np.random.random((12, 100, 10, 20))) output_value = self.eval(network.output(value)) self.assertEqual(output_value.shape, (12, 10, 100, 20)) value = asfloat(np.random.random((12, 33, 10, 20))) output_value = self.eval(network.output(value)) self.assertEqual(output_value.shape, (12, 10, 33, 20))
def test_transpose_unknown_input_dim(self): conn = layers.join( layers.Input((None, 10, 20)), layers.Transpose([2, 1, 3]), ) self.assertEqual(conn.output_shape, (10, None, 20)) value = asfloat(np.random.random((12, 100, 10, 20))) output_value = self.eval(conn.output(value)) self.assertEqual(output_value.shape, (12, 10, 100, 20)) value = asfloat(np.random.random((12, 33, 10, 20))) output_value = self.eval(conn.output(value)) self.assertEqual(output_value.shape, (12, 10, 33, 20))
def test_simple_transpose(self): network = layers.join( layers.Input((7, 11)), layers.Transpose((0, 2, 1)), ) self.assertShapesEqual(network.output_shape, (None, 11, 7))
def test_simple_transpose(self): conn = layers.join( layers.Input((7, 11)), layers.Transpose([2, 1]), ) self.assertEqual(conn.output_shape, (11, 7))