def get_conv_residual_model( input_size, nonlinearity, kernel_size, num_features, num_layers): """Generic conv model with residual convolutions. Args: input_size(int): Tuple containing input shape, excluding batch size nonlinearity(str): 'relu' or '3-piece' kernel_size(int): Dimension of each convolutional filter num_features(int): Number of features in each intermediate network layer num_layers(int): Number of convolutional layers in model Returns: model: Keras model comprised of num_layers core convolutional layers with specified nonlinearities """ inputs = Input(input_size) prev_layer = inputs for i in range(num_layers): conv = layers.BatchNormConv(num_features, kernel_size)(prev_layer) nonlinear = layers.get_nonlinear_layer(nonlinearity)(conv) prev_layer = nonlinear + tf.tile(inputs, [1, 1, 1, int(num_features/2)]) output = Conv2D(2, kernel_size, activation=None, padding='same', kernel_initializer='he_normal')(prev_layer) + inputs model = keras.models.Model(inputs=inputs, outputs=output) return model
def get_alternating_residual_model(input_size, nonlinearity, kernel_size, num_features, num_convs, num_layers, enforce_dc): """Alternating model with residual convolutions. Returns a model that takes a frequency-space input (of shape (batch_size, n, n, 2)) and returns a frequency-space output of the same size, comprised of alternating frequency- and image-space convolutional layers and with connections from the input to each layer. Args: input_size(int): Tuple containing input shape, excluding batch size nonlinearity(str): 'relu' or '3-piece' kernel_size(int): Dimension of each convolutional filter num_features(int): Number of features in each intermediate network layer num_convs(int): Number of convolutions per layer num_layers(int): Number of convolutional layers in model Returns: model: Keras model comprised of num_layers alternating image- and frequency-space convolutional layers with specified nonlinearities """ inputs = Input(input_size) if (enforce_dc): masks = Input(input_size) n = inputs.get_shape().as_list()[1] inp_real = tf.expand_dims(inputs[:, :, :, 0], -1) inp_imag = tf.expand_dims(inputs[:, :, :, 1], -1) n_copies = int(num_features / 2) inp_copy = tf.reshape( tf.tile(tf.expand_dims(tf.concat([inp_real, inp_imag], axis=3), 4), [1, 1, 1, 1, n_copies]), [-1, n, n, num_features]) inputs_img = utils.convert_tensor_to_image_domain(inputs) inp_img_real = tf.expand_dims(inputs_img[:, :, :, 0], -1) inp_img_imag = tf.expand_dims(inputs_img[:, :, :, 1], -1) inp_img_copy = tf.reshape( tf.tile( tf.expand_dims(tf.concat([inp_img_real, inp_img_imag], axis=3), 4), [1, 1, 1, 1, n_copies]), [-1, n, n, num_features]) prev_layer = inputs for i in range(num_layers): for j in range(num_convs): k_conv = layers.BatchNormConv(num_features, kernel_size)(prev_layer) + inp_copy prev_layer = layers.get_nonlinear_layer(nonlinearity)(k_conv) prev_layer = utils.convert_channels_to_image_domain(prev_layer) for j in range(num_convs): img_conv = layers.BatchNormConv( num_features, kernel_size)(prev_layer) + inp_img_copy prev_layer = layers.get_nonlinear_layer('relu')(img_conv) prev_layer = utils.convert_channels_to_frequency_domain(prev_layer) output = Conv2D(2, kernel_size, activation=None, padding='same', kernel_initializer='he_normal')(prev_layer) + inputs if (enforce_dc): output = masks * inputs + (1 - masks) * output model = keras.models.Model(inputs=(inputs, masks), outputs=output) else: model = keras.models.Model(inputs=inputs, outputs=output) return model