Exemplo n.º 1
0
class RankAndCrowdingSurvival(Survival):
    def __init__(self) -> None:
        super().__init__(filter_infeasible=True)
        self.nds = NonDominatedSorting()

    def _do(self, problem, pop, n_survive, D=None, **kwargs):

        # get the objective space values and objects
        F = pop.get("F").astype(np.float, copy=False)

        # the final indices of surviving individuals
        survivors = []

        # do the non-dominated sorting until splitting front
        fronts = self.nds.do(F, n_stop_if_ranked=n_survive)

        for k, front in enumerate(fronts):

            # calculate the crowding distance of the front
            crowding_of_front = calc_crowding_distance(F[front, :])

            # save rank and crowding in the individual class
            for j, i in enumerate(front):
                pop[i].set("rank", k)
                pop[i].set("crowding", crowding_of_front[j])

            # current front sorted by crowding distance if splitting
            if len(survivors) + len(front) > n_survive:
                I = randomized_argsort(crowding_of_front,
                                       order='descending',
                                       method='numpy')
                I = I[:(n_survive - len(survivors))]

            # otherwise take the whole front unsorted
            else:
                I = np.arange(len(front))

            # extend the survivors by all or selected individuals
            survivors.extend(front[I])

        return pop[survivors]
Exemplo n.º 2
0
    def get_all_sorted(self,
                       sorted_by: [str],
                       maximize: [bool],
                       data_set: str = None,
                       only_best=False) -> [MiniResult]:
        """
        get all / the best entries sorted by 1 to N objectives

        :param sorted_by: sort by 'acc1', 'flops', ... single or multi-objective
        :param maximize: for each sorted_by key, whether to maximize the value
        :param data_set: use a specific data set, otherwise default
        :param only_best: return all or only the pareto front
        """
        assert len(sorted_by) == len(maximize) > 0
        # only one objective
        if len(sorted_by) == 1:
            all_entries = sorted(self.get_all(),
                                 key=lambda s: s.get(sorted_by[0], data_set),
                                 reverse=maximize[0])
            if only_best:
                i, best_entries = 1, [all_entries[0]]
                while all_entries[0].get(sorted_by[0]) == all_entries[i].get(
                        sorted_by[0]):
                    i += 1
                    best_entries.append(all_entries[i])
                all_entries = best_entries
            return all_entries
        # multiple objectives
        else:
            nds = NonDominatedSorting()
            # first get all values, minimization problem
            entries = self.get_all()
            all_values = np.zeros(shape=(len(entries), len(sorted_by)))
            for i, entry in enumerate(entries):
                all_values[i] = [
                    entry.get(s, data_set) * (-1 if m else 1)
                    for s, m in zip(sorted_by, maximize)
                ]
            # if we want only the best rank, we can sort and concat smaller groups for speed
            if only_best:
                # find best in small subgroups
                group_size, best_idx = 200, []
                for i in range(0, len(all_values) // group_size + 1):
                    start, end = group_size * i, min(group_size * (i + 1),
                                                     len(all_values))
                    idx = np.arange(start, end)
                    group = all_values[idx]
                    pareto_idx = nds.do(group,
                                        only_non_dominated_front=True) + start
                    best_idx.append(pareto_idx)
                # concat subgroups and find best
                best_idx = np.concatenate(best_idx, axis=0)
                pareto_idx = nds.do(all_values[best_idx],
                                    only_non_dominated_front=True)
                best_idx = best_idx[pareto_idx]
                all_entries = [entries[i] for i in best_idx]
            # we want all, sorted
            else:
                sorted_idx = nds.do(all_values, only_non_dominated_front=True)
                all_entries = [entries[i] for i in sorted_idx]
        return sorted(all_entries,
                      key=lambda s: s.get(sorted_by[0], data_set),
                      reverse=maximize[0])
Exemplo n.º 3
0
class AGEMOEASurvival(Survival):
    def __init__(self) -> None:
        super().__init__(filter_infeasible=True)
        self.nds = NonDominatedSorting()

    def _do(self, problem, pop, *args, n_survive=None, **kwargs):

        # get the objective values
        F = pop.get("F")

        N = n_survive

        # Non-dominated sorting
        fronts = self.nds.do(F, n_stop_if_ranked=N)

        # get max int value
        max_val = np.iinfo(np.int).max

        # initialize population ranks with max int value
        front_no = np.full(F.shape[0], max_val, dtype=np.int)

        # assign the rank to each individual
        for i, fr in enumerate(fronts):
            front_no[fr] = i

        pop.set("rank", front_no)

        # get the index of the front to be sorted and cut
        max_f_no = np.max(front_no[front_no != max_val])

        # keep fronts that have lower rank than the front to cut
        selected: np.ndarray = front_no < max_f_no

        n_ind, _ = F.shape

        # crowding distance is positive and has to be maximized
        crowd_dist = np.zeros(n_ind)

        # get the first front for normalization
        front1 = F[front_no == 0, :]

        # follows from the definition of the ideal point but with current non dominated solutions
        ideal_point = np.min(front1, axis=0)

        # Calculate the crowding distance of the first front as well as p and the normalization constants
        crowd_dist[front_no == 0], p, normalization = survival_score(
            front1, ideal_point)
        for i in range(
                1, max_f_no
        ):  # skip first front since it is normalized by survival_score
            front = F[front_no == i, :]
            m, _ = front.shape
            front = front / normalization
            crowd_dist[front_no == i] = 1. / minkowski_matrix(
                front, ideal_point[None, :], p=p).squeeze()

        # Select the solutions in the last front based on their crowding distances
        last = np.arange(selected.shape[0])[front_no == max_f_no]
        rank = np.argsort(crowd_dist[last])[::-1]
        selected[last[rank[:N - np.sum(selected)]]] = True

        pop.set("crowding", crowd_dist)

        # return selected solutions, number of selected should be equal to population size
        return pop[selected]
Exemplo n.º 4
0
class AbstractPbtSelector(ArgsInterface):
    """
    to figure out which clients should save, which should load from where, ...
    """

    def __init__(self, weights_dir: str, logger: Logger, targets: [OptimizationTarget], mutations: [AbstractPbtMutation],
                 each_epochs: int, grace_epochs: int, save_ema: bool, elitist: bool):
        super().__init__()
        self._nds = NonDominatedSorting()
        self._data = {
            'saves': dict(),         # {key: PbtSave}
            'replacements': [],      # ReplacementPbtEvent
        }
        self.weights_dir = weights_dir
        self.logger = logger
        self.targets = targets
        self.mutations = mutations
        self.each_epochs = each_epochs
        self.grace_epochs = grace_epochs
        self.save_ema = save_ema
        self.elitist = elitist

    @classmethod
    def meta_args_to_add(cls) -> [MetaArgument]:
        return [
            MetaArgument('cls_pbt_targets', Register.optimization_targets, allow_duplicates=True,
                         help_name='optimization target(s)'),
            MetaArgument('cls_pbt_mutations', Register.pbt_mutations,
                         help_name='mutations to the copied checkpoint training state'),
        ]

    @classmethod
    def args_to_add(cls, index=None) -> [Argument]:
        """ list arguments to add to argparse when this class (or a child class) is chosen """
        return super().args_to_add(index) + [
            Argument('each_epochs', default=1, type=int, help="only synchronize each n epochs"),
            Argument('grace_epochs', default=0, type=int, help="skip synchronization for the first n epochs"),
            Argument('save_ema', default="True", type=str, is_bool=True,
                     help="save the EMA-model weights if available, otherwise save the trained model's weights"),
            Argument('elitist', default="True", type=str, is_bool=True,
                     help="elitist: keep old checkpoints if they are better"),
        ]

    @classmethod
    def from_args(cls, save_dir: str, logger: Logger, args: Namespace, index=None) -> 'AbstractPbtSelector':
        targets = [cls_target.from_args(args, index=i) for
                   i, cls_target in enumerate(cls._parsed_meta_arguments(Register.optimization_targets, 'cls_pbt_targets', args, index))]
        mutations = [cls_mutations.from_args(args, index=i) for
                     i, cls_mutations in enumerate(cls._parsed_meta_arguments(Register.pbt_mutations, 'cls_pbt_mutations', args, index))]
        return cls(save_dir, logger, targets, mutations, **cls._all_parsed_arguments(args, index=index))

    @classmethod
    def _meta_name(cls, save_dir: str) -> str:
        return "%s/%s.meta.pt" % (save_dir, cls.__name__)

    def save(self, save_dir: str):
        torch.save(self._data, self._meta_name(save_dir))

    def load(self, save_dir: str):
        self._data = torch.load(self._meta_name(save_dir))

    # keep track of saves/checkpoints

    def get_saves(self) -> {str: PbtSave}:
        return self._data['saves']

    def get_saves_list(self) -> [PbtSave]:
        return list(self.get_saves().values())

    def get_save(self, epoch: int, client_id: int) -> Union[None, PbtSave]:
        for save in self.get_saves_list():
            if save.epoch == epoch and save.client_id == client_id:
                return save

    def save_client(self, epoch: int, client_id: int, log_dict: dict) -> PbtSave:
        """ save """
        path = "%s/%d-%d-weights.pt" % (self.weights_dir, epoch, client_id)
        save = PbtSave(epoch, client_id, log_dict, path=path)
        self._data['saves'][save.key] = save
        return save

    def remove_saves_by_keys(self, keys: list):
        """ specific saves """
        if len(keys) > 0:
            self.logger.info("Removing saves:")
            saves = [self.get_saves()[k] for k in keys]
            self.log_saves(saves)
            for s in saves:
                s.remove_file()
                del self._data['saves'][s.key]

    def remove_unused_saves(self):
        """ remove all saves that are not marked as used """
        self.remove_saves_by_keys([save.key for save in self.get_saves_list() if not save.is_used()])

    def get_best(self, saves: [PbtSave], epoch: int = None, exclude_old=False) -> List[List[PbtSave]]:
        """
        get the saves in order from best to worst, optionally of a specific epoch
        :param saves
        :param epoch: prefer checkpoints of this epochs, append the others at the back
        :param exclude_old: do not append any old checkpoints, only functions is epoch is set
        """
        # get the ranking
        log_dicts = [c.log_dict for c in saves]
        values = np.array([[target.sort_value(ld) for target in self.targets] for ld in log_dicts])
        _, rank = self._nds.do(values, return_rank=True)
        # sort the saves, possibly older ones to the back
        best = []
        old_best = []
        for i in range(max(rank)+1):
            # rank i
            best_ = []
            for c, r in zip(saves, rank):
                if r == i:
                    best_.append(c)
            # old ones separated?
            if isinstance(epoch, int):
                best.append([s for s in best_ if s.epoch == epoch])
                if not exclude_old:
                    old_best.append([s for s in best_ if s.epoch != epoch])
            else:
                best.append(best_)
        # add old ones (maybe empty), remove empty lists, return
        best += old_best
        best = [b for b in best if len(b) > 0]
        return best

    def cleanup(self, epoch: int = None):
        """
        keep the best saves (pareto front), remove the rest
        :param epoch: if the selector is not elitist, remove older saves regardless of performance
        """
        best = self.get_best(self.get_saves_list(), epoch=None if self.elitist else epoch)
        keys = []
        for i, rank_i in enumerate(best):
            if i == 0:
                continue
            keys.extend([c.key for c in rank_i])
        self.remove_saves_by_keys(keys)
        assert len(self.get_saves()) > 0

    def log_saves(self, saves: [PbtSave] = None):
        if saves is None:
            self.logger.info("All saves:")
            saves = flatten(self.get_best(self.get_saves_list()))
        saves = sorted(saves, key=lambda save: self.targets[0].sort_value(save.log_dict))
        lines = [["epoch=%d" % save.epoch]
                 + ["client=%d" % save.client_id]
                 + [target.as_str(save.log_dict) for target in self.targets]
                 + [save.get_path() if save.is_used() else ""]
                 for save in saves]
        log_in_columns(self.logger, lines, add_bullets=True)

    # keep track of events

    def add_replacement_event(self, r: ReplacementPbtEvent):
        self._data['replacements'].append(r)

    def get_replacement_events(self, epoch: int = None, client_id: int = None):
        """
        get list of all replacements that happened, optionally filter
        :param epoch:
        :param client_id:
        """
        events = self._data['replacements']
        if isinstance(epoch, int):
            events = [e for e in events if e.epoch == epoch]
        if isinstance(client_id, int):
            events = [e for e in events if e.client_id == client_id]
        return events

    def log_events(self, events: [AbstractPbtEvent], text: str = None):
        if len(events) > 0:
            if isinstance(text, str):
                self.logger.info(text)
            lines = [e.str_columns() for e in events]
            log_in_columns(self.logger, lines, add_bullets=True)

    # selecting, mutating, responses

    @classmethod
    def empty_response(cls, client_id: int) -> PbtServerResponse:
        return PbtServerResponse(client_id=client_id)

    def is_interesting(self, epoch: int, log_dict: dict) -> bool:
        """ if the watched key is not in the log dict, no need to synchronize """
        if epoch < self.grace_epochs:
            return False
        if (epoch - self.grace_epochs + 1) % self.each_epochs > 0:
            return False
        return all([target.is_interesting(log_dict) for target in self.targets])

    def first_use(self, log_dicts: {int, dict}) -> {int, (dict, PbtServerResponse)}:
        """ create the changed log dict and response for each client """
        ret = {}
        for client_id, log_dict in log_dicts.items():
            ld, r = log_dict, self.empty_response(client_id)
            for m in self.mutations:
                ld, r = m.initial_mutate(r, ld, len(log_dicts))
            ret[client_id] = (ld, r)
        return ret

    def select(self, epoch: int, log_dicts: {int, dict}) -> {int, PbtServerResponse}:
        """ create the responses for each client """
        # reset mutations
        for m in self.mutations:
            m.reset()
        # reset usage of all saves
        for save in self.get_saves_list():
            save.reset()
        # add all to saves, without actually saving states yet (not necessary for all)
        for client_id, log_dict in log_dicts.items():
            self.save_client(epoch, client_id, log_dict)
        # empty responses
        responses = {}
        for client_id, log_dict in log_dicts.items():
            responses[client_id] = PbtServerResponse(client_id=client_id, save_ema=self.save_ema)
        # get replacements
        for (replace, replace_with) in self._select(responses, epoch, log_dicts):
            replace_with.add_usage()
            e = ReplacementPbtEvent(replace.epoch, replace.client_id,
                                    replace_with.epoch, replace_with.client_id, replace_with.get_path())
            self.add_replacement_event(e)
            responses[replace.client_id].load_path = replace_with.get_path()
            # apply mutations
            for m in self.mutations:
                responses[replace.client_id] = m.mutate(responses[replace.client_id], replace_with.log_dict)
        # remove all unused saves
        self.remove_unused_saves()
        # log replacements
        self.log_events(self.get_replacement_events(epoch=epoch), text="Replacements:")
        # log remaining saves, return responses
        self.log_saves()
        return responses

    def _select(self, responses: {int: PbtServerResponse}, epoch: int, log_dicts: {int: dict}) -> [(PbtSave, PbtSave)]:
        """
        create the responses for each client, {client_id: log_dict}, mark all saves that should be kept
        :param responses: {client_id: PbtServerResponse}
        :param epoch: current epoch
        :param log_dicts: {client_id: dict}
        :return replacements [(to_replace, replace_with)]
        """
        raise NotImplementedError

    # plotting results

    def plot_results(self, save_dir: str, log_dicts: dict):
        """ task complete, plot results, {epoch: {client_id: log_dict}} """
        # reshape into {key: {client_id: list of entries}}
        epochs = sorted(list(log_dicts.keys()))
        clients = list(log_dicts[epochs[0]].keys())
        keys = log_dicts[epochs[-1]][clients[0]]
        reshaped_log_dicts = {k: {c: list() for c in clients} for k in keys}
        for e in epochs:
            for c in clients:
                for k in keys:
                    reshaped_log_dicts[k][c].append(log_dicts[e][c].get(k, None))
        # actually plot
        self._plot_results(save_dir, log_dicts, reshaped_log_dicts, epochs, clients, keys)

    def _plot_results(self, save_dir: str, log_dicts: dict, reshaped_log_dicts: dict,
                      epochs: [int], clients: [int], keys: [str]):
        """

        :param save_dir: where to save
        :param log_dicts: {epoch: {client_id: {key: value}}}
        :param reshaped_log_dicts: {key: {client_id: list of values}}
        :param epochs: list of epochs
        :param clients: list of client ids
        :param keys: list of keys in the log dicts
        """
        # plot all by keys
        for key in keys:
            is_id_relevant = True
            for n in ['train', 'val', 'test']:
                if key.startswith(n):
                    is_id_relevant = False
                    break
            fmt = 'o--' if is_id_relevant else '-'
            for c in clients:
                plt.plot(epochs, reshaped_log_dicts[key][c], fmt, label="%d" % c)
            if is_id_relevant:
                plt.legend()
            plt.xlabel("epoch")
            plt.ylabel(key)
            plt.savefig("%s/key_%s.pdf" % (save_dir, key.replace('/', '_')))
            plt.clf()

        # lineage
        for e0, e1 in zip(epochs[:-1], epochs[1:]):
            for client_id in clients:
                if e0 == 0:
                    plt.plot((e0, e0), (client_id, client_id), 'bo-', linewidth=1)
                r = self.get_replacement_events(epoch=e0, client_id=client_id)
                if len(r) == 0:
                    # no replacement happened
                    plt.plot((e0, e1), (client_id, client_id), 'bo-', linewidth=1)
                else:
                    # no replacement happened
                    if len(r) > 1:
                        self.logger.warning("Have %d replacements for client_id=%d, epoch1=%d, only taking the first" %
                                            (len(r), client_id, e1))
                    r = r[0]
                    plt.plot((r.epoch_replaced_with, e1), (r.client_replaced_with, client_id), 'ro--', linewidth=2)
        plt.xlabel("epoch")
        plt.ylabel("client id")
        plt.savefig("%s/lineage.pdf" % save_dir)
        plt.clf()