예제 #1
0
def load_components(class_name, inst):
    global components, state, use_named_latents

    config = args.from_dict({'output_class': class_name})
    dump_name = get_or_compute(config, inst)
    data = np.load(dump_name, allow_pickle=False)
    X_comp = data['act_comp']
    X_mean = data['act_mean']
    X_stdev = data['act_stdev']
    Z_comp = data['lat_comp']
    Z_mean = data['lat_mean']
    Z_stdev = data['lat_stdev']
    random_stdev_act = np.mean(data['random_stdevs'])
    n_comp = X_comp.shape[0]
    data.close()

    # Transfer to GPU
    components = SimpleNamespace(
        X_comp=torch.from_numpy(X_comp).cuda().float(),
        X_mean=torch.from_numpy(X_mean).cuda().float(),
        X_stdev=torch.from_numpy(X_stdev).cuda().float(),
        Z_comp=torch.from_numpy(Z_comp).cuda().float(),
        Z_stdev=torch.from_numpy(Z_stdev).cuda().float(),
        Z_mean=torch.from_numpy(Z_mean).cuda().float(),
        names=[f'Component {i}' for i in range(n_comp)],
        latent_types=[model.latent_space_name()] * n_comp,
        ranges=[(0, model.get_max_latents())] * n_comp,
    )

    state.component_class = class_name  # invalidates cache
    use_named_latents = False
    print('Loaded components for', class_name, 'from', dump_name)
예제 #2
0
    latent_shape = model.get_latent_shape()
    print('Feature shape:', feature_shape)

    # Layout of activations
    if len(feature_shape) != 4:  # non-spatial
        axis_mask = np.ones(len(feature_shape), dtype=np.int32)
    else:
        axis_mask = np.array(
            [0, 1, 1, 1])  # only batch fixed => whole activation volume used

    # Shape of sample passed to PCA
    sample_shape = feature_shape * axis_mask
    sample_shape[sample_shape == 0] = 1

    # Load or compute components
    dump_name = get_or_compute(args, inst)
    data = np.load(dump_name,
                   allow_pickle=False)  # does not contain object arrays
    X_comp = data['act_comp']
    X_global_mean = data['act_mean']
    X_stdev = data['act_stdev']
    X_var_ratio = data['var_ratio']
    X_stdev_random = data['random_stdevs']
    Z_global_mean = data['lat_mean']
    Z_comp = data['lat_comp']
    Z_stdev = data['lat_stdev']
    n_comp = X_comp.shape[0]
    data.close()

    # Transfer components to device
    tensors = SimpleNamespace(