class TestObj_SmoothedLasso_Gradient_lowtemp(unittest.TestCase):
    def setUp(self):
        _init_lasso(self)
        self.obj = SmoothedLasso_Gradient(self.hparams)

    def test_oracle(self):
        cache_refs = [{
            'dw': torch.tensor([[0.7357], [1.4836], [-0.3669]]),
            'obj': torch.tensor(1.3400)
        }, {
            'dw': torch.tensor([[0.7414], [1.4836], [-0.3669]]),
            'obj': torch.tensor(1.3397)
        }, {
            'dw': torch.tensor([[0.7414], [1.4836], [-0.3669]]),
            'obj': torch.tensor(1.3397)
        }, {
            'dw': torch.tensor([[0.7414], [1.4836], [-0.3669]]),
            'obj': torch.tensor(1.3397)
        }, {
            'dw': torch.tensor([[0.7414], [1.4836], [-0.3669]]),
            'obj': torch.tensor(1.3397)
        }]
        temps = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5]

        for temp, cache_ref in zip(temps, cache_refs):
            self.hparams.temp = temp
            self.obj = SmoothedLasso_Gradient(self.hparams)
            cache_test = self.obj.oracle(self.w, self.x, self.y)
            assert_all_close_dict(
                cache_ref, cache_test,
                "oracle_info with parameter temp={}".format(temp))
            if self.w.grad is not None:
                self.w.grad.zero_()
class TestObj_SmoothedLasso_Gradient(unittest.TestCase):
    def setUp(self):
        _init_lasso(self)
        self.obj = SmoothedLasso_Gradient(self.hparams)

    def test_error(self):
        error_test = self.obj.task_error(self.w, self.x, self.y)
        error_ref = torch.tensor(1.3251)
        assert_all_close(error_test, error_ref, "task_error returned value")

    def test_oracle(self):
        cache_refs = [{
            'dw': torch.tensor([[0.7357], [1.4836], [-0.3669]]),
            'obj': torch.tensor(1.3400)
        }, {
            'dw': torch.tensor([[0.7319], [1.4774], [-0.3645]]),
            'obj': torch.tensor(1.3511)
        }, {
            'dw': torch.tensor([[0.7315], [1.4740], [-0.3579]]),
            'obj': torch.tensor(1.5336)
        }]
        temps = [0.1, 1, 10]

        for temp, cache_ref in zip(temps, cache_refs):
            self.hparams.temp = temp
            self.obj = SmoothedLasso_Gradient(self.hparams)
            cache_test = self.obj.oracle(self.w, self.x, self.y)
            assert_all_close_dict(
                cache_ref, cache_test,
                "oracle_info with parameter temp={}".format(temp))
            if self.w.grad is not None:
                self.w.grad.zero_()