Exemplo n.º 1
0
Arquivo: acer.py Projeto: pfnet/pfrl
 def sync_parameters(self):
     copy_param.copy_param(target_link=self.model, source_link=self.shared_model)
     copy_param.soft_copy_param(
         target_link=self.shared_average_model,
         source_link=self.model,
         tau=1 - self.trust_region_alpha,
     )
Exemplo n.º 2
0
    def test_copy_param_shape_check(self):
        a = nn.Linear(2, 5)
        b = nn.Linear(1, 5)

        with self.assertRaises(RuntimeError):
            # Different shape
            copy_param.copy_param(a, b)

        with self.assertRaises(RuntimeError):
            # Different shape
            copy_param.copy_param(b, a)
Exemplo n.º 3
0
    def test_copy_param_scalar(self):
        a = nn.Module()
        a.p = nn.Parameter(torch.Tensor([1]))
        b = nn.Module()
        b.p = nn.Parameter(torch.Tensor([2]))

        self.assertNotEqual(a.p.detach().numpy(), b.p.detach().numpy())

        # Copy b's parameters to a
        copy_param.copy_param(a, b)

        self.assertEqual(a.p.detach().numpy(), b.p.detach().numpy())
Exemplo n.º 4
0
    def test_copy_param(self):
        a = nn.Linear(1, 5)
        b = nn.Linear(1, 5)

        s = torch.from_numpy(np.random.rand(1, 1).astype(np.float32))
        a_out = list(a(s).detach().numpy().ravel())
        b_out = list(b(s).detach().numpy().ravel())
        self.assertNotEqual(a_out, b_out)

        # Copy b's parameters to a
        copy_param.copy_param(a, b)

        a_out_new = list(a(s).detach().numpy().ravel())
        b_out_new = list(b(s).detach().numpy().ravel())
        self.assertEqual(a_out_new, b_out)
        self.assertEqual(b_out_new, b_out)
Exemplo n.º 5
0
Arquivo: a3c.py Projeto: xylee95/pfrl
 def load(self, dirname):
     super().load(dirname)
     copy_param.copy_param(target_link=self.shared_model,
                           source_link=self.model)
Exemplo n.º 6
0
Arquivo: a3c.py Projeto: xylee95/pfrl
 def sync_parameters(self):
     copy_param.copy_param(target_link=self.model,
                           source_link=self.shared_model)