def cross_virtual_reward( host_observs: Tensor, host_rewards: Tensor, ext_observs: Tensor, ext_rewards: Tensor, dist_coef: float = 1.0, reward_coef: float = 1.0, return_compas: bool = False, distance_function: Callable = l2_norm, ): """Calculate the virtual rewards between two cloud of points.""" host_observs = host_observs.reshape(len(host_rewards), -1) ext_observs = ext_observs.reshape(len(ext_rewards), -1) compas_host = random_state.permutation(judo.arange(len(host_rewards))) compas_ext = random_state.permutation(judo.arange(len(ext_rewards))) # TODO: check if it's better for the distances to be the same for host and ext h_dist = distance_function(host_observs, ext_observs[compas_host]) e_dist = distance_function(ext_observs, host_observs[compas_ext]) host_distance = relativize(h_dist.flatten()) ext_distance = relativize(e_dist.flatten()) host_rewards = relativize(host_rewards) ext_rewards = relativize(ext_rewards) host_vr = host_distance**dist_coef * host_rewards**reward_coef ext_vr = ext_distance**dist_coef * ext_rewards**reward_coef if return_compas: return (host_vr, compas_host), (ext_vr, compas_ext) return host_vr, ext_vr
def get_alive_indexes(oobs: Tensor): """Get indexes representing random alive walkers given a vector of death conditions.""" if judo.all(oobs): return judo.arange(len(oobs)) ix = judo.logical_not(oobs).flatten() return random_state.choice(judo.arange(len(ix))[ix], size=len(ix), replace=ix.sum() < len(ix))
def _get_merge_indexes(self, walkers: ExportedWalkers) -> Tuple[Tensor, Tensor]: """Get the indexes for selecting the walkers that will be compared in \ the clone operation.""" local_ix = random_state.choice( judo.arange(len(self.swarm.walkers)), size=self.n_import, replace=False ) if len(walkers) <= self.n_import: import_ix = random_state.choice( judo.arange(len(walkers)), size=self.n_import, replace=False ) else: import_ix = random_state.choice( judo.arange(len(walkers)), size=self.n_import, replace=True ) return local_ix, import_ix
def calculate_virtual_reward( observs: Tensor, rewards: Tensor, oobs: Tensor = None, dist_coef: float = 1.0, reward_coef: float = 1.0, other_reward: Tensor = 1.0, return_compas: bool = False, distance_function: Callable = l2_norm, ): """Calculate the virtual rewards given the required data.""" compas = get_alive_indexes(oobs) if oobs is not None else judo.arange( len(rewards)) compas = random_state.permutation(compas) flattened_observs = observs.reshape(len(oobs), -1) other_reward = other_reward.flatten() if dtype.is_tensor( other_reward) else other_reward distance = distance_function(flattened_observs, flattened_observs[compas]) distance_norm = relativize(distance.flatten()) rewards_norm = relativize(rewards) virtual_reward = distance_norm**dist_coef * rewards_norm**reward_coef * other_reward return virtual_reward.flatten() if not return_compas else ( virtual_reward.flatten(), compas)
def test_clone(self, states_class): batch_size = 10 states = states_class(batch_size=batch_size) states.miau = judo.arange(states.n) states.miau_2 = judo.arange(states.n) will_clone = judo.zeros(states.n, dtype=judo.bool) will_clone[3:6] = True compas_ix = tensor(list(range(states.n))[::-1]) states.clone(will_clone=will_clone, compas_ix=compas_ix) target_1 = judo.arange(10) assert bool( judo.all(target_1 == states.miau)), (target_1 - states.miau, states_class)
def test_update(self, states_walkers): states_walkers = StatesWalkers(10) states_walkers.reset() test_vals = judo.arange(states_walkers.n) states_walkers.update(virtual_rewards=test_vals, distances=test_vals) assert (states_walkers.virtual_rewards == test_vals).all() assert (states_walkers.distances == test_vals).all()
def get_in_bounds_compas(self) -> Tensor: """ Return the indexes of walkers inside bounds chosen at random. Returns: Numpy array containing the int indexes of in bounds walkers chosen at \ random with replacement. Its length is equal to the number of walkers. """ if not self.states.in_bounds.any( ): # No need to sample if all walkers are dead. return judo.arange(self.n) alive_indexes = judo.arange(self.n, dtype=int)[self.states.in_bounds] compas_ix = self.random_state.permutation(alive_indexes) compas = self.random_state.choice(compas_ix, self.n, replace=True) compas[:len(compas_ix)] = compas_ix return compas
def reset(self): """Clear the internal data of the class.""" params = self.get_params_dict() other_attrs = [name for name in self.keys() if name not in params] for attr in other_attrs: setattr(self, attr, None) self.update( id_walkers=judo.zeros(self.n, dtype=judo.hash_type), compas_dist=judo.arange(self.n), compas_clone=judo.arange(self.n), processed_rewards=judo.zeros(self.n, dtype=judo.float), cum_rewards=judo.zeros(self.n, dtype=judo.float), virtual_rewards=judo.ones(self.n, dtype=judo.float), distances=judo.zeros(self.n, dtype=judo.float), clone_probs=judo.zeros(self.n, dtype=judo.float), will_clone=judo.zeros(self.n, dtype=judo.bool), in_bounds=judo.ones(self.n, dtype=judo.bool), )
def test_setitem(self, states_class): name_1 = "miau" val_1 = name_1 name_2 = "elephant" val_2 = judo.arange(10) new_states = states_class(batch_size=2) new_states[name_1] = val_1 new_states[name_2] = val_2 assert new_states[name_1] == val_1, type(new_states) assert (new_states[name_2] == val_2).all(), type(new_states)
def test_accumulate_rewards(self, walkers): walkers.reset() walkers._accumulate_rewards = True walkers.states.update( cum_rewards=[0, 0]) # Override array of Floats and set to None walkers.states.update(cum_rewards=None) rewards = judo.arange(len(walkers)) walkers._accumulate_and_update_rewards(rewards) assert (walkers.states.cum_rewards == rewards).all() walkers._accumulate_rewards = False walkers.states.update(cum_rewards=judo.zeros(len(walkers))) rewards = judo.arange(len(walkers)) walkers._accumulate_and_update_rewards(rewards) assert (walkers.states.cum_rewards == rewards).all() walkers._accumulate_rewards = True walkers.states.update(cum_rewards=judo.ones(len(walkers))) rewards = judo.arange(len(walkers)) walkers._accumulate_and_update_rewards(rewards) assert (walkers.states.cum_rewards == rewards + 1).all()
def small_tree(): node_data = {"a": judo.arange(10), "b": judo.zeros(10)} edge_data = {"c": judo.ones(10)} g = networkx.DiGraph() for i in range(8): g.add_node(to_node_id(i), **node_data) pairs = [(0, 1), (1, 2), (2, 3), (2, 4), (2, 5), (3, 6), (3, 7)] for a, b in pairs: g.add_edge(to_node_id(a), to_node_id(b), **edge_data) return g
def calculate_clone(virtual_rewards: Tensor, oobs: Tensor = None, eps=1e-3): """Calculate the clone indexes and masks from the virtual rewards.""" compas_ix = get_alive_indexes(oobs) if oobs is not None else judo.arange( len(virtual_rewards)) compas_ix = random_state.permutation(compas_ix) vir_rew = virtual_rewards.flatten() clone_probs = (vir_rew[compas_ix] - vir_rew) / judo.where( vir_rew > eps, vir_rew, tensor(eps)) will_clone = clone_probs.flatten() > random_state.random(len(clone_probs)) return compas_ix, will_clone
def calculate_distances(self) -> None: """Calculate the corresponding distance function for each observation with \ respect to another observation chosen at random. The internal :class:`StateWalkers` is updated with the relativized distance values. """ # TODO(guillemdb): Check if self.get_in_bounds_compas() works better. compas_ix = self.random_state.permutation(judo.arange(self.n)) obs = self.env_states.observs.reshape(self.n, -1) distances = self.distance_function(obs, obs[compas_ix]) distances = relativize(distances.flatten()) self.update_states(distances=distances, compas_dist=compas_ix)
def test_create_export_walkers(self, export_swarm): indexes = judo.arange(5) walkers = export_swarm._create_export_walkers(indexes) assert isinstance(walkers, ExportedWalkers) assert len(walkers) == 5 assert (walkers.observs == export_swarm.walkers.env_states.observs[indexes]).all() assert (walkers.rewards == export_swarm.walkers.states.cum_rewards[indexes]).all() assert (walkers.states == export_swarm.walkers.env_states.states[indexes]).all() assert (walkers.id_walkers == export_swarm.walkers.states.id_walkers[indexes]).all()
def test_states_from_data(self, env_data, batch_size, states_dim): env, model_states = env_data states = judo.zeros((batch_size, states_dim)) observs = judo.ones((batch_size, states_dim)) rewards = judo.arange(batch_size) oobs = judo.zeros(batch_size, dtype=dtype.bool) state = env.states_from_data(batch_size=batch_size, states=states, observs=observs, rewards=rewards, oobs=oobs) assert isinstance(state, StatesEnv) for val in state.vals(): assert dtype.is_tensor(val) assert len(val) == batch_size
def cross_clone( host_virtual_rewards: Tensor, ext_virtual_rewards: Tensor, host_oobs: Tensor = None, eps=1e-3, ): """Perform a clone operation between two different groups of points.""" compas_ix = random_state.permutation(judo.arange(len(ext_virtual_rewards))) host_vr = judo.astype(host_virtual_rewards.flatten(), dtype=dtype.float32) ext_vr = judo.astype(ext_virtual_rewards.flatten(), dtype=dtype.float32) clone_probs = (ext_vr[compas_ix] - host_vr) / judo.where( ext_vr > eps, ext_vr, tensor(eps, dtype=dtype.float32)) will_clone = clone_probs.flatten() > random_state.random(len(clone_probs)) if host_oobs is not None: will_clone[host_oobs] = True return compas_ix, will_clone
def calculate_distance( observs: Tensor, distance_function: Callable = l2_norm, return_compas: bool = False, oobs: Tensor = None, compas: Tensor = None, ): """Calculate a distance metric for each walker with respect to a random companion.""" if compas is None: compas = get_alive_indexes(oobs) if oobs is not None else judo.arange( observs.shape[0]) compas = random_state.permutation(compas) flattened_observs = observs.view(observs.shape[0], -1) distance = distance_function(flattened_observs, flattened_observs[compas]) distance_norm = relativize(distance.flatten()) return distance_norm if not return_compas else (distance_norm, compas)
def test_update_clone_probs(self, walkers): walkers.reset() walkers.states.update(virtual_rewards=relativize( judo.arange(walkers.n, dtype=dtype.float32))) walkers.update_clone_probs() assert 0 < judo.sum( walkers.states.clone_probs == walkers.states.clone_probs[0]), ( walkers.states.virtual_rewards, walkers.states.clone_probs, ) walkers.reset() walkers.update_clone_probs() assert judo.sum(walkers.states.clone_probs == walkers.states.clone_probs[0]) == walkers.n assert walkers.states.clone_probs.shape[0] == walkers.n assert len(walkers.states.clone_probs.shape) == 1
def test_merge_states(self, states_class): batch_size = 21 data = judo.repeat(judo.arange(5).reshape(1, -1), batch_size, 0) new_states = states_class(batch_size=batch_size, test="test", data=data) split_states = tuple(new_states.split_states(batch_size)) merged = new_states.merge_states(split_states) assert len(merged) == batch_size assert merged.test == "test" assert (merged.data == data).all() split_states = tuple(new_states.split_states(5)) merged = new_states.merge_states(split_states) assert len(merged) == batch_size assert merged.test == "test" assert (merged.data == data).all()
def test_append_leaf(self, tree): node_data = {"node": judo.arange(10)} edge_data = {"edge": False} leaf_id = to_node_id(-421) epoch = 123 tree.append_leaf( leaf_id=leaf_id, parent_id=tree.root_id, node_data=node_data, edge_data=edge_data, epoch=epoch, ) assert (tree.data.nodes[leaf_id]["node"] == node_data["node"]).all() assert tree.data.nodes[leaf_id]["epoch"] == epoch assert tree.data.edges[(tree.root_id, leaf_id)] == edge_data assert leaf_id in tree.leafs assert tree.root_id not in tree.leafs
def update_clone_probs(self) -> None: """ Calculate the new probability of cloning for each walker. Updates the :class:`StatesWalkers` with both the probability of cloning \ and the index of the randomly chosen companions that were selected to \ compare the virtual rewards. """ all_virtual_rewards_are_equal = (self.states.virtual_rewards == self.states.virtual_rewards[0]).all() if all_virtual_rewards_are_equal: clone_probs = judo.zeros(self.n, dtype=dtype.float) compas_ix = judo.arange(self.n) else: compas_ix = self.get_in_bounds_compas() companions = self.states.virtual_rewards[compas_ix] # This value can be negative!! clone_probs = (companions - self.states.virtual_rewards ) / self.states.virtual_rewards self.update_states(clone_probs=clone_probs, compas_clone=compas_ix)
def test_split_states(self, states_class): batch_size = 20 new_states = states_class(batch_size=batch_size, test="test") for s in new_states.split_states(batch_size): assert len(s) == 1 assert s.test == "test" data = judo.repeat(judo.arange(5).reshape(1, -1), batch_size, 0) new_states = states_class(batch_size=batch_size, test="test", data=data) for s in new_states.split_states(batch_size): assert len(s) == 1 assert s.test == "test" assert bool((s.data == judo.arange(5)).all()), s.data chunk_len = 4 test_data = judo.repeat(judo.arange(5).reshape(1, -1), chunk_len, 0) for s in new_states.split_states(5): assert len(s) == chunk_len assert s.test == "test" assert (s.data == test_data).all(), (s.data.shape, test_data.shape) batch_size = 21 data = judo.repeat(judo.arange(5).reshape(1, -1), batch_size, 0) new_states = states_class(batch_size=batch_size, test="test", data=data) chunk_len = 5 test_data = judo.repeat(judo.arange(5).reshape(1, -1), chunk_len, 0) split_states = list(new_states.split_states(5)) for s in split_states[:-1]: assert len(s) == chunk_len assert s.test == "test" assert (s.data == test_data).all(), (s.data.shape, test_data.shape) assert len(split_states[-1]) == 1 assert split_states[-1].test == "test" assert (split_states[-1].data == judo.arange(5)).all(), ( s.data.shape, test_data.shape)