def operate(self): """Perform all jobs assigned to the surgeon. """ # Operate on each node in self.nodes by order of decreasing depth. sorted_nodes = sorted( self.nodes, reverse=True, key=lambda x: utils.get_node_depth(self.model, x)) for node in sorted_nodes: # Rebuild submodel up to this node sub_output_nodes = utils.get_node_inbound_nodes(node) outputs, output_masks = self._rebuild_graph( self.model.inputs, sub_output_nodes) # Perform surgery at this node kwargs = self._kwargs_map[node] self._mod_func_map[node](node, outputs, output_masks, **kwargs) # Finish rebuilding model output_nodes = [] for output in self.model.outputs: layer, node_index, tensor_index = output._keras_history output_nodes.append(get_inbound_nodes(layer)[node_index]) new_outputs, _ = self._rebuild_graph(self.model.inputs, output_nodes) new_model = self._model_cls(self.model.inputs, new_outputs) if self._copy: return utils.clean_copy(new_model, self._custom_objects) else: return new_model
def operate(self): """Perform all jobs assigned to the surgeon. """ # Operate on each node in self.nodes by order of decreasing depth. sorted_nodes = sorted( self.nodes, reverse=True, key=lambda x: utils.get_node_depth(self.model, x)) for node in sorted_nodes: # Rebuild submodel up to this node sub_output_nodes = utils.get_node_inbound_nodes(node) outputs, output_masks = self._rebuild_graph( self.model.inputs, sub_output_nodes) # Perform surgery at this node kwargs = self._kwargs_map[node] self._mod_func_map[node](node, outputs, output_masks, **kwargs) # Finish rebuilding model output_nodes = [ get_inbound_nodes(self.model.output_layers[i])[node_index] for i, node_index in enumerate(self.model.output_layers_node_indices) ] new_outputs, _ = self._rebuild_graph(self.model.inputs, output_nodes) new_model = Model(self.model.inputs, new_outputs) if self._copy: return utils.clean_copy(new_model) else: return new_model
def _rebuild_rec(node): """Rebuild the graph up to `node` recursively. Args: node(Node): Node to rebuild up to. Returns: (tuple) containing : The output tensor of the rebuilt `node` The output mask of the rebuilt `node` """ # TODO: What happens if nodes have multiple output tensors? # Does that ever happen? layer = node.outbound_layer logger.debug("getting inputs for: {0}".format(layer.name)) node_output = utils.single_element(node.output_tensors) # First check for conditions to bottom out the recursion # Check for replaced tensors before any other checks: # these are created by the surgery methods. if node_output.name in self._replace_tensors.keys(): logger.debug( "bottomed out at replaced output: {0}".format(node_output)) output, output_mask = self._replace_tensors[node_output.name] return output, output_mask # Next check if the current node has already been rebuilt. elif node in self._finished_nodes.keys(): logger.debug("reached finished node: {0}".format(node)) return self._finished_nodes[node] # Next check if one of the graph_inputs has been reached. elif node_output.name in names(graph_inputs): logger.debug("bottomed out at a model input") output_mask = graph_input_masks[graph_inputs.index( node_output)] return node_output, output_mask # Otherwise recursively call this method on the inbound nodes. else: inbound_nodes = utils.get_node_inbound_nodes(node) logger.debug("inbound_layers: {0}".format( [node.outbound_layer.name for node in inbound_nodes])) # Recursively rebuild the model up to `node`s inbound nodes to # obtain its inputs and input masks inputs, input_masks = zip( *[_rebuild_rec(n) for n in inbound_nodes]) if all(i is None for i in inputs): output = None output_mask = np.zeros(node.output_shapes[0][1:], dtype=bool) elif any(i is None for i in inputs): if node.outbound_layer.__class__.__name__ != "Concatenate": TypeError( "Inputs can only be missing for concatenate layers." ) # remove Nones from inputs list inputs = [i for i in inputs if i is not None] new_layer, output_mask = self._apply_delete_mask( node, input_masks) if len(inputs) == 1: output = utils.single_element(list(inputs)) else: output = new_layer(utils.single_element(list(inputs))) else: new_layer, output_mask = self._apply_delete_mask( node, input_masks) output = new_layer(utils.single_element(list(inputs))) # Record that this node has been rebuild self._finished_nodes[node] = (output, output_mask) logger.debug("layer complete: {0}".format(layer.name)) return output, output_mask