def test_calculate_cob_weights(network,
                               model_name=None,
                               input_shape=(1, 1, 28, 28),
                               noise=False,
                               verbose=True):
    """
        Test if a cob can be calculated and applied to a network to teleport the network from the initial weights to
        the targets weights.

    Args:
        network (nn.Module): Network to be tested
        model_name (str): The name or label assigned to differentiate the model
        input_shape (tuple): Input shape of network
        noise (bool): whether to add noise to the target weights before optimisation.
        verbose (bool): whether to display sample ouputs during the test
    """
    model_name = model_name or network.__class__.__name__
    model = NeuralTeleportationModel(network=network, input_shape=input_shape)

    initial_weights = model.get_weights()
    w1 = model.get_weights(concat=False, flatten=False, bias=False)

    model.random_teleport()
    c1 = model.get_cob()
    model.random_teleport()
    c2 = model.get_cob()

    target_weights = model.get_weights()
    w2 = model.get_weights(concat=False, flatten=False, bias=False)

    if noise:
        for w in w2:
            w += torch.rand(w.shape) * 0.001

    calculated_cob = model.calculate_cob(w1, w2)

    model.initialize_cob()
    model.set_weights(initial_weights)
    model.teleport(calculated_cob, reset_teleportation=True)

    calculated_weights = model.get_weights()

    error = (calculated_weights - initial_weights).abs().mean()

    if verbose:
        print("weights: ", target_weights.flatten())
        print("Calculated cob weights: ", calculated_weights.flatten())
        print("Weight error ", error)
        print("C1: ", c1.flatten()[:10])
        print("C2: ", c2.flatten()[:10])
        print("C1 * C2: ", (c1 * c2).flatten()[:10])
        print("Calculated cob: ", calculated_cob.flatten()[:10])

    assert np.allclose(calculated_weights.detach().numpy(), target_weights.detach().numpy()), \
        "Calculate cob and weights FAILED for " + model_name + " model with error: " + str(error.item())

    print("Calculate cob and weights successful for " + model_name + " model.")
示例#2
0
def test_set_cob(network, model_name, input_shape=(1, 1, 28, 28), verbose=False):
    """
        Test if the set_change_of_basis method works.

    Args:
        network (nn.Module): Network to be tested
        model_name (str): The name or label assigned to differentiate the model
        input_shape (tuple): Input shape of network
        verbose (bool): Flag to print comparision between network and a teleportation

    """
    x = torch.rand(input_shape)
    model = NeuralTeleportationModel(network, input_shape=input_shape)
    model.random_teleport()
    w1 = model.get_weights()
    t1 = model.get_cob()
    pred1 = model(x)

    model.reset_weights()
    pred2 = model(x)

    model.set_weights(w1)
    model.teleport_activations(t1)

    pred3 = model(x)

    if verbose:
        print("Diff prediction average: ", (pred1 - pred3).mean())
        print("Pre teleportation: ", pred1.flatten()[:10])
        print("Post teleportation: ", pred3.flatten()[:10])

    assert not np.allclose(pred1.detach().numpy(), pred2.detach().numpy(), atol=1e-5)
    assert np.allclose(pred1.detach().numpy(), pred3.detach().numpy(), atol=1e-5), "Set cob/weights did not work."

    print("Set cob successful for " + model_name + " model.")
def test_calculate_ones(network,
                        model_name=None,
                        input_shape=(1, 1, 28, 28),
                        noise=False,
                        verbose=False):
    """
        Test if the correct change of basis can be calculated for a cob of ones.

    Args:
        network (nn.Module): Network to be tested
        model_name (str): The name or label assigned to differentiate the model
        input_shape (tuple): Input shape of network
        noise (bool): whether to add noise to the target weights before optimisation.
        verbose (bool): whether to display sample ouputs during the test
    """
    model_name = model_name or network.__class__.__name__
    model = NeuralTeleportationModel(network=network, input_shape=input_shape)

    model.initialize_cob()

    w1 = model.get_weights(concat=False, flatten=False, bias=False)
    _w1 = model.get_weights(concat=False, flatten=False, bias=False)

    if noise:
        for w in _w1:
            w += torch.rand(w.shape) * 0.001

    cob = model.get_cob()
    calculated_cob = model.calculate_cob(w1, _w1)

    error = (cob - calculated_cob).abs().mean()

    if verbose:
        print("Cob: ", cob.flatten()[:10])
        print("Calculated cob: ", calculated_cob.flatten()[:10])
        print("cob error ", (calculated_cob - cob).flatten()[:10])
        print("cob error : ", error)

    assert np.allclose(
        cob, calculated_cob
    ), "Calculate cob (ones) FAILED for " + model_name + " model."

    print("Calculate cob (ones) successful for " + model_name + " model.")
    args = argument_parser()

    torch.manual_seed(args.seed)

    model = NeuralTeleportationModel(network=MLPCOB(input_shape=(1, 28, 28),
                                                    num_classes=10),
                                     input_shape=(1, 1, 28, 28))

    # Get the initial set of weights and teleport.
    initial_weights = model.get_weights()
    model.random_teleport(cob_range=args.cob_range)

    # Get second set of weights (target weights)
    target_weights = model.get_weights()
    # Get the change of basis that created this set of weights.
    target_cob = model.get_cob(concat=True)

    # Generate a new random cob
    cob = model.generate_random_cob(cob_range=args.cob_range,
                                    requires_grad=True)

    history = []
    cob_error_history = []

    print("Initial error: ", (cob - target_cob).abs().mean().item())
    print("Target cob sample: ", target_cob[0:10].data)
    print("cob sample: ", cob[0:10].data)

    optimizer = optim.Adam([cob], lr=args.lr)
    """
    Optimize the cob to find the 'target_cob' that produced the original teleportation.