def sync_parameters(self): copy_param.copy_param(target_link=self.model, source_link=self.shared_model) copy_param.soft_copy_param( target_link=self.shared_average_model, source_link=self.model, tau=1 - self.trust_region_alpha, )
def test_copy_param_shape_check(self): a = nn.Linear(2, 5) b = nn.Linear(1, 5) with self.assertRaises(RuntimeError): # Different shape copy_param.copy_param(a, b) with self.assertRaises(RuntimeError): # Different shape copy_param.copy_param(b, a)
def test_copy_param_scalar(self): a = nn.Module() a.p = nn.Parameter(torch.Tensor([1])) b = nn.Module() b.p = nn.Parameter(torch.Tensor([2])) self.assertNotEqual(a.p.detach().numpy(), b.p.detach().numpy()) # Copy b's parameters to a copy_param.copy_param(a, b) self.assertEqual(a.p.detach().numpy(), b.p.detach().numpy())
def test_copy_param(self): a = nn.Linear(1, 5) b = nn.Linear(1, 5) s = torch.from_numpy(np.random.rand(1, 1).astype(np.float32)) a_out = list(a(s).detach().numpy().ravel()) b_out = list(b(s).detach().numpy().ravel()) self.assertNotEqual(a_out, b_out) # Copy b's parameters to a copy_param.copy_param(a, b) a_out_new = list(a(s).detach().numpy().ravel()) b_out_new = list(b(s).detach().numpy().ravel()) self.assertEqual(a_out_new, b_out) self.assertEqual(b_out_new, b_out)
def load(self, dirname): super().load(dirname) copy_param.copy_param(target_link=self.shared_model, source_link=self.model)
def sync_parameters(self): copy_param.copy_param(target_link=self.model, source_link=self.shared_model)