def _load(self, configs): """Loads a pytorch model from a directory string.""" self.configs = configs self.info, state_dict = load_file( configs["driver"]["torch_model_dict"], "pt") self.model = NeuralNetworkModel(self.info["smg_configs"]["processor"]) self.model.load_state_dict(state_dict)
class SurrogateModel(nn.Module): """ The TorchModel class is used to manage together a torch model and its state dictionary. The usage is expected to be as follows mymodel = TorchModel() mymodel.load_model('my_path/my_model.pt') mymodel.model """ # TODO: Automatically register the data type according to the configurations of the amplification variable of the info dictionary def __init__(self, configs): super().__init__() self._load(configs) self._init_voltage_ranges() self.amplification = TorchUtils.get_tensor_from_list( self.info["data_info"]["processor"]['driver']["amplification"]) self.clipping_value = TorchUtils.get_tensor_from_list( self.info["data_info"]["clipping_value"]) self.noise = get_noise(configs) def _load(self, configs): """Loads a pytorch model from a directory string.""" self.configs = configs self.info, state_dict = load_file( configs["driver"]["torch_model_dict"], "pt") self.model = NeuralNetworkModel(self.info["smg_configs"]["processor"]) self.model.load_state_dict(state_dict) def _init_voltage_ranges(self): offset = TorchUtils.get_tensor_from_list( self.info["data_info"]["input_data"]["offset"]) amplitude = TorchUtils.get_tensor_from_list( self.info["data_info"]["input_data"]["amplitude"]) min_voltage = (offset - amplitude).unsqueeze(dim=1) max_voltage = (offset + amplitude).unsqueeze(dim=1) self.voltage_ranges = torch.cat((min_voltage, max_voltage), dim=1) def forward(self, x): return self.noise(self.model(x) * self.amplification) def forward_numpy(self, input_matrix): with torch.no_grad(): inputs_torch = TorchUtils.get_tensor_from_numpy(input_matrix) output = self.forward(inputs_torch) return TorchUtils.get_numpy_from_tensor(output) def reset(self): print("Warning: Reset function in Surrogate Model not implemented.") pass def close(self): # print('The surrogate model does not have a closing function. ') pass def is_hardware(self): return False
import torch from brainspy.processors.simulation.model import NeuralNetworkModel from brainspy.utils.io import load_configs from brainspy.utils.pytorch import TorchUtils from bspysmg.model.data.outputs.train_model import train_surrogate_model #TorchUtils.force_cpu = True CONFIGS = load_configs( 'configs/training/smg_configs_template_multiple_outputs.yaml') MODEL = NeuralNetworkModel(CONFIGS['processor']) OPTIMIZER = torch.optim.Adam(filter(lambda p: p.requires_grad, MODEL.parameters()), lr=CONFIGS['hyperparameters']['learning_rate']) CRITERION = torch.nn.MSELoss() train_surrogate_model(CONFIGS, MODEL, CRITERION, OPTIMIZER)