class AdaLossFixedBatchNormalization( batch_normalization.FixedBatchNormalization): """ Wrapping the fixed batch normalization function """ def __init__(self, eps=2e-5, axis=None, ada_loss_cfg=None): super().__init__(eps=eps, axis=axis) if ada_loss_cfg is None: ada_loss_cfg = {} self.ada_loss = AdaLossChainer(**ada_loss_cfg) def backward(self, indexes, grad_outputs): """ wrap around the original backward function """ x, gamma, mean, var = self.get_retained_inputs() gy, = grad_outputs f = batch_normalization.FixedBatchNormalizationGrad( self.eps, self.expander, self.axis, self.inv_std, self.inv_var) gx, ggamma, gbeta, gmean, gvar = f.apply((x, gamma, mean, var, gy)) prev_scale = self.ada_loss.get_prev_scale(gy) ggamma_ = self.ada_loss.get_unscaled_gradient(ggamma, prev_scale, dtype=ggamma.dtype) gbeta_ = self.ada_loss.get_unscaled_gradient(gbeta, prev_scale, dtype=gbeta.dtype) gmean_ = self.ada_loss.get_unscaled_gradient(gmean, prev_scale, dtype=gmean.dtype) gvar_ = self.ada_loss.get_unscaled_gradient(gvar, prev_scale, dtype=gvar.dtype) self.ada_loss.set_loss_scale(gx, prev_scale) # pass along return gx, ggamma_, gbeta_, gmean_, gvar_
class AdaLossBatchNormalization(batch_normalization.BatchNormalization): """ """ def __init__(self, eps=2e-5, mean=None, var=None, decay=0.9, axis=None, ada_loss_cfg=None): super().__init__(eps=eps, mean=mean, var=var, decay=decay, axis=axis) if ada_loss_cfg is None: ada_loss_cfg = {} self.ada_loss = AdaLossChainer(**ada_loss_cfg) def backward(self, indexes, grad_outputs): x, gamma = self.get_retained_inputs() gy, = grad_outputs if self.use_ideep: assert self.var is not None var = self.var else: var = None f = batch_normalization.BatchNormalizationGrad( self.eps, self.use_cudnn, self.mode, self.expander, self.axis, self.mean, var, self.inv_std, self.key_axis) gx, ggamma, gbeta = f.apply((x, gamma, gy)) # update the loss scale prev_scale = self.ada_loss.get_prev_scale(gy) # NOTE: the numerical stability here? # NOTE: efficiency? ggamma_ = self.ada_loss.get_unscaled_gradient(ggamma, prev_scale, dtype=ggamma.dtype) gbeta_ = self.ada_loss.get_unscaled_gradient(gbeta, prev_scale, dtype=ggamma.dtype) self.ada_loss.set_loss_scale(gx, prev_scale) # pass along return gx, ggamma_, gbeta_