def compute_output_shape(self, input_shape): input_shape = tf.TensorShape(input_shape).as_list() output_shape = list(input_shape) if self.data_format == 'channels_first': c_axis, h_axis, w_axis = 1, 2, 3 else: c_axis, h_axis, w_axis = 3, 1, 2 kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides if self.output_padding is None: out_pad_h = out_pad_w = None else: out_pad_h, out_pad_w = self.output_padding output_shape[c_axis] = self.filters output_shape[h_axis] = conv_utils.deconv_output_length( output_shape[h_axis], kernel_h, padding=self.padding, output_padding=out_pad_h, stride=stride_h, dilation=self.dilation_rate[0]) output_shape[w_axis] = conv_utils.deconv_output_length( output_shape[w_axis], kernel_w, padding=self.padding, output_padding=out_pad_w, stride=stride_w, dilation=self.dilation_rate[1]) return tf.TensorShape(output_shape)
def call(self, inputs): inputs_shape = tf.shape(inputs) batch_size = inputs_shape[0] if self.data_format == "channels_first": t_axis = 2 else: t_axis = 1 length = inputs_shape[t_axis] if self.output_padding is None: output_padding = None else: output_padding = self.output_padding[0] # Infer the dynamic output shape: out_length = conv_utils.deconv_output_length( length, self.kernel_size[0], padding=self.padding, output_padding=output_padding, stride=self.strides[0], dilation=self.dilation_rate[0], ) if self.data_format == "channels_first": output_shape = (batch_size, self.filters, out_length) else: output_shape = (batch_size, out_length, self.filters) data_format = conv_utils.convert_data_format(self.data_format, ndim=3) output_shape_tensor = tf.stack(output_shape) outputs = tf.nn.conv1d_transpose( inputs, self.kernel, output_shape_tensor, strides=self.strides, padding=self.padding.upper(), data_format=data_format, dilations=self.dilation_rate, ) if not tf.executing_eagerly(): # Infer the static output shape: out_shape = self.compute_output_shape(inputs.shape) outputs.set_shape(out_shape) if self.use_bias: outputs = tf.nn.bias_add(outputs, self.bias, data_format=data_format) if self.activation is not None: return self.activation(outputs) return outputs
def compute_output_shape(self, input_shape): input_shape = tf.TensorShape(input_shape).as_list() output_shape = list(input_shape) if self.data_format == "channels_first": c_axis, d_axis, h_axis, w_axis = 1, 2, 3, 4 else: c_axis, d_axis, h_axis, w_axis = 4, 1, 2, 3 kernel_d, kernel_h, kernel_w = self.kernel_size stride_d, stride_h, stride_w = self.strides if self.output_padding is None: out_pad_d = out_pad_h = out_pad_w = None else: out_pad_d, out_pad_h, out_pad_w = self.output_padding output_shape[c_axis] = self.filters output_shape[d_axis] = conv_utils.deconv_output_length( output_shape[d_axis], kernel_d, padding=self.padding, output_padding=out_pad_d, stride=stride_d, ) output_shape[h_axis] = conv_utils.deconv_output_length( output_shape[h_axis], kernel_h, padding=self.padding, output_padding=out_pad_h, stride=stride_h, ) output_shape[w_axis] = conv_utils.deconv_output_length( output_shape[w_axis], kernel_w, padding=self.padding, output_padding=out_pad_w, stride=stride_w, ) return tf.TensorShape(output_shape)
def compute_output_shape(self, input_shape): input_shape = tf.TensorShape(input_shape).as_list() output_shape = list(input_shape) if self.data_format == 'channels_first': c_axis, t_axis = 1, 2 else: c_axis, t_axis = 2, 1 if self.output_padding is None: output_padding = None else: output_padding = self.output_padding[0] output_shape[c_axis] = self.filters output_shape[t_axis] = conv_utils.deconv_output_length( output_shape[t_axis], self.kernel_size[0], padding=self.padding, output_padding=output_padding, stride=self.strides[0], dilation=self.dilation_rate[0]) return tf.TensorShape(output_shape)
def test_deconv_output_length(self): self.assertEqual( 4, conv_utils.deconv_output_length(4, 2, 'same', stride=1)) self.assertEqual( 8, conv_utils.deconv_output_length(4, 2, 'same', stride=2)) self.assertEqual( 5, conv_utils.deconv_output_length(4, 2, 'valid', stride=1)) self.assertEqual( 8, conv_utils.deconv_output_length(4, 2, 'valid', stride=2)) self.assertEqual( 3, conv_utils.deconv_output_length(4, 2, 'full', stride=1)) self.assertEqual( 6, conv_utils.deconv_output_length(4, 2, 'full', stride=2)) self.assertEqual( 5, conv_utils.deconv_output_length(4, 2, 'same', output_padding=2, stride=1)) self.assertEqual( 7, conv_utils.deconv_output_length(4, 2, 'same', output_padding=1, stride=2)) self.assertEqual( 7, conv_utils.deconv_output_length(4, 2, 'valid', output_padding=2, stride=1)) self.assertEqual( 9, conv_utils.deconv_output_length(4, 2, 'valid', output_padding=1, stride=2)) self.assertEqual( 5, conv_utils.deconv_output_length(4, 2, 'full', output_padding=2, stride=1)) self.assertEqual( 7, conv_utils.deconv_output_length(4, 2, 'full', output_padding=1, stride=2)) self.assertEqual( 5, conv_utils.deconv_output_length(4, 2, 'same', output_padding=1, stride=1, dilation=2)) self.assertEqual( 12, conv_utils.deconv_output_length(4, 2, 'valid', output_padding=2, stride=2, dilation=3)) self.assertEqual( 6, conv_utils.deconv_output_length(4, 2, 'full', output_padding=2, stride=2, dilation=3))
def test_deconv_output_length(self): self.assertEqual( 4, conv_utils.deconv_output_length(4, 2, "same", stride=1)) self.assertEqual( 8, conv_utils.deconv_output_length(4, 2, "same", stride=2)) self.assertEqual( 5, conv_utils.deconv_output_length(4, 2, "valid", stride=1)) self.assertEqual( 8, conv_utils.deconv_output_length(4, 2, "valid", stride=2)) self.assertEqual( 3, conv_utils.deconv_output_length(4, 2, "full", stride=1)) self.assertEqual( 6, conv_utils.deconv_output_length(4, 2, "full", stride=2)) self.assertEqual( 5, conv_utils.deconv_output_length(4, 2, "same", output_padding=2, stride=1), ) self.assertEqual( 7, conv_utils.deconv_output_length(4, 2, "same", output_padding=1, stride=2), ) self.assertEqual( 7, conv_utils.deconv_output_length(4, 2, "valid", output_padding=2, stride=1), ) self.assertEqual( 9, conv_utils.deconv_output_length(4, 2, "valid", output_padding=1, stride=2), ) self.assertEqual( 5, conv_utils.deconv_output_length(4, 2, "full", output_padding=2, stride=1), ) self.assertEqual( 7, conv_utils.deconv_output_length(4, 2, "full", output_padding=1, stride=2), ) self.assertEqual( 5, conv_utils.deconv_output_length(4, 2, "same", output_padding=1, stride=1, dilation=2), ) self.assertEqual( 12, conv_utils.deconv_output_length(4, 2, "valid", output_padding=2, stride=2, dilation=3), ) self.assertEqual( 6, conv_utils.deconv_output_length(4, 2, "full", output_padding=2, stride=2, dilation=3), )
def call(self, inputs): inputs_shape = tf.shape(inputs) batch_size = inputs_shape[0] if self.data_format == 'channels_first': h_axis, w_axis = 2, 3 else: h_axis, w_axis = 1, 2 # Use the constant height and weight when possible. # TODO(scottzhu): Extract this into a utility function that can be applied # to all convolutional layers, which currently lost the static shape # information due to tf.shape(). height, width = None, None if inputs.shape.rank is not None: dims = inputs.shape.as_list() height = dims[h_axis] width = dims[w_axis] height = height if height is not None else inputs_shape[h_axis] width = width if width is not None else inputs_shape[w_axis] kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides if self.output_padding is None: out_pad_h = out_pad_w = None else: out_pad_h, out_pad_w = self.output_padding # Infer the dynamic output shape: out_height = conv_utils.deconv_output_length( height, kernel_h, padding=self.padding, output_padding=out_pad_h, stride=stride_h, dilation=self.dilation_rate[0]) out_width = conv_utils.deconv_output_length( width, kernel_w, padding=self.padding, output_padding=out_pad_w, stride=stride_w, dilation=self.dilation_rate[1]) if self.data_format == 'channels_first': output_shape = (batch_size, self.filters, out_height, out_width) else: output_shape = (batch_size, out_height, out_width, self.filters) output_shape_tensor = tf.stack(output_shape) outputs = backend.conv2d_transpose(inputs, self.kernel, output_shape_tensor, strides=self.strides, padding=self.padding, data_format=self.data_format, dilation_rate=self.dilation_rate) if not tf.executing_eagerly(): # Infer the static output shape: out_shape = self.compute_output_shape(inputs.shape) outputs.set_shape(out_shape) if self.use_bias: outputs = tf.nn.bias_add( outputs, self.bias, data_format=conv_utils.convert_data_format(self.data_format, ndim=4)) if self.activation is not None: return self.activation(outputs) return outputs
def call(self, inputs): inputs_shape = tf.shape(inputs) batch_size = inputs_shape[0] if self.data_format == 'channels_first': d_axis, h_axis, w_axis = 2, 3, 4 else: d_axis, h_axis, w_axis = 1, 2, 3 depth = inputs_shape[d_axis] height = inputs_shape[h_axis] width = inputs_shape[w_axis] kernel_d, kernel_h, kernel_w = self.kernel_size stride_d, stride_h, stride_w = self.strides if self.output_padding is None: out_pad_d = out_pad_h = out_pad_w = None else: out_pad_d, out_pad_h, out_pad_w = self.output_padding # Infer the dynamic output shape: out_depth = conv_utils.deconv_output_length(depth, kernel_d, padding=self.padding, output_padding=out_pad_d, stride=stride_d) out_height = conv_utils.deconv_output_length(height, kernel_h, padding=self.padding, output_padding=out_pad_h, stride=stride_h) out_width = conv_utils.deconv_output_length(width, kernel_w, padding=self.padding, output_padding=out_pad_w, stride=stride_w) if self.data_format == 'channels_first': output_shape = (batch_size, self.filters, out_depth, out_height, out_width) strides = (1, 1, stride_d, stride_h, stride_w) else: output_shape = (batch_size, out_depth, out_height, out_width, self.filters) strides = (1, stride_d, stride_h, stride_w, 1) output_shape_tensor = tf.stack(output_shape) outputs = tf.nn.conv3d_transpose( inputs, self.kernel, output_shape_tensor, strides, data_format=conv_utils.convert_data_format(self.data_format, ndim=5), padding=self.padding.upper()) if not tf.executing_eagerly(): # Infer the static output shape: out_shape = self.compute_output_shape(inputs.shape) outputs.set_shape(out_shape) if self.use_bias: outputs = tf.nn.bias_add( outputs, self.bias, data_format=conv_utils.convert_data_format(self.data_format, ndim=4)) if self.activation is not None: return self.activation(outputs) return outputs