Beispiel #1
0
    def __setstate__(self, state: Dict) -> None:
        """Set the state after unpickling.

        This method recreates the ``tile`` member, creating a new one from
        scratch, as the binding Tiles are not serializable.

        Caution:
            RPU configs are overwritten by loading the state.

        Raises:
            TileError: if tile class does not match or hidden parameters do not match
        """
        current_dict = state.copy()
        weights = current_dict.pop('analog_tile_weights')
        hidden_parameters = current_dict.pop('analog_tile_hidden_parameters')
        hidden_parameters_names = current_dict.pop('analog_tile_hidden_parameter_names', [])
        alpha_scale = current_dict.pop('analog_alpha_scale')
        tile_class = current_dict.pop('analog_tile_class', self.__class__.__name__)
        self.__dict__.update(current_dict)

        x_size = self.in_size + 1 if self.bias else self.in_size
        d_size = self.out_size

        # Recreate the tile.
        # Check for tile mismatch
        if tile_class != self.__class__.__name__:
            raise TileError(
                'Mismatch of tile class: {} versus {}. Can only load analog '
                'state from the same tile class.'.format(self.__class__.__name__, tile_class))

        self.tile = self._create_simulator_tile(x_size, d_size, self.rpu_config)
        names = self.tile.get_hidden_parameter_names()
        if len(hidden_parameters_names) > 0 and names != hidden_parameters_names:
            # Check whether names match
            raise TileError('Mismatch with loaded analog state: '
                            'Hidden parameter structure is unexpected.')
        self.tile.set_hidden_parameters(Tensor(hidden_parameters))
        self.tile.set_weights(weights)
        if alpha_scale is not None:
            self.tile.set_alpha_scale(alpha_scale)

        # Keep the data (for future use)
        data = self.analog_ctx.data.detach()
        self.analog_ctx = AnalogContext(self)
        self.analog_ctx.set_data(data)

        if self.is_cuda:
            self.cuda(self.device)

        self.ensure_shared_weights()
Beispiel #2
0
    def forward_indexed(self,
                        x_input: Tensor,
                        is_test: bool = False) -> Tensor:
        """Perform the forward pass for convolutions. Depending on the input tensor size
        it performs the forward pass for a 2D image or a 3D one.

        Args:
            x_input: ``[N, in_size]`` tensor. If ``in_trans`` is set, transposed.
            is_test: whether to assume testing mode.

        Returns:
            torch.Tensor: ``[N, out_size]`` tensor. If ``out_trans`` is set, transposed.

        Raises:
            TileError: if the indexed tile has not been initialized.
        """
        if not self.image_sizes:
            raise TileError('self.image_sizes is not initialized. Please use '
                            'set_indexed()')

        n_batch = x_input.size(0)
        channel_out = self.out_size

        if len(self.image_sizes) == 5:
            _, _, _, height_out, width_out = self.image_sizes
            d_tensor = empty(n_batch, channel_out, height_out, width_out)

        if len(self.image_sizes) == 7:
            _, _, _, _, depth_out, height_out, width_out = self.image_sizes
            d_tensor = empty(n_batch, channel_out, depth_out, height_out,
                             width_out)

        return self.tile.forward_indexed(x_input, d_tensor, is_test)
Beispiel #3
0
    def forward_indexed(self,
                        x_input: Tensor,
                        is_test: bool = False) -> Tensor:
        """Perform the forward pass for convolutions.

        Depending on the input tensor size it performs the forward pass for a
        2D image or a 3D one.

        Args:
            x_input: ``[N, in_size]`` tensor. If ``in_trans`` is set, transposed.
            is_test: whether to assume testing mode.

        Returns:
            torch.Tensor: ``[N, out_size]`` tensor. If ``out_trans`` is set, transposed.

        Raises:
            TileError: if the indexed tile has not been initialized, or if
                ``self.images_sizes`` does not have a valid dimennion.
        """
        if not self.image_sizes:
            raise TileError('self.image_sizes is not initialized. Please use '
                            'set_indexed()')

        n_batch = x_input.size(0)
        channel_out = self.out_size

        if len(self.image_sizes) == 3:
            _, _, height_out = self.image_sizes
            d_tensor = empty(n_batch, channel_out, height_out)

        elif len(self.image_sizes) == 5:
            _, _, _, height_out, width_out = self.image_sizes
            d_tensor = empty(n_batch, channel_out, height_out, width_out)

        elif len(self.image_sizes) == 7:
            _, _, _, _, depth_out, height_out, width_out = self.image_sizes
            d_tensor = empty(n_batch, channel_out, depth_out, height_out,
                             width_out)
        else:
            raise TileError('self.image_sizes length is not 3, 5 or 7')

        # Move helper tensor to cuda if needed.
        if self.is_cuda:
            d_tensor = d_tensor.to(self.device)

        return self.tile.forward_indexed(x_input, d_tensor, is_test)
Beispiel #4
0
    def backward_indexed(self, d_input: Tensor) -> Tensor:
        """Perform the backward pass for convolutions.

        Depending on the input tensor size it performs the backward pass for a
        2D image or a 3D one.

        Args:
            d_input: ``[N, out_size]`` tensor. If ``out_trans`` is set, transposed.

        Returns:
            torch.Tensor: ``[N, in_size]`` tensor. If ``in_trans`` is set, transposed.

        Raises:
            TileError: if the indexed tile has not been initialized, or if
                ``self.images_sizes`` does not have a valid dimennion.
        """
        if not self.image_sizes:
            raise TileError('self.image_sizes is not initialized. Please use '
                            'set_indexed()')

        n_batch = d_input.size(0)

        if len(self.image_sizes) == 3:
            channel_in, height_in, _ = self.image_sizes
            x_tensor = empty(n_batch, channel_in, height_in)

        elif len(self.image_sizes) == 5:
            channel_in, height_in, width_in, _, _ = self.image_sizes
            x_tensor = empty(n_batch, channel_in, height_in, width_in)

        elif len(self.image_sizes) == 7:
            channel_in, depth_in, height_in, width_in, _, _, _ \
                = self.image_sizes
            x_tensor = empty(n_batch, channel_in, depth_in, height_in,
                             width_in)
        else:
            raise TileError('self.image_sizes length is not 3, 5 or 7')

        # Move helper tensor to cuda if needed.
        if self.is_cuda:
            x_tensor = x_tensor.to(self.device)

        return self.tile.backward_indexed(d_input, x_tensor)
Beispiel #5
0
    def set_indexed(self, indices: Tensor, image_sizes: List) -> None:
        """Sets the index matrix for convolutions ans switches to
        indexed forward/backward/update versions.

        Args:
            indices : torch.tensor with int indices
            image_sizes: [C_in, H_in, W_in, H_out, W_out] sizes
        """
        if len(image_sizes) != 5:
            raise ValueError(
                'image_sizes expects 5 sizes [C_in, H_in, W_in, H_out, W_out]')

        if self.in_trans or self.out_trans:
            raise TileError(
                'Transposed indexed versions not supported (assumes NCHW)')

        self.image_sizes = image_sizes
        self.tile.set_matrix_indices(indices)
Beispiel #6
0
    def forward_indexed(self,
                        x_input: Tensor,
                        is_test: bool = False) -> Tensor:
        """Perform the forward pass for convolutions.

        Args:
            x_input: ``[N, in_size]`` tensor. If ``in_trans`` is set, transposed.
            is_test: whether to assume testing mode.

        Returns:
            torch.Tensor: ``[N, out_size]`` tensor. If ``out_trans`` is set, transposed.
        """
        if not self.image_sizes:
            raise TileError('self.image_sizes is not initialized. Please use '
                            'set_indexed()')

        _, _, _, height_out, width_out = self.image_sizes
        return self.tile.forward_indexed(x_input, height_out, width_out,
                                         is_test)
Beispiel #7
0
    def set_indexed(self, indices: Tensor, image_sizes: List) -> None:
        """Set the index matrix for convolutions ans switches to
        indexed forward/backward/update versions.

        Args:
            indices : torch.tensor with int indices
            image_sizes: [C_in, H_in, W_in, H_out, W_out] sizes

        Raises:
            ValueError: if ``image_sizes`` does not have valid dimensions.
            TileError: if the tile uses transposition.
        """
        if len(image_sizes) not in (3, 5, 7):
            raise ValueError('image_sizes expects 3, 5 or 7 sizes '
                             '[C_in, (D_in), H_in, (W_in), (D_out), H_out, (W_out)]')

        if self.in_trans or self.out_trans:
            raise TileError('Transposed indexed versions not supported (assumes NC(D)HW)')

        self.image_sizes = image_sizes
        self.tile.set_matrix_indices(indices)
Beispiel #8
0
    def set_hidden_parameters(self, ordered_parameters: OrderedDict) -> None:
        """Set the hidden parameters of the tile.

        Caution:
            Usually the hidden parameters are drawn according to the
            parameter definitions (those given in the RPU config). If
            the hidden parameters are arbitrary set by the user, then
            this correspondence might be broken. This might cause problems
            in the learning, in particular, the `weight granularity`
            (usually ``dw_min``, depending on the device) is needed for
            the dynamic adjustment of the bit length
            (``update_bl_management``, see
            :class:`~aihwkit.simulator.configs.utils.UpdateParameters`).

            Currently, the new ``dw_min`` parameter is tried to be
            estimated from the average of hidden parameters if the
            discrepancy with the ``dw_min`` from the definition is too
            large.

        Args:
            ordered_parameters: Ordered dictionary of hidden parameter tensors.

        Raises:
            TileError: In case the ordered dict keys do not conform
                with the current rpu config tile structure of the hidden
                parameters
        """
        if len(ordered_parameters) == 0:
            return

        hidden_parameters = stack(list(ordered_parameters.values()), dim=0)
        names = self.tile.get_hidden_parameter_names()
        if names != list(ordered_parameters.keys()):
            raise TileError('Mismatch with loaded analog state:'
                            'Hidden parameter structure is unexpected.')

        self.tile.set_hidden_parameters(hidden_parameters)
Beispiel #9
0
    def __setstate__(self, state: Dict) -> None:
        """Set the state after unpickling.

        This method recreates the ``tile`` member, creating a new one from
        scratch, as the binding Tiles are not serializable.

        Caution:
            RPU configs are overwritten by loading the state.

        Raises:
            TileError: if tile class does not match or hidden parameters do not match
        """
        # pylint: disable=too-many-locals

        # Note: self here is NOT initialized! So we need to recreate
        # attributes that were not saved in getstate

        current_dict = state.copy()
        weights = current_dict.pop('analog_tile_weights')
        hidden_parameters = current_dict.pop('analog_tile_hidden_parameters')
        hidden_parameters_names = current_dict.pop(
            'analog_tile_hidden_parameter_names', [])
        alpha_scale = current_dict.pop('analog_alpha_scale', None)
        tile_class = current_dict.pop('analog_tile_class',
                                      self.__class__.__name__)
        analog_lr = current_dict.pop('analog_lr', 0.01)
        analog_ctx = current_dict.pop('analog_ctx')
        shared_weights = current_dict.pop('shared_weights')
        shared_weights_if = shared_weights is not None

        self.__dict__.update(current_dict)

        self.device = torch_device('cpu')
        self.is_cuda = False
        # get the current map location from analog_ctx (which is restored)
        to_device = analog_ctx.device

        # recreate attributes not saved
        # always first create on CPU
        x_size = self.in_size + 1 if self.bias else self.in_size
        d_size = self.out_size

        # Recreate the tile.
        # Check for tile mismatch
        if tile_class != self.__class__.__name__:
            raise TileError(
                'Mismatch of tile class: {} versus {}. Can only load analog '
                'state from the same tile class.'.format(
                    self.__class__.__name__, tile_class))

        self.tile = self._create_simulator_tile(x_size, d_size,
                                                self.rpu_config)
        names = self.tile.get_hidden_parameter_names()
        if len(hidden_parameters_names
               ) > 0 and names != hidden_parameters_names:
            # Check whether names match
            raise TileError('Mismatch with loaded analog state: '
                            'Hidden parameter structure is unexpected.')
        self.tile.set_hidden_parameters(Tensor(hidden_parameters))
        self.tile.set_weights(weights)

        self.tile.set_learning_rate(analog_lr)

        # re-generate shared weights (CPU)
        if shared_weights_if:
            if not hasattr(self, 'shared_weights'):
                # this is needed when pkl loading
                self.shared_weights = shared_weights

            with no_grad():
                # always new will be populated with set weights.
                self.shared_weights.data = zeros(d_size,
                                                 x_size,
                                                 requires_grad=True)
            self.ensure_shared_weights()
        else:
            self.shared_weights = None

        # Regenerate context but keep the object ID
        if not hasattr(self, 'analog_ctx'):  # when loading
            self.analog_ctx = AnalogContext(self, parameter=analog_ctx)
        self.analog_ctx.reset(self)
        self.analog_ctx.set_data(analog_ctx.data)

        if to_device.type.startswith('cuda'):
            self.cuda(to_device)

        if alpha_scale is not None:
            # legacy. We apply the alpha scale instaed of the
            # out_scaling_alpha when loading. The alpha_scale
            # mechansim is now replaced with the out scaling factors
            #
            # Caution: will overwrite the loaded out_scaling_alphas
            # if they would exist also (should not be for old checkpoints)

            self.set_out_scaling_alpha(alpha_scale)