Ejemplo n.º 1
0
    def evaluate(self, x):
        """ Returns the evaluation and the accuracy for a perfect reconstruction and timestep
            accuracy.

            If `x` is either `"seed"` or `"repository"` the evaluation is based on the agents'
            local data. If `x` is a `ndarray` it creates a reversed version for evaluation.
        """

        x_reverse = reverse_sequences(x)
        evaluation = self.rvae.evaluate(x_reverse,
                                        x,
                                        verbose=0,
                                        batch_size=1,
                                        return_dict=True)

        _, _, z = self.rvae.encode(x_reverse)
        reconstructions = self.rvae.decode(z)

        correct = 0
        ts_correct = 0

        for original, reconstruction in zip(x, reconstructions):
            a = np.argmax(original, axis=1)
            b = np.argmax(reconstruction, axis=1)

            if np.array_equal(a, b):
                correct += 1

            ts_correct += np.sum([x == y for x, y in zip(a, b)]) / TIMESTEPS

        evaluation["abs_accuracy"] = correct / len(x)
        evaluation["ts_accuracy"] = ts_correct / len(x)
        return evaluation
def test_reverse_sequences_one_hot():
    a = reverse_sequences([[[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0]],
                           [[0, 0, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 0,
                                                               0]]])
    b = np.array([[[0, 0, 1, 0, 0], [0, 1, 0, 0, 0], [1, 0, 0, 0, 0]],
                  [[0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0]]])

    assert np.array_equal(
        a, b), "The one-hot sequences are not correctly reversed"
Ejemplo n.º 3
0
    def find_positions(self, entries, save_entries=True, frecency=False):
        positions = []

        for agent_entries in entries:
            artefacts = np.array(
                [entry["artefact"] for entry in agent_entries])
            x = reverse_sequences(one_hot(artefacts))
            zs = self.rvae.encode(x)

            if save_entries:
                self.save(agent_entries, zs, frecency=frecency)

            positions.append(np.mean(zs[0], axis=0))

        return np.array(positions)
Ejemplo n.º 4
0
    def save(self, entries, zs=None, frecency=False):
        """ Stores entries in the reposity, generates latent encodings if not provided. """
        if zs is None:
            artefacts = np.array([entry["artefact"] for entry in entries])
            x = reverse_sequences(one_hot(artefacts))
            zs = self.rvae.encode(x)

        for entry, z_mean, z_logvar, z in zip(entries, *zs):
            domain_entry = {
                **entry, "domain_z_mean": z_mean.numpy(),
                "domain_z_logvar": z_logvar.numpy(),
                "domain_z": z.numpy()
            }

            self.repository.append(domain_entry)

            if frecency:
                self.frecency.update({entry['id']: 10})
Ejemplo n.º 5
0
    def reconstruct(self, epoch, entries, artefacts):
        """ Returns a `list` with the data of all reconstructed artefacts
            and the current epoch.
        """
        x = reverse_sequences(artefacts)
        z_mean, _, z = self.rvae.encode(x)
        x_hat = np.argmax(self.rvae.decode(z), axis=-1)

        reconstructions = [{
            "rz_mean": rz_mean,
            "rz": rz,
            "reconstruction": r
        } for rz_mean, rz, r in zip(z_mean.numpy(), z.numpy(), x_hat)]

        entries = [{
            "current_epoch": epoch,
            **entry,
            **reconstruction
        } for entry, reconstruction in zip(entries, reconstructions)]

        return entries
def test_reverse_sequences():
    a = reverse_sequences([[['a', 'b', 'c'], ['c', 'b', 'a']]])
    b = [[['c', 'b', 'a'], ['a', 'b', 'c']]]

    assert np.array_equal(a, b), "The sequences are not correctly reversed"