Esempio n. 1
0
def test_rsqrt():
    np.random.seed(0)
    x_np = np.random.rand(2, 3, 4, 4).astype(np.float32)

    output_ms = P.Rsqrt()(Tensor(x_np))
    output_np = 1 / np.sqrt(x_np)
    assert np.allclose(output_ms.asnumpy(), output_np)
Esempio n. 2
0
def test_sqrt():
    x_np = np.random.rand(2, 3, 4, 4).astype(np.float32)

    context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
    output_ms = P.Sqrt()(Tensor(x_np))
    output_np = np.sqrt(x_np)
    assert np.allclose(output_ms.asnumpy(), output_np)

    output_ms = P.Rsqrt()(Tensor(x_np))
    output_np = 1 / np.sqrt(x_np)
    assert np.allclose(output_ms.asnumpy(), output_np)
Esempio n. 3
0
def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m,
                   v, gradient, decay_flag, optim_filter):
    """
    Update parameters.

    Args:
        beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0).
        beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0).
        eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0.
        lr (Tensor): Learning rate.
        weight_decay (Number): Weight decay. Should be equal to or greater than 0.
        global_step (Tensor): Global step.
        param (Tensor): Parameters.
        m (Tensor): m value of parameters.
        v (Tensor): v value of parameters.
        gradient (Tensor): Gradient of parameters.
        decay_flag (bool): Specifies whether param update with weight decay.
        optim_filter(bool): Applies parameter update or not.

    Returns:
        Tensor, the new value of v after updating.
    """
    if optim_filter:
        op_mul = P.Mul()
        op_sqrt = P.Sqrt()
        op_rsqrt = P.Rsqrt()
        op_square = P.Square()
        op_cast = P.Cast()
        op_reshape = P.Reshape()
        op_shape = P.Shape()
        op_pow = P.Pow()
        op_norm = layer.Norm()
        op_select = P.Select()
        op_greater = P.Greater()
        op_fill = P.Fill()
        op_dtype = P.DType()

        param_fp32 = op_cast(param, mstype.float32)
        m_fp32 = op_cast(m, mstype.float32)
        v_fp32 = op_cast(v, mstype.float32)
        gradient_fp32 = op_cast(gradient, mstype.float32)

        next_m = op_mul(beta1, m_fp32) + op_mul(
            op_cast(num_one, mstype.float32) - beta1, gradient_fp32)

        next_v = op_mul(beta2, v_fp32) + op_mul(
            op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))

        next_mm = next_m / (op_cast(num_one, mstype.float32) - op_pow(
            beta1, op_cast(global_step + num_one, mstype.float32)))
        next_vv = next_v / (op_cast(num_one, mstype.float32) - op_pow(
            beta2, op_cast(global_step + num_one, mstype.float32)))
        w_norm = op_norm(param_fp32)
        g_norm = op_norm(gradient_fp32)

        g_norm_hat = op_norm(
            op_mul(next_mm, op_rsqrt(next_vv + eps)) +
            weight_decay * param_fp32)
        zeros = F.zeros_like(w_norm)
        ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
        trust_ratio = op_select(
            op_greater(w_norm, zeros),
            op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones),
            ones)
        tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0)
        trust_ratio = C.clip_by_value(trust_ratio, zeros, tens)
        update = next_mm / (op_sqrt(next_vv) + eps)

        if decay_flag:
            update = update + op_mul(weight_decay, param_fp32)

        update_with_lr = op_mul(op_mul(trust_ratio, lr), update)

        next_param = param_fp32 - op_reshape(update_with_lr,
                                             op_shape(param_fp32))

        next_param = F.depend(
            next_param, F.assign(param, op_cast(next_param, F.dtype(param))))
        next_param = F.depend(next_param,
                              F.assign(m, op_cast(next_m, F.dtype(m))))
        next_param = F.depend(next_param,
                              F.assign(v, op_cast(next_v, F.dtype(v))))

        return op_cast(next_param, F.dtype(param))
    return gradient
Esempio n. 4
0
     'desc_bprop': [[2, 3]]}),
 ('Rank', {
     'block': P.Rank(),
     'desc_inputs': [[2, 3]],
     'skip': ['backward']}),
 ('InvertPermutation', {
     'block': P.InvertPermutation(),
     'desc_const': [(0, 3, 1, 2)],
     'desc_inputs': [],
     'skip': ['backward']}),
 ('Square', {
     'block': P.Square(),
     'desc_inputs': [[4]],
     'desc_bprop': [[4]]}),
 ('Rsqrt', {
     'block': P.Rsqrt(),
     'desc_inputs': [[4]],
     'desc_bprop': [[4]]}),
 ('Sqrt', {
     'block': P.Sqrt(),
     'desc_inputs': [[4]],
     'desc_bprop': [[4]]}),
 ('RealDiv', {
     'block': P.RealDiv(),
     'desc_inputs': [[4, 5], [2, 3, 4, 5]],
     'desc_bprop': [[2, 3, 4, 5]]}),
 ('Div', {
     'block': P.Div(),
     'desc_inputs': [[4, 5], [2, 3, 4, 5]],
     'desc_bprop': [[2, 3, 4, 5]]}),
 ('Equal', {
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P

add = P.TensorAdd()
mul = P.Mul()
real_div = P.RealDiv()
rsqrt = P.Rsqrt()
sqrt = P.Sqrt()
make_tuple = Primitive('make_tuple')
tuple_getitem = Primitive('tuple_getitem')
LambNextMVWithDecayV1 = Primitive('LambNextMVWithDecayV1')


class FnDict:
    def __init__(self):
        self.fnDict = {}

    def __call__(self, fn):
        self.fnDict[fn.__name__] = fn

    def __getitem__(self, name):
        return self.fnDict[name]
Esempio n. 6
0
 def __init__(self, strategy1, strategy2):
     super().__init__()
     self.matmul = P.MatMul().set_strategy(strategy1)
     self.rsqrt = P.Rsqrt().set_strategy(strategy2)
     self.matmul2 = P.MatMul().set_strategy(strategy1)
    # input two tensors, their shapes do not match
    ('Mul2', {
        'block': (P.Mul(), {'exception': ValueError, 'error_keywords': ['Mul']}),
        'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))],
        'skip': ['backward']}),

    # input is Tensor(bool)
    ('Square1', {
        'block': (P.Square(),
        {'exception': TypeError, 'error_keywords': ['Square']}),
        'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
        'skip': ['backward']}),

    # input is Tensor(bool)
    ('Rsqrt1', {
        'block': (P.Rsqrt(),
        {'exception': TypeError, 'error_keywords': ['Rsqrt']}),
        'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
        'skip': ['backward']}),

    # input is Tensor(bool)
    ('Sqrt1', {
        'block': (P.Sqrt(),
        {'exception': TypeError, 'error_keywords': ['Sqrt']}),
        'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))],
        'skip': ['backward']}),

    # input is not Tensor
    ('Reciprocal1', {
        'block': (P.Reciprocal(),
        {'exception': TypeError, 'error_keywords': ['Reciprocal']}),