Пример #1
0
 def test_weights_stay_tied(self, weights_and_gradient_and_groups):
     weights, gradient, groups = weights_and_gradient_and_groups
     tying = optimization.WeightTying(*groups)
     decent = optimization.GradientDecent()
     weights = tying(weights)
     weights = decent(weights, gradient, 0.1)
     self._is_tied(weights, groups)
Пример #2
0
 def test_copy_data(self, weights_and_gradient_and_groups):
     weights, _, groups = weights_and_gradient_and_groups
     tying = optimization.WeightTying(*groups)
     before = weights.copy()
     updated = tying(weights)
     assert (before.flat == weights.flat).all()
     assert updated.flat[0] != 42
     weights.flat[0] = 42
     assert updated.flat[0] != 42
Пример #3
0
 def test_dont_affect_others(self, weights_and_gradient_and_groups):
     weights, _, _ = weights_and_gradient_and_groups
     if len(weights.shapes) < 2:
         pytest.skip()
     group = (np.s_[0, :, :], np.s_[1, :, :])
     tying = optimization.WeightTying(group)
     updated = tying(weights)
     assert (updated[0] == updated[1]).all()
     for before, after in zip(weights[2:], updated[2:]):
         assert (before == after).all()
Пример #4
0
 def test_shapes_match(self, weights_and_gradient_and_groups):
     weights, _, groups = weights_and_gradient_and_groups
     tying = optimization.WeightTying(*groups)
     updated = tying(weights)
     assert weights.shapes == updated.shapes
Пример #5
0
 def test_calculation(self, weights_and_gradient_and_groups):
     weights, _, groups = weights_and_gradient_and_groups
     tying = optimization.WeightTying(*groups)
     updated = tying(weights)
     self._is_tied(updated, groups)