def train(self, dataset):
        # Determine trainable variables
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

        # Build losses
        self.losses['L1'] = tf.losses.absolute_difference(
            self.x, self.reconstruction, reduction=Reduction.NONE)
        self.losses['reconstructionLoss'] = tf.reduce_mean(
            tf.reduce_sum(self.losses['L1'], axis=[1, 2, 3]))

        self.losses['L2'] = l2 = tf.reduce_mean(tf.losses.mean_squared_error(
            self.x, self.reconstruction, reduction=Reduction.NONE),
                                                axis=[1, 2, 3])
        self.losses['Rec_z'] = rec_z = tf.reduce_mean(
            tf.losses.mean_squared_error(self.z,
                                         self.z_rec,
                                         reduction=Reduction.NONE),
            axis=[1])

        self.losses['loss'] = tf.reduce_mean(l2 + self.rho * rec_z)

        # Set the optimizer
        optim = self.create_optimizer(self.losses['loss'],
                                      var_list=self.variables,
                                      learningrate=self.config.learningrate,
                                      beta1=self.config.beta1,
                                      type=self.config.optimizer)

        # initialize all variables
        tf.global_variables_initializer().run(session=self.sess)

        best_cost = inf
        last_improvement = 0
        last_epoch = self.load_checkpoint()

        # Go go go!
        for epoch in range(last_epoch, self.config.numEpochs):
            ############
            # TRAINING #
            ############
            self.process(dataset, epoch, Phase.TRAIN, optim)

            # Increment last_epoch counter and save model
            last_epoch += 1
            self.save(self.checkpointDir, last_epoch)

            ##############
            # VALIDATION #
            ##############
            val_scalars = self.process(dataset, epoch, Phase.VAL)

            best_cost, last_improvement, stop = indicate_early_stopping(
                val_scalars['loss'], best_cost, last_improvement)
            if stop:
                print(
                    'Early stopping was triggered due to no improvement over the last 5 epochs'
                )
                break
Beispiel #2
0
    def train(self, dataset):
        # Determine trainable variables
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

        # Build losses
        self.losses['L1_vae'] = tf.losses.absolute_difference(self.x, self.reconstruction, reduction=Reduction.NONE)
        self.losses['L1_ce'] = tf.losses.absolute_difference(self.x_ce, self.reconstruction_ce, reduction=Reduction.NONE)
        self.losses['L1'] = 0.5 * (self.losses['L1_vae'] + self.losses['L1_ce'])
        rec_vae = tf.reduce_sum(self.losses['L1_vae'], axis=[1, 2, 3])
        rec_ce = tf.reduce_sum(self.losses['L1_ce'], axis=[1, 2, 3])
        kl = 0.5 * tf.reduce_sum(tf.square(self.z_mu) + tf.square(self.z_sigma) - tf.log(tf.square(self.z_sigma)) - 1, axis=1)

        self.losses['Rec_ce'] = tf.reduce_mean(rec_ce)
        self.losses['Rec_vae'] = tf.reduce_mean(rec_vae)
        self.losses['reconstructionLoss'] = 0.5 * tf.reduce_mean(rec_vae + rec_ce)
        self.losses['kl'] = tf.reduce_mean(kl)
        self.losses['loss'] = tf.reduce_mean(rec_vae + kl + rec_ce)
        self.losses['loss_vae'] = tf.reduce_mean(rec_vae + kl)
        self.losses['anomaly'] = self.losses['L1_vae'] * tf.abs(tf.gradients(self.losses['loss_vae'], self.x))[0]

        # Set the optimizer
        optim = self.create_optimizer(self.losses['loss'], var_list=self.variables, learningrate=self.config.learningrate,
                                      beta1=self.config.beta1, type=self.config.optimizer)

        # initialize all variables
        tf.global_variables_initializer().run(session=self.sess)

        best_cost = inf
        last_improvement = 0
        last_epoch = self.load_checkpoint()

        visualization_keys = ['reconstruction', 'reconstruction_ce', 'anomaly']
        # Go go go!
        for epoch in range(last_epoch, self.config.numEpochs):
            ############
            # TRAINING #
            ############
            self.process(dataset, epoch, Phase.TRAIN, optim, visualization_keys=visualization_keys)

            # Increment last_epoch counter and save model
            last_epoch += 1
            self.save(self.checkpointDir, last_epoch)

            ##############
            # VALIDATION #
            ##############
            val_scalars = self.process(dataset, epoch, Phase.VAL, visualization_keys=visualization_keys)

            best_cost, last_improvement, stop = indicate_early_stopping(val_scalars['loss'], best_cost, last_improvement)
            if stop:
                print('Early stopping was triggered due to no improvement over the last 5 epochs')
                break
Beispiel #3
0
    def train(self, dataset):
        # Determine trainable variables
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

        # Build losses
        self.losses['disc_fake'] = disc_fake = tf.reduce_mean(self.d_)
        self.losses['disc_real'] = disc_real = tf.reduce_mean(self.d)
        disc_loss = disc_fake - disc_real

        ddx = tf.gradients(self.d_hat, self.x_hat)[0]  # gradient
        ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1))  # slopes
        ddx = tf.reduce_mean(tf.square(ddx - 1.0)) * self.scale  # gradient penalty
        self.losses['disc_loss'] = disc_loss = disc_loss + ddx

        # Build losses
        kl = 0.5 * tf.reduce_sum(tf.square(self.z_mu) + tf.square(self.z_sigma) - tf.log(tf.square(self.z_sigma)) - 1,
                                 axis=1)
        self.losses['kl'] = loss_kl = tf.reduce_mean(kl)

        self.losses['loss_img'] = tf.reduce_mean(
            tf.reduce_mean(tf.losses.mean_squared_error(self.x, self.reconstruction, reduction=Reduction.NONE), axis=[1, 2, 3]))
        self.losses['loss_fts'] = tf.reduce_mean(
            tf.reduce_mean(tf.losses.mean_squared_error(self.d_fake_features, self.d_features, reduction=Reduction.NONE), axis=[1, 2, 3]))
        self.losses['L1'] = tf.losses.absolute_difference(self.x, self.reconstruction, reduction=Reduction.NONE)
        self.losses['reconstructionLoss'] = self.losses['loss'] = tf.reduce_mean(tf.reduce_sum(self.losses['L1'], axis=[1, 2, 3]))

        self.losses['gen_loss'] = gen_loss = - disc_fake
        self.losses['enc_loss'] = enc_loss = self.losses['reconstructionLoss'] + self.kl_weight * loss_kl

        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            # Set the optimizer
            t_vars = tf.trainable_variables()
            dis_vars = [var for var in t_vars if 'Discriminator' in var.name]
            gen_vars = [var for var in t_vars if 'Generator' in var.name]
            enc_vars = [var for var in t_vars if 'Encoder' in var.name]

            optim_dis = tf.train.AdamOptimizer(learning_rate=self.config.learningrate, beta1=0.5, beta2=0.9).minimize(disc_loss, var_list=dis_vars)
            optim_gen = tf.train.AdamOptimizer(learning_rate=self.config.learningrate, beta1=0.5, beta2=0.9).minimize(gen_loss, var_list=gen_vars)
            optim_vae = tf.train.AdamOptimizer(learning_rate=self.config.learningrate, beta1=0.5, beta2=0.9).minimize(enc_loss, var_list=enc_vars + gen_vars)

        # initialize all variables
        tf.global_variables_initializer().run(session=self.sess)

        best_cost = inf
        last_improvement = 0
        last_epoch = self.load_checkpoint()

        # Go go go!
        for epoch in range(last_epoch, self.config.numEpochs):
            #################
            # TRAINING WGAN #
            #################
            phase = Phase.TRAIN
            scalars = defaultdict(list)
            visuals = []
            d_iters = 5
            num_batches = dataset.num_batches(self.config.batchsize, set=phase.value)
            for idx in range(0, num_batches):
                batch, _, _ = dataset.next_batch(self.config.batchsize, set=phase.value)

                # Encoder optimization
                fetches = {
                    # 'generated': self.generated,
                    'reconstruction': self.reconstruction,
                    'reconstructionLoss': self.losses['reconstructionLoss'],
                    'L1': self.losses['L1'],
                    'enc_loss': self.losses['enc_loss'],
                    'optimizer_e': optim_vae,
                }

                feed_dict = {
                    self.x: batch,
                    self.dropout: phase == Phase.TRAIN,
                    self.dropout_rate: self.config.dropout_rate
                }
                run = self.sess.run(fetches, feed_dict=feed_dict)

                # Generator optimization
                fetches = {
                    'gen_loss': self.losses['gen_loss'],
                    'optimizer_g': optim_gen,
                }

                feed_dict = {
                    self.x: batch,
                    self.dropout: phase == Phase.TRAIN,
                    self.dropout_rate: self.config.dropout_rate
                }
                run = {**run, **self.sess.run(fetches, feed_dict=feed_dict)}

                for _ in range(0, d_iters):
                    # Discriminator optimization
                    fetches = {
                        'disc_loss': self.losses['disc_loss'],
                        'disc_fake': self.losses['disc_fake'],
                        'disc_real': self.losses['disc_real'],
                        'optimizer_d': optim_dis,
                    }
                    feed_dict = {
                        self.x: batch,
                        self.dropout: phase == Phase.TRAIN,
                        self.dropout_rate: self.config.dropout_rate
                    }
                    run = {**run, **self.sess.run(fetches, feed_dict=feed_dict)}

                # Print to console
                print(f'Epoch ({phase.value}): [{epoch:2d}] [{idx:4d}/{num_batches:4d}]'
                      f' gen_loss: {run["gen_loss"]:.8f}, disc_loss: {run["disc_loss"]:.8f}, reconstructionLoss: {run["reconstructionLoss"]:.8f}')
                update_log_dicts(*trainer_utils.get_summary_dict(batch, run), scalars, visuals)

            self.log_to_tensorboard(epoch, scalars, visuals, phase)

            # Increment last_epoch counter and save model
            last_epoch += 1
            self.save(self.checkpointDir, last_epoch)

            ##############
            # VALIDATION #
            ##############
            phase = Phase.VAL
            scalars = defaultdict(list)
            visuals = []
            num_batches = dataset.num_batches(self.config.batchsize, set=phase.value)
            for idx in range(0, num_batches):
                batch, _, _ = dataset.next_batch(self.config.batchsize, set=phase.value)

                # Encoder optimization
                fetches = {
                    'reconstruction': self.reconstruction,
                    'reconstructionLoss': self.losses['reconstructionLoss'],
                    'L1': self.losses['L1'],
                    'enc_loss': self.losses['enc_loss'],
                }

                feed_dict = {
                    self.x: batch,
                    self.dropout: phase == Phase.TRAIN,
                    self.dropout_rate: self.config.dropout_rate
                }
                run = self.sess.run(fetches, feed_dict=feed_dict)
                # Print to console
                print(f'Epoch ({phase.value}): [{epoch:2d}] [{idx:4d}/{num_batches:4d}] reconstructionLoss: {run["reconstructionLoss"]:.8f}')
                update_log_dicts(*trainer_utils.get_summary_dict(batch, run), scalars, visuals)

            self.log_to_tensorboard(epoch, scalars, visuals, phase)

            best_cost, last_improvement, stop = indicate_early_stopping(scalars['reconstructionLoss'], best_cost, last_improvement)
            if stop:
                print('Early stopping was triggered due to no improvement over the last 5 epochs')
                break
Beispiel #4
0
    def train(self, dataset):
        # Determine trainable variables
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

        # Build losses
        # 1. the reconstruction loss
        self.losses['L1'] = tf.losses.absolute_difference(
            self.x, self.xz_mu, reduction=Reduction.NONE)
        self.losses['L1_sum'] = tf.reduce_sum(self.losses['L1'],
                                              axis=[1, 2, 3])
        self.losses['reconstructionLoss'] = self.losses[
            'mean_p_loss'] = mean_p_loss = tf.reduce_mean(
                self.losses['L1_sum'])
        self.losses['L2'] = tf.losses.mean_squared_error(
            self.x, self.xz_mu, reduction=Reduction.NONE)
        self.losses['L2_sum'] = tf.reduce_sum(self.losses['L2'])

        # 2. E_c_w[KL(q(z|x)|| p(z|w, c))]
        # calculate KL for each cluster
        # KL  = 1/2(  logvar2 - logvar1 + (var1 + (m1-m2)^2)/var2  - 1 ) here dim_c clusters, then we have batchsize * dim_z * dim_c
        # then [batchsize * dim_z* dim_c] * [batchsize * dim_c * 1]  = batchsize * dim_z * 1, squeeze it to batchsize * dim_z
        self.z_mu = tf.tile(tf.expand_dims(self.z_mu, -1),
                            [1, 1, 1, 1, self.dim_c])
        z_logvar = tf.tile(tf.expand_dims(self.z_log_sigma, -1),
                           [1, 1, 1, 1, self.dim_c])
        d_mu_2 = tf.squared_difference(self.z_mu, self.z_wc_mu)
        d_var = (tf.exp(z_logvar) +
                 d_mu_2) * (tf.exp(self.z_wc_log_sigma_inv) + 1e-6)
        d_logvar = -1 * (self.z_wc_log_sigma_inv + z_logvar)
        kl = (d_var + d_logvar - 1) * 0.5
        con_prior_loss = tf.reduce_sum(
            tf.squeeze(tf.matmul(kl, tf.expand_dims(self.pc, -1)), -1),
            [1, 2, 3])
        self.losses['conditional_prior_loss'] = mean_con_loss = tf.reduce_mean(
            con_prior_loss)

        # 3. KL(q(w|x)|| p(w) ~ N(0, I))
        # KL = 1/2 sum( mu^2 + var - logvar -1 )
        w_loss = 0.5 * tf.reduce_sum(
            tf.square(self.w_mu) + tf.exp(self.w_log_sigma) -
            self.w_log_sigma - 1, [1, 2, 3])
        self.losses['w_prior_loss'] = mean_w_loss = tf.reduce_mean(w_loss)

        # 4. KL(q(c|z)||p(c)) =  - sum_k q(k) log p(k)/q(k) , k = dim_c
        # let p(k) = 1/K#
        closs1 = tf.reduce_sum(
            tf.multiply(self.pc, tf.log(self.pc * self.dim_c + 1e-8)), [3])
        c_lambda = tf.cast(tf.fill(tf.shape(closs1), self.c_lambda),
                           dtype=tf.float32)
        c_loss = tf.maximum(closs1, c_lambda)
        c_loss = tf.reduce_sum(c_loss, [1, 2])
        self.losses['c_prior_loss'] = mean_c_loss = tf.reduce_mean(c_loss)

        self.losses[
            'loss'] = mean_p_loss + mean_con_loss + mean_w_loss + mean_c_loss

        # Reconstruction losses
        self.losses['restore'] = self.tv_lambda * tf.image.total_variation(
            tf.subtract(self.x, self.reconstruction))
        self.losses['grads'] = tf.gradients(
            self.losses['loss'] + self.losses['restore'], self.x)[0]

        # Set the optimizer
        optim = self.create_optimizer(self.losses['loss'],
                                      var_list=self.variables,
                                      learningrate=self.config.learningrate,
                                      beta1=self.config.beta1,
                                      type=self.config.optimizer)

        # initialize all variables
        tf.global_variables_initializer().run(session=self.sess)

        best_cost = inf
        last_improvement = 0
        last_epoch = self.load_checkpoint()

        # Go go go!
        for epoch in range(last_epoch, self.config.numEpochs):
            ############
            # TRAINING #
            ############
            self.process(dataset,
                         epoch,
                         Phase.TRAIN,
                         optim,
                         visualization_keys=['reconstruction', 'L1', 'L2'])

            # Increment last_epoch counter and save model
            last_epoch += 1
            self.save(self.checkpointDir, last_epoch)

            ##############
            # VALIDATION #
            ##############
            val_scalars = self.process(
                dataset,
                epoch,
                Phase.VAL,
                visualization_keys=['reconstruction', 'L1', 'L2'])

            best_cost, last_improvement, stop = indicate_early_stopping(
                val_scalars['loss'], best_cost, last_improvement)
            if stop:
                print(
                    'Early stopping was triggered due to no improvement over the last 5 epochs'
                )
                break

        if self.tv_lambda_value == -1 and self.restore_steps > 0:
            ##############
            # Determine lambda #
            ##############
            print('Determining best lambda')
            self.determine_best_lambda(dataset)
Beispiel #5
0
    def train(self, dataset):
        # Determine trainable variables
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

        # Build losses
        self.losses['L1'] = tf.losses.absolute_difference(
            self.x, self.reconstruction, reduction=Reduction.NONE)
        rec = tf.reduce_sum(self.losses['L1'], axis=[1, 2, 3])
        kl = 0.5 * tf.reduce_sum(tf.square(self.z_mu) + tf.square(self.z_sigma)
                                 - tf.log(tf.square(self.z_sigma)) - 1,
                                 axis=1)
        self.losses['pixel_loss'] = rec + kl
        self.losses['reconstructionLoss'] = tf.reduce_mean(rec)
        self.losses['kl'] = tf.reduce_mean(kl)
        self.losses['loss'] = tf.reduce_mean(rec + kl)

        # for restoration
        self.losses['restore'] = self.tv_lambda * tf.image.total_variation(
            tf.subtract(self.x, self.reconstruction))
        self.losses['grads'] = tf.gradients(
            self.losses['pixel_loss'] + self.losses['restore'], self.x)[0]

        # Set the optimizer
        optim = self.create_optimizer(self.losses['loss'],
                                      var_list=self.variables,
                                      learningrate=self.config.learningrate,
                                      beta1=self.config.beta1,
                                      type=self.config.optimizer)

        # initialize all variables
        tf.global_variables_initializer().run(session=self.sess)

        best_cost = inf
        last_improvement = 0
        last_epoch = self.load_checkpoint()

        # Go go go!
        for epoch in range(last_epoch, self.config.numEpochs):
            ############
            # TRAINING #
            ############
            self.process(dataset, epoch, Phase.TRAIN, optim)

            # Increment last_epoch counter and save model
            last_epoch += 1
            self.save(self.checkpointDir, last_epoch)

            ##############
            # VALIDATION #
            ##############
            val_scalars = self.process(dataset, epoch, Phase.VAL)

            best_cost, last_improvement, stop = indicate_early_stopping(
                val_scalars['loss'], best_cost, last_improvement)
            if stop:
                print(
                    'Early stopping was triggered due to no improvement over the last 5 epochs'
                )
                break

        if self.tv_lambda_value == -1 and self.restore_steps > 0:
            ##############
            # Determine lambda #
            ##############
            print('Determining best lambda')
            self.determine_best_lambda(dataset)
Beispiel #6
0
Datei: AAE.py Projekt: irfixq/AE
    def train(self, dataset):
        # Determine trainable variables
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

        # Build losses
        self.losses['disc_real'] = disc_real = tf.reduce_mean(self.d)
        self.losses['disc_fake'] = disc_fake = tf.reduce_mean(self.d_)
        self.losses['gen_loss'] = gen_loss = -disc_fake
        self.losses['disc_loss_without_grad'] = disc_loss = disc_fake - disc_real

        ddx = tf.gradients(self.d_hat, self.z_hat)[0]
        ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1))
        ddx = tf.reduce_mean(tf.square(ddx - 1.0) * self.scale)
        self.losses['disc_loss'] = disc_loss = disc_loss + ddx

        self.losses['L1'] = tf.losses.absolute_difference(self.x, self.reconstruction, reduction=Reduction.NONE)
        self.losses['reconstructionLoss'] = tf.reduce_mean(tf.reduce_sum(self.losses['L1'], axis=[1, 2, 3]))

        self.losses['L2'] = l2 = tf.reduce_mean(tf.losses.mean_squared_error(self.x, self.reconstruction, reduction=Reduction.NONE), axis=[1, 2, 3])

        self.losses['loss'] = ae_loss = tf.reduce_mean(l2)

        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            # Set the optimizer
            t_vars = tf.trainable_variables()
            dis_vars = [var for var in t_vars if 'Discriminator' in var.name]
            gen_vars = [var for var in t_vars if 'Encoder' in var.name]
            ae_vars = t_vars

            optim_dis = tf.train.AdamOptimizer(learning_rate=self.config.learningrate, beta1=0.5, beta2=0.9).minimize(disc_loss, var_list=dis_vars)
            optim_gen = tf.train.AdamOptimizer(learning_rate=self.config.learningrate, beta1=0.5, beta2=0.9).minimize(gen_loss, var_list=gen_vars)
            optim_ae = tf.train.AdamOptimizer(learning_rate=self.config.learningrate, beta1=0.5, beta2=0.9).minimize(ae_loss, var_list=ae_vars)

        # initialize all variables
        tf.global_variables_initializer().run(session=self.sess)

        best_cost = inf
        last_improvement = 0
        last_epoch = self.load_checkpoint()

        # Go go go!
        for epoch in range(last_epoch, self.config.numEpochs):
            ############
            # TRAINING #
            ############
            phase = Phase.TRAIN
            scalars = defaultdict(list)
            visuals = []
            d_iters = 20
            num_batches = dataset.num_batches(self.config.batchsize, set=phase.value)
            for idx in range(0, num_batches):
                batch, _, _ = dataset.next_batch(self.config.batchsize, set=phase.value)

                run = {}
                for _ in range(d_iters if epoch <= 5 else 1):
                    # AE optimization
                    fetches = {
                        'reconstruction': self.reconstruction,
                        'L1': self.losses['L1'],
                        'loss': self.losses['loss'],
                        'reconstructionLoss': self.losses['reconstructionLoss'],
                        'optimizer_ae': optim_ae
                    }

                    feed_dict = self.get_feed_dict(batch, phase)

                    run = self.sess.run(fetches, feed_dict=feed_dict)

                for _ in range(d_iters):
                    # Discriminator optimization
                    fetches = {
                        'disc_loss': self.losses['disc_loss'],
                        'optimizer_d': optim_dis,
                    }

                    feed_dict = self.get_feed_dict(batch, phase)

                    run = {**run, **self.sess.run(fetches, feed_dict=feed_dict)}

                # Generator optimization
                fetches = {
                    'gen_loss': self.losses['gen_loss'],
                    'optimizer_g': optim_gen,
                }

                feed_dict = self.get_feed_dict(batch, phase)

                run = {**run, **self.sess.run(fetches, feed_dict=feed_dict)}

                # Print to console
                print(f'Epoch ({phase.value}): [{epoch:2d}] [{idx:4d}/{num_batches:4d}] loss: {run["reconstructionLoss"]:.8f},'
                      f' gen_loss: {run["gen_loss"]:.8f}, disc_loss: {run["disc_loss"]:.8f}')
                update_log_dicts(*trainer_utils.get_summary_dict(batch, run), scalars, visuals)

            self.log_to_tensorboard(epoch, scalars, visuals, phase)

            # Increment last_epoch counter and save model
            last_epoch += 1
            self.save(self.checkpointDir, last_epoch)

            ##############
            # VALIDATION #
            ##############
            scalars = defaultdict(list)
            visuals = []
            phase = Phase.VAL
            num_batches = dataset.num_batches(self.config.batchsize, set=phase.value)
            for idx in range(0, num_batches):
                batch, _, _ = dataset.next_batch(self.config.batchsize, set=phase.value)

                fetches = {
                    'reconstruction': self.reconstruction,
                    **self.losses
                }

                feed_dict = self.get_feed_dict(batch, phase)
                run = self.sess.run(fetches, feed_dict=feed_dict)

                # Print to console
                print(f'Epoch ({phase.value}): [{epoch:2d}] [{idx:4d}/{num_batches:4d}] loss: {run["loss"]:.8f}')
                update_log_dicts(*trainer_utils.get_summary_dict(batch, run), scalars, visuals)

            self.log_to_tensorboard(epoch, scalars, visuals, phase)

            best_cost, last_improvement, stop = indicate_early_stopping(scalars['reconstructionLoss'], best_cost, last_improvement)
            if stop:
                print('Early stopping was triggered due to no improvement over the last 5 epochs')
                break