def convnet_preprocessor(input_shapes, image_shape, output_size, name="convnet_preprocessor", make_picklable=True, *args, **kwargs): inputs = [layers.Input(shape=input_shape) for input_shape in input_shapes] concatenated_input = layers.Lambda(lambda x: tf.concat(x, axis=-1))(inputs) image_size = np.prod(image_shape) images_flat, input_raw = layers.Lambda( lambda x: [x[..., :image_size], x[..., image_size:]])( concatenated_input) images = layers.Reshape(image_shape)(images_flat) preprocessed_images = convnet( input_shape=image_shape, output_size=output_size - input_raw.shape[-1], *args, **kwargs, )(images) output = layers.Lambda(lambda x: tf.concat(x, axis=-1))( [preprocessed_images, input_raw]) preprocessor = PicklableKerasModel(inputs, output, name=name) return preprocessor
def spatial_ae(latent_dim): """ Implements the Deep Spatial AutoEncoder described in Finn et al. (2016) """ assert latent_dim % 2 == 0, latent_dim input_image = tf.keras.layers.Input(shape=(84, 84, 3)) conv = tf.keras.layers.Conv2D(filters=32, kernel_size=5, strides=(3, 3), activation=tf.nn.relu)(input_image) conv = tf.keras.layers.Conv2D(filters=32, kernel_size=5, strides=(3, 3), activation=tf.nn.relu)(conv) conv = tf.keras.layers.Conv2D(filters=int(latent_dim / 2), kernel_size=5, strides=(3, 3), activation=tf.nn.relu)(conv) #feature_points = tf.contrib.layers.spatial_softmax(conv, name='spatial_softmax') feature_points = SpatialSoftMax()(conv) feature_points_dropout = tf.keras.layers.Dropout(0.5)(feature_points) low_dim = 7 #image dimension of downsampled image out = tf.keras.layers.Dense(units=low_dim * low_dim * 32, activation=tf.nn.relu)(feature_points_dropout) out = tf.keras.layers.Reshape(target_shape=(low_dim, low_dim, 32))(out) out = tf.keras.layers.Conv2DTranspose(filters=64, kernel_size=5, strides=(3, 3), padding="SAME", activation=tf.nn.relu)(out) out = tf.keras.layers.Conv2DTranspose(filters=64, kernel_size=5, strides=(2, 2), padding="SAME", activation=tf.nn.relu)(out) out = tf.keras.layers.Conv2DTranspose(filters=32, kernel_size=5, strides=(2, 2), padding="SAME", activation=tf.nn.relu)(out) # No activation reconstruction = tf.keras.layers.Conv2DTranspose( filters=3, kernel_size=3, strides=(1, 1), padding="SAME", name='reconstruction')(out) return PicklableKerasModel(inputs=input_image, outputs=[feature_points, reconstruction])
def vanilla_ae(latent_dim): input_image = tf.keras.layers.Input(shape=(84, 84, 3)) conv = tf.keras.layers.Conv2D(filters=32, kernel_size=5, strides=(3, 3), activation=tf.nn.relu)(input_image) conv = tf.keras.layers.Conv2D(filters=64, kernel_size=5, strides=(3, 3), activation=tf.nn.relu)(conv) conv = tf.keras.layers.Conv2D(filters=64, kernel_size=5, strides=(3, 3), activation=tf.nn.relu)(conv) flat = tf.keras.layers.Flatten()(conv) latent_features = tf.keras.layers.Dense(latent_dim, activation=tf.nn.relu)(flat) low_dim = 7 #image dimension of downsampled image out = tf.keras.layers.Dense(units=low_dim * low_dim * 32, activation=tf.nn.relu)(latent_features) out = tf.keras.layers.Reshape(target_shape=(low_dim, low_dim, 32))(out) out = tf.keras.layers.Conv2DTranspose(filters=64, kernel_size=5, strides=(3, 3), padding="SAME", activation=tf.nn.relu)(out) out = tf.keras.layers.Conv2DTranspose(filters=64, kernel_size=5, strides=(2, 2), padding="SAME", activation=tf.nn.relu)(out) out = tf.keras.layers.Conv2DTranspose(filters=32, kernel_size=5, strides=(2, 2), padding="SAME", activation=tf.nn.relu)(out) # No activation reconstruction = tf.keras.layers.Conv2DTranspose( filters=3, kernel_size=3, strides=(1, 1), padding="SAME", name='reconstruction')(out) return PicklableKerasModel(inputs=input_image, outputs=[latent_features, reconstruction])
def feedforward_model(input_shapes, output_size, hidden_layer_sizes, activation='relu', output_activation='linear', preprocessors=None, name='feedforward_model', *args, **kwargs): inputs = [ tf.keras.layers.Input(shape=input_shape) for input_shape in input_shapes ] if preprocessors is None: preprocessors = (None, ) * len(inputs) preprocessed_inputs = [ preprocessor(input_) if preprocessor is not None else input_ for preprocessor, input_ in zip(preprocessors, inputs) ] concatenated = tf.keras.layers.Lambda(lambda x: tf.concat(x, axis=-1))( preprocessed_inputs) out = concatenated for units in hidden_layer_sizes: out = tf.keras.layers.Dense(units, *args, activation=activation, **kwargs)(out) out = tf.keras.layers.Dense(output_size, *args, activation=output_activation, **kwargs)(out) model = PicklableKerasModel(inputs, out, name=name) return model
def convnet_preprocessor(input_shapes, image_shape, output_size, conv_filters=(32, 32), conv_kernel_sizes=((5, 5), (5, 5)), pool_type='MaxPool2D', pool_sizes=((2, 2), (2, 2)), pool_strides=(2, 2), dense_hidden_layer_sizes=(64, 64), data_format='channels_last', name="convnet_preprocessor", make_picklable=True, *args, **kwargs): if data_format == 'channels_last': H, W, C = image_shape elif data_format == 'channels_first': C, H, W = image_shape inputs = [ tf.keras.layers.Input(shape=input_shape) for input_shape in input_shapes ] concatenated_input = tf.keras.layers.Lambda( lambda x: tf.concat(x, axis=-1))(inputs) images_flat, input_raw = tf.keras.layers.Lambda( lambda x: [x[..., :H * W * C], x[..., H * W * C:]])(concatenated_input) images = tf.keras.layers.Reshape(image_shape)(images_flat) conv_out = images for filters, kernel_size, pool_size, strides in zip( conv_filters, conv_kernel_sizes, pool_sizes, pool_strides): conv_out = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, padding="SAME", activation=tf.nn.relu, *args, **kwargs)(conv_out) conv_out = getattr(tf.keras.layers, pool_type)(pool_size=pool_size, strides=strides)(conv_out) flattened = tf.keras.layers.Flatten()(conv_out) concatenated_output = tf.keras.layers.Lambda( lambda x: tf.concat(x, axis=-1))([flattened, input_raw]) output = (feedforward_model( input_shapes=(concatenated_output.shape[1:].as_list(), ), output_size=output_size, hidden_layer_sizes=dense_hidden_layer_sizes, activation='relu', output_activation='linear', *args, **kwargs)([concatenated_output]) if dense_hidden_layer_sizes else concatenated_output) model = PicklableKerasModel(inputs, output, name=name) return model