示例#1
0

# fname = '/home/bjkomer/ssp_navigation_sandbox/axis_vector_explorations/heatmap_output/256dim_25seeds.npz'
# data = np.load(fname)
# square_heatmaps = data['square_heatmaps']
# hex_heatmaps = data['hex_heatmaps']
#
# avg_square_heatmap = square_heatmaps.mean(axis=0)
# avg_hex_heatmap = hex_heatmaps.mean(axis=0)

dim = 512
res = 256
limit = 5
xs = np.linspace(-limit, limit, res)

Xh, Yh = get_axes(dim=dim, seed=13)

rng = np.random.RandomState(seed=13)
X = make_good_unitary(dim=dim, rng=rng)
Y = make_good_unitary(dim=dim, rng=rng)
Z = make_good_unitary(dim=dim, rng=rng)

fig, ax = plt.subplots(1, 3, figsize=(8, 4))

sigma_normal = 0.5
sigma_hex = 0.5
sigma_hex_c = 0.5

sim_hex = np.zeros((res, ))
sim_hex_c = np.zeros((res, ))  # this version has axes generated together and then converted to 2D
sim_normal = np.zeros((res, ))
import nengo
import numpy as np
from spatial_semantic_pointers.utils import encode_point, make_good_unitary, get_axes, generate_region_vector, get_heatmap_vectors
from spatial_semantic_pointers.plots import SpatialHeatmap
from ssp_navigation.utils.encodings import hilbert_2d

# generate_region_vector(desired, xs, ys, x_axis_sp, y_axis_sp, normalize=True)

dim = 128  #256

limit_low = -5
limit_high = 5

rng = np.random.RandomState(seed=13)
X, Y = get_axes(dim=dim, n=3, seed=13, period=0, optimal_phi=False)


def to_ssp(v):

    return encode_point(v[0], v[1], X, Y).v


# 3 directions 120 degrees apart
vec_dirs = [0, 2 * np.pi / 3, 4 * np.pi / 3]
spacing = 4


def to_hex_region_ssp(v, spacing=4):

    ret = np.zeros((dim, ))
    ret[:] = encode_point(v[0], v[1], X, Y).v
path_prefix = '/media/ctnuser/53f2c4b3-4b3b-4768-ba69-f0a3da30c237/ctnuser/data/neural_implementation_output'

# if not os.path.exists('output'):
#     os.makedirs('output')

diff_axis = True
grid_axes = False


rng = np.random.RandomState(seed=13)

Xs = []
Ys = []
for i in range(args.n_envs):
    X, Y = get_axes(dim=args.dim, n=3, seed=13+i, period=0, optimal_phi=False)
    Xs.append(X)
    Ys.append(Y)


# if grid_axes:
#     X, Y = get_fixed_dim_grid_axes(dim=args.dim, seed=13)
# else:
#     X, Y = get_axes(dim=args.dim, n=3, seed=13, period=0, optimal_phi=False)
# # X_new, Y_new = get_axes(dim=dim, n=3, seed=14, period=0, optimal_phi=False)



def to_ssp(v):

    return encode_point(v[0], v[1], X, Y).v
# only the 4th channel has the data, values from 0 to 255
desired_complex = resize(
    imageio.imread('assets/icons8-star-96.png') / 255, (res, res))[:, :, 3]

# desired_complex[128-48:128+48, 128-48:128+48] = star_im[:, :, 3]

# print(star_im.shape)
# print(np.min(star_im[:, :, 3]))
# print(np.max(star_im[:, :, 3]))

rng = np.random.RandomState(seed=seed)

# X = make_good_unitary(dim=dim, rng=rng)
# Y = make_good_unitary(dim=dim, rng=rng)

X, Y = get_axes(dim=dim, n=3, seed=seed)

hmv = get_heatmap_vectors(xs, ys, X, Y)

circular_sp = generate_region_vector(desired_circular,
                                     xs,
                                     ys,
                                     X,
                                     Y,
                                     normalize=True)
rectangular_sp = generate_region_vector(desired_rectangular,
                                        xs,
                                        ys,
                                        X,
                                        Y,
                                        normalize=True)
def main():
    parser = argparse.ArgumentParser(
        'Train a network to clean up a noisy spatial semantic pointer')

    parser.add_argument('--loss',
                        type=str,
                        default='cosine',
                        choices=['cosine', 'mse'])
    parser.add_argument('--noise-type',
                        type=str,
                        default='memory',
                        choices=['memory', 'gaussian', 'both'])
    parser.add_argument(
        '--sigma',
        type=float,
        default=1.0,
        help='sigma on the gaussian noise if noise-type==gaussian')
    parser.add_argument('--train-fraction',
                        type=float,
                        default=.8,
                        help='proportion of the dataset to use for training')
    parser.add_argument(
        '--n-samples',
        type=int,
        default=10000,
        help=
        'Number of memories to generate. Total samples will be n-samples * n-items'
    )
    parser.add_argument('--n-items',
                        type=int,
                        default=12,
                        help='number of items in memory. Proxy for noisiness')
    parser.add_argument('--dim',
                        type=int,
                        default=512,
                        help='Dimensionality of the semantic pointers')
    parser.add_argument('--hidden-size',
                        type=int,
                        default=512,
                        help='Hidden size of the cleanup network')
    parser.add_argument('--limits',
                        type=str,
                        default="-5,5,-5,5",
                        help='The limits of the space')
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--seed', type=int, default=13)
    parser.add_argument('--logdir',
                        type=str,
                        default='ssp_cleanup',
                        help='Directory for saved model and tensorboard log')
    parser.add_argument('--load-model',
                        type=str,
                        default='',
                        help='Optional model to continue training from')
    parser.add_argument(
        '--name',
        type=str,
        default='',
        help=
        'Name of output folder within logdir. Will use current date and time if blank'
    )
    parser.add_argument('--weight-histogram',
                        action='store_true',
                        help='Save histograms of the weights if set')
    parser.add_argument('--use-hex-ssp', action='store_true')
    parser.add_argument('--optimizer',
                        type=str,
                        default='adam',
                        choices=['sgd', 'adam', 'rmsprop'])

    args = parser.parse_args()

    args.limits = tuple(float(v) for v in args.limits.split(','))

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    dataset_name = 'data/ssp_cleanup_dataset_dim{}_seed{}_items{}_limit{}_samples{}.npz'.format(
        args.dim, args.seed, args.n_items, args.limits[1], args.n_samples)

    final_test_samples = 100
    final_test_items = 15
    final_test_dataset_name = 'data/ssp_cleanup_test_dataset_dim{}_seed{}_items{}_limit{}_samples{}.npz'.format(
        args.dim, args.seed, final_test_items, args.limits[1],
        final_test_samples)

    if not os.path.exists('data'):
        os.makedirs('data')

    rng = np.random.RandomState(seed=args.seed)
    if args.use_hex_ssp:
        x_axis_sp, y_axis_sp = get_axes(dim=args.dim, n=3, seed=args.seed)
    else:
        x_axis_sp = make_good_unitary(args.dim, rng=rng)
        y_axis_sp = make_good_unitary(args.dim, rng=rng)

    if args.noise_type == 'gaussian':
        # Simple generation
        clean_ssps = np.zeros((args.n_samples, args.dim))
        coords = np.zeros((args.n_samples, 2))
        for i in range(args.n_samples):
            x = np.random.uniform(low=args.limits[0], high=args.limits[1])
            y = np.random.uniform(low=args.limits[2], high=args.limits[3])

            clean_ssps[i, :] = encode_point(x,
                                            y,
                                            x_axis_sp=x_axis_sp,
                                            y_axis_sp=y_axis_sp).v
            coords[i, 0] = x
            coords[i, 1] = y
        # Gaussian noise will be added later
        noisy_ssps = clean_ssps.copy()
    else:

        if os.path.exists(dataset_name):
            print("Loading dataset")
            data = np.load(dataset_name)
            clean_ssps = data['clean_ssps']
            noisy_ssps = data['noisy_ssps']
        else:
            print("Generating SSP cleanup dataset")
            clean_ssps, noisy_ssps, coords = generate_cleanup_dataset(
                x_axis_sp=x_axis_sp,
                y_axis_sp=y_axis_sp,
                n_samples=args.n_samples,
                dim=args.dim,
                n_items=args.n_items,
                limits=args.limits,
                seed=args.seed,
            )
            print("Dataset generation complete. Saving dataset")
            np.savez(
                dataset_name,
                clean_ssps=clean_ssps,
                noisy_ssps=noisy_ssps,
                coords=coords,
                x_axis_vec=x_axis_sp.v,
                y_axis_vec=x_axis_sp.v,
            )

    # check if the final test set has been generated yet
    if os.path.exists(final_test_dataset_name):
        print("Loading final test dataset")
        final_test_data = np.load(final_test_dataset_name)
        final_test_clean_ssps = final_test_data['clean_ssps']
        final_test_noisy_ssps = final_test_data['noisy_ssps']
    else:
        print("Generating final test dataset")
        final_test_clean_ssps, final_test_noisy_ssps, final_test_coords = generate_cleanup_dataset(
            x_axis_sp=x_axis_sp,
            y_axis_sp=y_axis_sp,
            n_samples=final_test_samples,
            dim=args.dim,
            n_items=final_test_items,
            limits=args.limits,
            seed=args.seed,
        )
        print("Final test generation complete. Saving dataset")
        np.savez(
            final_test_dataset_name,
            clean_ssps=final_test_clean_ssps,
            noisy_ssps=final_test_noisy_ssps,
            coords=final_test_coords,
            x_axis_vec=x_axis_sp.v,
            y_axis_vec=x_axis_sp.v,
        )

    # Add gaussian noise if required
    if args.noise_type == 'gaussian' or args.noise_type == 'both':
        noisy_ssps += np.random.normal(loc=0,
                                       scale=args.sigma,
                                       size=noisy_ssps.shape)

    n_samples = clean_ssps.shape[0]
    n_train = int(args.train_fraction * n_samples)
    n_test = n_samples - n_train
    assert (n_train > 0 and n_test > 0)
    train_clean = clean_ssps[:n_train, :]
    train_noisy = noisy_ssps[:n_train, :]
    test_clean = clean_ssps[n_train:, :]
    test_noisy = noisy_ssps[n_train:, :]

    # NOTE: this dataset is actually generic and can take any input/output mapping
    dataset_train = CoordDecodeDataset(vectors=train_noisy, coords=train_clean)
    dataset_test = CoordDecodeDataset(vectors=test_noisy, coords=test_clean)
    dataset_final_test = CoordDecodeDataset(vectors=final_test_noisy_ssps,
                                            coords=final_test_clean_ssps)

    trainloader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=0,
    )

    # For testing just do everything in one giant batch
    testloader = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=len(dataset_test),
        shuffle=False,
        num_workers=0,
    )

    final_testloader = torch.utils.data.DataLoader(
        dataset_final_test,
        batch_size=len(dataset_final_test),
        shuffle=False,
        num_workers=0,
    )

    model = FeedForward(dim=dataset_train.dim,
                        hidden_size=args.hidden_size,
                        output_size=dataset_train.dim)

    # Open a tensorboard writer if a logging directory is given
    if args.logdir != '':
        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        save_dir = osp.join(args.logdir, current_time)
        writer = SummaryWriter(log_dir=save_dir)
        if args.weight_histogram:
            # Log the initial parameters
            for name, param in model.named_parameters():
                writer.add_histogram('parameters/' + name,
                                     param.clone().cpu().data.numpy(), 0)

    mse_criterion = nn.MSELoss()
    cosine_criterion = nn.CosineEmbeddingLoss()

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    else:
        raise NotImplementedError

    for e in range(args.epochs):
        print('Epoch: {0}'.format(e + 1))

        avg_mse_loss = 0
        avg_cosine_loss = 0
        n_batches = 0
        for i, data in enumerate(trainloader):

            noisy, clean = data

            if noisy.size()[0] != args.batch_size:
                continue  # Drop data, not enough for a batch
            optimizer.zero_grad()

            outputs = model(noisy)

            mse_loss = mse_criterion(outputs, clean)
            # Modified to use CosineEmbeddingLoss
            cosine_loss = cosine_criterion(outputs, clean,
                                           torch.ones(args.batch_size))

            avg_cosine_loss += cosine_loss.data.item()
            avg_mse_loss += mse_loss.data.item()
            n_batches += 1

            if args.loss == 'cosine':
                cosine_loss.backward()
            else:
                mse_loss.backward()

            # print(loss.data.item())

            optimizer.step()

        print(avg_cosine_loss / n_batches)

        if args.logdir != '':
            if n_batches > 0:
                avg_cosine_loss /= n_batches
                writer.add_scalar('avg_cosine_loss', avg_cosine_loss, e + 1)
                writer.add_scalar('avg_mse_loss', avg_mse_loss, e + 1)

            if args.weight_histogram and (e + 1) % 10 == 0:
                for name, param in model.named_parameters():
                    writer.add_histogram('parameters/' + name,
                                         param.clone().cpu().data.numpy(),
                                         e + 1)

    print("Testing")
    with torch.no_grad():

        for label, loader in zip(['test', 'final_test'],
                                 [testloader, final_testloader]):

            # Everything is in one batch, so this loop will only happen once
            for i, data in enumerate(loader):

                noisy, clean = data

                outputs = model(noisy)

                mse_loss = mse_criterion(outputs, clean)
                # Modified to use CosineEmbeddingLoss
                cosine_loss = cosine_criterion(outputs, clean,
                                               torch.ones(len(loader)))

                print(cosine_loss.data.item())

            if args.logdir != '':
                # TODO: get a visualization of the performance

                # show plots of the noisy, clean, and cleaned up with the network
                # note that the plotting mechanism itself uses nearest neighbors, so has a form of cleanup built in

                xs = np.linspace(args.limits[0], args.limits[1], 256)
                ys = np.linspace(args.limits[0], args.limits[1], 256)

                heatmap_vectors = get_heatmap_vectors(xs, ys, x_axis_sp,
                                                      y_axis_sp)

                noisy_coord = ssp_to_loc_v(noisy, heatmap_vectors, xs, ys)

                pred_coord = ssp_to_loc_v(outputs, heatmap_vectors, xs, ys)

                clean_coord = ssp_to_loc_v(clean, heatmap_vectors, xs, ys)

                fig_noisy_coord, ax_noisy_coord = plt.subplots()
                fig_pred_coord, ax_pred_coord = plt.subplots()
                fig_clean_coord, ax_clean_coord = plt.subplots()

                plot_predictions_v(noisy_coord,
                                   clean_coord,
                                   ax_noisy_coord,
                                   min_val=args.limits[0],
                                   max_val=args.limits[1],
                                   fixed_axes=True)

                plot_predictions_v(pred_coord,
                                   clean_coord,
                                   ax_pred_coord,
                                   min_val=args.limits[0],
                                   max_val=args.limits[1],
                                   fixed_axes=True)

                plot_predictions_v(clean_coord,
                                   clean_coord,
                                   ax_clean_coord,
                                   min_val=args.limits[0],
                                   max_val=args.limits[1],
                                   fixed_axes=True)

                writer.add_figure('{}/original_noise'.format(label),
                                  fig_noisy_coord)
                writer.add_figure('{}/test_set_cleanup'.format(label),
                                  fig_pred_coord)
                writer.add_figure('{}/ground_truth'.format(label),
                                  fig_clean_coord)
                # fig_hist = plot_histogram(predictions=outputs, coords=coord)
                # writer.add_figure('test set histogram', fig_hist)
                writer.add_scalar('{}/test_cosine_loss'.format(label),
                                  cosine_loss.data.item())
                writer.add_scalar('{}/test_mse_loss'.format(label),
                                  mse_loss.data.item())

    # Close tensorboard writer
    if args.logdir != '':
        writer.close()

        torch.save(model.state_dict(), osp.join(save_dir, 'model.pt'))

        params = vars(args)
        # # Additionally save the axis vectors used
        # params['x_axis_vec'] = list(x_axis_sp.v)
        # params['y_axis_vec'] = list(y_axis_sp.v)
        with open(osp.join(save_dir, "params.json"), "w") as f:
            json.dump(params, f)