コード例 #1
0
    def validate_step(self,
                      x,
                      y):
        """Perform a validation step on an ensemble of models
        without using bootstrapping weights

        Args:

        x: tf.Tensor
            a batch of validation inputs shaped like [batch_size, channels]
        y: tf.Tensor
            a batch of validation labels shaped like [batch_size, 1]

        Returns:

        statistics: dict
            a dictionary that contains logging information
        """

        # corrupt the inputs with noise
        x = cont_noise(x, self.continuous_noise_std)
        statistics = dict()

        # calculate the prediction error and accuracy of the model
        d = self.forward_model(x, training=False)
        nll = tf.keras.losses.mean_squared_error(y, d)

        # evaluate how correct the rank fo the model predictions are
        rank_correlation = spearman(y[:, 0], d[:, 0])

        statistics[f'{self.logger_prefix}/validate/nll'] = nll
        statistics[f'{self.logger_prefix}/validate/rank_corr'] = rank_correlation

        return statistics
コード例 #2
0
    def validate_step(self, x, y):
        """Perform a validation step on an ensemble of models
        without using bootstrapping weights

        Args:

        x: tf.Tensor
            a batch of validation inputs shaped like [batch_size, channels]
        y: tf.Tensor
            a batch of validation labels shaped like [batch_size, 1]

        Returns:

        statistics: dict
            a dictionary that contains logging information
        """

        # corrupt the inputs with noise
        x0 = cont_noise(x, self.noise_std)

        statistics = dict()

        # calculate the prediction error and accuracy of the model
        d = self.fm.get_distribution(x0, training=False)
        nll = -d.log_prob(y)

        # evaluate how correct the rank fo the model predictions are
        rank_correlation = spearman(y[:, 0], d.mean()[:, 0])

        statistics[f'validate/nll'] = nll
        statistics[f'validate/rank_corr'] = rank_correlation

        return statistics
コード例 #3
0
    def train_step(self,
                   x,
                   y):
        """Perform a training step of gradient descent on an ensemble
        using bootstrap weights for each model in the ensemble

        Args:

        x: tf.Tensor
            a batch of training inputs shaped like [batch_size, channels]
        y: tf.Tensor
            a batch of training labels shaped like [batch_size, 1]
        b: tf.Tensor
            bootstrap indicators shaped like [batch_size, num_oracles]

        Returns:

        statistics: dict
            a dictionary that contains logging information
        """

        # corrupt the inputs with noise
        x = cont_noise(x, self.continuous_noise_std)
        statistics = dict()

        with tf.GradientTape() as tape:

            # calculate the prediction error and accuracy of the model
            d = self.forward_model(x, training=True)
            nll = tf.keras.losses.mean_squared_error(y, d)

            # evaluate how correct the rank fo the model predictions are
            rank_correlation = spearman(y[:, 0], d[:, 0])

            multiplier_loss = 0.0
            last_weight = self.forward_model.trainable_variables[-1]
            if tf.shape(tf.reshape(last_weight, [-1]))[0] == 1:
                statistics[f'{self.logger_prefix}/train/tanh_multipier'] = \
                    self.forward_model.trainable_variables[-1]

            # build the total loss and weight by the bootstrap
            total_loss = tf.reduce_mean(nll) + multiplier_loss

        grads = tape.gradient(total_loss,
                              self.forward_model.trainable_variables)
        self.forward_model_optim.apply_gradients(
            zip(grads, self.forward_model.trainable_variables))

        statistics[f'{self.logger_prefix}/train/nll'] = nll
        statistics[f'{self.logger_prefix}/train/rank_corr'] = rank_correlation

        return statistics
コード例 #4
0
    def validate_step(self,
                      x,
                      y):
        """Perform a validation step on the loss function
        of a conservative objective model

        Args:

        x: tf.Tensor
            a batch of validation inputs shaped like [batch_size, channels]
        y: tf.Tensor
            a batch of validation labels shaped like [batch_size, 1]

        Returns:

        statistics: dict
            a dictionary that contains logging information
        """

        statistics = dict()
        batch_dim = tf.shape(y)[0]

        # corrupt the inputs with noise
        x = cont_noise(x, self.continuous_noise_std)

        # calculate the prediction error and accuracy of the model
        d = self.forward_model(x, training=False)
        mse = tf.keras.losses.mean_squared_error(y, d)
        statistics[f'validate/mse'] = mse

        # evaluate how correct the rank fo the model predictions are
        rank_corr = spearman(y[:, 0], d[:, 0])
        statistics[f'validate/rank_corr'] = rank_corr

        # calculate negative samples starting from the dataset
        x_pos = x
        x_pos = tf.where(tf.random.uniform([batch_dim] + [1 for _ in x.shape[1:]])
                         < self.negatives_fraction, x_pos, self.solution[:batch_dim])
        x_neg = self.lookahead(x_pos, self.lookahead_steps, training=False)
        if not self.lookahead_backprop:
            x_neg = tf.stop_gradient(x_neg)

        # calculate the prediction error and accuracy of the model
        d_pos = self.forward_model(
            {"dataset": x, "mix": x_pos, "solution": self.solution[:batch_dim]}
            [self.constraint_type], training=False)
        d_neg = self.forward_model(x_neg, training=False)
        conservatism = d_neg[:, 0] - d_pos[:, 0]
        statistics[f'validate/conservatism'] = conservatism
        return statistics
コード例 #5
0
    def train_step(self, x, y, b):
        """Perform a training step of gradient descent on an ensemble
        using bootstrap weights for each model in the ensemble

        Args:

        x: tf.Tensor
            a batch of training inputs shaped like [batch_size, channels]
        y: tf.Tensor
            a batch of training labels shaped like [batch_size, 1]
        b: tf.Tensor
            bootstrap indicators shaped like [batch_size, num_oracles]

        Returns:

        statistics: dict
            a dictionary that contains logging information
        """

        statistics = dict()

        for i in range(self.bootstraps):
            model = self.forward_models[i]
            optim = self.forward_model_optims[i]

            # corrupt the inputs with noise
            if self.is_discrete:
                x0 = disc_noise(x, keep=self.keep, temp=self.temp)
            else:
                x0 = cont_noise(x, self.noise_std)

            with tf.GradientTape(persistent=True) as tape:
                # calculate the prediction error and accuracy of the model
                d = model.get_distribution(x0, training=True)
                nll = -d.log_prob(y)
                statistics[f'oracle_{i}/train/nll'] = nll

                # evaluate how correct the rank fo the model predictions are
                rank_correlation = spearman(y[:, 0], d.mean()[:, 0])
                statistics[f'oracle_{i}/train/rank_corr'] = rank_correlation

                # build the total loss
                total_loss = tf.math.divide_no_nan(
                    tf.reduce_sum(b[:, i] * nll), tf.reduce_sum(b[:, i]))

            var_list = model.trainable_variables
            optim.apply_gradients(
                zip(tape.gradient(total_loss, var_list), var_list))

        return statistics
コード例 #6
0
    def train_step(self, x, y, b):
        """Perform a training step of gradient descent on an ensemble
        using bootstrap weights for each model in the ensemble

        Args:

        x: tf.Tensor
            a batch of training inputs shaped like [batch_size, channels]
        y: tf.Tensor
            a batch of training labels shaped like [batch_size, 1]
        b: tf.Tensor
            bootstrap indicators shaped like [batch_size, num_oracles]

        Returns:

        statistics: dict
            a dictionary that contains logging information
        """

        # corrupt the inputs with noise
        x0 = cont_noise(x, self.noise_std)

        statistics = dict()
        with tf.GradientTape(persistent=True) as tape:

            # calculate the prediction error and accuracy of the model
            d = self.fm.get_distribution(x0, training=True)
            nll = -d.log_prob(y)

            # evaluate how correct the rank fo the model predictions are
            rank_correlation = spearman(y[:, 0], d.mean()[:, 0])

            # model loss that combines maximum likelihood
            model_loss = nll

            # build the total and lagrangian losses
            denom = tf.reduce_sum(b)
            total_loss = tf.math.divide_no_nan(tf.reduce_sum(b * model_loss),
                                               denom)

        grads = tape.gradient(total_loss, self.fm.trainable_variables)
        self.optim.apply_gradients(zip(grads, self.fm.trainable_variables))

        statistics[f'train/nll'] = nll
        statistics[f'train/rank_corr'] = rank_correlation

        return statistics
コード例 #7
0
    def validate_step(self, x, y):
        """Perform a validation step on an ensemble of models
        without using bootstrapping weights

        Args:

        x: tf.Tensor
            a batch of validation inputs shaped like [batch_size, channels]
        y: tf.Tensor
            a batch of validation labels shaped like [batch_size, 1]

        Returns:

        statistics: dict
            a dictionary that contains logging information
        """

        statistics = dict()

        for i in range(self.bootstraps):
            model = self.forward_models[i]

            # corrupt the inputs with noise
            if self.is_discrete:
                x0 = disc_noise(x, keep=self.keep, temp=self.temp)
            else:
                x0 = cont_noise(x, self.noise_std)

            # calculate the prediction error and accuracy of the model
            d = model.get_distribution(x0, training=False)
            nll = -d.log_prob(y)
            statistics[f'oracle_{i}/validate/nll'] = nll

            # evaluate how correct the rank fo the model predictions are
            rank_correlation = spearman(y[:, 0], d.mean()[:, 0])
            statistics[f'oracle_{i}/validate/rank_corr'] = rank_correlation

        return statistics
コード例 #8
0
    def train_step(self,
                   x,
                   y):
        """Perform a training step of gradient descent on an ensemble
        using bootstrap weights for each model in the ensemble

        Args:

        x: tf.Tensor
            a batch of training inputs shaped like [batch_size, channels]
        y: tf.Tensor
            a batch of training labels shaped like [batch_size, 1]

        Returns:

        statistics: dict
            a dictionary that contains logging information
        """

        # corrupt the inputs with noise
        x = cont_noise(x, self.continuous_noise_std)

        statistics = dict()
        with tf.GradientTape(persistent=True) as tape:

            # calculate the prediction error and accuracy of the model
            d_pos = self.forward_model(x, training=True)
            mse = tf.keras.losses.mean_squared_error(y, d_pos)
            statistics[f'train/mse'] = mse

            # evaluate how correct the rank fo the model predictions are
            rank_corr = spearman(y[:, 0], d_pos[:, 0])
            statistics[f'train/rank_corr'] = rank_corr

            # calculate negative samples starting from the dataset
            x_neg = self.outer_optimize(
                x, self.beta, self.outer_gradient_steps, training=False)
            x_neg = tf.stop_gradient(x_neg)

            # calculate the prediction error and accuracy of the model
            d_neg = self.forward_model(x_neg, training=False)
            conservatism = d_pos[:, 0] - d_neg[:, 0]
            statistics[f'train/conservatism'] = conservatism

            # build a lagrangian for dual descent
            alpha_loss = -(self.alpha * self.target_conservatism -
                           self.alpha * conservatism)
            statistics[f'train/alpha'] = self.alpha

            # loss that combines maximum likelihood with a constraint
            model_loss = mse - self.alpha * conservatism
            total_loss = tf.reduce_mean(model_loss)
            alpha_loss = tf.reduce_mean(alpha_loss)

        # calculate gradients using the model
        alpha_grads = tape.gradient(alpha_loss, self.log_alpha)
        model_grads = tape.gradient(
            total_loss, self.forward_model.trainable_variables)

        # take gradient steps on the model
        self.alpha_opt.apply_gradients([[alpha_grads, self.log_alpha]])
        self.forward_model_opt.apply_gradients(zip(
            model_grads, self.forward_model.trainable_variables))

        return statistics
コード例 #9
0
    def train_step(self,
                   x,
                   y):
        """Perform a training step of gradient descent on the loss function
        of a conservative objective model

        Args:

        x: tf.Tensor
            a batch of training inputs shaped like [batch_size, channels]
        y: tf.Tensor
            a batch of training labels shaped like [batch_size, 1]

        Returns:

        statistics: dict
            a dictionary that contains logging information
        """

        self.step.assign_add(1)
        statistics = dict()
        batch_dim = tf.shape(y)[0]

        # corrupt the inputs with noise
        x = cont_noise(x, self.continuous_noise_std)

        with tf.GradientTape(persistent=True) as tape:

            # calculate the prediction error and accuracy of the model
            d = self.forward_model(x, training=True)
            mse = tf.keras.losses.mean_squared_error(y, d)
            statistics[f'train/mse'] = mse

            # evaluate how correct the rank fo the model predictions are
            rank_corr = spearman(y[:, 0], d[:, 0])
            statistics[f'train/rank_corr'] = rank_corr

            # calculate negative samples starting from the dataset
            x_pos = x
            x_pos = tf.where(tf.random.uniform([batch_dim] + [1 for _ in x.shape[1:]])
                             < self.negatives_fraction, x_pos, self.solution[:batch_dim])
            x_neg = self.lookahead(x_pos, self.lookahead_steps, training=False)
            if not self.lookahead_backprop:
                x_neg = tf.stop_gradient(x_neg)

            # calculate the prediction error and accuracy of the model
            d_pos = self.forward_model(
                {"dataset": x, "mix": x_pos, "solution": self.solution[:batch_dim]}
                [self.constraint_type], training=False)
            d_neg = self.forward_model(x_neg, training=False)
            conservatism = d_neg[:, 0] - d_pos[:, 0]
            statistics[f'train/conservatism'] = conservatism

            # build a lagrangian for dual descent
            alpha_loss = (self.alpha * self.target_conservatism -
                          self.alpha * conservatism)
            statistics[f'train/alpha'] = self.alpha

            multiplier_loss = 0.0
            last_weight = self.forward_model.trainable_variables[-1]
            if tf.shape(tf.reshape(last_weight, [-1]))[0] == 1:
                statistics[f'train/tanh_multipier'] = \
                    self.forward_model.trainable_variables[-1]

            # loss that combines maximum likelihood with a constraint
            model_loss = mse + self.alpha * conservatism + multiplier_loss
            total_loss = tf.reduce_mean(model_loss)
            alpha_loss = tf.reduce_mean(alpha_loss)

        # initialize stateful variables at the first iteration
        if self.particle_loss is None:
            initialization = tf.zeros_like(conservatism)
            self.particle_loss = tf.Variable(initialization)
            self.particle_constraint = tf.Variable(initialization)

        # calculate gradients using the model
        alpha_grads = tape.gradient(alpha_loss, self.log_alpha)
        model_grads = tape.gradient(
            total_loss, self.forward_model.trainable_variables)

        # occasionally take gradient ascent steps on the solution
        if tf.logical_and(
                tf.equal(tf.math.mod(self.step, self.solver_interval), 0),
                tf.math.greater_equal(self.step, self.solver_warmup)):
            with tf.GradientTape() as tape:

                # take gradient steps on the model
                self.alpha_opt.apply_gradients([[alpha_grads, self.log_alpha]])
                self.forward_model_opt.apply_gradients(
                    zip(model_grads, self.forward_model.trainable_variables))

                # calculate the predicted score of the current solution
                current_score_new_model = self.forward_model(
                    self.solution, training=False)[:, 0]

                # look into the future and evaluate future solutions
                future_new_model = self.lookahead(
                    self.solution, self.solver_steps, training=False)
                future_score_new_model = self.forward_model(
                    future_new_model, training=False)[:, 0]

                # evaluate the conservatism of the current solution
                particle_loss = (self.solver_beta * future_score_new_model -
                                 current_score_new_model)
                update = (self.solution - self.solver_lr *
                          tape.gradient(particle_loss, self.solution))

            # if optimizer conservatism passes threshold stop optimizing
            self.solution.assign(tf.where(self.done, self.solution, update))
            self.particle_loss.assign(particle_loss)
            self.particle_constraint.assign(
                future_score_new_model - current_score_new_model)

        else:

            # take gradient steps on the model
            self.alpha_opt.apply_gradients([[alpha_grads, self.log_alpha]])
            self.forward_model_opt.apply_gradients(
                zip(model_grads, self.forward_model.trainable_variables))

        statistics[f'train/done'] = tf.cast(self.done, tf.float32)
        statistics[f'train/particle_loss'] = self.particle_loss
        statistics[f'train/particle_constraint'] = self.particle_constraint

        return statistics
コード例 #10
0
    def validate_step(self, x_real, y_real):
        """Perform a validation step for a generator and a discriminator
        using a least squares objective function

        Args:

        x_real: tf.Tensor
            a batch of validation inputs shaped like [batch_size, channels]
        y_real: tf.Tensor
            a batch of validation labels shaped like [batch_size, 1]

        Returns:

        statistics: dict
            a dictionary that contains logging information
        """

        statistics = dict()
        batch_dim = tf.shape(y_real)[0]

        # corrupt the inputs with noise
        if self.is_discrete:
            x_real = disc_noise(x_real, keep=self.keep, temp=self.temp)
        else:
            x_real = cont_noise(x_real, self.noise_std)

        # evaluate the discriminator on generated samples
        x_fake = self.generator.sample(y_real, temp=self.temp, training=False)
        p_fake, d_fake, acc_fake = self.discriminator.loss(x_fake,
                                                           y_real,
                                                           tf.zeros(
                                                               [batch_dim, 1]),
                                                           training=False)

        statistics[f'generator/validate/x_fake'] = x_fake
        statistics[f'generator/validate/y_real'] = y_real
        statistics[f'discriminator/validate/p_fake'] = p_fake
        statistics[f'discriminator/validate/d_fake'] = d_fake
        statistics[f'discriminator/validate/acc_fake'] = acc_fake

        x_pair = tf.zeros_like(x_fake)
        p_pair = tf.zeros_like(p_fake)
        d_pair = tf.zeros_like(d_fake)
        acc_pair = tf.zeros_like(acc_fake)

        if self.fake_pair_frac > 0:

            # evaluate the discriminator on fake pairs of real inputs
            x_pair = tf.random.shuffle(x_real)
            p_pair, d_pair, acc_pair = self.discriminator.loss(
                x_pair, y_real, tf.zeros([batch_dim, 1]), training=False)

        statistics[f'generator/validate/x_pair'] = x_pair
        statistics[f'discriminator/validate/p_pair'] = p_pair
        statistics[f'discriminator/validate/d_pair'] = d_pair
        statistics[f'discriminator/validate/acc_pair'] = acc_pair

        x_pool = tf.zeros_like(x_fake)
        p_pool = tf.zeros_like(p_fake)
        d_pool = tf.zeros_like(d_fake)
        acc_pool = tf.zeros_like(acc_fake)

        if self.pool.size > batch_dim and self.pool_frac > 0:

            # evaluate discriminator on samples from a replay buffer
            x_pool, y_pool = self.pool.sample(batch_dim)
            p_pool, d_pool, acc_pool = self.discriminator.loss(
                x_pool, y_pool, tf.zeros([batch_dim, 1]), training=False)

        statistics[f'generator/validate/x_pool'] = x_pool
        statistics[f'discriminator/validate/p_pool'] = p_pool
        statistics[f'discriminator/validate/d_pool'] = d_pool
        statistics[f'discriminator/validate/acc_pool'] = acc_pool

        # evaluate the discriminator on real inputs
        p_real, d_real, acc_real = self.discriminator.loss(x_real,
                                                           y_real,
                                                           tf.ones(
                                                               [batch_dim, 1]),
                                                           training=False)

        statistics[f'generator/validate/x_real'] = x_real
        statistics[f'discriminator/validate/p_real'] = p_real
        statistics[f'discriminator/validate/d_real'] = d_real
        statistics[f'discriminator/validate/acc_real'] = acc_real

        # evaluate a gradient penalty on interpolations
        e = tf.random.uniform([batch_dim] + [1] * (len(x_fake.shape) - 1))
        x_interp = x_real * e + x_fake * (1 - e)
        penalty = self.discriminator.penalty(x_interp, y_real, training=False)

        statistics[f'discriminator/validate/neg_critic_loss'] = -(d_real +
                                                                  d_fake)
        statistics[f'discriminator/validate/penalty'] = penalty

        return statistics
コード例 #11
0
    def train_step(self, i, x_real, y_real, w):
        """Perform a training step for a generator and a discriminator
        using a least squares objective function

        Args:

        x_real: tf.Tensor
            a batch of training inputs shaped like [batch_size, channels]
        y_real: tf.Tensor
            a batch of training labels shaped like [batch_size, 1]
        w: tf.Tensor
            importance sampling weights shaped like [batch_size, 1]

        Returns:

        statistics: dict
            a dictionary that contains logging information
        """

        statistics = dict()
        batch_dim = tf.shape(y_real)[0]

        # corrupt the inputs with noise
        if self.is_discrete:
            x_real = disc_noise(x_real, keep=self.keep, temp=self.temp)
        else:
            x_real = cont_noise(x_real, self.noise_std)

        with tf.GradientTape() as tape:

            # evaluate the discriminator on generated samples
            x_fake = self.generator.sample(y_real,
                                           temp=self.temp,
                                           training=False)
            p_fake, d_fake, acc_fake = self.discriminator.loss(
                x_fake, y_real, tf.zeros([batch_dim, 1]), training=False)

            statistics[f'generator/train/x_fake'] = x_fake
            statistics[f'generator/train/y_real'] = y_real
            statistics[f'discriminator/train/p_fake'] = p_fake
            statistics[f'discriminator/train/d_fake'] = d_fake
            statistics[f'discriminator/train/acc_fake'] = acc_fake

            # normalize the fake evaluation metrics
            d_fake = d_fake * (1.0 - self.fake_pair_frac - self.pool_frac)

            x_pair = tf.zeros_like(x_fake)
            p_pair = tf.zeros_like(p_fake)
            d_pair = tf.zeros_like(d_fake)
            acc_pair = tf.zeros_like(acc_fake)

            if self.fake_pair_frac > 0:

                # evaluate the discriminator on fake pairs of real inputs
                x_pair = tf.random.shuffle(x_real)
                p_pair, d_pair, acc_pair = self.discriminator.loss(
                    x_pair, y_real, tf.zeros([batch_dim, 1]), training=False)

                # average the metrics between fake samples
                d_fake = d_pair * self.fake_pair_frac + d_fake

            statistics[f'generator/train/x_pair'] = x_pair
            statistics[f'discriminator/train/p_pair'] = p_pair
            statistics[f'discriminator/train/d_pair'] = d_pair
            statistics[f'discriminator/train/acc_pair'] = acc_pair

            x_pool = tf.zeros_like(x_fake)
            p_pool = tf.zeros_like(p_fake)
            d_pool = tf.zeros_like(d_fake)
            acc_pool = tf.zeros_like(acc_fake)

            if self.pool.size > batch_dim and self.pool_frac > 0:

                # evaluate discriminator on samples from a replay buffer
                x_pool, y_pool = self.pool.sample(batch_dim)
                p_pool, d_pool, acc_pool = self.discriminator.loss(
                    x_pool, y_pool, tf.zeros([batch_dim, 1]), training=False)

                # average the metrics between fake samples
                d_fake = d_pool * self.pool_frac + d_fake

            statistics[f'generator/train/x_pool'] = x_pool
            statistics[f'discriminator/train/p_pool'] = p_pool
            statistics[f'discriminator/train/d_pool'] = d_pool
            statistics[f'discriminator/train/acc_pool'] = acc_pool

            if self.pool_save > 0:

                # possibly add more generated samples to the replay pool
                self.pool.insert_many(x_fake[:self.pool_save],
                                      y_real[:self.pool_save])

            # evaluate the discriminator on real inputs
            labels = tf.cast(
                self.flip_frac <= tf.random.uniform([batch_dim, 1]),
                tf.float32)
            p_real, d_real, acc_real = self.discriminator.loss(x_real,
                                                               y_real,
                                                               labels,
                                                               training=True)

            statistics[f'generator/train/x_real'] = x_real
            statistics[f'discriminator/train/p_real'] = p_real
            statistics[f'discriminator/train/d_real'] = d_real
            statistics[f'discriminator/train/acc_real'] = acc_real

            # evaluate a gradient penalty on interpolations
            e = tf.random.uniform([batch_dim] + [1] * (len(x_fake.shape) - 1))
            x_interp = x_real * e + x_fake * (1 - e)
            penalty = self.discriminator.penalty(x_interp,
                                                 y_real,
                                                 training=False)

            statistics[f'discriminator/train/neg_critic_loss'] = -(d_real +
                                                                   d_fake)
            statistics[f'discriminator/train/penalty'] = penalty

            # build the total loss
            total_loss = tf.reduce_mean(
                w * (d_real + d_fake + self.penalty_weight * penalty))

        var_list = self.discriminator.trainable_variables
        self.discriminator_optim.apply_gradients(
            zip(tape.gradient(total_loss, var_list), var_list))

        if tf.equal(tf.math.floormod(i, self.critic_frequency), 0):

            with tf.GradientTape() as tape:

                # evaluate the discriminator on generated samples
                x_fake = self.generator.sample(y_real,
                                               temp=self.temp,
                                               training=True)
                p_fake, d_fake, acc_fake = self.discriminator.loss(
                    x_fake, y_real, tf.ones([batch_dim, 1]), training=False)

                # build the total loss
                total_loss = tf.reduce_mean(w * d_fake)

            var_list = self.generator.trainable_variables
            self.generator_optim.apply_gradients(
                zip(tape.gradient(total_loss, var_list), var_list))

        return statistics