def test_parameter_lazy_init(): # Call init_data() without set default_input. para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1') assert not isinstance(para.default_input, Tensor) para.init_data() assert isinstance(para.default_input, Tensor) assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3))) # Call init_data() after default_input is set. para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2') assert not isinstance(para.default_input, Tensor) para.default_input = Tensor(np.zeros((1, 2, 3))) assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3))) para.init_data() # expect no effect. assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3)))
def test_parameter_lazy_init(): _set_has_initializer(False) # support lazy init in SEMI_AUTO_PARALLEL mode context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8) # Call init_data() without set default_input. para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1') assert not isinstance(para.default_input, Tensor) para = para.init_data() assert isinstance(para.default_input, Tensor) assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3))) # Call init_data() after default_input is set. para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2') assert not isinstance(para.default_input, Tensor) # expect type error when not init with pytest.raises(TypeError): para.default_input = Tensor(np.zeros((1, 2, 3))) # init then assign para = para.init_data() # check the type with pytest.raises(TypeError): para.default_input = Tensor(np.zeros((1, 2, 3))) # check the shape with pytest.raises(ValueError): para.default_input = Tensor(np.zeros((1, 2))) # expect change ok para.default_input = Tensor(np.zeros((1, 2, 3)).astype(np.float32)) assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3))) para.default_input = initializer('ones', [1, 2, 3], mstype.float32) assert isinstance(para.default_input, Tensor) # same object and has inited assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3))) # expect no effect. para.init_data() assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3))) para.set_parameter_data(Tensor(np.zeros((1, 2)).astype(np.float32)), slice_shape=True) assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2))) para.set_parameter_data(initializer('ones', [1, 2], mstype.float32), slice_shape=True) assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2))) context.reset_auto_parallel_context()
class Net(nn.Cell): def __init__(self, initial, updated): super().__init__() self.initial = initial self.updated = updated self.p = Parameter(self.initial, name="weight") self.new_p = self.p.init_data() self.new_p.set_data(self.updated) def construct(self): return self.new_p