def top(self, body_output, _): """Transform inputs from model space to target space. Perform the Xception "Exit flow", consisting of a single residual block and two separable convolutional upscalings followed by global spatial average pooling. Args: body_output: A Tensor with shape [batch, ?, ?, body_output_size]. Returns: a Tensors, each with shape [batch_size, ?, ?, vocab_size] """ with tf.variable_scope(self.name): x = body_output # Assume input is a square with self._body_input_depth channels. if self._is_2d: length_float = tf.to_float(tf.shape(x)[1]) length_float *= tf.to_float(tf.shape(x)[2]) spatial_dim_float = tf.sqrt(length_float) spatial_dim = tf.to_int32(spatial_dim_float) x_depth = int(x.get_shape()[3]) x = tf.reshape(x, [-1, spatial_dim, spatial_dim, x_depth]) x = common_layers.conv_block_downsample(x, self._kernel, self._strides, self._padding) x = tf.nn.relu(x) x = tf.reduce_mean(x, axis=[1, 2], keep_dims=True) res = common_layers.conv(x, self._vocab_size, (1, 1)) return tf.expand_dims(res, 3)
def testConvBlockDownsample(self): x = np.random.rand(5, 7, 1, 11) y = common_layers.conv_block_downsample( tf.constant(x, dtype=tf.float32), (3, 1), (2, 1), "SAME") self.evaluate(tf.global_variables_initializer()) res = self.evaluate(y) self.assertEqual(res.shape, (5, 4, 1, 27))
def xception_exit(inputs): with tf.variable_scope("xception_exit"): x = inputs x_shape = x.get_shape().as_list() if x_shape[1] is None or x_shape[2] is None: length_float = tf.to_float(tf.shape(x)[1]) length_float *= tf.to_float(tf.shape(x)[2]) spatial_dim_float = tf.sqrt(length_float) spatial_dim = tf.to_int32(spatial_dim_float) x_depth = x_shape[3] x = tf.reshape(x, [-1, spatial_dim, spatial_dim, x_depth]) elif x_shape[1] != x_shape[2]: spatial_dim = int(math.sqrt(float(x_shape[1] * x_shape[2]))) if spatial_dim * spatial_dim != x_shape[1] * x_shape[2]: raise ValueError("Assumed inputs were square-able but they were " "not. Shape: %s" % x_shape) x = tf.reshape(x, [-1, spatial_dim, spatial_dim, x_depth]) x = common_layers.conv_block_downsample(x, (3, 3), (2, 2), "SAME") return tf.nn.relu(x)