Exemplo n.º 1
0
        shuffle=False,
        num_workers=0,
    )

    with torch.no_grad():
        # Everything is in one batch, so this loop will only happen once
        for i, data in enumerate(testloader):
            ssp, coord = data

            model_outputs = model(ssp)

        fig_pred, ax_pred = plt.subplots(tight_layout=True)
        plot_predictions_v(
            predictions=model_outputs,
            coords=coord,
            ax=ax_pred,
            min_val=-args.limit * 1.1,
            max_val=args.limit * 1.1,
            fixed_axes=False,
        )
        ax_pred.set_title(enc_names[enc], fontsize=24)
        fig_pred.savefig("figures/untrained_{}_limit{}_dim{}.pdf".format(
            enc, int(args.limit), args.dim))
        # only record the ground truth once
        if ei == 0:
            fig_truth, ax_truth = plt.subplots(tight_layout=True)
            plot_predictions_v(
                predictions=coord,
                coords=coord,
                ax=ax_truth,
                min_val=-args.limit * 1.1,
                max_val=args.limit * 1.1,
Exemplo n.º 2
0
                coords_start[:, :] = ssp_to_loc_v(coord_start, heatmap_vectors,
                                                  xs, ys)
                coord_end = ssp_outputs.detach().numpy()[:, -1, :]
                coord_end = coord_end / coord_end.sum(axis=1)[:, np.newaxis]
                coords_end[:, :] = ssp_to_loc_v(coord_end, heatmap_vectors, xs,
                                                ys)

            fig_pred_start, ax_pred_start = plt.subplots()
            fig_truth_start, ax_truth_start = plt.subplots()
            fig_pred_end, ax_pred_end = plt.subplots()
            fig_truth_end, ax_truth_end = plt.subplots()

            print("plotting predicted locations")
            plot_predictions_v(predictions_start / ssp_scaling,
                               coords_start / ssp_scaling,
                               ax_pred_start,
                               min_val=0,
                               max_val=2.2)
            plot_predictions_v(predictions_end / ssp_scaling,
                               coords_end / ssp_scaling,
                               ax_pred_end,
                               min_val=0,
                               max_val=2.2)
            print("plotting ground truth locations")
            plot_predictions_v(coords_start / ssp_scaling,
                               coords_start / ssp_scaling,
                               ax_truth_start,
                               min_val=0,
                               max_val=2.2)
            plot_predictions_v(coords_end / ssp_scaling,
                               coords_end / ssp_scaling,
Exemplo n.º 3
0
                              size=(n_samples, args.dim))

# encodings /= encodings.sum(axis=1)[:, np.newaxis]

predictions = ssp_to_loc_v(
    # flat_heatmap_vectors,
    encodings,
    heatmap_vectors,
    xs,
    ys)

print(predictions)

coords = predictions.copy()

fig_pred, ax_pred = plt.subplots()

print("plotting predicted locations")
plot_predictions_v(predictions / args.ssp_scaling,
                   coords / args.ssp_scaling,
                   ax_pred,
                   min_val=args.limit_low,
                   max_val=args.limit_high)

if args.n_show_activations > 0:
    for i in range(min(args.dim, args.n_show_activations)):
        plt.figure()
        plt.imshow(heatmap_vectors[:, :, i])

plt.show()
    # batch_data = viz_eval[out_p_filt][:, -1, :]
    batch_data = viz_eval[out_p_filt][:, 10:, :].mean(axis=1)
    true_ssps = vis_output[bi, :, :]

    print('pred.shape', batch_data.shape)
    print('true_ssps.shape', true_ssps.shape)

    wall_overlay = np.sum(true_ssps, axis=1) == 0

    print('wall_overlay.shape', wall_overlay.shape)

    hmv = get_encoding_heatmap_vectors(xs, ys, args.dim, encoding_func, normalize=False)

    predictions = np.zeros((res * res, 2))

    # computing 'predicted' coordinates, where the agent thinks it is
    predictions[:, :] = ssp_to_loc_v(
        batch_data,
        hmv, xs, ys
    )

    plot_predictions_v(
        predictions=predictions, coords=coords,
        ax=ax[bi],
        min_val=limit_low,
        max_val=limit_high,
        fixed_axes=True,
    )

plt.show()
Exemplo n.º 5
0
            fig_pred_start, ax_pred_start = plt.subplots()
            fig_truth_start, ax_truth_start = plt.subplots()
            fig_pred_end, ax_pred_end = plt.subplots()
            fig_truth_end, ax_truth_end = plt.subplots()

            # print("plotting predicted locations")
            # plot_predictions_v(predictions_start / args.ssp_scaling, coords_start / args.ssp_scaling, ax_pred_start, min_val=0, max_val=2.2, fixed_axes=True)
            # plot_predictions_v(predictions_end / args.ssp_scaling, coords_end / args.ssp_scaling, ax_pred_end, min_val=0, max_val=2.2, fixed_axes=True)
            # print("plotting ground truth locations")
            # plot_predictions_v(coords_start / args.ssp_scaling, coords_start / args.ssp_scaling, ax_truth_start, min_val=0, max_val=2.2, fixed_axes=True)
            # plot_predictions_v(coords_end / args.ssp_scaling, coords_end / args.ssp_scaling, ax_truth_end, min_val=0, max_val=2.2, fixed_axes=True)

            print("plotting predicted locations")
            plot_predictions_v(predictions_start,
                               coords_start,
                               ax_pred_start,
                               min_val=0,
                               max_val=2.2,
                               fixed_axes=True)
            plot_predictions_v(predictions_end,
                               coords_end,
                               ax_pred_end,
                               min_val=0,
                               max_val=2.2,
                               fixed_axes=True)
            # only plot ground truth once at the start
            if epoch == 0:
                print("plotting ground truth locations")
                plot_predictions_v(coords_start,
                                   coords_start,
                                   ax_truth_start,
                                   min_val=0,
Exemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser(
        'Train a network to learn a mapping from an encoded value to 2D coordinate. View output over time.'
    )

    # parser.add_argument('--viz-period', type=int, default=10, help='number of epochs before a viz set run')
    parser.add_argument(
        '--val-period',
        type=int,
        default=5,
        help='number of epochs before a test/validation set run')
    parser.add_argument(
        '--spatial-encoding',
        type=str,
        default='hex-ssp',
        choices=[
            'ssp', 'hex-ssp', 'periodic-hex-ssp', 'grid-ssp', 'ind-ssp',
            'random', '2d', '2d-normalized', 'one-hot', 'hex-trig', 'trig',
            'random-trig', 'random-rotated-trig', 'random-proj', 'legendre',
            'learned', 'learned-normalized', 'frozen-learned',
            'frozen-learned-normalized', 'pc-gauss', 'pc-dog', 'tile-coding'
        ],
        help='coordinate encoding for agent location and goal')
    parser.add_argument(
        '--freq-limit',
        type=float,
        default=10,
        help='highest frequency of sine wave for random-trig encodings')
    parser.add_argument('--hex-freq-coef',
                        type=float,
                        default=2.5,
                        help='constant to scale frequencies by for hex-trig')
    parser.add_argument('--pc-gauss-sigma',
                        type=float,
                        default=0.75,
                        help='sigma for the gaussians')
    parser.add_argument('--pc-diff-sigma',
                        type=float,
                        default=1.5,
                        help='sigma for subtracted gaussian in DoG')
    parser.add_argument(
        '--hilbert-points',
        type=int,
        default=1,
        choices=[0, 1, 2, 3],
        help=
        'pc centers. 0: random uniform. 1: hilbert curve. 2: evenly spaced grid. 3: hex grid'
    )
    parser.add_argument('--n-tiles',
                        type=int,
                        default=8,
                        help='number of layers for tile coding')
    parser.add_argument('--n-bins',
                        type=int,
                        default=0,
                        help='number of bins for tile coding')
    parser.add_argument('--ssp-scaling', type=float, default=1.0)
    parser.add_argument('--grid-ssp-min',
                        type=float,
                        default=0.25,
                        help='minimum plane wave scale')
    parser.add_argument('--grid-ssp-max',
                        type=float,
                        default=2.0,
                        help='maximum plane wave scale')
    parser.add_argument('--hidden-size', type=int, default=512)

    parser.add_argument('--optimizer',
                        type=str,
                        default='adam',
                        choices=['rmsprop', 'adam', 'sgd'])

    parser.add_argument('--train-fraction',
                        type=float,
                        default=.5,
                        help='proportion of the dataset to use for training')
    # NOTE: this is changed to be smaller to see the effects of each epoch more
    parser.add_argument(
        '--n-samples',
        type=int,
        default=2500,
        help='Number of samples to generate if a dataset is not given')
    parser.add_argument('--dim',
                        type=int,
                        default=512,
                        help='Dimensionality of the semantic pointers')
    parser.add_argument('--limit',
                        type=float,
                        default=1,
                        help='The limits of the space')
    parser.add_argument('--epochs', type=int, default=25)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--lr', type=float, default=0.001)
    # parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--seed', type=int, default=13)
    parser.add_argument('--res', type=int, default=64)

    parser.add_argument('--load-model',
                        type=str,
                        default='',
                        help='Optional model to continue training from')

    args = parser.parse_args()

    fname = 'training_over_time_data.npz'

    if osp.exists(fname):
        data = np.load(fname)
        data_outputs = data['data_outputs']
        data_loss = data['data_loss']
        test_coords = data['test_coords']
    else:

        data_outputs = np.zeros((args.epochs, args.res * args.res, 2))
        data_loss = np.zeros((args.epochs, ))

        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

        encoding_func, repr_dim = get_encoding_function(args,
                                                        limit_low=-args.limit,
                                                        limit_high=args.limit)

        vectors, coords = generate_coord_dataset(
            encoding_func=encoding_func,
            n_samples=args.n_samples,
            dim=repr_dim,
            limit=args.limit,
            seed=args.seed,
        )

        n_samples = vectors.shape[0]
        n_train = int(args.train_fraction * n_samples)
        n_test = n_samples - n_train
        assert (n_train > 0 and n_test > 0)
        train_vectors = vectors[:n_train]
        train_coords = coords[:n_train]
        # test_vectors = vectors[n_train:]
        # test_coords = coords[n_train:]

        test_vectors = np.zeros((args.res * args.res, args.dim))
        test_coords = np.zeros((args.res * args.res, 2))

        # linspace for test, for easy visualization
        xs = np.linspace(-args.limit, args.limit, args.res)
        for i, x in enumerate(xs):
            for j, y in enumerate(xs):
                test_coords[i * args.res + j, 0] = x
                test_coords[i * args.res + j, 1] = y
                test_vectors[i * args.res + j, :] = encoding_func(x, y)

        dataset_train = GenericDataset(inputs=train_vectors,
                                       outputs=train_coords)
        dataset_test = GenericDataset(inputs=test_vectors, outputs=test_coords)

        trainloader = torch.utils.data.DataLoader(
            dataset_train,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=0,
        )

        # For testing just do everything in one giant batch
        testloader = torch.utils.data.DataLoader(
            dataset_test,
            batch_size=len(dataset_test),
            shuffle=False,
            num_workers=0,
        )

        model = FeedForward(input_size=repr_dim,
                            hidden_size=args.hidden_size,
                            output_size=2)

        criterion = nn.MSELoss()

        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum)
        elif args.optimizer == 'rmsprop':
            optimizer = torch.optim.RMSprop(model.parameters(),
                                            lr=args.lr,
                                            momentum=args.momentum)
        elif args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        else:
            raise NotImplementedError

        for e in range(args.epochs):
            print('Epoch: {0}'.format(e + 1))

            with torch.no_grad():

                # Everything is in one batch, so this loop will only happen once
                for i, data in enumerate(testloader):
                    ssp, coord = data

                    outputs = model(ssp)

                    loss = criterion(outputs, coord)

                data_loss[e] = loss.data.item()
                data_outputs[e, :, :] = outputs.detach().numpy()

            avg_loss = 0
            n_batches = 0
            for i, data in enumerate(trainloader):

                ssp, coord = data

                if ssp.size()[0] != args.batch_size:
                    continue  # Drop data, not enough for a batch
                optimizer.zero_grad()

                outputs = model(ssp)

                loss = criterion(outputs, coord)
                # print(loss.data.item())
                avg_loss += loss.data.item()
                n_batches += 1

                loss.backward()

                optimizer.step()

        np.savez(
            fname,
            data_outputs=data_outputs,
            data_loss=data_loss,
            test_coords=test_coords,
        )

    # epoch_list = [0, 1, 2, 3, 4, 5]
    # epoch_list = [0, 1, 5, 10, 24]
    epoch_list = [0, 1, 10]
    fix, ax = plt.subplots(1,
                           len(epoch_list),
                           tight_layout=True,
                           figsize=(9, 3))
    # plot data
    for ei, epoch in enumerate(epoch_list):

        plot_predictions_v(
            predictions=data_outputs[ei, :, :],
            coords=test_coords,
            ax=ax[ei],
            min_val=-args.limit,
            max_val=args.limit,
            fixed_axes=ei != 0,
        )
        ax[ei].set_title("Epoch {} - Loss {:.5f}".format(
            epoch, data_loss[epoch]))

    plt.show()
Exemplo n.º 7
0
                outputs,
                heatmap_vectors, xs, ys
            )

            clean_coord = ssp_to_loc_v(
                clean,
                heatmap_vectors, xs, ys
            )

            fig_noisy_coord, ax_noisy_coord = plt.subplots()
            fig_pred_coord, ax_pred_coord = plt.subplots()
            fig_clean_coord, ax_clean_coord = plt.subplots()

            plot_predictions_v(
                noisy_coord,
                clean_coord,
                ax_noisy_coord, min_val=-args.limit*2, max_val=args.limit*2, fixed_axes=True
            )

            plot_predictions_v(
                pred_coord,
                clean_coord,
                ax_pred_coord, min_val=-args.limit*2, max_val=args.limit*2, fixed_axes=True
            )

            plot_predictions_v(
                clean_coord,
                clean_coord,
                ax_clean_coord, min_val=-args.limit*2, max_val=args.limit*2, fixed_axes=True
            )
    def run_eval(self, model, writer, epoch):

        with torch.no_grad():
            # Everything is in one batch, so this loop will only happen once
            for i, data in enumerate(self.dataloader):
                combined_inputs, ssp_inputs, ssp_outputs = data

                ssp_pred = model(combined_inputs, ssp_inputs)

                # NOTE: need to permute axes of the targets here because the output is
                #       (sequence length, batch, units) instead of (batch, sequence_length, units)
                #       could also permute the outputs instead
                # NOTE: for cosine loss the input needs to be flattened first
                cosine_loss = self.cosine_criterion(
                    ssp_pred.reshape(ssp_pred.shape[0] * ssp_pred.shape[1],
                                     ssp_pred.shape[2]),
                    ssp_outputs.permute(1, 0, 2).reshape(
                        ssp_pred.shape[0] * ssp_pred.shape[1],
                        ssp_pred.shape[2]),
                    torch.ones(ssp_pred.shape[0] * ssp_pred.shape[1]))
                mse_loss = self.mse_criterion(ssp_pred,
                                              ssp_outputs.permute(1, 0, 2))

                print("test mse loss", mse_loss.data.item())
                print("test cosine loss", mse_loss.data.item())

            writer.add_scalar('test_mse_loss', mse_loss.data.item(), epoch)
            writer.add_scalar('test_cosine_loss', cosine_loss.data.item(),
                              epoch)

            # Just use start and end location to save on memory and computation
            predictions_start = np.zeros((ssp_pred.shape[1], 2))
            coords_start = np.zeros((ssp_pred.shape[1], 2))

            predictions_end = np.zeros((ssp_pred.shape[1], 2))
            coords_end = np.zeros((ssp_pred.shape[1], 2))

            if self.spatial_encoding == 'ssp':
                print("computing prediction locations")
                predictions_start[:, :] = ssp_to_loc_v(
                    ssp_pred.detach().numpy()[0, :, :], self.heatmap_vectors,
                    self.xs, self.ys)
                predictions_end[:, :] = ssp_to_loc_v(
                    ssp_pred.detach().numpy()[-1, :, :], self.heatmap_vectors,
                    self.xs, self.ys)
                print("computing ground truth locations")
                coords_start[:, :] = ssp_to_loc_v(
                    ssp_outputs.detach().numpy()[:, 0, :],
                    self.heatmap_vectors, self.xs, self.ys)
                coords_end[:, :] = ssp_to_loc_v(
                    ssp_outputs.detach().numpy()[:, -1, :],
                    self.heatmap_vectors, self.xs, self.ys)
            elif self.spatial_encoding == '2d':
                print("copying prediction locations")
                predictions_start[:, :] = ssp_pred.detach().numpy()[0, :, :]
                predictions_end[:, :] = ssp_pred.detach().numpy()[-1, :, :]
                print("copying ground truth locations")
                coords_start[:, :] = ssp_outputs.detach().numpy()[:, 0, :]
                coords_end[:, :] = ssp_outputs.detach().numpy()[:, -1, :]

            fig_pred_start, ax_pred_start = plt.subplots()
            fig_truth_start, ax_truth_start = plt.subplots()
            fig_pred_end, ax_pred_end = plt.subplots()
            fig_truth_end, ax_truth_end = plt.subplots()

            print("plotting predicted locations")
            plot_predictions_v(
                predictions_start / self.ssp_scaling,
                coords_start / self.ssp_scaling,
                ax_pred_start,
                min_val=self.xs[0],
                max_val=self.xs[-1],
            )
            plot_predictions_v(
                predictions_end / self.ssp_scaling,
                coords_end / self.ssp_scaling,
                ax_pred_end,
                min_val=self.xs[0],
                max_val=self.xs[-1],
            )

            writer.add_figure("predictions start", fig_pred_start, epoch)
            writer.add_figure("predictions end", fig_pred_end, epoch)

            # Only plotting ground truth if the epoch is 0
            if epoch == 0:

                print("plotting ground truth locations")
                plot_predictions_v(
                    coords_start / self.ssp_scaling,
                    coords_start / self.ssp_scaling,
                    ax_truth_start,
                    min_val=self.xs[0],
                    max_val=self.xs[-1],
                )
                plot_predictions_v(
                    coords_end / self.ssp_scaling,
                    coords_end / self.ssp_scaling,
                    ax_truth_end,
                    min_val=self.xs[0],
                    max_val=self.xs[-1],
                )

                writer.add_figure("ground truth start", fig_truth_start, epoch)
                writer.add_figure("ground truth end", fig_truth_end, epoch)
Exemplo n.º 9
0
def train(args, trainloader, testloader, input_size, output_size=2):

    model = FeedForward(input_size=input_size, hidden_size=args.hidden_size, output_size=output_size)

    # Open a tensorboard writer if a logging directory is given
    if args.logdir != '':
        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        save_dir = osp.join(args.logdir, current_time)
        writer = SummaryWriter(log_dir=save_dir)
        # if args.weight_histogram:
        #     # Log the initial parameters
        #     for name, param in model.named_parameters():
        #         writer.add_histogram('parameters/' + name, param.clone().cpu().data.numpy(), 0)

    criterion = nn.MSELoss()

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    else:
        raise NotImplementedError

    for e in range(args.epochs):
        print('Epoch: {0}'.format(e + 1))

        if e % args.val_period == 0:
            with torch.no_grad():

                # Everything is in one batch, so this loop will only happen once
                for i, data in enumerate(testloader):
                    ssp, coord = data

                    outputs = model(ssp)

                    loss = criterion(outputs, coord)

                if args.logdir != '':

                    if not args.no_viz:
                        if output_size == 2:
                            fig_pred, ax_pred = plt.subplots()
                            plot_predictions_v(
                                predictions=outputs, coords=coord,
                                ax=ax_pred,
                                min_val=-args.limit*1.1,
                                max_val=args.limit*1.1,
                                fixed_axes=False,
                            )
                            writer.add_figure('test set predictions', fig_pred, e)
                    writer.add_scalar('test_loss', loss.data.item(), e)

        avg_loss = 0
        n_batches = 0
        for i, data in enumerate(trainloader):

            ssp, coord = data

            if ssp.size()[0] != args.batch_size:
                continue  # Drop data, not enough for a batch
            optimizer.zero_grad()

            outputs = model(ssp)

            loss = criterion(outputs, coord)
            # print(loss.data.item())
            avg_loss += loss.data.item()
            n_batches += 1

            loss.backward()

            optimizer.step()

        if args.logdir != '':
            if n_batches > 0:
                avg_loss /= n_batches
                writer.add_scalar('avg_loss', avg_loss, e + 1)

            # if args.weight_histogram and (e + 1) % 10 == 0:
            #     for name, param in model.named_parameters():
            #         writer.add_histogram('parameters/' + name, param.clone().cpu().data.numpy(), e + 1)

    print("Testing")
    with torch.no_grad():

        # Everything is in one batch, so this loop will only happen once
        for i, data in enumerate(testloader):

            ssp, coord = data

            outputs = model(ssp)

            loss = criterion(outputs, coord)

            # print(loss.data.item())

        if args.logdir != '':

            if not args.no_viz:
                if output_size == 2:
                    fig_pred, ax_pred = plt.subplots()
                    fig_truth, ax_truth = plt.subplots()
                    plot_predictions_v(
                        predictions=outputs, coords=coord,
                        ax=ax_pred,
                        min_val=-args.limit*1.1,
                        max_val=args.limit*1.1,
                        fixed_axes=False,
                    )
                    writer.add_figure('test set predictions', fig_pred, args.epochs)
                    plot_predictions_v(
                        predictions=coord, coords=coord,
                        ax=ax_truth,
                        min_val=-args.limit*1.1,
                        max_val=args.limit*1.1,
                        fixed_axes=False,
                    )
                    writer.add_figure('ground truth', fig_truth)
            # fig_hist = plot_histogram(predictions=outputs, coords=coord)
            # writer.add_figure('test set histogram', fig_hist)
            writer.add_scalar('test_loss', loss.data.item(), args.epochs)

    # Close tensorboard writer
    if args.logdir != '':
        writer.close()

        torch.save(model.state_dict(), osp.join(save_dir, 'model.pt'))

        params = vars(args)
        # # Additionally save the axis vectors used
        # params['x_axis_vec'] = list(x_axis_sp.v)
        # params['y_axis_vec'] = list(y_axis_sp.v)
        with open(osp.join(save_dir, "params.json"), "w") as f:
            json.dump(params, f)
    def run_eval(self, model, writer, epoch):

        with torch.no_grad():
            # Everything is in one batch, so this loop will only happen once
            for i, data in enumerate(self.dataloader):
                # sensor_inputs, map_ids, ssp_outputs = data
                # sensors and map ID combined
                combined_inputs, ssp_outputs = data

                # ssp_pred = model(sensor_inputs, map_ids)
                ssp_pred = model(combined_inputs)

                cosine_loss = self.cosine_criterion(
                    ssp_pred, ssp_outputs, torch.ones(ssp_pred.shape[0]))
                mse_loss = self.mse_criterion(ssp_pred, ssp_outputs)

                print("test mse loss", mse_loss.data.item())
                print("test cosine loss", mse_loss.data.item())

            writer.add_scalar('test_mse_loss', mse_loss.data.item(), epoch)
            writer.add_scalar('test_cosine_loss', cosine_loss.data.item(),
                              epoch)

            # One prediction and ground truth coord for every element in the batch
            # NOTE: this is assuming the eval set only has one giant batch
            predictions = np.zeros((ssp_pred.shape[0], 2))
            coords = np.zeros((ssp_pred.shape[0], 2))

            if self.spatial_encoding == 'ssp':
                print("computing prediction locations")
                predictions[:, :] = ssp_to_loc_v(
                    ssp_pred.detach().numpy()[:, :], self.heatmap_vectors,
                    self.xs, self.ys)

                print("computing ground truth locations")
                coords[:, :] = ssp_to_loc_v(ssp_outputs.detach().numpy()[:, :],
                                            self.heatmap_vectors, self.xs,
                                            self.ys)

            elif self.spatial_encoding == '2d':
                print("copying prediction locations")
                predictions[:, :] = ssp_pred.detach().numpy()[:, :]
                print("copying ground truth locations")
                coords[:, :] = ssp_outputs.detach().numpy()[:, :]

            fig_pred, ax_pred = plt.subplots()
            fig_truth, ax_truth = plt.subplots()

            print("plotting predicted locations")
            plot_predictions_v(
                # predictions / self.ssp_scaling,
                # coords / self.ssp_scaling,
                predictions,
                coords,
                ax_pred,
                # min_val=0,
                # max_val=2.2
                min_val=self.xs[0],
                max_val=self.xs[-1],
            )

            writer.add_figure("predictions", fig_pred, epoch)

            # Only plot ground truth if epoch is 0
            if epoch == 0:
                print("plotting ground truth locations")
                plot_predictions_v(
                    # coords / self.ssp_scaling,
                    # coords / self.ssp_scaling,
                    coords,
                    coords,
                    ax_truth,
                    # min_val=0,
                    # max_val=2.2
                    min_val=self.xs[0],
                    max_val=self.xs[-1],
                )

                writer.add_figure("ground truth", fig_truth, epoch)
Exemplo n.º 11
0
def main():
    parser = argparse.ArgumentParser(
        'Train a network to learn a mapping from an encoded value to 2D coordinate'
    )

    # parser.add_argument('--viz-period', type=int, default=10, help='number of epochs before a viz set run')
    parser.add_argument(
        '--val-period',
        type=int,
        default=5,
        help='number of epochs before a test/validation set run')
    parser.add_argument(
        '--spatial-encoding',
        type=str,
        default='hex-ssp',
        choices=[
            'ssp', 'hex-ssp', 'periodic-hex-ssp', 'grid-ssp', 'ind-ssp',
            'random', '2d', '2d-normalized', 'one-hot', 'hex-trig',
            'sub-toroid-ssp', 'proj-ssp', 'trig', 'random-trig',
            'random-rotated-trig', 'random-proj', 'legendre', 'learned',
            'learned-normalized', 'frozen-learned',
            'frozen-learned-normalized', 'pc-gauss', 'pc-dog', 'tile-coding'
        ],
        help='coordinate encoding for agent location and goal')
    parser.add_argument(
        '--freq-limit',
        type=float,
        default=10,
        help='highest frequency of sine wave for random-trig encodings')
    parser.add_argument('--hex-freq-coef',
                        type=float,
                        default=2.5,
                        help='constant to scale frequencies by for hex-trig')
    parser.add_argument('--pc-gauss-sigma',
                        type=float,
                        default=0.75,
                        help='sigma for the gaussians')
    parser.add_argument('--pc-diff-sigma',
                        type=float,
                        default=1.5,
                        help='sigma for subtracted gaussian in DoG')
    parser.add_argument(
        '--hilbert-points',
        type=int,
        default=1,
        choices=[0, 1, 2, 3],
        help=
        'pc centers. 0: random uniform. 1: hilbert curve. 2: evenly spaced grid. 3: hex grid'
    )
    parser.add_argument('--n-tiles',
                        type=int,
                        default=8,
                        help='number of layers for tile coding')
    parser.add_argument('--n-bins',
                        type=int,
                        default=0,
                        help='number of bins for tile coding')
    parser.add_argument('--ssp-scaling', type=float, default=1.0)
    parser.add_argument('--grid-ssp-min',
                        type=float,
                        default=0.25,
                        help='minimum plane wave scale')
    parser.add_argument('--grid-ssp-max',
                        type=float,
                        default=2.0,
                        help='maximum plane wave scale')
    parser.add_argument('--n-proj',
                        type=int,
                        default=3,
                        help='projection dimension for sub toroids')
    parser.add_argument('--scale-ratio',
                        type=float,
                        default=(1 + 5**0.5) / 2,
                        help='ratio between sub toroid scales')
    parser.add_argument('--hidden-size', type=int, default=512)

    parser.add_argument('--optimizer',
                        type=str,
                        default='adam',
                        choices=['rmsprop', 'adam', 'sgd'])

    parser.add_argument('--train-fraction',
                        type=float,
                        default=.5,
                        help='proportion of the dataset to use for training')
    parser.add_argument(
        '--n-samples',
        type=int,
        default=10000,
        help='Number of samples to generate if a dataset is not given')
    parser.add_argument('--dim',
                        type=int,
                        default=512,
                        help='Dimensionality of the semantic pointers')
    parser.add_argument('--limit',
                        type=float,
                        default=5,
                        help='The limits of the space')
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--seed', type=int, default=13)
    parser.add_argument('--logdir',
                        type=str,
                        default='coord_decode_function',
                        help='Directory for saved model and tensorboard log')

    parser.add_argument('--load-model',
                        type=str,
                        default='',
                        help='Optional model to continue training from')

    args = parser.parse_args()

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    encoding_func, repr_dim = get_encoding_function(args,
                                                    limit_low=-args.limit,
                                                    limit_high=args.limit)

    vectors, coords = generate_coord_dataset(
        encoding_func=encoding_func,
        n_samples=args.n_samples,
        dim=repr_dim,
        limit=args.limit,
        seed=args.seed,
    )

    n_samples = vectors.shape[0]
    n_train = int(args.train_fraction * n_samples)
    n_test = n_samples - n_train
    assert (n_train > 0 and n_test > 0)
    train_vectors = vectors[:n_train]
    train_coords = coords[:n_train]
    test_vectors = vectors[n_train:]
    test_coords = coords[n_train:]

    dataset_train = GenericDataset(inputs=train_vectors, outputs=train_coords)
    dataset_test = GenericDataset(inputs=test_vectors, outputs=test_coords)

    trainloader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=0,
    )

    # For testing just do everything in one giant batch
    testloader = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=len(dataset_test),
        shuffle=False,
        num_workers=0,
    )

    model = FeedForward(input_size=repr_dim,
                        hidden_size=args.hidden_size,
                        output_size=2)

    # Open a tensorboard writer if a logging directory is given
    if args.logdir != '':
        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        save_dir = osp.join(args.logdir, current_time)
        writer = SummaryWriter(log_dir=save_dir)
        # if args.weight_histogram:
        #     # Log the initial parameters
        #     for name, param in model.named_parameters():
        #         writer.add_histogram('parameters/' + name, param.clone().cpu().data.numpy(), 0)

    criterion = nn.MSELoss()

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    else:
        raise NotImplementedError

    for e in range(args.epochs):
        print('Epoch: {0}'.format(e + 1))

        if e % args.val_period == 0:
            with torch.no_grad():

                # Everything is in one batch, so this loop will only happen once
                for i, data in enumerate(testloader):
                    ssp, coord = data

                    outputs = model(ssp)

                    loss = criterion(outputs, coord)

                if args.logdir != '':
                    fig_pred, ax_pred = plt.subplots()

                    plot_predictions_v(
                        predictions=outputs,
                        coords=coord,
                        ax=ax_pred,
                        min_val=-args.limit * 1.1,
                        max_val=args.limit * 1.1,
                        fixed_axes=False,
                    )
                    writer.add_figure('test set predictions', fig_pred, e)
                    writer.add_scalar('test_loss', loss.data.item(), e)

        avg_loss = 0
        n_batches = 0
        for i, data in enumerate(trainloader):

            ssp, coord = data

            if ssp.size()[0] != args.batch_size:
                continue  # Drop data, not enough for a batch
            optimizer.zero_grad()

            outputs = model(ssp)

            loss = criterion(outputs, coord)
            # print(loss.data.item())
            avg_loss += loss.data.item()
            n_batches += 1

            loss.backward()

            optimizer.step()

        if args.logdir != '':
            if n_batches > 0:
                avg_loss /= n_batches
                writer.add_scalar('avg_loss', avg_loss, e + 1)

            # if args.weight_histogram and (e + 1) % 10 == 0:
            #     for name, param in model.named_parameters():
            #         writer.add_histogram('parameters/' + name, param.clone().cpu().data.numpy(), e + 1)

    print("Testing")
    with torch.no_grad():

        # Everything is in one batch, so this loop will only happen once
        for i, data in enumerate(testloader):

            ssp, coord = data

            outputs = model(ssp)

            loss = criterion(outputs, coord)

            # print(loss.data.item())

        if args.logdir != '':
            fig_pred, ax_pred = plt.subplots()
            fig_truth, ax_truth = plt.subplots()

            plot_predictions_v(
                predictions=outputs,
                coords=coord,
                ax=ax_pred,
                min_val=-args.limit * 1.1,
                max_val=args.limit * 1.1,
                fixed_axes=False,
            )
            writer.add_figure('test set predictions', fig_pred, args.epochs)
            plot_predictions_v(
                predictions=coord,
                coords=coord,
                ax=ax_truth,
                min_val=-args.limit * 1.1,
                max_val=args.limit * 1.1,
                fixed_axes=False,
            )
            writer.add_figure('ground truth', fig_truth)
            # fig_hist = plot_histogram(predictions=outputs, coords=coord)
            # writer.add_figure('test set histogram', fig_hist)
            writer.add_scalar('test_loss', loss.data.item(), args.epochs)

    # Close tensorboard writer
    if args.logdir != '':
        writer.close()

        torch.save(model.state_dict(), osp.join(save_dir, 'model.pt'))

        params = vars(args)
        # # Additionally save the axis vectors used
        # params['x_axis_vec'] = list(x_axis_sp.v)
        # params['y_axis_vec'] = list(y_axis_sp.v)
        with open(osp.join(save_dir, "params.json"), "w") as f:
            json.dump(params, f)
Exemplo n.º 12
0
                hd_outputs_sm.detach().numpy()[-1, :, :],
                centers=hd_centers,
                jitter=0.01
            )

            fig_pc_pred_start, ax_pc_pred_start = plt.subplots()
            fig_pc_truth_start, ax_pc_truth_start = plt.subplots()
            fig_pc_pred_end, ax_pc_pred_end = plt.subplots()
            fig_pc_truth_end, ax_pc_truth_end = plt.subplots()
            fig_hd_pred_start, ax_hd_pred_start = plt.subplots()
            fig_hd_truth_start, ax_hd_truth_start = plt.subplots()
            fig_hd_pred_end, ax_hd_pred_end = plt.subplots()
            fig_hd_truth_end, ax_hd_truth_end = plt.subplots()

            print("plotting predicted locations")
            plot_predictions_v(pc_predictions_start, pc_coords_start, ax_pc_pred_start, min_val=0, max_val=2.2)
            plot_predictions_v(pc_predictions_end, pc_coords_end, ax_pc_pred_end, min_val=0, max_val=2.2)
            plot_predictions_v(hd_predictions_start, hd_coords_start, ax_hd_pred_start, min_val=-1, max_val=1)
            plot_predictions_v(hd_predictions_end, hd_coords_end, ax_hd_pred_end, min_val=-1, max_val=1)

            writer.add_figure("pc predictions start", fig_pc_pred_start, epoch)
            writer.add_figure("pc predictions end", fig_pc_pred_end, epoch)
            writer.add_figure("hd predictions start", fig_hd_pred_start, epoch)
            writer.add_figure("hd predictions end", fig_hd_pred_end, epoch)

    avg_loss = 0
    n_batches = 0
    for i, data in enumerate(trainloader):
        velocity_inputs, pc_inputs, hd_inputs, pc_outputs, hd_outputs = data

        if pc_inputs.size()[0] != batch_size:
def main():
    parser = argparse.ArgumentParser(
        'Train a network to clean up a noisy spatial semantic pointer')

    parser.add_argument('--loss',
                        type=str,
                        default='cosine',
                        choices=['cosine', 'mse'])
    parser.add_argument('--noise-type',
                        type=str,
                        default='memory',
                        choices=['memory', 'gaussian', 'both'])
    parser.add_argument(
        '--sigma',
        type=float,
        default=1.0,
        help='sigma on the gaussian noise if noise-type==gaussian')
    parser.add_argument('--train-fraction',
                        type=float,
                        default=.8,
                        help='proportion of the dataset to use for training')
    parser.add_argument(
        '--n-samples',
        type=int,
        default=10000,
        help=
        'Number of memories to generate. Total samples will be n-samples * n-items'
    )
    parser.add_argument('--n-items',
                        type=int,
                        default=12,
                        help='number of items in memory. Proxy for noisiness')
    parser.add_argument('--dim',
                        type=int,
                        default=512,
                        help='Dimensionality of the semantic pointers')
    parser.add_argument('--hidden-size',
                        type=int,
                        default=512,
                        help='Hidden size of the cleanup network')
    parser.add_argument('--limits',
                        type=str,
                        default="-5,5,-5,5",
                        help='The limits of the space')
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--seed', type=int, default=13)
    parser.add_argument('--logdir',
                        type=str,
                        default='ssp_cleanup',
                        help='Directory for saved model and tensorboard log')
    parser.add_argument('--load-model',
                        type=str,
                        default='',
                        help='Optional model to continue training from')
    parser.add_argument(
        '--name',
        type=str,
        default='',
        help=
        'Name of output folder within logdir. Will use current date and time if blank'
    )
    parser.add_argument('--weight-histogram',
                        action='store_true',
                        help='Save histograms of the weights if set')
    parser.add_argument('--use-hex-ssp', action='store_true')
    parser.add_argument('--optimizer',
                        type=str,
                        default='adam',
                        choices=['sgd', 'adam', 'rmsprop'])

    args = parser.parse_args()

    args.limits = tuple(float(v) for v in args.limits.split(','))

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    dataset_name = 'data/ssp_cleanup_dataset_dim{}_seed{}_items{}_limit{}_samples{}.npz'.format(
        args.dim, args.seed, args.n_items, args.limits[1], args.n_samples)

    final_test_samples = 100
    final_test_items = 15
    final_test_dataset_name = 'data/ssp_cleanup_test_dataset_dim{}_seed{}_items{}_limit{}_samples{}.npz'.format(
        args.dim, args.seed, final_test_items, args.limits[1],
        final_test_samples)

    if not os.path.exists('data'):
        os.makedirs('data')

    rng = np.random.RandomState(seed=args.seed)
    if args.use_hex_ssp:
        x_axis_sp, y_axis_sp = get_axes(dim=args.dim, n=3, seed=args.seed)
    else:
        x_axis_sp = make_good_unitary(args.dim, rng=rng)
        y_axis_sp = make_good_unitary(args.dim, rng=rng)

    if args.noise_type == 'gaussian':
        # Simple generation
        clean_ssps = np.zeros((args.n_samples, args.dim))
        coords = np.zeros((args.n_samples, 2))
        for i in range(args.n_samples):
            x = np.random.uniform(low=args.limits[0], high=args.limits[1])
            y = np.random.uniform(low=args.limits[2], high=args.limits[3])

            clean_ssps[i, :] = encode_point(x,
                                            y,
                                            x_axis_sp=x_axis_sp,
                                            y_axis_sp=y_axis_sp).v
            coords[i, 0] = x
            coords[i, 1] = y
        # Gaussian noise will be added later
        noisy_ssps = clean_ssps.copy()
    else:

        if os.path.exists(dataset_name):
            print("Loading dataset")
            data = np.load(dataset_name)
            clean_ssps = data['clean_ssps']
            noisy_ssps = data['noisy_ssps']
        else:
            print("Generating SSP cleanup dataset")
            clean_ssps, noisy_ssps, coords = generate_cleanup_dataset(
                x_axis_sp=x_axis_sp,
                y_axis_sp=y_axis_sp,
                n_samples=args.n_samples,
                dim=args.dim,
                n_items=args.n_items,
                limits=args.limits,
                seed=args.seed,
            )
            print("Dataset generation complete. Saving dataset")
            np.savez(
                dataset_name,
                clean_ssps=clean_ssps,
                noisy_ssps=noisy_ssps,
                coords=coords,
                x_axis_vec=x_axis_sp.v,
                y_axis_vec=x_axis_sp.v,
            )

    # check if the final test set has been generated yet
    if os.path.exists(final_test_dataset_name):
        print("Loading final test dataset")
        final_test_data = np.load(final_test_dataset_name)
        final_test_clean_ssps = final_test_data['clean_ssps']
        final_test_noisy_ssps = final_test_data['noisy_ssps']
    else:
        print("Generating final test dataset")
        final_test_clean_ssps, final_test_noisy_ssps, final_test_coords = generate_cleanup_dataset(
            x_axis_sp=x_axis_sp,
            y_axis_sp=y_axis_sp,
            n_samples=final_test_samples,
            dim=args.dim,
            n_items=final_test_items,
            limits=args.limits,
            seed=args.seed,
        )
        print("Final test generation complete. Saving dataset")
        np.savez(
            final_test_dataset_name,
            clean_ssps=final_test_clean_ssps,
            noisy_ssps=final_test_noisy_ssps,
            coords=final_test_coords,
            x_axis_vec=x_axis_sp.v,
            y_axis_vec=x_axis_sp.v,
        )

    # Add gaussian noise if required
    if args.noise_type == 'gaussian' or args.noise_type == 'both':
        noisy_ssps += np.random.normal(loc=0,
                                       scale=args.sigma,
                                       size=noisy_ssps.shape)

    n_samples = clean_ssps.shape[0]
    n_train = int(args.train_fraction * n_samples)
    n_test = n_samples - n_train
    assert (n_train > 0 and n_test > 0)
    train_clean = clean_ssps[:n_train, :]
    train_noisy = noisy_ssps[:n_train, :]
    test_clean = clean_ssps[n_train:, :]
    test_noisy = noisy_ssps[n_train:, :]

    # NOTE: this dataset is actually generic and can take any input/output mapping
    dataset_train = CoordDecodeDataset(vectors=train_noisy, coords=train_clean)
    dataset_test = CoordDecodeDataset(vectors=test_noisy, coords=test_clean)
    dataset_final_test = CoordDecodeDataset(vectors=final_test_noisy_ssps,
                                            coords=final_test_clean_ssps)

    trainloader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=0,
    )

    # For testing just do everything in one giant batch
    testloader = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=len(dataset_test),
        shuffle=False,
        num_workers=0,
    )

    final_testloader = torch.utils.data.DataLoader(
        dataset_final_test,
        batch_size=len(dataset_final_test),
        shuffle=False,
        num_workers=0,
    )

    model = FeedForward(dim=dataset_train.dim,
                        hidden_size=args.hidden_size,
                        output_size=dataset_train.dim)

    # Open a tensorboard writer if a logging directory is given
    if args.logdir != '':
        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        save_dir = osp.join(args.logdir, current_time)
        writer = SummaryWriter(log_dir=save_dir)
        if args.weight_histogram:
            # Log the initial parameters
            for name, param in model.named_parameters():
                writer.add_histogram('parameters/' + name,
                                     param.clone().cpu().data.numpy(), 0)

    mse_criterion = nn.MSELoss()
    cosine_criterion = nn.CosineEmbeddingLoss()

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    else:
        raise NotImplementedError

    for e in range(args.epochs):
        print('Epoch: {0}'.format(e + 1))

        avg_mse_loss = 0
        avg_cosine_loss = 0
        n_batches = 0
        for i, data in enumerate(trainloader):

            noisy, clean = data

            if noisy.size()[0] != args.batch_size:
                continue  # Drop data, not enough for a batch
            optimizer.zero_grad()

            outputs = model(noisy)

            mse_loss = mse_criterion(outputs, clean)
            # Modified to use CosineEmbeddingLoss
            cosine_loss = cosine_criterion(outputs, clean,
                                           torch.ones(args.batch_size))

            avg_cosine_loss += cosine_loss.data.item()
            avg_mse_loss += mse_loss.data.item()
            n_batches += 1

            if args.loss == 'cosine':
                cosine_loss.backward()
            else:
                mse_loss.backward()

            # print(loss.data.item())

            optimizer.step()

        print(avg_cosine_loss / n_batches)

        if args.logdir != '':
            if n_batches > 0:
                avg_cosine_loss /= n_batches
                writer.add_scalar('avg_cosine_loss', avg_cosine_loss, e + 1)
                writer.add_scalar('avg_mse_loss', avg_mse_loss, e + 1)

            if args.weight_histogram and (e + 1) % 10 == 0:
                for name, param in model.named_parameters():
                    writer.add_histogram('parameters/' + name,
                                         param.clone().cpu().data.numpy(),
                                         e + 1)

    print("Testing")
    with torch.no_grad():

        for label, loader in zip(['test', 'final_test'],
                                 [testloader, final_testloader]):

            # Everything is in one batch, so this loop will only happen once
            for i, data in enumerate(loader):

                noisy, clean = data

                outputs = model(noisy)

                mse_loss = mse_criterion(outputs, clean)
                # Modified to use CosineEmbeddingLoss
                cosine_loss = cosine_criterion(outputs, clean,
                                               torch.ones(len(loader)))

                print(cosine_loss.data.item())

            if args.logdir != '':
                # TODO: get a visualization of the performance

                # show plots of the noisy, clean, and cleaned up with the network
                # note that the plotting mechanism itself uses nearest neighbors, so has a form of cleanup built in

                xs = np.linspace(args.limits[0], args.limits[1], 256)
                ys = np.linspace(args.limits[0], args.limits[1], 256)

                heatmap_vectors = get_heatmap_vectors(xs, ys, x_axis_sp,
                                                      y_axis_sp)

                noisy_coord = ssp_to_loc_v(noisy, heatmap_vectors, xs, ys)

                pred_coord = ssp_to_loc_v(outputs, heatmap_vectors, xs, ys)

                clean_coord = ssp_to_loc_v(clean, heatmap_vectors, xs, ys)

                fig_noisy_coord, ax_noisy_coord = plt.subplots()
                fig_pred_coord, ax_pred_coord = plt.subplots()
                fig_clean_coord, ax_clean_coord = plt.subplots()

                plot_predictions_v(noisy_coord,
                                   clean_coord,
                                   ax_noisy_coord,
                                   min_val=args.limits[0],
                                   max_val=args.limits[1],
                                   fixed_axes=True)

                plot_predictions_v(pred_coord,
                                   clean_coord,
                                   ax_pred_coord,
                                   min_val=args.limits[0],
                                   max_val=args.limits[1],
                                   fixed_axes=True)

                plot_predictions_v(clean_coord,
                                   clean_coord,
                                   ax_clean_coord,
                                   min_val=args.limits[0],
                                   max_val=args.limits[1],
                                   fixed_axes=True)

                writer.add_figure('{}/original_noise'.format(label),
                                  fig_noisy_coord)
                writer.add_figure('{}/test_set_cleanup'.format(label),
                                  fig_pred_coord)
                writer.add_figure('{}/ground_truth'.format(label),
                                  fig_clean_coord)
                # fig_hist = plot_histogram(predictions=outputs, coords=coord)
                # writer.add_figure('test set histogram', fig_hist)
                writer.add_scalar('{}/test_cosine_loss'.format(label),
                                  cosine_loss.data.item())
                writer.add_scalar('{}/test_mse_loss'.format(label),
                                  mse_loss.data.item())

    # Close tensorboard writer
    if args.logdir != '':
        writer.close()

        torch.save(model.state_dict(), osp.join(save_dir, 'model.pt'))

        params = vars(args)
        # # Additionally save the axis vectors used
        # params['x_axis_vec'] = list(x_axis_sp.v)
        # params['y_axis_vec'] = list(y_axis_sp.v)
        with open(osp.join(save_dir, "params.json"), "w") as f:
            json.dump(params, f)
    truth = np.zeros((res * res, 2))

    # computing 'predicted' coordinates, where the agent thinks it is
    predictions[:, :] = ssp_to_loc_v(batch_data, hmv, xs, ys)

    truth[:, :] = ssp_to_loc_v(true_ssps, hmv, xs, ys)

    squared_error = np.sum(
        np.linalg.norm(predictions[wall_overlay == False, :] -
                       coords[wall_overlay == False, :],
                       axis=1)**2)

    plot_predictions_v(
        predictions=predictions[wall_overlay == False, :],
        coords=coords[wall_overlay == False, :],
        ax=ax[1, bi],
        min_val=limit_low,
        max_val=limit_high,
        fixed_axes=True,
    )

    plot_predictions_v(
        predictions=truth[wall_overlay == False, :],
        coords=coords[wall_overlay == False, :],
        ax=ax[0, bi],
        min_val=limit_low,
        max_val=limit_high,
        fixed_axes=True,
    )

    ax[0, bi].set_axis_off()
    ax[1, bi].set_axis_off()