class RankAndCrowdingSurvival(Survival): def __init__(self) -> None: super().__init__(filter_infeasible=True) self.nds = NonDominatedSorting() def _do(self, problem, pop, n_survive, D=None, **kwargs): # get the objective space values and objects F = pop.get("F").astype(np.float, copy=False) # the final indices of surviving individuals survivors = [] # do the non-dominated sorting until splitting front fronts = self.nds.do(F, n_stop_if_ranked=n_survive) for k, front in enumerate(fronts): # calculate the crowding distance of the front crowding_of_front = calc_crowding_distance(F[front, :]) # save rank and crowding in the individual class for j, i in enumerate(front): pop[i].set("rank", k) pop[i].set("crowding", crowding_of_front[j]) # current front sorted by crowding distance if splitting if len(survivors) + len(front) > n_survive: I = randomized_argsort(crowding_of_front, order='descending', method='numpy') I = I[:(n_survive - len(survivors))] # otherwise take the whole front unsorted else: I = np.arange(len(front)) # extend the survivors by all or selected individuals survivors.extend(front[I]) return pop[survivors]
def get_all_sorted(self, sorted_by: [str], maximize: [bool], data_set: str = None, only_best=False) -> [MiniResult]: """ get all / the best entries sorted by 1 to N objectives :param sorted_by: sort by 'acc1', 'flops', ... single or multi-objective :param maximize: for each sorted_by key, whether to maximize the value :param data_set: use a specific data set, otherwise default :param only_best: return all or only the pareto front """ assert len(sorted_by) == len(maximize) > 0 # only one objective if len(sorted_by) == 1: all_entries = sorted(self.get_all(), key=lambda s: s.get(sorted_by[0], data_set), reverse=maximize[0]) if only_best: i, best_entries = 1, [all_entries[0]] while all_entries[0].get(sorted_by[0]) == all_entries[i].get( sorted_by[0]): i += 1 best_entries.append(all_entries[i]) all_entries = best_entries return all_entries # multiple objectives else: nds = NonDominatedSorting() # first get all values, minimization problem entries = self.get_all() all_values = np.zeros(shape=(len(entries), len(sorted_by))) for i, entry in enumerate(entries): all_values[i] = [ entry.get(s, data_set) * (-1 if m else 1) for s, m in zip(sorted_by, maximize) ] # if we want only the best rank, we can sort and concat smaller groups for speed if only_best: # find best in small subgroups group_size, best_idx = 200, [] for i in range(0, len(all_values) // group_size + 1): start, end = group_size * i, min(group_size * (i + 1), len(all_values)) idx = np.arange(start, end) group = all_values[idx] pareto_idx = nds.do(group, only_non_dominated_front=True) + start best_idx.append(pareto_idx) # concat subgroups and find best best_idx = np.concatenate(best_idx, axis=0) pareto_idx = nds.do(all_values[best_idx], only_non_dominated_front=True) best_idx = best_idx[pareto_idx] all_entries = [entries[i] for i in best_idx] # we want all, sorted else: sorted_idx = nds.do(all_values, only_non_dominated_front=True) all_entries = [entries[i] for i in sorted_idx] return sorted(all_entries, key=lambda s: s.get(sorted_by[0], data_set), reverse=maximize[0])
class AGEMOEASurvival(Survival): def __init__(self) -> None: super().__init__(filter_infeasible=True) self.nds = NonDominatedSorting() def _do(self, problem, pop, *args, n_survive=None, **kwargs): # get the objective values F = pop.get("F") N = n_survive # Non-dominated sorting fronts = self.nds.do(F, n_stop_if_ranked=N) # get max int value max_val = np.iinfo(np.int).max # initialize population ranks with max int value front_no = np.full(F.shape[0], max_val, dtype=np.int) # assign the rank to each individual for i, fr in enumerate(fronts): front_no[fr] = i pop.set("rank", front_no) # get the index of the front to be sorted and cut max_f_no = np.max(front_no[front_no != max_val]) # keep fronts that have lower rank than the front to cut selected: np.ndarray = front_no < max_f_no n_ind, _ = F.shape # crowding distance is positive and has to be maximized crowd_dist = np.zeros(n_ind) # get the first front for normalization front1 = F[front_no == 0, :] # follows from the definition of the ideal point but with current non dominated solutions ideal_point = np.min(front1, axis=0) # Calculate the crowding distance of the first front as well as p and the normalization constants crowd_dist[front_no == 0], p, normalization = survival_score( front1, ideal_point) for i in range( 1, max_f_no ): # skip first front since it is normalized by survival_score front = F[front_no == i, :] m, _ = front.shape front = front / normalization crowd_dist[front_no == i] = 1. / minkowski_matrix( front, ideal_point[None, :], p=p).squeeze() # Select the solutions in the last front based on their crowding distances last = np.arange(selected.shape[0])[front_no == max_f_no] rank = np.argsort(crowd_dist[last])[::-1] selected[last[rank[:N - np.sum(selected)]]] = True pop.set("crowding", crowd_dist) # return selected solutions, number of selected should be equal to population size return pop[selected]
class AbstractPbtSelector(ArgsInterface): """ to figure out which clients should save, which should load from where, ... """ def __init__(self, weights_dir: str, logger: Logger, targets: [OptimizationTarget], mutations: [AbstractPbtMutation], each_epochs: int, grace_epochs: int, save_ema: bool, elitist: bool): super().__init__() self._nds = NonDominatedSorting() self._data = { 'saves': dict(), # {key: PbtSave} 'replacements': [], # ReplacementPbtEvent } self.weights_dir = weights_dir self.logger = logger self.targets = targets self.mutations = mutations self.each_epochs = each_epochs self.grace_epochs = grace_epochs self.save_ema = save_ema self.elitist = elitist @classmethod def meta_args_to_add(cls) -> [MetaArgument]: return [ MetaArgument('cls_pbt_targets', Register.optimization_targets, allow_duplicates=True, help_name='optimization target(s)'), MetaArgument('cls_pbt_mutations', Register.pbt_mutations, help_name='mutations to the copied checkpoint training state'), ] @classmethod def args_to_add(cls, index=None) -> [Argument]: """ list arguments to add to argparse when this class (or a child class) is chosen """ return super().args_to_add(index) + [ Argument('each_epochs', default=1, type=int, help="only synchronize each n epochs"), Argument('grace_epochs', default=0, type=int, help="skip synchronization for the first n epochs"), Argument('save_ema', default="True", type=str, is_bool=True, help="save the EMA-model weights if available, otherwise save the trained model's weights"), Argument('elitist', default="True", type=str, is_bool=True, help="elitist: keep old checkpoints if they are better"), ] @classmethod def from_args(cls, save_dir: str, logger: Logger, args: Namespace, index=None) -> 'AbstractPbtSelector': targets = [cls_target.from_args(args, index=i) for i, cls_target in enumerate(cls._parsed_meta_arguments(Register.optimization_targets, 'cls_pbt_targets', args, index))] mutations = [cls_mutations.from_args(args, index=i) for i, cls_mutations in enumerate(cls._parsed_meta_arguments(Register.pbt_mutations, 'cls_pbt_mutations', args, index))] return cls(save_dir, logger, targets, mutations, **cls._all_parsed_arguments(args, index=index)) @classmethod def _meta_name(cls, save_dir: str) -> str: return "%s/%s.meta.pt" % (save_dir, cls.__name__) def save(self, save_dir: str): torch.save(self._data, self._meta_name(save_dir)) def load(self, save_dir: str): self._data = torch.load(self._meta_name(save_dir)) # keep track of saves/checkpoints def get_saves(self) -> {str: PbtSave}: return self._data['saves'] def get_saves_list(self) -> [PbtSave]: return list(self.get_saves().values()) def get_save(self, epoch: int, client_id: int) -> Union[None, PbtSave]: for save in self.get_saves_list(): if save.epoch == epoch and save.client_id == client_id: return save def save_client(self, epoch: int, client_id: int, log_dict: dict) -> PbtSave: """ save """ path = "%s/%d-%d-weights.pt" % (self.weights_dir, epoch, client_id) save = PbtSave(epoch, client_id, log_dict, path=path) self._data['saves'][save.key] = save return save def remove_saves_by_keys(self, keys: list): """ specific saves """ if len(keys) > 0: self.logger.info("Removing saves:") saves = [self.get_saves()[k] for k in keys] self.log_saves(saves) for s in saves: s.remove_file() del self._data['saves'][s.key] def remove_unused_saves(self): """ remove all saves that are not marked as used """ self.remove_saves_by_keys([save.key for save in self.get_saves_list() if not save.is_used()]) def get_best(self, saves: [PbtSave], epoch: int = None, exclude_old=False) -> List[List[PbtSave]]: """ get the saves in order from best to worst, optionally of a specific epoch :param saves :param epoch: prefer checkpoints of this epochs, append the others at the back :param exclude_old: do not append any old checkpoints, only functions is epoch is set """ # get the ranking log_dicts = [c.log_dict for c in saves] values = np.array([[target.sort_value(ld) for target in self.targets] for ld in log_dicts]) _, rank = self._nds.do(values, return_rank=True) # sort the saves, possibly older ones to the back best = [] old_best = [] for i in range(max(rank)+1): # rank i best_ = [] for c, r in zip(saves, rank): if r == i: best_.append(c) # old ones separated? if isinstance(epoch, int): best.append([s for s in best_ if s.epoch == epoch]) if not exclude_old: old_best.append([s for s in best_ if s.epoch != epoch]) else: best.append(best_) # add old ones (maybe empty), remove empty lists, return best += old_best best = [b for b in best if len(b) > 0] return best def cleanup(self, epoch: int = None): """ keep the best saves (pareto front), remove the rest :param epoch: if the selector is not elitist, remove older saves regardless of performance """ best = self.get_best(self.get_saves_list(), epoch=None if self.elitist else epoch) keys = [] for i, rank_i in enumerate(best): if i == 0: continue keys.extend([c.key for c in rank_i]) self.remove_saves_by_keys(keys) assert len(self.get_saves()) > 0 def log_saves(self, saves: [PbtSave] = None): if saves is None: self.logger.info("All saves:") saves = flatten(self.get_best(self.get_saves_list())) saves = sorted(saves, key=lambda save: self.targets[0].sort_value(save.log_dict)) lines = [["epoch=%d" % save.epoch] + ["client=%d" % save.client_id] + [target.as_str(save.log_dict) for target in self.targets] + [save.get_path() if save.is_used() else ""] for save in saves] log_in_columns(self.logger, lines, add_bullets=True) # keep track of events def add_replacement_event(self, r: ReplacementPbtEvent): self._data['replacements'].append(r) def get_replacement_events(self, epoch: int = None, client_id: int = None): """ get list of all replacements that happened, optionally filter :param epoch: :param client_id: """ events = self._data['replacements'] if isinstance(epoch, int): events = [e for e in events if e.epoch == epoch] if isinstance(client_id, int): events = [e for e in events if e.client_id == client_id] return events def log_events(self, events: [AbstractPbtEvent], text: str = None): if len(events) > 0: if isinstance(text, str): self.logger.info(text) lines = [e.str_columns() for e in events] log_in_columns(self.logger, lines, add_bullets=True) # selecting, mutating, responses @classmethod def empty_response(cls, client_id: int) -> PbtServerResponse: return PbtServerResponse(client_id=client_id) def is_interesting(self, epoch: int, log_dict: dict) -> bool: """ if the watched key is not in the log dict, no need to synchronize """ if epoch < self.grace_epochs: return False if (epoch - self.grace_epochs + 1) % self.each_epochs > 0: return False return all([target.is_interesting(log_dict) for target in self.targets]) def first_use(self, log_dicts: {int, dict}) -> {int, (dict, PbtServerResponse)}: """ create the changed log dict and response for each client """ ret = {} for client_id, log_dict in log_dicts.items(): ld, r = log_dict, self.empty_response(client_id) for m in self.mutations: ld, r = m.initial_mutate(r, ld, len(log_dicts)) ret[client_id] = (ld, r) return ret def select(self, epoch: int, log_dicts: {int, dict}) -> {int, PbtServerResponse}: """ create the responses for each client """ # reset mutations for m in self.mutations: m.reset() # reset usage of all saves for save in self.get_saves_list(): save.reset() # add all to saves, without actually saving states yet (not necessary for all) for client_id, log_dict in log_dicts.items(): self.save_client(epoch, client_id, log_dict) # empty responses responses = {} for client_id, log_dict in log_dicts.items(): responses[client_id] = PbtServerResponse(client_id=client_id, save_ema=self.save_ema) # get replacements for (replace, replace_with) in self._select(responses, epoch, log_dicts): replace_with.add_usage() e = ReplacementPbtEvent(replace.epoch, replace.client_id, replace_with.epoch, replace_with.client_id, replace_with.get_path()) self.add_replacement_event(e) responses[replace.client_id].load_path = replace_with.get_path() # apply mutations for m in self.mutations: responses[replace.client_id] = m.mutate(responses[replace.client_id], replace_with.log_dict) # remove all unused saves self.remove_unused_saves() # log replacements self.log_events(self.get_replacement_events(epoch=epoch), text="Replacements:") # log remaining saves, return responses self.log_saves() return responses def _select(self, responses: {int: PbtServerResponse}, epoch: int, log_dicts: {int: dict}) -> [(PbtSave, PbtSave)]: """ create the responses for each client, {client_id: log_dict}, mark all saves that should be kept :param responses: {client_id: PbtServerResponse} :param epoch: current epoch :param log_dicts: {client_id: dict} :return replacements [(to_replace, replace_with)] """ raise NotImplementedError # plotting results def plot_results(self, save_dir: str, log_dicts: dict): """ task complete, plot results, {epoch: {client_id: log_dict}} """ # reshape into {key: {client_id: list of entries}} epochs = sorted(list(log_dicts.keys())) clients = list(log_dicts[epochs[0]].keys()) keys = log_dicts[epochs[-1]][clients[0]] reshaped_log_dicts = {k: {c: list() for c in clients} for k in keys} for e in epochs: for c in clients: for k in keys: reshaped_log_dicts[k][c].append(log_dicts[e][c].get(k, None)) # actually plot self._plot_results(save_dir, log_dicts, reshaped_log_dicts, epochs, clients, keys) def _plot_results(self, save_dir: str, log_dicts: dict, reshaped_log_dicts: dict, epochs: [int], clients: [int], keys: [str]): """ :param save_dir: where to save :param log_dicts: {epoch: {client_id: {key: value}}} :param reshaped_log_dicts: {key: {client_id: list of values}} :param epochs: list of epochs :param clients: list of client ids :param keys: list of keys in the log dicts """ # plot all by keys for key in keys: is_id_relevant = True for n in ['train', 'val', 'test']: if key.startswith(n): is_id_relevant = False break fmt = 'o--' if is_id_relevant else '-' for c in clients: plt.plot(epochs, reshaped_log_dicts[key][c], fmt, label="%d" % c) if is_id_relevant: plt.legend() plt.xlabel("epoch") plt.ylabel(key) plt.savefig("%s/key_%s.pdf" % (save_dir, key.replace('/', '_'))) plt.clf() # lineage for e0, e1 in zip(epochs[:-1], epochs[1:]): for client_id in clients: if e0 == 0: plt.plot((e0, e0), (client_id, client_id), 'bo-', linewidth=1) r = self.get_replacement_events(epoch=e0, client_id=client_id) if len(r) == 0: # no replacement happened plt.plot((e0, e1), (client_id, client_id), 'bo-', linewidth=1) else: # no replacement happened if len(r) > 1: self.logger.warning("Have %d replacements for client_id=%d, epoch1=%d, only taking the first" % (len(r), client_id, e1)) r = r[0] plt.plot((r.epoch_replaced_with, e1), (r.client_replaced_with, client_id), 'ro--', linewidth=2) plt.xlabel("epoch") plt.ylabel("client id") plt.savefig("%s/lineage.pdf" % save_dir) plt.clf()