def forward(self, stats_input: Tensor) -> Tensor: if self.training: if self.counter < self.collect_stats_steps: if self.stats_permute_dims is not None: stats_input = stats_input.permute( *self.stats_permute_dims).contiguous() stats_input = self.stats_input_view_shape_impl(stats_input) stats = self.stats(stats_input) if self.counter == 0: self.value.detach().mul_(stats.detach()) else: self.value.detach().mul_(1 - self.momentum) self.value.detach().add_(self.momentum * stats.detach()) self.counter = self.counter + 1 return stats elif self.counter == self.collect_stats_steps: self.restrict_inplace_preprocess(self.value.detach()) self.counter = self.counter + 1 return self.restrict_clamp_scaling( abs_binary_sign_grad(self.value)) else: return self.restrict_clamp_scaling( abs_binary_sign_grad(self.value)) out = self.restrict_clamp_scaling(abs_binary_sign_grad(self.value)) return out
def training_forward(self, stats_input: Tensor) -> Tensor: if self.counter < self.collect_stats_steps: stats_input = self.stats_input_view_shape_impl(stats_input) stats = self.stats(stats_input) new_counter = self.counter + 1 if self.counter == 0: _inplace_init(self.buffer, stats.detach()) else: _inplace_update(self.buffer, stats.detach(), self.momentum, self.counter, new_counter) self.counter = new_counter return stats elif self.counter == self.collect_stats_steps: self.restrict_inplace_preprocess(self.buffer) _inplace_init(self.value.detach(), self.buffer) self.counter = self.counter + 1 return abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) else: return abs_binary_sign_grad(self.restrict_clamp_scaling(self.value))
def forward(self, stats_input: Tensor) -> Tensor: if self.training: return self.training_forward(stats_input) else: if self.counter <= self.collect_stats_steps: out = self.buffer out = self.restrict_preprocess(out) else: out = self.value out = abs_binary_sign_grad(self.restrict_clamp_scaling(out)) return out
def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor: if self.training: out = self.training_forward(x) else: if self.counter <= self.collect_stats_steps: out = self.buffer else: out = self.value out = abs_binary_sign_grad(out) min_int = self.int_quant.min_int(bit_width) out = self.int_quant.to_int(scale, min_int, bit_width, out) return out
def training_forward(self, stats_input: Tensor) -> Tensor: if self.counter < self.collect_stats_steps: stats_input = self.stats_input_view_shape_impl(stats_input) stats = self.stats(stats_input) new_counter = self.counter + 1 if self.counter == 0: inplace_tensor_mul(self.buffer, stats.detach()) else: inplace_momentum_update(self.buffer, stats.detach(), self.momentum, self.counter, new_counter) self.counter = new_counter # workaround to avoid find_ununsed_parameter=True in DDP stats = stats + 0. * self.value return stats elif self.counter == self.collect_stats_steps: self.restrict_inplace_preprocess(self.buffer) inplace_tensor_mul(self.value.detach(), self.buffer) self.counter = self.counter + 1 return abs_binary_sign_grad(self.restrict_clamp_scaling( self.value)) else: return abs_binary_sign_grad(self.restrict_clamp_scaling( self.value))
def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor: if self.training: if self.counter <= self.collect_stats_steps: if self.stats_permute_dims is not None: x = x.permute(*self.stats_permute_dims).contiguous() stats_input = self.stats_input_view_shape_impl(x) stats = self.negative_min_or_zero(stats_input) if self.counter == 0: self.value.detach().add_(stats.detach()) else: self.value.detach().mul_(1 - self.momentum) self.value.detach().add_(self.momentum * stats.detach()) self.counter = self.counter + 1 out = stats else: out = self.value else: out = self.value out = abs_binary_sign_grad(out) min_int = self.int_quant.min_int(bit_width) out = self.int_quant.to_int(scale, min_int, bit_width, out) return out
def forward(self, x: Tensor, scale: Tensor, bit_width: Tensor) -> Tensor: out = abs_binary_sign_grad(self.value) min_int = self.int_quant.min_int(bit_width) out = self.int_quant.to_int(scale, bit_width, min_int, out) return out
def forward(self, placeholder: Tensor) -> Tensor: value = abs_binary_sign_grad(self.restrict_clamp_scaling(self.value)) return value
def forward(self) -> Tensor: bit_width = abs_binary_sign_grad( self.bit_width_offset) + self.bit_width_base bit_width = self.restrict_bit_width_impl(bit_width) return bit_width