def backward(grad_output, input_, input_low, input_range, output, level_low, level_high, range_sign): mask_hi = (input_ > (input_low + input_range)).astype(input_.dtype) mask_lo = (input_ < input_low).astype(input_.dtype) mask_in = 1 - mask_hi - mask_lo err = (output - input_) * np.reciprocal(input_range * range_sign) grad_range = grad_output * (err * mask_in + range_sign * (level_low / level_high) * mask_lo + mask_hi) grad_range = sum_like(grad_range, input_range) grad_input = grad_output * mask_in grad_low = grad_output * (mask_hi + mask_lo) grad_low = sum_like(grad_low, input_low) return [grad_input, grad_low, grad_range]
def backward(ctx, grad_output): input_, scale, output = ctx.saved_tensors level_high = ctx.level_high level_low = ctx.level_low alpha = float(level_low) / float(level_high) mask_hi = (input_ > scale).type(input_.dtype) mask_lo = (input_ < scale * alpha).type(input_.dtype) mask_in = 1 - mask_hi - mask_lo val_grad_out = mask_hi + alpha * mask_lo err = (output - input_) * scale.reciprocal() grad_scale = grad_output * (err * mask_in + val_grad_out) grad_scale = sum_like(grad_scale, scale) # calc gradient for input grad_input = grad_output * mask_in return grad_input, grad_scale, None