def bias(self, ext, module, g_inp, g_out, backproped): bp_strategy = ext.get_backprop_strategy() if BackpropStrategy.is_batch_average(bp_strategy): return self._bias_for_batch_average(module, backproped) elif BackpropStrategy.is_sqrt(bp_strategy): return self._bias_for_sqrt(module, backproped)
def bias(self, ext, module, g_inp, g_out, backproped): self.check_parameters(ext, module) bp_strategy = ext.get_backprop_strategy() if BackpropStrategy.is_batch_average(bp_strategy): return self._bias_for_batch_average(backproped) elif BackpropStrategy.is_sqrt(bp_strategy): return self._factor_from_sqrt(backproped)
def weight(self, ext, module, g_inp, g_out, backproped): self.check_parameters(ext, module) bp_strategy = ext.get_backprop_strategy() if BackpropStrategy.is_batch_average(bp_strategy): return self._weight_for_batch_average(ext, module, backproped) elif BackpropStrategy.is_sqrt(bp_strategy): return self._weight_for_sqrt(ext, module, backproped)
def backpropagate(self, ext, module, g_inp, g_out, backproped): bp_strategy = ext.get_backprop_strategy() if BackpropStrategy.is_batch_average(bp_strategy): return self.backpropagate_batch_average(ext, module, g_inp, g_out, backproped) elif BackpropStrategy.is_sqrt(bp_strategy): return self.backpropagate_sqrt(ext, module, g_inp, g_out, backproped)
def weight(self, ext, module, g_inp, g_out, backproped): if module.groups != 1: raise NotImplementedError( f"groups ≠ 1 is not supported by {ext.__class__.__name__} " + f"(got {module.groups}).") bp_strategy = ext.get_backprop_strategy() if BackpropStrategy.is_batch_average(bp_strategy): return self._weight_for_batch_average(ext, module, backproped) elif BackpropStrategy.is_sqrt(bp_strategy): return self._weight_for_sqrt(ext, module, backproped)