def test_rsqrt(): np.random.seed(0) x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) output_ms = P.Rsqrt()(Tensor(x_np)) output_np = 1 / np.sqrt(x_np) assert np.allclose(output_ms.asnumpy(), output_np)
def test_sqrt(): x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") output_ms = P.Sqrt()(Tensor(x_np)) output_np = np.sqrt(x_np) assert np.allclose(output_ms.asnumpy(), output_np) output_ms = P.Rsqrt()(Tensor(x_np)) output_np = 1 / np.sqrt(x_np) assert np.allclose(output_ms.asnumpy(), output_np)
def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter): """ Update parameters. Args: beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. lr (Tensor): Learning rate. weight_decay (Number): Weight decay. Should be equal to or greater than 0. global_step (Tensor): Global step. param (Tensor): Parameters. m (Tensor): m value of parameters. v (Tensor): v value of parameters. gradient (Tensor): Gradient of parameters. decay_flag (bool): Specifies whether param update with weight decay. optim_filter(bool): Applies parameter update or not. Returns: Tensor, the new value of v after updating. """ if optim_filter: op_mul = P.Mul() op_sqrt = P.Sqrt() op_rsqrt = P.Rsqrt() op_square = P.Square() op_cast = P.Cast() op_reshape = P.Reshape() op_shape = P.Shape() op_pow = P.Pow() op_norm = layer.Norm() op_select = P.Select() op_greater = P.Greater() op_fill = P.Fill() op_dtype = P.DType() param_fp32 = op_cast(param, mstype.float32) m_fp32 = op_cast(m, mstype.float32) v_fp32 = op_cast(v, mstype.float32) gradient_fp32 = op_cast(gradient, mstype.float32) next_m = op_mul(beta1, m_fp32) + op_mul( op_cast(num_one, mstype.float32) - beta1, gradient_fp32) next_v = op_mul(beta2, v_fp32) + op_mul( op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32)) next_mm = next_m / (op_cast(num_one, mstype.float32) - op_pow( beta1, op_cast(global_step + num_one, mstype.float32))) next_vv = next_v / (op_cast(num_one, mstype.float32) - op_pow( beta2, op_cast(global_step + num_one, mstype.float32))) w_norm = op_norm(param_fp32) g_norm = op_norm(gradient_fp32) g_norm_hat = op_norm( op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay * param_fp32) zeros = F.zeros_like(w_norm) ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) trust_ratio = op_select( op_greater(w_norm, zeros), op_select(op_greater(g_norm, zeros), w_norm / g_norm_hat, ones), ones) tens = op_fill(op_dtype(trust_ratio), op_shape(trust_ratio), 10.0) trust_ratio = C.clip_by_value(trust_ratio, zeros, tens) update = next_mm / (op_sqrt(next_vv) + eps) if decay_flag: update = update + op_mul(weight_decay, param_fp32) update_with_lr = op_mul(op_mul(trust_ratio, lr), update) next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) next_param = F.depend( next_param, F.assign(param, op_cast(next_param, F.dtype(param)))) next_param = F.depend(next_param, F.assign(m, op_cast(next_m, F.dtype(m)))) next_param = F.depend(next_param, F.assign(v, op_cast(next_v, F.dtype(v)))) return op_cast(next_param, F.dtype(param)) return gradient
'desc_bprop': [[2, 3]]}), ('Rank', { 'block': P.Rank(), 'desc_inputs': [[2, 3]], 'skip': ['backward']}), ('InvertPermutation', { 'block': P.InvertPermutation(), 'desc_const': [(0, 3, 1, 2)], 'desc_inputs': [], 'skip': ['backward']}), ('Square', { 'block': P.Square(), 'desc_inputs': [[4]], 'desc_bprop': [[4]]}), ('Rsqrt', { 'block': P.Rsqrt(), 'desc_inputs': [[4]], 'desc_bprop': [[4]]}), ('Sqrt', { 'block': P.Sqrt(), 'desc_inputs': [[4]], 'desc_bprop': [[4]]}), ('RealDiv', { 'block': P.RealDiv(), 'desc_inputs': [[4, 5], [2, 3, 4, 5]], 'desc_bprop': [[2, 3, 4, 5]]}), ('Div', { 'block': P.Div(), 'desc_inputs': [[4, 5], [2, 3, 4, 5]], 'desc_bprop': [[2, 3, 4, 5]]}), ('Equal', {
# # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ from mindspore.ops import Primitive from mindspore.ops import operations as P add = P.TensorAdd() mul = P.Mul() real_div = P.RealDiv() rsqrt = P.Rsqrt() sqrt = P.Sqrt() make_tuple = Primitive('make_tuple') tuple_getitem = Primitive('tuple_getitem') LambNextMVWithDecayV1 = Primitive('LambNextMVWithDecayV1') class FnDict: def __init__(self): self.fnDict = {} def __call__(self, fn): self.fnDict[fn.__name__] = fn def __getitem__(self, name): return self.fnDict[name]
def __init__(self, strategy1, strategy2): super().__init__() self.matmul = P.MatMul().set_strategy(strategy1) self.rsqrt = P.Rsqrt().set_strategy(strategy2) self.matmul2 = P.MatMul().set_strategy(strategy1)
# input two tensors, their shapes do not match ('Mul2', { 'block': (P.Mul(), {'exception': ValueError, 'error_keywords': ['Mul']}), 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], 'skip': ['backward']}), # input is Tensor(bool) ('Square1', { 'block': (P.Square(), {'exception': TypeError, 'error_keywords': ['Square']}), 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], 'skip': ['backward']}), # input is Tensor(bool) ('Rsqrt1', { 'block': (P.Rsqrt(), {'exception': TypeError, 'error_keywords': ['Rsqrt']}), 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], 'skip': ['backward']}), # input is Tensor(bool) ('Sqrt1', { 'block': (P.Sqrt(), {'exception': TypeError, 'error_keywords': ['Sqrt']}), 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], 'skip': ['backward']}), # input is not Tensor ('Reciprocal1', { 'block': (P.Reciprocal(), {'exception': TypeError, 'error_keywords': ['Reciprocal']}),