예제 #1
0
    def test_save_load_meta_parameter(self):
        """Test saving and loading a device with custom parameters."""
        # Create the device and the array.
        rpu_config = SingleRPUConfig(
            forward=IOParameters(inp_noise=0.321),
            backward=IOParameters(inp_noise=0.456),
            update=UpdateParameters(desired_bl=78),
            device=ConstantStepDevice(w_max=0.987)
        )

        model = self.get_layer(rpu_config=rpu_config)

        # Save the model to a file.
        with TemporaryFile() as file:
            save(model, file)
            # Load the model.
            file.seek(0)
            new_model = load(file)

        # Assert over the new model tile parameters.
        new_analog_tile = self.get_analog_tile(new_model)
        analog_tile = self.get_analog_tile(model)

        parameters = new_analog_tile.tile.get_parameters()
        self.assertAlmostEqual(parameters.forward_io.inp_noise, 0.321)
        self.assertAlmostEqual(parameters.backward_io.inp_noise, 0.456)
        self.assertAlmostEqual(parameters.update.desired_bl, 78)
        self.assertTrue(new_analog_tile.is_cuda == analog_tile.is_cuda)
예제 #2
0
class InferenceRPUConfig:
    """Configuration for an analog tile that is used only for inference.

    Training is done in *hardware-aware* manner, thus using only the
    non-idealities of the forward-pass, but backward and update passes
    are ideal.

    During inference, statistical models of programming, drift
    and read noise can be used.
    """
    # pylint: disable=too-many-instance-attributes

    bindings_class: ClassVar[Type] = devices.AnalogTileParameter

    forward: IOParameters = field(default_factory=IOParameters)
    """Input-output parameter setting for the forward direction."""

    noise_model: BaseNoiseModel = field(default_factory=PCMLikeNoiseModel)
    """Statistical noise model to be used during (realistic) inference."""

    drift_compensation: BaseDriftCompensation = field(
        default_factory=GlobalDriftCompensation)
    """For compensating the drift during inference only."""

    clip: WeightClipParameter = field(default_factory=WeightClipParameter)
    """Parameters for weight clip."""

    modifier: WeightModifierParameter = field(
        default_factory=WeightModifierParameter)
    """Parameters for weight modifier."""

    # The following fields are not included in `__init__`, and should be
    # treated as read-only.

    device: IdealDevice = field(default_factory=IdealDevice, init=False)
    """Parameters that modify the behavior of the pulsed device: ideal device."""

    backward: IOParameters = field(
        default_factory=lambda: IOParameters(is_perfect=True), init=False)
    """Input-output parameter setting for the backward direction: perfect."""

    update: UpdateParameters = field(
        default_factory=lambda: UpdateParameters(pulse_type=PulseType.NONE),
        init=False)
    """Parameter for the update behavior: ``NONE`` pulse type."""
    def as_bindings(self) -> devices.AnalogTileParameter:
        """Return a representation of this instance as a simulator bindings object."""
        return tile_parameters_to_bindings(self)

    def requires_diffusion(self) -> bool:
        """Return whether device has diffusion enabled."""
        return self.device.diffusion > 0.0

    def requires_decay(self) -> bool:
        """Return whether device has decay enabled."""
        return self.device.lifetime > 0.0
예제 #3
0
    def set_noise_free(dev: Any) -> Any:
        if hasattr(dev, 'dw_min_std'):
            dev.dw_min_std = 0.0  # Noise free.

        if hasattr(dev, 'refresh_forward'):
            setattr(dev, 'refresh_forward', IOParameters(is_perfect=True))

        if hasattr(dev, 'refresh_update'):
            setattr(dev, 'refresh_update',
                    UpdateParameters(pulse_type=PulseType.NONE))

        if hasattr(dev, 'transfer_forward'):
            setattr(dev, 'refresh_forward', IOParameters(is_perfect=True))

        if hasattr(dev, 'transfer_update'):
            setattr(dev, 'transfer_update',
                    UpdateParameters(pulse_type=PulseType.NONE))

        if (hasattr(dev, 'write_noise_std')
                and getattr(dev, 'write_noise_std') > 0.0):
            # Just make very small to avoid hidden parameter mismatch.
            setattr(dev, 'write_noise_std', 1e-6)
예제 #4
0
    def test_config_tile_parameters(self):
        """Test modifying the tile parameters."""
        rpu_config = self.get_rpu_config()

        rpu_config.forward = IOParameters(inp_noise=0.321)
        rpu_config.backward = IOParameters(inp_noise=0.456)
        rpu_config.update = UpdateParameters(desired_bl=78)

        tile = self.get_tile(11, 22, rpu_config).tile

        # Assert over the parameters in the binding objects.
        parameters = tile.get_parameters()
        self.assertAlmostEqual(parameters.forward_io.inp_noise, 0.321)
        self.assertAlmostEqual(parameters.backward_io.inp_noise, 0.456)
        self.assertAlmostEqual(parameters.update.desired_bl, 78)