def recalculate_indexes(self, x_input: Tensor) -> None: """Calculate and set the indexes of the analog tile. Args: x_input: the input tensor. Raises: ModuleError: in case the input is not at least 3 dimensional """ self.input_size = x_input.numel() / x_input.size(0) if x_input.ndim < 3: raise ModuleError("Expect >2-dim inputs to convolutions") channel_dim = 1 self.fold_indices_lst = [] splits = split(x_input, self.in_sizes, dim=channel_dim) for x, in_channels, in_tiles in zip(splits, self.in_sizes, self.analog_tile_array): fold_indices, image_sizes, _ = self._calculate_indexes( x, in_channels) self.fold_indices_lst.append(fold_indices) for analog_tile in in_tiles: analog_tile.set_indexed(fold_indices, image_sizes)
def prepare_for_ddp(self) -> None: """Adds ignores to avoid broadcasting the analog tile states in case of distributed training. Note: Call this function before the mode is converted with DDP. Important: Only InferenceTile supports DDP. Raises: ModuleError: In case analog tiles are used that do not support data-parallel model, ie. all analog training tiles. """ # pylint: disable=attribute-defined-outside-init exclude_list = [] for module in self.modules(): if isinstance(module, AnalogModuleBase): for analog_tile in module.analog_tiles(): if analog_tile.shared_weights is None: raise ModuleError("DDP is only supported with shared weights" "(e.g. InferenceTile)") exclude_list += [module.ANALOG_CTX_PREFIX, module.ANALOG_STATE_PREFIX] exclude_list = list(set(exclude_list)) params = self.state_dict().keys() exclude_params = [] for param in params: for word in exclude_list: if word in param and word not in exclude_params: exclude_params.append(param) break self._ddp_params_and_buffers_to_ignore = exclude_params
def get_split_sizes(self, size: int, split_max_size: int, group_size: int = 1) -> List[int]: """ Computed the split sizes across channels. Args: size: number of elements of the layer in one dimension split_max_size: max size of the split group_size: minimal size of features that needs to stay on one tile Returns: List of split sizes (in split groups) Raises: ModuleError: Tiling weight matrices is always done across channels only. If the group_size is larger than the maximal tile size, mapping cannot be done """ if split_max_size <= 0: return [size // group_size] if group_size > split_max_size: raise ModuleError( "Tile size too small to fit a single group (kernel): " + f"{group_size} > {split_max_size}") size_per_group = size // group_size split_max_per_group = split_max_size // group_size n_splits = (size_per_group + split_max_per_group - 1) // split_max_per_group base, extra = divmod(size_per_group, n_splits) return [(base + (i < extra)) for i in range(n_splits)]
def program_analog_weights(self) -> None: """Program the analog weights.""" if self.training: raise ModuleError('program_analog_weights can only be applied in ' 'evaluation mode') if isinstance(self.analog_tile, InferenceTile): self.analog_tile.program_weights()
def __init__( self, in_features: int, out_features: int, bias: bool = True, rpu_config: Optional[RPUConfigAlias] = None, realistic_read_write: bool = False, weight_scaling_omega: Optional[float] = None, ): # Call super() after tile creation, including ``reset_parameters``. Linear.__init__(self, in_features, out_features, bias=bias) # Create tiles if rpu_config is None: rpu_config = SingleRPUConfig() AnalogModuleBase.__init__( self, in_features, out_features, bias, realistic_read_write, rpu_config.mapping ) if self.analog_bias: raise ModuleError("AnalogLinearMapped only supports digital bias.") # More than one tile may need to be created. If so, divide # weight matrix into equal pieces along input dimension with # as many tiles as needed max_input_size = rpu_config.mapping.max_input_size max_output_size = rpu_config.mapping.max_output_size self.in_sizes = self.get_split_sizes(in_features, max_input_size) self.out_sizes = self.get_split_sizes(out_features, max_output_size) self.analog_tile_array = [] for i, in_tile_size in enumerate(self.in_sizes): in_tiles = [] for j, out_tile_size in enumerate(self.out_sizes): tile = rpu_config.tile_class(out_tile_size, in_tile_size, rpu_config, bias=self.analog_bias) self.register_analog_tile(tile, name=f"{i}_{j}") in_tiles.append(tile) self.analog_tile_array.append(in_tiles) # Set weights from the reset_parameters self.set_weights(self.weight, self.bias, remap_weights=True, weight_scaling_omega=weight_scaling_omega) # Unregister weight/bias as a parameter but keep for sync self.unregister_parameter('weight') if self.analog_bias: self.unregister_parameter('bias')
def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: Dict, strict: bool, missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]) -> None: """Copy parameters and buffers from `state_dict` into only this module, but not its descendants. This method is a specialization of ``Module._load_from_state_dict`` that takes into account the extra ``analog_tile_state`` key used by analog layers. Raises: ModuleError: in case the rpu_config class mismatches. """ for name, analog_tile in list(self.named_analog_tiles()): key = prefix + self.ANALOG_STATE_PREFIX + name if key not in state_dict: # legacy key = prefix + 'analog_tile_state' if key in state_dict: analog_state = state_dict.pop(key).copy() if not self._load_rpu_config: if analog_tile.rpu_config.__class__ != analog_state[ 'rpu_config'].__class__: raise ModuleError( "RPU config mismatch during loading: " "Tried to replace " f"{analog_state['rpu_config'].__class__.__name__} " f"with {analog_tile.rpu_config.__class__.__name__}" ) analog_state['rpu_config'] = analog_tile.rpu_config analog_tile.__setstate__(analog_state) elif strict: missing_keys.append(key) # update the weight / analog bias (not saved explicitly) self._sync_weights_from_tile() # remove helper parameters. rm_keys = [] for par_name in self._registered_helper_parameter: key = prefix + par_name if key in state_dict: state_dict.pop(key) rm_keys.append(key) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) # remove the missing keys of the helper parameters for key in rm_keys: missing_keys.remove(key)
def program_analog_weights(self) -> None: """Program all analog inference layers of a given model. Raises: ModuleError: if the layer is not in evaluation mode. """ if self.training: raise ModuleError('program_analog_weights can only be applied in ' 'evaluation mode') self._apply_to_analog(lambda m: m.program_weights())
def program_analog_weights(self) -> None: """Program the analog weights. Raises: ModuleError: if the layer is not in evaluation mode. """ if self.training: raise ModuleError('program_analog_weights can only be applied in ' 'evaluation mode') for analog_tile in self.analog_tiles(): if isinstance(analog_tile, InferenceTile): analog_tile.program_weights()
def drift_analog_weights(self, t_inference: float = 0.0) -> None: """(Program) and drift the analog weights. Args: t_inference: assumed time of inference (in sec) """ if self.training: raise ModuleError('drift_analog_weights can only be applied in ' 'evaluation mode') if isinstance(self.analog_tile, InferenceTile): self.analog_tile.drift_weights(t_inference)
def unregister_parameter(self, param_name: str) -> None: """Unregister module parameter from parameters. Raises: ModuleError: In case parameter is not found """ param = getattr(self, param_name, None) if not isinstance(param, Parameter): raise ModuleError( f"Cannot find parameter {param_name} to unregister") param_data = param.detach().clone() delattr(self, param_name) setattr(self, param_name, param_data)
def drift_analog_weights(self, t_inference: float = 0.0) -> None: """(Program) and drift all analog inference layers of a given model. Args: t_inference: assumed time of inference (in sec) Raises: ModuleError: if the layer is not in evaluation mode. """ if self.training: raise ModuleError('drift_analog_weights can only be applied in ' 'evaluation mode') self._apply_to_analog(lambda m: m.drift_analog_weights(t_inference))
def _load_from_state_dict(self, state_dict: Dict, prefix: str, local_metadata: Dict, strict: bool, missing_keys: List[str], unexpected_keys: List[str], error_msgs: List[str]) -> None: """Copy parameters and buffers from `state_dict` into only this module, but not its descendants. This method is a specialization of ``Module._load_from_state_dict`` that takes into account the extra ``analog_tile_state`` key used by analog layers. Raises: ModuleError: in case the rpu_config class mismatches. """ key = '{}analog_tile_state'.format(prefix) if key in state_dict: analog_state = state_dict.pop(key).copy() if not self._load_rpu_config: if self.analog_tile.rpu_config.__class__ != analog_state[ 'rpu_config'].__class__: raise ModuleError( "RPU config mismatch during loading: " "Tried to replace " f"{analog_state['rpu_config'].__class__.__name__} " f"with {self.analog_tile.rpu_config.__class__.__name__}" ) analog_state['rpu_config'] = self.analog_tile.rpu_config self.analog_tile.__setstate__(analog_state) elif strict: missing_keys.append(key) # update the weight / bias (not saved explicitly) self._sync_weights_from_tile() super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def __init__( self, in_channels: int, out_channels: int, kernel_size: Tuple[int, ...], stride: Tuple[int, ...], padding: Tuple[int, ...], dilation: Tuple[int, ...], transposed: bool, output_padding: Tuple[int, ...], groups: int, bias: bool, padding_mode: str, rpu_config: Optional[RPUConfigAlias] = None, realistic_read_write: bool = False, weight_scaling_omega: Optional[float] = None, ): # pylint: disable=too-many-arguments, too-many-locals if groups != 1: raise ValueError('Only one group is supported') if padding_mode != 'zeros': raise ValueError('Only "zeros" padding mode is supported') # Call super() after tile creation, including ``reset_parameters``. _ConvNd.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode) # Create tiles if rpu_config is None: rpu_config = SingleRPUConfig() AnalogModuleBase.__init__( self, self.get_tile_size(in_channels, groups, kernel_size), out_channels, bias, realistic_read_write, weight_scaling_omega, rpu_config.mapping) if self.analog_bias: raise ModuleError("AnalogConvNdMapped only supports digital bias.") if not rpu_config: rpu_config = SingleRPUConfig() max_input_size = rpu_config.mapping.max_input_size max_output_size = rpu_config.mapping.max_output_size kernel_elem = self.in_features // self.in_channels self.in_sizes = self.get_split_sizes(self.in_features, max_input_size, kernel_elem) self.out_sizes = self.get_split_sizes(self.out_features, max_output_size) self.analog_tile_array = [] for i, in_tile_size in enumerate(self.in_sizes): in_tiles = [] for j, out_tile_size in enumerate(self.out_sizes): tile = rpu_config.tile_class(out_tile_size, in_tile_size * kernel_elem, rpu_config, bias=self.analog_bias) self.register_analog_tile(tile, name=f"{i}_{j}") in_tiles.append(tile) self.analog_tile_array.append(in_tiles) # Set weights from the reset_parameters (since now the # analog_tiles are registered) self.set_weights(self.weight, self.bias) # Set the index matrices. self.input_size = 0 self.fold_indices_lst = [] # type: List[Tensor] # Unregister weight/bias as a parameter but keep it as a # field (needed for syncing still) self.unregister_parameter('weight') if self.analog_bias: self.unregister_parameter('bias')
def get_weights( self, force_exact: bool = False, apply_out_scales: bool = True, ) -> Tuple[Tensor, Optional[Tensor]]: """Get the weight (and bias) tensors. This uses an realistic read if the property ``realistic_read_write`` of the layer is set, unless it is overwritten by ``force_exact``. It scales the analog weights by the digital alpha scale if ``weight_scaling_omega`` is positive (see :meth:`~aihwkit.simulator.tiles.base.BaseTile.get_weights_scaled`). Note: This is the recommended way for setting the weight/bias matrix from the analog tile, as it will correctly fetch the weights from the internal memory. Accessing ``self.weight`` and ``self.bias`` might yield wrong results as they are not always in sync with the analog tile library, for performance reasons. Args: force_exact: Forces an exact read to the analog tiles apply_out_scales: Whether to return the weights with the (digital) output scaling factors applied. Note the "logical" weights of the layer which the DNN is effectively using are those with the output scales applied. If ``apply_out_scales`` is set to False, then only the weight values that is programmed onto the crossbar array are returned, without applying the digital scales. Returns: tuple: weight matrix, bias vector Raises: ModuleError: in case of multiple defined analog tiles in the module """ analog_tiles = list(self.analog_tiles()) if len(analog_tiles) != 1: raise ModuleError( "AnalogModuleBase.get_weights only supports a single tile.") analog_tile = analog_tiles[0] realistic = self.realistic_read_write and not force_exact if apply_out_scales: weight, analog_bias = analog_tile.get_weights_scaled( realistic=realistic) else: weight, analog_bias = analog_tile.get_weights(realistic=realistic) digital_bias = None if self.digital_bias: with no_grad(): digital_bias = self.bias.data.clone().detach().cpu() if (digital_bias is not None) and (analog_bias is not None): bias = digital_bias + analog_bias elif digital_bias is not None: bias = digital_bias else: bias = analog_bias return weight, bias
def set_weights(self, weight: Tensor, bias: Optional[Tensor] = None, force_exact: bool = False, remap_weights: bool = True, weight_scaling_omega: float = None) -> None: """Set the weight (and bias) values with given tensors. This uses an realistic write if the property ``realistic_read_write`` of the layer is set, unless it is overwritten by ``force_exact``. If ``weight_scaling_omega`` is larger than 0, the weights are set in a scaled manner (assuming a digital output scale). See :meth:`~aihwkit.simulator.tiles.base.BaseTile.set_weights_scaled` for details. Note: This is the recommended way for setting the weight/bias matrix of the analog tile, as it will correctly store the weights into the internal memory. Directly writing to ``self.weight`` and ``self.bias`` might yield wrong results as they are not always in sync with the analog tile Parameters, for performance reasons. Args: weight: weight matrix bias: bias vector force_exact: forces an exact write to the analog tiles remap_weights: Whether to rescale the given weight matrix and populate the digital output scaling factors as specified in the configuration :class:`~aihwkit.configs.utils.MappingParameter`. A new ``weight_scaling_omega`` can be given. Note that this will overwrite the existing digital out scaling factors. weight_scaling_omega: The weight scaling omega factor (see :class:`~aihwkit.configs.utils.MappingParameter`). If given explicitly here, it will overwrite the value in the mapping field. Raises: ModuleError: in case of multiple defined analog tiles in the module """ shape = [self.out_features, self.in_features] weight = weight.clone().reshape(shape) realistic = self.realistic_read_write and not force_exact analog_tiles = list(self.analog_tiles()) if len(analog_tiles) != 1: raise ModuleError( "AnalogModuleBase.set_weights only supports a single tile.") analog_tile = analog_tiles[0] if remap_weights: omega = weight_scaling_omega if omega is None: omega = analog_tile.rpu_config.mapping.weight_scaling_omega analog_tile.set_weights_scaled(weight, bias if self.analog_bias else None, realistic=realistic, weight_scaling_omega=omega) else: analog_tile.set_weights(weight, bias if self.analog_bias else None, realistic=realistic) if bias is not None and self.digital_bias: with no_grad(): self.bias.data[:] = bias[:] self._sync_weights_from_tile()