class AnalogLSTMCell(AnalogSequential): """Analog LSTM Cell. Args: input_size: in_features size for W_ih matrix hidden_size: in_features and out_features size for W_hh matrix bias: whether to use a bias row on the analog tile or not rpu_config: configuration for an analog resistive processing unit realistic_read_write: whether to enable realistic read/write for setting initial weights and read out of weights """ # pylint: disable=abstract-method def __init__( self, input_size: int, hidden_size: int, bias: bool, rpu_config: Optional[RPUConfigAlias] = None, realistic_read_write: bool = False, ): super().__init__() # Default to InferenceRPUConfig if not rpu_config: rpu_config = InferenceRPUConfig() self.input_size = input_size self.hidden_size = hidden_size self.weight_ih = AnalogLinear( input_size, 4 * hidden_size, bias=bias, rpu_config=rpu_config, realistic_read_write=realistic_read_write) self.weight_hh = AnalogLinear( hidden_size, 4 * hidden_size, bias=bias, rpu_config=rpu_config, realistic_read_write=realistic_read_write) def get_zero_state(self, batch_size: int) -> Tensor: """Returns a zeroed state. Args: batch_size: batch size of the input Returns: Zeroed state tensor """ device = self.weight_ih.get_analog_tile_devices()[0] return LSTMState(zeros(batch_size, self.hidden_size, device=device), zeros(batch_size, self.hidden_size, device=device)) def forward( self, input_: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # pylint: disable=arguments-differ h_x, c_x = state gates = self.weight_ih(input_) + self.weight_hh(h_x) in_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) in_gate = sigmoid(in_gate) forget_gate = sigmoid(forget_gate) cell_gate = tanh(cell_gate) out_gate = sigmoid(out_gate) c_y = (forget_gate * c_x) + (in_gate * cell_gate) h_y = out_gate * tanh(c_y) return h_y, (h_y, c_y)
class AnalogVanillaRNNCell(AnalogSequential): """Analog Vanilla RNN Cell. Args: input_size: in_features size for W_ih matrix hidden_size: in_features and out_features size for W_hh matrix bias: whether to use a bias row on the analog tile or not rpu_config: configuration for an analog resistive processing unit realistic_read_write: whether to enable realistic read/write for setting initial weights and read out of weights """ # pylint: disable=abstract-method def __init__( self, input_size: int, hidden_size: int, bias: bool, rpu_config: Optional[RPUConfigAlias] = None, realistic_read_write: bool = False, ): super().__init__() # Default to InferenceRPUConfig if not rpu_config: rpu_config = InferenceRPUConfig() self.input_size = input_size self.hidden_size = hidden_size self.weight_ih = AnalogLinear( input_size, hidden_size, bias=bias, rpu_config=rpu_config, realistic_read_write=realistic_read_write) self.weight_hh = AnalogLinear( hidden_size, hidden_size, bias=bias, rpu_config=rpu_config, realistic_read_write=realistic_read_write) def get_zero_state(self, batch_size: int) -> Tensor: """Returns a zeroed state. Args: batch_size: batch size of the input Returns: Zeroed state tensor """ device = self.weight_ih.get_analog_tile_devices()[0] return zeros(batch_size, self.hidden_size, device=device) def forward(self, input_: Tensor, state: Tensor) -> Tuple[Tensor, Tensor]: # pylint: disable=arguments-differ igates = self.weight_ih(input_) hgates = self.weight_hh(state) out = tanh(igates + hgates) return out, out # output will also be hidden state