def _sharding_grad(op, grad): sharding_attr = op.get_attr("sharding") grad_sharding = gen_xla_ops.xla_sharding(grad, sharding=sharding_attr) # pylint: disable=protected-access grad_sharding.op._set_attr("_XlaSharding", attr_value_pb2.AttrValue(s=sharding_attr)) return [grad_sharding]
def _sharding_grad(op, grad): """Gradient for XlaSharding op.""" sharding_attr = op.get_attr("sharding") grad_sharding = gen_xla_ops.xla_sharding( grad, sharding=sharding_attr, unspecified_dims=op.get_attr("unspecified_dims")) # pylint: disable=protected-access grad_sharding.op._set_attr("_XlaSharding", attr_value_pb2.AttrValue(s=sharding_attr)) return [grad_sharding]