示例#1
0
class AnalogLSTMCell(AnalogSequential):
    """Analog LSTM Cell.

    Args:
        input_size: in_features size for W_ih matrix
        hidden_size: in_features and out_features size for W_hh matrix
        bias: whether to use a bias row on the analog tile or not
        rpu_config: configuration for an analog resistive processing unit
        realistic_read_write: whether to enable realistic read/write
            for setting initial weights and read out of weights
    """

    # pylint: disable=abstract-method

    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        bias: bool,
        rpu_config: Optional[RPUConfigAlias] = None,
        realistic_read_write: bool = False,
    ):
        super().__init__()

        # Default to InferenceRPUConfig
        if not rpu_config:
            rpu_config = InferenceRPUConfig()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = AnalogLinear(
            input_size,
            4 * hidden_size,
            bias=bias,
            rpu_config=rpu_config,
            realistic_read_write=realistic_read_write)
        self.weight_hh = AnalogLinear(
            hidden_size,
            4 * hidden_size,
            bias=bias,
            rpu_config=rpu_config,
            realistic_read_write=realistic_read_write)

    def get_zero_state(self, batch_size: int) -> Tensor:
        """Returns a zeroed state.

        Args:
            batch_size: batch size of the input

        Returns:
           Zeroed state tensor
        """
        device = self.weight_ih.get_analog_tile_devices()[0]
        return LSTMState(zeros(batch_size, self.hidden_size, device=device),
                         zeros(batch_size, self.hidden_size, device=device))

    def forward(
            self, input_: Tensor,
            state: Tuple[Tensor,
                         Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:

        # pylint: disable=arguments-differ
        h_x, c_x = state
        gates = self.weight_ih(input_) + self.weight_hh(h_x)
        in_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)

        in_gate = sigmoid(in_gate)
        forget_gate = sigmoid(forget_gate)
        cell_gate = tanh(cell_gate)
        out_gate = sigmoid(out_gate)

        c_y = (forget_gate * c_x) + (in_gate * cell_gate)
        h_y = out_gate * tanh(c_y)

        return h_y, (h_y, c_y)
示例#2
0
class AnalogVanillaRNNCell(AnalogSequential):
    """Analog Vanilla RNN Cell.

    Args:
        input_size: in_features size for W_ih matrix
        hidden_size: in_features and out_features size for W_hh matrix
        bias: whether to use a bias row on the analog tile or not
        rpu_config: configuration for an analog resistive processing unit
        realistic_read_write: whether to enable realistic read/write
            for setting initial weights and read out of weights
    """

    # pylint: disable=abstract-method
    def __init__(
        self,
        input_size: int,
        hidden_size: int,
        bias: bool,
        rpu_config: Optional[RPUConfigAlias] = None,
        realistic_read_write: bool = False,
    ):
        super().__init__()

        # Default to InferenceRPUConfig
        if not rpu_config:
            rpu_config = InferenceRPUConfig()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.weight_ih = AnalogLinear(
            input_size,
            hidden_size,
            bias=bias,
            rpu_config=rpu_config,
            realistic_read_write=realistic_read_write)
        self.weight_hh = AnalogLinear(
            hidden_size,
            hidden_size,
            bias=bias,
            rpu_config=rpu_config,
            realistic_read_write=realistic_read_write)

    def get_zero_state(self, batch_size: int) -> Tensor:
        """Returns a zeroed state.

        Args:
            batch_size: batch size of the input

        Returns:
           Zeroed state tensor
        """
        device = self.weight_ih.get_analog_tile_devices()[0]
        return zeros(batch_size, self.hidden_size, device=device)

    def forward(self, input_: Tensor, state: Tensor) -> Tuple[Tensor, Tensor]:
        # pylint: disable=arguments-differ
        igates = self.weight_ih(input_)
        hgates = self.weight_hh(state)

        out = tanh(igates + hgates)

        return out, out  # output will also be hidden state