# fname = '/home/bjkomer/ssp_navigation_sandbox/axis_vector_explorations/heatmap_output/256dim_25seeds.npz' # data = np.load(fname) # square_heatmaps = data['square_heatmaps'] # hex_heatmaps = data['hex_heatmaps'] # # avg_square_heatmap = square_heatmaps.mean(axis=0) # avg_hex_heatmap = hex_heatmaps.mean(axis=0) dim = 512 res = 256 limit = 5 xs = np.linspace(-limit, limit, res) Xh, Yh = get_axes(dim=dim, seed=13) rng = np.random.RandomState(seed=13) X = make_good_unitary(dim=dim, rng=rng) Y = make_good_unitary(dim=dim, rng=rng) Z = make_good_unitary(dim=dim, rng=rng) fig, ax = plt.subplots(1, 3, figsize=(8, 4)) sigma_normal = 0.5 sigma_hex = 0.5 sigma_hex_c = 0.5 sim_hex = np.zeros((res, )) sim_hex_c = np.zeros((res, )) # this version has axes generated together and then converted to 2D sim_normal = np.zeros((res, ))
import nengo import numpy as np from spatial_semantic_pointers.utils import encode_point, make_good_unitary, get_axes, generate_region_vector, get_heatmap_vectors from spatial_semantic_pointers.plots import SpatialHeatmap from ssp_navigation.utils.encodings import hilbert_2d # generate_region_vector(desired, xs, ys, x_axis_sp, y_axis_sp, normalize=True) dim = 128 #256 limit_low = -5 limit_high = 5 rng = np.random.RandomState(seed=13) X, Y = get_axes(dim=dim, n=3, seed=13, period=0, optimal_phi=False) def to_ssp(v): return encode_point(v[0], v[1], X, Y).v # 3 directions 120 degrees apart vec_dirs = [0, 2 * np.pi / 3, 4 * np.pi / 3] spacing = 4 def to_hex_region_ssp(v, spacing=4): ret = np.zeros((dim, )) ret[:] = encode_point(v[0], v[1], X, Y).v
path_prefix = '/media/ctnuser/53f2c4b3-4b3b-4768-ba69-f0a3da30c237/ctnuser/data/neural_implementation_output' # if not os.path.exists('output'): # os.makedirs('output') diff_axis = True grid_axes = False rng = np.random.RandomState(seed=13) Xs = [] Ys = [] for i in range(args.n_envs): X, Y = get_axes(dim=args.dim, n=3, seed=13+i, period=0, optimal_phi=False) Xs.append(X) Ys.append(Y) # if grid_axes: # X, Y = get_fixed_dim_grid_axes(dim=args.dim, seed=13) # else: # X, Y = get_axes(dim=args.dim, n=3, seed=13, period=0, optimal_phi=False) # # X_new, Y_new = get_axes(dim=dim, n=3, seed=14, period=0, optimal_phi=False) def to_ssp(v): return encode_point(v[0], v[1], X, Y).v
# only the 4th channel has the data, values from 0 to 255 desired_complex = resize( imageio.imread('assets/icons8-star-96.png') / 255, (res, res))[:, :, 3] # desired_complex[128-48:128+48, 128-48:128+48] = star_im[:, :, 3] # print(star_im.shape) # print(np.min(star_im[:, :, 3])) # print(np.max(star_im[:, :, 3])) rng = np.random.RandomState(seed=seed) # X = make_good_unitary(dim=dim, rng=rng) # Y = make_good_unitary(dim=dim, rng=rng) X, Y = get_axes(dim=dim, n=3, seed=seed) hmv = get_heatmap_vectors(xs, ys, X, Y) circular_sp = generate_region_vector(desired_circular, xs, ys, X, Y, normalize=True) rectangular_sp = generate_region_vector(desired_rectangular, xs, ys, X, Y, normalize=True)
def main(): parser = argparse.ArgumentParser( 'Train a network to clean up a noisy spatial semantic pointer') parser.add_argument('--loss', type=str, default='cosine', choices=['cosine', 'mse']) parser.add_argument('--noise-type', type=str, default='memory', choices=['memory', 'gaussian', 'both']) parser.add_argument( '--sigma', type=float, default=1.0, help='sigma on the gaussian noise if noise-type==gaussian') parser.add_argument('--train-fraction', type=float, default=.8, help='proportion of the dataset to use for training') parser.add_argument( '--n-samples', type=int, default=10000, help= 'Number of memories to generate. Total samples will be n-samples * n-items' ) parser.add_argument('--n-items', type=int, default=12, help='number of items in memory. Proxy for noisiness') parser.add_argument('--dim', type=int, default=512, help='Dimensionality of the semantic pointers') parser.add_argument('--hidden-size', type=int, default=512, help='Hidden size of the cleanup network') parser.add_argument('--limits', type=str, default="-5,5,-5,5", help='The limits of the space') parser.add_argument('--epochs', type=int, default=50) parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--lr', type=float, default=0.001) parser.add_argument('--momentum', type=float, default=0.9) parser.add_argument('--seed', type=int, default=13) parser.add_argument('--logdir', type=str, default='ssp_cleanup', help='Directory for saved model and tensorboard log') parser.add_argument('--load-model', type=str, default='', help='Optional model to continue training from') parser.add_argument( '--name', type=str, default='', help= 'Name of output folder within logdir. Will use current date and time if blank' ) parser.add_argument('--weight-histogram', action='store_true', help='Save histograms of the weights if set') parser.add_argument('--use-hex-ssp', action='store_true') parser.add_argument('--optimizer', type=str, default='adam', choices=['sgd', 'adam', 'rmsprop']) args = parser.parse_args() args.limits = tuple(float(v) for v in args.limits.split(',')) np.random.seed(args.seed) torch.manual_seed(args.seed) dataset_name = 'data/ssp_cleanup_dataset_dim{}_seed{}_items{}_limit{}_samples{}.npz'.format( args.dim, args.seed, args.n_items, args.limits[1], args.n_samples) final_test_samples = 100 final_test_items = 15 final_test_dataset_name = 'data/ssp_cleanup_test_dataset_dim{}_seed{}_items{}_limit{}_samples{}.npz'.format( args.dim, args.seed, final_test_items, args.limits[1], final_test_samples) if not os.path.exists('data'): os.makedirs('data') rng = np.random.RandomState(seed=args.seed) if args.use_hex_ssp: x_axis_sp, y_axis_sp = get_axes(dim=args.dim, n=3, seed=args.seed) else: x_axis_sp = make_good_unitary(args.dim, rng=rng) y_axis_sp = make_good_unitary(args.dim, rng=rng) if args.noise_type == 'gaussian': # Simple generation clean_ssps = np.zeros((args.n_samples, args.dim)) coords = np.zeros((args.n_samples, 2)) for i in range(args.n_samples): x = np.random.uniform(low=args.limits[0], high=args.limits[1]) y = np.random.uniform(low=args.limits[2], high=args.limits[3]) clean_ssps[i, :] = encode_point(x, y, x_axis_sp=x_axis_sp, y_axis_sp=y_axis_sp).v coords[i, 0] = x coords[i, 1] = y # Gaussian noise will be added later noisy_ssps = clean_ssps.copy() else: if os.path.exists(dataset_name): print("Loading dataset") data = np.load(dataset_name) clean_ssps = data['clean_ssps'] noisy_ssps = data['noisy_ssps'] else: print("Generating SSP cleanup dataset") clean_ssps, noisy_ssps, coords = generate_cleanup_dataset( x_axis_sp=x_axis_sp, y_axis_sp=y_axis_sp, n_samples=args.n_samples, dim=args.dim, n_items=args.n_items, limits=args.limits, seed=args.seed, ) print("Dataset generation complete. Saving dataset") np.savez( dataset_name, clean_ssps=clean_ssps, noisy_ssps=noisy_ssps, coords=coords, x_axis_vec=x_axis_sp.v, y_axis_vec=x_axis_sp.v, ) # check if the final test set has been generated yet if os.path.exists(final_test_dataset_name): print("Loading final test dataset") final_test_data = np.load(final_test_dataset_name) final_test_clean_ssps = final_test_data['clean_ssps'] final_test_noisy_ssps = final_test_data['noisy_ssps'] else: print("Generating final test dataset") final_test_clean_ssps, final_test_noisy_ssps, final_test_coords = generate_cleanup_dataset( x_axis_sp=x_axis_sp, y_axis_sp=y_axis_sp, n_samples=final_test_samples, dim=args.dim, n_items=final_test_items, limits=args.limits, seed=args.seed, ) print("Final test generation complete. Saving dataset") np.savez( final_test_dataset_name, clean_ssps=final_test_clean_ssps, noisy_ssps=final_test_noisy_ssps, coords=final_test_coords, x_axis_vec=x_axis_sp.v, y_axis_vec=x_axis_sp.v, ) # Add gaussian noise if required if args.noise_type == 'gaussian' or args.noise_type == 'both': noisy_ssps += np.random.normal(loc=0, scale=args.sigma, size=noisy_ssps.shape) n_samples = clean_ssps.shape[0] n_train = int(args.train_fraction * n_samples) n_test = n_samples - n_train assert (n_train > 0 and n_test > 0) train_clean = clean_ssps[:n_train, :] train_noisy = noisy_ssps[:n_train, :] test_clean = clean_ssps[n_train:, :] test_noisy = noisy_ssps[n_train:, :] # NOTE: this dataset is actually generic and can take any input/output mapping dataset_train = CoordDecodeDataset(vectors=train_noisy, coords=train_clean) dataset_test = CoordDecodeDataset(vectors=test_noisy, coords=test_clean) dataset_final_test = CoordDecodeDataset(vectors=final_test_noisy_ssps, coords=final_test_clean_ssps) trainloader = torch.utils.data.DataLoader( dataset_train, batch_size=args.batch_size, shuffle=True, num_workers=0, ) # For testing just do everything in one giant batch testloader = torch.utils.data.DataLoader( dataset_test, batch_size=len(dataset_test), shuffle=False, num_workers=0, ) final_testloader = torch.utils.data.DataLoader( dataset_final_test, batch_size=len(dataset_final_test), shuffle=False, num_workers=0, ) model = FeedForward(dim=dataset_train.dim, hidden_size=args.hidden_size, output_size=dataset_train.dim) # Open a tensorboard writer if a logging directory is given if args.logdir != '': current_time = datetime.now().strftime('%b%d_%H-%M-%S') save_dir = osp.join(args.logdir, current_time) writer = SummaryWriter(log_dir=save_dir) if args.weight_histogram: # Log the initial parameters for name, param in model.named_parameters(): writer.add_histogram('parameters/' + name, param.clone().cpu().data.numpy(), 0) mse_criterion = nn.MSELoss() cosine_criterion = nn.CosineEmbeddingLoss() if args.optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) elif args.optimizer == 'rmsprop': optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) else: raise NotImplementedError for e in range(args.epochs): print('Epoch: {0}'.format(e + 1)) avg_mse_loss = 0 avg_cosine_loss = 0 n_batches = 0 for i, data in enumerate(trainloader): noisy, clean = data if noisy.size()[0] != args.batch_size: continue # Drop data, not enough for a batch optimizer.zero_grad() outputs = model(noisy) mse_loss = mse_criterion(outputs, clean) # Modified to use CosineEmbeddingLoss cosine_loss = cosine_criterion(outputs, clean, torch.ones(args.batch_size)) avg_cosine_loss += cosine_loss.data.item() avg_mse_loss += mse_loss.data.item() n_batches += 1 if args.loss == 'cosine': cosine_loss.backward() else: mse_loss.backward() # print(loss.data.item()) optimizer.step() print(avg_cosine_loss / n_batches) if args.logdir != '': if n_batches > 0: avg_cosine_loss /= n_batches writer.add_scalar('avg_cosine_loss', avg_cosine_loss, e + 1) writer.add_scalar('avg_mse_loss', avg_mse_loss, e + 1) if args.weight_histogram and (e + 1) % 10 == 0: for name, param in model.named_parameters(): writer.add_histogram('parameters/' + name, param.clone().cpu().data.numpy(), e + 1) print("Testing") with torch.no_grad(): for label, loader in zip(['test', 'final_test'], [testloader, final_testloader]): # Everything is in one batch, so this loop will only happen once for i, data in enumerate(loader): noisy, clean = data outputs = model(noisy) mse_loss = mse_criterion(outputs, clean) # Modified to use CosineEmbeddingLoss cosine_loss = cosine_criterion(outputs, clean, torch.ones(len(loader))) print(cosine_loss.data.item()) if args.logdir != '': # TODO: get a visualization of the performance # show plots of the noisy, clean, and cleaned up with the network # note that the plotting mechanism itself uses nearest neighbors, so has a form of cleanup built in xs = np.linspace(args.limits[0], args.limits[1], 256) ys = np.linspace(args.limits[0], args.limits[1], 256) heatmap_vectors = get_heatmap_vectors(xs, ys, x_axis_sp, y_axis_sp) noisy_coord = ssp_to_loc_v(noisy, heatmap_vectors, xs, ys) pred_coord = ssp_to_loc_v(outputs, heatmap_vectors, xs, ys) clean_coord = ssp_to_loc_v(clean, heatmap_vectors, xs, ys) fig_noisy_coord, ax_noisy_coord = plt.subplots() fig_pred_coord, ax_pred_coord = plt.subplots() fig_clean_coord, ax_clean_coord = plt.subplots() plot_predictions_v(noisy_coord, clean_coord, ax_noisy_coord, min_val=args.limits[0], max_val=args.limits[1], fixed_axes=True) plot_predictions_v(pred_coord, clean_coord, ax_pred_coord, min_val=args.limits[0], max_val=args.limits[1], fixed_axes=True) plot_predictions_v(clean_coord, clean_coord, ax_clean_coord, min_val=args.limits[0], max_val=args.limits[1], fixed_axes=True) writer.add_figure('{}/original_noise'.format(label), fig_noisy_coord) writer.add_figure('{}/test_set_cleanup'.format(label), fig_pred_coord) writer.add_figure('{}/ground_truth'.format(label), fig_clean_coord) # fig_hist = plot_histogram(predictions=outputs, coords=coord) # writer.add_figure('test set histogram', fig_hist) writer.add_scalar('{}/test_cosine_loss'.format(label), cosine_loss.data.item()) writer.add_scalar('{}/test_mse_loss'.format(label), mse_loss.data.item()) # Close tensorboard writer if args.logdir != '': writer.close() torch.save(model.state_dict(), osp.join(save_dir, 'model.pt')) params = vars(args) # # Additionally save the axis vectors used # params['x_axis_vec'] = list(x_axis_sp.v) # params['y_axis_vec'] = list(y_axis_sp.v) with open(osp.join(save_dir, "params.json"), "w") as f: json.dump(params, f)