def call(self, inputs):
    outputs = nn.convolution(
        input=inputs,
        filter=self.kernel,
        dilation_rate=self.dilation_rate,
        strides=self.strides,
        padding=self.padding.upper(),
        data_format=utils.convert_data_format(self.data_format, self.rank + 2))
    if self.bias is not None:
      if self.rank != 2 and self.data_format == 'channels_first':
        # bias_add does not support channels_first for non-4D inputs.
        if self.rank == 1:
          bias = array_ops.reshape(self.bias, (1, self.filters, 1))
        if self.rank == 3:
          bias = array_ops.reshape(self.bias, (1, self.filters, 1, 1))
        outputs += bias
      else:
        outputs = nn.bias_add(
            outputs,
            self.bias,
            data_format=utils.convert_data_format(self.data_format, 4))
        # Note that we passed rank=4 because bias_add will only accept
        # NHWC and NCWH even if the rank of the inputs is 3 or 5.

    if self.activation is not None:
      return self.activation(outputs)
    return outputs
Example #2
0
    def call(self, inputs, training=True):
        inputs_shape = array_ops.shape(inputs)
        batch_size = inputs_shape[0]
        if self.data_format == 'channels_first':
            c_axis, h_axis, w_axis = 1, 2, 3
        else:
            c_axis, h_axis, w_axis = 3, 1, 2

        height, width = inputs_shape[h_axis], inputs_shape[w_axis]
        kernel_h, kernel_w = self.kernel_size
        stride_h, stride_w = self.strides

        # Infer the dynamic output shape:
        out_height = utils.deconv_output_length(height,
                                                kernel_h,
                                                self.padding,
                                                stride_h)
        out_width = utils.deconv_output_length(width,
                                               kernel_w,
                                               self.padding,
                                               stride_w)
        if self.data_format == 'channels_first':
            output_shape = (batch_size, self.filters, out_height, out_width)
            strides = (1, 1, stride_h, stride_w)
        else:
            output_shape = (batch_size, out_height, out_width, self.filters)
            strides = (1, stride_h, stride_w, 1)

        output_shape_tensor = array_ops.stack(output_shape)
        outputs = nn.conv2d_transpose(
            inputs,
            self.compute_spectral_normal(training=training),
            output_shape_tensor,
            strides,
            padding=self.padding.upper(),
            data_format=utils.convert_data_format(self.data_format, ndim=4))

        if not context.executing_eagerly():
            # Infer the static output shape:
            out_shape = inputs.get_shape().as_list()
            out_shape[c_axis] = self.filters
            out_shape[h_axis] = utils.deconv_output_length(out_shape[h_axis],
                                                           kernel_h,
                                                           self.padding,
                                                           stride_h)
            out_shape[w_axis] = utils.deconv_output_length(out_shape[w_axis],
                                                           kernel_w,
                                                           self.padding,
                                                           stride_w)
            outputs.set_shape(out_shape)

        if self.use_bias:
            outputs = nn.bias_add(
                outputs,
                self.bias,
                data_format=utils.convert_data_format(self.data_format, ndim=4))

        if self.activation is not None:
            return self.activation(outputs)
        return outputs
  def call(self, inputs):
    inputs_shape = array_ops.shape(inputs)
    batch_size = inputs_shape[0]
    if self.data_format == 'channels_first':
      c_axis, h_axis, w_axis = 1, 2, 3
    else:
      c_axis, h_axis, w_axis = 3, 1, 2

    height, width = inputs_shape[h_axis], inputs_shape[w_axis]
    kernel_h, kernel_w = self.kernel_size
    stride_h, stride_w = self.strides

    def get_deconv_dim(dim_size, stride_size, kernel_size, padding):
      if isinstance(dim_size, ops.Tensor):
        dim_size = math_ops.mul(dim_size, stride_size)
      elif dim_size is not None:
        dim_size *= stride_size

      if padding == 'valid' and dim_size is not None:
        dim_size += max(kernel_size - stride_size, 0)
      return dim_size

    # Infer the dynamic output shape:
    out_height = get_deconv_dim(height, stride_h, kernel_h, self.padding)
    out_width = get_deconv_dim(width, stride_w, kernel_w, self.padding)

    if self.data_format == 'channels_first':
      output_shape = (batch_size, self.filters, out_height, out_width)
      strides = (1, 1, stride_h, stride_w)
    else:
      output_shape = (batch_size, out_height, out_width, self.filters)
      strides = (1, stride_h, stride_w, 1)

    output_shape_tensor = array_ops.stack(output_shape)
    outputs = nn.conv2d_transpose(
        inputs,
        self.kernel,
        output_shape_tensor,
        strides,
        padding=self.padding.upper(),
        data_format=utils.convert_data_format(self.data_format, ndim=4))

    # Infer the static output shape:
    out_shape = inputs.get_shape().as_list()
    out_shape[c_axis] = self.filters
    out_shape[h_axis] = get_deconv_dim(
        out_shape[h_axis], stride_h, kernel_h, self.padding)
    out_shape[w_axis] = get_deconv_dim(
        out_shape[w_axis], stride_w, kernel_w, self.padding)
    outputs.set_shape(out_shape)

    if self.bias:
      outputs = nn.bias_add(
          outputs,
          self.bias,
          data_format=utils.convert_data_format(self.data_format, ndim=4))

    if self.activation is not None:
      return self.activation(outputs)
    return outputs
Example #4
0
    def call(self, inputs):
        inputs_shape = array_ops.shape(inputs)
        batch_size = inputs_shape[0]
        if self.data_format == 'channels_first':
            c_axis, h_axis, w_axis = 1, 2, 3
        else:
            c_axis, h_axis, w_axis = 3, 1, 2

        height, width = inputs_shape[h_axis], inputs_shape[w_axis]
        kernel_h, kernel_w = self.kernel_size
        stride_h, stride_w = self.strides

        # Infer the dynamic output shape:
        out_height = utils.deconv_output_length(height, kernel_h, self.padding,
                                                stride_h)
        out_width = utils.deconv_output_length(width, kernel_w, self.padding,
                                               stride_w)
        if self.data_format == 'channels_first':
            output_shape = (batch_size, self.filters, out_height, out_width)
            strides = (1, 1, stride_h, stride_w)
        else:
            output_shape = (batch_size, out_height, out_width, self.filters)
            strides = (1, stride_h, stride_w, 1)

        output_shape_tensor = array_ops.stack(output_shape)

        kernel_norm = nn.l2_normalize(self.kernel, [0, 1, 3])
        if self.use_scale:
            kernel_norm = tf.reshape(self.scale,
                                     [1, 1, self.filters, 1]) * kernel_norm

        outputs = nn.conv2d_transpose(inputs,
                                      kernel_norm,
                                      output_shape_tensor,
                                      strides,
                                      padding=self.padding.upper(),
                                      data_format=utils.convert_data_format(
                                          self.data_format, ndim=4))

        if context.in_graph_mode():
            # Infer the static output shape:
            out_shape = inputs.get_shape().as_list()
            out_shape[c_axis] = self.filters
            out_shape[h_axis] = utils.deconv_output_length(
                out_shape[h_axis], kernel_h, self.padding, stride_h)
            out_shape[w_axis] = utils.deconv_output_length(
                out_shape[w_axis], kernel_w, self.padding, stride_w)
            outputs.set_shape(out_shape)

        if self.use_bias:
            outputs = nn.bias_add(outputs,
                                  self.bias,
                                  data_format=utils.convert_data_format(
                                      self.data_format, ndim=4))

        if self.activation is not None:
            return self.activation(outputs)
        return outputs
  def call(self, inputs):
    if self.data_format == 'channels_first':
      # Reshape to channels last
      inputs = array_ops.transpose(inputs, (0, 2, 3, 1))

    # Apply the actual ops.
    outputs = nn.separable_conv2d(
        inputs,
        self.depthwise_kernel,
        self.pointwise_kernel,
        strides=(1,) + self.strides + (1,),
        padding=self.padding.upper(),
        rate=self.dilation_rate)

    if self.data_format == 'channels_first':
      # Reshape to channels first
      outputs = array_ops.transpose(outputs, (0, 3, 1, 2))

    if self.bias:
      outputs = nn.bias_add(
          outputs,
          self.bias,
          data_format=utils.convert_data_format(self.data_format, ndim=4))

    if self.activation is not None:
      return self.activation(outputs)
    return outputs
Example #6
0
  def call(self, inputs):
    outputs = nn.convolution(
        input=inputs,
        filter=self.masked_kernel,
        dilation_rate=self.dilation_rate,
        strides=self.strides,
        padding=self.padding.upper(),
        data_format=utils.convert_data_format(self.data_format, self.rank + 2))

    if self.bias is not None:
      if self.data_format == 'channels_first':
        if self.rank == 1:
          # nn.bias_add does not accept a 1D input tensor.
          bias = array_ops.reshape(self.bias, (1, self.filters, 1))
          outputs += bias
        if self.rank == 2:
          outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
        if self.rank == 3:
          # As of Mar 2017, direct addition is significantly slower than
          # bias_add when computing gradients. To use bias_add, we collapse Z
          # and Y into a single dimension to obtain a 4D input tensor.
          outputs_shape = outputs.shape.as_list()
          outputs_4d = array_ops.reshape(outputs, [
              outputs_shape[0], outputs_shape[1],
              outputs_shape[2] * outputs_shape[3], outputs_shape[4]
          ])
          outputs_4d = nn.bias_add(outputs_4d, self.bias, data_format='NCHW')
          outputs = array_ops.reshape(outputs_4d, outputs_shape)
      else:
        outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')

    if self.activation is not None:
      return self.activation(outputs)
    return outputs
    def call(self, inputs):
        if self.data_format == 'channels_first':
            # Reshape to channels last
            inputs = array_ops.transpose(inputs, (0, 2, 3, 1))

        # Apply the actual ops.
        outputs = separable_conv2d_tf_nn(inputs,
                                         self.depthwise_kernel,
                                         self.pointwise_kernel,
                                         strides=(1, ) + self.strides + (1, ),
                                         padding=self.padding.upper(),
                                         rate=self.dilation_rate)

        if self.data_format == 'channels_first':
            # Reshape to channels first
            outputs = array_ops.transpose(outputs, (0, 3, 1, 2))

        if self.bias is not None:
            outputs = nn.bias_add(outputs,
                                  self.bias,
                                  data_format=utils.convert_data_format(
                                      self.data_format, ndim=4))

        if self.activation is not None:
            return self.activation(outputs)
        return outputs
Example #8
0
  def call(self, inputs):
    if (self.data_format == 'channels_first' and
        not framework.test_util.gpu_device_name()):
      # `nn.convolution` is not implemented on CPU for `channels_first` format.
      # TODO(chollet): remove this when `nn.convolution` is feature-complete.
      data_format = 'channels_last'
      inputs = array_ops.transpose(inputs, (0, 2, 3, 1))
    else:
      data_format = self.data_format

    if data_format == 'channels_last':
      pool_shape = (1,) + self.pool_size + (1,)
      strides = (1,) + self.strides + (1,)
    else:
      pool_shape = (1, 1) + self.pool_size
      strides = (1, 1) + self.strides
    outputs = self.pool_function(
        inputs,
        ksize=pool_shape,
        strides=strides,
        padding=self.padding.upper(),
        data_format=utils.convert_data_format(data_format, 4))

    if (self.data_format == 'channels_first' and
        not framework.test_util.gpu_device_name()):
      outputs = array_ops.transpose(outputs, (0, 3, 1, 2))
    return outputs
Example #9
0
  def testConvertDataFormat(self):
    self.assertEqual('NCDHW', utils.convert_data_format('channels_first', 5))
    self.assertEqual('NCHW', utils.convert_data_format('channels_first', 4))
    self.assertEqual('NCW', utils.convert_data_format('channels_first', 3))
    self.assertEqual('NHWC', utils.convert_data_format('channels_last', 4))
    self.assertEqual('NWC', utils.convert_data_format('channels_last', 3))
    self.assertEqual('NDHWC', utils.convert_data_format('channels_last', 5))

    with self.assertRaises(ValueError):
      utils.convert_data_format('invalid', 2)
Example #10
0
  def testConvertDataFormat(self):
    self.assertEqual(utils.convert_data_format('channels_first', 5), 'NCDHW')
    self.assertEqual(utils.convert_data_format('channels_first', 4), 'NCHW')
    self.assertEqual(utils.convert_data_format('channels_first', 3), 'NCW')
    self.assertEqual(utils.convert_data_format('channels_last', 4), 'NHWC')
    self.assertEqual(utils.convert_data_format('channels_last', 3), 'NWC')
    self.assertEqual(utils.convert_data_format('channels_last', 5), 'NDHWC')

    with self.assertRaises(ValueError):
      utils.convert_data_format('invalid', 2)
Example #11
0
    def call(self, inputs):
        # quantize the weights, if there is an weight quantizer
        if self.weight_quantizer is not None:
                used_kernel = self.weight_quantizer.quantize(self.kernel)
        else:
                used_kernel = self.kernel
        # if intrinsic quantization, apply intr. quantization to weights, too!
        if self.quantizer is not None:
                used_kernel = self.quantizer.quantize(used_kernel)

        if self.rank == 2:          
          if self.quantizer is None:
            outputs = nn.convolution(
            input=inputs,
            filter=used_kernel,
            dilation_rate=self.dilation_rate,
            strides=self.strides,
            padding=self.padding.upper(),
            data_format=utils.convert_data_format(self.data_format, self.rank + 2))
          else: # with quantization
            outputs = q2dconvolution(input=inputs, filter=used_kernel, quantizer=self.quantizer,
                padding=self.padding.upper(), strides=self.strides, dilation_rate=self.dilation_rate,
                data_format=utils.convert_data_format(self.data_format, self.rank + 2))
        else:
          raise ValueError("quantized convolution not supported for input rank %d" % (self.rank))
        if self.bias is not None:
          if self.rank != 2 and self.data_format == 'channels_first':
            # bias_add does not support channels_first for non-4D inputs.
            if self.rank == 1:
              bias = array_ops.reshape(self.bias, (1, self.filters, 1))
            if self.rank == 3:
              bias = array_ops.reshape(self.bias, (1, self.filters, 1, 1))
            outputs += bias
          else:
            outputs = nn.bias_add(
                outputs,
                self.bias,
                data_format=utils.convert_data_format(self.data_format, 4))
            # Note that we passed rank=4 because bias_add will only accept
            # NHWC and NCWH even if the rank of the inputs is 3 or 5.
        if self.quantizer is not None:         # quantize after activation
            outputs = self.quantizer.quantize(outputs)
        if self.activation is not None:
          outputs = self.activation(outputs)
        return outputs
    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis].value is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        input_dim = input_shape[channel_axis].value
        kernel_shape = self.kernel_size + (input_dim, self.filters)

        self.kernel = self.add_variable(name='kernel',
                                        shape=kernel_shape,
                                        initializer=self.kernel_initializer,
                                        regularizer=self.kernel_regularizer,
                                        constraint=self.kernel_constraint,
                                        trainable=True,
                                        dtype=self.dtype)

        if self.weight_norm:
            self.V = self.add_variable(
                name='V_weight_norm',
                shape=kernel_shape,
                dtype=tf.float32,
                initializer=tf.random_normal_initializer(0, 0.05),
                trainable=True)
            self.g = self.add_variable(name='g_weight_norm',
                                       shape=(self.filters, ),
                                       initializer=init_ops.ones_initializer(),
                                       dtype=self.dtype,
                                       trainable=True)
        if self.mean_only_batch_norm:
            self.batch_norm_running_average = []

        if self.use_bias:
            self.bias = self.add_variable(name='bias',
                                          shape=(self.filters, ),
                                          initializer=self.bias_initializer,
                                          regularizer=self.bias_regularizer,
                                          constraint=self.bias_constraint,
                                          trainable=True,
                                          dtype=self.dtype)
        else:
            self.bias = None
        self.input_spec = base.InputSpec(ndim=self.rank + 2,
                                         axes={channel_axis: input_dim})
        self._convolution_op = nn_ops.Convolution(
            input_shape,
            filter_shape=self.kernel.get_shape(),
            dilation_rate=self.dilation_rate,
            strides=self.strides,
            padding=self.padding.upper(),
            data_format=utils.convert_data_format(self.data_format,
                                                  self.rank + 2))
        self.built = True
Example #13
0
 def call(self, inputs):
     outputs = self.pool_function(
         inputs,
         window_shape=self.pool_size,
         pooling_type="MAX",
         strides=self.strides,
         padding=self.padding.upper(),
         dilation_rate=(self.dilation_rate, self.dilation_rate),
         data_format=utils.convert_data_format(self.data_format, 4))
     return outputs
Example #14
0
    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1

        if input_shape[channel_axis].value is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        input_dim = input_shape[channel_axis].value
        # kernel_shape=(self.kernel_size, input_dim, self.filters)
        kernel_shape = self.kernel_size + (input_dim, self.filters)

        kernel = self.add_variable(name='kernel',
                                   shape=kernel_shape,
                                   initializer=self.kernel_initializer,
                                   regularizer=self.kernel_regularizer,
                                   constraint=self.kernel_constraint,
                                   trainable=True,
                                   dtype=self.dtype)
        # weight normalization
        if self.weight_norm:
            g = self.add_variable(name='wn/g',
                                  shape=(self.filters, ),
                                  initializer=init_ops.ones_initializer(),
                                  dtype=kernel.dtype,
                                  trainable=True)
            self.kernel = tf.reshape(
                g, [1, 1, self.filters]) * nn_impl.l2_normalize(
                    kernel, [0, 1])
        else:
            self.kernel = kernel

        if self.use_bias:
            self.bias = self.add_variable(name='bias',
                                          shape=(self.filters, ),
                                          initializer=self.bias_initializer,
                                          regularizer=self.bias_regularizer,
                                          constraint=self.bias_constraint,
                                          trainable=True,
                                          dtype=self.dtype)
        else:
            self.bias = None
        self.input_spec = base.InputSpec(ndim=self.rank + 2,
                                         axes={channel_axis: input_dim})
        self._convolution_op = nn_ops.Convolution(
            input_shape,
            filter_shape=self.kernel.get_shape(),
            dilation_rate=self.dilation_rate,
            strides=self.strides,
            padding=self.padding.upper(),
            data_format=utils.convert_data_format(self.data_format,
                                                  self.rank + 2))
        self.built = True
Example #15
0
    def build(self, input_shape):
        input_shape = tf.TensorShape(input_shape)
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        if input_shape[channel_axis].value is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        input_dim = input_shape[channel_axis].value
        kernel_shape = self.kernel_size + (input_dim, self.filters)
        self.kernel_mu = self.add_variable(
            'posterior_kernel_mu',
            shape=kernel_shape,
            initializer=self.kernel_mu_initializer,
            trainable=True,
            dtype=self.dtype)
        self.kernel_rho = self.add_variable(
            'posterior_kernel_rho',
            shape=kernel_shape,
            initializer=self.kernel_rho_initializer,
            trainable=True,
            dtype=self.dtype)

        if self.use_bias:
            self.bias_mu = self.add_variable(
                'posterior_bias_mu',
                shape=[
                    self.filters,
                ],
                initializer=self.bias_mu_initializer,
                dtype=self.dtype,
                trainable=True)
            self.bias_rho = self.add_variable(
                'posterior_bias_rho',
                shape=[
                    self.filters,
                ],
                initializer=self.bias_rho_initializer)
        else:
            self.bias_mu = None
            self.bias_rho = None
        self.input_spec = base.InputSpec(ndim=self.rank + 2,
                                         axes={channel_axis: input_dim})
        self._convolution_op = nn_ops.Convolution(
            input_shape,
            filter_shape=self.kernel_mu.get_shape(),
            dilation_rate=self.dilation_rate,
            strides=self.strides,
            padding=self.padding.upper(),
            data_format=utils.convert_data_format(self.data_format,
                                                  self.rank + 2))
        self.built = True
Example #16
0
 def build(self, input_shape):
     input_shape = tf.TensorShape(input_shape)
     if self.data_format == "channels_first":
         channel_axis = 1
     else:
         channel_axis = -1
     input_dim = tf.compat.dimension_value(input_shape[channel_axis])
     if input_dim is None:
         raise ValueError("The channel dimension of inputs Found `None`.")
     kernel_shape = self.kernel_size + (input_dim, self.filters)
     # If self.dtype is None, build weights using the default dtype.
     dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx())
     # Must have a posterior kernel.
     self.kernel_posterior = self.kernel_posterior_fn(
         dtype, kernel_shape, "kernel_posterior", self.trainable, self.add_variable
     )
     if self.kernel_prior_fn is None:
         self.kernel_prior = None
     else:
         self.kernel_prior = self.kernel_prior_fn(
             dtype, kernel_shape, "kernel_prior", self.trainable, self.add_variable
         )
     if self.bias_posterior_fn is None:
         self.bias_posterior = None
     else:
         self.bias_posterior = self.bias_posterior_fn(
             dtype,
             (self.filters,),
             "bias_posterior",
             self.trainable,
             self.add_variable,
         )
     if self.bias_prior_fn is None:
         self.bias_prior = None
     else:
         self.bias_prior = self.bias_prior_fn(
             dtype, (self.filters,), "bias_prior", self.trainable, self.add_variable
         )
     self.input_spec = tf.keras.layers.InputSpec(
         ndim=self.rank + 2, axes={channel_axis: input_dim}
     )
     self._convolution_op = nn_ops.Convolution(
         input_shape,
         filter_shape=tf.TensorShape(kernel_shape),
         dilation_rate=self.dilation_rate,
         strides=self.strides,
         padding=self.padding.upper(),
         data_format=tf_layers_util.convert_data_format(
             self.data_format, self.rank + 2
         ),
     )
     self.built = True
Example #17
0
 def call(self, inputs):
     if self.data_format == 'channels_last':
         pool_shape = (1, ) + self.pool_size + (1, )
         strides = (1, ) + self.strides + (1, )
     else:
         pool_shape = (1, 1) + self.pool_size
         strides = (1, 1) + self.strides
     return self.pool_function(inputs,
                               ksize=pool_shape,
                               strides=strides,
                               padding=self.padding.upper(),
                               data_format=utils.convert_data_format(
                                   self.data_format, 4))
Example #18
0
 def call(self, inputs):
   if self.data_format == 'channels_last':
     pool_shape = (1,) + self.pool_size + (1,)
     strides = (1,) + self.strides + (1,)
   else:
     pool_shape = (1, 1) + self.pool_size
     strides = (1, 1) + self.strides
   return self.pool_function(
       inputs,
       ksize=pool_shape,
       strides=strides,
       padding=self.padding.upper(),
       data_format=utils.convert_data_format(self.data_format, 4))
Example #19
0
    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)

        dau_params_shape = self.get_dau_variable_shape(input_shape)
        if self.dau_weights is None:
            self.dau_weights = self.add_dau_weights_var(input_shape)
        elif np.any(self.dau_weights.shape != dau_params_shape):
            raise ValueError('Shape mismatch for variable `dau_weights`')
        if self.dau_mu1 is None:
            self.dau_mu1 = self.add_dau_mu1_var(input_shape)
        elif np.any(self.dau_mu1.shape != dau_params_shape):
            raise ValueError('Shape mismatch for variable `dau_mu1`')

        if self.dau_mu2 is None:
            self.dau_mu2 = self.add_dau_mu2_var(input_shape)
        elif np.any(self.dau_mu2.shape != dau_params_shape):
            raise ValueError('Shape mismatch for variable `dau_mu2`')
        if self.dau_sigma is None:
            self.dau_sigma = self.add_dau_sigma_var(input_shape)
        elif np.any(self.dau_sigma.shape != dau_params_shape):
            raise ValueError('Shape mismatch for variable `dau_sigma`')

        if self.use_bias:
            self.bias = self.add_bias_var()
        else:
            self.bias = None

        input_channel_axis = self._get_input_channel_axis()
        num_input_channels = self._get_input_channels(input_shape)

        self.input_spec = base.InputSpec(
            ndim=self.rank + 2, axes={input_channel_axis: num_input_channels})

        self._dau_convolution_op = _DAUConvolution2d(
            input_shape,
            num_output=self.filters,
            dau_units=self.dau_units,
            max_kernel_size=self.max_kernel_size,
            padding=self.padding,
            strides=1,
            num_dau_units_ignore=self.num_dau_units_ignore,
            mu_learning_rate_factor=self.mu_learning_rate_factor,
            dau_unit_border_bound=self.dau_unit_border_bound,
            dau_unit_single_dim=self.dau_unit_single_dim,
            dau_aggregation_forbid_positive_dim1=self.
            dau_aggregation_forbid_positive_dim1,
            unit_testing=self.unit_testing,
            data_format=utils.convert_data_format(self.data_format,
                                                  self.rank + 2))
        self.built = True
Example #20
0
    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        # pylint: disable=no-member
        if input_shape[channel_axis].value is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        # pylint: disable=no-member
        input_dim = input_shape[channel_axis].value
        kernel_shape = self.kernel_size + (input_dim, self.filters)

        # The variables defined below are specific to the weight normed conv class
        self.kernel_v = self.add_variable(name='kernel_v',
                                          shape=kernel_shape,
                                          initializer=self.kernel_initializer,
                                          regularizer=self.kernel_regularizer,
                                          constraint=self.kernel_constraint,
                                          trainable=True,
                                          dtype=self.dtype)
        self.kernel_g = self.add_variable(name='kernel_g',
                                          shape=[],
                                          trainable=True,
                                          dtype=self.dtype)
        self.kernel = self.kernel_g * tf.nn.l2_normalize(self.kernel_v)

        if self.use_bias:
            self.bias = self.add_variable(name='bias',
                                          shape=(self.filters, ),
                                          initializer=self.bias_initializer,
                                          regularizer=self.bias_regularizer,
                                          constraint=self.bias_constraint,
                                          trainable=True,
                                          dtype=self.dtype)
        else:
            self.bias = None
        self.input_spec = base.InputSpec(ndim=self.rank + 2,
                                         axes={channel_axis: input_dim})
        self._convolution_op = nn_ops.Convolution(
            input_shape,
            filter_shape=self.kernel.get_shape(),
            dilation_rate=self.dilation_rate,
            strides=self.strides,
            padding=self.padding.upper(),
            data_format=utils.convert_data_format(self.data_format,
                                                  self.rank + 2))
        self.built = True
Example #21
0
    def call(self, inputs):
        # Apply the actual ops.
        if self.data_format == 'channels_last':
            strides = (1, ) + self.strides + (1, )
        else:
            strides = (1, 1) + self.strides
        outputs = tf.nn.depthwise_conv2d(inputs,
                                         self.kernel,
                                         strides=strides,
                                         padding=self.padding.upper(),
                                         rate=self.dilation_rate,
                                         data_format=utils.convert_data_format(
                                             self.data_format, ndim=4))

        if self.use_bias:
            outputs = tf.nn.bias_add(outputs,
                                     self.bias,
                                     data_format=utils.convert_data_format(
                                         self.data_format, ndim=4))

        if self.activation is not None:
            return self.activation(outputs)

        return outputs
Example #22
0
    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)

        dau_params_shape = self.get_dau_variable_shape(input_shape)
        if self.dau_weights is None:
            self.dau_weights = self.add_dau_weights_var(input_shape)
        elif np.any(self.dau_weights.shape != dau_params_shape):
            raise ValueError('Shape mismatch for variable `dau_weights`')
        if self.dau_mu1 is None:
            self.dau_mu1 = self.add_dau_mu1_var(input_shape)
        elif np.any(self.dau_mu1.shape != dau_params_shape):
            raise ValueError('Shape mismatch for variable `dau_mu1`')

        if self.dau_mu2 is None:
            self.dau_mu2 = self.add_dau_mu2_var(input_shape)
        elif np.any(self.dau_mu2.shape != dau_params_shape):
            raise ValueError('Shape mismatch for variable `dau_mu2`')
        if self.dau_sigma is None:
            self.dau_sigma = self.add_dau_sigma_var(input_shape, trainable=self.dau_sigma_trainable)
        elif np.any(self.dau_sigma.shape != dau_params_shape):
            raise ValueError('Shape mismatch for variable `dau_sigma`')

        if self.use_bias:
            self.bias = self.add_bias_var()
        else:
            self.bias = None

        input_channel_axis = self._get_input_channel_axis()
        num_input_channels = self._get_input_channels(input_shape)

        self.input_spec = base.InputSpec(ndim=self.rank + 2,
                                         axes={input_channel_axis: num_input_channels})

        kernel_shape = tf.TensorShape((self.max_kernel_size, self.max_kernel_size, num_input_channels, self.filters))

        self._convolution_op = nn_ops.Convolution(
            input_shape,
            filter_shape=kernel_shape,
            dilation_rate=(1,1),
            strides=(self.strides,self.strides),
            padding="SAME",
            data_format=utils.convert_data_format(self.data_format,
                                                  self.rank + 2))
        self.built = True
Example #23
0
  def call(self, inputs):
    outputs = nn_ops.convolution(  # modified by Lin
        input=inputs,
        filter=self.kernel,
        dilation_rate=self.dilation_rate,
        strides=self.strides,
        padding=self.padding.upper(),
        quantized=self.quantized, # add by Lin
        quantization_params=self.quantization_params, # add by Lin
        data_format=utils.convert_data_format(self.data_format, self.rank + 2))

    if self.bias is not None:
      if self.data_format == 'channels_first':
        # bias_add only supports NHWC.
        # TODO(fchollet): remove this when `bias_add` is feature-complete.
        if self.rank == 1:
          bias = array_ops.reshape(self.bias, (1, self.filters, 1))
          outputs += bias
        if self.rank == 2:
          bias = array_ops.reshape(self.bias, (1, self.filters, 1, 1))
          outputs += bias
        if self.rank == 3:
          # As of Mar 2017, direct addition is significantly slower than
          # bias_add when computing gradients. To use bias_add, we collapse Z
          # and Y into a single dimension to obtain a 4D input tensor.
          outputs_shape = outputs.shape.as_list()
          outputs_4d = array_ops.reshape(outputs,
                                         [outputs_shape[0], outputs_shape[1],
                                          outputs_shape[2] * outputs_shape[3],
                                          outputs_shape[4]])
          outputs_4d = nn.bias_add(outputs_4d, self.bias, data_format='NCHW')
          outputs = array_ops.reshape(outputs_4d, outputs_shape)
      else:
        outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')

    if self.activation is not None:
      return self.activation(outputs)
    return outputs
    def build(self, input_shape):
        input_shape = tf.TensorShape(input_shape)
        if self.data_format == 'channels_first':
            channel_axis = 1
        else:
            channel_axis = -1
        input_dim = tf.compat.dimension_value(input_shape[channel_axis])
        if input_dim is None:
            raise ValueError('The channel dimension of the inputs '
                             'should be defined. Found `None`.')
        kernel_shape = self.kernel_size + (input_dim, self.filters)

        # If self.dtype is None, build weights using the default dtype.
        dtype = tf.as_dtype(self.dtype or tf.keras.backend.floatx())
        name = 'kernel'

        self.kernel_posterior_fn, self.kernel_prior_fn = \
            self.build_posterior_fn_natural(kernel_shape, dtype, name,
                                            self.kernel_posterior_fn,
                                            self.kernel_prior_fn)

        natural_initializer = natural_initializer_fn(
            loc_stdev=0.1,
            u_scale_init_avg=-5,
            u_scale_init_stdev=0.1,
            untransformed_scale_initializer=self.untransformed_scale_initializer)

        # Must have a posterior kernel.
        self.kernel_posterior = self.kernel_posterior_fn(
            dtype, kernel_shape, 'kernel_posterior',
            self.trainable, self.add_variable,
            natural_initializer=natural_initializer)

        if self.kernel_prior_fn is None:
            self.kernel_prior = None
        else:
            self.kernel_prior = self.kernel_prior_fn(
                dtype, kernel_shape, 'kernel_prior',
                self.trainable, self.add_variable)
        self._built_kernel_divergence = False

        if self.bias_posterior_fn is None:
            self.bias_posterior = None
        else:
            self.bias_posterior = self.bias_posterior_fn(
                dtype, (self.filters,), 'bias_posterior',
                self.trainable, self.add_variable)

        if self.bias_prior_fn is None:
            self.bias_prior = None
        else:
            self.bias_prior = self.bias_prior_fn(
                dtype, (self.filters,), 'bias_prior',
                self.trainable, self.add_variable)
        self._built_bias_divergence = False

        self.input_spec = tf.keras.layers.InputSpec(
            ndim=self.rank + 2, axes={channel_axis: input_dim})
        self._convolution_op = nn_ops.Convolution(
            input_shape,
            filter_shape=tf.TensorShape(kernel_shape),
            dilation_rate=self.dilation_rate,
            strides=self.strides,
            padding=self.padding.upper(),
            data_format=tf_layers_util.convert_data_format(
                self.data_format, self.rank + 2))

        if self.bias_posterior:
            self.bias_center = self.add_weight(
                'bias_center',
                shape=[self.units, ],
                initializer=tf.keras.initializers.constant(0.),
                dtype=self.dtype,
                trainable=False)
            self.client_variable_dict['bias'] = self.bias_posterior.distribution.loc
            self.server_variable_dict['bias'] = self.bias_posterior.distribution.loc
            self.client_center_variable_dict['bias'] = self.bias_center

        self.built = True
Example #25
0
    def build(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape)

        dau_params_shape = self.get_dau_variable_shape(input_shape)
        if self.dau_weights is None:
            self.dau_weights = self.add_variable(
                name='weights',
                shape=dau_params_shape,
                initializer=self.weight_initializer,
                regularizer=self.weight_regularizer,
                constraint=self.weight_constraint,
                trainable=True,
                dtype=self.dtype)
        elif np.any(self.dau_weights != dau_params_shape):
            raise ValueError('Shape mismatch for variable `dau_weights`')
        if self.dau_mu1 is None:
            self.dau_mu1 = self.add_variable(name='mu1',
                                             shape=dau_params_shape,
                                             initializer=self.mu1_initializer,
                                             regularizer=self.mu1_regularizer,
                                             constraint=self.mu1_constraint,
                                             trainable=True,
                                             dtype=self.dtype)
        elif np.any(self.dau_mu1 != dau_params_shape):
            raise ValueError('Shape mismatch for variable `dau_mu1`')

        if self.dau_mu2 is None:
            self.dau_mu2 = self.add_variable(name='mu2',
                                             shape=dau_params_shape,
                                             initializer=self.mu2_initializer,
                                             regularizer=self.mu2_regularizer,
                                             constraint=self.mu2_constraint,
                                             trainable=True,
                                             dtype=self.dtype)
        elif np.any(self.dau_mu2 != dau_params_shape):
            raise ValueError('Shape mismatch for variable `dau_mu2`')
        if self.dau_sigma is None:
            self.dau_sigma = self.add_variable(
                name='sigma',
                shape=dau_params_shape,
                initializer=self.sigma_initializer,
                regularizer=self.sigma_regularizer,
                constraint=self.sigma_constraint,
                trainable=False,
                dtype=self.dtype)
        elif np.any(self.dau_sigma != dau_params_shape):
            raise ValueError('Shape mismatch for variable `dau_sigma`')

        if self.use_bias:
            self.bias = self.add_variable(name='bias',
                                          shape=(self.filters, ),
                                          initializer=self.bias_initializer,
                                          regularizer=self.bias_regularizer,
                                          constraint=self.bias_constraint,
                                          trainable=True,
                                          dtype=self.dtype)
        else:
            self.bias = None

        input_channel_axis = self._get_input_channel_axis()
        num_input_channels = self._get_input_channels(input_shape)

        self.input_spec = base.InputSpec(
            ndim=self.rank + 2, axes={input_channel_axis: num_input_channels})

        self._dau_convolution_op = _DAUConvolution2d(
            input_shape,
            num_output=self.filters,
            dau_units=self.dau_units,
            max_kernel_size=self.max_kernel_size,
            padding=self.padding,
            strides=self.strides,
            num_dau_units_ignore=self.num_dau_units_ignore,
            mu_learning_rate_factor=self.mu_learning_rate_factor,
            unit_testing=self.unit_testing,
            data_format=utils.convert_data_format(self.data_format,
                                                  self.rank + 2))
        self.built = True
    def _testConvFlipout(self, layer_class):  # pylint: disable=invalid-name
        batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5
        with self.cached_session() as sess:
            (kernel_posterior, kernel_prior, kernel_divergence, bias_posterior,
             bias_prior, bias_divergence, layer, inputs, outputs, kl_penalty,
             kernel_shape) = self._testConvSetUp(layer_class,
                                                 batch_size,
                                                 depth=depth,
                                                 height=height,
                                                 width=width,
                                                 channels=channels,
                                                 filters=filters,
                                                 seed=44)

            tf.compat.v1.set_random_seed(5995)

            convolution_op = nn_ops.Convolution(
                tf.TensorShape(inputs.shape),
                filter_shape=tf.TensorShape(kernel_shape),
                padding='SAME',
                data_format=tf_layers_util.convert_data_format(
                    self.data_format, inputs.shape.rank))

            expected_kernel_posterior_affine = tfd.Normal(
                loc=tf.zeros_like(kernel_posterior.result_loc),
                scale=kernel_posterior.result_scale)
            expected_kernel_posterior_affine_tensor = (
                expected_kernel_posterior_affine.sample(seed=42))

            expected_outputs = convolution_op(
                inputs, kernel_posterior.distribution.loc)

            input_shape = tf.shape(input=inputs)
            batch_shape = tf.expand_dims(input_shape[0], 0)
            if self.data_format == 'channels_first':
                channels = input_shape[1]
            else:
                channels = input_shape[-1]
            rank = len(inputs.shape) - 2

            seed_stream = tfd.SeedStream(layer.seed, salt='ConvFlipout')

            sign_input = tf.random.uniform(tf.concat(
                [batch_shape, tf.expand_dims(channels, 0)], 0),
                                           minval=0,
                                           maxval=2,
                                           dtype=tf.int64,
                                           seed=seed_stream())
            sign_input = tf.cast(2 * sign_input - 1, inputs.dtype)
            sign_output = tf.random.uniform(tf.concat(
                [batch_shape, tf.expand_dims(filters, 0)], 0),
                                            minval=0,
                                            maxval=2,
                                            dtype=tf.int64,
                                            seed=seed_stream())
            sign_output = tf.cast(2 * sign_output - 1, inputs.dtype)

            if self.data_format == 'channels_first':
                for _ in range(rank):
                    sign_input = tf.expand_dims(sign_input,
                                                -1)  # 2D ex: (B, C, 1, 1)
                    sign_output = tf.expand_dims(sign_output, -1)
            else:
                for _ in range(rank):
                    sign_input = tf.expand_dims(sign_input,
                                                1)  # 2D ex: (B, 1, 1, C)
                    sign_output = tf.expand_dims(sign_output, 1)

            perturbed_inputs = convolution_op(
                inputs * sign_input, expected_kernel_posterior_affine_tensor)
            perturbed_inputs *= sign_output

            expected_outputs += perturbed_inputs
            expected_outputs = tf.nn.bias_add(
                expected_outputs,
                bias_posterior.result_sample,
                data_format=tf_layers_util.convert_data_format(
                    self.data_format, 4))

            [
                expected_outputs_,
                actual_outputs_,
                expected_kernel_divergence_,
                actual_kernel_divergence_,
                expected_bias_,
                actual_bias_,
                expected_bias_divergence_,
                actual_bias_divergence_,
            ] = sess.run([
                expected_outputs,
                outputs,
                kernel_divergence.result,
                kl_penalty[0],
                bias_posterior.result_sample,
                layer.bias_posterior_tensor,
                bias_divergence.result,
                kl_penalty[1],
            ])

            self.assertAllClose(expected_bias_, actual_bias_, rtol=1e-6)
            self.assertAllClose(expected_outputs_, actual_outputs_, rtol=1e-6)
            self.assertAllClose(expected_kernel_divergence_,
                                actual_kernel_divergence_,
                                rtol=1e-6)
            self.assertAllClose(expected_bias_divergence_,
                                actual_bias_divergence_,
                                rtol=1e-6)

            expected_args = [kernel_posterior, kernel_prior, None]
            # We expect that there was one call to kernel_divergence, with the above
            # args; MockKLDivergence appends the list of args to a list, so the above
            # args should be in the 0th position of that list.
            actual_args = kernel_divergence.args[0]
            # Test for identity with 'is'. TensorFlowTestCase.assertAllEqual actually
            # coerces the inputs to numpy arrays, so we can't use that to assert that
            # the arguments (which are a mixture of Distributions and Tensors) are
            # equal.
            for a, b in zip(expected_args, actual_args):
                self.assertIs(a, b)

            # Same story as above.
            expected_args = [
                bias_posterior, bias_prior, bias_posterior.result_sample
            ]
            actual_args = bias_divergence.args[0]
            for a, b in zip(expected_args, actual_args):
                self.assertIs(a, b)
    def _testConvReparameterization(self, layer_class):  # pylint: disable=invalid-name
        batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5
        with self.cached_session() as sess:
            (kernel_posterior, kernel_prior, kernel_divergence, bias_posterior,
             bias_prior, bias_divergence, layer, inputs, outputs, kl_penalty,
             kernel_shape) = self._testConvSetUp(layer_class,
                                                 batch_size,
                                                 depth=depth,
                                                 height=height,
                                                 width=width,
                                                 channels=channels,
                                                 filters=filters)

            convolution_op = nn_ops.Convolution(
                tf.TensorShape(inputs.shape),
                filter_shape=tf.TensorShape(kernel_shape),
                padding='SAME',
                data_format=tf_layers_util.convert_data_format(
                    self.data_format, inputs.shape.rank))
            expected_outputs = convolution_op(inputs,
                                              kernel_posterior.result_sample)
            expected_outputs = tf.nn.bias_add(
                expected_outputs,
                bias_posterior.result_sample,
                data_format=tf_layers_util.convert_data_format(
                    self.data_format, 4))

            [
                expected_outputs_,
                actual_outputs_,
                expected_kernel_,
                actual_kernel_,
                expected_kernel_divergence_,
                actual_kernel_divergence_,
                expected_bias_,
                actual_bias_,
                expected_bias_divergence_,
                actual_bias_divergence_,
            ] = sess.run([
                expected_outputs,
                outputs,
                kernel_posterior.result_sample,
                layer.kernel_posterior_tensor,
                kernel_divergence.result,
                kl_penalty[0],
                bias_posterior.result_sample,
                layer.bias_posterior_tensor,
                bias_divergence.result,
                kl_penalty[1],
            ])

            self.assertAllClose(expected_kernel_, actual_kernel_, rtol=1e-6)
            self.assertAllClose(expected_bias_, actual_bias_, rtol=1e-6)
            self.assertAllClose(expected_outputs_, actual_outputs_, rtol=1e-6)
            self.assertAllClose(expected_kernel_divergence_,
                                actual_kernel_divergence_,
                                rtol=1e-6)
            self.assertAllClose(expected_bias_divergence_,
                                actual_bias_divergence_,
                                rtol=1e-6)

            expected_args = [
                kernel_posterior, kernel_prior, kernel_posterior.result_sample
            ]
            # We expect that there was one call to kernel_divergence, with the above
            # args; MockKLDivergence appends the list of args to a list, so the above
            # args should be in the 0th position of that list.
            actual_args = kernel_divergence.args[0]
            # Test for identity with 'is'. TensorFlowTestCase.assertAllEqual actually
            # coerces the inputs to numpy arrays, so we can't use that to assert that
            # the arguments (which are a mixture of Distributions and Tensors) are
            # equal.
            for a, b in zip(expected_args, actual_args):
                self.assertIs(a, b)

            # Same story as above.
            expected_args = [
                bias_posterior, bias_prior, bias_posterior.result_sample
            ]
            actual_args = bias_divergence.args[0]
            for a, b in zip(expected_args, actual_args):
                self.assertIs(a, b)
Example #28
0
    def call(self, inputs):
        inputs = ops.convert_to_tensor(inputs, dtype=self.dtype)
        input_shape = array_ops.shape(inputs)
        outputs = inputs

        # First, perform any requested padding.
        if self.padding in ("same_zeros", "same_reflect"):
            padding = padding_ops.same_padding_for_kernel(
                self.kernel_support, self.corr, self.strides_up)
            if self.data_format == "channels_last":
                padding = [[0, 0]] + list(padding) + [[0, 0]]
            else:
                padding = [[0, 0], [0, 0]] + list(padding)
            outputs = array_ops.pad(outputs, padding, self._pad_mode)

        # Now, perform valid convolutions/correlations.

        # Not for all possible combinations of (`kernel_support`, `corr`,
        # `strides_up`, `strides_down`) TF ops exist. We implement some additional
        # combinations by manipulating the kernels and toggling `corr`.
        kernel = self.kernel
        corr = self.corr

        # If a convolution with no upsampling is desired, we flip the kernels and
        # use cross correlation to implement it, provided the kernels are odd-length
        # in every dimension (with even-length kernels, the boundary handling
        # would have to change, so we'll throw an error instead).
        if (not corr and all(s == 1 for s in self.strides_up)
                and all(s % 2 == 1 for s in self.kernel_support)):
            corr = True
            slices = self._rank * (slice(None, None,
                                         -1), ) + 2 * (slice(None), )
            kernel = kernel[slices]

        # Similarly, we can implement a cross correlation with no downsampling using
        # convolutions. However, we do this only if upsampling is requested, as we
        # are wasting computation in the boundaries whenever we call the transpose
        # convolution ops.
        if (corr and all(s == 1 for s in self.strides_down)
                and any(s != 1 for s in self.strides_up)
                and all(s % 2 == 1 for s in self.kernel_support)):
            corr = False
            slices = self._rank * (slice(None, None,
                                         -1), ) + 2 * (slice(None), )
            kernel = kernel[slices]

        data_format = utils.convert_data_format(self.data_format,
                                                self._rank + 2)
        if (corr and self.channel_separable and self._rank == 2
                and all(s == 1 for s in self.strides_up)
                and all(s == self.strides_down[0] for s in self.strides_down)):
            # `nn.depthwise_conv2d_native` performs channel-separable correlations
            # followed by optional downsampling.
            outputs = nn.depthwise_conv2d_native(outputs,
                                                 kernel,
                                                 strides=self._pad_strides(
                                                     self.strides_down),
                                                 padding="VALID",
                                                 data_format=data_format)
        elif (corr and all(s == 1 for s in self.strides_up)
              and not self.channel_separable):
            # `nn.convolution` performs correlations followed by optional
            # downsampling.
            outputs = nn.convolution(outputs,
                                     kernel,
                                     strides=self.strides_down,
                                     padding="VALID",
                                     data_format=data_format)
        elif (not corr and all(s == 1 for s in self.strides_down)
              and ((not self.channel_separable and 1 <= self._rank <= 3) or
                   (self.channel_separable and self.filters == 1
                    and self._rank == 2 and all(s == self.strides_up[0]
                                                for s in self.strides_up)))):
            # `nn.conv?d_transpose` perform convolutions, preceded by optional
            # upsampling. Generally, they increase the spatial support of their
            # inputs, so in order to implement 'valid', we need to crop their outputs.

            # Transpose convolutions expect the output and input channels in reversed
            # order. We implement this by swapping those dimensions of the kernel.
            # For channel separable convolutions, we can't currently perform anything
            # other than one filter per channel, so the last dimension needs to be of
            # length one. Since this happens to be the format that the op expects it,
            # we can skip the transpose in that case.
            if not self.channel_separable:
                kernel = array_ops.transpose(
                    kernel,
                    list(range(self._rank)) + [self._rank + 1, self._rank])

            # Compute shape of temporary.
            pad_shape = array_ops.shape(outputs)
            temp_shape = [pad_shape[0]] + (self._rank + 1) * [None]
            if self.data_format == "channels_last":
                spatial_axes = range(1, self._rank + 1)
                if self.channel_separable:
                    temp_shape[-1] = input_shape[-1]
                else:
                    temp_shape[-1] = self.filters
            else:
                spatial_axes = range(2, self._rank + 2)
                if self.channel_separable:
                    temp_shape[1] = input_shape[1]
                else:
                    temp_shape[1] = self.filters
            if self.extra_pad_end:
                get_length = lambda l, s, k: l * s + (k - 1)
            else:
                get_length = lambda l, s, k: l * s + (k - s)
            for i, a in enumerate(spatial_axes):
                temp_shape[a] = get_length(pad_shape[a], self.strides_up[i],
                                           self.kernel_support[i])

            # Compute convolution.
            if self._rank == 1 and not self.channel_separable:
                # There's no 1D transpose convolution op, so we insert an extra
                # dimension and use 2D.
                extradim = {
                    "channels_first": 2,
                    "channels_last": 1
                }[self.data_format]
                strides = self._pad_strides(self.strides_up)
                temp = array_ops.squeeze(
                    nn.conv2d_transpose(
                        array_ops.expand_dims(outputs, extradim),
                        array_ops.expand_dims(kernel, 0),
                        temp_shape[:extradim] + [1] + temp_shape[extradim:],
                        strides=strides[:extradim] + (1, ) +
                        strides[extradim:],
                        padding="VALID",
                        data_format=data_format.replace("W", "HW")),
                    [extradim])
            elif self._rank == 2 and self.channel_separable:
                temp = nn.depthwise_conv2d_native_backprop_input(
                    temp_shape,
                    kernel,
                    outputs,
                    strides=self._pad_strides(self.strides_up),
                    padding="VALID",
                    data_format=data_format)
            elif self._rank == 2 and not self.channel_separable:
                temp = nn.conv2d_transpose(outputs,
                                           kernel,
                                           temp_shape,
                                           strides=self._pad_strides(
                                               self.strides_up),
                                           padding="VALID",
                                           data_format=data_format)
            elif self._rank == 3 and not self.channel_separable:
                temp = nn.conv3d_transpose(outputs,
                                           kernel,
                                           temp_shape,
                                           strides=self._pad_strides(
                                               self.strides_up),
                                           padding="VALID",
                                           data_format=data_format)
            else:
                assert False  # Should never reach this.

            # Perform crop.
            slices = [slice(None)] * (self._rank + 2)
            if self.padding == "valid":
                # Take `kernel_support - 1` samples away from both sides. This leaves
                # just samples computed without padding.
                for i, a in enumerate(spatial_axes):
                    slices[a] = slice(
                        self.kernel_support[i] - 1,
                        None if self.kernel_support[i] == 1 else 1 -
                        self.kernel_support[i])
            else:
                # Take `kernel_support // 2` plus the padding away from beginning, and
                # crop end to input length multiplied by upsampling factor.
                for i, a in enumerate(spatial_axes):
                    offset = padding[a][0] * self.strides_up[i]
                    offset += self.kernel_support[i] // 2
                    length = get_length(input_shape[a], self.strides_up[i],
                                        offset + 1)
                    slices[a] = slice(offset, length)
            outputs = temp[slices]
        else:
            raise NotImplementedError(
                "The provided combination of SignalConv arguments is not currently "
                "implemented (kernel_support={}, corr={}, strides_down={}, "
                "strides_up={}, channel_separable={}, filters={}). "
                "Try using odd-length kernels or turning off separability?".
                format(self.kernel_support, self.corr, self.strides_down,
                       self.strides_up, self.channel_separable, self.filters))

        # Now, add bias if requested.
        if self.bias is not None:
            if self.data_format == "channels_first":
                # As of Mar 2017, direct addition is significantly slower than
                # bias_add when computing gradients.
                if self._rank == 1:
                    # nn.bias_add does not accept a 1D input tensor.
                    outputs = array_ops.expand_dims(outputs, 2)
                    outputs = nn.bias_add(outputs,
                                          self.bias,
                                          data_format="NCHW")
                    outputs = array_ops.squeeze(outputs, [2])
                elif self._rank == 2:
                    outputs = nn.bias_add(outputs,
                                          self.bias,
                                          data_format="NCHW")
                elif self._rank >= 3:
                    shape = array_ops.shape(outputs)
                    outputs = array_ops.reshape(outputs, shape[:3] + [-1])
                    outputs = nn.bias_add(outputs,
                                          self.bias,
                                          data_format="NCHW")
                    outputs = array_ops.reshape(outputs, shape)
            else:
                outputs = nn.bias_add(outputs, self.bias)

        # Finally, pass through activation function if requested.
        if self.activation is not None:
            outputs = self.activation(outputs)  # pylint:disable=not-callable

        # Aid shape inference, for some reason shape info is not always available.
        if not context.executing_eagerly():
            outputs.set_shape(self.compute_output_shape(inputs.shape))

        return outputs