def test_Linear_with_initial():

    Linear1_prev = LinearBayesianGaussian(10,
                                          100,
                                          LinearBayesianGaussian_init=False,
                                          p=1)
    Linear1 = LinearBayesianGaussian(10, 100, Linear1_prev, p=1)

    assert (Linear1_prev.rho.weight.data.numpy() ==
            Linear1_prev.rho.weight.data.numpy()).all()
def test_Linear_stack():

    Linear_stack = LinearBayesianGaussian(10,
                                          10,
                                          LinearBayesianGaussian_init=False,
                                          p=1)
    output = Linear_stack(
        torch.tensor(np.random.uniform(1, 1, (20, 10)), dtype=torch.float64))

    mu_stack, rho_stack, w_stack = Linear_stack.stack()

    assert ((mu_stack.shape == w_stack.shape)
            and (rho_stack.data.numpy().shape[0] == 110))
def test_Linear_multi_input():

    Linear1 = LinearBayesianGaussian(10,
                                     1,
                                     LinearBayesianGaussian_init=False,
                                     p=1)

    with torch.no_grad():
        Linear1.mu.weight.copy_(
            torch.tensor(np.random.uniform(2, 2, (1, 10)),
                         dtype=torch.float64))
        Linear1.mu.bias.copy_(
            torch.tensor(np.random.uniform(0, 0, (1)), dtype=torch.float64))
        #Linear1.rho.weight.copy_( torch.tensor( np.random.uniform( 0, 1, (100, 10) ), dtype=torch.float64 ) )

    output = Linear1(
        torch.tensor(np.random.uniform(1, 1, (20, 10)), dtype=torch.float64))

    #print( (output.data.numpy()[0, :] == output.data.numpy()).all() )

    # The derivative of a composition of function in this case is given by all 0 because the inputs are 0
    # The reparam trick add to this derivative the derivative of the loss function wrt mu that is in this case
    # given by all 2*2 (all the mu are 2 and then derive a square function)
    #
    # print(Linear1.mu.weight.grad)

    assert ((output.data.numpy()[0, :] == output.data.numpy()).all()
            and output.data.numpy().shape[0] == 20)
def test_Linear_reparam_trick_be():

    Linear1 = LinearBayesianGaussian(10,
                                     100,
                                     LinearBayesianGaussian_init=False,
                                     p=0.5)

    with torch.no_grad():
        Linear1.mu.weight.copy_(
            torch.tensor(np.random.uniform(2, 2, (100, 10)),
                         dtype=torch.float64))
        Linear1.rho.weight.copy_(
            torch.tensor(np.random.uniform(-20, -20, (100, 10)),
                         dtype=torch.float64))

    output = Linear1(
        torch.tensor(np.random.uniform(1, 1, (10)), dtype=torch.float64))

    loss = output.sum() + (Linear1.mu.weight * Linear1.mu.weight).sum() + (
        2 * Linear1.rho.weight).sum()
    loss.backward()

    # The derivative of a composition of function in this case is given by all 0 because the inputs are 0
    # The reparam trick add to this derivative the derivative of the loss function wrt mu that is in this case
    # given by all 2*2 (all the mu are 2 and then derive a square function)
    #
    # print((Linear1.mu.weight.grad.data.numpy() == 5).sum()/1000)
    # print(Linear1.rho.weight.grad)

    assert (
        ((Linear1.mu.weight.grad.data.numpy() == 5).sum() / 1000) < 0.6
        and ((Linear1.mu.weight.grad.data.numpy() == 5).sum() / 1000) > 0.4 and
        (np.abs(Linear1.rho.weight.grad.data.numpy() - 2) < np.exp(-15)).all())
def test_Linear_without_initial():

    Linear1 = LinearBayesianGaussian(10,
                                     100,
                                     LinearBayesianGaussian_init=False,
                                     p=1)

    assert ((Linear1.mu.weight.data.numpy().shape[1] == 10
             and Linear1.mu.weight.data.numpy().shape[0] == 100
             and Linear1.mu.bias.data.numpy().shape[0] == 100)
            and (Linear1.rho.weight.data.numpy().shape[1] == 10
                 and Linear1.rho.weight.data.numpy().shape[0] == 100
                 and Linear1.rho.bias.data.numpy().shape[0] == 100))
def test_Linear_w_values():

    Linear1 = LinearBayesianGaussian(10,
                                     100,
                                     LinearBayesianGaussian_init=False,
                                     p=1)

    with torch.no_grad():
        Linear1.mu.weight.copy_(
            torch.tensor(np.random.uniform(1, 1, (100, 10)),
                         dtype=torch.float64))

    output = Linear1(
        torch.tensor(np.random.uniform(1, 1, (10)), dtype=torch.float64))

    assert (output.data.numpy() > 9).all() and (output.data.numpy() < 11).all()