def evaluate(self, x): """ Returns the evaluation and the accuracy for a perfect reconstruction and timestep accuracy. If `x` is either `"seed"` or `"repository"` the evaluation is based on the agents' local data. If `x` is a `ndarray` it creates a reversed version for evaluation. """ x_reverse = reverse_sequences(x) evaluation = self.rvae.evaluate(x_reverse, x, verbose=0, batch_size=1, return_dict=True) _, _, z = self.rvae.encode(x_reverse) reconstructions = self.rvae.decode(z) correct = 0 ts_correct = 0 for original, reconstruction in zip(x, reconstructions): a = np.argmax(original, axis=1) b = np.argmax(reconstruction, axis=1) if np.array_equal(a, b): correct += 1 ts_correct += np.sum([x == y for x, y in zip(a, b)]) / TIMESTEPS evaluation["abs_accuracy"] = correct / len(x) evaluation["ts_accuracy"] = ts_correct / len(x) return evaluation
def test_reverse_sequences_one_hot(): a = reverse_sequences([[[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 1, 0, 0]], [[0, 0, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 0, 0]]]) b = np.array([[[0, 0, 1, 0, 0], [0, 1, 0, 0, 0], [1, 0, 0, 0, 0]], [[0, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0]]]) assert np.array_equal( a, b), "The one-hot sequences are not correctly reversed"
def find_positions(self, entries, save_entries=True, frecency=False): positions = [] for agent_entries in entries: artefacts = np.array( [entry["artefact"] for entry in agent_entries]) x = reverse_sequences(one_hot(artefacts)) zs = self.rvae.encode(x) if save_entries: self.save(agent_entries, zs, frecency=frecency) positions.append(np.mean(zs[0], axis=0)) return np.array(positions)
def save(self, entries, zs=None, frecency=False): """ Stores entries in the reposity, generates latent encodings if not provided. """ if zs is None: artefacts = np.array([entry["artefact"] for entry in entries]) x = reverse_sequences(one_hot(artefacts)) zs = self.rvae.encode(x) for entry, z_mean, z_logvar, z in zip(entries, *zs): domain_entry = { **entry, "domain_z_mean": z_mean.numpy(), "domain_z_logvar": z_logvar.numpy(), "domain_z": z.numpy() } self.repository.append(domain_entry) if frecency: self.frecency.update({entry['id']: 10})
def reconstruct(self, epoch, entries, artefacts): """ Returns a `list` with the data of all reconstructed artefacts and the current epoch. """ x = reverse_sequences(artefacts) z_mean, _, z = self.rvae.encode(x) x_hat = np.argmax(self.rvae.decode(z), axis=-1) reconstructions = [{ "rz_mean": rz_mean, "rz": rz, "reconstruction": r } for rz_mean, rz, r in zip(z_mean.numpy(), z.numpy(), x_hat)] entries = [{ "current_epoch": epoch, **entry, **reconstruction } for entry, reconstruction in zip(entries, reconstructions)] return entries
def test_reverse_sequences(): a = reverse_sequences([[['a', 'b', 'c'], ['c', 'b', 'a']]]) b = [[['c', 'b', 'a'], ['a', 'b', 'c']]] assert np.array_equal(a, b), "The sequences are not correctly reversed"