コード例 #1
0
 def setUp(self):
     x_shape_0 = 2
     x_shape_1 = numpy.int64(3)
     self.link = chainer.Link(x=((x_shape_0, x_shape_1), 'd'),
                              u=(None, 'd'))
     with self.link.init_scope():
         self.link.y = chainer.Parameter(shape=(2, ))
         self.link.v = chainer.Parameter()
     self.p = numpy.array([1, 2, 3], dtype='f')
     self.link.add_persistent('p', self.p)
     self.link.name = 'a'
     self.link.x.update_rule = chainer.UpdateRule()
     self.link.x.update_rule.enabled = False
     self.link.u.update_rule = chainer.UpdateRule()
     if cuda.available:
         self.current_device_id = cuda.cupy.cuda.get_device_id()
コード例 #2
0
ファイル: test_optimizer.py プロジェクト: zwcdp/chainer
    def test_uninitialized_parameter(self):
        dtype = self.dtype

        def initializer(array):
            assert False  # never called

        # Set initializer.dtype to specify the parameter's dtype
        if dtype is not None:
            initializer.dtype = dtype

        # Create an uninitialized parameter
        param = chainer.Parameter(initializer)
        assert param.array is None
        if dtype is not None:
            assert param.dtype == dtype

        # Create an update rule with custom update_core
        record = []
        update_rule = chainer.UpdateRule()

        def update_core(param):
            # param.dtype may not be retrieved because it can be uninitialized
            # and dtype is not given (i.e. self.dtype is None)
            try:
                param_dtype = param.dtype
            except RuntimeError:
                param_dtype = None
            record.append({
                'param': param,
                'dtype': param_dtype,
            })

        update_rule.update_core = update_core

        # Enable fp32 update
        update_rule.use_fp32_update()
        # Call update_rule.update
        update_rule.update(param)

        if dtype == np.float16:
            assert record[0]['param'] is not param
            assert record[0]['dtype'] == np.float32
        else:
            assert record[0]['param'] is param
            assert record[0]['dtype'] == dtype

        # The original parameter is kept uninitialized and its dtype is
        # unchanged.
        assert param.array is None
        if dtype is not None:
            assert param.dtype == dtype
        else:
            with pytest.raises(RuntimeError):
                param.dtype