def bprop(x, out, dout): if mean_flag: if F.issubclass_(F.typeof(dout), mstype.tensor): dx = all_reduce(dout) float_one = F.scalar_cast(1.0, F.dtype(dx)) num = F.scalar_cast(dev_num, F.dtype(dx)) dx = mul(dx, cast(F.scalar_to_array(float_one / num), F.dtype(dx))) else: indices = all_gather(dout.indices) grad = all_gather(dout.values) float_one = F.scalar_cast(1.0, F.dtype(grad)) num = F.scalar_cast(dev_num, F.dtype(grad)) grad = mul( grad, cast(F.scalar_to_array(float_one / num), F.dtype(grad))) dx = RowTensor(indices, grad, dout.dense_shape) else: if F.issubclass_(F.typeof(dout), mstype.tensor): dx = all_reduce(dout) else: indices = all_gather(dout.indices) grad = all_gather(dout.values) dx = RowTensor(indices, grad, dout.dense_shape) return (dx, )
def bprop(x, z, out, dout): if mean_flag: if F.issubclass_(F.typeof(dout), mstype.tensor): if do_mirror: z = F.depend(z, F.assign_add(z, dout)) real_grad = all_reduce(z) dx = real_grad else: dx = dout float_one = F.scalar_cast(1.0, F.dtype(dx)) num = F.scalar_cast(dev_num, F.dtype(dx)) dx = mul(dx, cast(F.scalar_to_array(float_one / num), F.dtype(dx))) else: dx = zeros_like( x) # The grad accumulation do not support row tensor now else: if F.issubclass_(F.typeof(dout), mstype.tensor): if do_mirror: z = F.depend(z, F.assign_add(z, dout)) real_grad = all_reduce(z) dx = real_grad else: dx = dout else: dx = zeros_like( x) # The grad accumulation do not support row tensor now return (dx, zeros_like(z))
def bprop(x, out, dout): if F.issubclass_(F.typeof(dout), mstype.tensor): if F.issubclass_(F.dtype(dout), mstype.bool_): return (dout, ) dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout))) return (dx, ) dx = () input_nums = F.tuple_len(dout) for i in range(input_nums): ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i]))) dx = dx + (ele_grad, ) return (dx, )
def bprop(x, out, dout): if F.issubclass_(F.typeof(dout), mstype.tensor): dx = all_reduce_grad(dout) else: indices = all_gather(dout.indices) grad = all_gather(dout.values) dx = RowTensor(indices, grad, dout.dense_shape) return (dx, )
def bprop(x, out, dout): if F.issubclass_(F.typeof(dout), mstype.tensor): dx = all_reduce_grad(dout) else: indices = all_gather(dout[0]) grad = all_gather(dout[1]) dx = (indices, grad, dout[2]) return (dx,)
def bprop(x, y, z, out, dout): do_mirror = equal(y, grad_accumulation_step) do_mirror = reshape(do_mirror, (())) if mean_flag: if F.issubclass_(F.typeof(dout), mstype.tensor): if do_mirror: tmp = z + dout real_grad = all_reduce(tmp) dx = real_grad - z else: dx = dout float_one = F.scalar_cast(1.0, F.dtype(dx)) num = F.scalar_cast(dev_num, F.dtype(dx)) dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) else: if do_mirror: indices = all_gather(dout.indices) grad = all_gather(dout.values) else: indices = dout.indices grad = dout.values float_one = F.scalar_cast(1.0, F.dtype(grad)) num = F.scalar_cast(dev_num, F.dtype(grad)) grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) dx = RowTensor(indices, grad, dout.dense_shape) else: if F.issubclass_(F.typeof(dout), mstype.tensor): if do_mirror: tmp = z + dout real_grad = all_reduce(tmp) dx = real_grad - z else: dx = dout else: if do_mirror: indices = all_gather(dout.indices) grad = all_gather(dout.values) else: indices = dout.indices grad = dout.values dx = RowTensor(indices, grad, dout.dense_shape) return (dx, zeros_like(y), zeros_like(z))
def bprop(x, out, dout): if F.issubclass_(F.typeof(dout), mstype.tensor): dx = all_reduce_grad(dout) z = equal(x, out) z = cast(z, dtype(dx)) dx = mul(dx, z) else: indices = all_gather(dout.indices) grad = all_gather(dout.values) z = equal(x, out) z = cast(z, dtype(grad)) grad = mul(grad, z) dx = RowTensor(indices, grad, dout.dense_shape) return (dx, )
def bprop(x, out, dout): if F.issubclass_(F.typeof(dout), mstype.tensor): dx = all_reduce_grad(dout) z = equal(x, out) z = cast(z, dtype(dx)) dx = mul(dx, z) else: indices = all_gather(dout[0]) grad = all_gather(dout[1]) z = equal(x, out) z = cast(z, dtype(grad)) grad = mul(grad, z) dx = (indices, grad, dout[2]) return (dx,)
def bprop(x, out, dout): if F.issubclass_(F.dtype(dout), mstype.bool_): return (dout, ) dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout))) return (dx, )