Example #1
0
    def apply_config(self):

        self.seed = self.config[0]
        self.dim = self.config[1]
        self.res = self.config[2]
        self.limits = self.config[3]
        self.x_spacing = self.config[4]
        self.y_spacing = self.config[5]
        self.x_offset = self.config[6]
        self.y_offset = self.config[7]
        self.xs = np.linspace(-self.limits, self.limits, self.res)
        self.ys = np.linspace(-self.limits, self.limits, self.res)
        self.x_axis_vec, self.y_axis_vec = angle_spacing_axes(
            #rng=np.random.RandomState(seed=self.seed),
            ang_x=self.x_spacing,
            ang_y=self.y_spacing,
            off_x=self.x_offset,
            off_y=self.y_offset,
            dim=self.dim,
        )
        self.heatmap_vectors = get_heatmap_vectors(xs=self.xs,
                                                   ys=self.ys,
                                                   x_axis_sp=self.x_axis_vec,
                                                   y_axis_sp=self.y_axis_vec)

        # Origin point for similarity measures
        self.origin = np.zeros((self.dim, ))
        self.origin[0] = 1
Example #2
0
def test_generate_maze_sp(size=10,
                          limit_low=-5,
                          limit_high=5,
                          res=64,
                          dim=512,
                          seed=13):
    from spatial_semantic_pointers.utils import make_good_unitary, get_heatmap_vectors
    from spatial_semantic_pointers.plots import plot_heatmap

    rng = np.random.RandomState(seed=seed)

    x_axis_sp = make_good_unitary(dim=dim, rng=rng)
    y_axis_sp = make_good_unitary(dim=dim, rng=rng)

    xs = np.linspace(limit_low, limit_high, res)
    ys = np.linspace(limit_low, limit_high, res)

    sp, maze, fine_maze = generate_maze_sp(size,
                                           xs,
                                           ys,
                                           x_axis_sp,
                                           y_axis_sp,
                                           normalize=True,
                                           obstacle_ratio=.2,
                                           map_style='blocks')

    fig, ax = plt.subplots(1, 4)

    ax[0].imshow(maze)
    ax[1].imshow(fine_maze)
    heatmap_vectors = get_heatmap_vectors(xs, ys, x_axis_sp, y_axis_sp)
    plot_heatmap(sp.v,
                 heatmap_vectors,
                 ax[2],
                 xs,
                 ys,
                 name='',
                 vmin=-1,
                 vmax=1,
                 cmap='plasma',
                 invert=True)
    plot_heatmap(sp.v,
                 heatmap_vectors,
                 ax[3],
                 xs,
                 ys,
                 name='',
                 vmin=None,
                 vmax=None,
                 cmap='plasma',
                 invert=True)

    plt.show()
Example #3
0
    def apply_config(self):

        self.seed = self.config[0]
        self.dim = self.config[1]
        self.res = self.config[2]
        self.spacing = self.config[3]
        self.limits = self.config[4]
        self.xs = np.linspace(-self.limits, self.limits, self.res)
        self.ys = np.linspace(-self.limits, self.limits, self.res)
        self.x_axis_vec, self.y_axis_vec = make_periodic_axes(
            rng=np.random.RandomState(seed=self.seed),
            dim=self.dim,
            spacing=self.spacing,
            axis_angles=[0, 120, 240])
        self.heatmap_vectors = get_heatmap_vectors(xs=self.xs,
                                                   ys=self.ys,
                                                   x_axis_sp=self.x_axis_vec,
                                                   y_axis_sp=self.y_axis_vec)

        # Origin point for similarity measures
        self.origin = np.zeros((self.dim, ))
        self.origin[0] = 1
Example #4
0
def experiment(dim=512,
               n_hierarchy=3,
               n_items=16,
               seed=0,
               limit=5,
               res=128,
               thresh=0.5,
               neural=False,
               neurons_per_dim=25,
               time_per_item=1.0,
               max_items=100):
    rng = np.random.RandomState(seed=seed)

    X, Y = get_fixed_dim_sub_toriod_axes(
        dim=dim,
        n_proj=3,
        scale_ratio=0,
        scale_start_index=0,
        rng=rng,
        eps=0.001,
    )

    xs = np.linspace(-limit, limit, res)
    ys = np.linspace(-limit, limit, res)
    hmv = get_heatmap_vectors(xs, ys, X, Y)

    item_vecs = rng.normal(size=(n_items, dim))
    for i in range(n_items):
        item_vecs[i, :] = item_vecs[i, :] / np.linalg.norm(item_vecs[i, :])

    locations = rng.uniform(low=-limit, high=limit, size=(n_items, 2))

    if n_hierarchy == 1:  # no hierarchy case

        # Encode items into memory
        mem = np.zeros((dim, ))
        for i in range(n_items):
            mem += (spa.SemanticPointer(data=item_vecs[i, :]) *
                    encode_point(locations[i, 0], locations[i, 1], X, Y)).v
        mem /= np.linalg.norm(mem)

        mem_sp = spa.SemanticPointer(data=mem)

        estims = np.zeros((
            n_items,
            dim,
        ))
        sims = np.zeros((n_items, ))
        if neural:
            # save time for very large numbers of items
            n_exp_items = min(n_items, max_items)
            estims = np.zeros((
                n_exp_items,
                dim,
            ))
            sims = np.zeros((n_exp_items, ))

            model = nengo.Network(seed=seed)
            with model:
                input_node = nengo.Node(
                    lambda t: item_vecs[int(np.floor(t)) % n_items, :],
                    size_in=0,
                    size_out=dim)
                mem_node = nengo.Node(mem, size_in=0, size_out=dim)

                cconv = nengo.networks.CircularConvolution(
                    n_neurons=neurons_per_dim, dimensions=dim, invert_b=True)

                nengo.Connection(mem_node, cconv.input_a)
                nengo.Connection(input_node, cconv.input_b)

                out_node = nengo.Node(size_in=dim, size_out=0)

                nengo.Connection(cconv.output, out_node)

                p_out = nengo.Probe(out_node, synapse=0.01)

            sim = nengo.Simulator(model)
            sim.run(n_exp_items * time_per_item)

            output_data = sim.data[p_out]
            timesteps_per_item = int(time_per_item / 0.001)

            # timestep offset to cancel transients
            offset = 100
            for i in range(n_exp_items):
                estims[i, :] = output_data[i * timesteps_per_item +
                                           offset:(i + 1) *
                                           timesteps_per_item, :].mean(axis=0)
                sims[i] = np.dot(
                    estims[i, :],
                    encode_point(locations[i, 0], locations[i, 1], X, Y).v)

            pred_locs = ssp_to_loc_v(estims, hmv, xs, ys)

            errors = np.linalg.norm(pred_locs - locations[:n_exp_items, :],
                                    axis=1)

            accuracy = len(np.where(errors < thresh)[0]) / n_items

            rmse = np.sqrt(np.mean(errors**2))

            sim = np.mean(sims)
        else:
            # retrieve items
            for i in range(n_items):
                estims[i, :] = (mem_sp *
                                ~spa.SemanticPointer(data=item_vecs[i, :])).v

                sims[i] = np.dot(
                    estims[i, :],
                    encode_point(locations[i, 0], locations[i, 1], X, Y).v)

            pred_locs = ssp_to_loc_v(estims, hmv, xs, ys)

            errors = np.linalg.norm(pred_locs - locations, axis=1)

            accuracy = len(np.where(errors < thresh)[0]) / n_items

            rmse = np.sqrt(np.mean(errors**2))

            sim = np.mean(sims)

    elif n_hierarchy == 2:
        # TODO: generate vocab and input sequences

        n_ids = int(np.sqrt(n_items))
        f_n_ids = np.sqrt(n_items)

        id_vecs = rng.normal(size=(n_ids, dim))
        for i in range(n_ids):
            id_vecs[i, :] = id_vecs[i, :] / np.linalg.norm(id_vecs[i, :])

        # items to be included in each ID vec
        item_sums = np.zeros((n_ids, dim))
        item_loc_sums = np.zeros((n_ids, dim))
        for i in range(n_items):
            id_ind = min(i // n_ids, n_ids - 1)
            # id_ind = min(int(i / f_n_ids), n_ids - 1)
            item_sums[id_ind, :] += item_vecs[i, :]
            item_loc_sums[id_ind, :] += (
                spa.SemanticPointer(data=item_vecs[i, :]) *
                encode_point(locations[i, 0], locations[i, 1], X, Y)).v

        # Encode id_vecs into memory, each id is bound to something that has similarity to all items in the ID's map
        mem = np.zeros((dim, ))
        for i in range(n_ids):
            # normalize previous memories
            item_sums[i, :] = item_sums[i, :] / np.linalg.norm(item_sums[i, :])
            item_loc_sums[i, :] = item_loc_sums[i, :] / np.linalg.norm(
                item_loc_sums[i, :])

            mem += (spa.SemanticPointer(data=id_vecs[i, :]) *
                    spa.SemanticPointer(data=item_sums[i, :])).v
        mem /= np.linalg.norm(mem)

        mem_sp = spa.SemanticPointer(data=mem)

        estims = np.zeros((
            n_items,
            dim,
        ))
        sims = np.zeros((n_items, ))

        # retrieve items
        for i in range(n_items):
            # noisy ID for the map with this item
            estim_id = (mem_sp * ~spa.SemanticPointer(data=item_vecs[i, :])).v

            # get closest clean match
            id_sims = np.zeros((n_ids, ))
            for j in range(n_ids):
                id_sims[j] = np.dot(estim_id, id_vecs[j, :])

            best_ind = np.argmax(id_sims)

            # clean_id = id_vecs[best_ind, :]

            # item_loc_sums comes from the associative mapping from clean_id

            estims[i, :] = (
                spa.SemanticPointer(data=item_loc_sums[best_ind, :]) *
                ~spa.SemanticPointer(data=item_vecs[i, :])).v

            sims[i] = np.dot(
                estims[i, :],
                encode_point(locations[i, 0], locations[i, 1], X, Y).v)

        pred_locs = ssp_to_loc_v(estims, hmv, xs, ys)

        errors = np.linalg.norm(pred_locs - locations, axis=1)

        accuracy = len(np.where(errors < thresh)[0]) / n_items

        rmse = np.sqrt(np.mean(errors**2))

        sim = np.mean(sims)

    elif n_hierarchy == 3:
        # n_ids = int(np.cbrt(n_items))
        f_n_ids = np.cbrt(n_items)
        n_ids = int(np.ceil(np.cbrt(n_items)))
        n_ids_inner = int(np.ceil(np.sqrt(n_items / n_ids)))
        # f_n_ids = np.cbrt(n_items)

        id_outer_vecs = rng.normal(size=(n_ids, dim))
        id_inner_vecs = rng.normal(size=(n_ids_inner, dim))
        for i in range(n_ids):
            id_outer_vecs[i, :] = id_outer_vecs[i, :] / np.linalg.norm(
                id_outer_vecs[i, :])
            # for j in range(n_ids):
            #     id_inner_vecs[i*n_ids+j, :] = id_inner_vecs[i*n_ids+j, :] / np.linalg.norm(id_inner_vecs[i*n_ids+j, :])
        for i in range(n_ids_inner):
            id_inner_vecs[i, :] = id_inner_vecs[i, :] / np.linalg.norm(
                id_inner_vecs[i, :])

        # items to be included in each ID vec
        item_outer_sums = np.zeros((n_ids, dim))
        # item_inner_sums = np.zeros((n_ids*n_ids, dim))
        item_inner_sums = np.zeros((n_ids_inner, dim))
        item_loc_outer_sums = np.zeros((n_ids, dim))
        # item_loc_inner_sums = np.zeros((n_ids*n_ids, dim))
        item_loc_inner_sums = np.zeros((n_ids_inner, dim))
        for i in range(n_items):

            id_outer_ind = min(int(i / (f_n_ids * f_n_ids)), n_ids - 1)
            id_inner_ind = min(int(i / f_n_ids), n_ids_inner - 1)

            item_outer_sums[id_outer_ind, :] += item_vecs[i, :]
            item_inner_sums[id_inner_ind, :] += item_vecs[i, :]

            item_loc_outer_sums[id_outer_ind, :] += (
                spa.SemanticPointer(data=item_vecs[i, :]) *
                encode_point(locations[i, 0], locations[i, 1], X, Y)).v
            item_loc_inner_sums[id_inner_ind, :] += (
                spa.SemanticPointer(data=item_vecs[i, :]) *
                encode_point(locations[i, 0], locations[i, 1], X, Y)).v

        # Encode id_vecs into memory, each id is bound to something that has similarity to all items in the ID's map
        mem_outer = np.zeros((dim, ))
        mem_inner = np.zeros((
            n_ids,
            dim,
        ))
        for i in range(n_ids):
            # normalize previous memories
            item_outer_sums[i, :] = item_outer_sums[i, :] / np.linalg.norm(
                item_outer_sums[i, :])
            item_loc_outer_sums[i, :] = item_loc_outer_sums[
                i, :] / np.linalg.norm(item_loc_outer_sums[i, :])

            mem_outer += (spa.SemanticPointer(data=id_outer_vecs[i, :]) *
                          spa.SemanticPointer(data=item_outer_sums[i, :])).v

        for j in range(n_ids_inner):
            # normalize previous memories
            item_inner_sums[j, :] = item_inner_sums[j, :] / np.linalg.norm(
                item_inner_sums[j, :])
            item_loc_inner_sums[j, :] = item_loc_inner_sums[
                j, :] / np.linalg.norm(item_loc_inner_sums[j, :])

            i = min(int(j / n_ids), n_ids - 1)

            mem_inner[i, :] += (
                spa.SemanticPointer(data=id_inner_vecs[j, :]) *
                spa.SemanticPointer(data=item_inner_sums[j, :])).v

            mem_inner[i, :] /= np.linalg.norm(mem_inner[i, :])
        mem_outer /= np.linalg.norm(mem_outer)

        mem_outer_sp = spa.SemanticPointer(data=mem_outer)

        estims = np.zeros((
            n_items,
            dim,
        ))
        sims = np.zeros((n_items, ))

        if neural:
            # time for each item, in seconds
            time_per_item = 1.0
            model = nengo.Network(seed=seed)
            with model:
                inp_node = nengo.Node('?', size_in=0, size_out=dim)

                estim_outer_id = nengo.Ensemble(dimension=dim,
                                                n_neurons=dim *
                                                neurons_per_dim)

                out_node = nengo.Node(size_in=dim, size_out=0)

                p_out = nengo.Probe(out_node, synapse=0.01)

            sim = nengo.Simulator(model)
            sim.run(n_items * time_per_item)
        else:
            # non-neural version

            # retrieve items
            for i in range(n_items):
                # noisy outer ID for the map with this item
                estim_outer_id = (mem_outer_sp *
                                  ~spa.SemanticPointer(data=item_vecs[i, :])).v

                # get closest clean match
                id_sims = np.zeros((n_ids))
                for j in range(n_ids):
                    id_sims[j] = np.dot(estim_outer_id, id_outer_vecs[j, :])

                best_ind = np.argmax(id_sims)

                # noisy inner ID for the map with this item
                estim_inner_id = (
                    spa.SemanticPointer(data=mem_inner[best_ind, :]) *
                    ~spa.SemanticPointer(data=item_vecs[i, :])).v

                # get closest clean match
                id_sims = np.zeros((n_ids_inner))
                for j in range(n_ids_inner):
                    id_sims[j] = np.dot(estim_inner_id, id_inner_vecs[j, :])

                best_ind = np.argmax(id_sims)

                # item_loc_sums comes from the associative mapping from clean_id

                estims[i, :] = (spa.SemanticPointer(
                    data=item_loc_inner_sums[best_ind, :]) *
                                ~spa.SemanticPointer(data=item_vecs[i, :])).v

                sims[i] = np.dot(
                    estims[i, :],
                    encode_point(locations[i, 0], locations[i, 1], X, Y).v)

        pred_locs = ssp_to_loc_v(estims, hmv, xs, ys)

        errors = np.linalg.norm(pred_locs - locations, axis=1)

        accuracy = len(np.where(errors < thresh)[0]) / n_items

        rmse = np.sqrt(np.mean(errors**2))

        sim = np.mean(sims)
    else:
        # 4 split hierarchy

        vocab = spa.Vocabulary(dimensions=dim,
                               pointer_gen=np.random.RandomState(seed=seed))
        filler_id_keys = []
        filler_keys = []
        mapping = {}

        items_left = n_items
        n_levels = 0
        while items_left > 1:
            n_levels += 1
            items_left /= 4

        print(n_levels)

        # Location Values, labelled SSP
        for i in range(n_items):
            # vocab.populate('Item{}'.format(i))
            vocab.add('Loc{}'.format(i),
                      encode_point(locations[i, 0], locations[i, 1], X, Y).v)

        # level IDs, e.g. CITY, PROVINCE, COUNTRY
        for i in range(n_levels):
            vocab.populate('LevelSlot{}.unitary()'.format(i))
            # sp = spa.SemanticPointer()

        # Item IDs, e.g. Waterloo_ID
        for i in range(n_items):
            vocab.populate('ItemID{}.unitary()'.format(i))

        # level labels (fillers for level ID slots), e.g. Waterloo_ID, Ontario_ID, Canada_ID
        for i in range(n_levels):
            for j in range(int(n_items / (4**(n_levels - i - 1)))):
                vocab.populate('LevelFillerID{}_{}.unitary()'.format(i, j))
                # filler_id_keys.append('LevelFillerID{}_{}'.format(i, j))
                # filler_keys.append('LevelFiller{}_{}'.format(i, j))
                # mapping['LevelFillerID{}_{}'.format(i, j)] = 'LevelFiller{}_{}'.format(i, j)

        # Second last level with item*location pairs
        for i in range(int(n_items / 4)):
            id_str = []
            for k in range(n_levels - 1):
                id_str.append('LevelSlot{} * LevelFillerID{}_{}'.format(
                    k, k, int(i * 4 / (4**(n_levels - k - 1)))))

            data_str = []
            for j in range(4):
                ind = i * 4 + j
                data_str.append('ItemID{}*Loc{}'.format(ind, ind))
                vocab.populate('Item{} = ({}).normalized()'.format(
                    # i, ' + '.join(id_str + ['LevelSlot{} * LevelFillerID{}_{}'.format(n_levels - 2, n_levels - 2, j)])
                    ind,
                    ' + '.join(id_str + [
                        'LevelSlot{} * LevelFillerID{}_{}'.format(
                            n_levels - 1, n_levels - 1, j)
                    ])))

            # vocab.populate('LevelFiller{}_{} = {}'.format(n_levels - 1, i, ' + '.join(data_str)))
            vocab.populate('LevelFiller{}_{} = ({}).normalized()'.format(
                n_levels - 2, i, ' + '.join(data_str)))

            # only appending the ones used
            filler_id_keys.append('LevelFillerID{}_{}'.format(n_levels - 2, i))
            filler_keys.append('LevelFiller{}_{}'.format(n_levels - 2, i))
            mapping['LevelFillerID{}_{}'.format(
                n_levels - 2, i)] = 'LevelFiller{}_{}'.format(n_levels - 2, i)

        print(sorted(list(vocab.keys())))

        # Given each ItemID, calculate the corresponding Loc
        # Can map from ItemID{X} -> Item{X}
        # Query based on second last levelID to get the appropriate LevelFillerID
        # map from LevelFillerID -> LevelFiller
        # do the query LevelFiller *~ ItemID{X} to get Loc{X}

        possible_level_filler_id_vecs = np.zeros((int(n_items / 4), dim))
        for i in range(int(n_items / 4)):
            possible_level_filler_id_vecs[i] = vocab[
                'LevelFillerID{}_{}'.format(n_levels - 2, i)].v

        estims = np.zeros((
            n_items,
            dim,
        ))
        sims = np.zeros((n_items, ))

        if neural:
            # save time for very large numbers of items
            n_exp_items = min(n_items, max_items)
            estims = np.zeros((
                n_exp_items,
                dim,
            ))
            sims = np.zeros((n_exp_items, ))

            filler_id_vocab = vocab.create_subset(keys=filler_id_keys)
            filler_vocab = vocab.create_subset(keys=filler_keys)
            filler_all_vocab = vocab.create_subset(keys=filler_keys +
                                                   filler_id_keys)

            model = nengo.Network(seed=seed)
            with model:
                # The changing item query. Full expanded item, not just ID
                item_input_node = nengo.Node(lambda t: vocab['Item{}'.format(
                    int(np.floor(t)) % n_items)].v,
                                             size_in=0,
                                             size_out=dim)
                # item_input_node = spa.Transcode(lambda t: 'Item{}'.format(int(np.floor(t))), output_vocab=vocab)

                # The ID for the changing item query
                item_id_input_node = nengo.Node(lambda t: vocab[
                    'ItemID{}'.format(int(np.floor(t)) % n_items)].v,
                                                size_in=0,
                                                size_out=dim)
                # item_id_input_node = spa.Transcode(lambda t: 'ItemID{}'.format(int(np.floor(t))), output_vocab=vocab)

                # Fixed memory based on the level slot to access
                level_slot_input_node = nengo.Node(
                    lambda t: vocab['LevelSlot{}'.format(n_levels - 2)].v,
                    size_in=0,
                    size_out=dim)

                model.cconv_noisy_level_filler = nengo.networks.CircularConvolution(
                    n_neurons=neurons_per_dim * 2,
                    dimensions=dim,
                    invert_b=True)

                nengo.Connection(item_input_node,
                                 model.cconv_noisy_level_filler.input_a)
                nengo.Connection(level_slot_input_node,
                                 model.cconv_noisy_level_filler.input_b)

                # Note: this is set up as heteroassociative between ID and the content (should clean up as well)
                model.noisy_level_filler_id_cleanup = spa.ThresholdingAssocMem(
                    threshold=0.4,
                    input_vocab=filler_id_vocab,
                    output_vocab=filler_vocab,
                    # mapping=vocab.keys(),
                    mapping=mapping,
                    function=lambda x: x > 0.)

                nengo.Connection(model.cconv_noisy_level_filler.output,
                                 model.noisy_level_filler_id_cleanup.input)

                model.cconv_location = nengo.networks.CircularConvolution(
                    n_neurons=neurons_per_dim * 2,
                    dimensions=dim,
                    invert_b=True)

                nengo.Connection(model.noisy_level_filler_id_cleanup.output,
                                 model.cconv_location.input_a)
                nengo.Connection(item_id_input_node,
                                 model.cconv_location.input_b)

                out_node = nengo.Node(size_in=dim, size_out=0)

                nengo.Connection(model.cconv_location.output, out_node)

                p_out = nengo.Probe(out_node, synapse=0.01)

            sim = nengo.Simulator(model)
            sim.run(n_exp_items * time_per_item)

            output_data = sim.data[p_out]
            timesteps_per_item = int(time_per_item / 0.001)

            # timestep offset to cancel transients
            offset = 100
            for i in range(n_exp_items):
                estims[i, :] = output_data[i * timesteps_per_item +
                                           offset:(i + 1) *
                                           timesteps_per_item, :].mean(axis=0)
                sims[i] = np.dot(
                    estims[i, :],
                    encode_point(locations[i, 0], locations[i, 1], X, Y).v)

            pred_locs = ssp_to_loc_v(estims, hmv, xs, ys)

            errors = np.linalg.norm(pred_locs - locations[:n_exp_items, :],
                                    axis=1)

            accuracy = len(np.where(errors < thresh)[0]) / n_items

            rmse = np.sqrt(np.mean(errors**2))

            sim = np.mean(sims)
        else:
            # non-neural version

            # retrieve items
            for i in range(n_items):
                noisy_level_filler_id = vocab['Item{}'.format(
                    i)] * ~vocab['LevelSlot{}'.format(n_levels - 2)]
                # cleanup filler id
                n_fillers = int(n_items / 4)
                sim = np.zeros((n_fillers, ))
                for j in range(n_fillers):
                    sim[j] = np.dot(noisy_level_filler_id.v,
                                    possible_level_filler_id_vecs[j, :])

                filler_id_ind = np.argmax(sim)

                # query the appropriate filler
                loc_estim = vocab['LevelFiller{}_{}'.format(
                    n_levels - 2,
                    filler_id_ind)] * ~vocab['ItemID{}'.format(i)]

                estims[i, :] = loc_estim.v

                sims[i] = np.dot(
                    estims[i, :],
                    encode_point(locations[i, 0], locations[i, 1], X, Y).v)

            pred_locs = ssp_to_loc_v(estims, hmv, xs, ys)

            errors = np.linalg.norm(pred_locs - locations, axis=1)

            accuracy = len(np.where(errors < thresh)[0]) / n_items

            rmse = np.sqrt(np.mean(errors**2))

            sim = np.mean(sims)

    return rmse, accuracy, sim
Example #5
0
n_sensors = data['dist_sensors'].shape[3]
n_mazes = data['coarse_mazes'].shape[0]

if args.encoding == 'ssp':
    dim = 512
elif args.encoding == '2d':
    dim = 2
    # ssp_scaling = 1  # no scaling used for 2D coordinates directly
elif args.encoding == 'pc':
    dim = args.n_place_cells
    # ssp_scaling = 1
else:
    raise NotImplementedError

# Used for visualization of test set performance using pos = ssp_to_loc(sp, heatmap_vectors, xs, ys)
heatmap_vectors = get_heatmap_vectors(xs, ys, x_axis_vec, y_axis_vec)

n_samples = args.n_samples
batch_size = args.batch_size
n_epochs = args.n_epochs

# Input is the distance sensor measurements
if args.n_hidden_layers > 1:
    model = MLP(input_size=n_sensors + n_mazes,
                hidden_size=args.hidden_size,
                output_size=dim,
                n_layers=args.n_hidden_layers)
else:
    model = FeedForward(
        input_size=n_sensors + n_mazes,
        hidden_size=args.hidden_size,
    for seed in range(args.n_seeds):
        print("\x1b[2K\r Seed {} of {}".format(seed + 1, args.n_seeds),
              end="\r")
        rng = np.random.RandomState(seed=seed)
        X = make_good_unitary(args.dim, rng=rng)
        Y = make_good_unitary(args.dim, rng=rng)
        Z = make_good_unitary(args.dim, rng=rng)

        axes[seed, 0, :] = X.v
        axes[seed, 1, :] = Y.v
        axes[seed, 2, :] = Z.v

        square_heatmaps[seed, :, :] = np.tensordot(origin_vec,
                                                   get_heatmap_vectors(
                                                       xs=xs,
                                                       ys=ys,
                                                       x_axis_sp=X,
                                                       y_axis_sp=Y),
                                                   axes=([0], [2]))

        hex_heatmaps[seed, :, :] = np.tensordot(origin_vec,
                                                get_heatmap_vectors_hex(
                                                    xs=xs,
                                                    ys=ys,
                                                    x_axis_sp=X,
                                                    y_axis_sp=Y,
                                                    z_axis_sp=Z),
                                                axes=([0], [2]))

    np.savez(
        fname,
dim = 512
dim = 2048
# dim = 16
X, Y = get_fixed_dim_sub_toriod_axes(
    dim=dim,
    n_proj=3,
    scale_ratio=0,
    scale_start_index=0,
    rng=np.random.RandomState(seed=13),
    eps=0.001,
)
# X = make_good_unitary(dim=dim)
# Y = make_good_unitary(dim=dim)

res = 256  #128
limit = 25  #15#5
xs = np.linspace(-limit, limit, res)
ys = np.linspace(-limit, limit, res)
heatmap_vectors = get_heatmap_vectors(xs, ys, X, Y)
print(heatmap_vectors.shape)

for r in range(n_rows):
    for c in range(n_cols):

        ax[r, c].imshow(heatmap_vectors[:, :, r * n_cols + c])
        ax[r, c].set_axis_off()

# ax[-1, -1].imshow(heatmap_vectors[:, :, -1])

plt.show()
dim = 512
# dim = 1024

res = 128
xs = np.linspace(0, 10, res)
ys = np.linspace(0, 10, res)

# These will include the space that is the difference between any two nodes
xs_larger = np.linspace(-10, 10, res)
ys_larger = np.linspace(-10, 10, res)

x_axis_sp = make_good_unitary(dim)
y_axis_sp = make_good_unitary(dim)

heatmap_vectors = get_heatmap_vectors(xs, ys, x_axis_sp, y_axis_sp)
heatmap_vectors_larger = get_heatmap_vectors(xs_larger, ys_larger, x_axis_sp,
                                             y_axis_sp)

# Map
map_sp = spa.SemanticPointer(data=np.zeros((dim, )))
# version of the map with landmark IDs bound to each location
landmark_map_sp = spa.SemanticPointer(data=np.zeros((dim, )))

# Connectivity
# contains each connection egocentrically
con_ego_sp = spa.SemanticPointer(data=np.zeros((dim, )))
# contains each connection allocentrically
con_allo_sp = spa.SemanticPointer(data=np.zeros((dim, )))

# Agent Location
Example #9
0
        np.savez(cache_fname,
                 clean_vectors=clean_vectors,
                 noisy_vectors=noisy_vectors)


def angle_to_ssp(x):

    return encode_point(np.cos(x), np.sin(x), X, Y).v


def to_ssp(v):

    return encode_point(v[0], v[1], X, Y).v


heatmap_vectors = get_heatmap_vectors(xs, ys, X, Y)

# for direction SSP visualization
xs_dir = np.linspace(-1.5, 1.5, 32)
ys_dir = np.linspace(-1.5, 1.5, 32)
heatmap_vectors_dir = get_heatmap_vectors(xs_dir, ys_dir, X, Y)

if not use_spa:

    # spatial_heatmap = SpatialHeatmap(heatmap_vectors, xs, ys, cmap='plasma', vmin=None, vmax=None)

    # preferred_locations = hilbert_2d(-limit, limit, n_neurons, rng, p=8, N=2, normal_std=3)
    preferred_locations = hilbert_2d(limit_low,
                                     limit_high,
                                     n_neurons,
                                     rng,
Example #10
0
    phi_ays = [0., 0., -np.pi / 4., -np.pi / 6]
    phi_bys = [np.pi / 2., np.pi / 4., np.pi / 4., np.pi / 4]
    titles = [
        'X = [\u03C0/2, 0], Y = [0, \u03C0/2]',
        'X = [\u03C0/4, 0], Y = [0, \u03C0/4]',
        'X = [\u03C0/4, \u03C0/4], Y = [-\u03C0/4, \u03C0/4]',
        'X = [\u03C0/4, \u03C0/2], Y = [-\u03C0/6, \u03C0/4]',
    ]
    fontsize = 14
    angles = [0, 0, np.pi / 4., np.pi / 6.]
    loc = plticker.MultipleLocator(
        base=5)  # this locator puts ticks at regular intervals
    for i in range(4):
        X = unitary_5d(dim=dim, phi_a=phi_axs[i], phi_b=phi_bxs[i])
        Y = unitary_5d(dim=dim, phi_a=phi_ays[i], phi_b=phi_bys[i])
        hmv = get_heatmap_vectors(xs, ys, X, Y)
        im = ax[i].imshow(hmv[:, :, 0],
                          origin='lower',
                          interpolation='none',
                          extent=(xs[0], xs[-1], ys[0], ys[-1]),
                          vmin=None,
                          vmax=1)
        ax[i].set_title(titles[i], fontsize=fontsize)
        ax[i].xaxis.set_major_locator(loc)
        ax[i].yaxis.set_major_locator(loc)

    fig.colorbar(im, cax=ax[-1])
elif dim == 7:
    titles = [
        '\u03C6 = \u03C0/2, \u03B8 = 0',
        '\u03C6 = \u03C0/4, \u03B8 = 0',
def snapshot_localization_train_test_loaders(data,
                                             n_train_samples=1000,
                                             n_test_samples=1000,
                                             batch_size=10,
                                             encoding='ssp',
                                             n_mazes_to_use=0):

    # Option to use SSPs or the 2D location directly
    assert encoding in ['ssp', '2d']

    xs = data['xs']
    ys = data['ys']
    x_axis_vec = data['x_axis_sp']
    y_axis_vec = data['y_axis_sp']
    heatmap_vectors = get_heatmap_vectors(xs, ys, x_axis_vec, y_axis_vec)

    # positions = data['positions']

    # shape is (n_mazes, res, res, n_sensors)
    dist_sensors = data['dist_sensors']

    fine_mazes = data['fine_mazes']

    n_sensors = dist_sensors.shape[3]

    # ssps = data['ssps']

    n_mazes = data['coarse_mazes'].shape[0]
    dim = x_axis_vec.shape[0]

    for test_set, n_samples in enumerate([n_train_samples, n_test_samples]):

        sensor_inputs = np.zeros((n_samples, n_sensors))

        # these include outputs for every time-step
        ssp_outputs = np.zeros((n_samples, dim))

        # for the 2D encoding method
        pos_outputs = np.zeros((n_samples, 2))

        maze_ids = np.zeros((n_samples, n_mazes))

        for i in range(n_samples):
            # choose random maze and position in maze
            if n_mazes_to_use <= 0:
                # use all available mazes
                maze_ind = np.random.randint(low=0, high=n_mazes)
            else:
                # use only some mazes
                maze_ind = np.random.randint(low=0, high=n_mazes_to_use)
            xi = np.random.randint(low=0, high=len(xs))
            yi = np.random.randint(low=0, high=len(ys))
            # Keep choosing position until it is not inside a wall
            while fine_mazes[maze_ind, xi, yi] == 1:
                xi = np.random.randint(low=0, high=len(xs))
                yi = np.random.randint(low=0, high=len(ys))

            sensor_inputs[i, :] = dist_sensors[maze_ind, xi, yi, :]

            ssp_outputs[i, :] = heatmap_vectors[xi, yi, :]

            # one-hot maze ID
            maze_ids[i, maze_ind] = 1

            # for the 2D encoding method
            pos_outputs[i, :] = np.array([xs[xi], ys[yi]])

        if encoding == 'ssp':
            dataset = LocalizationSnapshotDataset(
                sensor_inputs=sensor_inputs,
                maze_ids=maze_ids,
                ssp_outputs=ssp_outputs,
            )
        elif encoding == '2d':
            dataset = LocalizationSnapshotDataset(
                sensor_inputs=sensor_inputs,
                maze_ids=maze_ids,
                ssp_outputs=pos_outputs,
            )

        if test_set == 0:
            trainloader = torch.utils.data.DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=True,
                num_workers=0,
            )
        elif test_set == 1:
            testloader = torch.utils.data.DataLoader(
                dataset,
                batch_size=n_samples,
                shuffle=True,
                num_workers=0,
            )

    return trainloader, testloader
def main():

    parser = argparse.ArgumentParser(
        'Traverse many graphs with and SSP algorithm and report metrics')

    parser.add_argument('--n-samples',
                        type=int,
                        default=5,
                        help='Number of different graphs to test')
    parser.add_argument('--seed',
                        type=int,
                        default=13,
                        help='Seed for training and generating axis SSPs')
    parser.add_argument('--dim',
                        type=int,
                        default=512,
                        help='Dimensionality of the SSPs')
    parser.add_argument('--res',
                        type=int,
                        default=128,
                        help='Resolution of the linspaces used')
    parser.add_argument('--normalize',
                        type=int,
                        default=1,
                        choices=[0, 1],
                        help='Whether or not to normalize SPs')
    parser.add_argument(
        '--diameter-increment',
        type=float,
        default=1.0,
        help='How much to expand ellipse diameter by on each step')

    args = parser.parse_args()

    # Convert to boolean
    args.normalize = args.normalize == 1

    # Metrics
    # set to 1 if the found path is the shortest
    shortest_path = np.zeros((args.n_samples))

    # set to 1 if the path found is valid (only uses connections that exist)
    valid_path = np.zeros((args.n_samples))

    # set to 1 if any path is found in the time allotted
    found_path = np.zeros((args.n_samples))

    np.random.seed(args.seed)

    xs = np.linspace(0, 10, args.res)
    ys = np.linspace(0, 10, args.res)

    x_axis_sp = make_good_unitary(args.dim)
    y_axis_sp = make_good_unitary(args.dim)

    heatmap_vectors = get_heatmap_vectors(xs, ys, x_axis_sp, y_axis_sp)

    # TEMP: putting this outside the loop for debugging
    graph_params = generate_graph(dim=args.dim,
                                  x_axis_sp=x_axis_sp,
                                  y_axis_sp=y_axis_sp)

    for n in range(args.n_samples):

        print("Sample {} of {}".format(n + 1, args.n_samples))

        # graph_params = generate_graph(dim=args.dim, x_axis_sp=x_axis_sp, y_axis_sp=y_axis_sp)

        elliptic_expansion = EllipticExpansion(
            x_axis_sp=x_axis_sp,
            y_axis_sp=y_axis_sp,
            xs=xs,
            ys=ys,
            heatmap_vectors=heatmap_vectors,
            diameter_increment=args.diameter_increment,
            normalize=args.normalize,
            debug_mode=True,
            **graph_params)

        path = elliptic_expansion.find_path(
            max_steps=10,  #15,#20,
            display=False,
            graph=graph_params['graph'],
            xs=xs,
            ys=ys,
            heatmap_vectors=heatmap_vectors)

        optimal_path = graph_params['graph'].search_graph(
            start_node=graph_params['start_landmark_id'],
            end_node=graph_params['end_landmark_id'],
        )

        print("found path is: {}".format(path))
        print("optimal path is: {}".format(optimal_path))

        if path is not None:
            found_path[n] = 1

            if graph_params['graph'].is_valid_path(path):
                valid_path[n] = 1
                print("path is valid")
            else:
                print("path is invalid")

            if path == optimal_path:
                shortest_path[n] = 1
                print("path is optimal")
            else:
                print("path is not optimal")

    print("Found path: {}".format(found_path.mean()))
    print("Valid path: {}".format(valid_path.mean()))
    print("Shortest path: {}".format(shortest_path.mean()))
Example #13
0
def run_and_gather_activations(
        seed=13,
        n_samples=1000,
        dataset='../../lab/reproducing/data/path_integration_trajectories_logits_200t_15s_seed13.npz',
        model_path='../output/ssp_path_integration/clipped/Mar22_15-24-10/ssp_path_integration_model.pt',
        encoding='ssp',
        rollout_length=100,
        batch_size=10,
        n_place_cells=256,
        encoding_func=None,  # added for frozen-learned encoding option

):

    torch.manual_seed(seed)
    np.random.seed(seed)

    data = np.load(dataset)

    x_axis_vec = data['x_axis_vec']
    y_axis_vec = data['y_axis_vec']

    pc_centers = data['pc_centers']
    #pc_activations = data['pc_activations']

    if encoding == 'ssp':
        encoding_dim = 512
        ssp_scaling = data['ssp_scaling']
    elif encoding == '2d':
        encoding_dim = 2
        ssp_scaling = 1
    elif encoding == 'pc':
        dim = n_place_cells
        ssp_scaling = 1
    elif encoding == 'frozen-learned':
        encoding_dim = 512
        ssp_scaling = 1
    elif encoding == 'pc-gauss' or encoding == 'pc-gauss-softmax':
        encoding_dim = 512
        ssp_scaling = 1
    else:
        raise NotImplementedError

    limit_low = 0 * ssp_scaling
    limit_high = 2.2 * ssp_scaling
    res = 128 #256

    xs = np.linspace(limit_low, limit_high, res)
    ys = np.linspace(limit_low, limit_high, res)

    if encoding == 'frozen-learned' or encoding == 'pc-gauss' or encoding == 'pc-gauss-softmax':
        # encoding for every point in a 2D linspace, for approximating a readout

        # FIXME: inefficient but will work for now
        heatmap_vectors = np.zeros((len(xs), len(ys), 512))

        for i, x in enumerate(xs):
            for j, y in enumerate(ys):
                heatmap_vectors[i, j, :] = encoding_func(
                    # batch dim
                    # np.array(
                    #     [[x, y]]
                    # )
                    # no batch dim
                    np.array(
                        [x, y]
                    )
                )

                heatmap_vectors[i, j, :] /= np.linalg.norm(heatmap_vectors[i, j, :])

    else:
        # Used for visualization of test set performance using pos = ssp_to_loc(sp, heatmap_vectors, xs, ys)
        heatmap_vectors = get_heatmap_vectors(xs, ys, x_axis_vec, y_axis_vec)


    model = SSPPathIntegrationModel(unroll_length=rollout_length, sp_dim=encoding_dim)

    model.load_state_dict(torch.load(model_path), strict=False)

    trainloader, testloader = train_test_loaders(
        data,
        n_train_samples=n_samples,
        n_test_samples=n_samples,
        rollout_length=rollout_length,
        batch_size=batch_size,
        encoding=encoding,
        encoding_func=encoding_func,
    )

    print("Testing")
    with torch.no_grad():
        # Everything is in one batch, so this loop will only happen once
        for i, data in enumerate(testloader):
            velocity_inputs, ssp_inputs, ssp_outputs = data

            ssp_pred, lstm_outputs = model.forward_activations(velocity_inputs, ssp_inputs)


        predictions = np.zeros((ssp_pred.shape[0]*ssp_pred.shape[1], 2))
        coords = np.zeros((ssp_pred.shape[0]*ssp_pred.shape[1], 2))
        activations = np.zeros((ssp_pred.shape[0]*ssp_pred.shape[1], model.lstm_hidden_size))

        assert rollout_length == ssp_pred.shape[0]

        # # For each neuron, contains the average activity at each spatial bin
        # # Computing for both ground truth and predicted location
        # rate_maps_pred = np.zeros((model.lstm_hidden_size, len(xs), len(ys)))
        # rate_maps_truth = np.zeros((model.lstm_hidden_size, len(xs), len(ys)))

        print("Computing predicted locations and true locations")
        # Using all data, one chunk at a time
        for ri in range(rollout_length):

            if encoding == 'ssp':
                # computing 'predicted' coordinates, where the agent thinks it is
                predictions[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_to_loc_v(
                    ssp_pred.detach().numpy()[ri, :, :],
                    heatmap_vectors, xs, ys
                )

                # computing 'ground truth' coordinates, where the agent should be
                coords[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_to_loc_v(
                    ssp_outputs.detach().numpy()[:, ri, :],
                    heatmap_vectors, xs, ys
                )
            elif encoding == '2d':
                # copying 'predicted' coordinates, where the agent thinks it is
                predictions[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_pred.detach().numpy()[ri, :, :]

                # copying 'ground truth' coordinates, where the agent should be
                coords[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_outputs.detach().numpy()[:, ri, :]
            elif encoding == 'pc':
                # (quick hack is to just use the most activated place cell center)
                predictions[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = pc_to_loc_v(
                    pc_activations=ssp_outputs.detach().numpy()[:, ri, :],
                    centers=pc_centers,
                    jitter=0.01,
                )

                coords[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = pc_to_loc_v(
                    pc_activations=ssp_outputs.detach().numpy()[:, ri, :],
                    centers=pc_centers,
                    jitter=0.01,
                )
            elif encoding == 'frozen-learned' or encoding == 'pc-gauss' or encoding == 'pc-gauss-softmax':
                # computing 'predicted' coordinates, where the agent thinks it is
                pred = ssp_pred.detach().numpy()[ri, :, :]
                pred = pred / pred.sum(axis=1)[:, np.newaxis]
                predictions[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_to_loc_v(
                    pred,
                    heatmap_vectors, xs, ys
                )

                # computing 'ground truth' coordinates, where the agent should be
                coord = ssp_outputs.detach().numpy()[:, ri, :]
                coord = coord / coord.sum(axis=1)[:, np.newaxis]
                coords[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_to_loc_v(
                    coord,
                    heatmap_vectors, xs, ys
                )

            # reshaping activations and converting to numpy array
            activations[ri*ssp_pred.shape[1]:(ri+1)*ssp_pred.shape[1], :] = lstm_outputs.detach().numpy()[ri, :, :]

    return activations, predictions, coords
Example #14
0
def run_and_gather_localization_activations(
        seed=13,
        n_samples=1000,
        dataset='../../localization/data/localization_trajectories_5m_200t_250s_seed13.npz',
        model_path='../../localization/output/ssp_trajectory_localization/May13_16-00-27/ssp_trajectory_localization_model.pt',
        encoding='ssp',
        rollout_length=100,
        batch_size=10,

):

    torch.manual_seed(seed)
    np.random.seed(seed)

    data = np.load(dataset)

    x_axis_vec = data['x_axis_vec']
    y_axis_vec = data['y_axis_vec']
    ssp_scaling = data['ssp_scaling']
    ssp_offset = data['ssp_offset']

    # shape of coarse maps is (n_maps, env_size, env_size)
    # some npz files had different naming, try both
    try:
        coarse_maps = data['coarse_maps']
    except KeyError:
        coarse_maps = data['coarse_mazes']
    n_maps = coarse_maps.shape[0]
    env_size = coarse_maps.shape[1]

    # shape of dist_sensors is (n_maps, n_trajectories, n_steps, n_sensors)
    n_sensors = data['dist_sensors'].shape[3]

    # shape of ssps is (n_maps, n_trajectories, n_steps, dim)
    dim = data['ssps'].shape[3]

    limit_low = -ssp_offset * ssp_scaling
    limit_high = (env_size - ssp_offset) * ssp_scaling
    res = 256

    xs = np.linspace(limit_low, limit_high, res)
    ys = np.linspace(limit_low, limit_high, res)

    # Used for visualization of test set performance using pos = ssp_to_loc(sp, heatmap_vectors, xs, ys)
    heatmap_vectors = get_heatmap_vectors(xs, ys, x_axis_vec, y_axis_vec)

    model = LocalizationModel(
        input_size=2 + n_sensors + n_maps,
        unroll_length=rollout_length,
        sp_dim=dim
    )

    model.load_state_dict(torch.load(model_path), strict=False)

    trainloader, testloader = localization_train_test_loaders(
        data,
        n_train_samples=n_samples,
        n_test_samples=n_samples,
        rollout_length=rollout_length,
        batch_size=batch_size,
        encoding=encoding,
    )

    print("Testing")
    with torch.no_grad():
        # Everything is in one batch, so this loop will only happen once
        for i, data in enumerate(testloader):
            combined_inputs, ssp_inputs, ssp_outputs = data

            ssp_pred, lstm_outputs = model.forward_activations(combined_inputs, ssp_inputs)

        predictions = np.zeros((ssp_pred.shape[0]*ssp_pred.shape[1], 2))
        coords = np.zeros((ssp_pred.shape[0]*ssp_pred.shape[1], 2))
        activations = np.zeros((ssp_pred.shape[0]*ssp_pred.shape[1], model.lstm_hidden_size))

        assert rollout_length == ssp_pred.shape[0]

        # # For each neuron, contains the average activity at each spatial bin
        # # Computing for both ground truth and predicted location
        # rate_maps_pred = np.zeros((model.lstm_hidden_size, len(xs), len(ys)))
        # rate_maps_truth = np.zeros((model.lstm_hidden_size, len(xs), len(ys)))

        print("Computing predicted locations and true locations")
        # Using all data, one chunk at a time
        for ri in range(rollout_length):

            if encoding == 'ssp':
                # computing 'predicted' coordinates, where the agent thinks it is
                predictions[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_to_loc_v(
                    ssp_pred.detach().numpy()[ri, :, :],
                    heatmap_vectors, xs, ys
                )

                # computing 'ground truth' coordinates, where the agent should be
                coords[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_to_loc_v(
                    ssp_outputs.detach().numpy()[:, ri, :],
                    heatmap_vectors, xs, ys
                )
            elif encoding == '2d':
                # copying 'predicted' coordinates, where the agent thinks it is
                predictions[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_pred.detach().numpy()[ri, :, :]

                # copying 'ground truth' coordinates, where the agent should be
                coords[ri * ssp_pred.shape[1]:(ri + 1) * ssp_pred.shape[1], :] = ssp_outputs.detach().numpy()[:, ri, :]

            # reshaping activations and converting to numpy array
            activations[ri*ssp_pred.shape[1]:(ri+1)*ssp_pred.shape[1], :] = lstm_outputs.detach().numpy()[ri, :, :]

    return activations, predictions, coords
mem_sp = spa.SemanticPointer(data=np.zeros((args.dim, )))
for key, value in items.items():
    # mem_sp += item_vocab[key] * encode_point_hex(value[0], value[1], X, Y, Z)
    mem_sp += item_vocab[key] * encode_point(value[0], value[1], X, Y)


def encode_func(pos):
    # return encode_point_hex(pos[0], pos[1], X, Y, Z).v
    return encode_point(pos[0], pos[1], X, Y).v


# xs = np.linspace(-1, args.env_size+1, 256)
if not os.path.exists('hmv_attractor_exp_{}.npz'.format(args.dim)):
    # hmv = get_heatmap_vectors_hex(xs, xs, X, Y, Z)
    hmv = get_heatmap_vectors(xs, xs, X, Y)
    np.savez('hmv_exp_{}.npz'.format(args.dim), hmv=hmv)
else:
    hmv = np.load('hmv_attractor_exp_{}.npz'.format(args.dim))['hmv']


def decode_func(ssp):
    vs = np.tensordot(ssp, hmv, axes=([0], [2]))

    xy = np.unravel_index(vs.argmax(), vs.shape)

    x = xs[xy[0]]
    y = xs[xy[1]]

    return np.array([x, y])
coarse_ys = np.linspace(ys[0], ys[-1], coarse_size)

map_array = coarse_mazes[args.maze_index, :, :]
limit_low = 0
limit_high = 13
encoding_func, repr_dim = get_encoding_function(args,
                                                limit_low=limit_low,
                                                limit_high=limit_high)

# x_axis_sp = spa.SemanticPointer(data=data['x_axis_sp'])
# y_axis_sp = spa.SemanticPointer(data=data['y_axis_sp'])
x_axis_vec = encoding_func(1, 0)
y_axis_vec = encoding_func(0, 1)
x_axis_sp = nengo_spa.SemanticPointer(data=x_axis_vec)
y_axis_sp = nengo_spa.SemanticPointer(data=y_axis_vec)
heatmap_vectors = get_heatmap_vectors(xs, ys, x_axis_sp, y_axis_sp)
coarse_heatmap_vectors = get_heatmap_vectors(coarse_xs, coarse_ys, x_axis_sp,
                                             y_axis_sp)

# fixed random set of locations for the goals
limit_range = xs[-1] - xs[0]

goal_sps = data['goal_sps']
goals = data['goals']
# print(np.min(goals))
# print(np.max(goals))
goals_scaled = ((goals - xs[0]) / limit_range) * coarse_size
# print(np.min(goals_scaled))
# print(np.max(goals_scaled))

n_goals = args.n_goals
def main():
    parser = argparse.ArgumentParser(
        'Train a network to clean up a noisy spatial semantic pointer')

    parser.add_argument('--loss',
                        type=str,
                        default='cosine',
                        choices=['cosine', 'mse'])
    parser.add_argument('--noise-type',
                        type=str,
                        default='memory',
                        choices=['memory', 'gaussian', 'both'])
    parser.add_argument(
        '--sigma',
        type=float,
        default=1.0,
        help='sigma on the gaussian noise if noise-type==gaussian')
    parser.add_argument('--train-fraction',
                        type=float,
                        default=.8,
                        help='proportion of the dataset to use for training')
    parser.add_argument(
        '--n-samples',
        type=int,
        default=10000,
        help=
        'Number of memories to generate. Total samples will be n-samples * n-items'
    )
    parser.add_argument('--n-items',
                        type=int,
                        default=12,
                        help='number of items in memory. Proxy for noisiness')
    parser.add_argument('--dim',
                        type=int,
                        default=512,
                        help='Dimensionality of the semantic pointers')
    parser.add_argument('--hidden-size',
                        type=int,
                        default=512,
                        help='Hidden size of the cleanup network')
    parser.add_argument('--limits',
                        type=str,
                        default="-5,5,-5,5",
                        help='The limits of the space')
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--seed', type=int, default=13)
    parser.add_argument('--logdir',
                        type=str,
                        default='ssp_cleanup',
                        help='Directory for saved model and tensorboard log')
    parser.add_argument('--load-model',
                        type=str,
                        default='',
                        help='Optional model to continue training from')
    parser.add_argument(
        '--name',
        type=str,
        default='',
        help=
        'Name of output folder within logdir. Will use current date and time if blank'
    )
    parser.add_argument('--weight-histogram',
                        action='store_true',
                        help='Save histograms of the weights if set')
    parser.add_argument('--use-hex-ssp', action='store_true')
    parser.add_argument('--optimizer',
                        type=str,
                        default='adam',
                        choices=['sgd', 'adam', 'rmsprop'])

    args = parser.parse_args()

    args.limits = tuple(float(v) for v in args.limits.split(','))

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    dataset_name = 'data/ssp_cleanup_dataset_dim{}_seed{}_items{}_limit{}_samples{}.npz'.format(
        args.dim, args.seed, args.n_items, args.limits[1], args.n_samples)

    final_test_samples = 100
    final_test_items = 15
    final_test_dataset_name = 'data/ssp_cleanup_test_dataset_dim{}_seed{}_items{}_limit{}_samples{}.npz'.format(
        args.dim, args.seed, final_test_items, args.limits[1],
        final_test_samples)

    if not os.path.exists('data'):
        os.makedirs('data')

    rng = np.random.RandomState(seed=args.seed)
    if args.use_hex_ssp:
        x_axis_sp, y_axis_sp = get_axes(dim=args.dim, n=3, seed=args.seed)
    else:
        x_axis_sp = make_good_unitary(args.dim, rng=rng)
        y_axis_sp = make_good_unitary(args.dim, rng=rng)

    if args.noise_type == 'gaussian':
        # Simple generation
        clean_ssps = np.zeros((args.n_samples, args.dim))
        coords = np.zeros((args.n_samples, 2))
        for i in range(args.n_samples):
            x = np.random.uniform(low=args.limits[0], high=args.limits[1])
            y = np.random.uniform(low=args.limits[2], high=args.limits[3])

            clean_ssps[i, :] = encode_point(x,
                                            y,
                                            x_axis_sp=x_axis_sp,
                                            y_axis_sp=y_axis_sp).v
            coords[i, 0] = x
            coords[i, 1] = y
        # Gaussian noise will be added later
        noisy_ssps = clean_ssps.copy()
    else:

        if os.path.exists(dataset_name):
            print("Loading dataset")
            data = np.load(dataset_name)
            clean_ssps = data['clean_ssps']
            noisy_ssps = data['noisy_ssps']
        else:
            print("Generating SSP cleanup dataset")
            clean_ssps, noisy_ssps, coords = generate_cleanup_dataset(
                x_axis_sp=x_axis_sp,
                y_axis_sp=y_axis_sp,
                n_samples=args.n_samples,
                dim=args.dim,
                n_items=args.n_items,
                limits=args.limits,
                seed=args.seed,
            )
            print("Dataset generation complete. Saving dataset")
            np.savez(
                dataset_name,
                clean_ssps=clean_ssps,
                noisy_ssps=noisy_ssps,
                coords=coords,
                x_axis_vec=x_axis_sp.v,
                y_axis_vec=x_axis_sp.v,
            )

    # check if the final test set has been generated yet
    if os.path.exists(final_test_dataset_name):
        print("Loading final test dataset")
        final_test_data = np.load(final_test_dataset_name)
        final_test_clean_ssps = final_test_data['clean_ssps']
        final_test_noisy_ssps = final_test_data['noisy_ssps']
    else:
        print("Generating final test dataset")
        final_test_clean_ssps, final_test_noisy_ssps, final_test_coords = generate_cleanup_dataset(
            x_axis_sp=x_axis_sp,
            y_axis_sp=y_axis_sp,
            n_samples=final_test_samples,
            dim=args.dim,
            n_items=final_test_items,
            limits=args.limits,
            seed=args.seed,
        )
        print("Final test generation complete. Saving dataset")
        np.savez(
            final_test_dataset_name,
            clean_ssps=final_test_clean_ssps,
            noisy_ssps=final_test_noisy_ssps,
            coords=final_test_coords,
            x_axis_vec=x_axis_sp.v,
            y_axis_vec=x_axis_sp.v,
        )

    # Add gaussian noise if required
    if args.noise_type == 'gaussian' or args.noise_type == 'both':
        noisy_ssps += np.random.normal(loc=0,
                                       scale=args.sigma,
                                       size=noisy_ssps.shape)

    n_samples = clean_ssps.shape[0]
    n_train = int(args.train_fraction * n_samples)
    n_test = n_samples - n_train
    assert (n_train > 0 and n_test > 0)
    train_clean = clean_ssps[:n_train, :]
    train_noisy = noisy_ssps[:n_train, :]
    test_clean = clean_ssps[n_train:, :]
    test_noisy = noisy_ssps[n_train:, :]

    # NOTE: this dataset is actually generic and can take any input/output mapping
    dataset_train = CoordDecodeDataset(vectors=train_noisy, coords=train_clean)
    dataset_test = CoordDecodeDataset(vectors=test_noisy, coords=test_clean)
    dataset_final_test = CoordDecodeDataset(vectors=final_test_noisy_ssps,
                                            coords=final_test_clean_ssps)

    trainloader = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=0,
    )

    # For testing just do everything in one giant batch
    testloader = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=len(dataset_test),
        shuffle=False,
        num_workers=0,
    )

    final_testloader = torch.utils.data.DataLoader(
        dataset_final_test,
        batch_size=len(dataset_final_test),
        shuffle=False,
        num_workers=0,
    )

    model = FeedForward(dim=dataset_train.dim,
                        hidden_size=args.hidden_size,
                        output_size=dataset_train.dim)

    # Open a tensorboard writer if a logging directory is given
    if args.logdir != '':
        current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        save_dir = osp.join(args.logdir, current_time)
        writer = SummaryWriter(log_dir=save_dir)
        if args.weight_histogram:
            # Log the initial parameters
            for name, param in model.named_parameters():
                writer.add_histogram('parameters/' + name,
                                     param.clone().cpu().data.numpy(), 0)

    mse_criterion = nn.MSELoss()
    cosine_criterion = nn.CosineEmbeddingLoss()

    if args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.momentum)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    else:
        raise NotImplementedError

    for e in range(args.epochs):
        print('Epoch: {0}'.format(e + 1))

        avg_mse_loss = 0
        avg_cosine_loss = 0
        n_batches = 0
        for i, data in enumerate(trainloader):

            noisy, clean = data

            if noisy.size()[0] != args.batch_size:
                continue  # Drop data, not enough for a batch
            optimizer.zero_grad()

            outputs = model(noisy)

            mse_loss = mse_criterion(outputs, clean)
            # Modified to use CosineEmbeddingLoss
            cosine_loss = cosine_criterion(outputs, clean,
                                           torch.ones(args.batch_size))

            avg_cosine_loss += cosine_loss.data.item()
            avg_mse_loss += mse_loss.data.item()
            n_batches += 1

            if args.loss == 'cosine':
                cosine_loss.backward()
            else:
                mse_loss.backward()

            # print(loss.data.item())

            optimizer.step()

        print(avg_cosine_loss / n_batches)

        if args.logdir != '':
            if n_batches > 0:
                avg_cosine_loss /= n_batches
                writer.add_scalar('avg_cosine_loss', avg_cosine_loss, e + 1)
                writer.add_scalar('avg_mse_loss', avg_mse_loss, e + 1)

            if args.weight_histogram and (e + 1) % 10 == 0:
                for name, param in model.named_parameters():
                    writer.add_histogram('parameters/' + name,
                                         param.clone().cpu().data.numpy(),
                                         e + 1)

    print("Testing")
    with torch.no_grad():

        for label, loader in zip(['test', 'final_test'],
                                 [testloader, final_testloader]):

            # Everything is in one batch, so this loop will only happen once
            for i, data in enumerate(loader):

                noisy, clean = data

                outputs = model(noisy)

                mse_loss = mse_criterion(outputs, clean)
                # Modified to use CosineEmbeddingLoss
                cosine_loss = cosine_criterion(outputs, clean,
                                               torch.ones(len(loader)))

                print(cosine_loss.data.item())

            if args.logdir != '':
                # TODO: get a visualization of the performance

                # show plots of the noisy, clean, and cleaned up with the network
                # note that the plotting mechanism itself uses nearest neighbors, so has a form of cleanup built in

                xs = np.linspace(args.limits[0], args.limits[1], 256)
                ys = np.linspace(args.limits[0], args.limits[1], 256)

                heatmap_vectors = get_heatmap_vectors(xs, ys, x_axis_sp,
                                                      y_axis_sp)

                noisy_coord = ssp_to_loc_v(noisy, heatmap_vectors, xs, ys)

                pred_coord = ssp_to_loc_v(outputs, heatmap_vectors, xs, ys)

                clean_coord = ssp_to_loc_v(clean, heatmap_vectors, xs, ys)

                fig_noisy_coord, ax_noisy_coord = plt.subplots()
                fig_pred_coord, ax_pred_coord = plt.subplots()
                fig_clean_coord, ax_clean_coord = plt.subplots()

                plot_predictions_v(noisy_coord,
                                   clean_coord,
                                   ax_noisy_coord,
                                   min_val=args.limits[0],
                                   max_val=args.limits[1],
                                   fixed_axes=True)

                plot_predictions_v(pred_coord,
                                   clean_coord,
                                   ax_pred_coord,
                                   min_val=args.limits[0],
                                   max_val=args.limits[1],
                                   fixed_axes=True)

                plot_predictions_v(clean_coord,
                                   clean_coord,
                                   ax_clean_coord,
                                   min_val=args.limits[0],
                                   max_val=args.limits[1],
                                   fixed_axes=True)

                writer.add_figure('{}/original_noise'.format(label),
                                  fig_noisy_coord)
                writer.add_figure('{}/test_set_cleanup'.format(label),
                                  fig_pred_coord)
                writer.add_figure('{}/ground_truth'.format(label),
                                  fig_clean_coord)
                # fig_hist = plot_histogram(predictions=outputs, coords=coord)
                # writer.add_figure('test set histogram', fig_hist)
                writer.add_scalar('{}/test_cosine_loss'.format(label),
                                  cosine_loss.data.item())
                writer.add_scalar('{}/test_mse_loss'.format(label),
                                  mse_loss.data.item())

    # Close tensorboard writer
    if args.logdir != '':
        writer.close()

        torch.save(model.state_dict(), osp.join(save_dir, 'model.pt'))

        params = vars(args)
        # # Additionally save the axis vectors used
        # params['x_axis_vec'] = list(x_axis_sp.v)
        # params['y_axis_vec'] = list(y_axis_sp.v)
        with open(osp.join(save_dir, "params.json"), "w") as f:
            json.dump(params, f)