Ejemplo n.º 1
0
def naive_map(weight, r_on, r_off, scheme, p_l=None):
    """Method to naively map network parameters to memristive device conductances, using two crossbars to represent both positive and negative weights.

        Parameters
        ----------
        weight : torch.Tensor
            Weight tensor to map.
        r_on : float
            Low resistance state.
        r_off : float
            High resistance state.
        scheme: memtorch.bh.crossbar.Scheme
            Weight representation scheme.
        p_l: float
            If not None, the proportion of weights to retain.

        Returns
        -------
        torch.Tensor, torch.Tensor
            Positive and negative crossbar weights.
    """

    if p_l is not None:
        assert p_l >= 0 and p_l <= 1, 'p_l must be None or between 0 and 1.'
        weight_max, _ = torch.sort(weight.clone().abs().flatten(),
                                   descending=True)
        weight_max = weight_max[int(p_l * (weight.numel() - 1))]
        weight_min = weight_max / (r_off / r_on)
    else:
        weight_max = weight.abs().max()
        weight_min = 0

    if scheme == memtorch.bh.crossbar.Scheme.DoubleColumn:
        pos = weight.clone()
        neg = weight.clone() * -1
        pos[pos < 0] = 0
        neg[neg < 0] = 0
        pos = torch.clamp(pos, weight_min, weight_max)
        neg = torch.clamp(neg, weight_min, weight_max)
        pos = convert_range(pos, weight_min, weight_max, 1 / r_off, 1 / r_on)
        neg = convert_range(neg, weight_min, weight_max, 1 / r_off, 1 / r_on)
        return pos, neg
    elif scheme == memtorch.bh.crossbar.Scheme.SingleColumn:
        crossbar = weight.clone()
        crossbar = torch.clamp(crossbar, weight_min, weight_max)
        crossbar = convert_range(crossbar, weight_min, weight_max, 1 / r_off,
                                 1 / r_on)
        return crossbar
    else:
        raise ('%s is not currently supported.' % scheme)
Ejemplo n.º 2
0
    def __init__(self,
                 time_series_resolution=1e-4,
                 u_v=1e-14,
                 d=10e-9,
                 r_on=100,
                 r_off=16e3,
                 pos_write_threshold=0.55,
                 neg_write_threshold=-0.55,
                 p=1,
                 **kwargs):

        args = memtorch.bh.unpack_parameters(locals())
        super(LinearIonDrift, self).__init__(
            args.r_off,
            args.r_on,
            args.time_series_resolution,
            args.pos_write_threshold,
            args.neg_write_threshold,
        )
        self.u_v = args.u_v
        self.d = args.d
        self.r_i = args.r_on
        self.p = args.p
        self.g = 1 / self.r_i
        self.x = convert_range(self.r_i, self.r_on, self.r_off, 0, 1)
Ejemplo n.º 3
0
    def forward(self, input):
        """Method to perform forward propagations.

            Parameters
            ----------
            input : torch.Tensor
                Input tensor.

            Returns
            -------
            torch.Tensor
                Output tensor.
        """
        if self.forward_legacy_enabled:
            out = torch.matmul(input.to(self.device), self.weight.data.T.to(self.device))
            if not self.bias is None:
                out += self.bias.view(1, -1).expand_as(out)

            return out
        else:
            if hasattr(self, 'non_linear'):
                input = convert_range(input, input.min(), input.max(), -1, 1)
                input = input.cpu().detach().numpy()
                if hasattr(self, 'simulate'):
                    out = self.transform_output(self.crossbar_operation(self.crossbars, lambda crossbar, input_: simulate_matmul(input, crossbar.devices, nl=False), input_=input)).to(self.device)
                else:
                    out = self.transform_output(self.crossbar_operation(self.crossbars, lambda crossbar, input_: simulate_matmul(input, crossbar.devices, nl=True), input_=input)).to(self.device)
            else:
                out = torch.matmul(input.to(self.device), self.crossbar_operation(self.crossbars, lambda crossbar: crossbar.conductance_matrix))

            out = self.transform_output(out)
            if not self.bias is None:
                out += self.bias.view(1, -1).expand_as(out)

            return out
Ejemplo n.º 4
0
    def forward(self, input):
        """Method to perform forward propagations.

            Parameters
            ----------
            input : torch.Tensor
                Input tensor.

            Returns
            -------
            torch.Tensor
                Output tensor.
        """
        output_dim = int((input.shape[2] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1
        out = torch.zeros((input.shape[0], self.out_channels, output_dim)).to(self.device)
        if hasattr(self, 'non_linear'):
            input = convert_range(input, input.min(), input.max(), -1, 1)
        else:
            weight = self.crossbar_operation(self.crossbars, lambda crossbar: crossbar.conductance_matrix).view(self.weight.shape)

        for batch in range(input.shape[0]):
            filter = torch.zeros((self.in_channels, self.kernel_size[0]))
            count = 0
            for i in range(self.out_channels):
                while count < (input.shape[-1] - self.kernel_size[0] + 1):
                    for j in range(self.in_channels):
                        for k in range(count, self.kernel_size[0] + count):
                            if hasattr(self, 'non_linear') and hasattr(self, 'simulate'):
                                out[batch][i][count] = out[batch][i][count] + self.crossbar_operation(self.crossbars, lambda crossbar: crossbar.devices[i][j][k - count].simulate(input[batch][j][k], return_current=True)).item()
                            elif hasattr(self, 'non_linear'):
                                out[batch][i][count] = out[batch][i][count] + self.crossbar_operation(self.crossbars, lambda crossbar: crossbar.devices[i][j][k - count].det_current(input[batch][j][k])).item()
                            else:
                                out[batch][i][count] = out[batch][i][count] + (input[batch][j][k] * weight[i][j][k - count].item())

                    count = count + 1
                count = 0

        out = self.transform_output(out)
        if self.bias is not None:
            out += self.bias.view(-1, 1).expand_as(out)

        return out
Ejemplo n.º 5
0
    def forward(self, input):
        """Method to perform forward propagations.

            Parameters
            ----------
            input : torch.Tensor
                Input tensor.

            Returns
            -------
            torch.Tensor
                Output tensor.
        """
        if self.forward_legacy_enabled:
            return torch.nn.functional.conv2d(input.to(self.device), self.weight, bias=self.bias, stride=self.stride, padding=self.padding)
        else:
            output_dim = int((input.shape[2] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1
            out = torch.zeros((input.shape[0], self.out_channels, output_dim, output_dim)).to(self.device)
            for batch in range(input.shape[0]):
                unfolded_batch_input = torch.nn.functional.unfold(input[batch, :, :, :].unsqueeze(0), kernel_size=self.kernel_size, padding=self.padding)
                if hasattr(self, 'non_linear'):
                    unfolded_batch_input = convert_range(unfolded_batch_input, unfolded_batch_input.min(), unfolded_batch_input.max(), -1, 1).squeeze(0)
                    unfolded_batch_input = unfolded_batch_input.transpose(1, 0).cpu().detach().numpy()
                    if hasattr(self, 'simulate'):
                        out_ = torch.tensor(self.transform_output(self.crossbar_operation(self.crossbars, lambda crossbar, input: simulate_matmul(input, crossbar.devices.transpose(1, 0), nl=False), unfolded_batch_input))).to(self.device)
                    else:
                        out_ = torch.tensor(self.transform_output(self.crossbar_operation(self.crossbars, lambda crossbar, input: simulate_matmul(input, crossbar.devices.transpose(1, 0), nl=True), unfolded_batch_input))).to(self.device)
                else:
                    out_ = self.transform_output(torch.matmul(self.crossbar_operation(self.crossbars, lambda crossbar: crossbar.conductance_matrix), unfolded_batch_input))

                if not self.bias is None:
                    out_ += self.bias.view(-1, 1).expand_as(out_)

                out[batch] = out_.view(size=(1, self.out_channels, output_dim, output_dim))

            return out
Ejemplo n.º 6
0
    def forward(self, input):
        """Method to perform forward propagations.

            Parameters
            ----------
            input : torch.Tensor
                Input tensor.

            Returns
            -------
            torch.Tensor
                Output tensor.
        """
        if self.forward_legacy_enabled:
            return torch.nn.functional.conv3d(input.to(self.device),
                                              self.weight.to(self.device),
                                              bias=self.bias,
                                              stride=self.stride,
                                              padding=self.padding)
        else:
            output_dim = [0, 0, 0]
            output_dim[0] = int(
                (input.shape[2] - self.kernel_size[0] + 2 * self.padding[0]) /
                self.stride[0]) + 1
            output_dim[1] = int(
                (input.shape[3] - self.kernel_size[1] + 2 * self.padding[1]) /
                self.stride[1]) + 1
            output_dim[2] = int(
                (input.shape[4] - self.kernel_size[2] + 2 * self.padding[2]) /
                self.stride[2]) + 1
            out = torch.zeros(
                (input.shape[0], self.out_channels, output_dim[0],
                 output_dim[1], output_dim[2])).to(self.device)
            for batch in range(input.shape[0]):
                if not all(item == 0 for item in self.padding):
                    batch_input = nn.functional.pad(
                        input[batch],
                        pad=(self.padding[2], self.padding[2], self.padding[1],
                             self.padding[1], self.padding[0],
                             self.padding[0]))
                else:
                    batch_input = input[batch]

                if self.max_input_voltage is not None:
                    assert (
                        type(self.max_input_voltage) == int
                        or type(self.max_input_voltage) == float
                    ) and self.max_input_voltage > 0, 'The maximum input voltage (max_input_voltage) must be >0.'
                    batch_input = batch_input = convert_range(
                        batch_input, batch_input.min(), batch_input.max(),
                        -self.max_input_voltage, self.max_input_voltage)

                unfolded_batch_input = batch_input.unfold(1, self.kernel_size[0], self.stride[0]).unfold(2, self.kernel_size[1], self.stride[1]).unfold(3, self.kernel_size[2], self.stride[2]) \
                    .permute(1, 2, 3, 0, 4, 5, 6).reshape(-1, self.in_channels * self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2])
                unfolded_batch_input_shape = unfolded_batch_input.shape
                if hasattr(self, 'non_linear'):
                    if self.tile_shape is not None:
                        tiles_map = self.crossbars[0].tiles_map
                        crossbar_shape = (self.crossbars[0].rows,
                                          self.crossbars[0].columns)
                    else:
                        tiles_map = None
                        crossbar_shape = None

                    if hasattr(self, 'simulate'):
                        nl = False
                    else:
                        nl = True

                    out_ = self.crossbar_operation(self.crossbars, lambda crossbar, input_: simulate_matmul(unfolded_batch_input, crossbar, nl=nl, \
                                                   tiles_map=tiles_map, crossbar_shape=crossbar_shape, max_input_voltage=self.max_input_voltage,
                                                   ADC_resolution=self.ADC_resolution, ADC_overflow_rate=self.ADC_overflow_rate,
                                                   quant_method=self.quant_method), input_=unfolded_batch_input).to(self.device).T
                else:
                    if self.tile_shape is not None:
                        unfolded_batch_input_tiles, unfolded_batch_input_tiles_map = gen_tiles(
                            unfolded_batch_input, self.tile_shape, input=True)
                        crossbar_shape = (self.crossbars[0].rows,
                                          self.crossbars[0].columns)
                        tiles_map = self.crossbars[0].tiles_map
                        out_ = tile_matmul(unfolded_batch_input_tiles, unfolded_batch_input_tiles_map, unfolded_batch_input_shape, \
                                self.crossbar_operation(self.crossbars, lambda crossbar: crossbar.conductance_matrix), tiles_map, crossbar_shape,
                                self.ADC_resolution, self.ADC_overflow_rate, self.quant_method).T
                    else:
                        out_ = torch.matmul(
                            unfolded_batch_input,
                            self.crossbar_operation(
                                self.crossbars, lambda crossbar: crossbar.
                                conductance_matrix)).T
                        if self.quant_method is not None:
                            out_ = memtorch.bh.Quantize.quantize(
                                out_,
                                bits=self.ADC_resolution,
                                overflow_rate=self.ADC_overflow_rate,
                                quant_method=self.quant_method)

                    out[batch] = out_.view(size=(1, self.out_channels,
                                                 *output_dim))

            out = self.transform_output(out)
            if not self.bias is None:
                out[batch] += self.bias.data.view(-1, 1, 1,
                                                  1).expand_as(out[batch])

            return out
Ejemplo n.º 7
0
    def forward(self, input):
        """Method to perform forward propagations.

        Parameters
        ----------
        input : torch.Tensor
            Input tensor.

        Returns
        -------
        torch.Tensor
            Output tensor.
        """
        if self.forward_legacy_enabled:
            return torch.nn.functional.conv1d(
                input.to(self.device),
                self.weight.to(self.device),
                bias=self.bias,
                stride=self.stride,
                padding=self.padding,
            )
        else:
            output_dim = (
                int(
                    (input.shape[2] - self.kernel_size[0] + 2 * self.padding[0])
                    / self.stride[0]
                )
                + 1
            )
            out = torch.zeros((input.shape[0], self.out_channels, output_dim)).to(
                self.device
            )
            if not all(item == 0 for item in self.padding):
                input = nn.functional.pad(input, pad=(self.padding[0], self.padding[0]))

            if self.max_input_voltage is not None:
                assert (
                    type(self.max_input_voltage) == int
                    or type(self.max_input_voltage) == float
                ) and self.max_input_voltage > 0, (
                    "The maximum input voltage (max_input_voltage) must be >0."
                )
                input = convert_range(
                    input,
                    input.min(),
                    input.max(),
                    -self.max_input_voltage,
                    self.max_input_voltage,
                )

            for batch in range(input.shape[0]):
                unfolded_batch_input = (
                    input[batch]
                    .unfold(-1, size=self.kernel_size[0], step=self.stride[0])
                    .permute(1, 0, 2)
                    .reshape(-1, self.in_channels * self.kernel_size[0])
                )
                unfolded_batch_input_shape = unfolded_batch_input.shape
                if hasattr(self, "non_linear"):
                    if self.tile_shape is not None:
                        tiles_map = self.crossbars[0].tiles_map
                        crossbar_shape = (
                            self.crossbars[0].rows,
                            self.crossbars[0].columns,
                        )
                    else:
                        tiles_map = None
                        crossbar_shape = None

                    if hasattr(self, "simulate"):
                        nl = False
                    else:
                        nl = True

                    out_ = (
                        self.crossbar_operation(
                            self.crossbars,
                            lambda crossbar, input_: simulate_matmul(
                                unfolded_batch_input,
                                crossbar,
                                nl=nl,
                                tiles_map=tiles_map,
                                crossbar_shape=crossbar_shape,
                                max_input_voltage=self.max_input_voltage,
                                ADC_resolution=self.ADC_resolution,
                                ADC_overflow_rate=self.ADC_overflow_rate,
                                quant_method=self.quant_method,
                            ),
                            input_=unfolded_batch_input,
                        )
                        .to(self.device)
                        .T
                    )
                else:
                    if self.tile_shape is not None:
                        (
                            unfolded_batch_input_tiles,
                            unfolded_batch_input_tiles_map,
                        ) = gen_tiles(unfolded_batch_input, self.tile_shape, input=True)
                        crossbar_shape = (
                            self.crossbars[0].rows,
                            self.crossbars[0].columns,
                        )
                        tiles_map = self.crossbars[0].tiles_map
                        out_ = tile_matmul(
                            unfolded_batch_input_tiles,
                            unfolded_batch_input_tiles_map,
                            unfolded_batch_input_shape,
                            self.crossbar_operation(
                                self.crossbars,
                                lambda crossbar: crossbar.conductance_matrix,
                            ),
                            tiles_map,
                            crossbar_shape,
                            self.ADC_resolution,
                            self.ADC_overflow_rate,
                            self.quant_method,
                        ).T
                    else:
                        out_ = torch.matmul(
                            unfolded_batch_input,
                            self.crossbar_operation(
                                self.crossbars,
                                lambda crossbar: crossbar.conductance_matrix,
                            ),
                        ).T
                        if self.quant_method is not None:
                            out_ = memtorch.bh.Quantize.quantize(
                                out_,
                                quant=self.ADC_resolution,
                                overflow_rate=self.ADC_overflow_rate,
                                quant_method=self.quant_method,
                            )

                out[batch] = out_.view(size=(1, self.out_channels, output_dim))

            out = self.transform_output(out)
            if self.bias is not None:
                out += self.bias.view(-1, 1).expand_as(out)

            return out
Ejemplo n.º 8
0
    def forward(self, input):
        """Method to perform forward propagations.

            Parameters
            ----------
            input : torch.Tensor
                Input tensor.

            Returns
            -------
            torch.Tensor
                Output tensor.
        """
        if self.forward_legacy_enabled:
            out = torch.matmul(input.to(self.device),
                               self.weight.data.T.to(self.device))
            if self.bias is not None:
                out += self.bias.view(1, -1).expand_as(out)

            return out
        else:
            input_shape = input.shape
            if self.max_input_voltage is not None:
                assert (
                    type(self.max_input_voltage) == int
                    or type(self.max_input_voltage) == float
                ) and self.max_input_voltage > 0, 'The maximum input voltage (max_input_voltage) must be >0.'
                input = input = convert_range(input, input.min(), input.max(),
                                              -self.max_input_voltage,
                                              self.max_input_voltage)

            if hasattr(self, 'non_linear'):
                if self.tile_shape is not None:
                    tiles_map = self.crossbars[0].tiles_map
                    crossbar_shape = self.weight.data.shape
                else:
                    tiles_map = None
                    crossbar_shape = None

                if hasattr(self, 'simulate'):
                    nl = False
                else:
                    nl = True

                out = self.crossbar_operation(self.crossbars, lambda crossbar, input_: simulate_matmul(input, crossbar, nl=nl, \
                                              tiles_map=tiles_map, crossbar_shape=crossbar_shape, max_input_voltage=self.max_input_voltage,
                                              ADC_resolution=self.ADC_resolution, ADC_overflow_rate=self.ADC_overflow_rate,
                                              quant_method=self.quant_method), input_=input).to(self.device)
            else:
                if self.tile_shape is not None:
                    input_tiles, input_tiles_map = gen_tiles(input,
                                                             self.tile_shape,
                                                             input=True)
                    crossbar_shape = (self.crossbars[0].rows,
                                      self.crossbars[0].columns)
                    tiles_map = self.crossbars[0].tiles_map
                    out = tile_matmul(input_tiles, input_tiles_map, input_shape, self.crossbar_operation(self.crossbars, \
                        lambda crossbar: crossbar.conductance_matrix), tiles_map, crossbar_shape,
                        self.ADC_resolution, self.ADC_overflow_rate, self.quant_method)
                else:
                    out = torch.matmul(
                        input.to(self.device),
                        self.crossbar_operation(
                            self.crossbars,
                            lambda crossbar: crossbar.conductance_matrix))
                    if self.quant_method is not None:
                        out = memtorch.bh.Quantize.quantize(
                            out,
                            bits=self.ADC_resolution,
                            overflow_rate=self.ADC_overflow_rate,
                            quant_method=self.quant_method)

            out = self.transform_output(out)
            if self.bias is not None:
                out += self.bias.data.view(1, -1).expand_as(out)

            return out
Ejemplo n.º 9
0
    def forward(self, input):
        """Method to perform forward propagations.

            Parameters
            ----------
            input : torch.Tensor
                Input tensor.

            Returns
            -------
            torch.Tensor
                Output tensor.
        """
        if self.forward_legacy_enabled:
            return torch.nn.functional.conv3d(input.to(self.device),
                                              self.weight.to(self.device),
                                              bias=self.bias,
                                              stride=self.stride,
                                              padding=self.padding)
        else:
            output_dim = [0, 0, 0]
            output_dim[0] = int(
                (input.shape[2] - self.kernel_size[0] + 2 * self.padding[0]) /
                self.stride[0]) + 1
            output_dim[1] = int(
                (input.shape[3] - self.kernel_size[1] + 2 * self.padding[1]) /
                self.stride[1]) + 1
            output_dim[2] = int(
                (input.shape[4] - self.kernel_size[2] + 2 * self.padding[2]) /
                self.stride[2]) + 1
            out = torch.zeros(
                (input.shape[0], self.out_channels, output_dim[0],
                 output_dim[1], output_dim[2])).to(self.device)
            for batch in range(input.shape[0]):
                if all(item == 0 for item in self.padding):
                    batch_input = input[batch]
                else:
                    batch_input = nn.functional.pad(
                        input[batch],
                        pad=(self.padding[2], self.padding[2], self.padding[1],
                             self.padding[1], self.padding[0],
                             self.padding[0]),
                        mode="constant",
                        value=0)

                channel_idx = 0
                for channel in range(self.in_channels):
                    batch_channel_input = batch_input[channel].unsqueeze(
                        0).unsqueeze(0)
                    unfolded_batch_channel_input = batch_channel_input.unfold(
                        2, self.kernel_size[0], self.stride[0]).unfold(
                            3, self.kernel_size[1],
                            self.stride[1]).unfold(4, self.kernel_size[2],
                                                   self.stride[2])
                    unfolded_batch_channel_input = unfolded_batch_channel_input.reshape(
                        -1, self.kernel_size[0] * self.kernel_size[1] *
                        self.kernel_size[2])
                    if hasattr(self, 'non_linear'):
                        unfolded_batch_channel_input = convert_range(
                            unfolded_batch_channel_input,
                            unfolded_batch_channel_input.min(),
                            unfolded_batch_channel_input.max(), -1,
                            1).squeeze(0)
                        unfolded_batch_channel_input = unfolded_batch_channel_input.transpose(
                            1, 0).cpu().detach().numpy()
                        if hasattr(self, 'simulate'):
                            nl = False
                        else:
                            nl = True

                        if self.scheme == memtorch.bh.Scheme.DoubleColumn:
                            out[batch, :, :, :, :] += self.transform_output(
                                self.crossbar_operation(
                                    self.crossbars,
                                    lambda crossbar, input_: simulate_matmul(
                                        input_,
                                        crossbar.devices.transpose(1, 0),
                                        nl=nl),
                                    input_=unfolded_batch_channel_input.T,
                                    idx=(channel_idx, channel_idx + 1))).view(
                                        self.out_channels, output_dim[0],
                                        output_dim[1],
                                        output_dim[2]).to(self.device)
                        elif self.scheme == memtorch.bh.Scheme.SingleColumn:
                            out[batch, :, :, :, :] += self.transform_output(
                                self.crossbar_operation(
                                    self.crossbars,
                                    lambda crossbar, input_: simulate_matmul(
                                        input_,
                                        crossbar.devices.transpose(1, 0),
                                        nl=nl),
                                    input_=unfolded_batch_channel_input.T,
                                    idx=channel_idx)).view(
                                        self.out_channels, output_dim[0],
                                        output_dim[1],
                                        output_dim[2]).to(self.device)
                        else:
                            raise Exception('Scheme is currently unsupported.')
                    else:
                        if self.scheme == memtorch.bh.Scheme.DoubleColumn:
                            out[batch, :, :, :, :] += self.transform_output(
                                torch.matmul(
                                    self.crossbar_operation(
                                        self.crossbars, lambda crossbar:
                                        crossbar.conductance_matrix,
                                        (channel_idx, channel_idx + 1)),
                                    unfolded_batch_channel_input.T)).view(
                                        self.out_channels, output_dim[0],
                                        output_dim[1], output_dim[2])
                            channel_idx += 2
                        elif self.scheme == memtorch.bh.Scheme.SingleColumn:
                            out[batch, :, :, :, :] += self.transform_output(
                                torch.matmul(
                                    self.crossbar_operation(
                                        self.crossbars, lambda crossbar:
                                        crossbar.conductance_matrix,
                                        channel_idx),
                                    unfolded_batch_channel_input.T)).view(
                                        self.out_channels, output_dim[0],
                                        output_dim[1], output_dim[2])
                            channel_idx += 1
                        else:
                            raise Exception('Scheme is currently unsupported.')

                if not self.bias is None:
                    out[batch] += self.bias.view(-1, 1).expand_as(out[batch])

            return out
Ejemplo n.º 10
0
 def set_conductance(self, conductance):
     conductance = clip(conductance, 1 / self.r_off, 1 / self.r_on)
     self.x = convert_range(1 / conductance, self.r_on, self.r_off, 0, 1)
     self.g = conductance