Exemplo n.º 1
0
def layer_alignment(model, output_fn, loader, n_output, centering=True):
    lc = LayerCollection.from_model(model)
    alignments = []

    targets = torch.cat([args[1] for args in iter(loader)])
    targets = one_hot(targets).float()
    targets -= targets.mean(dim=0)
    targets = FVector(vector_repr=targets.t().contiguous())

    for l in lc.layers.items():
        # print(l)
        lc_this = LayerCollection()
        lc_this.add_layer(*l)

        generator = Jacobian(layer_collection=lc_this,
                             model=model,
                             loader=loader,
                             function=output_fn,
                             n_output=n_output,
                             centering=centering)

        K_dense = FMatDense(generator)
        yTKy = K_dense.vTMv(targets)
        frobK = K_dense.frobenius_norm()

        align = yTKy / (frobK * torch.norm(targets.get_flat_representation())**2)

        alignments.append(align.item())

    return alignments
Exemplo n.º 2
0
def get_fullyconnect_onlylast_task():
    train_loader, lc_full, _, net, output_fn, n_output = \
        get_fullyconnect_task()
    layer_collection = LayerCollection()
    # only keep last layer parameters
    layer_collection.add_layer(*lc_full.layers.popitem())
    parameters = net.net[-1].parameters()

    return train_loader, layer_collection, parameters, net, output_fn, n_output
Exemplo n.º 3
0
def get_batchnorm_conv_linear_task():
    train_set = get_mnist()
    train_set = Subset(train_set, range(70))
    train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False)
    net = BatchNormConvLinearNet()
    to_device_model(net)
    net.eval()

    def output_fn(input, target):
        return net(to_device(input))

    lc_full = LayerCollection.from_model(net)
    layer_collection = LayerCollection()
    # only keep fc1 and fc2
    layer_collection.add_layer(*lc_full.layers.popitem())
    layer_collection.add_layer(*lc_full.layers.popitem())
    parameters = list(net.conv2.parameters()) + \
        list(net.conv1.parameters())

    return (train_loader, layer_collection, parameters, net, output_fn, 2)