def _gradients(*args, **kwargs): return (gradients_impl.gradients_v2(*args, **kwargs) if tf2.enabled() else gradients_impl.gradients(*args, **kwargs))
def step(c): x = constant_op.constant([[3.], [5.]]) y = constant_op.constant([[2.], [4.]]) mid = all_gather_fn([x, y]) y = mid * c return gradients_impl.gradients_v2(y, [x])[0]