def __init__(self, graph): super(TorchModel, self).__init__() self.graph = graph self.layers = [] for layer in graph.layer_list: self.layers.append(to_real_layer(layer)) if graph.weighted: for index, layer in enumerate(self.layers): set_stub_weight_to_torch(self.graph.layer_list[index], layer)
def produce_model(self): """Build a new Keras model based on the current graph.""" input_tensor = Input(shape=self.input_shape) topo_node_list = self._topological_order() output_id = topo_node_list[-1] input_id = topo_node_list[0] new_to_old_layer = {} node_list = deepcopy(self.node_list) node_list[input_id] = input_tensor node_to_id = deepcopy(self.node_to_id) node_to_id[input_tensor] = input_id for v in topo_node_list: for u, layer_id in self.reverse_adj_list[v]: layer = self.layer_list[layer_id] if isinstance(layer, (StubAdd, StubConcatenate)): edge_input_tensor = list( map(lambda x: node_list[x], self.layer_id_to_input_node_ids[layer_id])) else: edge_input_tensor = node_list[u] new_layer = to_real_layer(layer) new_to_old_layer[new_layer] = layer temp_tensor = new_layer(edge_input_tensor) node_list[v] = temp_tensor node_to_id[temp_tensor] = v model = Model(input_tensor, node_list[output_id]) for layer in model.layers[1:]: if not isinstance(layer, (Activation, Dropout, Concatenate, Add)): old_layer = new_to_old_layer[layer] if self.weighted: layer.set_weights(old_layer.get_weights()) return model