Example #1
0
def _sharding_grad(op, grad):
    sharding_attr = op.get_attr("sharding")
    grad_sharding = gen_xla_ops.xla_sharding(grad, sharding=sharding_attr)
    # pylint: disable=protected-access
    grad_sharding.op._set_attr("_XlaSharding",
                               attr_value_pb2.AttrValue(s=sharding_attr))
    return [grad_sharding]
Example #2
0
def _sharding_grad(op, grad):
  """Gradient for XlaSharding op."""
  sharding_attr = op.get_attr("sharding")
  grad_sharding = gen_xla_ops.xla_sharding(
      grad,
      sharding=sharding_attr,
      unspecified_dims=op.get_attr("unspecified_dims"))
  # pylint: disable=protected-access
  grad_sharding.op._set_attr("_XlaSharding",
                             attr_value_pb2.AttrValue(s=sharding_attr))
  return [grad_sharding]