def fn(x): #conv_even = K.conv2d(K.conv2d(x, even_kernel_3d), #K.permute_dimensions(even_kernel_3d, (1, 0, 2, 3))) #conv_odd = K.conv2d(K.conv2d(x, odd_kernel_3d), #K.permute_dimensions(odd_kernel_3d, (1, 0, 2, 3))) input_shape = K.shape(x) dim1 = conv_utils.conv_input_length(input_shape[1], 5, padding=padding_mode, stride=2) dim2 = conv_utils.conv_input_length(input_shape[2], 5, padding=padding_mode, stride=2) output_shape_a = (input_shape[0], dim1, input_shape[2], input_shape[3]) output_shape_b = (input_shape[0], dim1, dim2, input_shape[3]) upconvolved = K.conv2d_transpose(x, kernel_3d, output_shape_a, strides=(2, 1), padding=padding_mode) upconvolved = K.conv2d_transpose(upconvolved, K.permute_dimensions( kernel_3d, (1, 0, 2, 3)), output_shape_b, strides=(1, 2), padding=padding_mode) return 4 * upconvolved
def call(self, x): inp = x kernel = K.random_uniform_variable(shape=(self.kernel_size[0], self.kernel_size[1], self.out_shape[-1], int(x.get_shape()[-1])), low=0, high=1) deconv = K.conv2d_transpose(x, kernel=kernel, strides=self.strides, output_shape=self.out_shape, padding='same') biases = K.zeros(shape=(self.out_shape[-1])) deconv = K.reshape(K.bias_add(deconv, biases), deconv.get_shape()) deconv = LeakyReLU()(deconv) g = K.conv2d_transpose(inp, kernel, output_shape=self.out_shape, strides=self.strides, padding='same') biases2 = K.zeros(shape=(self.out_shape[-1])) g = K.reshape(K.bias_add(g, biases2), deconv.get_shape()) g = K.sigmoid(g) deconv = tf.multiply(deconv, g) outputs = [deconv, g] output_shapes = self.compute_output_shape(x.shape) for output, shape in zip(outputs, output_shapes): output._keras_shape = shape return [deconv, g]
def gconv2d(x, kernel, gconv_indices, gconv_shape_info, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), transpose=False, output_shape=None): """2D group equivariant convolution. # Arguments x: Tensor or variable. kernel: kernel tensor. strides: strides tuple. padding: string, `"same"` or `"valid"`. data_format: string, `"channels_last"` or `"channels_first"`. Whether to use Theano or TensorFlow data format for inputs/kernels/ouputs. dilation_rate: tuple of 2 integers. # Returns A tensor, result of 2D convolution. # Raises ValueError: if `data_format` is neither `channels_last` or `channels_first`. """ # Transform the filters transformed_filter = transform_filter_2d_nhwc(w=kernel, flat_indices=gconv_indices, shape_info=gconv_shape_info) if transpose: output_shape = (K.shape(x)[0], output_shape[1], output_shape[2], output_shape[3]) transformed_filter = transform_filter_2d_nhwc(w=kernel, flat_indices=gconv_indices, shape_info=gconv_shape_info) transformed_filter = K.permute_dimensions(transformed_filter, [0, 1, 3, 2]) return K.conv2d_transpose(x=x, kernel=transformed_filter, output_shape=output_shape, strides=strides, padding=padding, data_format=data_format) return K.conv2d(x=x, kernel=transformed_filter, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate)
def call(self, inputs): h = self.h tmp, otmp = tf.split(inputs, num_or_size_splits=2, axis=-1) for k in range(self.unroll_length): y_f_tmp = K.conv2d(otmp, self.kernel, strides=self.strides, padding=self.padding, data_format=self.data_format, dilation_rate=self.dilation_rate) if self.use_bias: y_f_tmp = K.bias_add(y_f_tmp, self.bias, data_format=self.data_format) x_f_tmp = -K.conv2d_transpose(x=tmp, kernel=self.kernel, strides=self.strides, output_shape=K.shape(tmp), padding=self.padding, data_format=self.data_format) if self.use_bias: x_f_tmp = K.bias_add(x_f_tmp, self.bias2, data_format=self.data_format) if self.activation is not None: tmp = tmp + h * self.activation(y_f_tmp) otmp = otmp + h * self.activation(x_f_tmp) out = K.concatenate([tmp, otmp], axis=-1) return out
def call(self, inputs): input_shape = K.shape(inputs) batch_size = input_shape[0] if self.data_format == 'channels_first': h_axis, w_axis = 2, 3 c_axis = 1 else: h_axis, w_axis = 1, 2 c_axis = 3 ##BTEK kernel = self.U() in_channels = input_shape[c_axis] height, width = input_shape[h_axis], input_shape[w_axis] kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides if self.output_padding is None: out_pad_h = out_pad_w = None else: out_pad_h, out_pad_w = self.output_padding # Infer the dynamic output shape: out_height = conv_utils.deconv_length(height, stride_h, kernel_h, self.padding, out_pad_h, self.dilation_rate[0]) out_width = conv_utils.deconv_length(width, stride_w, kernel_w, self.padding, out_pad_w, self.dilation_rate[1]) if self.data_format == 'channels_first': output_shape = (batch_size, self.filters, out_height, out_width) else: output_shape = (batch_size, out_height, out_width, self.filters) ##BTEK kernel = self.U() print("kernel shape in output:", kernel.shape) print("channel axis") kernel = K.repeat_elements(kernel, self.input_channels, axis=c_axis) print("kernel reshaped :", kernel.shape) outputs = K.conv2d_transpose(inputs, kernel, output_shape, self.strides, padding=self.padding, data_format=self.data_format, dilation_rate=self.dilation_rate) if self.use_bias: outputs = K.bias_add(outputs, self.bias, data_format=self.data_format) if self.activation is not None: return self.activation(outputs) return outputs
def call(self, x): shape = self.compute_output_shape(x.shape.as_list()) batch_size = K.shape(x)[0] output_shape = (batch_size, *shape) return K.conv2d_transpose(x, self._W, output_shape=output_shape, strides=tuple(self._upscaling_factors), padding="same")
def call(self, inputs): if self.transpose_output_shape is None: input_shape = K.shape(inputs) input_shape_list = inputs.get_shape().as_list() batch_size = input_shape[0] if self.data_format == 'channels_first': h_axis, w_axis = 2, 3 else: h_axis, w_axis = 1, 2 height, width = input_shape_list[h_axis], input_shape_list[w_axis] kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides # Infer the dynamic output shape: out_height = conv_utils.deconv_length(height, stride_h, kernel_h, self.padding) out_width = conv_utils.deconv_length(width, stride_w, kernel_w, self.padding) if self.data_format == 'channels_first': self.transpose_output_shape = (batch_size, self.filters, out_height, out_width) else: self.transpose_output_shape = (batch_size, out_height, out_width, self.filters) shape = self.transpose_output_shape else: shape = self.transpose_output_shape if self.data_format == 'channels_first': shape = (shape[0], shape[2], shape[3], shape[1]) if shape[0] is None: shape = (tf.shape(inputs)[0], ) + tuple(shape[1:]) shape = tf.stack(list(shape)) outputs = K.conv2d_transpose(x=inputs, kernel=self.kernel, output_shape=shape, strides=self.strides, padding=self.padding, data_format=self.data_format) outputs = tf.reshape(outputs, shape) # if self.bias: # outputs = K.bias_add( # outputs, # self.bias, # data_format=self.data_format) if self.activation is not None: return self.activation(outputs) return outputs
def call(self, inputs): input_shape = K.shape(inputs) batch_size = input_shape[0] if self.data_format == 'channels_first': h_axis, w_axis = 2, 3 else: h_axis, w_axis = 1, 2 height, width = input_shape[h_axis], input_shape[w_axis] kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides if self.output_padding is None: out_pad_h = out_pad_w = None else: out_pad_h, out_pad_w = self.output_padding # Infer the dynamic output shape: out_height = conv_utils.deconv_length(height, stride_h, kernel_h, self.padding, out_pad_h, self.dilation_rate[0]) out_width = conv_utils.deconv_length(width, stride_w, kernel_w, self.padding, out_pad_w, self.dilation_rate[1]) if self.data_format == 'channels_first': output_shape = (batch_size, self.filters, out_height, out_width) else: output_shape = (batch_size, out_height, out_width, self.filters) outputs = inputs if self.use_bias: outputs = K.bias_add( outputs, -self.bias, data_format=self.data_format) outputs = K.conv2d_transpose( outputs, self.kernel, output_shape, self.strides, padding=self.padding, data_format=self.data_format, dilation_rate=self.dilation_rate) if self.activation is not None: return self.activation(outputs) return outputs
def call(self, inputs): shape = K.shape(inputs) if self.keep_dims: y = K.conv2d(inputs, self.dct_kernel) else: y = K.conv2d_transpose(inputs, self.dct_kernel, output_shape=(shape[0], shape[1] + 7, shape[2] + 7, 1)) y = y * self.scale_kernel return y
def branch(self, z, kernel, bias): out = K.conv2d(z, kernel=kernel, padding='same', strides=(1, 1)) out = K.bias_add(out, bias) out = K.relu(out) z_shape = K.shape(z) tr_shape = (z_shape[0], z_shape[1], z_shape[2], self.filters / 2) out = K.conv2d_transpose(out, output_shape=tr_shape, kernel=kernel, padding='same', strides=(1, 1)) return self.h * out
def max_singular_val_for_convolution(w, u, fully_differentiable=False, ip=1, padding='same', strides=(1, 1), data_format='channels_last'): assert ip >= 1 if not fully_differentiable: w_ = K.stop_gradient(w) else: w_ = w u_bar = u for _ in range(ip): v_bar = K.conv2d(u_bar, w_, strides=strides, data_format=data_format, padding=padding) v_bar = K.l2_normalize(v_bar) u_bar_raw = K.conv2d_transpose(v_bar, w_, output_shape=K.int_shape(u), strides=strides, data_format=data_format, padding=padding) u_bar = K.l2_normalize(u_bar_raw) u_bar_raw_diff = K.conv2d_transpose(v_bar, w, output_shape=K.int_shape(u), strides=strides, data_format=data_format, padding=padding) sigma = K.sum(u_bar * u_bar_raw_diff) return sigma, u_bar
def call(self, input_tensor, training=None): input_transposed = tf.transpose(input_tensor, [3, 0, 1, 2, 4]) input_shape = K.shape(input_transposed) input_tensor_reshaped = K.reshape(input_transposed, [ input_shape[1] * input_shape[0], self.input_height, self.input_width, self.input_num_atoms]) input_tensor_reshaped.set_shape((None, self.input_height, self.input_width, self.input_num_atoms)) if self.upsamp_type == 'resize': upsamp = K.resize_images(input_tensor_reshaped, self.scaling, self.scaling, 'channels_last') outputs = K.conv2d(upsamp, kernel=self.W, strides=(1, 1), padding=self.padding, data_format='channels_last') elif self.upsamp_type == 'subpix': conv = K.conv2d(input_tensor_reshaped, kernel=self.W, strides=(1, 1), padding='same', data_format='channels_last') outputs = tf.depth_to_space(conv, self.scaling) else: batch_size = input_shape[1] * input_shape[0] # Infer the dynamic output shape: out_height = deconv_length(self.input_height, self.scaling, self.kernel_size, self.padding, None) out_width = deconv_length(self.input_width, self.scaling, self.kernel_size, self.padding, None) output_shape = (batch_size, out_height, out_width, self.num_capsule * self.num_atoms) outputs = K.conv2d_transpose(input_tensor_reshaped, self.W, output_shape, (self.scaling, self.scaling), padding=self.padding, data_format='channels_last') votes_shape = K.shape(outputs) _, conv_height, conv_width, _ = outputs.get_shape() votes = K.reshape(outputs, [input_shape[1], input_shape[0], votes_shape[1], votes_shape[2], self.num_capsule, self.num_atoms]) votes.set_shape((None, self.input_num_capsule, conv_height.value, conv_width.value, self.num_capsule, self.num_atoms)) logit_shape = K.stack([ input_shape[1], input_shape[0], votes_shape[1], votes_shape[2], self.num_capsule]) biases_replicated = K.tile(self.b, [votes_shape[1], votes_shape[2], 1, 1]) activations = update_routing( votes=votes, biases=biases_replicated, logit_shape=logit_shape, num_dims=6, input_dim=self.input_num_capsule, output_dim=self.num_capsule, num_routing=self.routings) return activations
def call(self, inputs): input_shape = K.shape(inputs) batch_size = input_shape[0] #if self.data_format == 'channels_first': # s_axis = 2 #else: s_axis = 1 steps = input_shape[s_axis] kernel_w, = self.kernel_size stride, = self.strides if self.output_padding is None: out_pad_w = None else: out_pad_w, = self.output_padding # Infer the dynamic output shape: out_width = conv_utils.deconv_length(steps, stride, kernel_w, self.padding, out_pad_w, self.dilation_rate[0]) #if self.data_format == 'channels_first': # output_shape = (batch_size, self.filters, 1, out_width) # inputs = K.expand_dims(inputs, axis=2) #else: output_shape = (batch_size, 1, out_width, self.filters) inputs = K.expand_dims(inputs, axis=1) outputs = K.conv2d_transpose(inputs, K.expand_dims(self.kernel, axis=0), output_shape, (1, self.strides[0]), padding=self.padding, data_format=self.data_format, dilation_rate=(1, self.dilation_rate[0])) #if self.data_format == 'channels_first': # outputs = K.squeeze(outputs, axis=2) #else: outputs = K.squeeze(outputs, axis=1) if self.use_bias: outputs = K.bias_add(outputs, self.bias, data_format=self.data_format) if self.activation is not None: return self.activation(outputs) return outputs
def _deconv_module(feature_layer, deconv_layer, name, feature_size=512): # deconv1 = Conv2DTranspose(feature_size, kernel_size=(2,2), # strides=(2, 2), name=name+'Transpose1', padding='same')(deconv_layer) # # deconv1 = UpSampling2D(size=(2, 2), name=name+'upsampled')(deconv_layer) # filter_shape = K.stack([1, 2,2, deconv_layer.shape[3]]) deconv_filter = Lambda(lambda args: K.variable( K.random_normal(K.stack([2, 2, feature_size, args.shape[3]])), name=name + 'kernel'))(deconv_layer) # output_shape = K.stack([K.shape(feature_layer)[0], feature_layer.shape[1], feature_layer.shape[2], feature_size]) deconv1 = Lambda(lambda args: K.conv2d_transpose( args[0], args[2], K.stack([ K.shape(args[1])[0], args[1].shape[1], args[1].shape[2], feature_size ]), (2, 2), 'valid'))([deconv_layer, feature_layer, deconv_filter]) conv1 = Conv2D(feature_size, kernel_size=(3, 3), strides=1, padding='same', name=name + 'Conv1')(deconv1) bn1 = BatchNormalization(name='bn_1' + name)(conv1) conv2 = Conv2D(feature_size, kernel_size=(3, 3), strides=1, padding='same', name=name + 'Conv2')(feature_layer) bn2 = BatchNormalization(name='bn_2' + name)(conv2) relu1 = Activation('relu', name=name + 'Relu2')(bn2) conv3 = Conv2D(feature_size, kernel_size=(3, 3), strides=1, padding='same', name=name + 'Conv3')(relu1) bn3 = BatchNormalization(name='bn_3' + name)(conv3) product = Multiply(name='mul_' + name)([bn1, bn3]) return Activation('relu', name=name)(product)
def call(self, inputs): input_shape = K.shape(inputs) batch_size = input_shape[0] if self.data_format == 'channels_first': h_axis, w_axis = 2, 3 else: h_axis, w_axis = 1, 2 height, width = input_shape[h_axis], input_shape[w_axis] kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides # Infer the dynamic output shape: if self._output_shape is None: out_height = deconv_length(height, stride_h, kernel_h, self.padding) out_width = deconv_length(width, stride_w, kernel_w, self.padding) if self.data_format == 'channels_first': output_shape = ( batch_size, self.filters, out_height, out_width ) else: output_shape = ( batch_size, out_height, out_width, self.filters ) else: output_shape = (batch_size,) + self._output_shape outputs = K.conv2d_transpose( inputs, self.kernel, output_shape, self.strides, padding=self.padding, data_format=self.data_format ) if self.bias: outputs = K.bias_add( outputs, self.bias, data_format=self.data_format ) if self.activation is not None: return self.activation(outputs) return outputs
def call(self, inputs): revcomp_kernel =\ K.concatenate([self.kernel, self.kernel[::-1,::-1,::-1]],axis=-2) if (self.use_bias): revcomp_bias = K.concatenate([self.bias, self.bias[::-1]], axis=-1) input_shape = K.shape(inputs) batch_size = input_shape[0] s_axis = 1 steps = input_shape[s_axis] kernel_w, = self.kernel_size stride, = self.strides if self.output_padding is None: out_pad_w = None else: out_pad_w, = self.output_padding # Infer the dynamic output shape: out_width = conv_utils.deconv_length(steps, stride, kernel_w, self.padding, out_pad_w, self.dilation_rate[0]) output_shape = (batch_size, 1, out_width, 2 * self.filters) inputs = K.expand_dims(inputs, axis=1) outputs = K.conv2d_transpose(inputs, K.expand_dims(revcomp_kernel, axis=0), output_shape, (1, self.strides[0]), padding=self.padding, data_format=self.data_format, dilation_rate=(1, self.dilation_rate[0])) outputs = K.squeeze(outputs, axis=1) if self.use_bias: outputs = K.bias_add(outputs, revcomp_bias, data_format=self.data_format) if self.activation is not None: return self.activation(outputs) return outputs
def call(self, inputs): inputs_shape = array_ops.shape(inputs) batch_size = inputs_shape[0] if self.data_format == 'channels_first': h_axis, w_axis = 2, 3 else: h_axis, w_axis = 1, 2 height, width = inputs_shape[h_axis], inputs_shape[w_axis] kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides if self.output_padding is None: out_pad_h = out_pad_w = None else: out_pad_h, out_pad_w = self.output_padding # Hard-code the output shape because Keras screws this up: # this only works for padding='same' stride_h, stride_w = self.strides out_height = inputs_shape[h_axis] * stride_h out_width = inputs_shape[w_axis] * stride_w if self.data_format == 'channels_first': output_shape = (batch_size, self.filters, out_height, out_width) else: output_shape = (batch_size, out_height, out_width, self.filters) outputs = K.conv2d_transpose(inputs, self.kernel, output_shape, self.strides, padding=self.padding, data_format=self.data_format, dilation_rate=self.dilation_rate) if self.use_bias: outputs = K.bias_add(outputs, self.bias, data_format=self.data_format) if self.activation is not None: return self.activation(outputs) return outputs
def conv2d_transpose( inputs, filter, # pylint: disable=redefined-builtin kernel_size=None, filters=None, strides=(1, 1), padding='same', output_padding=None, data_format='channels_last'): """Compatibility layer for K.conv2d_transpose Take a filter defined for forward convolution and adjusts it for a transposed convolution.""" input_shape = K.shape(inputs) batch_size = input_shape[0] if data_format == 'channels_first': h_axis, w_axis = 2, 3 else: h_axis, w_axis = 1, 2 height, width = input_shape[h_axis], input_shape[w_axis] kernel_h, kernel_w = kernel_size stride_h, stride_w = strides # Infer the dynamic output shape: out_height = conv_utils.deconv_length(height, stride_h, kernel_h, padding, output_padding) out_width = conv_utils.deconv_length(width, stride_w, kernel_w, padding, output_padding) if data_format == 'channels_first': output_shape = (batch_size, filters, out_height, out_width) else: output_shape = (batch_size, out_height, out_width, filters) filter = K.permute_dimensions(filter, (0, 1, 3, 2)) return K.conv2d_transpose(inputs, filter, output_shape, strides, padding=padding, data_format=data_format)
def __call__(self, w): norm = self.max_k if len(w.shape) == 4: x = K.random_normal_variable(shape=(1,) + self.in_shape[1:3] + (self.in_shape[0],), mean=0, scale=1) for i in range(0, self.iterations): x_p = K.conv2d(x, w, strides=self.stride, padding=self.padding) x = K.conv2d_transpose(x_p, w, x.shape, strides=self.stride, padding=self.padding) Wx = K.conv2d(x, w, strides=self.stride, padding=self.padding) norm = K.sqrt(K.sum(K.pow(Wx, 2.0)) / K.sum(K.pow(x, 2.0))) else: x = K.random_normal_variable(shape=(int(w.shape[1]), 1), mean=0, scale=1) for i in range(0, self.iterations): x_p = K.dot(w, x) x = K.dot(K.transpose(w), x_p) norm = K.sqrt(K.sum(K.pow(K.dot(w, x), 2.0)) / K.sum(K.pow(x, 2.0))) return w * (1.0 / K.maximum(1.0, norm / self.max_k))
def call(self, inputs): input_shape = K.shape(inputs) batch_size = input_shape[0] if self.data_format == 'channels_first': h_axis, w_axis = 2, 3 else: h_axis, w_axis = 1, 2 height, width = input_shape[h_axis], input_shape[w_axis] kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides if self.output_padding is None: out_pad_h = out_pad_w = None else: out_pad_h, out_pad_w = self.output_padding # Infer the dynamic output shape: out_height = conv_utils.deconv_length(height, stride_h, kernel_h, self.padding, out_pad_h) out_width = conv_utils.deconv_length(width, stride_w, kernel_w, self.padding, out_pad_w) if self.data_format == 'channels_first': output_shape = (batch_size, self.filters, out_height, out_width) else: output_shape = (batch_size, out_height, out_width, self.filters) #Spectral Normalization def _l2normalize(v, eps=1e-12): return v / (K.sum(v**2)**0.5 + eps) def power_iteration(W, u): #Accroding the paper, we only need to do power iteration one time. _u = u _v = _l2normalize(K.dot(_u, K.transpose(W))) _u = _l2normalize(K.dot(_v, W)) return _u, _v W_shape = self.kernel.shape.as_list() #Flatten the Tensor W_reshaped = K.reshape(self.kernel, [-1, W_shape[-1]]) _u, _v = power_iteration(W_reshaped, self.u) #Calculate Sigma sigma = K.dot(_v, W_reshaped) sigma = K.dot(sigma, K.transpose(_u)) #normalize it W_bar = W_reshaped / sigma #reshape weight tensor if training in {0, False}: W_bar = K.reshape(W_bar, W_shape) else: with tf.control_dependencies([self.u.assign(_u)]): W_bar = K.reshape(W_bar, W_shape) self.kernel = W_bar outputs = K.conv2d_transpose(inputs, self.kernel, output_shape, self.strides, padding=self.padding, data_format=self.data_format) if self.use_bias: outputs = K.bias_add(outputs, self.bias, data_format=self.data_format) if self.activation is not None: return self.activation(outputs) return outputs
def call(self, inputs, training=None): input_shape = K.shape(inputs) batch_size = input_shape[0] if self.data_format == 'channels_first': h_axis, w_axis = 2, 3 else: h_axis, w_axis = 1, 2 height, width = input_shape[h_axis], input_shape[w_axis] kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides if self.output_padding is None: out_pad_h = out_pad_w = None else: out_pad_h, out_pad_w = self.output_padding # Infer the dynamic output shape: out_height = conv_utils.deconv_length(height, stride_h, kernel_h, self.padding, out_pad_h) out_width = conv_utils.deconv_length(width, stride_w, kernel_w, self.padding, out_pad_w) if self.data_format == 'channels_first': output_shape = (batch_size, self.filters, out_height, out_width) else: output_shape = (batch_size, out_height, out_width, self.filters) w_shape = self.kernel.shape.as_list() # Flatten the Tensor w_reshaped = K.reshape(self.kernel, [-1, w_shape[-1]]) _u, _v = power_iteration(w_reshaped, self.u) # Calculate Sigma sigma = K.dot(_v, w_reshaped) sigma = K.dot(sigma, K.transpose(_u)) # normalize it w_bar = w_reshaped / sigma # reshape weight tensor if not training: w_bar = K.reshape(w_bar, w_shape) else: with tf.control_dependencies([self.u.assign(_u)]): w_bar = K.reshape(w_bar, w_shape) outputs = K.conv2d_transpose( inputs, w_bar, output_shape, self.strides, padding=self.padding, data_format=self.data_format) if self.use_bias: outputs = K.bias_add( outputs, self.bias, data_format=self.data_format) if self.activation is not None: return self.activation(outputs) return outputs
def call(self, inputs): inputs_shape = tf.shape(inputs) batch_size = inputs_shape[0] if self.data_format == 'channels_first': h_axis, w_axis = 2, 3 else: h_axis, w_axis = 1, 2 # Use the constant height and weight when possible. # TODO(scottzhu): Extract this into a utility function that can be applied # to all convolutional layers, which currently lost the static shape # information due to tf.shape(). height, width = None, None if inputs.shape.rank is not None: dims = inputs.shape.as_list() height = dims[h_axis] width = dims[w_axis] height = height if height is not None else inputs_shape[h_axis] width = width if width is not None else inputs_shape[w_axis] kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides if self.output_padding is None: out_pad_h = out_pad_w = None else: out_pad_h, out_pad_w = self.output_padding # Infer the dynamic output shape: out_height = conv_utils.deconv_output_length( height, kernel_h, padding=self.padding, output_padding=out_pad_h, stride=stride_h, dilation=self.dilation_rate[0]) out_width = conv_utils.deconv_output_length( width, kernel_w, padding=self.padding, output_padding=out_pad_w, stride=stride_w, dilation=self.dilation_rate[1]) if self.data_format == 'channels_first': output_shape = (batch_size, self.filters, out_height, out_width) else: output_shape = (batch_size, out_height, out_width, self.filters) output_shape_tensor = tf.stack(output_shape) outputs = backend.conv2d_transpose(inputs, self.kernel, output_shape_tensor, strides=self.strides, padding=self.padding, data_format=self.data_format, dilation_rate=self.dilation_rate) if not tf.executing_eagerly(): # Infer the static output shape: out_shape = self.compute_output_shape(inputs.shape) outputs.set_shape(out_shape) if self.use_bias: outputs = tf.nn.bias_add( outputs, self.bias, data_format=conv_utils.convert_data_format(self.data_format, ndim=4)) if self.activation is not None: return self.activation(outputs) return outputs