示例#1
0
 def __init__(self, test_name):
     super(WaveformTest, self).__init__()
     configs = {}
     configs["plateau_length"] = 80
     configs["slope_length"] = 20
     self.configs = configs
     self.waveform_mgr = WaveformManager(configs)
示例#2
0
class DataPointsToPlateau:
    def __init__(self, configs):
        self.mgr = WaveformManager(configs)

    def __call__(self, data):
        inputs, targets = data[0], data[1]

        inputs = self.mgr.points_to_plateaus(inputs)
        targets = self.mgr.points_to_plateaus(targets)

        return (inputs, targets)
示例#3
0
 def __init__(self, configs, logger=None):
     super(HardwareProcessor, self).__init__()
     self.driver = get_driver(configs)
     if configs['processor_type'] == 'simulation_debug':
         self.voltage_ranges = self.driver.voltage_ranges
     else:
         self.voltage_ranges = TorchUtils.get_tensor_from_numpy(self.driver.voltage_ranges)
     self.waveform_mgr = WaveformManager(configs["data"]["waveform"])
     self.logger = logger
     # TODO: Manage amplification from this class
     self.amplification = configs["driver"]["amplification"]
     self.clipping_value = [
         configs["driver"]["output_clipping_range"][0] * self.amplification,
         configs["driver"]["output_clipping_range"][1] * self.amplification,
     ]
示例#4
0
class HardwareProcessor(nn.Module):
    """
    The TorchModel class is used to manage together a torch model and its state dictionary. The usage is expected to be as follows
    mymodel = TorchModel()
    mymodel.load_model('my_path/my_model.pt')
    mymodel.model
    """

    # TODO: Automatically register the data type according to the configurations of the amplification variable of the  info dictionary

    def __init__(self, configs, logger=None):
        super(HardwareProcessor, self).__init__()
        self.driver = get_driver(configs)
        if configs['processor_type'] == 'simulation_debug':
            self.voltage_ranges = self.driver.voltage_ranges
        else:
            self.voltage_ranges = TorchUtils.get_tensor_from_numpy(
                self.driver.voltage_ranges)
        self.waveform_mgr = WaveformManager(configs["data"]["waveform"])
        self.logger = logger
        # TODO: Manage amplification from this class
        self.amplification = configs["driver"]["amplification"]
        self.clipping_value = torch.tensor(
            [
                configs["driver"]["output_clipping_range"][0] *
                self.amplification[0],
                configs["driver"]["output_clipping_range"][1] *
                self.amplification[0],
            ],
            device=TorchUtils.get_accelerator_type())

    def forward(self, x):
        with torch.no_grad():
            x, mask = self.waveform_mgr.plateaus_to_waveform(
                x, return_pytorch=False)
            output = self.forward_numpy(x)
            if self.logger is not None:
                self.logger.log_output(x)
        return TorchUtils.get_tensor_from_numpy(output[mask])

    def get_clipping_value(self):
        return self.clipping_value.T

    def forward_numpy(self, x):
        return self.driver.forward_numpy(x)

    def reset(self):
        self.driver.reset()

    def close(self):
        if "close_tasks" in dir(self.driver):
            self.driver.close_tasks()
        else:
            print('Warning: Driver tasks have not been closed.')

    def is_hardware(self):
        return self.driver.is_hardware()
示例#5
0
class WaveformTest(unittest.TestCase):
    def __init__(self, test_name):
        super(WaveformTest, self).__init__()
        configs = {}
        configs["plateau_length"] = 80
        configs["slope_length"] = 20
        self.configs = configs
        self.waveform_mgr = WaveformManager(configs)

    def full_check(self, point_no):
        points = torch.rand(point_no, device=TorchUtils.get_accelerator_type(), dtype=TorchUtils.get_data_type())  # .unsqueeze(dim=1)
        waveform = self.waveform_mgr.points_to_waveform(points)
        assert (
            (waveform[0, :] == 0.0).all() and (waveform[-1, :] == 0.0).all()
        ), "Waveforms do not start and end with zero"
        assert len(waveform) == (
            (self.waveform_mgr.plateau_length * len(points))
            + (self.waveform_mgr.slope_length * (len(points) + 1))
        ), "Waveform has an incorrect shape"

        mask = self.waveform_mgr.generate_mask(len(waveform))
        assert len(mask) == len(waveform)

        waveform_to_points = self.waveform_mgr.waveform_to_points(waveform)
        plateaus_to_points = self.waveform_mgr.plateaus_to_points(waveform[mask])
        assert (
            (points.half().float() == waveform_to_points.half().float()).all()
            == (points.half().float() == plateaus_to_points.half().float()).all()
            == True
        ), "Inconsistent to_point conversion"

        points_to_plateau = self.waveform_mgr.points_to_plateaus(points)
        waveform_to_plateau = self.waveform_mgr.waveform_to_plateaus(waveform)
        assert (waveform[mask] == points_to_plateau).all() == (
            waveform[mask] == waveform_to_plateau
        ).all(), "Inconsistent plateau conversion"

        plateaus_to_waveform, _ = self.waveform_mgr.plateaus_to_waveform(
            waveform[mask]
        )
        assert (
            waveform == plateaus_to_waveform
        ).all(), "Inconsistent waveform conversion"

    def runTest(self):
        self.full_check((1, 1))
        self.full_check((10, 1))
        self.full_check((100, 1))
        self.full_check((10, 2))
        self.full_check((100, 7))
示例#6
0
 def __init__(self, configs):
     self.mgr = WaveformManager(configs)
示例#7
0
class PointsToPlateaus:
    def __init__(self, configs):
        self.mgr = WaveformManager(configs)

    def __call__(self, x):
        return self.mgr.points_to_plateaus(x)