예제 #1
0
    def bias(self, ext, module, g_inp, g_out, backproped):
        bp_strategy = ext.get_backprop_strategy()

        if BackpropStrategy.is_batch_average(bp_strategy):
            return self._bias_for_batch_average(module, backproped)
        elif BackpropStrategy.is_sqrt(bp_strategy):
            return self._bias_for_sqrt(module, backproped)
예제 #2
0
파일: linear.py 프로젝트: f-dangel/backpack
    def bias(self, ext, module, g_inp, g_out, backproped):
        self.check_parameters(ext, module)
        bp_strategy = ext.get_backprop_strategy()

        if BackpropStrategy.is_batch_average(bp_strategy):
            return self._bias_for_batch_average(backproped)
        elif BackpropStrategy.is_sqrt(bp_strategy):
            return self._factor_from_sqrt(backproped)
예제 #3
0
파일: linear.py 프로젝트: f-dangel/backpack
    def weight(self, ext, module, g_inp, g_out, backproped):
        self.check_parameters(ext, module)
        bp_strategy = ext.get_backprop_strategy()

        if BackpropStrategy.is_batch_average(bp_strategy):
            return self._weight_for_batch_average(ext, module, backproped)

        elif BackpropStrategy.is_sqrt(bp_strategy):
            return self._weight_for_sqrt(ext, module, backproped)
예제 #4
0
    def backpropagate(self, ext, module, g_inp, g_out, backproped):
        bp_strategy = ext.get_backprop_strategy()

        if BackpropStrategy.is_batch_average(bp_strategy):
            return self.backpropagate_batch_average(ext, module, g_inp, g_out,
                                                    backproped)

        elif BackpropStrategy.is_sqrt(bp_strategy):
            return self.backpropagate_sqrt(ext, module, g_inp, g_out,
                                           backproped)
예제 #5
0
파일: conv2d.py 프로젝트: f-dangel/backpack
    def weight(self, ext, module, g_inp, g_out, backproped):

        if module.groups != 1:
            raise NotImplementedError(
                f"groups ≠ 1 is not supported by {ext.__class__.__name__} " +
                f"(got {module.groups}).")

        bp_strategy = ext.get_backprop_strategy()

        if BackpropStrategy.is_batch_average(bp_strategy):
            return self._weight_for_batch_average(ext, module, backproped)

        elif BackpropStrategy.is_sqrt(bp_strategy):
            return self._weight_for_sqrt(ext, module, backproped)