Ejemplo n.º 1
0
    def compute_loss_against(self, opponent, input, training_epoch=None):

        # If HeuristicLoss is applied in the Generator, the Discriminator applies BCELoss
        if self.loss_function.__class__.__name__ == 'MustangsLoss':
            if 'HeuristicLoss' in self.loss_function.get_applied_loss_name():
                self.loss_function.set_applied_loss(torch.nn.BCELoss())

        # Compute loss using real images
        # Second term of the loss is always zero since real_labels == 1
        batch_size = input.size(0)

        real_labels = to_pytorch_variable(torch.ones(batch_size))
        fake_labels = to_pytorch_variable(torch.zeros(batch_size))

        outputs = self.net(input)  #.view(-1)
        d_loss_real = self.loss_function(outputs, real_labels)

        # Compute loss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = noise(batch_size, self.data_size)
        fake_images = opponent.net(z)
        outputs = self.net(fake_images).view(-1)
        d_loss_fake = self.loss_function(outputs, fake_labels)

        return d_loss_real + d_loss_fake, None
Ejemplo n.º 2
0
    def compute_loss_against(self, opponent, input):
        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1

        batch_size = input.size(0)
        sequence_length = input.size(1)
        num_inputs = input.size(2)

        real_labels = to_pytorch_variable(torch.ones(batch_size))
        fake_labels = to_pytorch_variable(torch.zeros(batch_size))

        outputs_intermediate = self.net(input)
        sm = Softmax()

        outputs = sm(outputs_intermediate[:, -1, :].contiguous().view(-1))
        d_loss_real = self.loss_function(outputs, real_labels)

        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = noise(batch_size, self.data_size)
        new_z = z.unsqueeze(1).repeat(1, sequence_length, 1)

        fake_images = opponent.net(new_z)
        outputs_full = self.net(fake_images)
        sm = Softmax()

        outputs = sm(outputs_full[:, -1, :].contiguous().view(-1))
        d_loss_fake = self.loss_function(outputs, fake_labels)

        return d_loss_real + d_loss_fake, None
Ejemplo n.º 3
0
    def train(self, n_iterations, stop_event=None):
        loader = self.dataloader.load()

        self.lw_cache.init_session(n_iterations, self.population_gen, self.population_dis)

        for i in range(n_iterations):
            for j, (input_data, labels) in enumerate(loader):

                batch_size = input_data.size(0)
                input_var = to_pytorch_variable(input_data.view(batch_size, -1))

                self.evolve_generation(self.population_gen, self.population_dis, input_var)

                self.lw_cache.append_stepsizes(self.population_gen, self.population_dis)

                # If n_batches is set to 0, all batches will be used
                if self.dataloader.n_batches != 0 and self.dataloader.n_batches - 1 == j:
                    break

            self.lw_cache.log_best_individuals(i, self.population_gen, self.population_dis)
            self.log_results(batch_size, i, input_var, loader)

        self.lw_cache.end_session()

        return (self.population_gen.individuals[0].genome, self.population_gen.individuals[0].fitness), (
            self.population_dis.individuals[0].genome, self.population_dis.individuals[0].fitness)
    def train(self, n_iterations, stop_event=None):
        loader = self.dataloader.load()

        distributor = Distributor()

        if self._population_size % distributor.n_procs_overall != 0:
            self._logger.error("Number of overall processes should be a factor of population size, "
                               "distribution may not work.")

        self.lw_cache.init_session(n_iterations, self.population_gen, self.population_dis)

        for i in range(n_iterations):
            for j, (input_data, labels) in enumerate(loader):

                batch_size = input_data.size(0)
                input_var = to_pytorch_variable(input_data.view(batch_size, -1))

                populations = self.population_gen, self.population_dis
                self.evolve_generation(distributor, populations, input_var)

                self.lw_cache.append_stepsizes(self.population_gen, self.population_dis)

                # If n_batches is set to 0, all batches will be used
                if self.dataloader.n_batches != 0 and self.dataloader.n_batches - 1 == j:
                    break

            self.lw_cache.log_best_individuals(i, self.population_gen, self.population_dis)
            self.log_results(batch_size, i, input_var, loader)

        self.lw_cache.end_session()

        return (self.population_gen.individuals[0].genome, self.population_gen.individuals[0].fitness), (
            self.population_dis.individuals[0].genome, self.population_dis.individuals[0].fitness)
Ejemplo n.º 5
0
        def _plot_discriminator(discriminator, ax):
            if discriminator is not None:
                alphas = []
                for x in np.linspace(-1, 1, 8, endpoint=False):
                    for y in np.linspace(-1, 1, 8, endpoint=False):
                        center = torch.zeros(2)
                        center[0] = x + 0.125
                        center[1] = y + 0.125
                        alphas.append(
                            float(
                                discriminator.net(
                                    to_pytorch_variable(center))))

                alphas = np.asarray(alphas)
                normalized = (alphas - min(alphas)) / (max(alphas) -
                                                       min(alphas))
                plt.text(0.1,
                         0.9,
                         'Min: {}\nMax: {}'.format(min(alphas), max(alphas)),
                         transform=ax.transAxes)

                k = 0
                for x in np.linspace(-1, 1, 8, endpoint=False):
                    for y in np.linspace(-1, 1, 8, endpoint=False):
                        center = torch.zeros(2)
                        center[0] = x + 0.125
                        center[1] = y + 0.125
                        ax.fill([x, x + 0.25, x + 0.25, x],
                                [y, y, y + 0.25, y + 0.25],
                                'r',
                                alpha=normalized[k],
                                zorder=0)
                        k += 1
Ejemplo n.º 6
0
    def test_majority_voting_discriminators(self,
                                            models,
                                            test_loader,
                                            train=False):
        correct = 0
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        train_or_test = "Train" if train else "Test"
        with torch.no_grad():
            for data, target in test_loader:
                data, target = data.to(device), target.to(device)
                if self.cc.settings["network"]["name"] == "ssgan_perceptron":
                    data = data.view(-1, 784)
                elif self.cc.settings["network"][
                        "name"] == "ssgan_conv_mnist_28x28":
                    data = data.view(-1, 1, 28, 28)
                elif self.cc.settings["network"]["name"] == "ssgan_svhn":
                    data = data.view(-1, 3, 32, 32)
                else:
                    if self.cc.settings["dataloader"][
                            "dataset_name"] == "cifar":
                        data = data.view(-1, 3, 64, 64)
                    else:
                        data = data.view(-1, 1, 64, 64)
                pred_accumulator = []
                for model in models:
                    output = model.classification_layer(model.net(data))
                    output = output.view(-1, 11)
                    pred = output.argmax(dim=1, keepdim=True)
                    pred_accumulator.append(pred.view(-1))
                label_votes = to_pytorch_variable(
                    torch.tensor(list(zip(*pred_accumulator))))
                prediction = to_pytorch_variable(
                    torch.tensor([
                        labels.bincount(minlength=11).argmax()
                        for labels in label_votes
                    ]))
                correct += prediction.eq(
                    target.view_as(prediction)).sum().item()

        num_samples = len(test_loader.dataset)
        accuracy = 100.0 * float(correct / num_samples)
        self._logger.info(
            f"Majority Voting {train_or_test} Accuracy: {correct}/{num_samples} ({accuracy}%)"
        )
Ejemplo n.º 7
0
    def compute_loss_against(self, opponent, input):
        # Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))
        # Second term of the loss is always zero since real_labels == 1
        batch_size = input.size(0)
        real_labels = to_pytorch_variable(torch.ones(batch_size))
        fake_labels = to_pytorch_variable(torch.zeros(batch_size))

        outputs = self.net(input).view(-1)
        #d_loss_real = self.loss_function(outputs, real_labels)
        d_loss_real = torch.nn.functional.binary_cross_entropy(outputs, real_labels)

        # Compute BCELoss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = noise(batch_size, self.data_size)
        fake_images = opponent.net(z)
        outputs = self.net(fake_images).view(-1)
        d_loss_fake = torch.nn.functional.binary_cross_entropy(outputs, fake_labels)

        return d_loss_real + d_loss_fake, None
Ejemplo n.º 8
0
    def compute_loss_against(self, opponent, input, training_epoch=None):
        batch_size = input.size(0)

        real_labels = to_pytorch_variable(torch.ones(batch_size))

        z = noise(batch_size, self.data_size)

        fake_images = self.net(z)
        outputs = opponent.net(fake_images).view(-1)

        return self.loss_function(outputs, real_labels), fake_images, None
Ejemplo n.º 9
0
    def compute_loss_against(self, opponent, input, labels = None, alpha = None, beta = None, iter = None, log_class_distribution = False,):
        FloatTensor = torch.cuda.FloatTensor if is_cuda_enabled() else torch.FloatTensor
        LongTensor = torch.cuda.LongTensor if is_cuda_enabled() else torch.LongTensor
        batch_size = input.size(0)
        # print(batch_size)
        # print(input.size(1))
        # print('batch size')
        # print(batch_size)
        real_labels = to_pytorch_variable(torch.ones(batch_size))  # label all generator images 1 (real)

        z = noise(batch_size, self.data_size)  # dims: batch size x data_size

        labels = LongTensor(
            np.random.randint(0, self.num_classes, batch_size))  # random labels between 0 and 9, output of shape batch_size
        labels = labels.view(-1, 1)
        labels_onehot = torch.FloatTensor(batch_size, self.num_classes)
        labels_onehot.zero_()
        labels_onehot.scatter_(1, labels, 1)
        # print(labels_onehot)


        labels = to_pytorch_variable(labels_onehot.type(FloatTensor))
        # print(labels)
        # print(self.label_emb(labels))
        # concatenate z and labels here before passing into generator net

        gen_input = torch.cat((labels, z), -1)
        # print(gen_input)
        # print('gen input shape')
        # print(gen_input.shape)
        fake_images = self.net(gen_input)
        # print('fake images shape')
        # print(fake_images.shape)
        # fake_images = fake_images.view(fake_images.size(0), *)

        dis_input = torch.cat((fake_images, labels), -1)  # discriminator training data input
        # concatenate fake_images and labels here before passing into discriminator net
        outputs = opponent.net(dis_input).view(-1)  # view(-1) flattens tensor

        return self.loss_function(outputs,
                                  real_labels), fake_images  # loss function evaluated discriminator output vs. 1 (generator trying to get discriminator output to be 1)
Ejemplo n.º 10
0
    def compute_loss_against(self, opponent, input, training_epoch=None):

        print('ITERATION: {}'.format(training_epoch))
        # If HeuristicLoss is applied in the Generator, the Discriminator applies BCELoss
        if self.loss_function.__class__.__name__ == 'MustangsLoss':
            if 'HeuristicLoss' in self.loss_function.get_applied_loss_name():
                self.loss_function.set_applied_loss(torch.nn.BCELoss())

        # Compute loss using real images
        # Second term of the loss is always zero since real_labels == 1
        batch_size = input.size(0)

        # Adding noise to prevent Discriminator from getting too strong
        if training_epoch is not None:
            std = max(self.in_std_min,
                      self.in_std - training_epoch * self.in_std_decay_rate)
        else:
            std = self.in_std
        print('Perturbation Std: {}'.format(std))
        input_perturbation = to_pytorch_variable(
            torch.empty(input.shape).normal_(mean=self.in_mean, std=std))
        input = input + input_perturbation

        input = input.view(-1, 1, self.image_length, self.image_width)

        real_labels = to_pytorch_variable(torch.ones(batch_size))
        fake_labels = to_pytorch_variable(torch.zeros(batch_size))

        outputs = self.net(input).view(-1)
        d_loss_real = self.loss_function(outputs, real_labels)

        # Compute loss using fake images
        # First term of the loss is always zero since fake_labels == 0
        z = noise(batch_size, self.data_size)
        fake_images = opponent.net(z)
        outputs = self.net(fake_images).view(-1)
        d_loss_fake = self.loss_function(outputs, fake_labels)

        return d_loss_real + d_loss_fake, None, None
Ejemplo n.º 11
0
    def train(self, n_iterations, stop_event=None):

        loader = self.dataloader.load()

        self.lw_cache.init_session(n_iterations, self.population_gen,
                                   self.population_dis)

        for i in range(n_iterations):
            for j, (input_data, labels) in enumerate(loader):

                batch_size = input_data.size(0)
                input_var = to_pytorch_variable(input_data.view(
                    batch_size, -1))

                if i == 0:
                    self.evaluate_fitness_against_population(
                        self.population_gen, self.population_dis, input_var)

                for attacker, defender in ((self.population_gen,
                                            self.population_dis),
                                           (self.population_dis,
                                            self.population_gen)):
                    new_population = self.tournament_selection(attacker)
                    self.mutate_gaussian(new_population)
                    self.evaluate_fitness_against_population(
                        new_population, defender, input_var)

                    # Replace the worst with the best new
                    attacker.replacement(new_population, self._n_replacements)
                    attacker.sort_population()

                self.lw_cache.append_stepsizes(self.population_gen,
                                               self.population_dis)

                # If n_batches is set to 0, all batches will be used
                if self.dataloader.n_batches != 0 and self.dataloader.n_batches - 1 == j:
                    break

            self.lw_cache.log_best_individuals(i, self.population_gen,
                                               self.population_dis)
            self.log_results(batch_size, i, input_var, loader)

        self.lw_cache.end_session()

        return (self.population_gen.individuals[0].genome,
                self.population_gen.individuals[0].fitness), (
                    self.population_dis.individuals[0].genome,
                    self.population_dis.individuals[0].fitness)
Ejemplo n.º 12
0
    def _update_discriminators(self, population_attacker, population_defender, input_var, loaded, data_iterator, defender_weights):

        batch_size = input_var.size(0)

        for discriminator in population_attacker.individuals:
            weights = [self.get_weight(defender, defender_weights) for defender in population_defender.individuals]
            weights /= np.sum(weights)
            generator = np.random.choice(population_defender.individuals, p=weights)
            optimizer = self._get_optimizer(discriminator)

            # Train the discriminator Diters times
            if self.gen_iterations < 25 or self.gen_iterations % 500 == 0:
                discriminator_iterations = 100
            else:
                discriminator_iterations = DISCRIMINATOR_STEPS

            j = 0
            while j < discriminator_iterations and self.batch_number < len(loaded):
                if j > 0:
                    input_var = to_pytorch_variable(self.dataloader.transpose_data(next(data_iterator)[0]))
                j += 1

                # Train with real data
                discriminator.genome.net.zero_grad()
                error_real = discriminator.genome.net(input_var)
                error_real = error_real.mean(0).view(1)
                error_real.backward(self.real_labels)

                # Train with fake data
                z = noise(batch_size, generator.genome.data_size)
                z.volatile = True
                fake_data = Variable(generator.genome.net(z).data)
                loss = discriminator.genome.net(fake_data).mean(0).view(1)
                loss.backward(self.fake_labels)
                optimizer.step()

                # Clamp parameters to a cube
                for p in discriminator.genome.net.parameters():
                    p.data.clamp_(CLAMP_LOWER, CLAMP_UPPER)

                self.batch_number += 1

            discriminator.optimizer_state = optimizer.state_dict()

        return input_var
Ejemplo n.º 13
0
    def compute_loss_against(self, opponent, input):
        batch_size = input.size(0)

        real_labels = to_pytorch_variable(torch.ones(batch_size))

        z = noise(batch_size, self.data_size)
        fake_images = self.net(z)
        outputs = opponent.net(fake_images).view(-1)

        # Compute BCELoss using D(G(z))
        if self.loss_function.__class__.__name__ == 'SMuGANLoss':
            prob = np.random.uniform()
            if prob < 0.33:
                loss = self.bceloss(outputs, real_labels)
            elif prob < 0.66 :
                loss = self.mseloss(outputs, real_labels)
            else:
                loss = self.heuristicloss(outputs, real_labels)
            return loss, fake_images
        else:
            return self.loss_function(outputs, real_labels), fake_images
Ejemplo n.º 14
0
    def train(self, n_iterations, stop_event=None):

        loader = self.dataloader.load()

        self.lw_cache.init_session(n_iterations, self.population_gen, self.population_dis)

        for i in range(n_iterations):
            for j, (input_data, labels) in enumerate(loader):

                batch_size = input_data.size(0)
                input_var = to_pytorch_variable(input_data.view(batch_size, -1))

                if i == 0 and j == 0:
                    self.evaluate_fitness_against_population(self.population_gen, self.population_dis, input_var)

                new_population_generator = self.tournament_selection(self.population_gen, TYPE_GENERATOR)
                new_population_discriminator = self.tournament_selection(self.population_dis, TYPE_DISCRIMINATOR)
                self.mutate_gaussian(new_population_generator)
                self.mutate_gaussian(new_population_discriminator)

                self.evaluate_fitness_against_population(new_population_generator, new_population_discriminator, input_var)

                self.population_gen.replacement(new_population_generator, self._n_replacements)
                self.population_dis.replacement(new_population_discriminator, self._n_replacements)
                self.population_gen.sort_population()
                self.population_dis.sort_population()

                self.lw_cache.append_stepsizes(self.population_gen, self.population_dis)

                if self.dataloader.n_batches != 0 and self.dataloader.n_batches - 1 == j:
                    break

            self.lw_cache.log_best_individuals(i, self.population_gen, self.population_dis)
            self.log_results(batch_size, i, input_var, loader)

        self.lw_cache.end_session()

        return (self.population_gen.individuals[0].genome, self.population_gen.individuals[0].fitness), (
            self.population_dis.individuals[0].genome, self.population_dis.individuals[0].fitness)
Ejemplo n.º 15
0
    def compute_loss_against(self, opponent, input):
        batch_size = input.size(0)
        sequence_length = input.size(1)
        num_inputs = input.size(2)
        # batch_size = input.shape[0]

        # Define differently based on whether we're evaluating entire sequences as true or false, vs. individual messages.
        real_labels = to_pytorch_variable(torch.ones(batch_size))

        z = noise(batch_size, self.data_size)

        # Repeats the noise to match the shape
        new_z = z.unsqueeze(1).repeat(1, sequence_length, 1)
        fake_sequences = self.net(new_z)

        outputs_intermediate = opponent.net(fake_sequences)

        # Compute BCELoss using D(G(z))
        sm = Softmax()
        outputs = sm(outputs_intermediate[:, -1, :].contiguous().view(-1))

        return self.loss_function(outputs, real_labels), fake_sequences
Ejemplo n.º 16
0
    def train(self, n_iterations, stop_event=None):

        loaded = self.dataloader.load()

        for iteration in range(n_iterations):
            self._logger.debug('Iteration {} started'.format(iteration))
            start_time = time()

            all_generators = self.neighbourhood.all_generators
            all_discriminators = self.neighbourhood.all_discriminators
            local_generators = self.neighbourhood.local_generators
            local_discriminators = self.neighbourhood.local_discriminators

            new_populations = {}

            self.batch_number = 0
            data_iterator = iter(loaded)
            while self.batch_number < len(loaded):
                input_data = next(data_iterator)[0]
                batch_size = input_data.size(0)
                input_data = to_pytorch_variable(
                    self.dataloader.transpose_data(input_data))

                if iteration == 0 and self.batch_number == 0:
                    self._logger.debug('Evaluating first fitness')
                    self.evaluate_fitness(local_generators, all_discriminators,
                                          input_data)
                    self._logger.debug('Finished evaluating first fitness')

                if self.batch_number == 0 and self._enable_selection:
                    self._logger.debug('Started tournamend selection')
                    new_populations[
                        TYPE_GENERATOR] = self.tournament_selection(
                            all_generators, TYPE_GENERATOR)
                    new_populations[
                        TYPE_DISCRIMINATOR] = self.tournament_selection(
                            all_discriminators, TYPE_DISCRIMINATOR)
                    self._logger.debug('Finished tournamend selection')

                # Quit if requested
                if stop_event is not None and stop_event.is_set():
                    self._logger.warning('External stop requested.')
                    return self.result()

                attackers = new_populations[
                    TYPE_GENERATOR] if self._enable_selection else local_generators
                defenders = new_populations[
                    TYPE_DISCRIMINATOR] if self._enable_selection else all_discriminators
                input_data = self.step(
                    local_generators, attackers, defenders, input_data, loaded,
                    data_iterator,
                    self.neighbourhood.mixture_weights_discriminators)

                if self._discriminator_skip_each_nth_step == 0 or self.batch_number % (
                        self._discriminator_skip_each_nth_step + 1) == 0:
                    attackers = new_populations[
                        TYPE_DISCRIMINATOR] if self._enable_selection else local_discriminators
                    defenders = new_populations[
                        TYPE_GENERATOR] if self._enable_selection else all_generators

                    input_data = self.step(
                        local_discriminators, attackers, defenders, input_data,
                        loaded, data_iterator,
                        self.neighbourhood.mixture_weights_generators)

                self._logger.info('Batch {}/{} done'.format(
                    self.batch_number, len(loaded)))

                # If n_batches is set to 0, all batches will be used
                if self.is_last_batch(self.batch_number):
                    break

                self.batch_number += 1

            # Mutate mixture weights
            weights_generators = self.neighbourhood.mixture_weights_generators
            weights_discriminators = self.neighbourhood.mixture_weights_discriminators
            generators = new_populations[
                TYPE_GENERATOR] if self._enable_selection else all_generators
            discriminators = new_populations[
                TYPE_DISCRIMINATOR] if self._enable_selection else all_discriminators
            self.mutate_mixture_weights(weights_generators,
                                        weights_discriminators, generators,
                                        discriminators, input_data)
            self.mutate_mixture_weights(weights_discriminators,
                                        weights_generators, discriminators,
                                        generators, input_data)

            # Replace the worst with the best new
            if self._enable_selection:
                self.evaluate_fitness(new_populations[TYPE_GENERATOR],
                                      new_populations[TYPE_DISCRIMINATOR],
                                      input_data)
                self.concurrent_populations.lock()
                local_generators.replacement(new_populations[TYPE_GENERATOR],
                                             self._n_replacements)
                local_generators.sort_population()
                local_discriminators.replacement(
                    new_populations[TYPE_DISCRIMINATOR], self._n_replacements)
                local_discriminators.sort_population()
                self.concurrent_populations.unlock()
            else:
                self.evaluate_fitness(all_generators, all_discriminators,
                                      input_data)

            if self.score_calc is not None:
                self._logger.info('Calculating FID/inception score.')
                self.calculate_score()

            stop_time = time()

            path_real_images, path_fake_images = \
                self.log_results(batch_size, iteration, input_data, loaded,
                                 lr_gen=self.concurrent_populations.generator.individuals[0].learning_rate,
                                 lr_dis=self.concurrent_populations.discriminator.individuals[0].learning_rate,
                                 score=self.score, mixture_gen=self.neighbourhood.mixture_weights_generators,
                                 mixture_dis=self.neighbourhood.mixture_weights_discriminators)

            if self.db_logger.is_enabled:
                self.db_logger.log_results(iteration, self.neighbourhood,
                                           self.concurrent_populations,
                                           self.score, stop_time - start_time,
                                           path_real_images, path_fake_images)

        return self.result()
Ejemplo n.º 17
0
    def compute_loss_against(self, opponent, input, labels = None, alpha = None, beta = None, iter = None, log_class_distribution = False,):

        # need to pass in the labels from dataloader too in lipizzaner_gan_trainer.py
        # Compute loss using real images
        # Second term of the loss is always zero since real_labels == 1
        batch_size = input.size(0)

        FloatTensor = torch.cuda.FloatTensor if is_cuda_enabled() else torch.FloatTensor
        LongTensor = torch.cuda.LongTensor if is_cuda_enabled() else torch.LongTensor

        real_labels = torch.Tensor(batch_size)
        real_labels.fill_(0.9)
        real_labels = to_pytorch_variable(real_labels)

        fake_labels = to_pytorch_variable(torch.zeros(batch_size))

        labels = labels.view(-1, 1).cuda() if is_cuda_enabled() else labels.view(-1, 1)
        labels_onehot = torch.FloatTensor(batch_size, self.num_classes)
        labels_onehot.zero_()
        labels_onehot.scatter_(1, labels, 1)

        labels = to_pytorch_variable(labels_onehot.type(FloatTensor))

        instance_noise_std_dev_min = 0.5
        instance_noise_std_dev_max = 5.0
        instance_noise_std_dev = 2.5
        instance_noise_mean = 0


        # Adding instance noise to prevent Discriminator from getting too strong
        if iter is not None:
            std = max(
                instance_noise_std_dev_min,
                instance_noise_std_dev_max - iter * 0.001,
            )
        else:
            instance_noise_std_dev

        input_perturbation = to_pytorch_variable(
            torch.empty(input.shape).normal_(mean=instance_noise_mean, std=std)
        )

        input = input + input_perturbation

        dis_input = torch.cat((input, labels), -1)  # discriminator training data input

        outputs = self.net(dis_input).view(-1)  # pass in training data input and respective labels to discriminator
        d_loss_real = self.loss_function(outputs, real_labels)  # get real image loss of discriminator (output vs. 1)

        # torch.cat((img.view(img.size(0), -1), self.label_embedding(gen_labels)), -1)

        # Compute loss using fake images
        # First term of the loss is always zero since fake_labels == 0
        gen_labels = LongTensor(np.random.randint(0, self.num_classes, batch_size))  # random labels for generator input

        z = noise(batch_size, self.data_size)  # noise for generator input

        gen_labels = gen_labels.view(-1, 1)
        labels_onehot = torch.FloatTensor(batch_size, self.num_classes)
        labels_onehot.zero_()
        labels_onehot.scatter_(1, gen_labels, 1)

        gen_labels = to_pytorch_variable(labels_onehot.type(FloatTensor))

        gen_input = torch.cat((gen_labels, z), -1)

        fake_images = opponent.net(gen_input)
        # print('fake images shape')
        # print(fake_images.shape)
        dis_input = torch.cat((fake_images, gen_labels), -1)  # discriminator training data input
        outputs = self.net(dis_input).view(-1)
        d_loss_fake = self.loss_function(outputs, fake_labels)  # get fake image loss of discriminator (output vs. 0)

        return (d_loss_real + d_loss_fake), None
Ejemplo n.º 18
0
    def train(self, n_iterations, stop_event=None):
        loaded = self.dataloader.load()

        alpha = None  #self.neighbourhood.alpha
        beta = None  #self.neighbourhood.beta
        if alpha is not None:
            self._logger.info(f"Alpha is {alpha} and Beta is {beta}")
        else:
            self._logger.debug("Alpha and Beta are not set")

        for iteration in range(n_iterations):
            self._logger.debug("Iteration {} started".format(iteration + 1))
            start_time = time()

            all_generators = self.neighbourhood.all_generators
            all_discriminators = self.neighbourhood.all_discriminators
            local_generators = self.neighbourhood.local_generators
            local_discriminators = self.neighbourhood.local_discriminators

            # Log the name of individuals in entire neighborhood and local individuals for every iteration
            # (to help tracing because individuals from adjacent cells might be from different iterations)
            self._logger.info(
                "Neighborhood located in possition {} of the grid".format(
                    self.neighbourhood.grid_position))
            self._logger.info(
                "Generators in current neighborhood are {}".format([
                    individual.name
                    for individual in all_generators.individuals
                ]))
            self._logger.info(
                "Discriminators in current neighborhood are {}".format([
                    individual.name
                    for individual in all_discriminators.individuals
                ]))
            self._logger.info(
                "Local generators in current neighborhood are {}".format([
                    individual.name
                    for individual in local_generators.individuals
                ]))
            self._logger.info(
                "Local discriminators in current neighborhood are {}".format([
                    individual.name
                    for individual in local_discriminators.individuals
                ]))

            self._logger.info(
                "L2 distance between all generators weights: {}".format(
                    all_generators.net_weights_dist))
            self._logger.info(
                "L2 distance between all discriminators weights: {}".format(
                    all_discriminators.net_weights_dist))

            new_populations = {}

            # Create random dataset to evaluate fitness in each iterations
            (
                fitness_samples,
                fitness_labels,
            ) = self.generate_random_fitness_samples(self.fitness_sample_size)
            if (self.cc.settings["dataloader"]["dataset_name"] == "celeba" or
                    self.cc.settings["dataloader"]["dataset_name"] == "cifar"
                    or self.cc.settings["network"]["name"]
                    == "ssgan_convolutional_mnist"):
                fitness_samples = to_pytorch_variable(fitness_samples)
                fitness_labels = to_pytorch_variable(fitness_labels)
            elif self.cc.settings["dataloader"][
                    "dataset_name"] == "network_traffic":
                fitness_samples = to_pytorch_variable(
                    generate_random_sequences(self.fitness_sample_size))
            else:
                fitness_samples = to_pytorch_variable(
                    fitness_samples.view(self.fitness_sample_size, -1))
                fitness_labels = to_pytorch_variable(
                    fitness_labels.view(self.fitness_sample_size, -1))

            fitness_labels = torch.squeeze(fitness_labels)

            # Fitness evaluation
            self._logger.debug("Evaluating fitness")
            self.evaluate_fitness(
                all_generators,
                all_discriminators,
                fitness_samples,
                self.fitness_mode,
            )
            self.evaluate_fitness(
                all_discriminators,
                all_generators,
                fitness_samples,
                self.fitness_mode,
                labels=fitness_labels,
                logger=self._logger,
                alpha=alpha,
                beta=beta,
                iter=iteration,
                log_class_distribution=True,
            )
            self._logger.debug("Finished evaluating fitness")

            # Tournament selection
            if self._enable_selection:
                self._logger.debug("Started tournament selection")
                new_populations[TYPE_GENERATOR] = self.tournament_selection(
                    all_generators, TYPE_GENERATOR, is_logging=True)
                new_populations[
                    TYPE_DISCRIMINATOR] = self.tournament_selection(
                        all_discriminators,
                        TYPE_DISCRIMINATOR,
                        is_logging=True)
                self._logger.debug("Finished tournament selection")

            self.batch_number = 0
            data_iterator = iter(loaded)
            while self.batch_number < len(loaded):
                if self.cc.settings["dataloader"][
                        "dataset_name"] == "network_traffic":
                    input_data = to_pytorch_variable(next(data_iterator))
                    batch_size = input_data.size(0)
                else:
                    input_data, labels = next(data_iterator)
                    batch_size = input_data.size(0)
                    input_data = to_pytorch_variable(
                        self.dataloader.transpose_data(input_data))
                    labels = to_pytorch_variable(
                        self.dataloader.transpose_data(labels))
                    labels = torch.squeeze(labels)

                # Quit if requested
                if stop_event is not None and stop_event.is_set():
                    self._logger.warning("External stop requested.")
                    return self.result()

                attackers = new_populations[
                    TYPE_GENERATOR] if self._enable_selection else local_generators
                defenders = new_populations[
                    TYPE_DISCRIMINATOR] if self._enable_selection else all_discriminators
                input_data = self.step(
                    local_generators,
                    attackers,
                    defenders,
                    input_data,
                    self.batch_number,
                    loaded,
                    data_iterator,
                    iter=iteration,
                )

                if (self._discriminator_skip_each_nth_step == 0
                        or self.batch_number %
                    (self._discriminator_skip_each_nth_step + 1) == 0):
                    self._logger.debug("Skipping discriminator step")

                    attackers = new_populations[
                        TYPE_DISCRIMINATOR] if self._enable_selection else local_discriminators
                    defenders = new_populations[
                        TYPE_GENERATOR] if self._enable_selection else all_generators
                    input_data = self.step(
                        local_discriminators,
                        attackers,
                        defenders,
                        input_data,
                        self.batch_number,
                        loaded,
                        data_iterator,
                        labels=labels,
                        alpha=alpha,
                        beta=beta,
                        iter=iteration,
                    )

                self._logger.info("Iteration {}, Batch {}/{}".format(
                    iteration + 1, self.batch_number, len(loaded)))

                # If n_batches is set to 0, all batches will be used
                if self.is_last_batch(self.batch_number):
                    break

                self.batch_number += 1

            # Perform selection first before mutation of mixture_weights
            # Replace the worst with the best new
            if self._enable_selection:
                # Evaluate fitness of new_populations against neighborhood
                self.evaluate_fitness(
                    new_populations[TYPE_GENERATOR],
                    all_discriminators,
                    fitness_samples,
                    self.fitness_mode,
                )
                self.evaluate_fitness(
                    new_populations[TYPE_DISCRIMINATOR],
                    all_generators,
                    fitness_samples,
                    self.fitness_mode,
                    labels=fitness_labels,
                    alpha=alpha,
                    beta=beta,
                    iter=iteration,
                )
                self.concurrent_populations.lock()
                local_generators.replacement(
                    new_populations[TYPE_GENERATOR],
                    self._n_replacements,
                    is_logging=True,
                )
                local_generators.sort_population(is_logging=True)
                local_discriminators.replacement(
                    new_populations[TYPE_DISCRIMINATOR],
                    self._n_replacements,
                    is_logging=True,
                )
                local_discriminators.sort_population(is_logging=True)
                self.concurrent_populations.unlock()

                # Update individuals' iteration and id after replacement and logging to ease tracing
                for i, individual in enumerate(local_generators.individuals):
                    individual.id = "{}/G{}".format(
                        self.neighbourhood.cell_number, i)
                    individual.iteration = iteration + 1
                for i, individual in enumerate(
                        local_discriminators.individuals):
                    individual.id = "{}/D{}".format(
                        self.neighbourhood.cell_number, i)
                    individual.iteration = iteration + 1
            else:
                # Re-evaluate fitness of local_generators and local_discriminators against neighborhood
                self.evaluate_fitness(
                    local_generators,
                    all_discriminators,
                    fitness_samples,
                    self.fitness_mode,
                )
                self.evaluate_fitness(
                    local_discriminators,
                    all_generators,
                    fitness_samples,
                    self.fitness_mode,
                    labels=fitness_labels,
                    alpha=alpha,
                    beta=beta,
                    iter=iteration,
                )

            self.compute_mixture_generative_score(iteration)

            stop_time = time()

            path_real_images, path_fake_images = self.log_results(
                batch_size,
                iteration,
                input_data,
                loaded,
                lr_gen=self.concurrent_populations.generator.individuals[0].
                learning_rate,
                lr_dis=self.concurrent_populations.discriminator.
                individuals[0].learning_rate,
                score=self.score,
                mixture_gen=self.neighbourhood.mixture_weights_generators,
                mixture_dis=None,
            )

            if self.db_logger.is_enabled:
                self.db_logger.log_results(
                    iteration,
                    self.neighbourhood,
                    self.concurrent_populations,
                    self.score,
                    stop_time - start_time,
                    path_real_images,
                    path_fake_images,
                )

            if self.checkpoint_period > 0 and (
                    iteration + 1) % self.checkpoint_period == 0:
                self.save_checkpoint(
                    all_generators.individuals,
                    all_discriminators.individuals,
                    self.neighbourhood.cell_number,
                    self.neighbourhood.grid_position,
                )

        # Evaluate the discriminators when addressing Semi-supervised Learning
        if "ssgan" in self.cc.settings["network"]["name"]:
            discriminator = self.concurrent_populations.discriminator.individuals[
                0].genome
            batch_size = self.dataloader.batch_size
            dataloader_loaded = self.dataloader.load(train=True)
            self.test_accuracy_discriminators(discriminator,
                                              dataloader_loaded,
                                              train=True)

            self.dataloader.batch_size = 100
            dataloader_loaded = self.dataloader.load(train=False)
            self.test_accuracy_discriminators(discriminator,
                                              dataloader_loaded,
                                              train=False)

            discriminators = [
                individual.genome for individual in
                self.neighbourhood.all_discriminators.individuals
            ]

            for model in discriminators:
                dataloader_loaded = self.dataloader.load(train=False)
                self.test_accuracy_discriminators(model,
                                                  dataloader_loaded,
                                                  train=False)

            dataloader_loaded = self.dataloader.load(train=False)
            self.test_majority_voting_discriminators(discriminators,
                                                     dataloader_loaded,
                                                     train=False)
            self.dataloader.batch_size = batch_size

        if self.optimize_weights_at_the_end:
            self.optimize_generator_mixture_weights()

            path_real_images, path_fake_images = self.log_results(
                batch_size,
                iteration + 1,
                input_data,
                loaded,
                lr_gen=self.concurrent_populations.generator.individuals[0].
                learning_rate,
                lr_dis=self.concurrent_populations.discriminator.
                individuals[0].learning_rate,
                score=self.score,
                mixture_gen=self.neighbourhood.mixture_weights_generators,
                mixture_dis=self.neighbourhood.mixture_weights_discriminators,
            )

            if self.db_logger.is_enabled:
                self.db_logger.log_results(
                    iteration + 1,
                    self.neighbourhood,
                    self.concurrent_populations,
                    self.score,
                    stop_time - start_time,
                    path_real_images,
                    path_fake_images,
                )

        return self.result()
Ejemplo n.º 19
0
    def train(self, n_iterations, stop_event=None):

        loaded = self.dataloader.load()

        for iteration in range(n_iterations):
            self._logger.debug('Iteration {} started'.format(iteration + 1))
            start_time = time()

            local_generators = self.neighbourhood.local_generators
            local_discriminators = self.neighbourhood.local_discriminators
            all_generators, all_discriminators = self.neighbourhood.all_disc_gen_local(
            )

            # Local functions

            # Log the name of individuals in entire neighborhood for every iteration
            # (to help tracing because individuals from adjacent cells might be from different iterations)
            self._logger.info(
                'Generators in current neighborhood are {}'.format([
                    individual.name
                    for individual in all_generators.individuals
                ]))
            self._logger.info(
                'Discriminators in current neighborhood are {}'.format([
                    individual.name
                    for individual in all_discriminators.individuals
                ]))

            # TODO: Fixme
            # self._logger.info('L2 distance between all generators weights: {}'.format(all_generators.net_weights_dist))
            # self._logger.info('L2 distance between all discriminators weights: {}'.format(all_discriminators.net_weights_dist))

            new_populations = {}

            # Create random dataset to evaluate fitness in each iterations
            fitness_samples = self.generate_random_fitness_samples(
                self.fitness_sample_size)
            if self.cc.settings['dataloader']['dataset_name'] == 'celeba' \
                or self.cc.settings['dataloader']['dataset_name'] == 'cifar':
                fitness_samples = to_pytorch_variable(fitness_samples)
            else:
                fitness_samples = to_pytorch_variable(
                    fitness_samples.view(self.fitness_sample_size, -1))

            # Fitness evaluation
            self._logger.debug('Evaluating fitness')
            self.evaluate_fitness(all_generators, all_discriminators,
                                  fitness_samples, self.fitness_mode)
            self.evaluate_fitness(all_discriminators, all_generators,
                                  fitness_samples, self.fitness_mode)
            self._logger.debug('Finished evaluating fitness')

            # Tournament selection
            if self._enable_selection:
                self._logger.debug('Started tournament selection')
                new_populations[TYPE_GENERATOR] = self.tournament_selection(
                    all_generators, TYPE_GENERATOR, is_logging=True)
                new_populations[
                    TYPE_DISCRIMINATOR] = self.tournament_selection(
                        all_discriminators,
                        TYPE_DISCRIMINATOR,
                        is_logging=True)
                self._logger.debug('Finished tournament selection')

            self.batch_number = 0
            data_iterator = iter(loaded)
            while self.batch_number < len(loaded):
                # for i, (input_data, labels) in enumerate(loaded):
                input_data = next(data_iterator)[0]
                batch_size = input_data.size(0)
                input_data = to_pytorch_variable(
                    self.dataloader.transpose_data(input_data))

                # Quit if requested
                if stop_event is not None and stop_event.is_set():
                    self._logger.warning('External stop requested.')
                    return self.result()

                attackers = new_populations[
                    TYPE_GENERATOR] if self._enable_selection else local_generators
                defenders = new_populations[
                    TYPE_DISCRIMINATOR] if self._enable_selection else all_discriminators
                input_data = self.step(local_generators, attackers, defenders,
                                       input_data, self.batch_number, loaded,
                                       data_iterator)

                if self._discriminator_skip_each_nth_step == 0 or self.batch_number % (
                        self._discriminator_skip_each_nth_step + 1) == 0:
                    self._logger.debug('Skipping discriminator step')

                    attackers = new_populations[
                        TYPE_DISCRIMINATOR] if self._enable_selection else local_discriminators
                    defenders = new_populations[
                        TYPE_GENERATOR] if self._enable_selection else all_generators
                    input_data = self.step(local_discriminators, attackers,
                                           defenders, input_data,
                                           self.batch_number, loaded,
                                           data_iterator)

                self._logger.info('Iteration {}, Batch {}/{}'.format(
                    iteration + 1, self.batch_number, len(loaded)))

                # If n_batches is set to 0, all batches will be used
                if self.is_last_batch(self.batch_number):
                    break

                self.batch_number += 1

            # Perform selection first before mutation of mixture_weights
            # Replace the worst with the best new
            if self._enable_selection:
                # Evaluate fitness of new_populations against neighborhood
                self.evaluate_fitness(new_populations[TYPE_GENERATOR],
                                      all_discriminators, fitness_samples,
                                      self.fitness_mode)
                self.evaluate_fitness(new_populations[TYPE_DISCRIMINATOR],
                                      all_generators, fitness_samples,
                                      self.fitness_mode)
                self.concurrent_populations.lock()
                local_generators.replacement(new_populations[TYPE_GENERATOR],
                                             self._n_replacements,
                                             is_logging=True)
                local_generators.sort_population(is_logging=True)
                local_discriminators.replacement(
                    new_populations[TYPE_DISCRIMINATOR],
                    self._n_replacements,
                    is_logging=True)
                local_discriminators.sort_population(is_logging=True)
                self.concurrent_populations.unlock()

                # Update individuals' iteration and id after replacement and logging to ease tracing
                for i, individual in enumerate(local_generators.individuals):
                    individual.id = '{}/G{}'.format(
                        self.neighbourhood.cell_number, i)
                    individual.iteration = iteration + 1
                for i, individual in enumerate(
                        local_discriminators.individuals):
                    individual.id = '{}/D{}'.format(
                        self.neighbourhood.cell_number, i)
                    individual.iteration = iteration + 1
            else:
                # Re-evaluate fitness of local_generators and local_discriminators against neighborhood
                self.evaluate_fitness(local_generators, all_discriminators,
                                      fitness_samples, self.fitness_mode)
                self.evaluate_fitness(local_discriminators, all_generators,
                                      fitness_samples, self.fitness_mode)

            # Mutate mixture weights after selection
            self.mutate_mixture_weights_with_score(
                input_data)  # self.score is updated here

            stop_time = time()

            path_real_images, path_fake_images = \
                self.log_results(batch_size, iteration, input_data, loaded,
                                 lr_gen=self.concurrent_populations.generator.individuals[0].learning_rate,
                                 lr_dis=self.concurrent_populations.discriminator.individuals[0].learning_rate,
                                 score=self.score, mixture_gen=self.neighbourhood.mixture_weights_generators,
                                 mixture_dis=None)

            if self.db_logger.is_enabled:
                self.db_logger.log_results(iteration, self.neighbourhood,
                                           self.concurrent_populations,
                                           self.score, stop_time - start_time,
                                           path_real_images, path_fake_images)

        return self.result()
Ejemplo n.º 20
0
    def _update_discriminators(self, population_attacker, population_defender,
                               input_var, loaded, data_iterator):

        batch_size = input_var.size(0)
        # Randomly pick one only, referred from asynchronous_ea_trainer
        generator = random.choice(population_defender.individuals)

        for i, discriminator in enumerate(population_attacker.individuals):
            if i < len(population_attacker.individuals) - 1:
                # https://stackoverflow.com/a/42132767
                # Perform deep copy first instead of directly updating iterator passed in
                data_iterator, curr_iterator = tee(data_iterator)
            else:
                # Directly update the iterator with the last individual only, so that
                # every individual can learn from the full batch
                curr_iterator = data_iterator

            # Use temporary batch variable for each individual
            # so that every individual can learn from the full batch
            curr_batch_number = self.batch_number
            optimizer = self._get_optimizer(discriminator)

            # Train the discriminator Diters times
            if self.gen_iterations < 25 or self.gen_iterations % 500 == 0:
                discriminator_iterations = 100
            else:
                discriminator_iterations = DISCRIMINATOR_STEPS

            j = 0
            while j < discriminator_iterations and curr_batch_number < len(
                    loaded):
                if j > 0:
                    input_var = to_pytorch_variable(
                        self.dataloader.transpose_data(next(curr_iterator)[0]))
                j += 1

                # Train with real data
                discriminator.genome.net.zero_grad()
                error_real = discriminator.genome.net(input_var)
                error_real = error_real.mean(0).view(1)
                error_real.backward(self.real_labels)

                # Train with fake data
                z = noise(batch_size, generator.genome.data_size)
                z.volatile = True
                fake_data = Variable(generator.genome.net(z).data)
                loss = discriminator.genome.net(fake_data).mean(0).view(1)
                loss.backward(self.fake_labels)
                optimizer.step()

                # Clamp parameters to a cube
                for p in discriminator.genome.net.parameters():
                    p.data.clamp_(CLAMP_LOWER, CLAMP_UPPER)

                curr_batch_number += 1

            discriminator.optimizer_state = optimizer.state_dict()
        # Update the final batch_number to class variable after all individuals are updated
        self.batch_number = curr_batch_number

        return input_var
Ejemplo n.º 21
0
    def __init__(self, generator_population, weights, n_samples, mixture_generator_samples_mode, z=None):
        """
        Creates samples from a mixture of generators, with sample probability defined given a random noise vector
        sampled from the latent space by a weights vector

        :param generator_population: Population of generators that will be used to create the images
        :param weights: Dictionary that maps generator IDs to weights, e.g. {'127.0.0.1:5000': 0.8, '127.0.0.1:5001': 0.2}
        :param n_samples: Number of samples that will be generated
        :param mixture_generator_samples_mode:
        :param z: Noise vector from latent space. If it is not given it generates a new one
        """
        self.n_samples = n_samples
        self.individuals = sorted(generator_population.individuals, key=lambda x: x.source)
        for individual in self.individuals:
            individual.genome.net.eval()
        self.data = []

        self.cc = ConfigurationContainer.instance()

        weights = collections.OrderedDict(sorted(weights.items()))
        weights = {k: v for k, v in weights.items() if any([i for i in self.individuals if i.source == k])}
        weights_np = np.asarray(list(weights.values()))

        if np.sum(weights_np) != 1:
            weights_np = weights_np / np.sum(weights_np).astype(float)    # A bit of patching, but normalize it again

        if mixture_generator_samples_mode == 'independent_probability':
            self.gen_indices = np.random.choice(len(self.individuals), n_samples, p=weights_np.tolist())
        elif mixture_generator_samples_mode == 'exact_proportion':
            # Does not perform checking here if weights_np.tolist() sum up to one
            # There will be some trivial error if prob*n_samples is not integer for prob in weights_np.tolist()
            self.gen_indices = [
                i for gen_idx, prob in enumerate(weights_np.tolist()) for i in [gen_idx] * math.ceil(n_samples * prob)
            ]
            np.random.shuffle(self.gen_indices)
            self.gen_indices = self.gen_indices[:n_samples]
        else:
            raise NotImplementedError(
                "Invalid argument for mixture_generator_samples_mode: {}".format(mixture_generator_samples_mode)
            )

        num_classes = self.individuals[0].genome.num_classes if hasattr(self.individuals[0].genome, 'num_classes') \
                                                                and self.individuals[0].genome.num_classes != 0 else 0

        if z is None:
            z = noise(n_samples, self.individuals[0].genome.data_size)

            if num_classes != 0 and self.cc.settings["network"]["name"] == 'conditional_four_layer_perceptron':
                FloatTensor = torch.cuda.FloatTensor if is_cuda_enabled() else torch.FloatTensor
                LongTensor = torch.cuda.LongTensor if is_cuda_enabled() else torch.LongTensor
                labels = LongTensor(np.random.randint(0, num_classes, n_samples))  # random labels between 0 and 9, output of shape batch_size

                labels = labels.view(-1, 1)
                labels_onehot = torch.FloatTensor(n_samples, num_classes)
                labels_onehot.zero_()
                labels_onehot.scatter_(1, labels, 1)

                input_labels = to_pytorch_variable(labels_onehot.type(FloatTensor))

                self.z = torch.cat((input_labels, z), -1)
            else:
                self.z = z

        else:
            self.z = z

        #HACK: If it's a sequential model, add another dimension to the noise input
        # Also we're currently just using a fixed sequence length for sequence generation; make this
        # able to be specified by the user.
        if self.individuals[0].genome.name in ["DiscriminatorSequential", "GeneratorSequential"]:
            sequence_length = 100
            self.z = self.z.unsqueeze(1).repeat(1,sequence_length,1)
    def train(self, n_iterations, stop_event=None):

        cc = ConfigurationContainer.instance()
        session = None
        graph_loss = None
        graph_step_size = None

        generator = self.network_factory.create_generator()
        discriminator = self.network_factory.create_discriminator()

        g_optimizer = torch.optim.Adam(generator.net.parameters(), lr=0.0003)
        d_optimizer = torch.optim.Adam(discriminator.net.parameters(),
                                       lr=0.0003)

        if cc.is_losswise_enabled:
            session = losswise.Session(tag=self.__class__.__name__,
                                       max_iter=n_iterations)
            graph_loss = session.graph('loss', kind='min')
            graph_step_size = session.graph('step_size')

        loaded = self.dataloader.load()
        for epoch in range(n_iterations):

            step_sizes_gen = []
            step_sizes_dis = []

            for i, (images, labels) in enumerate(loaded):
                # Store previous parameters for step size computation
                w_gen_previous = generator.parameters
                w_dis_previous = discriminator.parameters

                # Build mini-batch dataset
                batch_size = images.size(0)
                images = to_pytorch_variable(images.view(batch_size, -1))

                # ============= Train the discriminator =============#
                d_loss = discriminator.compute_loss_against(generator,
                                                            images)[0]

                discriminator.net.zero_grad()
                d_loss.backward()
                d_optimizer.step()

                # =============== Train the generator ===============#
                g_loss, fake_images = generator.compute_loss_against(
                    discriminator, images)

                discriminator.net.zero_grad()
                generator.net.zero_grad()
                g_loss.backward()
                g_optimizer.step()

                step_sizes_gen.append(
                    np.linalg.norm(generator.parameters - w_gen_previous))
                step_sizes_dis.append(
                    np.linalg.norm(discriminator.parameters - w_dis_previous))

                if (i + 1) % 300 == 0:
                    self._logger.info(
                        'Epoch [%d/%d], Step[%d/%d], d_loss: %.4f, '
                        'g_loss: %.4f' % (epoch, n_iterations, i + 1, 600,
                                          d_loss.data[0], g_loss.data[0]))

            if graph_loss is not None:
                graph_loss.append(
                    epoch, {
                        'L(gen(x)) - Backprop': float(g_loss),
                        'L(disc(x)) - Backprop': float(d_loss)
                    })
                graph_step_size.append(
                    epoch, {
                        'avg_step_size(g(x)) - Backprop':
                        np.mean(step_sizes_gen),
                        'avg_step_size(d(x)) - Backprop':
                        np.mean(step_sizes_dis)
                    })

            # Save real images once
            if epoch == 0:
                self.dataloader.save_images(images,
                                            loaded.dataset.train_data.shape,
                                            'real_images.png')

            z = to_pytorch_variable(
                torch.randn(min(batch_size, 100),
                            self.network_factory.gen_input_size))
            generated_output = generator.net(z)
            self.dataloader.save_images(generated_output,
                                        loaded.dataset.train_data.shape,
                                        'fake_images-%d.png' % (epoch + 1))
            self._logger.info(
                'Epoch [%d/%d], d_loss: %.4f, g_loss: %.4f' %
                (epoch, n_iterations, d_loss.data[0], g_loss.data[0]))

            #
            # # Save real images once
            # if epoch == 0:
            #     save_images(images, loaded.dataset.train_data.shape, 'real_images.png')
            #
            # # Save sampled images
            # save_images(fake_images, loaded.dataset.train_data.shape, 'fake_images-%d.png' % (epoch + 1))

        if session is not None:
            session.done()

        return (generator, g_loss), (discriminator, d_loss)