Beispiel #1
0
    def bias(self, ext, module, g_inp, g_out, backproped):
        bp_strategy = ext.get_backprop_strategy()

        if BackpropStrategy.is_batch_average(bp_strategy):
            return self._bias_for_batch_average(module, backproped)
        elif BackpropStrategy.is_sqrt(bp_strategy):
            return self._bias_for_sqrt(module, backproped)
Beispiel #2
0
    def bias(self, ext, module, g_inp, g_out, backproped):
        self.check_parameters(ext, module)
        bp_strategy = ext.get_backprop_strategy()

        if BackpropStrategy.is_batch_average(bp_strategy):
            return self._bias_for_batch_average(backproped)
        elif BackpropStrategy.is_sqrt(bp_strategy):
            return self._factor_from_sqrt(backproped)
Beispiel #3
0
    def weight(self, ext, module, g_inp, g_out, backproped):
        self.check_parameters(ext, module)
        bp_strategy = ext.get_backprop_strategy()

        if BackpropStrategy.is_batch_average(bp_strategy):
            return self._weight_for_batch_average(ext, module, backproped)

        elif BackpropStrategy.is_sqrt(bp_strategy):
            return self._weight_for_sqrt(ext, module, backproped)
Beispiel #4
0
    def backpropagate(self, ext, module, g_inp, g_out, backproped):
        bp_strategy = ext.get_backprop_strategy()

        if BackpropStrategy.is_batch_average(bp_strategy):
            return self.backpropagate_batch_average(ext, module, g_inp, g_out,
                                                    backproped)

        elif BackpropStrategy.is_sqrt(bp_strategy):
            return self.backpropagate_sqrt(ext, module, g_inp, g_out,
                                           backproped)
Beispiel #5
0
    def weight(self, ext, module, g_inp, g_out, backproped):

        if module.groups != 1:
            raise NotImplementedError(
                f"groups ≠ 1 is not supported by {ext.__class__.__name__} " +
                f"(got {module.groups}).")

        bp_strategy = ext.get_backprop_strategy()

        if BackpropStrategy.is_batch_average(bp_strategy):
            return self._weight_for_batch_average(ext, module, backproped)

        elif BackpropStrategy.is_sqrt(bp_strategy):
            return self._weight_for_sqrt(ext, module, backproped)