Beispiel #1
0
                batch_size=batch_size,
                encoding=args.spatial_encoding,
                encoding_func=encoding_func,
                encoding_dim=args.dim,
                train_split=args.train_split,
                hd_dim=args.n_hd_cells,
                hd_encoding_func=hd_encoding_func,
                sin_cos_ang=args.sin_cos_ang,
            )
        else:
            trainloader, testloader = train_test_loaders(
                data,
                n_train_samples=n_samples,
                n_test_samples=n_samples,
                rollout_length=rollout_length,
                batch_size=batch_size,
                encoding=args.spatial_encoding,
                encoding_func=encoding_func,
                encoding_dim=args.dim,
                train_split=args.train_split,
            )

    if args.allow_cache:

        if not os.path.exists('dataset_cache'):
            os.makedirs('dataset_cache')

        np.savez(
            cache_fname,
            train_velocity_inputs=trainloader.dataset.velocity_inputs,
            train_ssp_inputs=trainloader.dataset.ssp_inputs,
Beispiel #2
0
# n_samples = 5000
n_samples = args.n_samples  #1000
rollout_length = args.trajectory_length  #100
batch_size = args.minibatch_size  #10

model = SSPPathIntegrationModel(unroll_length=rollout_length,
                                sp_dim=encoding_dim)

if args.model:
    model.load_state_dict(torch.load(args.model), strict=False)

trainloader, testloader = train_test_loaders(
    data,
    n_train_samples=n_samples,
    n_test_samples=n_samples,
    rollout_length=rollout_length,
    batch_size=batch_size,
    encoding=args.encoding,
)

print("Testing")
with torch.no_grad():
    # Everything is in one batch, so this loop will only happen once
    for i, data in enumerate(testloader):
        velocity_inputs, ssp_inputs, ssp_outputs = data

        ssp_pred, lstm_outputs = model.forward_activations(
            velocity_inputs, ssp_inputs)

    print("ssp_pred.shape", ssp_pred.shape)
    print("ssp_outputs.shape", ssp_outputs.shape)
Beispiel #3
0
def run_and_gather_activations(
        seed=13,
        n_samples=1000,
        dataset='../../lab/reproducing/data/path_integration_trajectories_logits_200t_15s_seed13.npz',
        model_path='../output/ssp_path_integration/clipped/Mar22_15-24-10/ssp_path_integration_model.pt',
        encoding='ssp',
        rollout_length=100,
        batch_size=10,
        n_place_cells=256,
        encoding_func=None,  # added for frozen-learned encoding option

):

    torch.manual_seed(seed)
    np.random.seed(seed)

    data = np.load(dataset)

    x_axis_vec = data['x_axis_vec']
    y_axis_vec = data['y_axis_vec']

    pc_centers = data['pc_centers']
    #pc_activations = data['pc_activations']

    if encoding == 'ssp':
        encoding_dim = 512
        ssp_scaling = data['ssp_scaling']
    elif encoding == '2d':
        encoding_dim = 2
        ssp_scaling = 1
    elif encoding == 'pc':
        dim = n_place_cells
        ssp_scaling = 1
    elif encoding == 'frozen-learned':
        encoding_dim = 512
        ssp_scaling = 1
    elif encoding == 'pc-gauss' or encoding == 'pc-gauss-softmax':
        encoding_dim = 512
        ssp_scaling = 1
    else:
        raise NotImplementedError

    limit_low = 0 * ssp_scaling
    limit_high = 2.2 * ssp_scaling
    res = 128 #256

    xs = np.linspace(limit_low, limit_high, res)
    ys = np.linspace(limit_low, limit_high, res)

    if encoding == 'frozen-learned' or encoding == 'pc-gauss' or encoding == 'pc-gauss-softmax':
        # encoding for every point in a 2D linspace, for approximating a readout

        # FIXME: inefficient but will work for now
        heatmap_vectors = np.zeros((len(xs), len(ys), 512))

        for i, x in enumerate(xs):
            for j, y in enumerate(ys):
                heatmap_vectors[i, j, :] = encoding_func(
                    # batch dim
                    # np.array(
                    #     [[x, y]]
                    # )
                    # no batch dim
                    np.array(
                        [x, y]
                    )
                )

                heatmap_vectors[i, j, :] /= np.linalg.norm(heatmap_vectors[i, j, :])

    else:
        # Used for visualization of test set performance using pos = ssp_to_loc(sp, heatmap_vectors, xs, ys)
        heatmap_vectors = get_heatmap_vectors(xs, ys, x_axis_vec, y_axis_vec)


    model = SSPPathIntegrationModel(unroll_length=rollout_length, sp_dim=encoding_dim)

    model.load_state_dict(torch.load(model_path), strict=False)

    trainloader, testloader = train_test_loaders(
        data,
        n_train_samples=n_samples,
        n_test_samples=n_samples,
        rollout_length=rollout_length,
        batch_size=batch_size,
        encoding=encoding,
        encoding_func=encoding_func,
    )

    print("Testing")
    with torch.no_grad():
        # Everything is in one batch, so this loop will only happen once
        for i, data in enumerate(testloader):
            velocity_inputs, ssp_inputs, ssp_outputs = data

            ssp_pred, lstm_outputs = model.forward_activations(velocity_inputs, ssp_inputs)


        predictions = np.zeros((ssp_pred.shape[0]*ssp_pred.shape[1], 2))
        coords = np.zeros((ssp_pred.shape[0]*ssp_pred.shape[1], 2))
        activations = np.zeros((ssp_pred.shape[0]*ssp_pred.shape[1], model.lstm_hidden_size))

        assert rollout_length == ssp_pred.shape[0]

        # # For each neuron, contains the average activity at each spatial bin
        # # Computing for both ground truth and predicted location
        # rate_maps_pred = np.zeros((model.lstm_hidden_size, len(xs), len(ys)))
        # rate_maps_truth = np.zeros((model.lstm_hidden_size, len(xs), len(ys)))

        print("Computing predicted locations and true locations")
        # Using all data, one chunk at a time
        for ri in range(rollout_length):

            if encoding == 'ssp':
                # computing 'predicted' coordinates, where the agent thinks it is
                predictions[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_to_loc_v(
                    ssp_pred.detach().numpy()[ri, :, :],
                    heatmap_vectors, xs, ys
                )

                # computing 'ground truth' coordinates, where the agent should be
                coords[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_to_loc_v(
                    ssp_outputs.detach().numpy()[:, ri, :],
                    heatmap_vectors, xs, ys
                )
            elif encoding == '2d':
                # copying 'predicted' coordinates, where the agent thinks it is
                predictions[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_pred.detach().numpy()[ri, :, :]

                # copying 'ground truth' coordinates, where the agent should be
                coords[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_outputs.detach().numpy()[:, ri, :]
            elif encoding == 'pc':
                # (quick hack is to just use the most activated place cell center)
                predictions[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = pc_to_loc_v(
                    pc_activations=ssp_outputs.detach().numpy()[:, ri, :],
                    centers=pc_centers,
                    jitter=0.01,
                )

                coords[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = pc_to_loc_v(
                    pc_activations=ssp_outputs.detach().numpy()[:, ri, :],
                    centers=pc_centers,
                    jitter=0.01,
                )
            elif encoding == 'frozen-learned' or encoding == 'pc-gauss' or encoding == 'pc-gauss-softmax':
                # computing 'predicted' coordinates, where the agent thinks it is
                pred = ssp_pred.detach().numpy()[ri, :, :]
                pred = pred / pred.sum(axis=1)[:, np.newaxis]
                predictions[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_to_loc_v(
                    pred,
                    heatmap_vectors, xs, ys
                )

                # computing 'ground truth' coordinates, where the agent should be
                coord = ssp_outputs.detach().numpy()[:, ri, :]
                coord = coord / coord.sum(axis=1)[:, np.newaxis]
                coords[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_to_loc_v(
                    coord,
                    heatmap_vectors, xs, ys
                )

            # reshaping activations and converting to numpy array
            activations[ri*ssp_pred.shape[1]:(ri+1)*ssp_pred.shape[1], :] = lstm_outputs.detach().numpy()[ri, :, :]

    return activations, predictions, coords