Ejemplo n.º 1
0
    def forward(self, input):
        """Method to perform forward propagations.

        Parameters
        ----------
        input : torch.Tensor
            Input tensor.

        Returns
        -------
        torch.Tensor
            Output tensor.
        """
        if self.forward_legacy_enabled:
            return torch.nn.functional.conv1d(
                input.to(self.device),
                self.weight.to(self.device),
                bias=self.bias,
                stride=self.stride,
                padding=self.padding,
            )
        else:
            output_dim = (
                int(
                    (input.shape[2] - self.kernel_size[0] + 2 * self.padding[0])
                    / self.stride[0]
                )
                + 1
            )
            out = torch.zeros((input.shape[0], self.out_channels, output_dim)).to(
                self.device
            )
            if not all(item == 0 for item in self.padding):
                input = nn.functional.pad(input, pad=(self.padding[0], self.padding[0]))

            if self.max_input_voltage is not None:
                assert (
                    type(self.max_input_voltage) == int
                    or type(self.max_input_voltage) == float
                ) and self.max_input_voltage > 0, (
                    "The maximum input voltage (max_input_voltage) must be >0."
                )
                input = convert_range(
                    input,
                    input.min(),
                    input.max(),
                    -self.max_input_voltage,
                    self.max_input_voltage,
                )

            for batch in range(input.shape[0]):
                unfolded_batch_input = (
                    input[batch]
                    .unfold(-1, size=self.kernel_size[0], step=self.stride[0])
                    .permute(1, 0, 2)
                    .reshape(-1, self.in_channels * self.kernel_size[0])
                )
                unfolded_batch_input_shape = unfolded_batch_input.shape
                if hasattr(self, "non_linear"):
                    if self.tile_shape is not None:
                        tiles_map = self.crossbars[0].tiles_map
                        crossbar_shape = (
                            self.crossbars[0].rows,
                            self.crossbars[0].columns,
                        )
                    else:
                        tiles_map = None
                        crossbar_shape = None

                    if hasattr(self, "simulate"):
                        nl = False
                    else:
                        nl = True

                    out_ = (
                        self.crossbar_operation(
                            self.crossbars,
                            lambda crossbar, input_: simulate_matmul(
                                unfolded_batch_input,
                                crossbar,
                                nl=nl,
                                tiles_map=tiles_map,
                                crossbar_shape=crossbar_shape,
                                max_input_voltage=self.max_input_voltage,
                                ADC_resolution=self.ADC_resolution,
                                ADC_overflow_rate=self.ADC_overflow_rate,
                                quant_method=self.quant_method,
                            ),
                            input_=unfolded_batch_input,
                        )
                        .to(self.device)
                        .T
                    )
                else:
                    if self.tile_shape is not None:
                        (
                            unfolded_batch_input_tiles,
                            unfolded_batch_input_tiles_map,
                        ) = gen_tiles(unfolded_batch_input, self.tile_shape, input=True)
                        crossbar_shape = (
                            self.crossbars[0].rows,
                            self.crossbars[0].columns,
                        )
                        tiles_map = self.crossbars[0].tiles_map
                        out_ = tile_matmul(
                            unfolded_batch_input_tiles,
                            unfolded_batch_input_tiles_map,
                            unfolded_batch_input_shape,
                            self.crossbar_operation(
                                self.crossbars,
                                lambda crossbar: crossbar.conductance_matrix,
                            ),
                            tiles_map,
                            crossbar_shape,
                            self.ADC_resolution,
                            self.ADC_overflow_rate,
                            self.quant_method,
                        ).T
                    else:
                        out_ = torch.matmul(
                            unfolded_batch_input,
                            self.crossbar_operation(
                                self.crossbars,
                                lambda crossbar: crossbar.conductance_matrix,
                            ),
                        ).T
                        if self.quant_method is not None:
                            out_ = memtorch.bh.Quantize.quantize(
                                out_,
                                quant=self.ADC_resolution,
                                overflow_rate=self.ADC_overflow_rate,
                                quant_method=self.quant_method,
                            )

                out[batch] = out_.view(size=(1, self.out_channels, output_dim))

            out = self.transform_output(out)
            if self.bias is not None:
                out += self.bias.view(-1, 1).expand_as(out)

            return out
Ejemplo n.º 2
0
    def forward(self, input):
        """Method to perform forward propagations.

            Parameters
            ----------
            input : torch.Tensor
                Input tensor.

            Returns
            -------
            torch.Tensor
                Output tensor.
        """
        if self.forward_legacy_enabled:
            return torch.nn.functional.conv3d(input.to(self.device),
                                              self.weight.to(self.device),
                                              bias=self.bias,
                                              stride=self.stride,
                                              padding=self.padding)
        else:
            output_dim = [0, 0, 0]
            output_dim[0] = int(
                (input.shape[2] - self.kernel_size[0] + 2 * self.padding[0]) /
                self.stride[0]) + 1
            output_dim[1] = int(
                (input.shape[3] - self.kernel_size[1] + 2 * self.padding[1]) /
                self.stride[1]) + 1
            output_dim[2] = int(
                (input.shape[4] - self.kernel_size[2] + 2 * self.padding[2]) /
                self.stride[2]) + 1
            out = torch.zeros(
                (input.shape[0], self.out_channels, output_dim[0],
                 output_dim[1], output_dim[2])).to(self.device)
            for batch in range(input.shape[0]):
                if not all(item == 0 for item in self.padding):
                    batch_input = nn.functional.pad(
                        input[batch],
                        pad=(self.padding[2], self.padding[2], self.padding[1],
                             self.padding[1], self.padding[0],
                             self.padding[0]))
                else:
                    batch_input = input[batch]

                if self.max_input_voltage is not None:
                    assert (
                        type(self.max_input_voltage) == int
                        or type(self.max_input_voltage) == float
                    ) and self.max_input_voltage > 0, 'The maximum input voltage (max_input_voltage) must be >0.'
                    batch_input = batch_input = convert_range(
                        batch_input, batch_input.min(), batch_input.max(),
                        -self.max_input_voltage, self.max_input_voltage)

                unfolded_batch_input = batch_input.unfold(1, self.kernel_size[0], self.stride[0]).unfold(2, self.kernel_size[1], self.stride[1]).unfold(3, self.kernel_size[2], self.stride[2]) \
                    .permute(1, 2, 3, 0, 4, 5, 6).reshape(-1, self.in_channels * self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2])
                unfolded_batch_input_shape = unfolded_batch_input.shape
                if hasattr(self, 'non_linear'):
                    if self.tile_shape is not None:
                        tiles_map = self.crossbars[0].tiles_map
                        crossbar_shape = (self.crossbars[0].rows,
                                          self.crossbars[0].columns)
                    else:
                        tiles_map = None
                        crossbar_shape = None

                    if hasattr(self, 'simulate'):
                        nl = False
                    else:
                        nl = True

                    out_ = self.crossbar_operation(self.crossbars, lambda crossbar, input_: simulate_matmul(unfolded_batch_input, crossbar, nl=nl, \
                                                   tiles_map=tiles_map, crossbar_shape=crossbar_shape, max_input_voltage=self.max_input_voltage,
                                                   ADC_resolution=self.ADC_resolution, ADC_overflow_rate=self.ADC_overflow_rate,
                                                   quant_method=self.quant_method), input_=unfolded_batch_input).to(self.device).T
                else:
                    if self.tile_shape is not None:
                        unfolded_batch_input_tiles, unfolded_batch_input_tiles_map = gen_tiles(
                            unfolded_batch_input, self.tile_shape, input=True)
                        crossbar_shape = (self.crossbars[0].rows,
                                          self.crossbars[0].columns)
                        tiles_map = self.crossbars[0].tiles_map
                        out_ = tile_matmul(unfolded_batch_input_tiles, unfolded_batch_input_tiles_map, unfolded_batch_input_shape, \
                                self.crossbar_operation(self.crossbars, lambda crossbar: crossbar.conductance_matrix), tiles_map, crossbar_shape,
                                self.ADC_resolution, self.ADC_overflow_rate, self.quant_method).T
                    else:
                        out_ = torch.matmul(
                            unfolded_batch_input,
                            self.crossbar_operation(
                                self.crossbars, lambda crossbar: crossbar.
                                conductance_matrix)).T
                        if self.quant_method is not None:
                            out_ = memtorch.bh.Quantize.quantize(
                                out_,
                                bits=self.ADC_resolution,
                                overflow_rate=self.ADC_overflow_rate,
                                quant_method=self.quant_method)

                    out[batch] = out_.view(size=(1, self.out_channels,
                                                 *output_dim))

            out = self.transform_output(out)
            if not self.bias is None:
                out[batch] += self.bias.data.view(-1, 1, 1,
                                                  1).expand_as(out[batch])

            return out
Ejemplo n.º 3
0
    def forward(self, input):
        """Method to perform forward propagations.

            Parameters
            ----------
            input : torch.Tensor
                Input tensor.

            Returns
            -------
            torch.Tensor
                Output tensor.
        """
        if self.forward_legacy_enabled:
            out = torch.matmul(input.to(self.device),
                               self.weight.data.T.to(self.device))
            if self.bias is not None:
                out += self.bias.view(1, -1).expand_as(out)

            return out
        else:
            input_shape = input.shape
            if self.max_input_voltage is not None:
                assert (
                    type(self.max_input_voltage) == int
                    or type(self.max_input_voltage) == float
                ) and self.max_input_voltage > 0, 'The maximum input voltage (max_input_voltage) must be >0.'
                input = input = convert_range(input, input.min(), input.max(),
                                              -self.max_input_voltage,
                                              self.max_input_voltage)

            if hasattr(self, 'non_linear'):
                if self.tile_shape is not None:
                    tiles_map = self.crossbars[0].tiles_map
                    crossbar_shape = self.weight.data.shape
                else:
                    tiles_map = None
                    crossbar_shape = None

                if hasattr(self, 'simulate'):
                    nl = False
                else:
                    nl = True

                out = self.crossbar_operation(self.crossbars, lambda crossbar, input_: simulate_matmul(input, crossbar, nl=nl, \
                                              tiles_map=tiles_map, crossbar_shape=crossbar_shape, max_input_voltage=self.max_input_voltage,
                                              ADC_resolution=self.ADC_resolution, ADC_overflow_rate=self.ADC_overflow_rate,
                                              quant_method=self.quant_method), input_=input).to(self.device)
            else:
                if self.tile_shape is not None:
                    input_tiles, input_tiles_map = gen_tiles(input,
                                                             self.tile_shape,
                                                             input=True)
                    crossbar_shape = (self.crossbars[0].rows,
                                      self.crossbars[0].columns)
                    tiles_map = self.crossbars[0].tiles_map
                    out = tile_matmul(input_tiles, input_tiles_map, input_shape, self.crossbar_operation(self.crossbars, \
                        lambda crossbar: crossbar.conductance_matrix), tiles_map, crossbar_shape,
                        self.ADC_resolution, self.ADC_overflow_rate, self.quant_method)
                else:
                    out = torch.matmul(
                        input.to(self.device),
                        self.crossbar_operation(
                            self.crossbars,
                            lambda crossbar: crossbar.conductance_matrix))
                    if self.quant_method is not None:
                        out = memtorch.bh.Quantize.quantize(
                            out,
                            bits=self.ADC_resolution,
                            overflow_rate=self.ADC_overflow_rate,
                            quant_method=self.quant_method)

            out = self.transform_output(out)
            if self.bias is not None:
                out += self.bias.data.view(1, -1).expand_as(out)

            return out