Пример #1
0
    def process(self, dataset, epoch, phase: Phase, optim=None):
        scalars = defaultdict(list)
        visuals = []
        num_batches = dataset.num_batches(self.config.batchsize, set=phase.value)
        for idx in range(0, num_batches):
            batch, _, _ = dataset.next_batch(self.config.batchsize, set=phase.value)

            fetches = {
                'reconstruction': self.reconstruction,
                **self.losses
            }
            if phase == Phase.TRAIN:
                fetches['optimizer'] = optim

            feed_dict = {
                self.x: batch,
                self.dropout: phase == Phase.TRAIN,
                self.dropout_rate: self.config.dropout_rate
            }

            run = self.sess.run(fetches, feed_dict=feed_dict)

            # Print to console
            print(f'Epoch ({phase.value}): [{epoch:2d}] [{idx:4d}/{num_batches:4d}] loss: {run["loss"]:.8f}')
            update_log_dicts(*trainer_utils.get_summary_dict(batch, run), scalars, visuals)

        self.log_to_tensorboard(epoch, scalars, visuals, phase)
        return scalars
Пример #2
0
    def train(self, dataset):
        # Determine trainable variables
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

        # Build losses
        self.losses['disc_fake'] = disc_fake = tf.reduce_mean(self.d_)
        self.losses['disc_real'] = disc_real = tf.reduce_mean(self.d)
        disc_loss = disc_fake - disc_real

        ddx = tf.gradients(self.d_hat, self.x_hat)[0]  # gradient
        ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1))  # slopes
        ddx = tf.reduce_mean(tf.square(ddx - 1.0)) * self.scale  # gradient penalty
        self.losses['disc_loss'] = disc_loss = disc_loss + ddx

        # Build losses
        kl = 0.5 * tf.reduce_sum(tf.square(self.z_mu) + tf.square(self.z_sigma) - tf.log(tf.square(self.z_sigma)) - 1,
                                 axis=1)
        self.losses['kl'] = loss_kl = tf.reduce_mean(kl)

        self.losses['loss_img'] = tf.reduce_mean(
            tf.reduce_mean(tf.losses.mean_squared_error(self.x, self.reconstruction, reduction=Reduction.NONE), axis=[1, 2, 3]))
        self.losses['loss_fts'] = tf.reduce_mean(
            tf.reduce_mean(tf.losses.mean_squared_error(self.d_fake_features, self.d_features, reduction=Reduction.NONE), axis=[1, 2, 3]))
        self.losses['L1'] = tf.losses.absolute_difference(self.x, self.reconstruction, reduction=Reduction.NONE)
        self.losses['reconstructionLoss'] = self.losses['loss'] = tf.reduce_mean(tf.reduce_sum(self.losses['L1'], axis=[1, 2, 3]))

        self.losses['gen_loss'] = gen_loss = - disc_fake
        self.losses['enc_loss'] = enc_loss = self.losses['reconstructionLoss'] + self.kl_weight * loss_kl

        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            # Set the optimizer
            t_vars = tf.trainable_variables()
            dis_vars = [var for var in t_vars if 'Discriminator' in var.name]
            gen_vars = [var for var in t_vars if 'Generator' in var.name]
            enc_vars = [var for var in t_vars if 'Encoder' in var.name]

            optim_dis = tf.train.AdamOptimizer(learning_rate=self.config.learningrate, beta1=0.5, beta2=0.9).minimize(disc_loss, var_list=dis_vars)
            optim_gen = tf.train.AdamOptimizer(learning_rate=self.config.learningrate, beta1=0.5, beta2=0.9).minimize(gen_loss, var_list=gen_vars)
            optim_vae = tf.train.AdamOptimizer(learning_rate=self.config.learningrate, beta1=0.5, beta2=0.9).minimize(enc_loss, var_list=enc_vars + gen_vars)

        # initialize all variables
        tf.global_variables_initializer().run(session=self.sess)

        best_cost = inf
        last_improvement = 0
        last_epoch = self.load_checkpoint()

        # Go go go!
        for epoch in range(last_epoch, self.config.numEpochs):
            #################
            # TRAINING WGAN #
            #################
            phase = Phase.TRAIN
            scalars = defaultdict(list)
            visuals = []
            d_iters = 5
            num_batches = dataset.num_batches(self.config.batchsize, set=phase.value)
            for idx in range(0, num_batches):
                batch, _, _ = dataset.next_batch(self.config.batchsize, set=phase.value)

                # Encoder optimization
                fetches = {
                    # 'generated': self.generated,
                    'reconstruction': self.reconstruction,
                    'reconstructionLoss': self.losses['reconstructionLoss'],
                    'L1': self.losses['L1'],
                    'enc_loss': self.losses['enc_loss'],
                    'optimizer_e': optim_vae,
                }

                feed_dict = {
                    self.x: batch,
                    self.dropout: phase == Phase.TRAIN,
                    self.dropout_rate: self.config.dropout_rate
                }
                run = self.sess.run(fetches, feed_dict=feed_dict)

                # Generator optimization
                fetches = {
                    'gen_loss': self.losses['gen_loss'],
                    'optimizer_g': optim_gen,
                }

                feed_dict = {
                    self.x: batch,
                    self.dropout: phase == Phase.TRAIN,
                    self.dropout_rate: self.config.dropout_rate
                }
                run = {**run, **self.sess.run(fetches, feed_dict=feed_dict)}

                for _ in range(0, d_iters):
                    # Discriminator optimization
                    fetches = {
                        'disc_loss': self.losses['disc_loss'],
                        'disc_fake': self.losses['disc_fake'],
                        'disc_real': self.losses['disc_real'],
                        'optimizer_d': optim_dis,
                    }
                    feed_dict = {
                        self.x: batch,
                        self.dropout: phase == Phase.TRAIN,
                        self.dropout_rate: self.config.dropout_rate
                    }
                    run = {**run, **self.sess.run(fetches, feed_dict=feed_dict)}

                # Print to console
                print(f'Epoch ({phase.value}): [{epoch:2d}] [{idx:4d}/{num_batches:4d}]'
                      f' gen_loss: {run["gen_loss"]:.8f}, disc_loss: {run["disc_loss"]:.8f}, reconstructionLoss: {run["reconstructionLoss"]:.8f}')
                update_log_dicts(*trainer_utils.get_summary_dict(batch, run), scalars, visuals)

            self.log_to_tensorboard(epoch, scalars, visuals, phase)

            # Increment last_epoch counter and save model
            last_epoch += 1
            self.save(self.checkpointDir, last_epoch)

            ##############
            # VALIDATION #
            ##############
            phase = Phase.VAL
            scalars = defaultdict(list)
            visuals = []
            num_batches = dataset.num_batches(self.config.batchsize, set=phase.value)
            for idx in range(0, num_batches):
                batch, _, _ = dataset.next_batch(self.config.batchsize, set=phase.value)

                # Encoder optimization
                fetches = {
                    'reconstruction': self.reconstruction,
                    'reconstructionLoss': self.losses['reconstructionLoss'],
                    'L1': self.losses['L1'],
                    'enc_loss': self.losses['enc_loss'],
                }

                feed_dict = {
                    self.x: batch,
                    self.dropout: phase == Phase.TRAIN,
                    self.dropout_rate: self.config.dropout_rate
                }
                run = self.sess.run(fetches, feed_dict=feed_dict)
                # Print to console
                print(f'Epoch ({phase.value}): [{epoch:2d}] [{idx:4d}/{num_batches:4d}] reconstructionLoss: {run["reconstructionLoss"]:.8f}')
                update_log_dicts(*trainer_utils.get_summary_dict(batch, run), scalars, visuals)

            self.log_to_tensorboard(epoch, scalars, visuals, phase)

            best_cost, last_improvement, stop = indicate_early_stopping(scalars['reconstructionLoss'], best_cost, last_improvement)
            if stop:
                print('Early stopping was triggered due to no improvement over the last 5 epochs')
                break
Пример #3
0
Файл: AAE.py Проект: irfixq/AE
    def train(self, dataset):
        # Determine trainable variables
        self.variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

        # Build losses
        self.losses['disc_real'] = disc_real = tf.reduce_mean(self.d)
        self.losses['disc_fake'] = disc_fake = tf.reduce_mean(self.d_)
        self.losses['gen_loss'] = gen_loss = -disc_fake
        self.losses['disc_loss_without_grad'] = disc_loss = disc_fake - disc_real

        ddx = tf.gradients(self.d_hat, self.z_hat)[0]
        ddx = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1))
        ddx = tf.reduce_mean(tf.square(ddx - 1.0) * self.scale)
        self.losses['disc_loss'] = disc_loss = disc_loss + ddx

        self.losses['L1'] = tf.losses.absolute_difference(self.x, self.reconstruction, reduction=Reduction.NONE)
        self.losses['reconstructionLoss'] = tf.reduce_mean(tf.reduce_sum(self.losses['L1'], axis=[1, 2, 3]))

        self.losses['L2'] = l2 = tf.reduce_mean(tf.losses.mean_squared_error(self.x, self.reconstruction, reduction=Reduction.NONE), axis=[1, 2, 3])

        self.losses['loss'] = ae_loss = tf.reduce_mean(l2)

        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            # Set the optimizer
            t_vars = tf.trainable_variables()
            dis_vars = [var for var in t_vars if 'Discriminator' in var.name]
            gen_vars = [var for var in t_vars if 'Encoder' in var.name]
            ae_vars = t_vars

            optim_dis = tf.train.AdamOptimizer(learning_rate=self.config.learningrate, beta1=0.5, beta2=0.9).minimize(disc_loss, var_list=dis_vars)
            optim_gen = tf.train.AdamOptimizer(learning_rate=self.config.learningrate, beta1=0.5, beta2=0.9).minimize(gen_loss, var_list=gen_vars)
            optim_ae = tf.train.AdamOptimizer(learning_rate=self.config.learningrate, beta1=0.5, beta2=0.9).minimize(ae_loss, var_list=ae_vars)

        # initialize all variables
        tf.global_variables_initializer().run(session=self.sess)

        best_cost = inf
        last_improvement = 0
        last_epoch = self.load_checkpoint()

        # Go go go!
        for epoch in range(last_epoch, self.config.numEpochs):
            ############
            # TRAINING #
            ############
            phase = Phase.TRAIN
            scalars = defaultdict(list)
            visuals = []
            d_iters = 20
            num_batches = dataset.num_batches(self.config.batchsize, set=phase.value)
            for idx in range(0, num_batches):
                batch, _, _ = dataset.next_batch(self.config.batchsize, set=phase.value)

                run = {}
                for _ in range(d_iters if epoch <= 5 else 1):
                    # AE optimization
                    fetches = {
                        'reconstruction': self.reconstruction,
                        'L1': self.losses['L1'],
                        'loss': self.losses['loss'],
                        'reconstructionLoss': self.losses['reconstructionLoss'],
                        'optimizer_ae': optim_ae
                    }

                    feed_dict = self.get_feed_dict(batch, phase)

                    run = self.sess.run(fetches, feed_dict=feed_dict)

                for _ in range(d_iters):
                    # Discriminator optimization
                    fetches = {
                        'disc_loss': self.losses['disc_loss'],
                        'optimizer_d': optim_dis,
                    }

                    feed_dict = self.get_feed_dict(batch, phase)

                    run = {**run, **self.sess.run(fetches, feed_dict=feed_dict)}

                # Generator optimization
                fetches = {
                    'gen_loss': self.losses['gen_loss'],
                    'optimizer_g': optim_gen,
                }

                feed_dict = self.get_feed_dict(batch, phase)

                run = {**run, **self.sess.run(fetches, feed_dict=feed_dict)}

                # Print to console
                print(f'Epoch ({phase.value}): [{epoch:2d}] [{idx:4d}/{num_batches:4d}] loss: {run["reconstructionLoss"]:.8f},'
                      f' gen_loss: {run["gen_loss"]:.8f}, disc_loss: {run["disc_loss"]:.8f}')
                update_log_dicts(*trainer_utils.get_summary_dict(batch, run), scalars, visuals)

            self.log_to_tensorboard(epoch, scalars, visuals, phase)

            # Increment last_epoch counter and save model
            last_epoch += 1
            self.save(self.checkpointDir, last_epoch)

            ##############
            # VALIDATION #
            ##############
            scalars = defaultdict(list)
            visuals = []
            phase = Phase.VAL
            num_batches = dataset.num_batches(self.config.batchsize, set=phase.value)
            for idx in range(0, num_batches):
                batch, _, _ = dataset.next_batch(self.config.batchsize, set=phase.value)

                fetches = {
                    'reconstruction': self.reconstruction,
                    **self.losses
                }

                feed_dict = self.get_feed_dict(batch, phase)
                run = self.sess.run(fetches, feed_dict=feed_dict)

                # Print to console
                print(f'Epoch ({phase.value}): [{epoch:2d}] [{idx:4d}/{num_batches:4d}] loss: {run["loss"]:.8f}')
                update_log_dicts(*trainer_utils.get_summary_dict(batch, run), scalars, visuals)

            self.log_to_tensorboard(epoch, scalars, visuals, phase)

            best_cost, last_improvement, stop = indicate_early_stopping(scalars['reconstructionLoss'], best_cost, last_improvement)
            if stop:
                print('Early stopping was triggered due to no improvement over the last 5 epochs')
                break