def inverse_wavelet_transform(vres, inv_filters=None, output_shape=None, levels=None): if inv_filters is None: w = pywt.Wavelet('db4') rec_hi = np.array(w.rec_hi) rec_lo = np.array(w.rec_lo) inv_filters = np.stack([ rec_lo[None, None, :] * rec_lo[None, :, None] * rec_lo[:, None, None], rec_lo[None, None, :] * rec_lo[None, :, None] * rec_hi[:, None, None], rec_lo[None, None, :] * rec_hi[None, :, None] * rec_lo[:, None, None], rec_lo[None, None, :] * rec_hi[None, :, None] * rec_hi[:, None, None], rec_hi[None, None, :] * rec_lo[None, :, None] * rec_lo[:, None, None], rec_hi[None, None, :] * rec_lo[None, :, None] * rec_hi[:, None, None], rec_hi[None, None, :] * rec_hi[None, :, None] * rec_lo[:, None, None], rec_hi[None, None, :] * rec_hi[None, :, None] * rec_hi[:, None, None] ]).transpose((1, 2, 3, 0))[:, :, :, None, :] inv_filters = K.constant(inv_filters) if levels is None: levels = pywt.dwtn_max_level(K.int_shape(vres)[1:4], 'db4') print(levels) t = vres.shape[1] h = vres.shape[2] w = vres.shape[3] ''' res = K.permute_dimensions(vres, (0, 4, 1, 2, 3)) res = K.reshape(res, (-1, t // 2, 2, h // 2, w // 2)) res = K.permute_dimensions(res, (0, 2, 1, 3, 4)) res = K.reshape(res, (-1, 8, t // 2, h // 2, w // 2)) res = K.permute_dimensions(res, (0, 2, 3, 4, 1)) ''' res = K.reshape(vres, (-1, t // 2, h // 2, w // 2, 8)) if levels > 1: res = K.concatenate([ inverse_wavelet_transform( res[:, :, :, :, :1], inv_filters, output_shape=(K.shape(vres)[0], K.shape(vres)[1] // 2, K.shape(vres)[2] // 2, K.shape(vres)[3] // 2, K.shape(vres)[4]), levels=(levels - 1)), res[:, :, :, :, 1:] ], axis=-1) res = K.conv3d_transpose(res, inv_filters, output_shape=K.shape(vres), strides=(2, 2, 2), padding='same') out = res[:, :output_shape[1], :output_shape[2], :output_shape[3], :] #print('iwt', levels, K.int_shape(vres), K.int_shape(inv_filters), K.int_shape(res), K.int_shape(out), output_shape) return out
def _alphabeta_dtd(layer, R, beta, parameter2): print('_convolutional3d_alphabeta_dtd') alpha = 1 + beta X = layer.input + 1e-12 if not alpha == 0: Wp = K.maximum(layer.kernel, 1e-12) Zp = K.conv3d(X, Wp, strides=layer.strides, padding=layer.padding, data_format=layer.data_format) Salpha = alpha * (R / Zp) Calpha = K.conv3d_transpose(Salpha, Wp, K.shape(layer.input), strides=layer.strides, padding=layer.padding, data_format=layer.data_format) else: Calpha = 0 if not beta == 0: Wn = K.minimum(layer.kernel, -1e-12) Zn = K.conv3d(X, Wn, strides=layer.strides, padding=layer.padding, data_format=layer.data_format) Sbeta = -beta * (R / Zn) Cbeta = K.conv3d_transpose(Sbeta, Wn, K.shape(layer.input), strides=layer.strides, padding=layer.padding, data_format=layer.data_format) else: Cbeta = 0 return X * (Calpha + Cbeta)
def _ww_dtd(layer, R, parameter1, parameter2): print('_convolutional3d_ww_dtd') Z = K.square(layer.kernel) Zs = K.sum(Z, axis=[0, 1, 2, 3]) return K.conv3d_transpose(R, Z / Zs, K.shape(layer.input), strides=layer.strides, padding=layer.padding, data_format=layer.data_format)
def call(self, input_tensor, training=None): input_transposed = tf.transpose(input_tensor, [4, 0, 1, 2, 3, 5]) input_shape = K.shape(input_transposed) input_tensor_reshaped = K.reshape(input_transposed, [ input_shape[1] * input_shape[0], self.input_height, self.input_width, self.input_depth, self.input_num_atoms]) input_tensor_reshaped.set_shape((None, self.input_height, self.input_width, self.input_depth, self.input_num_atoms)) if self.upsamp_type == 'resize': # added 1 more self.scaling upsamp = K.resize_images(input_tensor_reshaped, self.scaling, self.scaling, self.scaling, 'channels_last') outputs = K.conv3d(upsamp, kernel=self.W, strides=(1, 1, 1), padding=self.padding, data_format='channels_last') elif self.upsamp_type == 'subpix': conv = K.conv3d(input_tensor_reshaped, kernel=self.W, strides=(1, 1, 1), padding='same', data_format='channels_last') outputs = tf.depth_to_space(conv, self.scaling) else: batch_size = input_shape[1] * input_shape[0] # Infer the dynamic output shape: out_height = deconv_length(self.input_height, self.scaling, self.kernel_size, self.padding) out_width = deconv_length(self.input_width, self.scaling, self.kernel_size, self.padding) out_depth = deconv_length(self.input_depth, self.scaling, self.kernel_size, self.padding) output_shape = (batch_size, out_height, out_width, out_depth, self.num_capsule * self.num_atoms) outputs = K.conv3d_transpose(input_tensor_reshaped, self.W, output_shape, (self.scaling, self.scaling, self.scaling), padding=self.padding, data_format='channels_last') votes_shape = K.shape(outputs) _, conv_height, conv_width, conv_depth, _ = outputs.get_shape() votes = K.reshape(outputs, [input_shape[2], input_shape[1], input_shape[0], votes_shape[1], votes_shape[2], self.num_capsule, self.num_atoms]) votes.set_shape((None, self.input_num_capsule, conv_height.value, conv_width.value, conv_depth.value, self.num_capsule, self.num_atoms)) logit_shape = K.stack([ input_shape[1], input_shape[0], votes_shape[1], votes_shape[2], votes_shape[3], self.num_capsule]) biases_replicated = K.tile(self.b, [votes_shape[1], votes_shape[2], votes_shape[3], 1, 1]) activations = update_routing( votes=votes, biases=biases_replicated, logit_shape=logit_shape, num_dims=7, input_dim=self.input_num_capsule, output_dim=self.num_capsule, num_routing=self.routings) return activations
def _z_dtd(layer, R, parameter1, parameter2): print('_convolutional3d_z_dtd') X = layer.input + 1e-12 Z = K.conv3d(X, layer.kernel, strides=layer.strides, padding=layer.padding, data_format=layer.data_format) S = R / Z C = K.conv3d_transpose(S, layer.kernel, K.shape(layer.input), strides=layer.strides, padding=layer.padding, data_format=layer.data_format) return X * C
def call(self, inputs): input_shape = K.shape(inputs) batch_size = input_shape[0] if self.data_format == 'channels_first': d_axis, h_axis, w_axis = 2, 3, 4 else: d_axis, h_axis, w_axis = 1, 2, 3 depth, height, width = input_shape[d_axis], input_shape[ h_axis], input_shape[w_axis] kernel_d, kernel_h, kernel_w = self.kernel_size stride_d, stride_h, stride_w = self.strides # Infer the dynamic output shape: out_depth = conv_utils.deconv_length(depth, stride_h, kernel_h, self.padding) out_height = conv_utils.deconv_length(height, stride_h, kernel_h, self.padding) out_width = conv_utils.deconv_length(width, stride_w, kernel_w, self.padding) if self.data_format == 'channels_first': output_shape = (batch_size, self.filters, out_depth, out_height, out_width) else: output_shape = (batch_size, out_depth, out_height, out_width, self.filters) outputs = K.conv3d_transpose(inputs, self.kernel, output_shape, self.strides, padding=self.padding, data_format=self.data_format) if self.bias: outputs = K.bias_add(outputs, self.bias, data_format=self.data_format) if self.activation is not None: return self.activation(outputs) return outputs