def forward(self, input): """Method to perform forward propagations. Parameters ---------- input : torch.Tensor Input tensor. Returns ------- torch.Tensor Output tensor. """ if self.forward_legacy_enabled: return torch.nn.functional.conv1d( input.to(self.device), self.weight.to(self.device), bias=self.bias, stride=self.stride, padding=self.padding, ) else: output_dim = ( int( (input.shape[2] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0] ) + 1 ) out = torch.zeros((input.shape[0], self.out_channels, output_dim)).to( self.device ) if not all(item == 0 for item in self.padding): input = nn.functional.pad(input, pad=(self.padding[0], self.padding[0])) if self.max_input_voltage is not None: assert ( type(self.max_input_voltage) == int or type(self.max_input_voltage) == float ) and self.max_input_voltage > 0, ( "The maximum input voltage (max_input_voltage) must be >0." ) input = convert_range( input, input.min(), input.max(), -self.max_input_voltage, self.max_input_voltage, ) for batch in range(input.shape[0]): unfolded_batch_input = ( input[batch] .unfold(-1, size=self.kernel_size[0], step=self.stride[0]) .permute(1, 0, 2) .reshape(-1, self.in_channels * self.kernel_size[0]) ) unfolded_batch_input_shape = unfolded_batch_input.shape if hasattr(self, "non_linear"): if self.tile_shape is not None: tiles_map = self.crossbars[0].tiles_map crossbar_shape = ( self.crossbars[0].rows, self.crossbars[0].columns, ) else: tiles_map = None crossbar_shape = None if hasattr(self, "simulate"): nl = False else: nl = True out_ = ( self.crossbar_operation( self.crossbars, lambda crossbar, input_: simulate_matmul( unfolded_batch_input, crossbar, nl=nl, tiles_map=tiles_map, crossbar_shape=crossbar_shape, max_input_voltage=self.max_input_voltage, ADC_resolution=self.ADC_resolution, ADC_overflow_rate=self.ADC_overflow_rate, quant_method=self.quant_method, ), input_=unfolded_batch_input, ) .to(self.device) .T ) else: if self.tile_shape is not None: ( unfolded_batch_input_tiles, unfolded_batch_input_tiles_map, ) = gen_tiles(unfolded_batch_input, self.tile_shape, input=True) crossbar_shape = ( self.crossbars[0].rows, self.crossbars[0].columns, ) tiles_map = self.crossbars[0].tiles_map out_ = tile_matmul( unfolded_batch_input_tiles, unfolded_batch_input_tiles_map, unfolded_batch_input_shape, self.crossbar_operation( self.crossbars, lambda crossbar: crossbar.conductance_matrix, ), tiles_map, crossbar_shape, self.ADC_resolution, self.ADC_overflow_rate, self.quant_method, ).T else: out_ = torch.matmul( unfolded_batch_input, self.crossbar_operation( self.crossbars, lambda crossbar: crossbar.conductance_matrix, ), ).T if self.quant_method is not None: out_ = memtorch.bh.Quantize.quantize( out_, quant=self.ADC_resolution, overflow_rate=self.ADC_overflow_rate, quant_method=self.quant_method, ) out[batch] = out_.view(size=(1, self.out_channels, output_dim)) out = self.transform_output(out) if self.bias is not None: out += self.bias.view(-1, 1).expand_as(out) return out
def forward(self, input): """Method to perform forward propagations. Parameters ---------- input : torch.Tensor Input tensor. Returns ------- torch.Tensor Output tensor. """ if self.forward_legacy_enabled: return torch.nn.functional.conv3d(input.to(self.device), self.weight.to(self.device), bias=self.bias, stride=self.stride, padding=self.padding) else: output_dim = [0, 0, 0] output_dim[0] = int( (input.shape[2] - self.kernel_size[0] + 2 * self.padding[0]) / self.stride[0]) + 1 output_dim[1] = int( (input.shape[3] - self.kernel_size[1] + 2 * self.padding[1]) / self.stride[1]) + 1 output_dim[2] = int( (input.shape[4] - self.kernel_size[2] + 2 * self.padding[2]) / self.stride[2]) + 1 out = torch.zeros( (input.shape[0], self.out_channels, output_dim[0], output_dim[1], output_dim[2])).to(self.device) for batch in range(input.shape[0]): if not all(item == 0 for item in self.padding): batch_input = nn.functional.pad( input[batch], pad=(self.padding[2], self.padding[2], self.padding[1], self.padding[1], self.padding[0], self.padding[0])) else: batch_input = input[batch] if self.max_input_voltage is not None: assert ( type(self.max_input_voltage) == int or type(self.max_input_voltage) == float ) and self.max_input_voltage > 0, 'The maximum input voltage (max_input_voltage) must be >0.' batch_input = batch_input = convert_range( batch_input, batch_input.min(), batch_input.max(), -self.max_input_voltage, self.max_input_voltage) unfolded_batch_input = batch_input.unfold(1, self.kernel_size[0], self.stride[0]).unfold(2, self.kernel_size[1], self.stride[1]).unfold(3, self.kernel_size[2], self.stride[2]) \ .permute(1, 2, 3, 0, 4, 5, 6).reshape(-1, self.in_channels * self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]) unfolded_batch_input_shape = unfolded_batch_input.shape if hasattr(self, 'non_linear'): if self.tile_shape is not None: tiles_map = self.crossbars[0].tiles_map crossbar_shape = (self.crossbars[0].rows, self.crossbars[0].columns) else: tiles_map = None crossbar_shape = None if hasattr(self, 'simulate'): nl = False else: nl = True out_ = self.crossbar_operation(self.crossbars, lambda crossbar, input_: simulate_matmul(unfolded_batch_input, crossbar, nl=nl, \ tiles_map=tiles_map, crossbar_shape=crossbar_shape, max_input_voltage=self.max_input_voltage, ADC_resolution=self.ADC_resolution, ADC_overflow_rate=self.ADC_overflow_rate, quant_method=self.quant_method), input_=unfolded_batch_input).to(self.device).T else: if self.tile_shape is not None: unfolded_batch_input_tiles, unfolded_batch_input_tiles_map = gen_tiles( unfolded_batch_input, self.tile_shape, input=True) crossbar_shape = (self.crossbars[0].rows, self.crossbars[0].columns) tiles_map = self.crossbars[0].tiles_map out_ = tile_matmul(unfolded_batch_input_tiles, unfolded_batch_input_tiles_map, unfolded_batch_input_shape, \ self.crossbar_operation(self.crossbars, lambda crossbar: crossbar.conductance_matrix), tiles_map, crossbar_shape, self.ADC_resolution, self.ADC_overflow_rate, self.quant_method).T else: out_ = torch.matmul( unfolded_batch_input, self.crossbar_operation( self.crossbars, lambda crossbar: crossbar. conductance_matrix)).T if self.quant_method is not None: out_ = memtorch.bh.Quantize.quantize( out_, bits=self.ADC_resolution, overflow_rate=self.ADC_overflow_rate, quant_method=self.quant_method) out[batch] = out_.view(size=(1, self.out_channels, *output_dim)) out = self.transform_output(out) if not self.bias is None: out[batch] += self.bias.data.view(-1, 1, 1, 1).expand_as(out[batch]) return out
def forward(self, input): """Method to perform forward propagations. Parameters ---------- input : torch.Tensor Input tensor. Returns ------- torch.Tensor Output tensor. """ if self.forward_legacy_enabled: out = torch.matmul(input.to(self.device), self.weight.data.T.to(self.device)) if self.bias is not None: out += self.bias.view(1, -1).expand_as(out) return out else: input_shape = input.shape if self.max_input_voltage is not None: assert ( type(self.max_input_voltage) == int or type(self.max_input_voltage) == float ) and self.max_input_voltage > 0, 'The maximum input voltage (max_input_voltage) must be >0.' input = input = convert_range(input, input.min(), input.max(), -self.max_input_voltage, self.max_input_voltage) if hasattr(self, 'non_linear'): if self.tile_shape is not None: tiles_map = self.crossbars[0].tiles_map crossbar_shape = self.weight.data.shape else: tiles_map = None crossbar_shape = None if hasattr(self, 'simulate'): nl = False else: nl = True out = self.crossbar_operation(self.crossbars, lambda crossbar, input_: simulate_matmul(input, crossbar, nl=nl, \ tiles_map=tiles_map, crossbar_shape=crossbar_shape, max_input_voltage=self.max_input_voltage, ADC_resolution=self.ADC_resolution, ADC_overflow_rate=self.ADC_overflow_rate, quant_method=self.quant_method), input_=input).to(self.device) else: if self.tile_shape is not None: input_tiles, input_tiles_map = gen_tiles(input, self.tile_shape, input=True) crossbar_shape = (self.crossbars[0].rows, self.crossbars[0].columns) tiles_map = self.crossbars[0].tiles_map out = tile_matmul(input_tiles, input_tiles_map, input_shape, self.crossbar_operation(self.crossbars, \ lambda crossbar: crossbar.conductance_matrix), tiles_map, crossbar_shape, self.ADC_resolution, self.ADC_overflow_rate, self.quant_method) else: out = torch.matmul( input.to(self.device), self.crossbar_operation( self.crossbars, lambda crossbar: crossbar.conductance_matrix)) if self.quant_method is not None: out = memtorch.bh.Quantize.quantize( out, bits=self.ADC_resolution, overflow_rate=self.ADC_overflow_rate, quant_method=self.quant_method) out = self.transform_output(out) if self.bias is not None: out += self.bias.data.view(1, -1).expand_as(out) return out