def test_set_cob(network, model_name, input_shape=(1, 1, 28, 28), verbose=False):
    """
        Test if the set_change_of_basis method works.

    Args:
        network (nn.Module): Network to be tested
        model_name (str): The name or label assigned to differentiate the model
        input_shape (tuple): Input shape of network
        verbose (bool): Flag to print comparision between network and a teleportation

    """
    x = torch.rand(input_shape)
    model = NeuralTeleportationModel(network, input_shape=input_shape)
    model.random_teleport()
    w1 = model.get_weights()
    t1 = model.get_cob()
    pred1 = model(x)

    model.reset_weights()
    pred2 = model(x)

    model.set_weights(w1)
    model.teleport_activations(t1)

    pred3 = model(x)

    if verbose:
        print("Diff prediction average: ", (pred1 - pred3).mean())
        print("Pre teleportation: ", pred1.flatten()[:10])
        print("Post teleportation: ", pred3.flatten()[:10])

    assert not np.allclose(pred1.detach().numpy(), pred2.detach().numpy(), atol=1e-5)
    assert np.allclose(pred1.detach().numpy(), pred3.detach().numpy(), atol=1e-5), "Set cob/weights did not work."

    print("Set cob successful for " + model_name + " model.")
def test_reset_weights(network: nn.Module, input_shape: Tuple = (1, 1, 28, 28), model_name: str = None):
    """
        test_reset_weights checks if method reset_weights() in NeuralTeleportationModel works

    Args:
        network (nn.Module): Network to be tested
        input_shape (tuple): Input shape of network
        model_name (str): The name or label assigned to differentiate the model

    """
    model_name = model_name or network.__class__.__name__
    model = NeuralTeleportationModel(network, input_shape=input_shape)
    w1 = model.get_weights().detach().numpy()
    model.reset_weights()
    w2 = model.get_weights().detach().numpy()

    assert not np.allclose(w1, w2)
    print("Reset weights successful for " + model_name + " model.")