예제 #1
0
    def bprop(x, out, dout):
        if mean_flag:
            if F.issubclass_(F.typeof(dout), mstype.tensor):
                dx = all_reduce(dout)
                float_one = F.scalar_cast(1.0, F.dtype(dx))
                num = F.scalar_cast(dev_num, F.dtype(dx))
                dx = mul(dx,
                         cast(F.scalar_to_array(float_one / num), F.dtype(dx)))
            else:
                indices = all_gather(dout.indices)
                grad = all_gather(dout.values)
                float_one = F.scalar_cast(1.0, F.dtype(grad))
                num = F.scalar_cast(dev_num, F.dtype(grad))
                grad = mul(
                    grad,
                    cast(F.scalar_to_array(float_one / num), F.dtype(grad)))
                dx = RowTensor(indices, grad, dout.dense_shape)
        else:
            if F.issubclass_(F.typeof(dout), mstype.tensor):
                dx = all_reduce(dout)
            else:
                indices = all_gather(dout.indices)
                grad = all_gather(dout.values)
                dx = RowTensor(indices, grad, dout.dense_shape)

        return (dx, )
예제 #2
0
    def bprop(x, z, out, dout):
        if mean_flag:
            if F.issubclass_(F.typeof(dout), mstype.tensor):
                if do_mirror:
                    z = F.depend(z, F.assign_add(z, dout))
                    real_grad = all_reduce(z)
                    dx = real_grad
                else:
                    dx = dout
                float_one = F.scalar_cast(1.0, F.dtype(dx))
                num = F.scalar_cast(dev_num, F.dtype(dx))
                dx = mul(dx,
                         cast(F.scalar_to_array(float_one / num), F.dtype(dx)))
            else:
                dx = zeros_like(
                    x)  # The grad accumulation do not support row tensor now
        else:
            if F.issubclass_(F.typeof(dout), mstype.tensor):
                if do_mirror:
                    z = F.depend(z, F.assign_add(z, dout))
                    real_grad = all_reduce(z)
                    dx = real_grad
                else:
                    dx = dout
            else:
                dx = zeros_like(
                    x)  # The grad accumulation do not support row tensor now

        return (dx, zeros_like(z))
예제 #3
0
    def bprop(x, out, dout):
        if F.issubclass_(F.typeof(dout), mstype.tensor):
            if F.issubclass_(F.dtype(dout), mstype.bool_):
                return (dout, )
            dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
            return (dx, )

        dx = ()
        input_nums = F.tuple_len(dout)
        for i in range(input_nums):
            ele_grad = op(dout[i],
                          cast(F.scalar_to_array(divisor), dtype(dout[i])))
            dx = dx + (ele_grad, )
        return (dx, )
예제 #4
0
 def bprop(x, out, dout):
     if F.issubclass_(F.typeof(dout), mstype.tensor):
         dx = all_reduce_grad(dout)
     else:
         indices = all_gather(dout.indices)
         grad = all_gather(dout.values)
         dx = RowTensor(indices, grad, dout.dense_shape)
     return (dx, )
예제 #5
0
 def bprop(x, out, dout):
     if F.issubclass_(F.typeof(dout), mstype.tensor):
         dx = all_reduce_grad(dout)
     else:
         indices = all_gather(dout[0])
         grad = all_gather(dout[1])
         dx = (indices, grad, dout[2])
     return (dx,)
예제 #6
0
    def bprop(x, y, z, out, dout):
        do_mirror = equal(y, grad_accumulation_step)
        do_mirror = reshape(do_mirror, (()))
        if mean_flag:
            if F.issubclass_(F.typeof(dout), mstype.tensor):
                if do_mirror:
                    tmp = z + dout
                    real_grad = all_reduce(tmp)
                    dx = real_grad - z
                else:
                    dx = dout
                float_one = F.scalar_cast(1.0, F.dtype(dx))
                num = F.scalar_cast(dev_num, F.dtype(dx))
                dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx)))
            else:
                if do_mirror:
                    indices = all_gather(dout.indices)
                    grad = all_gather(dout.values)
                else:
                    indices = dout.indices
                    grad = dout.values
                float_one = F.scalar_cast(1.0, F.dtype(grad))
                num = F.scalar_cast(dev_num, F.dtype(grad))
                grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad)))
                dx = RowTensor(indices, grad, dout.dense_shape)
        else:
            if F.issubclass_(F.typeof(dout), mstype.tensor):
                if do_mirror:
                    tmp = z + dout
                    real_grad = all_reduce(tmp)
                    dx = real_grad - z
                else:
                    dx = dout
            else:
                if do_mirror:
                    indices = all_gather(dout.indices)
                    grad = all_gather(dout.values)
                else:
                    indices = dout.indices
                    grad = dout.values
                dx = RowTensor(indices, grad, dout.dense_shape)

        return (dx, zeros_like(y), zeros_like(z))
예제 #7
0
 def bprop(x, out, dout):
     if F.issubclass_(F.typeof(dout), mstype.tensor):
         dx = all_reduce_grad(dout)
         z = equal(x, out)
         z = cast(z, dtype(dx))
         dx = mul(dx, z)
     else:
         indices = all_gather(dout.indices)
         grad = all_gather(dout.values)
         z = equal(x, out)
         z = cast(z, dtype(grad))
         grad = mul(grad, z)
         dx = RowTensor(indices, grad, dout.dense_shape)
     return (dx, )
예제 #8
0
 def bprop(x, out, dout):
     if F.issubclass_(F.typeof(dout), mstype.tensor):
         dx = all_reduce_grad(dout)
         z = equal(x, out)
         z = cast(z, dtype(dx))
         dx = mul(dx, z)
     else:
         indices = all_gather(dout[0])
         grad = all_gather(dout[1])
         z = equal(x, out)
         z = cast(z, dtype(grad))
         grad = mul(grad, z)
         dx = (indices, grad, dout[2])
     return (dx,)
예제 #9
0
 def bprop(x, out, dout):
     if F.issubclass_(F.dtype(dout), mstype.bool_):
         return (dout, )
     dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
     return (dx, )