Esempio n. 1
0
    def compute_output_shape(self, input_shape):
        input_shape = tf.TensorShape(input_shape).as_list()
        output_shape = list(input_shape)
        if self.data_format == 'channels_first':
            c_axis, h_axis, w_axis = 1, 2, 3
        else:
            c_axis, h_axis, w_axis = 3, 1, 2

        kernel_h, kernel_w = self.kernel_size
        stride_h, stride_w = self.strides

        if self.output_padding is None:
            out_pad_h = out_pad_w = None
        else:
            out_pad_h, out_pad_w = self.output_padding

        output_shape[c_axis] = self.filters
        output_shape[h_axis] = conv_utils.deconv_output_length(
            output_shape[h_axis],
            kernel_h,
            padding=self.padding,
            output_padding=out_pad_h,
            stride=stride_h,
            dilation=self.dilation_rate[0])
        output_shape[w_axis] = conv_utils.deconv_output_length(
            output_shape[w_axis],
            kernel_w,
            padding=self.padding,
            output_padding=out_pad_w,
            stride=stride_w,
            dilation=self.dilation_rate[1])
        return tf.TensorShape(output_shape)
Esempio n. 2
0
    def call(self, inputs):
        inputs_shape = tf.shape(inputs)
        batch_size = inputs_shape[0]
        if self.data_format == "channels_first":
            t_axis = 2
        else:
            t_axis = 1

        length = inputs_shape[t_axis]
        if self.output_padding is None:
            output_padding = None
        else:
            output_padding = self.output_padding[0]

        # Infer the dynamic output shape:
        out_length = conv_utils.deconv_output_length(
            length,
            self.kernel_size[0],
            padding=self.padding,
            output_padding=output_padding,
            stride=self.strides[0],
            dilation=self.dilation_rate[0],
        )
        if self.data_format == "channels_first":
            output_shape = (batch_size, self.filters, out_length)
        else:
            output_shape = (batch_size, out_length, self.filters)
        data_format = conv_utils.convert_data_format(self.data_format, ndim=3)

        output_shape_tensor = tf.stack(output_shape)
        outputs = tf.nn.conv1d_transpose(
            inputs,
            self.kernel,
            output_shape_tensor,
            strides=self.strides,
            padding=self.padding.upper(),
            data_format=data_format,
            dilations=self.dilation_rate,
        )

        if not tf.executing_eagerly():
            # Infer the static output shape:
            out_shape = self.compute_output_shape(inputs.shape)
            outputs.set_shape(out_shape)

        if self.use_bias:
            outputs = tf.nn.bias_add(outputs,
                                     self.bias,
                                     data_format=data_format)

        if self.activation is not None:
            return self.activation(outputs)
        return outputs
Esempio n. 3
0
    def compute_output_shape(self, input_shape):
        input_shape = tf.TensorShape(input_shape).as_list()
        output_shape = list(input_shape)
        if self.data_format == "channels_first":
            c_axis, d_axis, h_axis, w_axis = 1, 2, 3, 4
        else:
            c_axis, d_axis, h_axis, w_axis = 4, 1, 2, 3

        kernel_d, kernel_h, kernel_w = self.kernel_size
        stride_d, stride_h, stride_w = self.strides

        if self.output_padding is None:
            out_pad_d = out_pad_h = out_pad_w = None
        else:
            out_pad_d, out_pad_h, out_pad_w = self.output_padding

        output_shape[c_axis] = self.filters
        output_shape[d_axis] = conv_utils.deconv_output_length(
            output_shape[d_axis],
            kernel_d,
            padding=self.padding,
            output_padding=out_pad_d,
            stride=stride_d,
        )
        output_shape[h_axis] = conv_utils.deconv_output_length(
            output_shape[h_axis],
            kernel_h,
            padding=self.padding,
            output_padding=out_pad_h,
            stride=stride_h,
        )
        output_shape[w_axis] = conv_utils.deconv_output_length(
            output_shape[w_axis],
            kernel_w,
            padding=self.padding,
            output_padding=out_pad_w,
            stride=stride_w,
        )
        return tf.TensorShape(output_shape)
Esempio n. 4
0
    def compute_output_shape(self, input_shape):
        input_shape = tf.TensorShape(input_shape).as_list()
        output_shape = list(input_shape)
        if self.data_format == 'channels_first':
            c_axis, t_axis = 1, 2
        else:
            c_axis, t_axis = 2, 1

        if self.output_padding is None:
            output_padding = None
        else:
            output_padding = self.output_padding[0]
        output_shape[c_axis] = self.filters
        output_shape[t_axis] = conv_utils.deconv_output_length(
            output_shape[t_axis],
            self.kernel_size[0],
            padding=self.padding,
            output_padding=output_padding,
            stride=self.strides[0],
            dilation=self.dilation_rate[0])
        return tf.TensorShape(output_shape)
Esempio n. 5
0
 def test_deconv_output_length(self):
     self.assertEqual(
         4, conv_utils.deconv_output_length(4, 2, 'same', stride=1))
     self.assertEqual(
         8, conv_utils.deconv_output_length(4, 2, 'same', stride=2))
     self.assertEqual(
         5, conv_utils.deconv_output_length(4, 2, 'valid', stride=1))
     self.assertEqual(
         8, conv_utils.deconv_output_length(4, 2, 'valid', stride=2))
     self.assertEqual(
         3, conv_utils.deconv_output_length(4, 2, 'full', stride=1))
     self.assertEqual(
         6, conv_utils.deconv_output_length(4, 2, 'full', stride=2))
     self.assertEqual(
         5,
         conv_utils.deconv_output_length(4,
                                         2,
                                         'same',
                                         output_padding=2,
                                         stride=1))
     self.assertEqual(
         7,
         conv_utils.deconv_output_length(4,
                                         2,
                                         'same',
                                         output_padding=1,
                                         stride=2))
     self.assertEqual(
         7,
         conv_utils.deconv_output_length(4,
                                         2,
                                         'valid',
                                         output_padding=2,
                                         stride=1))
     self.assertEqual(
         9,
         conv_utils.deconv_output_length(4,
                                         2,
                                         'valid',
                                         output_padding=1,
                                         stride=2))
     self.assertEqual(
         5,
         conv_utils.deconv_output_length(4,
                                         2,
                                         'full',
                                         output_padding=2,
                                         stride=1))
     self.assertEqual(
         7,
         conv_utils.deconv_output_length(4,
                                         2,
                                         'full',
                                         output_padding=1,
                                         stride=2))
     self.assertEqual(
         5,
         conv_utils.deconv_output_length(4,
                                         2,
                                         'same',
                                         output_padding=1,
                                         stride=1,
                                         dilation=2))
     self.assertEqual(
         12,
         conv_utils.deconv_output_length(4,
                                         2,
                                         'valid',
                                         output_padding=2,
                                         stride=2,
                                         dilation=3))
     self.assertEqual(
         6,
         conv_utils.deconv_output_length(4,
                                         2,
                                         'full',
                                         output_padding=2,
                                         stride=2,
                                         dilation=3))
Esempio n. 6
0
 def test_deconv_output_length(self):
     self.assertEqual(
         4, conv_utils.deconv_output_length(4, 2, "same", stride=1))
     self.assertEqual(
         8, conv_utils.deconv_output_length(4, 2, "same", stride=2))
     self.assertEqual(
         5, conv_utils.deconv_output_length(4, 2, "valid", stride=1))
     self.assertEqual(
         8, conv_utils.deconv_output_length(4, 2, "valid", stride=2))
     self.assertEqual(
         3, conv_utils.deconv_output_length(4, 2, "full", stride=1))
     self.assertEqual(
         6, conv_utils.deconv_output_length(4, 2, "full", stride=2))
     self.assertEqual(
         5,
         conv_utils.deconv_output_length(4,
                                         2,
                                         "same",
                                         output_padding=2,
                                         stride=1),
     )
     self.assertEqual(
         7,
         conv_utils.deconv_output_length(4,
                                         2,
                                         "same",
                                         output_padding=1,
                                         stride=2),
     )
     self.assertEqual(
         7,
         conv_utils.deconv_output_length(4,
                                         2,
                                         "valid",
                                         output_padding=2,
                                         stride=1),
     )
     self.assertEqual(
         9,
         conv_utils.deconv_output_length(4,
                                         2,
                                         "valid",
                                         output_padding=1,
                                         stride=2),
     )
     self.assertEqual(
         5,
         conv_utils.deconv_output_length(4,
                                         2,
                                         "full",
                                         output_padding=2,
                                         stride=1),
     )
     self.assertEqual(
         7,
         conv_utils.deconv_output_length(4,
                                         2,
                                         "full",
                                         output_padding=1,
                                         stride=2),
     )
     self.assertEqual(
         5,
         conv_utils.deconv_output_length(4,
                                         2,
                                         "same",
                                         output_padding=1,
                                         stride=1,
                                         dilation=2),
     )
     self.assertEqual(
         12,
         conv_utils.deconv_output_length(4,
                                         2,
                                         "valid",
                                         output_padding=2,
                                         stride=2,
                                         dilation=3),
     )
     self.assertEqual(
         6,
         conv_utils.deconv_output_length(4,
                                         2,
                                         "full",
                                         output_padding=2,
                                         stride=2,
                                         dilation=3),
     )
Esempio n. 7
0
    def call(self, inputs):
        inputs_shape = tf.shape(inputs)
        batch_size = inputs_shape[0]
        if self.data_format == 'channels_first':
            h_axis, w_axis = 2, 3
        else:
            h_axis, w_axis = 1, 2

        # Use the constant height and weight when possible.
        # TODO(scottzhu): Extract this into a utility function that can be applied
        # to all convolutional layers, which currently lost the static shape
        # information due to tf.shape().
        height, width = None, None
        if inputs.shape.rank is not None:
            dims = inputs.shape.as_list()
            height = dims[h_axis]
            width = dims[w_axis]
        height = height if height is not None else inputs_shape[h_axis]
        width = width if width is not None else inputs_shape[w_axis]

        kernel_h, kernel_w = self.kernel_size
        stride_h, stride_w = self.strides

        if self.output_padding is None:
            out_pad_h = out_pad_w = None
        else:
            out_pad_h, out_pad_w = self.output_padding

        # Infer the dynamic output shape:
        out_height = conv_utils.deconv_output_length(
            height,
            kernel_h,
            padding=self.padding,
            output_padding=out_pad_h,
            stride=stride_h,
            dilation=self.dilation_rate[0])
        out_width = conv_utils.deconv_output_length(
            width,
            kernel_w,
            padding=self.padding,
            output_padding=out_pad_w,
            stride=stride_w,
            dilation=self.dilation_rate[1])
        if self.data_format == 'channels_first':
            output_shape = (batch_size, self.filters, out_height, out_width)
        else:
            output_shape = (batch_size, out_height, out_width, self.filters)

        output_shape_tensor = tf.stack(output_shape)
        outputs = backend.conv2d_transpose(inputs,
                                           self.kernel,
                                           output_shape_tensor,
                                           strides=self.strides,
                                           padding=self.padding,
                                           data_format=self.data_format,
                                           dilation_rate=self.dilation_rate)

        if not tf.executing_eagerly():
            # Infer the static output shape:
            out_shape = self.compute_output_shape(inputs.shape)
            outputs.set_shape(out_shape)

        if self.use_bias:
            outputs = tf.nn.bias_add(
                outputs,
                self.bias,
                data_format=conv_utils.convert_data_format(self.data_format,
                                                           ndim=4))

        if self.activation is not None:
            return self.activation(outputs)
        return outputs
Esempio n. 8
0
    def call(self, inputs):
        inputs_shape = tf.shape(inputs)
        batch_size = inputs_shape[0]
        if self.data_format == 'channels_first':
            d_axis, h_axis, w_axis = 2, 3, 4
        else:
            d_axis, h_axis, w_axis = 1, 2, 3

        depth = inputs_shape[d_axis]
        height = inputs_shape[h_axis]
        width = inputs_shape[w_axis]

        kernel_d, kernel_h, kernel_w = self.kernel_size
        stride_d, stride_h, stride_w = self.strides

        if self.output_padding is None:
            out_pad_d = out_pad_h = out_pad_w = None
        else:
            out_pad_d, out_pad_h, out_pad_w = self.output_padding

        # Infer the dynamic output shape:
        out_depth = conv_utils.deconv_output_length(depth,
                                                    kernel_d,
                                                    padding=self.padding,
                                                    output_padding=out_pad_d,
                                                    stride=stride_d)
        out_height = conv_utils.deconv_output_length(height,
                                                     kernel_h,
                                                     padding=self.padding,
                                                     output_padding=out_pad_h,
                                                     stride=stride_h)
        out_width = conv_utils.deconv_output_length(width,
                                                    kernel_w,
                                                    padding=self.padding,
                                                    output_padding=out_pad_w,
                                                    stride=stride_w)
        if self.data_format == 'channels_first':
            output_shape = (batch_size, self.filters, out_depth, out_height,
                            out_width)
            strides = (1, 1, stride_d, stride_h, stride_w)
        else:
            output_shape = (batch_size, out_depth, out_height, out_width,
                            self.filters)
            strides = (1, stride_d, stride_h, stride_w, 1)

        output_shape_tensor = tf.stack(output_shape)
        outputs = tf.nn.conv3d_transpose(
            inputs,
            self.kernel,
            output_shape_tensor,
            strides,
            data_format=conv_utils.convert_data_format(self.data_format,
                                                       ndim=5),
            padding=self.padding.upper())

        if not tf.executing_eagerly():
            # Infer the static output shape:
            out_shape = self.compute_output_shape(inputs.shape)
            outputs.set_shape(out_shape)

        if self.use_bias:
            outputs = tf.nn.bias_add(
                outputs,
                self.bias,
                data_format=conv_utils.convert_data_format(self.data_format,
                                                           ndim=4))

        if self.activation is not None:
            return self.activation(outputs)
        return outputs