コード例 #1
0
ファイル: test_parameter.py プロジェクト: tomzhang/mindspore
def test_parameter_lazy_init():
    # Call init_data() without set default_input.
    para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1')
    assert not isinstance(para.default_input, Tensor)
    para.init_data()
    assert isinstance(para.default_input, Tensor)
    assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))

    # Call init_data() after default_input is set.
    para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2')
    assert not isinstance(para.default_input, Tensor)
    para.default_input = Tensor(np.zeros((1, 2, 3)))
    assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3)))
    para.init_data()  # expect no effect.
    assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3)))
コード例 #2
0
def test_parameter_lazy_init():
    _set_has_initializer(False)
    # support lazy init in SEMI_AUTO_PARALLEL mode
    context.reset_auto_parallel_context()
    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel",
                                      device_num=8)
    # Call init_data() without set default_input.
    para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1')
    assert not isinstance(para.default_input, Tensor)
    para = para.init_data()
    assert isinstance(para.default_input, Tensor)
    assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))

    # Call init_data() after default_input is set.
    para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2')
    assert not isinstance(para.default_input, Tensor)
    # expect type error when not init
    with pytest.raises(TypeError):
        para.default_input = Tensor(np.zeros((1, 2, 3)))
    # init then assign
    para = para.init_data()
    # check the type
    with pytest.raises(TypeError):
        para.default_input = Tensor(np.zeros((1, 2, 3)))
    # check the shape
    with pytest.raises(ValueError):
        para.default_input = Tensor(np.zeros((1, 2)))
    # expect change ok
    para.default_input = Tensor(np.zeros((1, 2, 3)).astype(np.float32))
    assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3)))
    para.default_input = initializer('ones', [1, 2, 3], mstype.float32)
    assert isinstance(para.default_input, Tensor)
    # same object and has inited
    assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))
    # expect no effect.
    para.init_data()
    assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))
    para.set_parameter_data(Tensor(np.zeros((1, 2)).astype(np.float32)),
                            slice_shape=True)
    assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2)))
    para.set_parameter_data(initializer('ones', [1, 2], mstype.float32),
                            slice_shape=True)
    assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2)))
    context.reset_auto_parallel_context()
コード例 #3
0
 class Net(nn.Cell):
     def __init__(self, initial, updated):
         super().__init__()
         self.initial = initial
         self.updated = updated
         self.p = Parameter(self.initial, name="weight")
         self.new_p = self.p.init_data()
         self.new_p.set_data(self.updated)
     def construct(self):
         return self.new_p