def __init__(self, graph): super(TorchModel, self).__init__() self.graph = graph self.layers = [] for layer in graph.layer_list: self.layers.append(layer.to_real_layer()) if graph.weighted: for index, layer in enumerate(self.layers): set_stub_weight_to_torch(self.graph.layer_list[index], layer) for index, layer in enumerate(self.layers): self.add_module(str(index), layer)