def import_keras_model(model, network_name): """ Import a keras model into conx. """ from .network import Network import inspect import conx network = Network(network_name) network.model = model conx_layers = { name: layer for (name, layer) in inspect.getmembers(conx.layers, inspect.isclass) } # First, make all of the conx layers: for layer in model.layers: clayer_class = conx_layers[layer.__class__.__name__ + "Layer"] if clayer_class.__name__ == "InputLayerLayer": clayer = conx.layers.InputLayer(layer.name, None) #clayer.make_input_layer_k = lambda layer=layer: layer clayer.shape = None clayer.params["batch_shape"] = layer.get_config( )["batch_input_shape"] #clayer.params = layer.get_config() clayer.k = clayer.make_input_layer_k() clayer.keras_layer = clayer.k else: clayer = clayer_class(**layer.get_config()) clayer.k = layer clayer.keras_layer = layer network.add(clayer) # Next, connect them up: for layer_from in model.layers: for node in layer.outbound_nodes: network.connect(layer_from, node.outbound_layer.name) print("connecting:", layer_from, node.outbound_layer.name) # Connect them all up, and set input banks: network.connect() for clayer in network.layers: clayer.input_names = network.input_bank_order # Finally, make the internal models: for clayer in network.layers: ## FIXME: the appropriate inputs: if clayer.kind() != "input": clayer.model = keras.models.Model( inputs=model.layers[0].input, outputs=clayer.keras_layer.output) return network