def __init__(self, test_name): super(WaveformTest, self).__init__() configs = {} configs["plateau_length"] = 80 configs["slope_length"] = 20 self.configs = configs self.waveform_mgr = WaveformManager(configs)
class DataPointsToPlateau: def __init__(self, configs): self.mgr = WaveformManager(configs) def __call__(self, data): inputs, targets = data[0], data[1] inputs = self.mgr.points_to_plateaus(inputs) targets = self.mgr.points_to_plateaus(targets) return (inputs, targets)
def __init__(self, configs, logger=None): super(HardwareProcessor, self).__init__() self.driver = get_driver(configs) if configs['processor_type'] == 'simulation_debug': self.voltage_ranges = self.driver.voltage_ranges else: self.voltage_ranges = TorchUtils.get_tensor_from_numpy(self.driver.voltage_ranges) self.waveform_mgr = WaveformManager(configs["data"]["waveform"]) self.logger = logger # TODO: Manage amplification from this class self.amplification = configs["driver"]["amplification"] self.clipping_value = [ configs["driver"]["output_clipping_range"][0] * self.amplification, configs["driver"]["output_clipping_range"][1] * self.amplification, ]
class HardwareProcessor(nn.Module): """ The TorchModel class is used to manage together a torch model and its state dictionary. The usage is expected to be as follows mymodel = TorchModel() mymodel.load_model('my_path/my_model.pt') mymodel.model """ # TODO: Automatically register the data type according to the configurations of the amplification variable of the info dictionary def __init__(self, configs, logger=None): super(HardwareProcessor, self).__init__() self.driver = get_driver(configs) if configs['processor_type'] == 'simulation_debug': self.voltage_ranges = self.driver.voltage_ranges else: self.voltage_ranges = TorchUtils.get_tensor_from_numpy( self.driver.voltage_ranges) self.waveform_mgr = WaveformManager(configs["data"]["waveform"]) self.logger = logger # TODO: Manage amplification from this class self.amplification = configs["driver"]["amplification"] self.clipping_value = torch.tensor( [ configs["driver"]["output_clipping_range"][0] * self.amplification[0], configs["driver"]["output_clipping_range"][1] * self.amplification[0], ], device=TorchUtils.get_accelerator_type()) def forward(self, x): with torch.no_grad(): x, mask = self.waveform_mgr.plateaus_to_waveform( x, return_pytorch=False) output = self.forward_numpy(x) if self.logger is not None: self.logger.log_output(x) return TorchUtils.get_tensor_from_numpy(output[mask]) def get_clipping_value(self): return self.clipping_value.T def forward_numpy(self, x): return self.driver.forward_numpy(x) def reset(self): self.driver.reset() def close(self): if "close_tasks" in dir(self.driver): self.driver.close_tasks() else: print('Warning: Driver tasks have not been closed.') def is_hardware(self): return self.driver.is_hardware()
class WaveformTest(unittest.TestCase): def __init__(self, test_name): super(WaveformTest, self).__init__() configs = {} configs["plateau_length"] = 80 configs["slope_length"] = 20 self.configs = configs self.waveform_mgr = WaveformManager(configs) def full_check(self, point_no): points = torch.rand(point_no, device=TorchUtils.get_accelerator_type(), dtype=TorchUtils.get_data_type()) # .unsqueeze(dim=1) waveform = self.waveform_mgr.points_to_waveform(points) assert ( (waveform[0, :] == 0.0).all() and (waveform[-1, :] == 0.0).all() ), "Waveforms do not start and end with zero" assert len(waveform) == ( (self.waveform_mgr.plateau_length * len(points)) + (self.waveform_mgr.slope_length * (len(points) + 1)) ), "Waveform has an incorrect shape" mask = self.waveform_mgr.generate_mask(len(waveform)) assert len(mask) == len(waveform) waveform_to_points = self.waveform_mgr.waveform_to_points(waveform) plateaus_to_points = self.waveform_mgr.plateaus_to_points(waveform[mask]) assert ( (points.half().float() == waveform_to_points.half().float()).all() == (points.half().float() == plateaus_to_points.half().float()).all() == True ), "Inconsistent to_point conversion" points_to_plateau = self.waveform_mgr.points_to_plateaus(points) waveform_to_plateau = self.waveform_mgr.waveform_to_plateaus(waveform) assert (waveform[mask] == points_to_plateau).all() == ( waveform[mask] == waveform_to_plateau ).all(), "Inconsistent plateau conversion" plateaus_to_waveform, _ = self.waveform_mgr.plateaus_to_waveform( waveform[mask] ) assert ( waveform == plateaus_to_waveform ).all(), "Inconsistent waveform conversion" def runTest(self): self.full_check((1, 1)) self.full_check((10, 1)) self.full_check((100, 1)) self.full_check((10, 2)) self.full_check((100, 7))
def __init__(self, configs): self.mgr = WaveformManager(configs)
class PointsToPlateaus: def __init__(self, configs): self.mgr = WaveformManager(configs) def __call__(self, x): return self.mgr.points_to_plateaus(x)