Exemplo n.º 1
0
    def call(self, inputs):
        # Apply the actual ops.
        if self.data_format == 'channels_last':
            strides = (1, ) + self.strides + (1, )
        else:
            strides = (1, 1) + self.strides
        outputs = tf.compat.v1.nn.separable_conv2d(
            inputs,
            self.depthwise_kernel,
            self.pointwise_kernel,
            strides=strides,
            padding=self.padding.upper(),
            rate=self.dilation_rate,
            data_format=conv_utils.convert_data_format(self.data_format,
                                                       ndim=4))

        if self.use_bias:
            outputs = tf.nn.bias_add(
                outputs,
                self.bias,
                data_format=conv_utils.convert_data_format(self.data_format,
                                                           ndim=4))

        if self.activation is not None:
            return self.activation(outputs)
        return outputs
Exemplo n.º 2
0
    def call(self, inputs):
        if self.data_format == "channels_last":
            strides = (1, ) + self.strides * 2 + (1, )
            spatial_start_dim = 1
        else:
            strides = (1, 1) + self.strides * 2
            spatial_start_dim = 2
        inputs = tf.expand_dims(inputs, spatial_start_dim)
        depthwise_kernel = tf.expand_dims(self.depthwise_kernel, axis=0)
        dilation_rate = (1, ) + self.dilation_rate

        outputs = tf.nn.depthwise_conv2d(
            inputs,
            depthwise_kernel,
            strides=strides,
            padding=self.padding.upper(),
            dilations=dilation_rate,
            data_format=conv_utils.convert_data_format(self.data_format,
                                                       ndim=4),
        )

        if self.use_bias:
            outputs = tf.nn.bias_add(
                outputs,
                self.bias,
                data_format=conv_utils.convert_data_format(self.data_format,
                                                           ndim=4),
            )

        outputs = tf.squeeze(outputs, [spatial_start_dim])

        if self.activation is not None:
            return self.activation(outputs)

        return outputs
Exemplo n.º 3
0
    def call(self, inputs):
        if self.padding == "causal":
            inputs = tf.pad(inputs, self._compute_causal_padding(inputs))
        if self.data_format == "channels_last":
            strides = (1, ) + self.strides * 2 + (1, )
            spatial_start_dim = 1
        else:
            strides = (1, 1) + self.strides * 2
            spatial_start_dim = 2

        # Explicitly broadcast inputs and kernels to 4D.
        # TODO(fchollet): refactor when a native separable_conv1d op is
        # available.
        inputs = tf.expand_dims(inputs, spatial_start_dim)
        depthwise_kernel = tf.expand_dims(self.depthwise_kernel, 0)
        pointwise_kernel = tf.expand_dims(self.pointwise_kernel, 0)
        dilation_rate = (1, ) + self.dilation_rate

        if self.padding == "causal":
            op_padding = "valid"
        else:
            op_padding = self.padding
        outputs = tf.compat.v1.nn.separable_conv2d(
            inputs,
            depthwise_kernel,
            pointwise_kernel,
            strides=strides,
            padding=op_padding.upper(),
            rate=dilation_rate,
            data_format=conv_utils.convert_data_format(self.data_format,
                                                       ndim=4),
        )

        if self.use_bias:
            outputs = tf.nn.bias_add(
                outputs,
                self.bias,
                data_format=conv_utils.convert_data_format(self.data_format,
                                                           ndim=4),
            )

        outputs = tf.squeeze(outputs, [spatial_start_dim])

        if self.activation is not None:
            return self.activation(outputs)
        return outputs
Exemplo n.º 4
0
  def test_convert_data_format(self):
    self.assertEqual('NCDHW', conv_utils.convert_data_format(
        'channels_first', 5))
    self.assertEqual('NCHW', conv_utils.convert_data_format(
        'channels_first', 4))
    self.assertEqual('NCW', conv_utils.convert_data_format('channels_first', 3))
    self.assertEqual('NHWC', conv_utils.convert_data_format('channels_last', 4))
    self.assertEqual('NWC', conv_utils.convert_data_format('channels_last', 3))
    self.assertEqual('NDHWC', conv_utils.convert_data_format(
        'channels_last', 5))

    with self.assertRaises(ValueError):
      conv_utils.convert_data_format('invalid', 2)
Exemplo n.º 5
0
    def call(self, inputs):
        inputs_shape = tf.shape(inputs)
        batch_size = inputs_shape[0]
        if self.data_format == "channels_first":
            t_axis = 2
        else:
            t_axis = 1

        length = inputs_shape[t_axis]
        if self.output_padding is None:
            output_padding = None
        else:
            output_padding = self.output_padding[0]

        # Infer the dynamic output shape:
        out_length = conv_utils.deconv_output_length(
            length,
            self.kernel_size[0],
            padding=self.padding,
            output_padding=output_padding,
            stride=self.strides[0],
            dilation=self.dilation_rate[0],
        )
        if self.data_format == "channels_first":
            output_shape = (batch_size, self.filters, out_length)
        else:
            output_shape = (batch_size, out_length, self.filters)
        data_format = conv_utils.convert_data_format(self.data_format, ndim=3)

        output_shape_tensor = tf.stack(output_shape)
        outputs = tf.nn.conv1d_transpose(
            inputs,
            self.kernel,
            output_shape_tensor,
            strides=self.strides,
            padding=self.padding.upper(),
            data_format=data_format,
            dilations=self.dilation_rate,
        )

        if not tf.executing_eagerly():
            # Infer the static output shape:
            out_shape = self.compute_output_shape(inputs.shape)
            outputs.set_shape(out_shape)

        if self.use_bias:
            outputs = tf.nn.bias_add(outputs,
                                     self.bias,
                                     data_format=data_format)

        if self.activation is not None:
            return self.activation(outputs)
        return outputs
Exemplo n.º 6
0
 def call(self, inputs):
     if self.data_format == 'channels_last':
         pool_shape = (1, ) + self.pool_size + (1, )
         strides = (1, ) + self.strides + (1, )
     else:
         pool_shape = (1, 1) + self.pool_size
         strides = (1, 1) + self.strides
     outputs = self.pool_function(
         inputs,
         ksize=pool_shape,
         strides=strides,
         padding=self.padding.upper(),
         data_format=conv_utils.convert_data_format(self.data_format, 4))
     return outputs
Exemplo n.º 7
0
    def test_convert_data_format(self):
        self.assertEqual("NCDHW",
                         conv_utils.convert_data_format("channels_first", 5))
        self.assertEqual("NCHW",
                         conv_utils.convert_data_format("channels_first", 4))
        self.assertEqual("NCW",
                         conv_utils.convert_data_format("channels_first", 3))
        self.assertEqual("NHWC",
                         conv_utils.convert_data_format("channels_last", 4))
        self.assertEqual("NWC",
                         conv_utils.convert_data_format("channels_last", 3))
        self.assertEqual("NDHWC",
                         conv_utils.convert_data_format("channels_last", 5))

        with self.assertRaises(ValueError):
            conv_utils.convert_data_format("invalid", 2)
Exemplo n.º 8
0
    def __init__(
        self,
        rank,
        filters,
        kernel_size,
        strides=1,
        padding="valid",
        data_format=None,
        dilation_rate=1,
        groups=1,
        activation=None,
        use_bias=True,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        trainable=True,
        name=None,
        conv_op=None,
        **kwargs,
    ):
        super().__init__(
            trainable=trainable,
            name=name,
            activity_regularizer=regularizers.get(activity_regularizer),
            **kwargs,
        )
        self.rank = rank

        if isinstance(filters, float):
            filters = int(filters)
        if filters is not None and filters <= 0:
            raise ValueError("Invalid value for argument `filters`. "
                             "Expected a strictly positive value. "
                             f"Received filters={filters}.")
        self.filters = filters
        self.groups = groups or 1
        self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank,
                                                      "kernel_size")
        self.strides = conv_utils.normalize_tuple(strides,
                                                  rank,
                                                  "strides",
                                                  allow_zero=True)
        self.padding = conv_utils.normalize_padding(padding)
        self.data_format = conv_utils.normalize_data_format(data_format)
        self.dilation_rate = conv_utils.normalize_tuple(
            dilation_rate, rank, "dilation_rate")

        self.activation = activations.get(activation)
        self.use_bias = use_bias

        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_constraint = constraints.get(bias_constraint)
        self.input_spec = InputSpec(min_ndim=self.rank + 2)

        self._validate_init()
        self._is_causal = self.padding == "causal"
        self._channels_first = self.data_format == "channels_first"
        self._tf_data_format = conv_utils.convert_data_format(
            self.data_format, self.rank + 2)
Exemplo n.º 9
0
    def call(self, inputs):
        inputs_shape = tf.shape(inputs)
        batch_size = inputs_shape[0]
        if self.data_format == 'channels_first':
            h_axis, w_axis = 2, 3
        else:
            h_axis, w_axis = 1, 2

        # Use the constant height and weight when possible.
        # TODO(scottzhu): Extract this into a utility function that can be applied
        # to all convolutional layers, which currently lost the static shape
        # information due to tf.shape().
        height, width = None, None
        if inputs.shape.rank is not None:
            dims = inputs.shape.as_list()
            height = dims[h_axis]
            width = dims[w_axis]
        height = height if height is not None else inputs_shape[h_axis]
        width = width if width is not None else inputs_shape[w_axis]

        kernel_h, kernel_w = self.kernel_size
        stride_h, stride_w = self.strides

        if self.output_padding is None:
            out_pad_h = out_pad_w = None
        else:
            out_pad_h, out_pad_w = self.output_padding

        # Infer the dynamic output shape:
        out_height = conv_utils.deconv_output_length(
            height,
            kernel_h,
            padding=self.padding,
            output_padding=out_pad_h,
            stride=stride_h,
            dilation=self.dilation_rate[0])
        out_width = conv_utils.deconv_output_length(
            width,
            kernel_w,
            padding=self.padding,
            output_padding=out_pad_w,
            stride=stride_w,
            dilation=self.dilation_rate[1])
        if self.data_format == 'channels_first':
            output_shape = (batch_size, self.filters, out_height, out_width)
        else:
            output_shape = (batch_size, out_height, out_width, self.filters)

        output_shape_tensor = tf.stack(output_shape)
        outputs = backend.conv2d_transpose(inputs,
                                           self.kernel,
                                           output_shape_tensor,
                                           strides=self.strides,
                                           padding=self.padding,
                                           data_format=self.data_format,
                                           dilation_rate=self.dilation_rate)

        if not tf.executing_eagerly():
            # Infer the static output shape:
            out_shape = self.compute_output_shape(inputs.shape)
            outputs.set_shape(out_shape)

        if self.use_bias:
            outputs = tf.nn.bias_add(
                outputs,
                self.bias,
                data_format=conv_utils.convert_data_format(self.data_format,
                                                           ndim=4))

        if self.activation is not None:
            return self.activation(outputs)
        return outputs
Exemplo n.º 10
0
    def call(self, inputs):
        inputs_shape = tf.shape(inputs)
        batch_size = inputs_shape[0]
        if self.data_format == 'channels_first':
            d_axis, h_axis, w_axis = 2, 3, 4
        else:
            d_axis, h_axis, w_axis = 1, 2, 3

        depth = inputs_shape[d_axis]
        height = inputs_shape[h_axis]
        width = inputs_shape[w_axis]

        kernel_d, kernel_h, kernel_w = self.kernel_size
        stride_d, stride_h, stride_w = self.strides

        if self.output_padding is None:
            out_pad_d = out_pad_h = out_pad_w = None
        else:
            out_pad_d, out_pad_h, out_pad_w = self.output_padding

        # Infer the dynamic output shape:
        out_depth = conv_utils.deconv_output_length(depth,
                                                    kernel_d,
                                                    padding=self.padding,
                                                    output_padding=out_pad_d,
                                                    stride=stride_d)
        out_height = conv_utils.deconv_output_length(height,
                                                     kernel_h,
                                                     padding=self.padding,
                                                     output_padding=out_pad_h,
                                                     stride=stride_h)
        out_width = conv_utils.deconv_output_length(width,
                                                    kernel_w,
                                                    padding=self.padding,
                                                    output_padding=out_pad_w,
                                                    stride=stride_w)
        if self.data_format == 'channels_first':
            output_shape = (batch_size, self.filters, out_depth, out_height,
                            out_width)
            strides = (1, 1, stride_d, stride_h, stride_w)
        else:
            output_shape = (batch_size, out_depth, out_height, out_width,
                            self.filters)
            strides = (1, stride_d, stride_h, stride_w, 1)

        output_shape_tensor = tf.stack(output_shape)
        outputs = tf.nn.conv3d_transpose(
            inputs,
            self.kernel,
            output_shape_tensor,
            strides,
            data_format=conv_utils.convert_data_format(self.data_format,
                                                       ndim=5),
            padding=self.padding.upper())

        if not tf.executing_eagerly():
            # Infer the static output shape:
            out_shape = self.compute_output_shape(inputs.shape)
            outputs.set_shape(out_shape)

        if self.use_bias:
            outputs = tf.nn.bias_add(
                outputs,
                self.bias,
                data_format=conv_utils.convert_data_format(self.data_format,
                                                           ndim=4))

        if self.activation is not None:
            return self.activation(outputs)
        return outputs