Exemplo n.º 1
0
 def __init__(
         self,
         out_size: int,
         in_size: int,
         rpu_config: Optional[FloatingPointRPUConfig] = None,
         bias: bool = False,
         in_trans: bool = False,
         out_trans: bool = False,
 ):
     rpu_config = rpu_config or FloatingPointRPUConfig()
     super().__init__(out_size, in_size, rpu_config, bias, in_trans, out_trans)
Exemplo n.º 2
0
    def get_custom_tile(self, out_size, in_size, **parameters):
        """Return a tile with custom parameters for the resistive device."""
        if 'FloatingPoint' in self.parameter:
            rpu_config = FloatingPointRPUConfig(device=FloatingPointDevice(
                **parameters))
        else:
            rpu_config = SingleRPUConfig(device=ConstantStepDevice(
                **parameters))

        python_tile = self.get_tile(out_size, in_size, rpu_config)
        self.set_init_weights(python_tile)

        return python_tile
Exemplo n.º 3
0
    def test_load_state_load_rpu_config_wrong(self):
        """Test creating a new model using a state dict, while using a different RPU config."""

        # Skipped for FP
        if isinstance(self.get_rpu_config(), FloatingPointRPUConfig):
            raise SkipTest('Not available for FP')

        # Create the device and the array.
        model = self.get_layer()
        state_dict = model.state_dict()

        rpu_config = FloatingPointRPUConfig()

        new_model = self.get_layer(rpu_config=rpu_config)
        assert_raises(ModuleError, new_model.load_state_dict, state_dict, load_rpu_config=False)
Exemplo n.º 4
0
 def __init__(
         self,
         out_size: int,
         in_size: int,
         rpu_config: Optional['FloatingPointRPUConfig'] = None,
         bias: bool = False,
         in_trans: bool = False,
         out_trans: bool = False,
 ):
     if not rpu_config:
         # Import `FloatingPointRPUConfig` dynamically to avoid import cycles.
         # pylint: disable=import-outside-toplevel
         from aihwkit.simulator.configs import FloatingPointRPUConfig
         rpu_config = FloatingPointRPUConfig()
     super().__init__(out_size, in_size, rpu_config, bias, in_trans, out_trans)
Exemplo n.º 5
0
SEED = 1
N_EPOCHS = 30
BATCH_SIZE = 8
LEARNING_RATE = 0.01
N_CLASSES = 10

# Select the device model to use in the training.
# * If `SingleRPUConfig(device=ConstantStepDevice())` then analog tiles with
#   constant step devices will be used,
# * If `FloatingPointRPUConfig(device=FloatingPointDevice())` then standard
#   floating point devices will be used
USE_ANALOG_TRAINING = True
if USE_ANALOG_TRAINING:
    RPU_CONFIG = SingleRPUConfig(device=ConstantStepDevice())
else:
    RPU_CONFIG = FloatingPointRPUConfig(device=FloatingPointDevice())


def load_images():
    """Load images for train from torchvision datasets."""

    transform = transforms.Compose([transforms.ToTensor()])
    train_set = datasets.MNIST(PATH_DATASET, download=True, train=True, transform=transform)
    val_set = datasets.MNIST(PATH_DATASET, download=True, train=False, transform=transform)
    train_data = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
    validation_data = torch.utils.data.DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False)

    return train_data, validation_data


def create_analog_network():
Exemplo n.º 6
0
 def get_rpu_config(self):
     return FloatingPointRPUConfig()
Exemplo n.º 7
0
    def test_against_fp(self):
        """Test whether FP is same as is_perfect inference tile."""
        # pylint: disable-msg=too-many-locals
        # Prepare the datasets (input and expected output).
        x = Tensor([[0.1, 0.2, 0.4, 0.3], [0.2, 0.1, 0.1, 0.3]])
        y = Tensor([[1.0, 0.5], [0.7, 0.3]])

        # Define a single-layer network, using a constant step device type.
        rpu_config = self.get_rpu_config()
        rpu_config.forward.is_perfect = True
        model_torch = Linear(4, 2, bias=True)
        model = AnalogLinear(4, 2, bias=True, rpu_config=rpu_config)
        model.set_weights(model_torch.weight, model_torch.bias)
        model_fp = AnalogLinear(4,
                                2,
                                bias=True,
                                rpu_config=FloatingPointRPUConfig())
        model_fp.set_weights(model_torch.weight, model_torch.bias)

        self.assertTensorAlmostEqual(model.get_weights()[0],
                                     model_torch.weight)
        self.assertTensorAlmostEqual(model.get_weights()[0],
                                     model_fp.get_weights()[0])

        # Move the model and tensors to cuda if it is available.
        if self.use_cuda:
            x = x.cuda()
            y = y.cuda()
            model.cuda()
            model_fp.cuda()
            model_torch.cuda()

        # Define an analog-aware optimizer, preparing it for using the layers.
        opt = AnalogSGD(model.parameters(), lr=0.1)
        opt_fp = AnalogSGD(model_fp.parameters(), lr=0.1)
        opt_torch = SGD(model_torch.parameters(), lr=0.1)

        for _ in range(100):

            # inference
            opt.zero_grad()
            pred = model(x)
            loss = mse_loss(pred, y)
            loss.backward()
            opt.step()

            # same for fp
            opt_fp.zero_grad()
            pred_fp = model_fp(x)
            loss_fp = mse_loss(pred_fp, y)
            loss_fp.backward()
            opt_fp.step()

            # same for torch
            opt_torch.zero_grad()
            pred_torch = model_torch(x)
            loss_torch = mse_loss(pred_torch, y)
            loss_torch.backward()
            opt_torch.step()

            self.assertTensorAlmostEqual(pred_torch, pred)
            self.assertTensorAlmostEqual(loss_torch, loss)
            self.assertTensorAlmostEqual(model.get_weights()[0],
                                         model_torch.weight)

            self.assertTensorAlmostEqual(pred_fp, pred)
            self.assertTensorAlmostEqual(loss_fp, loss)
            self.assertTensorAlmostEqual(model.get_weights()[0],
                                         model_fp.get_weights()[0])