Esempio n. 1
0
    def test_soft_copy_param(self):
        a = L.Linear(1, 5)
        b = L.Linear(1, 5)

        a.W.array[:] = 0.5
        b.W.array[:] = 1

        # a = (1 - tau) * a + tau * b
        copy_param.soft_copy_param(target_link=a, source_link=b, tau=0.1)

        np.testing.assert_almost_equal(a.W.array, np.full(a.W.shape, 0.55))
        np.testing.assert_almost_equal(b.W.array, np.full(b.W.shape, 1.0))

        copy_param.soft_copy_param(target_link=a, source_link=b, tau=0.1)

        np.testing.assert_almost_equal(a.W.array, np.full(a.W.shape, 0.595))
        np.testing.assert_almost_equal(b.W.array, np.full(b.W.shape, 1.0))
Esempio n. 2
0
    def test_soft_copy_param_scalar(self):
        a = chainer.Chain()
        with a.init_scope():
            a.p = chainer.Parameter(np.array(0.5))
        b = chainer.Chain()
        with b.init_scope():
            b.p = chainer.Parameter(np.array(1))

        # a = (1 - tau) * a + tau * b
        copy_param.soft_copy_param(target_link=a, source_link=b, tau=0.1)

        np.testing.assert_almost_equal(a.p.array, 0.55)
        np.testing.assert_almost_equal(b.p.array, 1.0)

        copy_param.soft_copy_param(target_link=a, source_link=b, tau=0.1)

        np.testing.assert_almost_equal(a.p.array, 0.595)
        np.testing.assert_almost_equal(b.p.array, 1.0)
Esempio n. 3
0
    def test_soft_copy_param_type_check(self):
        a = L.Linear(None, 5)
        b = L.Linear(1, 5)

        with self.assertRaises(TypeError):
            copy_param.soft_copy_param(target_link=a, source_link=b, tau=0.1)
Esempio n. 4
0
 def sync_parameters(self):
     copy_param.copy_param(target_link=self.model,
                           source_link=self.shared_model)
     copy_param.soft_copy_param(target_link=self.shared_average_model,
                                source_link=self.model,
                                tau=1 - self.trust_region_alpha)