Exemple #1
0
 def forward(self, stats_input: Tensor) -> Tensor:
     if self.training:
         if self.counter < self.collect_stats_steps:
             if self.stats_permute_dims is not None:
                 stats_input = stats_input.permute(
                     *self.stats_permute_dims).contiguous()
             stats_input = self.stats_input_view_shape_impl(stats_input)
             stats = self.stats(stats_input)
             if self.counter == 0:
                 self.value.detach().mul_(stats.detach())
             else:
                 self.value.detach().mul_(1 - self.momentum)
                 self.value.detach().add_(self.momentum * stats.detach())
             self.counter = self.counter + 1
             return stats
         elif self.counter == self.collect_stats_steps:
             self.restrict_inplace_preprocess(self.value.detach())
             self.counter = self.counter + 1
             return self.restrict_clamp_scaling(
                 abs_binary_sign_grad(self.value))
         else:
             return self.restrict_clamp_scaling(
                 abs_binary_sign_grad(self.value))
     out = self.restrict_clamp_scaling(abs_binary_sign_grad(self.value))
     return out
Exemple #2
0
 def training_forward(self, stats_input: Tensor) -> Tensor:
     if self.counter < self.collect_stats_steps:
         stats_input = self.stats_input_view_shape_impl(stats_input)
         stats = self.stats(stats_input)
         new_counter = self.counter + 1
         if self.counter == 0:
             _inplace_init(self.buffer, stats.detach())
         else:
             _inplace_update(self.buffer, stats.detach(), self.momentum, self.counter, new_counter)
         self.counter = new_counter
         return stats
     elif self.counter == self.collect_stats_steps:
         self.restrict_inplace_preprocess(self.buffer)
         _inplace_init(self.value.detach(), self.buffer)
         self.counter = self.counter + 1
         return abs_binary_sign_grad(self.restrict_clamp_scaling(self.value))
     else:
         return abs_binary_sign_grad(self.restrict_clamp_scaling(self.value))
Exemple #3
0
 def forward(self, stats_input: Tensor) -> Tensor:
     if self.training:
         return self.training_forward(stats_input)
     else:
         if self.counter <= self.collect_stats_steps:
             out = self.buffer
             out = self.restrict_preprocess(out)
         else:
             out = self.value
         out = abs_binary_sign_grad(self.restrict_clamp_scaling(out))
     return out
Exemple #4
0
 def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
     if self.training:
         out = self.training_forward(x)
     else:
         if self.counter <= self.collect_stats_steps:
             out = self.buffer
         else:
             out = self.value
     out = abs_binary_sign_grad(out)
     min_int = self.int_quant.min_int(bit_width)
     out = self.int_quant.to_int(scale, min_int, bit_width, out)
     return out
Exemple #5
0
 def training_forward(self, stats_input: Tensor) -> Tensor:
     if self.counter < self.collect_stats_steps:
         stats_input = self.stats_input_view_shape_impl(stats_input)
         stats = self.stats(stats_input)
         new_counter = self.counter + 1
         if self.counter == 0:
             inplace_tensor_mul(self.buffer, stats.detach())
         else:
             inplace_momentum_update(self.buffer, stats.detach(),
                                     self.momentum, self.counter,
                                     new_counter)
         self.counter = new_counter
         # workaround to avoid find_ununsed_parameter=True in DDP
         stats = stats + 0. * self.value
         return stats
     elif self.counter == self.collect_stats_steps:
         self.restrict_inplace_preprocess(self.buffer)
         inplace_tensor_mul(self.value.detach(), self.buffer)
         self.counter = self.counter + 1
         return abs_binary_sign_grad(self.restrict_clamp_scaling(
             self.value))
     else:
         return abs_binary_sign_grad(self.restrict_clamp_scaling(
             self.value))
Exemple #6
0
 def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
     if self.training:
         if self.counter <= self.collect_stats_steps:
             if self.stats_permute_dims is not None:
                 x = x.permute(*self.stats_permute_dims).contiguous()
             stats_input = self.stats_input_view_shape_impl(x)
             stats = self.negative_min_or_zero(stats_input)
             if self.counter == 0:
                 self.value.detach().add_(stats.detach())
             else:
                 self.value.detach().mul_(1 - self.momentum)
                 self.value.detach().add_(self.momentum * stats.detach())
             self.counter = self.counter + 1
             out = stats
         else:
             out = self.value
     else:
         out = self.value
     out = abs_binary_sign_grad(out)
     min_int = self.int_quant.min_int(bit_width)
     out = self.int_quant.to_int(scale, min_int, bit_width, out)
     return out
Exemple #7
0
 def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor:
     out = abs_binary_sign_grad(self.value)
     min_int = self.int_quant.min_int(bit_width)
     out = self.int_quant.to_int(scale, bit_width, min_int, out)
     return out
Exemple #8
0
 def forward(self, placeholder: Tensor) -> Tensor:
     value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value))
     return value
Exemple #9
0
 def forward(self) -> Tensor:
     bit_width = abs_binary_sign_grad(
         self.bit_width_offset) + self.bit_width_base
     bit_width = self.restrict_bit_width_impl(bit_width)
     return bit_width