Example #1
0
    def _start_experiments(self):
        self.cc.settings['general']['distribution'][
            'start_time'] = time.strftime('%Y-%m-%d_%H-%M-%S')

        # If DB logging is enabled, create a new experiment and attach its ID to settings for clients
        db_logger = DbLogger()
        if db_logger.is_enabled:
            self.experiment_id = db_logger.create_experiment(self.cc.settings)
            self.cc.settings['general']['logging'][
                'experiment_id'] = self.experiment_id

        for client in self.cc.settings['general']['distribution'][
                'client_nodes']:
            address = 'http://{}:{}/experiments'.format(
                client['address'], client['port'])
            try:
                resp = requests.post(address, json=self.cc.settings)
                assert resp.status_code == 200, resp.text
                self._logger.info(
                    'Successfully started experiment on {}'.format(address))
            except AssertionError as err:
                self._logger.critical(
                    'Could not start experiment on {}: {}'.format(
                        address, err))
                self._terminate()
Example #2
0
    def __init__(self,
                 dataloader,
                 network_factory,
                 population_size=10,
                 tournament_size=2,
                 mutation_probability=0.9,
                 n_replacements=1,
                 sigma=0.25,
                 alpha=0.25,
                 default_adam_learning_rate=0.001,
                 calc_mixture=False,
                 mixture_sigma=0.01,
                 score_sample_size=10000,
                 discriminator_skip_each_nth_step=0,
                 enable_selection=True):

        super().__init__(dataloader, network_factory, population_size,
                         tournament_size, mutation_probability, n_replacements,
                         sigma, alpha)

        self.batch_number = 0

        self._default_adam_learning_rate = self.settings.get(
            'default_adam_learning_rate', default_adam_learning_rate)
        self._discriminator_skip_each_nth_step = self.settings.get(
            'discriminator_skip_each_nth_step',
            discriminator_skip_each_nth_step)
        self._enable_selection = self.settings.get('enable_selection',
                                                   enable_selection)
        self.mixture_sigma = self.settings.get('mixture_sigma', mixture_sigma)

        self.neighbourhood = Neighbourhood.instance()

        for i, individual in enumerate(self.population_gen.individuals):
            individual.learning_rate = self._default_adam_learning_rate
            individual.id = '{}/G{}'.format(self.neighbourhood.cell_number, i)
        for i, individual in enumerate(self.population_dis.individuals):
            individual.learning_rate = self._default_adam_learning_rate
            individual.id = '{}/D{}'.format(self.neighbourhood.cell_number, i)

        self.concurrent_populations = ConcurrentPopulations.instance()
        self.concurrent_populations.generator = self.population_gen
        self.concurrent_populations.discriminator = self.population_dis
        self.concurrent_populations.unlock()

        experiment_id = ConfigurationContainer.instance(
        ).settings['general']['logging'].get('experiment_id', None)
        self.db_logger = DbLogger(current_experiment=experiment_id)

        if 'score' in self.settings and self.settings['score'].get(
                'enabled', calc_mixture):
            self.score_calc = ScoreCalculatorFactory.create()
            self.score_sample_size = self.settings['score'].get(
                'sample_size', score_sample_size)
            self.score = float(
                'inf') if self.score_calc.is_reversed else float('-inf')
        else:
            self.score_calc = None
            self.score = 0
Example #3
0
    def _terminate(self, stop_clients=True, return_code=-1):
        try:
            if stop_clients:
                self._logger.info('Stopping clients...')
                self.comms.close_all()  #stop_running_experiments()
        finally:
            db_logger = DbLogger()
            if db_logger.is_enabled and self.experiment_id is not None:
                db_logger.finish_experiment(self.experiment_id)

            exit(return_code)
Example #4
0
    def _terminate(self, stop_clients=True, return_code=-1):
        try:
            if self.heartbeat_thread:
                self._logger.info('Stopping heartbeat...')
                self.heartbeat_thread.stopped.set()
                self.heartbeat_thread.join()

            if stop_clients:
                self._logger.info('Stopping clients...')
                node_client = NodeClient(None)
                node_client.stop_running_experiments()
        finally:
            db_logger = DbLogger()
            if db_logger.is_enabled and self.experiment_id is not None:
                db_logger.finish_experiment(self.experiment_id)

            exit(return_code)
Example #5
0
    def _start_experiments(self):
        self.cc.settings['general']['distribution'][
            'start_time'] = time.strftime('%Y-%m-%d_%H-%M-%S')

        # If DB logging is enabled, create a new experiment and attach its ID to settings for clients
        db_logger = DbLogger()
        if db_logger.is_enabled:
            self.experiment_id = db_logger.create_experiment(self.cc.settings)
            self.cc.settings['general']['logging'][
                'experiment_id'] = self.experiment_id

        while self.grid.empty_spaces() > 0:
            worker = self.topology.get_best_worker()
            w_id = self.grid.assign_worker(worker)
            self.topology.assign_worker(worker)

            self._logger.info("Asigned worker {} to wid {}".format(
                worker, w_id))
            self.comms.send_task("new_comm",
                                 worker,
                                 data={
                                     "color": 0,
                                     "key": w_id
                                 })

        for off_pu in self.topology.inactive_pu:
            self._logger.info("Asigned worker {} to rest".format(off_pu))
            self.comms.send_task("new_comm",
                                 off_pu,
                                 data={
                                     "color": 1,
                                     "key": off_pu
                                 })

        self.comms.new_comm(1, 0)

        # print(self.grid.grid)
        for proc_unit in self.grid.grid_to_list():
            self.cc.settings["general"]["distribution"]["grid"][
                "config"] = self.grid.grid
            self.comms.start_worker(proc_unit, self.cc.settings)
Example #6
0
    def _gather_results(self):
        self._logger.info('Collecting results from clients...')

        # Initialize node client
        dataloader = self.cc.create_instance(
            self.cc.settings['dataloader']['dataset_name'])
        network_factory = self.cc.create_instance(
            self.cc.settings['network']['name'], dataloader.n_input_neurons)
        node_client = NodeClient(network_factory)
        db_logger = DbLogger()

        results = node_client.gather_results(
            self.cc.settings['general']['distribution']['client_nodes'], 120)

        scores = []
        for (node, generator_pop, discriminator_pop, weights_generator,
             weights_discriminator) in results:
            node_name = '{}:{}'.format(node['address'], node['port'])
            try:
                output_dir = self.get_and_create_output_dir(node)

                for generator in generator_pop.individuals:
                    source = generator.source.replace(':', '-')
                    filename = '{}{}.pkl'.format(GENERATOR_PREFIX, source)
                    torch.save(
                        generator.genome.net.state_dict(),
                        os.path.join(output_dir,
                                     'generator-{}.pkl'.format(source)))

                    with open(os.path.join(output_dir, 'mixture.yml'),
                              "a") as file:
                        file.write('{}: {}\n'.format(
                            filename, weights_generator[generator.source]))

                for discriminator in discriminator_pop.individuals:
                    source = discriminator.source.replace(':', '-')
                    filename = '{}{}.pkl'.format(DISCRIMINATOR_PREFIX, source)
                    torch.save(discriminator.genome.net.state_dict(),
                               os.path.join(output_dir, filename))

                # Save images
                dataset = MixedGeneratorDataset(
                    generator_pop, weights_generator,
                    self.cc.settings['master']['score_sample_size'],
                    self.cc.settings['trainer']
                    ['mixture_generator_samples_mode'])
                image_paths = self.save_samples(dataset, output_dir,
                                                dataloader)
                self._logger.info(
                    'Saved mixture result images of client {} to target directory {}.'
                    .format(node_name, output_dir))

                # Calculate inception or FID score
                score = float('-inf')
                if self.cc.settings['master']['calculate_score']:
                    calc = ScoreCalculatorFactory.create()
                    self._logger.info('Score calculator: {}'.format(
                        type(calc).__name__))
                    self._logger.info(
                        'Calculating score score of {}. Depending on the type, this may take very long.'
                        .format(node_name))

                    score = calc.calculate(dataset)
                    self._logger.info(
                        'Node {} with weights {} yielded a score of {}'.format(
                            node_name, weights_generator, score))
                    scores.append((node, score))

                if db_logger.is_enabled and self.experiment_id is not None:
                    db_logger.add_experiment_results(self.experiment_id,
                                                     node_name, image_paths,
                                                     score)
            except Exception as ex:
                self._logger.error(
                    'An error occured while trying to gather results from {}: {}'
                    .format(node_name, ex))
                traceback.print_exc()

        if self.cc.settings['master']['calculate_score'] and scores:
            best_node = sorted(
                scores,
                key=lambda x: x[1],
                reverse=ScoreCalculatorFactory.create().is_reversed)[-1]
            self._logger.info('Best result: {}:{} = {}'.format(
                best_node[0]['address'], best_node[0]['port'], best_node[1]))
Example #7
0
    def __init__(
        self,
        dataloader,
        network_factory,
        population_size=10,
        tournament_size=2,
        mutation_probability=0.9,
        n_replacements=1,
        sigma=0.25,
        alpha=0.25,
        default_adam_learning_rate=0.001,
        calc_mixture=False,
        mixture_sigma=0.01,
        score_sample_size=10000,
        discriminator_skip_each_nth_step=0,
        enable_selection=True,
        fitness_sample_size=10000,
        calculate_net_weights_dist=False,
        fitness_mode="worst",
        es_score_sample_size=10000,
        es_random_init=False,
        checkpoint_period=0,
    ):

        super().__init__(
            dataloader,
            network_factory,
            population_size,
            tournament_size,
            mutation_probability,
            n_replacements,
            sigma,
            alpha,
        )

        self.batch_number = 0
        self.cc = ConfigurationContainer.instance()

        self._default_adam_learning_rate = self.settings.get(
            "default_adam_learning_rate", default_adam_learning_rate)
        self._discriminator_skip_each_nth_step = self.settings.get(
            "discriminator_skip_each_nth_step",
            discriminator_skip_each_nth_step,
        )
        self._enable_selection = self.settings.get("enable_selection",
                                                   enable_selection)
        self.mixture_sigma = self.settings.get("mixture_sigma", mixture_sigma)

        self.neighbourhood = Neighbourhood.instance()

        for i, individual in enumerate(self.population_gen.individuals):
            individual.learning_rate = self.settings.get(
                "default_g_adam_learning_rate",
                self._default_adam_learning_rate,
            )
            individual.id = "{}/G{}".format(self.neighbourhood.cell_number, i)
        for i, individual in enumerate(self.population_dis.individuals):
            individual.learning_rate = self.settings.get(
                "default_d_adam_learning_rate",
                self._default_adam_learning_rate,
            )
            individual.id = "{}/D{}".format(self.neighbourhood.cell_number, i)

        self.concurrent_populations = ConcurrentPopulations.instance()
        self.concurrent_populations.generator = self.population_gen
        self.concurrent_populations.discriminator = self.population_dis
        self.concurrent_populations.unlock()

        experiment_id = self.cc.settings["general"]["logging"].get(
            "experiment_id", None)
        self.db_logger = DbLogger(current_experiment=experiment_id)

        if "fitness" in self.settings:
            self.fitness_sample_size = self.settings["fitness"].get(
                "fitness_sample_size", fitness_sample_size)
            self.fitness_loaded = self.dataloader.load()
            self.fitness_iterator = iter(
                self.fitness_loaded)  # Create iterator for fitness loader

            # Determine how to aggregate fitness calculated among neighbourhood
            self.fitness_mode = self.settings["fitness"].get(
                "fitness_mode", fitness_mode)
            if self.fitness_mode not in ["worse", "best", "average"]:
                raise NotImplementedError(
                    "Invalid argument for fitness_mode: {}".format(
                        self.fitness_mode))
        else:
            # TODO: Add code for safe implementation & error handling
            raise KeyError(
                "Fitness section must be defined in configuration file")

        n_iterations = self.cc.settings["trainer"].get("n_iterations", 0)

        if ("score" in self.settings and self.settings["score"].get(
                "enabled",
                calc_mixture)) or "optimize_mixture" in self.settings:
            self.score_calc = ScoreCalculatorFactory.create()
            self.score_sample_size = self.settings["score"].get(
                "sample_size", score_sample_size)
            self.score = float(
                "inf") if self.score_calc.is_reversed else float("-inf")
            self.mixture_generator_samples_mode = self.cc.settings["trainer"][
                "mixture_generator_samples_mode"]
        else:
            self.score_sample_size = score_sample_size
            self.score_calc = None
            self.score = 0

        if "optimize_mixture" in self.settings:
            self.optimize_weights_at_the_end = self.settings[
                "optimize_mixture"].get("enabled", True)
            self.score_sample_size = self.settings["optimize_mixture"].get(
                "sample_size", es_score_sample_size)
            self.es_generations = self.settings["optimize_mixture"].get(
                "es_generations", n_iterations)
            self.es_random_init = self.settings["optimize_mixture"].get(
                "es_random_init", es_random_init)
            self.mixture_sigma = self.settings["optimize_mixture"].get(
                "mixture_sigma", mixture_sigma)
            self.mixture_generator_samples_mode = self.cc.settings["trainer"][
                "mixture_generator_samples_mode"]
        else:
            self.optimize_weights_at_the_end = True
            self.score_sample_size = es_score_sample_size
            self.es_generations = n_iterations
            self.es_random_init = es_random_init
            self.mixture_sigma = mixture_sigma
            self.mixture_generator_samples_mode = self.cc.settings["trainer"][
                "mixture_generator_samples_mode"]

        assert 0 <= checkpoint_period <= n_iterations, (
            "Checkpoint period paramenter (checkpoint_period) should be "
            "between 0 and the number of iterations (n_iterations).")
        self.checkpoint_period = self.cc.settings["general"].get(
            "checkpoint_period", checkpoint_period)
Example #8
0
class LipizzanerGANTrainer(EvolutionaryAlgorithmTrainer):
    """
    Distributed, asynchronous trainer for coevolutionary GANs. Uses the standard Goodfellow GAN approach.
    (Without discriminator mixture)
    """
    def __init__(
        self,
        dataloader,
        network_factory,
        population_size=10,
        tournament_size=2,
        mutation_probability=0.9,
        n_replacements=1,
        sigma=0.25,
        alpha=0.25,
        default_adam_learning_rate=0.001,
        calc_mixture=False,
        mixture_sigma=0.01,
        score_sample_size=10000,
        discriminator_skip_each_nth_step=0,
        enable_selection=True,
        fitness_sample_size=10000,
        calculate_net_weights_dist=False,
        fitness_mode="worst",
        es_score_sample_size=10000,
        es_random_init=False,
        checkpoint_period=0,
    ):

        super().__init__(
            dataloader,
            network_factory,
            population_size,
            tournament_size,
            mutation_probability,
            n_replacements,
            sigma,
            alpha,
        )

        self.batch_number = 0
        self.cc = ConfigurationContainer.instance()

        self._default_adam_learning_rate = self.settings.get(
            "default_adam_learning_rate", default_adam_learning_rate)
        self._discriminator_skip_each_nth_step = self.settings.get(
            "discriminator_skip_each_nth_step",
            discriminator_skip_each_nth_step,
        )
        self._enable_selection = self.settings.get("enable_selection",
                                                   enable_selection)
        self.mixture_sigma = self.settings.get("mixture_sigma", mixture_sigma)

        self.neighbourhood = Neighbourhood.instance()

        for i, individual in enumerate(self.population_gen.individuals):
            individual.learning_rate = self.settings.get(
                "default_g_adam_learning_rate",
                self._default_adam_learning_rate,
            )
            individual.id = "{}/G{}".format(self.neighbourhood.cell_number, i)
        for i, individual in enumerate(self.population_dis.individuals):
            individual.learning_rate = self.settings.get(
                "default_d_adam_learning_rate",
                self._default_adam_learning_rate,
            )
            individual.id = "{}/D{}".format(self.neighbourhood.cell_number, i)

        self.concurrent_populations = ConcurrentPopulations.instance()
        self.concurrent_populations.generator = self.population_gen
        self.concurrent_populations.discriminator = self.population_dis
        self.concurrent_populations.unlock()

        experiment_id = self.cc.settings["general"]["logging"].get(
            "experiment_id", None)
        self.db_logger = DbLogger(current_experiment=experiment_id)

        if "fitness" in self.settings:
            self.fitness_sample_size = self.settings["fitness"].get(
                "fitness_sample_size", fitness_sample_size)
            self.fitness_loaded = self.dataloader.load()
            self.fitness_iterator = iter(
                self.fitness_loaded)  # Create iterator for fitness loader

            # Determine how to aggregate fitness calculated among neighbourhood
            self.fitness_mode = self.settings["fitness"].get(
                "fitness_mode", fitness_mode)
            if self.fitness_mode not in ["worse", "best", "average"]:
                raise NotImplementedError(
                    "Invalid argument for fitness_mode: {}".format(
                        self.fitness_mode))
        else:
            # TODO: Add code for safe implementation & error handling
            raise KeyError(
                "Fitness section must be defined in configuration file")

        n_iterations = self.cc.settings["trainer"].get("n_iterations", 0)

        if ("score" in self.settings and self.settings["score"].get(
                "enabled",
                calc_mixture)) or "optimize_mixture" in self.settings:
            self.score_calc = ScoreCalculatorFactory.create()
            self.score_sample_size = self.settings["score"].get(
                "sample_size", score_sample_size)
            self.score = float(
                "inf") if self.score_calc.is_reversed else float("-inf")
            self.mixture_generator_samples_mode = self.cc.settings["trainer"][
                "mixture_generator_samples_mode"]
        else:
            self.score_sample_size = score_sample_size
            self.score_calc = None
            self.score = 0

        if "optimize_mixture" in self.settings:
            self.optimize_weights_at_the_end = self.settings[
                "optimize_mixture"].get("enabled", True)
            self.score_sample_size = self.settings["optimize_mixture"].get(
                "sample_size", es_score_sample_size)
            self.es_generations = self.settings["optimize_mixture"].get(
                "es_generations", n_iterations)
            self.es_random_init = self.settings["optimize_mixture"].get(
                "es_random_init", es_random_init)
            self.mixture_sigma = self.settings["optimize_mixture"].get(
                "mixture_sigma", mixture_sigma)
            self.mixture_generator_samples_mode = self.cc.settings["trainer"][
                "mixture_generator_samples_mode"]
        else:
            self.optimize_weights_at_the_end = True
            self.score_sample_size = es_score_sample_size
            self.es_generations = n_iterations
            self.es_random_init = es_random_init
            self.mixture_sigma = mixture_sigma
            self.mixture_generator_samples_mode = self.cc.settings["trainer"][
                "mixture_generator_samples_mode"]

        assert 0 <= checkpoint_period <= n_iterations, (
            "Checkpoint period paramenter (checkpoint_period) should be "
            "between 0 and the number of iterations (n_iterations).")
        self.checkpoint_period = self.cc.settings["general"].get(
            "checkpoint_period", checkpoint_period)

    def train(self, n_iterations, stop_event=None):
        loaded = self.dataloader.load()

        alpha = None  #self.neighbourhood.alpha
        beta = None  #self.neighbourhood.beta
        if alpha is not None:
            self._logger.info(f"Alpha is {alpha} and Beta is {beta}")
        else:
            self._logger.debug("Alpha and Beta are not set")

        for iteration in range(n_iterations):
            self._logger.debug("Iteration {} started".format(iteration + 1))
            start_time = time()

            all_generators = self.neighbourhood.all_generators
            all_discriminators = self.neighbourhood.all_discriminators
            local_generators = self.neighbourhood.local_generators
            local_discriminators = self.neighbourhood.local_discriminators

            # Log the name of individuals in entire neighborhood and local individuals for every iteration
            # (to help tracing because individuals from adjacent cells might be from different iterations)
            self._logger.info(
                "Neighborhood located in possition {} of the grid".format(
                    self.neighbourhood.grid_position))
            self._logger.info(
                "Generators in current neighborhood are {}".format([
                    individual.name
                    for individual in all_generators.individuals
                ]))
            self._logger.info(
                "Discriminators in current neighborhood are {}".format([
                    individual.name
                    for individual in all_discriminators.individuals
                ]))
            self._logger.info(
                "Local generators in current neighborhood are {}".format([
                    individual.name
                    for individual in local_generators.individuals
                ]))
            self._logger.info(
                "Local discriminators in current neighborhood are {}".format([
                    individual.name
                    for individual in local_discriminators.individuals
                ]))

            self._logger.info(
                "L2 distance between all generators weights: {}".format(
                    all_generators.net_weights_dist))
            self._logger.info(
                "L2 distance between all discriminators weights: {}".format(
                    all_discriminators.net_weights_dist))

            new_populations = {}

            # Create random dataset to evaluate fitness in each iterations
            (
                fitness_samples,
                fitness_labels,
            ) = self.generate_random_fitness_samples(self.fitness_sample_size)
            if (self.cc.settings["dataloader"]["dataset_name"] == "celeba" or
                    self.cc.settings["dataloader"]["dataset_name"] == "cifar"
                    or self.cc.settings["network"]["name"]
                    == "ssgan_convolutional_mnist"):
                fitness_samples = to_pytorch_variable(fitness_samples)
                fitness_labels = to_pytorch_variable(fitness_labels)
            elif self.cc.settings["dataloader"][
                    "dataset_name"] == "network_traffic":
                fitness_samples = to_pytorch_variable(
                    generate_random_sequences(self.fitness_sample_size))
            else:
                fitness_samples = to_pytorch_variable(
                    fitness_samples.view(self.fitness_sample_size, -1))
                fitness_labels = to_pytorch_variable(
                    fitness_labels.view(self.fitness_sample_size, -1))

            fitness_labels = torch.squeeze(fitness_labels)

            # Fitness evaluation
            self._logger.debug("Evaluating fitness")
            self.evaluate_fitness(
                all_generators,
                all_discriminators,
                fitness_samples,
                self.fitness_mode,
            )
            self.evaluate_fitness(
                all_discriminators,
                all_generators,
                fitness_samples,
                self.fitness_mode,
                labels=fitness_labels,
                logger=self._logger,
                alpha=alpha,
                beta=beta,
                iter=iteration,
                log_class_distribution=True,
            )
            self._logger.debug("Finished evaluating fitness")

            # Tournament selection
            if self._enable_selection:
                self._logger.debug("Started tournament selection")
                new_populations[TYPE_GENERATOR] = self.tournament_selection(
                    all_generators, TYPE_GENERATOR, is_logging=True)
                new_populations[
                    TYPE_DISCRIMINATOR] = self.tournament_selection(
                        all_discriminators,
                        TYPE_DISCRIMINATOR,
                        is_logging=True)
                self._logger.debug("Finished tournament selection")

            self.batch_number = 0
            data_iterator = iter(loaded)
            while self.batch_number < len(loaded):
                if self.cc.settings["dataloader"][
                        "dataset_name"] == "network_traffic":
                    input_data = to_pytorch_variable(next(data_iterator))
                    batch_size = input_data.size(0)
                else:
                    input_data, labels = next(data_iterator)
                    batch_size = input_data.size(0)
                    input_data = to_pytorch_variable(
                        self.dataloader.transpose_data(input_data))
                    labels = to_pytorch_variable(
                        self.dataloader.transpose_data(labels))
                    labels = torch.squeeze(labels)

                # Quit if requested
                if stop_event is not None and stop_event.is_set():
                    self._logger.warning("External stop requested.")
                    return self.result()

                attackers = new_populations[
                    TYPE_GENERATOR] if self._enable_selection else local_generators
                defenders = new_populations[
                    TYPE_DISCRIMINATOR] if self._enable_selection else all_discriminators
                input_data = self.step(
                    local_generators,
                    attackers,
                    defenders,
                    input_data,
                    self.batch_number,
                    loaded,
                    data_iterator,
                    iter=iteration,
                )

                if (self._discriminator_skip_each_nth_step == 0
                        or self.batch_number %
                    (self._discriminator_skip_each_nth_step + 1) == 0):
                    self._logger.debug("Skipping discriminator step")

                    attackers = new_populations[
                        TYPE_DISCRIMINATOR] if self._enable_selection else local_discriminators
                    defenders = new_populations[
                        TYPE_GENERATOR] if self._enable_selection else all_generators
                    input_data = self.step(
                        local_discriminators,
                        attackers,
                        defenders,
                        input_data,
                        self.batch_number,
                        loaded,
                        data_iterator,
                        labels=labels,
                        alpha=alpha,
                        beta=beta,
                        iter=iteration,
                    )

                self._logger.info("Iteration {}, Batch {}/{}".format(
                    iteration + 1, self.batch_number, len(loaded)))

                # If n_batches is set to 0, all batches will be used
                if self.is_last_batch(self.batch_number):
                    break

                self.batch_number += 1

            # Perform selection first before mutation of mixture_weights
            # Replace the worst with the best new
            if self._enable_selection:
                # Evaluate fitness of new_populations against neighborhood
                self.evaluate_fitness(
                    new_populations[TYPE_GENERATOR],
                    all_discriminators,
                    fitness_samples,
                    self.fitness_mode,
                )
                self.evaluate_fitness(
                    new_populations[TYPE_DISCRIMINATOR],
                    all_generators,
                    fitness_samples,
                    self.fitness_mode,
                    labels=fitness_labels,
                    alpha=alpha,
                    beta=beta,
                    iter=iteration,
                )
                self.concurrent_populations.lock()
                local_generators.replacement(
                    new_populations[TYPE_GENERATOR],
                    self._n_replacements,
                    is_logging=True,
                )
                local_generators.sort_population(is_logging=True)
                local_discriminators.replacement(
                    new_populations[TYPE_DISCRIMINATOR],
                    self._n_replacements,
                    is_logging=True,
                )
                local_discriminators.sort_population(is_logging=True)
                self.concurrent_populations.unlock()

                # Update individuals' iteration and id after replacement and logging to ease tracing
                for i, individual in enumerate(local_generators.individuals):
                    individual.id = "{}/G{}".format(
                        self.neighbourhood.cell_number, i)
                    individual.iteration = iteration + 1
                for i, individual in enumerate(
                        local_discriminators.individuals):
                    individual.id = "{}/D{}".format(
                        self.neighbourhood.cell_number, i)
                    individual.iteration = iteration + 1
            else:
                # Re-evaluate fitness of local_generators and local_discriminators against neighborhood
                self.evaluate_fitness(
                    local_generators,
                    all_discriminators,
                    fitness_samples,
                    self.fitness_mode,
                )
                self.evaluate_fitness(
                    local_discriminators,
                    all_generators,
                    fitness_samples,
                    self.fitness_mode,
                    labels=fitness_labels,
                    alpha=alpha,
                    beta=beta,
                    iter=iteration,
                )

            self.compute_mixture_generative_score(iteration)

            stop_time = time()

            path_real_images, path_fake_images = self.log_results(
                batch_size,
                iteration,
                input_data,
                loaded,
                lr_gen=self.concurrent_populations.generator.individuals[0].
                learning_rate,
                lr_dis=self.concurrent_populations.discriminator.
                individuals[0].learning_rate,
                score=self.score,
                mixture_gen=self.neighbourhood.mixture_weights_generators,
                mixture_dis=None,
            )

            if self.db_logger.is_enabled:
                self.db_logger.log_results(
                    iteration,
                    self.neighbourhood,
                    self.concurrent_populations,
                    self.score,
                    stop_time - start_time,
                    path_real_images,
                    path_fake_images,
                )

            if self.checkpoint_period > 0 and (
                    iteration + 1) % self.checkpoint_period == 0:
                self.save_checkpoint(
                    all_generators.individuals,
                    all_discriminators.individuals,
                    self.neighbourhood.cell_number,
                    self.neighbourhood.grid_position,
                )

        # Evaluate the discriminators when addressing Semi-supervised Learning
        if "ssgan" in self.cc.settings["network"]["name"]:
            discriminator = self.concurrent_populations.discriminator.individuals[
                0].genome
            batch_size = self.dataloader.batch_size
            dataloader_loaded = self.dataloader.load(train=True)
            self.test_accuracy_discriminators(discriminator,
                                              dataloader_loaded,
                                              train=True)

            self.dataloader.batch_size = 100
            dataloader_loaded = self.dataloader.load(train=False)
            self.test_accuracy_discriminators(discriminator,
                                              dataloader_loaded,
                                              train=False)

            discriminators = [
                individual.genome for individual in
                self.neighbourhood.all_discriminators.individuals
            ]

            for model in discriminators:
                dataloader_loaded = self.dataloader.load(train=False)
                self.test_accuracy_discriminators(model,
                                                  dataloader_loaded,
                                                  train=False)

            dataloader_loaded = self.dataloader.load(train=False)
            self.test_majority_voting_discriminators(discriminators,
                                                     dataloader_loaded,
                                                     train=False)
            self.dataloader.batch_size = batch_size

        if self.optimize_weights_at_the_end:
            self.optimize_generator_mixture_weights()

            path_real_images, path_fake_images = self.log_results(
                batch_size,
                iteration + 1,
                input_data,
                loaded,
                lr_gen=self.concurrent_populations.generator.individuals[0].
                learning_rate,
                lr_dis=self.concurrent_populations.discriminator.
                individuals[0].learning_rate,
                score=self.score,
                mixture_gen=self.neighbourhood.mixture_weights_generators,
                mixture_dis=self.neighbourhood.mixture_weights_discriminators,
            )

            if self.db_logger.is_enabled:
                self.db_logger.log_results(
                    iteration + 1,
                    self.neighbourhood,
                    self.concurrent_populations,
                    self.score,
                    stop_time - start_time,
                    path_real_images,
                    path_fake_images,
                )

        return self.result()

    def test_majority_voting_discriminators(self,
                                            models,
                                            test_loader,
                                            train=False):
        correct = 0
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        train_or_test = "Train" if train else "Test"
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                if self.cc.settings["network"]["name"] == "ssgan_perceptron":
                    data = data.view(-1, 784)
                elif self.cc.settings["network"][
                        "name"] == "ssgan_conv_mnist_28x28":
                    data = data.view(-1, 1, 28, 28)
                elif self.cc.settings["network"]["name"] == "ssgan_svhn":
                    data = data.view(-1, 3, 32, 32)
                else:
                    if self.cc.settings["dataloader"][
                            "dataset_name"] == "cifar":
                        data = data.view(-1, 3, 64, 64)
                    else:
                        data = data.view(-1, 1, 64, 64)
                pred_accumulator = []
                for model in models:
                    output = model.classification_layer(model.net(data))
                    output = output.view(-1, 11)
                    pred = output.argmax(dim=1, keepdim=True)
                    pred_accumulator.append(pred.view(-1))
                label_votes = to_pytorch_variable(
                    torch.tensor(list(zip(*pred_accumulator))))
                prediction = to_pytorch_variable(
                    torch.tensor([
                        labels.bincount(minlength=11).argmax()
                        for labels in label_votes
                    ]))
                correct += prediction.eq(
                    target.view_as(prediction)).sum().item()

        num_samples = len(test_loader.dataset)
        accuracy = 100.0 * float(correct / num_samples)
        self._logger.info(
            f"Majority Voting {train_or_test} Accuracy: {correct}/{num_samples} ({accuracy}%)"
        )

    def test_accuracy_discriminators(self, model, test_loader, train=False):
        correct = 0
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        train_or_test = "Train" if train else "Test"
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                if self.cc.settings["network"]["name"] == "ssgan_perceptron":
                    data = data.view(-1, 784)
                elif self.cc.settings["network"][
                        "name"] == "ssgan_conv_mnist_28x28":
                    data = data.view(-1, 1, 28, 28)
                elif self.cc.settings["network"]["name"] == "ssgan_svhn":
                    data = data.view(-1, 3, 32, 32)
                elif self.cc.settings["dataloader"]["dataset_name"] == "cifar":
                    data = data.view(-1, 3, 64, 64)
                else:
                    data = data.view(-1, 1, 64, 64)
                output = model.classification_layer(model.net(data))
                output = output.view(-1, 11)
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()

        num_samples = len(test_loader.dataset)
        accuracy = 100.0 * float(correct / num_samples)
        self._logger.info(
            f"{train_or_test} Accuracy: {correct}/{num_samples} ({accuracy}%)")

    def optimize_generator_mixture_weights(self):
        generators = self.neighbourhood.best_generators
        weights_generators = self.neighbourhood.mixture_weights_generators

        # Not necessary for single-cell grids, as mixture must always be [1]
        if self.neighbourhood.grid_size == 1:
            return

        # Create random vector from latent space
        z_noise = noise(self.score_sample_size,
                        generators.individuals[0].genome.data_size)

        # Include option to start from random weights
        if self.es_random_init:
            aux_weights = np.random.rand(len(weights_generators))
            aux_weights /= np.sum(aux_weights)
            weights_generators = OrderedDict(
                zip(weights_generators.keys(), aux_weights))
            self.neighbourhood.mixture_weights_generators = weights_generators

        dataset = MixedGeneratorDataset(
            generators,
            weights_generators,
            self.score_sample_size,
            self.mixture_generator_samples_mode,
            z_noise,
        )

        self.score = self.score_calc.calculate(dataset)[0]
        init_score = self.score

        self._logger.info(
            "Mixture weight mutation - Starting mixture weights optimization ..."
        )
        self._logger.info("Init score: {}\tInit weights: {}.".format(
            init_score, weights_generators))

        for g in range(self.es_generations):

            # Mutate mixture weights
            z = np.random.normal(loc=0,
                                 scale=self.mixture_sigma,
                                 size=len(weights_generators))
            transformed = np.asarray(
                [value for _, value in weights_generators.items()])
            transformed += z

            # Don't allow negative values, normalize to sum of 1.0
            transformed = np.clip(transformed, 0, None)
            transformed /= np.sum(transformed)
            new_mixture_weights = OrderedDict(
                zip(weights_generators.keys(), transformed))

            # TODO: Testing the idea of not generating the images again
            dataset = MixedGeneratorDataset(
                generators,
                new_mixture_weights,
                self.score_sample_size,
                self.mixture_generator_samples_mode,
                z_noise,
                epoch=g)

            if self.score_calc is not None:
                score_after_mutation = self.score_calc.calculate(dataset)[0]
                self._logger.info(
                    "Mixture weight mutation - Generation: {} \tScore of new weights: {}\tNew weights: {}."
                    .format(g, score_after_mutation, new_mixture_weights))

                # For fid the lower the better, for inception_score, the higher the better
                if (score_after_mutation < self.score
                        and self.score_calc.is_reversed) or (
                            score_after_mutation > self.score and
                            (not self.score_calc.is_reversed)):
                    weights_generators = new_mixture_weights
                    self.score = score_after_mutation
                    self._logger.info(
                        "Mixture weight mutation - Generation: {} \tNew score: {}\tWeights changed to: {}."
                        .format(g, self.score, weights_generators))
        self.neighbourhood.mixture_weights_generators = weights_generators

        self._logger.info(
            "Mixture weight mutation - Score before mixture weight optimzation: {}\tScore after mixture weight optimzation: {}."
            .format(init_score, self.score))

    def step(
        self,
        original,
        attacker,
        defender,
        input_data,
        i,
        loaded,
        data_iterator,
        labels=None,
        alpha=None,
        beta=None,
        iter=None,
    ):
        self.mutate_hyperparams(attacker)
        return self.update_genomes(
            attacker,
            defender,
            input_data,
            loaded,
            data_iterator,
            labels=labels,
            alpha=alpha,
            beta=beta,
            iter=iter,
        )

    def is_last_batch(self, i):
        return self.dataloader.n_batches != 0 and self.dataloader.n_batches - 1 == i

    def result(self):
        return (
            (
                self.concurrent_populations.generator.individuals[0].genome,
                self.concurrent_populations.generator.individuals[0].fitness,
            ),
            (
                self.concurrent_populations.discriminator.individuals[0].
                genome,
                self.concurrent_populations.discriminator.individuals[0].
                fitness,
            ),
        )

    def mutate_hyperparams(self, population):
        loc = -(self._default_adam_learning_rate / 10)
        deltas = np.random.normal(
            loc=loc,
            scale=self._default_adam_learning_rate,
            size=len(population.individuals),
        )
        deltas[np.random.rand(*deltas.shape) < 1 -
               self._mutation_probability] = 0
        for i, individual in enumerate(population.individuals):
            individual.learning_rate = max(
                0, individual.learning_rate + deltas[i] * self._alpha)

    def update_genomes(
        self,
        population_attacker,
        population_defender,
        input_var,
        loaded,
        data_iterator,
        labels=None,
        alpha=None,
        beta=None,
        iter=None,
    ):

        # TODO Currently picking random opponent, introduce parameter for this
        defender = random.choice(population_defender.individuals).genome

        for individual_attacker in population_attacker.individuals:
            attacker = individual_attacker.genome
            optimizer = torch.optim.Adam(
                attacker.net.parameters(),
                lr=individual_attacker.learning_rate,
                betas=(0.5, 0.999),
            )

            # Restore previous state dict, if available
            if individual_attacker.optimizer_state is not None:
                optimizer.load_state_dict(individual_attacker.optimizer_state)

            if labels is None:
                loss = attacker.compute_loss_against(defender, input_var)[0]
            else:
                loss = attacker.compute_loss_against(
                    defender,
                    input_var,
                    labels=labels,
                    alpha=alpha,
                    beta=beta,
                    iter=iter,
                )[0]

            attacker.net.zero_grad()
            defender.net.zero_grad()
            loss.backward()
            optimizer.step()

            individual_attacker.optimizer_state = optimizer.state_dict()

        return input_var

    @staticmethod
    def evaluate_fitness(
        population_attacker,
        population_defender,
        input_var,
        fitness_mode,
        labels=None,
        logger=None,
        alpha=None,
        beta=None,
        iter=None,
        log_class_distribution=False,
    ):
        # Single direction only: Evaluate fitness of attacker based on defender
        # TODO: Simplify and refactor this function
        def compare_fitness(curr_fitness, fitness, mode):
            # The initial fitness value is -inf before evaluation started, so we
            # directly adopt the curr_fitness when -inf is encountered
            if mode == "best":
                if curr_fitness < fitness or fitness == float("-inf"):
                    return curr_fitness
            elif mode == "worse":
                if curr_fitness > fitness or fitness == float("-inf"):
                    return curr_fitness
            elif mode == "average":
                if fitness == float("-inf"):
                    return curr_fitness
                else:
                    return fitness + curr_fitness

            return fitness

        for individual_attacker in population_attacker.individuals:
            individual_attacker.fitness = float(
                "-inf"
            )  # Reinitalize before evaluation started (Needed for average fitness)
            for individual_defender in population_defender.individuals:
                if labels is None:
                    fitness_attacker = float(
                        individual_attacker.genome.compute_loss_against(
                            individual_defender.genome, input_var)[0])
                else:
                    fitness_attacker = float(
                        individual_attacker.genome.compute_loss_against(
                            individual_defender.genome,
                            input_var,
                            labels=labels,
                            alpha=alpha,
                            beta=beta,
                            iter=iter,
                        )[0])

                individual_attacker.fitness = compare_fitness(
                    fitness_attacker, individual_attacker.fitness,
                    fitness_mode)

            if fitness_mode == "average":
                individual_attacker.fitness /= len(
                    population_defender.individuals)

        if labels is not None and logger is not None and population_defender.individuals[
                0].name == "SemiSupervised":
            gen = None
            for g in population_defender.individuals:
                if g.is_local:
                    gen = g
                    break
            dis = None
            for d in population_attacker.individuals:
                if d.is_local:
                    dis = d
                    break
            generator = gen.genome
            discriminator = dis.genome
            discriminator_output = discriminator.compute_loss_against(
                generator,
                input_var,
                labels=labels,
                alpha=alpha,
                beta=beta,
                iter=iter,
                log_class_distribution=log_class_distribution,
            )
            accuracy = discriminator_output[2]
            if discriminator.name == "SemiSupervisedDiscriminator" and accuracy is not None:
                logger.info(
                    f"Iteration {iter},  Label Prediction Accuracy: {100 * accuracy}% "
                )

    def compute_mixture_generative_score(self, epoch=None):
        # Not necessary for single-cell grids, as mixture must always be [1]
        self._logger.info("Calculating score {}.".format(epoch))
        best_generators = self.neighbourhood.best_generators

        if True or self.neighbourhood.grid_size == 1:
            if self.score_calc is not None:
                dataset = MixedGeneratorDataset(
                    best_generators,
                    self.neighbourhood.mixture_weights_generators,
                    self.score_sample_size,
                    self.cc.settings["trainer"]
                    ["mixture_generator_samples_mode"],
                    # epoch=epoch
                )
                self.score = self.score_calc.calculate(dataset)[0]

    def generate_random_fitness_samples(self, fitness_sample_size):
        """
        Generate random samples for fitness evaluation according to fitness_sample_size

        Abit of hack, use iterator of batch_size to sample data of fitness_sample_size
        TODO Implement another iterator (and dataloader) of fitness_sample_size
        """
        def get_next_batch(iterator, loaded):
            # Handle if the end of iterator is reached
            try:
                input, labels = next(iterator)
                return input, labels, iterator
            except StopIteration:
                # Use a new iterator
                iterator = iter(loaded)
                input, labels = next(iterator)
                return input, labels, iterator

        sampled_data, sampled_labels, self.fitness_iterator = get_next_batch(
            self.fitness_iterator, self.fitness_loaded)
        batch_size = sampled_data.size(0)

        if fitness_sample_size < batch_size:
            return (
                sampled_data[:fitness_sample_size],
                sampled_labels[:fitness_sample_size],
            )
        else:
            fitness_sample_size -= batch_size
            while fitness_sample_size >= batch_size:
                # Keep concatenate a full batch of data
                curr_data, curr_labels, self.fitness_iterator = get_next_batch(
                    self.fitness_iterator, self.fitness_loaded)
                sampled_data = torch.cat((sampled_data, curr_data), 0)
                sampled_labels = torch.cat((sampled_labels, curr_labels), 0)
                fitness_sample_size -= batch_size

            if fitness_sample_size > 0:
                # Concatenate partial batch of data
                curr_data, curr_labels, self.fitness_iterator = get_next_batch(
                    self.fitness_iterator, self.fitness_loaded)
                sampled_data = torch.cat(
                    (sampled_data, curr_data[:fitness_sample_size]), 0)
                sampled_labels = torch.cat(
                    (sampled_labels, curr_labels[:fitness_sample_size]), 0)

            return sampled_data, sampled_labels
Example #9
0
    def __init__(self,
                 dataloader,
                 network_factory,
                 population_size=10,
                 tournament_size=2,
                 mutation_probability=0.9,
                 n_replacements=1,
                 sigma=0.25,
                 alpha=0.25,
                 default_adam_learning_rate=0.001,
                 calc_mixture=False,
                 mixture_sigma=0.01,
                 score_sample_size=10000,
                 discriminator_skip_each_nth_step=0,
                 enable_selection=True,
                 fitness_sample_size=10000,
                 calculate_net_weights_dist=False,
                 fitness_mode='worst'):

        super().__init__(dataloader, network_factory, population_size,
                         tournament_size, mutation_probability, n_replacements,
                         sigma, alpha)

        self.batch_number = 0
        self.cc = ConfigurationContainer.instance()

        self._default_adam_learning_rate = self.settings.get(
            'default_adam_learning_rate', default_adam_learning_rate)
        self._discriminator_skip_each_nth_step = self.settings.get(
            'discriminator_skip_each_nth_step',
            discriminator_skip_each_nth_step)
        self._enable_selection = self.settings.get('enable_selection',
                                                   enable_selection)
        self.mixture_sigma = self.settings.get('mixture_sigma', mixture_sigma)

        self.neighbourhood = Grid.instance()

        for i, individual in enumerate(self.population_gen.individuals):
            individual.learning_rate = self._default_adam_learning_rate
            individual.id = '{}/G{}'.format(self.neighbourhood.cell_number, i)
        for i, individual in enumerate(self.population_dis.individuals):
            individual.learning_rate = self._default_adam_learning_rate
            individual.id = '{}/D{}'.format(self.neighbourhood.cell_number, i)

        # TRACE: Se genera un lock para setear la poblaciĆ³n? Usa multithread lock
        #               Faltaria revisar si singleton es thread safe
        self.concurrent_populations = ConcurrentPopulations.instance()
        self.concurrent_populations.generator = self.population_gen
        self.concurrent_populations.discriminator = self.population_dis
        if self.concurrent_populations.locked():
            self.concurrent_populations.unlock()

        experiment_id = self.cc.settings['general']['logging'].get(
            'experiment_id', None)
        self.db_logger = DbLogger(current_experiment=experiment_id)

        if 'score' in self.settings and self.settings['score'].get(
                'enabled', calc_mixture):
            self.score_calc = ScoreCalculatorFactory.create()
            self.score_sample_size = self.settings['score'].get(
                'sample_size', score_sample_size)
            self.score = float(
                'inf') if self.score_calc.is_reversed else float('-inf')
        else:
            self.score_calc = None
            self.score = 0

        if 'fitness' in self.settings:
            self.fitness_sample_size = self.settings['fitness'].get(
                'fitness_sample_size', fitness_sample_size)
            self.fitness_loaded = self.dataloader.load()
            self.fitness_iterator = iter(
                self.fitness_loaded)  # Create iterator for fitness loader

            # Determine how to aggregate fitness calculated among neighbourhood
            self.fitness_mode = self.settings['fitness'].get(
                'fitness_mode', fitness_mode)
            if self.fitness_mode not in ['worse', 'best', 'average']:
                raise NotImplementedError(
                    "Invalid argument for fitness_mode: {}".format(
                        self.fitness_mode))
        else:
            # TODO: Add code for safe implementation & error handling
            raise KeyError(
                "Fitness section must be defined in configuration file")
Example #10
0
class LipizzanerGANTrainer(EvolutionaryAlgorithmTrainer):
    """
    Distributed, asynchronous trainer for coevolutionary GANs. Uses the standard Goodfellow GAN approach.
    (Without discriminator mixture)
    """
    def __init__(self,
                 dataloader,
                 network_factory,
                 population_size=10,
                 tournament_size=2,
                 mutation_probability=0.9,
                 n_replacements=1,
                 sigma=0.25,
                 alpha=0.25,
                 default_adam_learning_rate=0.001,
                 calc_mixture=False,
                 mixture_sigma=0.01,
                 score_sample_size=10000,
                 discriminator_skip_each_nth_step=0,
                 enable_selection=True,
                 fitness_sample_size=10000,
                 calculate_net_weights_dist=False,
                 fitness_mode='worst'):

        super().__init__(dataloader, network_factory, population_size,
                         tournament_size, mutation_probability, n_replacements,
                         sigma, alpha)

        self.batch_number = 0
        self.cc = ConfigurationContainer.instance()

        self._default_adam_learning_rate = self.settings.get(
            'default_adam_learning_rate', default_adam_learning_rate)
        self._discriminator_skip_each_nth_step = self.settings.get(
            'discriminator_skip_each_nth_step',
            discriminator_skip_each_nth_step)
        self._enable_selection = self.settings.get('enable_selection',
                                                   enable_selection)
        self.mixture_sigma = self.settings.get('mixture_sigma', mixture_sigma)

        self.neighbourhood = Grid.instance()

        for i, individual in enumerate(self.population_gen.individuals):
            individual.learning_rate = self._default_adam_learning_rate
            individual.id = '{}/G{}'.format(self.neighbourhood.cell_number, i)
        for i, individual in enumerate(self.population_dis.individuals):
            individual.learning_rate = self._default_adam_learning_rate
            individual.id = '{}/D{}'.format(self.neighbourhood.cell_number, i)

        # TRACE: Se genera un lock para setear la poblaciĆ³n? Usa multithread lock
        #               Faltaria revisar si singleton es thread safe
        self.concurrent_populations = ConcurrentPopulations.instance()
        self.concurrent_populations.generator = self.population_gen
        self.concurrent_populations.discriminator = self.population_dis
        if self.concurrent_populations.locked():
            self.concurrent_populations.unlock()

        experiment_id = self.cc.settings['general']['logging'].get(
            'experiment_id', None)
        self.db_logger = DbLogger(current_experiment=experiment_id)

        if 'score' in self.settings and self.settings['score'].get(
                'enabled', calc_mixture):
            self.score_calc = ScoreCalculatorFactory.create()
            self.score_sample_size = self.settings['score'].get(
                'sample_size', score_sample_size)
            self.score = float(
                'inf') if self.score_calc.is_reversed else float('-inf')
        else:
            self.score_calc = None
            self.score = 0

        if 'fitness' in self.settings:
            self.fitness_sample_size = self.settings['fitness'].get(
                'fitness_sample_size', fitness_sample_size)
            self.fitness_loaded = self.dataloader.load()
            self.fitness_iterator = iter(
                self.fitness_loaded)  # Create iterator for fitness loader

            # Determine how to aggregate fitness calculated among neighbourhood
            self.fitness_mode = self.settings['fitness'].get(
                'fitness_mode', fitness_mode)
            if self.fitness_mode not in ['worse', 'best', 'average']:
                raise NotImplementedError(
                    "Invalid argument for fitness_mode: {}".format(
                        self.fitness_mode))
        else:
            # TODO: Add code for safe implementation & error handling
            raise KeyError(
                "Fitness section must be defined in configuration file")

    def train(self, n_iterations, stop_event=None):

        loaded = self.dataloader.load()

        for iteration in range(n_iterations):
            self._logger.debug('Iteration {} started'.format(iteration + 1))
            start_time = time()

            local_generators = self.neighbourhood.local_generators
            local_discriminators = self.neighbourhood.local_discriminators
            all_generators, all_discriminators = self.neighbourhood.all_disc_gen_local(
            )

            # Local functions

            # Log the name of individuals in entire neighborhood for every iteration
            # (to help tracing because individuals from adjacent cells might be from different iterations)
            self._logger.info(
                'Generators in current neighborhood are {}'.format([
                    individual.name
                    for individual in all_generators.individuals
                ]))
            self._logger.info(
                'Discriminators in current neighborhood are {}'.format([
                    individual.name
                    for individual in all_discriminators.individuals
                ]))

            # TODO: Fixme
            # self._logger.info('L2 distance between all generators weights: {}'.format(all_generators.net_weights_dist))
            # self._logger.info('L2 distance between all discriminators weights: {}'.format(all_discriminators.net_weights_dist))

            new_populations = {}

            # Create random dataset to evaluate fitness in each iterations
            fitness_samples = self.generate_random_fitness_samples(
                self.fitness_sample_size)
            if self.cc.settings['dataloader']['dataset_name'] == 'celeba' \
                or self.cc.settings['dataloader']['dataset_name'] == 'cifar':
                fitness_samples = to_pytorch_variable(fitness_samples)
            else:
                fitness_samples = to_pytorch_variable(
                    fitness_samples.view(self.fitness_sample_size, -1))

            # Fitness evaluation
            self._logger.debug('Evaluating fitness')
            self.evaluate_fitness(all_generators, all_discriminators,
                                  fitness_samples, self.fitness_mode)
            self.evaluate_fitness(all_discriminators, all_generators,
                                  fitness_samples, self.fitness_mode)
            self._logger.debug('Finished evaluating fitness')

            # Tournament selection
            if self._enable_selection:
                self._logger.debug('Started tournament selection')
                new_populations[TYPE_GENERATOR] = self.tournament_selection(
                    all_generators, TYPE_GENERATOR, is_logging=True)
                new_populations[
                    TYPE_DISCRIMINATOR] = self.tournament_selection(
                        all_discriminators,
                        TYPE_DISCRIMINATOR,
                        is_logging=True)
                self._logger.debug('Finished tournament selection')

            self.batch_number = 0
            data_iterator = iter(loaded)
            while self.batch_number < len(loaded):
                # for i, (input_data, labels) in enumerate(loaded):
                input_data = next(data_iterator)[0]
                batch_size = input_data.size(0)
                input_data = to_pytorch_variable(
                    self.dataloader.transpose_data(input_data))

                # Quit if requested
                if stop_event is not None and stop_event.is_set():
                    self._logger.warning('External stop requested.')
                    return self.result()

                attackers = new_populations[
                    TYPE_GENERATOR] if self._enable_selection else local_generators
                defenders = new_populations[
                    TYPE_DISCRIMINATOR] if self._enable_selection else all_discriminators
                input_data = self.step(local_generators, attackers, defenders,
                                       input_data, self.batch_number, loaded,
                                       data_iterator)

                if self._discriminator_skip_each_nth_step == 0 or self.batch_number % (
                        self._discriminator_skip_each_nth_step + 1) == 0:
                    self._logger.debug('Skipping discriminator step')

                    attackers = new_populations[
                        TYPE_DISCRIMINATOR] if self._enable_selection else local_discriminators
                    defenders = new_populations[
                        TYPE_GENERATOR] if self._enable_selection else all_generators
                    input_data = self.step(local_discriminators, attackers,
                                           defenders, input_data,
                                           self.batch_number, loaded,
                                           data_iterator)

                self._logger.info('Iteration {}, Batch {}/{}'.format(
                    iteration + 1, self.batch_number, len(loaded)))

                # If n_batches is set to 0, all batches will be used
                if self.is_last_batch(self.batch_number):
                    break

                self.batch_number += 1

            # Perform selection first before mutation of mixture_weights
            # Replace the worst with the best new
            if self._enable_selection:
                # Evaluate fitness of new_populations against neighborhood
                self.evaluate_fitness(new_populations[TYPE_GENERATOR],
                                      all_discriminators, fitness_samples,
                                      self.fitness_mode)
                self.evaluate_fitness(new_populations[TYPE_DISCRIMINATOR],
                                      all_generators, fitness_samples,
                                      self.fitness_mode)
                self.concurrent_populations.lock()
                local_generators.replacement(new_populations[TYPE_GENERATOR],
                                             self._n_replacements,
                                             is_logging=True)
                local_generators.sort_population(is_logging=True)
                local_discriminators.replacement(
                    new_populations[TYPE_DISCRIMINATOR],
                    self._n_replacements,
                    is_logging=True)
                local_discriminators.sort_population(is_logging=True)
                self.concurrent_populations.unlock()

                # Update individuals' iteration and id after replacement and logging to ease tracing
                for i, individual in enumerate(local_generators.individuals):
                    individual.id = '{}/G{}'.format(
                        self.neighbourhood.cell_number, i)
                    individual.iteration = iteration + 1
                for i, individual in enumerate(
                        local_discriminators.individuals):
                    individual.id = '{}/D{}'.format(
                        self.neighbourhood.cell_number, i)
                    individual.iteration = iteration + 1
            else:
                # Re-evaluate fitness of local_generators and local_discriminators against neighborhood
                self.evaluate_fitness(local_generators, all_discriminators,
                                      fitness_samples, self.fitness_mode)
                self.evaluate_fitness(local_discriminators, all_generators,
                                      fitness_samples, self.fitness_mode)

            # Mutate mixture weights after selection
            self.mutate_mixture_weights_with_score(
                input_data)  # self.score is updated here

            stop_time = time()

            path_real_images, path_fake_images = \
                self.log_results(batch_size, iteration, input_data, loaded,
                                 lr_gen=self.concurrent_populations.generator.individuals[0].learning_rate,
                                 lr_dis=self.concurrent_populations.discriminator.individuals[0].learning_rate,
                                 score=self.score, mixture_gen=self.neighbourhood.mixture_weights_generators,
                                 mixture_dis=None)

            if self.db_logger.is_enabled:
                self.db_logger.log_results(iteration, self.neighbourhood,
                                           self.concurrent_populations,
                                           self.score, stop_time - start_time,
                                           path_real_images, path_fake_images)

        return self.result()

    def step(self, original, attacker, defender, input_data, i, loaded,
             data_iterator):
        # Don't execute for remote populations - needed if generator and discriminator are on different node
        #         if any(not ind.is_local for ind in original.individuals):
        #             return

        self.mutate_hyperparams(attacker)
        return self.update_genomes(attacker, defender, input_data, loaded,
                                   data_iterator)

    def is_last_batch(self, i):
        return self.dataloader.n_batches != 0 and self.dataloader.n_batches - 1 == i

    def result(self):
        return (
            (self.concurrent_populations.generator.individuals[0].genome,
             self.concurrent_populations.generator.individuals[0].fitness),
            (self.concurrent_populations.discriminator.individuals[0].genome,
             self.concurrent_populations.discriminator.individuals[0].fitness))

    def mutate_hyperparams(self, population):
        loc = -(self._default_adam_learning_rate / 10)
        deltas = np.random.normal(loc=loc,
                                  scale=self._default_adam_learning_rate,
                                  size=len(population.individuals))
        deltas[np.random.rand(*deltas.shape) < 1 -
               self._mutation_probability] = 0
        for i, individual in enumerate(population.individuals):
            individual.learning_rate = max(
                0, individual.learning_rate + deltas[i] * self._alpha)

    def update_genomes(self, population_attacker, population_defender,
                       input_var, loaded, data_iterator):

        # TODO Currently picking random opponent, introduce parameter for this
        defender = random.choice(population_defender.individuals).genome

        for individual_attacker in population_attacker.individuals:
            attacker = individual_attacker.genome
            optimizer = torch.optim.Adam(attacker.net.parameters(),
                                         lr=individual_attacker.learning_rate,
                                         betas=(0.5, 0.999))

            # Restore previous state dict, if available
            if individual_attacker.optimizer_state is not None:
                optimizer.load_state_dict(individual_attacker.optimizer_state)
            loss = attacker.compute_loss_against(defender, input_var)[0]

            attacker.net.zero_grad()
            defender.net.zero_grad()
            loss.backward()
            optimizer.step()

            individual_attacker.optimizer_state = optimizer.state_dict()

        return input_var

    @staticmethod
    def evaluate_fitness(population_attacker, population_defender, input_var,
                         fitness_mode):
        # Single direction only: Evaluate fitness of attacker based on defender
        # TODO: Simplify and refactor this function
        def compare_fitness(curr_fitness, fitness, mode):
            # The initial fitness value is -inf before evaluation started, so we
            # directly adopt the curr_fitness when -inf is encountered
            if mode == 'best':
                if curr_fitness < fitness or fitness == float('-inf'):
                    return curr_fitness
            elif mode == 'worse':
                if curr_fitness > fitness or fitness == float('-inf'):
                    return curr_fitness
            elif mode == 'average':
                if fitness == float('-inf'):
                    return curr_fitness
                else:
                    return fitness + curr_fitness

            return fitness

        for individual_attacker in population_attacker.individuals:
            individual_attacker.fitness = float(
                '-inf'
            )  # Reinitalize before evaluation started (Needed for average fitness)
            for individual_defender in population_defender.individuals:
                # TRACE: Evalua atacante contra defensor
                fitness_attacker = float(
                    individual_attacker.genome.compute_loss_against(
                        individual_defender.genome, input_var)[0])
                # TRACE: Se que da con el mejor fitness?
                individual_attacker.fitness = compare_fitness(
                    fitness_attacker, individual_attacker.fitness,
                    fitness_mode)

            if fitness_mode == 'average':
                individual_attacker.fitness /= len(
                    population_defender.individuals)

    def mutate_mixture_weights_with_score(self, input_data):
        # Not necessary for single-cell grids, as mixture must always be [1]
        if self.neighbourhood.grid_size == 1:
            if self.score_calc is not None:
                self._logger.info('Calculating FID/inception score.')
                best_generators = self.neighbourhood.best_generators_local()

                dataset = MixedGeneratorDataset(
                    best_generators,
                    self.neighbourhood.mixture_weights_generators,
                    self.score_sample_size, self.cc.settings['trainer']
                    ['mixture_generator_samples_mode'])
                self.score = self.score_calc.calculate(dataset)[0]
        else:
            # Mutate mixture weights
            z = np.random.normal(
                loc=0,
                scale=self.mixture_sigma,
                size=len(self.neighbourhood.mixture_weights_generators))
            transformed = np.asarray([
                value for _, value in
                self.neighbourhood.mixture_weights_generators.items()
            ])
            transformed += z
            # Don't allow negative values, normalize to sum of 1.0
            transformed = np.clip(transformed, 0, None)
            transformed /= np.sum(transformed)

            new_mixture_weights_generators = OrderedDict(
                zip(self.neighbourhood.mixture_weights_generators.keys(),
                    transformed))

            best_generators = self.neighbourhood.best_generators_local()
            dataset_before_mutation = MixedGeneratorDataset(
                best_generators, self.neighbourhood.mixture_weights_generators,
                self.score_sample_size,
                self.cc.settings['trainer']['mixture_generator_samples_mode'])
            dataset_after_mutation = MixedGeneratorDataset(
                best_generators, new_mixture_weights_generators,
                self.score_sample_size,
                self.cc.settings['trainer']['mixture_generator_samples_mode'])

            if self.score_calc is not None:
                self._logger.info('Calculating FID/inception score.')

                score_before_mutation = self.score_calc.calculate(
                    dataset_before_mutation)[0]
                score_after_mutation = self.score_calc.calculate(
                    dataset_after_mutation)[0]

                # For fid the lower the better, for inception_score, the higher the better
                if (score_after_mutation < score_before_mutation and self.score_calc.is_reversed) \
                     or (score_after_mutation > score_before_mutation and (not self.score_calc.is_reversed)):
                    # Adopt the mutated mixture_weights only if the performance after mutation is better
                    self.neighbourhood.mixture_weights_generators = new_mixture_weights_generators
                    self.score = score_after_mutation
                else:
                    # Do not adopt the mutated mixture_weights here
                    self.score = score_before_mutation

    def generate_random_fitness_samples(self, fitness_sample_size):
        """
        Generate random samples for fitness evaluation according to fitness_sample_size

        Abit of hack, use iterator of batch_size to sample data of fitness_sample_size
        TODO Implement another iterator (and dataloader) of fitness_sample_size
        """
        def get_next_batch(iterator, loaded):
            # Handle if the end of iterator is reached
            try:
                return next(iterator)[0], iterator
            except StopIteration:
                # Use a new iterator
                iterator = iter(loaded)
                return next(iterator)[0], iterator

        sampled_data, self.fitness_iterator = get_next_batch(
            self.fitness_iterator, self.fitness_loaded)
        batch_size = sampled_data.size(0)

        if fitness_sample_size < batch_size:
            return sampled_data[:fitness_sample_size]
        else:
            fitness_sample_size -= batch_size
            while fitness_sample_size >= batch_size:
                # Keep concatenate a full batch of data
                curr_data, self.fitness_iterator = get_next_batch(
                    self.fitness_iterator, self.fitness_loaded)
                sampled_data = torch.cat((sampled_data, curr_data), 0)
                fitness_sample_size -= batch_size

            if fitness_sample_size > 0:
                # Concatenate partial batch of data
                curr_data, self.fitness_iterator = get_next_batch(
                    self.fitness_iterator, self.fitness_loaded)
                sampled_data = torch.cat(
                    (sampled_data, curr_data[:fitness_sample_size]), 0)

            return sampled_data
Example #11
0
    def __init__(self,
                 dataloader,
                 network_factory,
                 population_size=10,
                 tournament_size=2,
                 mutation_probability=0.9,
                 n_replacements=1,
                 sigma=0.25,
                 alpha=0.25,
                 default_adam_learning_rate=0.001,
                 calc_mixture=False,
                 mixture_sigma=0.01,
                 score_sample_size=10000,
                 discriminator_skip_each_nth_step=0,
                 enable_selection=True,
                 fitness_sample_size=10000,
                 calculate_net_weights_dist=False,
                 fitness_mode='worst',
                 es_generations=10,
                 es_score_sample_size=10000,
                 es_random_init=False,
                 checkpoint_period=0):

        super().__init__(dataloader, network_factory, population_size,
                         tournament_size, mutation_probability, n_replacements,
                         sigma, alpha)

        self.batch_number = 0
        self.cc = ConfigurationContainer.instance()

        self._default_adam_learning_rate = self.settings.get(
            'default_adam_learning_rate', default_adam_learning_rate)
        self._discriminator_skip_each_nth_step = self.settings.get(
            'discriminator_skip_each_nth_step',
            discriminator_skip_each_nth_step)
        self._enable_selection = self.settings.get('enable_selection',
                                                   enable_selection)
        self.mixture_sigma = self.settings.get('mixture_sigma', mixture_sigma)

        self.neighbourhood = Neighbourhood.instance()

        for i, individual in enumerate(self.population_gen.individuals):
            individual.learning_rate = self._default_adam_learning_rate
            individual.id = '{}/G{}'.format(self.neighbourhood.cell_number, i)
        for i, individual in enumerate(self.population_dis.individuals):
            individual.learning_rate = self._default_adam_learning_rate
            individual.id = '{}/D{}'.format(self.neighbourhood.cell_number, i)

        self.concurrent_populations = ConcurrentPopulations.instance()
        self.concurrent_populations.generator = self.population_gen
        self.concurrent_populations.discriminator = self.population_dis
        self.concurrent_populations.unlock()

        experiment_id = self.cc.settings['general']['logging'].get(
            'experiment_id', None)
        self.db_logger = DbLogger(current_experiment=experiment_id)

        if 'fitness' in self.settings:
            self.fitness_sample_size = self.settings['fitness'].get(
                'fitness_sample_size', fitness_sample_size)
            self.fitness_loaded = self.dataloader.load()
            self.fitness_iterator = iter(
                self.fitness_loaded)  # Create iterator for fitness loader

            # Determine how to aggregate fitness calculated among neighbourhood
            self.fitness_mode = self.settings['fitness'].get(
                'fitness_mode', fitness_mode)
            if self.fitness_mode not in ['worse', 'best', 'average']:
                raise NotImplementedError(
                    "Invalid argument for fitness_mode: {}".format(
                        self.fitness_mode))
        else:
            # TODO: Add code for safe implementation & error handling
            raise KeyError(
                "Fitness section must be defined in configuration file")

        if 'score' in self.settings and self.settings['score'].get(
                'enabled', calc_mixture):
            self.score_calc = ScoreCalculatorFactory.create()
            self.score_sample_size = self.settings['score'].get(
                'sample_size', score_sample_size)
            self.score = float(
                'inf') if self.score_calc.is_reversed else float('-inf')
            self.mixture_generator_samples_mode = self.cc.settings['trainer'][
                'mixture_generator_samples_mode']
        elif 'optimize_mixture' in self.settings:
            self.score_calc = ScoreCalculatorFactory.create()
            self.score = float(
                'inf') if self.score_calc.is_reversed else float('-inf')
        else:
            self.score_sample_size = score_sample_size
            self.score_calc = None
            self.score = 0

        if 'optimize_mixture' in self.settings:
            self.optimize_weights_at_the_end = True
            self.score_sample_size = self.settings['optimize_mixture'].get(
                'sample_size', es_score_sample_size)
            self.es_generations = self.settings['optimize_mixture'].get(
                'es_generations', es_generations)
            self.es_random_init = self.settings['optimize_mixture'].get(
                'es_random_init', es_random_init)
            self.mixture_sigma = self.settings['optimize_mixture'].get(
                'mixture_sigma', mixture_sigma)
            self.mixture_generator_samples_mode = self.cc.settings['trainer'][
                'mixture_generator_samples_mode']
        else:
            self.optimize_weights_at_the_end = False

        n_iterations = self.cc.settings['trainer'].get('n_iterations', 0)
        assert 0 <= checkpoint_period <= n_iterations, 'Checkpoint period paramenter (checkpoint_period) should be ' \
                                                       'between 0 and the number of iterations (n_iterations).'
        self.checkpoint_period = self.cc.settings['general'].get(
            'checkpoint_period', checkpoint_period)
Example #12
0
class LipizzanerGANTrainer(EvolutionaryAlgorithmTrainer):
    """
    Distributed, asynchronous trainer for coevolutionary GANs. Uses the standard Goodfellow GAN approach.
    """
    def __init__(self,
                 dataloader,
                 network_factory,
                 population_size=10,
                 tournament_size=2,
                 mutation_probability=0.9,
                 n_replacements=1,
                 sigma=0.25,
                 alpha=0.25,
                 default_adam_learning_rate=0.001,
                 calc_mixture=False,
                 mixture_sigma=0.01,
                 score_sample_size=10000,
                 discriminator_skip_each_nth_step=0,
                 enable_selection=True):

        super().__init__(dataloader, network_factory, population_size,
                         tournament_size, mutation_probability, n_replacements,
                         sigma, alpha)

        self.batch_number = 0

        self._default_adam_learning_rate = self.settings.get(
            'default_adam_learning_rate', default_adam_learning_rate)
        self._discriminator_skip_each_nth_step = self.settings.get(
            'discriminator_skip_each_nth_step',
            discriminator_skip_each_nth_step)
        self._enable_selection = self.settings.get('enable_selection',
                                                   enable_selection)
        self.mixture_sigma = self.settings.get('mixture_sigma', mixture_sigma)

        self.neighbourhood = Neighbourhood.instance()

        for i, individual in enumerate(self.population_gen.individuals):
            individual.learning_rate = self._default_adam_learning_rate
            individual.id = '{}/G{}'.format(self.neighbourhood.cell_number, i)
        for i, individual in enumerate(self.population_dis.individuals):
            individual.learning_rate = self._default_adam_learning_rate
            individual.id = '{}/D{}'.format(self.neighbourhood.cell_number, i)

        self.concurrent_populations = ConcurrentPopulations.instance()
        self.concurrent_populations.generator = self.population_gen
        self.concurrent_populations.discriminator = self.population_dis
        self.concurrent_populations.unlock()

        experiment_id = ConfigurationContainer.instance(
        ).settings['general']['logging'].get('experiment_id', None)
        self.db_logger = DbLogger(current_experiment=experiment_id)

        if 'score' in self.settings and self.settings['score'].get(
                'enabled', calc_mixture):
            self.score_calc = ScoreCalculatorFactory.create()
            self.score_sample_size = self.settings['score'].get(
                'sample_size', score_sample_size)
            self.score = float(
                'inf') if self.score_calc.is_reversed else float('-inf')
        else:
            self.score_calc = None
            self.score = 0

    def train(self, n_iterations, stop_event=None):

        loaded = self.dataloader.load()

        for iteration in range(n_iterations):
            self._logger.debug('Iteration {} started'.format(iteration))
            start_time = time()

            all_generators = self.neighbourhood.all_generators
            all_discriminators = self.neighbourhood.all_discriminators
            local_generators = self.neighbourhood.local_generators
            local_discriminators = self.neighbourhood.local_discriminators

            new_populations = {}

            self.batch_number = 0
            data_iterator = iter(loaded)
            while self.batch_number < len(loaded):
                input_data = next(data_iterator)[0]
                batch_size = input_data.size(0)
                input_data = to_pytorch_variable(
                    self.dataloader.transpose_data(input_data))

                if iteration == 0 and self.batch_number == 0:
                    self._logger.debug('Evaluating first fitness')
                    self.evaluate_fitness(local_generators, all_discriminators,
                                          input_data)
                    self._logger.debug('Finished evaluating first fitness')

                if self.batch_number == 0 and self._enable_selection:
                    self._logger.debug('Started tournamend selection')
                    new_populations[
                        TYPE_GENERATOR] = self.tournament_selection(
                            all_generators, TYPE_GENERATOR)
                    new_populations[
                        TYPE_DISCRIMINATOR] = self.tournament_selection(
                            all_discriminators, TYPE_DISCRIMINATOR)
                    self._logger.debug('Finished tournamend selection')

                # Quit if requested
                if stop_event is not None and stop_event.is_set():
                    self._logger.warning('External stop requested.')
                    return self.result()

                attackers = new_populations[
                    TYPE_GENERATOR] if self._enable_selection else local_generators
                defenders = new_populations[
                    TYPE_DISCRIMINATOR] if self._enable_selection else all_discriminators
                input_data = self.step(
                    local_generators, attackers, defenders, input_data, loaded,
                    data_iterator,
                    self.neighbourhood.mixture_weights_discriminators)

                if self._discriminator_skip_each_nth_step == 0 or self.batch_number % (
                        self._discriminator_skip_each_nth_step + 1) == 0:
                    attackers = new_populations[
                        TYPE_DISCRIMINATOR] if self._enable_selection else local_discriminators
                    defenders = new_populations[
                        TYPE_GENERATOR] if self._enable_selection else all_generators

                    input_data = self.step(
                        local_discriminators, attackers, defenders, input_data,
                        loaded, data_iterator,
                        self.neighbourhood.mixture_weights_generators)

                self._logger.info('Batch {}/{} done'.format(
                    self.batch_number, len(loaded)))

                # If n_batches is set to 0, all batches will be used
                if self.is_last_batch(self.batch_number):
                    break

                self.batch_number += 1

            # Mutate mixture weights
            weights_generators = self.neighbourhood.mixture_weights_generators
            weights_discriminators = self.neighbourhood.mixture_weights_discriminators
            generators = new_populations[
                TYPE_GENERATOR] if self._enable_selection else all_generators
            discriminators = new_populations[
                TYPE_DISCRIMINATOR] if self._enable_selection else all_discriminators
            self.mutate_mixture_weights(weights_generators,
                                        weights_discriminators, generators,
                                        discriminators, input_data)
            self.mutate_mixture_weights(weights_discriminators,
                                        weights_generators, discriminators,
                                        generators, input_data)

            # Replace the worst with the best new
            if self._enable_selection:
                self.evaluate_fitness(new_populations[TYPE_GENERATOR],
                                      new_populations[TYPE_DISCRIMINATOR],
                                      input_data)
                self.concurrent_populations.lock()
                local_generators.replacement(new_populations[TYPE_GENERATOR],
                                             self._n_replacements)
                local_generators.sort_population()
                local_discriminators.replacement(
                    new_populations[TYPE_DISCRIMINATOR], self._n_replacements)
                local_discriminators.sort_population()
                self.concurrent_populations.unlock()
            else:
                self.evaluate_fitness(all_generators, all_discriminators,
                                      input_data)

            if self.score_calc is not None:
                self._logger.info('Calculating FID/inception score.')
                self.calculate_score()

            stop_time = time()

            path_real_images, path_fake_images = \
                self.log_results(batch_size, iteration, input_data, loaded,
                                 lr_gen=self.concurrent_populations.generator.individuals[0].learning_rate,
                                 lr_dis=self.concurrent_populations.discriminator.individuals[0].learning_rate,
                                 score=self.score, mixture_gen=self.neighbourhood.mixture_weights_generators,
                                 mixture_dis=self.neighbourhood.mixture_weights_discriminators)

            if self.db_logger.is_enabled:
                self.db_logger.log_results(iteration, self.neighbourhood,
                                           self.concurrent_populations,
                                           self.score, stop_time - start_time,
                                           path_real_images, path_fake_images)

        return self.result()

    def step(self, original, attacker, defender, input_data, loaded,
             data_iterator, defender_weights):
        # Don't execute for remote populations - needed if generator and discriminator are on different node
        if any(not ind.is_local for ind in original.individuals):
            return

        self.mutate_hyperparams(attacker)
        return self.update_genomes(attacker, defender, input_data, loaded,
                                   data_iterator, defender_weights)

    def is_last_batch(self, i):
        return self.dataloader.n_batches != 0 and self.dataloader.n_batches - 1 == i

    def result(self):
        return (
            (self.concurrent_populations.generator.individuals[0].genome,
             self.concurrent_populations.generator.individuals[0].fitness),
            (self.concurrent_populations.discriminator.individuals[0].genome,
             self.concurrent_populations.discriminator.individuals[0].fitness))

    def mutate_hyperparams(self, population):
        loc = -(self._default_adam_learning_rate / 10)
        deltas = np.random.normal(loc=loc,
                                  scale=self._default_adam_learning_rate,
                                  size=len(population.individuals))
        deltas[np.random.rand(*deltas.shape) < 1 -
               self._mutation_probability] = 0
        for i, individual in enumerate(population.individuals):
            individual.learning_rate = max(
                0, individual.learning_rate + deltas[i] * self._alpha)

    def update_genomes(self, population_attacker, population_defender,
                       input_var, loaded, data_iterator, defender_weights):

        for individual_attacker in population_attacker.individuals:
            attacker = individual_attacker.genome
            weights = [
                self.get_weight(defender, defender_weights)
                for defender in population_defender.individuals
            ]
            weights /= np.sum(weights)
            defender = np.random.choice(population_defender.individuals,
                                        p=weights).genome
            optimizer = torch.optim.Adam(attacker.net.parameters(),
                                         lr=individual_attacker.learning_rate,
                                         betas=(0.5, 0.999))

            # Restore previous state dict, if available
            if individual_attacker.optimizer_state is not None:
                optimizer.load_state_dict(individual_attacker.optimizer_state)

            loss = attacker.compute_loss_against(defender, input_var)[0]

            attacker.net.zero_grad()
            defender.net.zero_grad()
            loss.backward()
            optimizer.step()

            individual_attacker.optimizer_state = optimizer.state_dict()

        return input_var

    @staticmethod
    def evaluate_fitness(population_attacker, population_defender, input_var):

        for individual_attacker in population_attacker.individuals:
            for individual_defender in population_defender.individuals:

                if individual_attacker.is_local:
                    fitness_attacker = float(
                        individual_attacker.genome.compute_loss_against(
                            individual_defender.genome, input_var)[0])

                    if fitness_attacker > individual_attacker.fitness:
                        individual_attacker.fitness = fitness_attacker

                if individual_defender.is_local:
                    fitness_defender = float(
                        individual_defender.genome.compute_loss_against(
                            individual_attacker.genome, input_var)[0])

                    if fitness_defender > individual_defender.fitness:
                        individual_defender.fitness = fitness_defender

    def mutate_mixture_weights(self, weights_attacker, weights_defender,
                               population_attacker, population_defender,
                               input_data):

        # Not necessary for single-cell grids, as mixture must always be [1]
        if self.neighbourhood.grid_size == 1:
            return

        # Mutate mixture weights
        z = np.random.normal(loc=0,
                             scale=self.mixture_sigma,
                             size=len(weights_attacker))
        transformed = np.asarray(
            [value for _, value in weights_attacker.items()])
        transformed += z

        new_mixture_weights = {}
        for i, key in enumerate(weights_attacker):
            new_mixture_weights[key] = transformed[i]

        for attacker in population_attacker.individuals:
            loss_prev = self.weights_loss(attacker, population_defender,
                                          weights_attacker, weights_defender,
                                          input_data)
            loss_new = self.weights_loss(attacker, population_defender,
                                         new_mixture_weights, weights_defender,
                                         input_data)

            if loss_new < loss_prev:
                weights_attacker[attacker.source] = self.get_weight(
                    attacker, new_mixture_weights)

        # Don't allow negative values, normalize to sum of 1.0
        clipped = np.clip(list(weights_attacker.values()), 0, None)
        clipped /= np.sum(clipped)
        for i, key in enumerate(weights_attacker):
            weights_attacker[key] = clipped[i]

    def calculate_score(self):
        best_generators = self.neighbourhood.best_generators

        dataset = MixedGeneratorDataset(
            best_generators, self.neighbourhood.mixture_weights_generators,
            self.score_sample_size)
        self.score = self.score_calc.calculate(dataset)[0]

    @staticmethod
    def get_weight(individual, weights):
        return [v for k, v in weights.items() if k == individual.source][0]

    @staticmethod
    def weights_loss(attacker, population_defender, weights_attacker,
                     weights_defender, input_data):
        w_attacker = LipizzanerGANTrainer.get_weight(attacker,
                                                     weights_attacker)
        return sum([
            w_attacker *
            LipizzanerGANTrainer.get_weight(defender, weights_defender) *
            float(
                attacker.genome.compute_loss_against(defender.genome,
                                                     input_data)[0])
            for defender in population_defender.individuals
        ])