def test_weights_stay_tied(self, weights_and_gradient_and_groups): weights, gradient, groups = weights_and_gradient_and_groups tying = optimization.WeightTying(*groups) decent = optimization.GradientDecent() weights = tying(weights) weights = decent(weights, gradient, 0.1) self._is_tied(weights, groups)
def test_copy_data(self, weights_and_gradient_and_groups): weights, _, groups = weights_and_gradient_and_groups tying = optimization.WeightTying(*groups) before = weights.copy() updated = tying(weights) assert (before.flat == weights.flat).all() assert updated.flat[0] != 42 weights.flat[0] = 42 assert updated.flat[0] != 42
def test_dont_affect_others(self, weights_and_gradient_and_groups): weights, _, _ = weights_and_gradient_and_groups if len(weights.shapes) < 2: pytest.skip() group = (np.s_[0, :, :], np.s_[1, :, :]) tying = optimization.WeightTying(group) updated = tying(weights) assert (updated[0] == updated[1]).all() for before, after in zip(weights[2:], updated[2:]): assert (before == after).all()
def test_shapes_match(self, weights_and_gradient_and_groups): weights, _, groups = weights_and_gradient_and_groups tying = optimization.WeightTying(*groups) updated = tying(weights) assert weights.shapes == updated.shapes
def test_calculation(self, weights_and_gradient_and_groups): weights, _, groups = weights_and_gradient_and_groups tying = optimization.WeightTying(*groups) updated = tying(weights) self._is_tied(updated, groups)