示例#1
0
    def train(self, train_dataset, valid_dataset, steps, evaluate_every=1000, save_best_only=False):
        loss_mean = Mean()

        ckpt_mgr = self.checkpoint_manager
        ckpt = self.checkpoint

        self.now = time.perf_counter()

        for lr, hr in train_dataset.take(steps - ckpt.step.numpy()):
            ckpt.step.assign_add(1)
            step = ckpt.step.numpy()

            loss = self.train_step(lr, hr)
            loss_mean(loss)

            print("Currently in the train step ",step)

            if step % evaluate_every == 0:
                loss_value = loss_mean.result()
                loss_mean.reset_states()

                # Compute PSNR on validation dataset
                psnr_value = self.evaluate(valid_dataset)

                duration = time.perf_counter() - self.now
                print(f'{step}/{steps}: loss = {loss_value.numpy():.3f}, PSNR = {psnr_value.numpy():3f} ({duration:.2f}s)')

                if save_best_only and psnr_value <= ckpt.psnr:
                    self.now = time.perf_counter()
                    continue

                ckpt.psnr = psnr_value
                ckpt_mgr.save()

                self.now = time.perf_counter()
    def train(self, train_dataset, steps=200000):
        pls_metric = Mean()
        dls_metric = Mean()
        step = 0

        for lr, hr in train_dataset.take(steps):
            step += 1

            pl, dl = self.train_step(lr, hr)
            pls_metric(pl)
            dls_metric(dl)

            # if step % 1 == 0:
            if step % 50 == 0:
                print(
                    f'{step}/{steps}, perceptual loss = {pls_metric.result():.4f}, discriminator loss = {dls_metric.result():.4f}'
                )

                log_metric("GAN perceptual loss",
                           float(f'{pls_metric.result():.4f}'))
                log_metric("GAN discriminator loss",
                           float(f'{dls_metric.result():.4f}'))

                pls_metric.reset_states()
                dls_metric.reset_states()
示例#3
0
def pre_train(generator, train_dataset, valid_dataset, steps, evaluate_every=1,lr_rate=1e-4):
    loss_mean = Mean()
    pre_train_loss = MeanSquaredError()
    pre_train_optimizer = Adam(lr_rate)

    now = time.perf_counter()

    step = 0
    for lr, hr in train_dataset.take(steps):
        step = step+1

        with tf.GradientTape() as tape:
            lr = tf.cast(lr, tf.float32)
            hr = tf.cast(hr, tf.float32)

            sr = generator(lr, training=True)
            loss_value = pre_train_loss(hr, sr)

        gradients = tape.gradient(loss_value, generator.trainable_variables)
        pre_train_optimizer.apply_gradients(zip(gradients, generator.trainable_variables))
        loss_mean(loss_value)

        if step % evaluate_every == 0:
            loss_value = loss_mean.result()
            loss_mean.reset_states()

            psnr_value = evaluate(generator, valid_dataset)

            duration = time.perf_counter() - now
            print(
                f'{step}/{steps}: loss = {loss_value.numpy():.3f}, PSNR = {psnr_value.numpy():3f} ({duration:.2f}s)')

            now = time.perf_counter()
示例#4
0
    def train(self, train_dataset, valid_dataset, save_best_only=False):
        loss_mean = Mean()

        ckpt_mgr = self.checkpoint_manager
        ckpt = self.checkpoint

        self.now = time.perf_counter()

        for lr, hr in train_dataset.take(self.args.num_iter -
                                         ckpt.step.numpy()):
            ckpt.step.assign_add(1)
            step = ckpt.step.numpy()

            loss = self.train_step(lr, hr)
            loss_mean(loss)

            loss_value = loss_mean.result()
            loss_mean.reset_states()

            lr_value = ckpt.optimizer._decayed_lr('float32').numpy()

            duration = time.perf_counter() - self.now
            self.now = time.perf_counter()

            if step % self.args.log_freq == 0:
                tf.summary.scalar('loss', loss_value, step=step)
                tf.summary.scalar('lr', lr_value, step=step)

            if step % self.args.print_freq == 0:
                print(
                    f'{step}/{self.args.num_iter}: loss = {loss_value.numpy():.3f} , lr = {lr_value:.6f} ({duration:.2f}s)'
                )

            if step % self.args.valid_freq == 0:
                psnr_value = self.evaluate(valid_dataset)
                ckpt.psnr = psnr_value
                tf.summary.scalar('psnr', psnr_value, step=step)

                print(
                    f'{step}/{self.args.num_iter}: loss = {loss_value.numpy():.3f}, lr = {lr_value:.6f}, PSNR = {psnr_value.numpy():3f}'
                )

            if step % self.args.save_freq == 0:
                # save weights only
                save_path = self.ckpt_path + '/weights-' + str(step) + '.h5'
                self.checkpoint.model.save_weights(filepath=save_path,
                                                   save_format='h5')

                # save ckpt (weights + other train status)
                ckpt_mgr.save(checkpoint_number=step)
示例#5
0
class Leaner:
    def __init__(self, config: MuZeroConfig, storage: SharedStorage,
                 replay_buffer: ReplayBuffer):
        self.config = config
        self.storage = storage
        self.replay_buffer = replay_buffer
        self.summary = create_summary(name="leaner")
        self.metrics_loss = Mean(f'leaner-loss', dtype=tf.float32)
        self.network = Network(self.config)
        self.lr_schedule = ExponentialDecay(
            initial_learning_rate=self.config.lr_init,
            decay_steps=self.config.lr_decay_steps,
            decay_rate=self.config.lr_decay_rate)
        self.optimizer = Adam(learning_rate=self.lr_schedule)

    def start(self):
        while self.network.training_steps() < self.config.training_steps:
            if ray.get(self.replay_buffer.size.remote()) > 0:

                self.train()

                if self.network.training_steps(
                ) % self.config.checkpoint_interval == 0:
                    weigths = self.network.get_weights()
                    self.storage.update_network.remote(weigths)

                if self.network.training_steps(
                ) % self.config.save_interval == 0:
                    self.network.save()

        print("Finished")

    def train(self):
        batch = ray.get(self.replay_buffer.sample_batch.remote())

        with tf.GradientTape() as tape:
            loss = self.network.loss_function(batch)

        grads = tape.gradient(loss, self.network.get_variables())
        self.optimizer.apply_gradients(zip(grads,
                                           self.network.get_variables()))

        self.metrics_loss(loss)
        with self.summary.as_default():
            tf.summary.scalar(f'loss', self.metrics_loss.result(),
                              self.network.training_steps())
        self.metrics_loss.reset_states()

        self.network.update_training_steps()
示例#6
0
    def train_gan(self, train_ds, epochs, print_every, save_every,
                  log_filename, model_save_name):
        pls_metric = Mean()
        dls_metric = Mean()

        log_file = open(os.path.join(LOG_DIR, '{}.txt'.format(log_filename)),
                        'w+')
        log_file.close()

        print('----- Start training -----')
        epoch = 0
        for lr, hr in train_ds.take(epochs):
            epoch += 1
            step_time = time.time()

            generator_loss, discriminator_loss = self.train_step(lr, hr)

            # Apply metrics
            pls_metric(generator_loss)
            dls_metric(discriminator_loss)

            # Update log every 100 epochs
            if epoch == 1 or epoch % print_every == 0:
                print(
                    'Epoch {}/{}, time: {:.3f}s, generator loss = {:.4f}, discriminator loss = {:.4f}'
                    .format(epoch, epochs,
                            time.time() - step_time, pls_metric.result(),
                            dls_metric.result()))

                log_file = open(
                    os.path.join(LOG_DIR, '{}.txt'.format(log_filename)), 'a')
                log_file.write(
                    'Epoch {}/{}, time: {:.3f}s, generator loss = {:.4f}, discriminator loss = {:.4f}\n'
                    .format(epoch, epochs,
                            time.time() - step_time, pls_metric.result(),
                            dls_metric.result()))
                log_file.close()

                pls_metric.reset_states()
                dls_metric.reset_states()

            # Save model every 500 epochs
            if epoch % save_every == 0:
                generator.save(model_save_dir +
                               '/gen_{}_{}.h5'.format(model_save_name, epoch))
                discriminator.save(
                    model_save_dir +
                    '/dis_{}_{}.h5'.format(model_save_name, epoch))
示例#7
0
    def train(self,
              train_dataset,
              valid_dataset,
              steps,
              evaluate_every=1000,
              save_best_only=False):
        loss_mean = Mean()

        ckpt_mgr = self.checkpoint_manager
        ckpt = self.checkpoint

        self.now = time.perf_counter()

        for lr, hr in train_dataset.take(steps - ckpt.step.numpy(
        )):  # for low_resolution+high_resolution image pair in dataset
            t_start = time.time()
            ckpt.step.assign_add(1)
            step = ckpt.step.numpy()
            loss = self.train_step(lr, hr)
            loss_mean(loss)
            t_end = time.time()
            print("epoch:%3d step:%2d loss:%.5f time:%.3f" %
                  (step / 50, step % 50, loss, t_end - t_start))

            # evaluate
            if step % evaluate_every == 0:
                loss_value = loss_mean.result()
                loss_mean.reset_states()

                # Compute PSNR on validation dataset
                psnr_value = self.evaluate(valid_dataset)

                duration = time.perf_counter() - self.now
                print(
                    f'{step}/{steps}: loss = {loss_value.numpy():.3f}, PSNR = {psnr_value.numpy():3f} ({duration:.2f}s)'
                )

                if save_best_only and psnr_value <= ckpt.psnr:  # if no PSNR improvement
                    self.now = time.perf_counter()
                    # skip saving checkpoint
                    continue

                ckpt.psnr = psnr_value
                ckpt_mgr.save()
                print("checkpoint saved!")

                self.now = time.perf_counter()
示例#8
0
def train(
    model: CAE,
    dataset,
    output_path: str,
    epochs: Optional[int],
    image_width: int,
    image_height: int,
    log_freq: int,
    save_freq: int,
) -> None:
    @tf.function
    def train_step(image):
        with tf.GradientTape() as tape:
            pred_image = model(image)
            model_trainable_variables = model.trainable_variables
            loss = MSE(image, pred_image)
            gradients = tape.gradient(loss, model_trainable_variables)
            optimizer.apply_gradients(zip(gradients,
                                          model_trainable_variables))
            train_loss(loss)

    optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
    ckpt = tf.train.Checkpoint(optimizer=optimizer, transformer=model)
    manager = tf.train.CheckpointManager(ckpt, output_path, max_to_keep=1)
    train_loss = Mean(name='train_loss')

    epochs = epochs or len(dataset)
    section_size = 128
    for step, train_image in enumerate(dataset):
        train_image = train_image.numpy()
        for c in range(image_height // section_size):
            for j in range(image_width // section_size):
                cc = section_size * c
                jj = section_size * j
                train_image_batch = train_image[:, cc:cc + section_size,
                                                jj:jj + section_size, :]
                train_image_tensor = tf.convert_to_tensor(train_image_batch)
                train_step(train_image_tensor)

        if step % log_freq == 0:
            print(f'Step {step}/{epochs}, ' f'Loss: {train_loss.result()}, ')
        if step % save_freq == 0 or step == epochs - 1:
            print(f'Saved checkpoint: {manager.save()}')
            train_loss.reset_states()

            if epochs and step == epochs:
                break
    def train(self,
              train_dataset,
              valid_dataset,
              steps,
              evaluate_every=1000,
              save_best_only=False):
        loss_mean = Mean()

        ckpt_mgr = self.checkpoint_manager
        ckpt = self.checkpoint

        vis_list = []

        for lr, hr in train_dataset.take(steps - ckpt.step.numpy()):
            ckpt.step.assign_add(1)
            step = ckpt.step.numpy()

            loss = self.train_step(lr, hr)
            loss_mean(loss)

            if step % evaluate_every == 0:
                loss_value = loss_mean.result()
                loss_mean.reset_states()

                # Compute PSNR on validation dataset
                psnr_value = self.evaluate(valid_dataset)

                print(
                    f'{step}/{steps}: loss = {loss_value.numpy():.3f}, PSNR = {psnr_value.numpy():3f}'
                )

                vis_list.append((step, loss_value, psnr_value))

                if save_best_only and psnr_value <= ckpt.psnr:
                    # skip saving checkpoint, no PSNR improvement
                    continue

                ckpt.psnr = psnr_value
                ckpt_mgr.save()

        # saving progress data to make graphs
        csv = open('./visLoss.csv', 'w')
        csv.write('step, loss, psnr\n')
        for vals in vis_list:
            csv.write('{},{},{}\n'.format(vals[0], vals[1], vals[2]))
        csv.close()
示例#10
0
    def train(self, train_dataset, steps=200000):
        pls_metric = Mean()
        dls_metric = Mean()
        step = 0

        for lr, hr in train_dataset.take(steps):
            step += 1

            pl, dl = self.train_step(lr, hr)
            print("Currently in the sr-train step ",step)
            pls_metric(pl)
            dls_metric(dl)

            if step % 50 == 0:
                print(f'{step}/{steps}, perceptual loss = {pls_metric.result():.4f}, discriminator loss = {dls_metric.result():.4f}')
                pls_metric.reset_states()
                dls_metric.reset_states()
示例#11
0
    def train(self,
              train_dataset,
              valid_dataset,
              steps,
              evaluate_every=1000,
              save_best_only=False):
        loss_mean = Mean()

        ckpt_mgr = self.checkpoint_manager
        ckpt = self.checkpoint

        self.now = time.perf_counter()

        for lr, hr in train_dataset.take(steps - ckpt.step.numpy()):
            #print('check1..', steps, ckpt.step.numpy())
            ckpt.step.assign_add(1)
            step = ckpt.step.numpy()

            loss = self.train_step(lr, hr)
            loss_mean(loss)

            if step % evaluate_every == 0:
                loss_value = loss_mean.result()
                loss_mean.reset_states()

                # Compute PSNR on validation dataset
                psnr_value = self.evaluate(valid_dataset)

                duration = time.perf_counter() - self.now
                print(
                    f'{step}/{steps}: loss = {loss_value.numpy():.3f}, PSNR = {psnr_value.numpy():3f} ({duration:.2f}s)'
                )
                #########
                self.resolve_and_plot('demo/img_0', step)
                #########

                if save_best_only and psnr_value <= ckpt.psnr:
                    self.now = time.perf_counter()
                    # skip saving checkpoint, no PSNR improvement
                    continue

                ckpt.psnr = psnr_value
                ckpt_mgr.save()

                self.now = time.perf_counter()
class MeanBasedMetric(Metric):
    def __init__(self, name, dtype):
        super().__init__(name, dtype=dtype)
        self._mean = Mean(dtype=dtype)

    @abstractmethod
    def _objective_function(self, y_true, y_pred):
        pass

    def update_state(self, y_true, y_pred, sample_weight=None):
        values = self._objective_function(y_true, y_pred)
        self._mean.update_state(values=values, sample_weight=sample_weight)

    def result(self):
        return self._mean.result()

    def reset_states(self):
        self._mean.reset_states()
示例#13
0
    def train(self, train_dataset, steps=100000):
        pls_metric = Mean()
        dls_metric = Mean()
        step = 0

        for lr, hr in train_dataset.take(steps):
            step += 1

            pl, dl = self.train_step(lr, hr)
            pls_metric(pl)
            dls_metric(dl)

            if step % 10 == 0:
                print(
                    f'{step}/{steps}, Adv loss = {pls_metric.result():.4f}, D loss = {dls_metric.result():.4f}'
                )
                pls_metric.reset_states()
                dls_metric.reset_states()
示例#14
0
    def train_generator(self,
                        train_dataset,
                        valid_dataset,
                        epochs=20000,
                        valid_lr=None,
                        valid_hr=None):
        evaluate_size = epochs / 10

        loss_mean = Mean()

        start_time = time.time()
        epoch = 0

        for lr, hr in train_dataset.take(epochs):
            epoch += 1
            step = tf.convert_to_tensor(epoch, dtype=tf.int64)
            generator_loss = self.train_generator_step(lr, hr)
            loss_mean(generator_loss)

            if epoch % 50 == 0:
                loss_value = loss_mean.result()
                loss_mean.reset_states()

                psnr_value = self.evaluate(valid_dataset.take(1))

                print(
                    f'Time for epoch {epoch}/{epochs} is {(time.time() - start_time):.4f} sec, '
                    f'gan loss = {loss_value:.4f}, psnr = {psnr_value:.4f}')
                start_time = time.time()

                if self.summary_writer is not None:
                    with self.summary_writer.as_default():
                        tf.summary.scalar('generator_loss',
                                          loss_value,
                                          step=epoch)
                        tf.summary.scalar('psnr', psnr_value, step=epoch)

            if epoch % evaluate_size == 0:
                self.util.save_checkpoint(self.checkpoint, epoch)

            if epoch % 5000 == 0:
                self.generate_and_save_images(step, valid_lr, valid_hr)
示例#15
0
    def train(self,
              train_dataset,
              valid_dataset,
              steps,
              evaluate_every=1000,
              save_best_only=False):
        loss_mean = Mean()

        ckpt_mgr = self.checkpoint_manager
        ckpt = self.checkpoint

        self.now = timeit.default_timer()

        for lr, hr in train_dataset.take(steps - ckpt.step.numpy()):
            ckpt.step.assign_add(1)
            step = ckpt.step.numpy()

            loss = self.train_step(lr, hr)
            loss_mean(loss)

            if step % evaluate_every == 0:
                loss_value = loss_mean.result()
                loss_mean.reset_states()

                # Compute PSNR on validation dataset
                psnr_value, ssim_value = self.evaluate(valid_dataset)

                duration = timeit.default_timer() - self.now
                print('%d/%d: loss = %.3f, PSNR = %3f (%.2fs)' %
                      (step, steps, loss_value.numpy(), psnr_value.numpy(),
                       duration))

                if save_best_only and psnr_value <= ckpt.psnr:
                    self.now = timeit.timeit()
                    # skip saving checkpoint, no PSNR improvement
                    continue

                ckpt.psnr = psnr_value
                ckpt_mgr.save()

                self.now = timeit.timeit()
class StandardVarianceBasedMetric(Metric):
    def __init__(self, name, dtype):
        super().__init__(name, dtype=dtype)
        self._mean = Mean(dtype=dtype)
        self._square_mean = Mean(dtype=dtype)

    @abstractmethod
    def _objective_function(self, y_true, y_pred):
        pass

    def update_state(self, y_true, y_pred, sample_weight=None):
        values = self._objective_function(y_true, y_pred)
        self._mean.update_state(values=values, sample_weight=sample_weight)
        self._square_mean.update_state(values=tf.square(values),
                                       sample_weight=sample_weight)

    def result(self):
        return tf.sqrt(self._square_mean.result() -
                       tf.square(self._mean.result()))

    def reset_states(self):
        self._mean.reset_states()
        self._square_mean.reset_states()
示例#17
0
    def train(self, train_ds, valid_ds, steps, evaluate_every=1000, save_best_only=False):
        loss_mean = Mean()

        ckpt_mgr = self.checkpoint_manager
        ckpt = self.checkpoint

        self.now = time.perf_counter()

        for lr, hr in train_ds.take(steps - ckpt.step.numpy()):
            ckpt.step.assign_add(1)
            step = ckpt.step.numpy()

            loss = self.train_step(lr, hr)
            loss_mean(loss)

            if step % evaluate_every == 0:
                # Record loss value
                loss_value = loss_mean.result()
                loss_mean.reset_states()
                
                # Comput PSNR on validation set
                psnr_value = self.evaluate(valid_ds)
                
                # Calculate time consumed
                duration = time.perf_counter() - self.now
                print('{}/{}: loss = {:.3f}, PSNR = {:.3f} ({:.2f}s)'.format(step, steps, loss_value.numpy(), psnr_value.numpy(), duration))

                # Skip checkpoint if PSNR does not improve
                if save_best_only and psnr_value <= ckpt.psnr:
                    self.now = time.perf_counter()
                    continue
                
                # Save checkpoint
                ckpt.psnr = psnr_value
                ckpt_mgr.save()

                self.now = time.perf_counter()
示例#18
0
文件: train.py 项目: Aksh97/PRS_prj
    def train(self, train_dataset, steps=200000):
        pls_metric = Mean()
        dls_metric = Mean()
        step = 0

        for lr, hr in train_dataset.take(steps):
            step += 1

            pl, dl = self.train_step(lr, hr)
            pls_metric(pl)
            dls_metric(dl)

            if step % 50 == 0:
                print(
                    f'{step}/{steps}, perceptual loss = {pls_metric.result():.4f}, discriminator loss = {dls_metric.result():.4f}'
                )
                pls_metric.reset_states()
                dls_metric.reset_states()

            if step % 10000 == 0:
                self.generator.save_weights('weights/srgan/gan_generator' +
                                            str(step) + '.h5')
                self.discriminator.save_weights(
                    'weights/srgan/gan_discriminator' + str(step) + '.h5')
示例#19
0
@tf.function
def testing(images, labels):
    predicts = model(images)
    t_loss = loss_(labels, predicts)

    test_loss(t_loss)
    test_accuracy(labels, predicts)


# TRAINING
for epoch in range(EPOCHS):
    for train_images, train_labels in train:
        training(train_images, train_labels)

    for test_images, test_labels in test:
        testing(test_images, test_labels)

    to_print = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(
        to_print.format(epoch + 1, train_loss.result(),
                        train_accuracy.result() * 100, test_loss.result(),
                        test_accuracy.result() * 100))

    # Reset the metrics for the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    model.save_weights('model', save_format='tf')
示例#20
0
class ModelTrainer:
    """
    Note:
    Having this model keeps the trainStep and testStep instance new every time you call it.
    Implementing those functions outside a class will return an error
    ValueError: Creating variables on a non-first call to a function decorated with tf.function.
    """
    def __init__(self,
                 model,
                 loss,
                 metric,
                 optimizer,
                 ckptDir,
                 logDir,
                 multiGPU=True,
                 evalStep=1000):

        # Safety checks
        self.logDirTrain = os.path.join(logDir, 'Train')
        self.logDirTest = os.path.join(logDir, 'Test')

        if not os.path.exists(ckptDir):
            os.makedirs(ckptDir)
        if not os.path.exists(self.logDirTrain):
            os.makedirs(self.logDirTrain)
        if not os.path.exists(self.logDirTest):
            os.makedirs(self.logDirTest)

        self.trainWriter = tf.summary.create_file_writer(self.logDirTrain)
        self.testWriter = tf.summary.create_file_writer(self.logDirTest)

        self.ckpt = tf.train.Checkpoint(step=tf.Variable(0),
                                        psnr=tf.Variable(1.0),
                                        optimizer=optimizer,
                                        model=model)
        self.ckptMngr = tf.train.CheckpointManager(checkpoint=self.ckpt,
                                                   directory=ckptDir,
                                                   max_to_keep=5)

        self.loss = loss
        self.metric = metric

        self.accTestLoss = Mean(name='accTestLoss')
        self.accTestPSNR = Mean(name='accTestPSNR')
        self.accTrainLoss = Mean(name='accTrainLoss')
        self.accTrainPSNR = Mean(name='accTrainPSNR')
        self.evalStep = evalStep
        self.multiGPU = multiGPU
        self.strategy = None
        self.restore()

    @property
    def model(self):
        return self.ckpt.model

    def restore(self):
        if self.ckptMngr.latest_checkpoint:
            self.ckpt.restore(self.ckptMngr.latest_checkpoint)
            print(
                f'[ INFO ] Model restored from checkpoint at step {self.ckpt.step.numpy()}.'
            )

    def fitTrainData(self,
                     X: tf.Tensor,
                     y: tf.Tensor,
                     globalBatchSize: int,
                     epochs: int,
                     valData: List[np.ma.array],
                     bufferSize: int = 128,
                     valSteps: int = 64,
                     saveBestOnly: bool = True,
                     initEpoch: int = 0):

        logger.info('[ INFO ] Loading data set to buffer cache...')
        trainSet = loadTrainDataAsTFDataSet(X, y[0], y[1], epochs,
                                            globalBatchSize, bufferSize)
        valSet = loadValDataAsTFDataSet(valData[0], valData[1], valData[2],
                                        valSteps, globalBatchSize, bufferSize)
        logger.info('[ INFO ] Loading success...')

        dataSetLength = len(X)
        totalSteps = tf.cast(dataSetLength / globalBatchSize, tf.int64)
        globalStep = tf.cast(self.ckpt.step, tf.int64)
        step = globalStep % totalSteps
        epoch = initEpoch

        logger.info('[ INFO ] Begin training...')

        for x_batch_train, y_batch_train, y_mask_batch_train in trainSet:
            if (totalSteps - step) == 0:
                epoch += 1
                step = tf.cast(self.ckpt.step, tf.int64) % totalSteps
                logger.info(
                    f'[ ***************  NEW EPOCH  *************** ] Epoch number {epoch}'
                )
                # Reset metrics
                self.accTrainLoss.reset_states()
                self.accTrainPSNR.reset_states()
                self.accTestLoss.reset_states()
                self.accTestPSNR.reset_states()

            step += 1
            globalStep += 1
            self.trainStep(x_batch_train, y_batch_train, y_mask_batch_train)
            self.ckpt.step.assign_add(1)

            t = f"[ EPOCH {epoch}/{epochs} ] - [ STEP {step}/{int(totalSteps)} ] Loss: {self.accTrainLoss.result():.3f}, cPSNR: {self.accTrainPSNR.result():.3f}"
            logger.info(t)

            self.saveLog('Train', globalStep)

            if step != 0 and (step % self.evalStep) == 0:
                # Reset states for test
                self.accTestLoss.reset_states()
                self.accTestPSNR.reset_states()
                for x_batch_val, y_batch_val, y_mask_batch_val in valSet:
                    self.testStep(x_batch_val, y_batch_val, y_mask_batch_val)
                self.saveLog('Test', globalStep)
                t = f"[ *************** VAL INFO *************** ] Validation Loss: {self.accTestLoss.result():.3f}, Validation PSNR: {self.accTestPSNR.result():.3f}"
                logger.info(t)

                if saveBestOnly and (self.accTestPSNR.result() <=
                                     self.ckpt.psnr):
                    continue

                logger.info('[ SAVE ] Saving checkpoint...')
                self.ckpt.psnr = self.accTestPSNR.result()
                self.ckptMngr.save()

    @tf.function
    def trainStep(self, patchLR, patchHR, maskHR):
        with tf.GradientTape() as tape:
            predPatchHR = self.ckpt.model(patchLR, training=True)
            # Loss(patchHR: tf.Tensor, maskHR: tf.Tensor, predPatchHR: tf.Tensor)
            loss = self.loss(patchHR, maskHR, predPatchHR)

        gradients = tape.gradient(loss, self.ckpt.model.trainable_variables)
        self.ckpt.optimizer.apply_gradients(
            zip(gradients, self.ckpt.model.trainable_variables))
        metric = self.metric(patchHR, maskHR, predPatchHR)
        self.accTrainLoss(loss)
        self.accTrainPSNR(metric)

    @tf.function
    def testStep(self, patchLR, patchHR, maskHR):
        predPatchHR = self.ckpt.model(patchLR, training=False)
        loss = self.loss(patchHR, maskHR, predPatchHR)
        metric = self.metric(patchHR, maskHR, predPatchHR)
        self.accTestLoss(loss)
        self.accTestPSNR(metric)

    def saveLog(self, testOrTrain, globalStep):
        w = self.trainWriter if testOrTrain == 'Train' else self.testWriter
        with w.as_default():
            if testOrTrain == 'Train':
                tf.summary.scalar('PSNR',
                                  self.accTrainPSNR.result(),
                                  step=globalStep)
                tf.summary.scalar('Loss',
                                  self.accTrainLoss.result(),
                                  step=globalStep)
            else:
                tf.summary.scalar('PSNR',
                                  self.accTestPSNR.result(),
                                  step=globalStep)
                tf.summary.scalar('Loss',
                                  self.accTestLoss.result(),
                                  step=globalStep)
            w.flush()
示例#21
0
    def train(self,
              train_dataset,
              valid_dataset,
              args,
              shapes,
              save_best_only=False):
        loss_mean = Mean()

        ckpt_mgr = self.checkpoint_manager
        ckpt = self.checkpoint

        self.now = time.perf_counter()

        lr_train_shape, hr_train_shape = shapes[0]
        lr_valid_shape, hr_valid_shape = shapes[1]

        info = {
            "losses": [],
            "psnr": [],
            "time": [],
            "every": args.every,
            "total": args.steps,
            "avg_lr_train_shape": lr_train_shape,
            "avg_hr_train_shape": hr_train_shape,
            "avg_lr_valid_shape": lr_valid_shape,
            "avg_hr_valid_shape": hr_valid_shape,
        }

        for lr, hr in train_dataset.take(args.steps - ckpt.step.numpy()):
            ckpt.step.assign_add(1)
            step = ckpt.step.numpy()

            loss = self.train_step(lr, hr)
            loss_mean(loss)

            if step % args.every == 0:
                loss_value = loss_mean.result()
                loss_mean.reset_states()

                # Compute PSNR on validation dataset
                psnr_value = self.evaluate(valid_dataset)

                duration = time.perf_counter() - self.now

                info["losses"].append(float(loss_value.numpy()))
                info["psnr"].append(float(psnr_value.numpy()))
                info["time"].append(float(duration))

                filename = (
                    f"{args.dataset}_edsr_lr_x{args.scale * 2}_hr_x{args.scale}_"
                    f"res_{args.nb_res}_filt_{args.nb_filters}_batch_{args.batch_size}_"
                    f"transform_{args.transform}_every_{args.every}_"
                    f"steps_{args.steps}.json")

                with open(filename, "w") as f:
                    json.dump(info, f)

                print(
                    f"{step}/{args.steps}: loss = {loss_value.numpy():.3f}, PSNR = {psnr_value.numpy():3f} ({duration:.2f}s)"
                )

                if save_best_only and psnr_value <= ckpt.psnr:
                    self.now = time.perf_counter()
                    # skip saving checkpoint, no PSNR improvement
                    continue

                ckpt.psnr = psnr_value
                ckpt_mgr.save()

                self.now = time.perf_counter()
示例#22
0
    #     validation_acc(y, predictions)

    train_reporter()
    # print(colored('Epoch: ', 'red', 'on_white'), epoch + 1)
    # template = 'Train Loss: {:.4f}\t Train Accuracy: {:.2f}%\n' + \
    #            'Validation Loss: {:.4f}\t Validation Accuracy: {:.2f}%\n'
    # print(template.format(train_loss.result(), train_acc.result()*100,
    #                       validation_loss.result(), validation_acc.result()*100))

    # metric_resetter()
    train_losses.append(train_loss.result())
    train_accs.append(train_acc.result())
    validation_losses.append(validation_loss.result() * 100)
    validation_accs.append(validation_acc.result() * 100)

    train_loss.reset_states()
    train_acc.reset_states()
    validation_loss.reset_states()
    validation_acc.reset_states()

for x, y in test_ds:
    predictions = model(x)
    loss = loss_object(y, predictions)

    test_loss(loss)
    test_acc(y, predictions)

print(colored('Final Result: ', 'red', 'on_white'), epoch + 1)
template = 'Test Loss: {:.4f}\t Test Accuracy: {:.2f}%\n'
print(template.format(test_loss.result(), test_acc.result() * 100))
示例#23
0
class Trainer:
    def __init__(self, num_scales, num_iters, max_size, min_size, scale_factor,
                 learning_rate, checkpoint_dir, debug):

        self.num_scales = num_scales
        self.num_iters = num_iters
        self.num_filters = [
            32 * pow(2, (scale // 4)) for scale in range(self.num_scales)
        ]  # num_filters double for every 4 scales
        self.max_size = max_size
        self.min_size = min_size
        self.scale_factor = scale_factor
        self.noise_amp_init = 0.1

        self.checkpoint_dir = checkpoint_dir
        self.G_dir = self.checkpoint_dir + '/G'
        self.D_dir = self.checkpoint_dir + '/D'

        self.learning_schedule = ExponentialDecay(
            learning_rate, decay_steps=4800, decay_rate=0.1,
            staircase=True)  # 1600 * 3 steps
        self.build_model()

        self.debug = debug
        if self.debug:
            self.create_summary_writer()
            self.create_metrics()

    def build_model(self):
        """ Build initial model """
        create_dir(self.checkpoint_dir)
        self.generators = []
        self.discriminators = []
        for scale in range(self.num_scales):
            self.generators.append(
                Generator(num_filters=self.num_filters[scale]))
            self.discriminators.append(
                Discriminator(num_filters=self.num_filters[scale]))

    def save_model(self, scale):
        """ Save weights and NoiseAmp """
        G_dir = self.G_dir + f'{scale}'
        D_dir = self.D_dir + f'{scale}'
        if not os.path.exists(G_dir):
            os.makedirs(G_dir)
        if not os.path.exists(D_dir):
            os.makedirs(D_dir)

        self.generators[scale].save_weights(G_dir + '/G', save_format='tf')
        self.discriminators[scale].save_weights(D_dir + '/D', save_format='tf')
        np.save(self.checkpoint_dir + '/NoiseAmp', self.NoiseAmp)

    def init_from_previous_model(self, scale):
        """ Initialize current model from the previous trained model """
        if self.num_filters[scale] == self.num_filters[scale - 1]:
            self.generators[scale].load_weights(self.G_dir + f'{scale-1}/G')
            self.discriminators[scale].load_weights(self.D_dir +
                                                    f'{scale-1}/D')

    def train(self, training_image):
        """ Training """
        real_image = functions.read_image(training_image)
        #real_image = normalize_m11(real_image)

        self.num_scales, stop_scale, scale1, self.scale_factor, reals =\
            functions.adjust_scales2image(real_image, self.scale_factor, self.min_size, self.max_size)

        real_image = functions.read_image(training_image)
        real_image = my_imresize(real_image, scale1)
        reals = []
        reals = functions.creat_reals_pyramid(real_image, reals,
                                              self.scale_factor, stop_scale)
        reals = [
            tf.convert_to_tensor(real.permute((0, 2, 3, 1)).numpy(),
                                 dtype=tf.float32) for real in reals
        ]

        self.Z_fixed = []
        self.NoiseAmp = []
        noise_amp = tf.constant(0.1)

        for scale in range(stop_scale + 1):
            print(scale)
            start = time.perf_counter()

            if scale > 0:
                self.init_from_previous_model(scale)
            g_opt = Adam(learning_rate=self.learning_schedule,
                         beta_1=0.5,
                         beta_2=0.999)
            d_opt = Adam(learning_rate=self.learning_schedule,
                         beta_1=0.5,
                         beta_2=0.999)
            """ Build with shape """
            prev_rec = tf.zeros_like(reals[scale])
            self.discriminators[scale](prev_rec)
            self.generators[scale](prev_rec, prev_rec)

            train_step = self.wrapper()
            print(tf.get_collection('checkpoints'))
            for step in tf.range(self.num_iters):
                z_fixed, prev_rec, noise_amp, metrics = train_step(
                    reals, prev_rec, noise_amp, scale, step, g_opt, d_opt)

            print(tf.get_collection('checkpoints'))

            self.Z_fixed.append(z_fixed)
            self.NoiseAmp.append(noise_amp)
            self.save_model(scale)

            if self.debug:
                self.write_summaries(metrics, scale)
                self.update_metrics(metrics, scale)
                print(
                    f'Time taken for scale {scale} is {time.perf_counter()-start:.2f} sec\n'
                )

    def wrapper(self):
        @tf.function
        def train_step(reals, prev_rec, noise_amp, scale, step, g_opt, d_opt):
            real = reals[scale]
            z_rand = tf.random.normal(real.shape)

            if scale == 0:
                z_rec = tf.random.normal(real.shape)
            else:
                z_rec = tf.zeros_like(real)

            for i in range(6):
                if i == 0 and tf.get_static_value(step) == 0:
                    if scale == 0:
                        """ Coarsest scale is purely generative """
                        prev_rand = tf.zeros_like(real)
                        prev_rec = tf.zeros_like(real)
                        noise_amp = 1.0
                    else:
                        """ Finer scale takes noise and image generated from previous scale as input """
                        prev_rand = self.generate_from_coarsest(
                            scale, reals, 'rand')
                        prev_rec = self.generate_from_coarsest(
                            scale, reals, 'rec')
                        """ Compute the standard deviation of noise """
                        RMSE = tf.sqrt(
                            tf.reduce_mean(tf.square(real - prev_rec)))
                        noise_amp = self.noise_amp_init * RMSE
                else:
                    prev_rand = self.generate_from_coarsest(
                        scale, reals, 'rand')

                Z_rand = z_rand if scale == 0 else noise_amp * z_rand
                Z_rec = noise_amp * z_rec

                if i < 3:
                    with tf.GradientTape() as tape:
                        """ Only record the training variables """
                        fake_rand = self.generators[scale](prev_rand, Z_rand)

                        dis_loss = self.dicriminator_wgan_loss(
                            self.discriminators[scale], real, fake_rand, 1)

                    dis_gradients = tape.gradient(
                        dis_loss,
                        self.discriminators[scale].trainable_variables)
                    d_opt.apply_gradients(
                        zip(dis_gradients,
                            self.discriminators[scale].trainable_variables))
                else:
                    with tf.GradientTape() as tape:
                        """ Only record the training variables """
                        fake_rand = self.generators[scale](prev_rand, Z_rand)
                        fake_rec = self.generators[scale](prev_rec, Z_rec)

                        gen_loss = self.generator_wgan_loss(
                            self.discriminators[scale], fake_rand)
                        rec_loss = self.reconstruction_loss(real, fake_rec)
                        gen_loss = gen_loss + 10 * rec_loss

                    gen_gradients = tape.gradient(
                        gen_loss, self.generators[scale].trainable_variables)
                    g_opt.apply_gradients(
                        zip(gen_gradients,
                            self.generators[scale].trainable_variables))

            metrics = (dis_loss, gen_loss, rec_loss)
            return z_rec, prev_rec, noise_amp, metrics

        return train_step

    def generate_from_coarsest(self, scale, reals, mode='rand'):
        """ Use random/fixed noise to generate from coarsest scale"""
        fake = tf.zeros_like(reals[0])
        if scale > 0:
            if mode == 'rand':
                for i in range(scale):
                    z_rand = tf.random.normal(reals[i].shape)
                    z_rand = self.NoiseAmp[i] * z_rand
                    fake = self.generators[i](fake, z_rand)
                    fake = imresize(fake, new_shapes=reals[i + 1].shape)

            if mode == 'rec':
                for i in range(scale):
                    z_fixed = self.NoiseAmp[i] * self.Z_fixed[i]
                    fake = self.generators[i](fake, z_fixed)
                    fake = imresize(fake, new_shapes=reals[i + 1].shape)
        return fake

    def create_real_pyramid(self, real_image):
        """ Create the pyramid of scales """
        reals = [real_image]
        for i in range(1, self.num_scales):
            reals.append(
                imresize(real_image,
                         min_size=self.min_size,
                         scale_factor=pow(0.75, i)))
        """ Reverse it to coarse-fine scales """
        reals.reverse()
        for real in reals:
            print(real.shape)
        return reals

    def generator_wgan_loss(self, discriminator, fake):
        """ Ladv(G) = -E[D(fake)] """
        return -tf.reduce_mean(discriminator(fake))

    def reconstruction_loss(self, real, fake_rec):
        """ Lrec = || G(z*) - real ||^2 """
        return tf.reduce_mean(tf.square(fake_rec - real))

    def dicriminator_wgan_loss(self, discriminator, real, fake, batch_size=1):
        """ Ladv(D) = E[D(fake)] - E[D(real)] + GradientPenalty"""
        dis_loss = tf.reduce_mean(discriminator(fake)) - tf.reduce_mean(
            discriminator(real))

        alpha = tf.random.uniform(shape=[batch_size, 1, 1, 1],
                                  minval=0.,
                                  maxval=1.)  # real.shape
        interpolates = alpha * real + ((1 - alpha) * fake)
        with tf.GradientTape() as tape:
            tape.watch(interpolates)
            dis_interpolates = discriminator(interpolates)
        gradients = tape.gradient(dis_interpolates, [interpolates])[0]

        slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[
            3
        ]))  # compute pixelwise gradient norm; per image use [1, 2, 3]
        gradient_penalty = tf.reduce_mean((slopes - 1.)**2)

        dis_loss = dis_loss + 0.1 * gradient_penalty
        return dis_loss

    def create_metrics(self):
        self.dis_metric = Mean()
        self.gen_metric = Mean()
        self.rec_metric = Mean()

    def update_metrics(self, metrics, step):
        dis_loss, gen_loss, rec_loss = metrics

        self.dis_metric(dis_loss)
        self.gen_metric(gen_loss)
        self.rec_metric(rec_loss)

        print(f' dis_loss = {self.dis_metric.result():.3f}')
        print(f' gen_loss = {self.gen_metric.result():.3f}')
        print(f' rec_loss = {self.rec_metric.result():.3f}')

        self.dis_metric.reset_states()
        self.gen_metric.reset_states()
        self.rec_metric.reset_states()

    def create_summary_writer(self):
        import datetime
        self.summary_writer = tf.contrib.summary.create_file_writer(
            'log/fit/' + datetime.datetime.now().strftime('%Y%m%d-%H%M%S'))

    def write_summaries(self, metrics, scale):
        dis_loss, gen_loss, rec_loss = metrics

        with self.summary_writer.as_default():
            tf.summary.scalar('dis_loss', dis_loss)
            tf.summary.scalar('gen_loss', gen_loss)
            tf.summary.scalar('rec_loss', rec_loss)
示例#24
0
    for val_images, val_labels in val:
        testing(val_images, val_labels)

    to_print = 'Epoch {}, Loss: {}, Accuracy: {}, Valid Loss: {}, Valid Accuracy: {}'
    print(
        to_print.format(epoch + 1, train_loss.result(),
                        train_accuracy.result() * 100, val_loss.result(),
                        val_accuracy.result() * 100))
    train_l.append(train_loss.result())
    train_a.append(train_accuracy.result())
    val_l.append(val_loss.result())
    val_a.append(val_accuracy.result())
    epochs.append(epoch)

    # Reset the metrics for the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    val_loss.reset_states()
    val_accuracy.reset_states()

    model.save_weights('model', save_format='tf')

plt.figure(figsize=(24, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs, val_a, label="validation_accuracy", c="red")
plt.plot(epochs, train_a, label="training_accuracy", c="green")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs, val_l, label="validation_loss", c="red")
plt.plot(epochs, train_l, label="training_loss", c="green")
示例#25
0
class POTrainer:
    def __init__(self, actor, env, config):
        self.actor = actor
        self.env = env

        self.config = config
        self.T = config.unroll_steps
        self.input_dim = config.channel_cardinality - 1
        self.model = self._build_training_model()

        lr_sche = keras.optimizers.schedules.ExponentialDecay(
            self.config.learning_rate,
            decay_steps=self.config.num_epochs //
            self.config.learning_rate_decay_steps,
            decay_rate=self.config.learning_rate_decay,
            staircase=False)
        self.optimizer = Adam(learning_rate=lr_sche)

        self.model.compile(optimizer=self.optimizer, loss=loss)

        self.mean_metric = Mean()

    @tf.function
    def split_rewards_and_state(self, x):
        return tf.split(x, axis=-1, num_or_size_splits=[1, self.input_dim])

    def _build_training_model(self):
        def name(title, idx):
            return "{}_{:d}".format(title, idx)

        rewards = list()
        states = list()
        self.input = z = Input(shape=[self.input_dim])
        for t in range(self.T):
            u = self.actor.model(z, training=True)
            z_u = Concatenate(axis=-1, name=name("concat_z_u", t))([z, u])
            r, z_prime = self.env.model(z_u)
            # size = K.cast(tf.shape(probs), dtype=tf.int64)[0]
            #
            # disturbance = K.squeeze(tf.random.categorical(tf.math.log(probs), 1), axis=-1)
            # next_state_indices = K.stack([tf.range(size), disturbance], axis=1)
            # z_prime = tf.gather_nd(z_primes, next_state_indices)

            rewards.append(r)
            z = z_prime
            states.append(z)

        self.rewards = Concatenate(axis=-1, name="concat_rewards")(rewards)
        self.states = Lambda(tf.stack,
                             arguments={
                                 'axis': 1,
                                 'name': "concat_states"
                             })(states)

        return Model(inputs=self.input, outputs=[self.rewards, self.states])

    # def train_epoch(self):
    #     loop = tqdm(range(self.config.num_iter_per_epoch))
    #     losses = []
    #     accs = []
    #     for _ in loop:
    #         loss, acc = self.train_step()
    #         losses.append(loss)
    #         accs.append(acc)
    #     loss = np.mean(losses)
    #     acc = np.mean(accs)
    #
    #     cur_it = self.model.global_step_tensor.eval(self.sess)
    #     summaries_dict = {
    #         'loss': loss,
    #         'acc': acc,
    #     }
    #     self.logger.summarize(cur_it, summaries_dict=summaries_dict)
    #     self.model.save(self.sess)

    # @tf.function
    @tf.function
    def train_step(self, z):
        with tf.GradientTape() as tape:
            # training=True is only needed if there are layers with different
            # behavior during training versus inference (e.g. Dropout).
            rewards, states = self.model(z)
            loss = -K.mean(rewards)
        gradients = tape.gradient(loss, self.actor.model.trainable_weights)
        self.optimizer.apply_gradients(
            zip(gradients, self.actor.model.trainable_weights))

        return -loss, states[:, -1, :]
        # train_loss(loss)
        # train_accuracy(labels, predictions)

    def train(self):
        def lr(k):
            return self.config.learning_rate * (
                self.config.learning_rate_decay
                **(k // (self.config.num_epochs //
                         self.config.learning_rate_decay_steps)))

        template = "Epoch: {:05d}\tLearning Rate: {:2.2e}\tAverage Reward: {:8.5f} "

        z = tf.random.uniform([self.config.batch_size, self.input_dim])
        for k in range(self.config.num_epochs):
            I, z = self.train_step(z)
            if k % self.config.eval_freq == 0:
                average_reward, state_histogram = self.test(
                    self.config.eval_len)
                with open(os.path.join(self.config.summary_dir, "log.txt"),
                          'a') as f:
                    f.write(template.format(k, lr(k), average_reward) + "\n")
                print(template.format(k, lr(k), average_reward))

        average_reward, state_histogram = self.test(self.config.eval_long_len)
        with open(os.path.join(self.config.summary_dir, "log.txt"), 'a') as f:
            f.write("Epoch: {}\tAverage Reward: {:8.5f} \n".format(
                "Final", average_reward))
        print("Epoch: {}\tAverage Reward: {:8.5f} ".format(
            "Final", average_reward))

        state_clusters = KMeans(
            n_clusters=self.config.n_clusters).fit(state_histogram)
        with open(os.path.join(self.config.summary_dir, "log.txt"), 'a') as f:
            f.write("Clusters:\n")
            f.writelines(
                ['{}\n'.format(x) for x in state_clusters.cluster_centers_])
        print(*['{}\n'.format(x) for x in state_clusters.cluster_centers_])

    def test(self, eval_len):
        state_histogram = list()
        self.mean_metric.reset_states()
        z = tf.random.uniform([self.config.batch_size_eval, self.input_dim])
        for k in range(eval_len):
            r, next_states = self.model.predict(z)
            z = tf.squeeze(next_states[:, -1, :])
            if k > eval_len // 10:
                self.mean_metric(r)
                state_histogram.append(next_states)

        return self.mean_metric.result(), tf.reshape(
            tf.concat(state_histogram, axis=0), [-1, self.input_dim])
示例#26
0
def low_level_train(optimizer, yolo_loss, train_datasets, valid_datasets, train_steps, valid_steps):
    """
    以底层的方式训练,这种方式更好地观察训练过程,监视变量的变化
    :param optimizer: 优化器
    :param yolo_loss: 自定义的loss function
    :param train_datasets: 以tf.data封装好的训练集数据
    :param valid_datasets: 验证集数据
    :param train_steps: 迭代一个epoch的轮次
    :param valid_steps: 同上
    :return: None
    """
    # 创建模型结构
    model = yolo_body()

    # 定义模型评估指标
    train_loss = Mean(name='train_loss')
    valid_loss = Mean(name='valid_loss')

    # 设置保存最好模型的指标
    best_test_loss = float('inf')
    patience = 10
    min_delta = 1e-3
    patience_cnt = 0
    history_loss = []

    # 创建summary
    summary_writer = tf.summary.create_file_writer(logdir=cfg.log_dir)

    # low level的方式计算loss
    for epoch in range(1, cfg.epochs + 1):
        train_loss.reset_states()
        valid_loss.reset_states()
        step = 0
        print("Epoch {}/{}".format(epoch, cfg.epochs))

        # 处理训练集数据
        for batch, (images, labels) in enumerate(train_datasets.take(train_steps)):
            with tf.GradientTape() as tape:
                # 得到预测
                outputs = model(images, training=True)
                # 计算损失(注意这里收集model.losses的前提是Conv2D的kernel_regularizer参数)
                regularization_loss = tf.reduce_sum(model.losses)
                pred_loss = []
                # yolo_loss、label、output都是3个特征层的数据,通过for 拆包之后,一个loss_fn就是yolo_loss中一个特征层
                # 然后逐一计算,
                for output, label, loss_fn in zip(outputs, labels, yolo_loss):
                    pred_loss.append(loss_fn(label, output))

                # 总损失 = yolo损失 + 正则化损失
                total_train_loss = tf.reduce_sum(pred_loss) + regularization_loss

            # 反向传播梯度下降
            # model.trainable_variables代表把loss反向传播到每个可以训练的变量中
            grads = tape.gradient(total_train_loss, model.trainable_variables)
            # 将每个节点的误差梯度gradients,用于更新该节点的可训练变量值
            # zip是把梯度和可训练变量值打包成元组
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            # 更新train_loss
            train_loss.update_state(total_train_loss)
            # 输出训练过程
            rate = (step + 1) / train_steps
            a = "*" * int(rate * 70)
            b = "." * int((1 - rate) * 70)
            loss = train_loss.result().numpy()

            print("\r{}/{} {:^3.0f}%[{}->{}] - loss:{:.4f}".
                  format(batch, train_steps, int(rate * 100), a, b, loss), end='')
            step += 1

        # 计算验证集
        for batch, (images, labels) in enumerate(valid_datasets.take(valid_steps)):
            # 得到预测,不training
            outputs = model(images)
            regularization_loss = tf.reduce_sum(model.losses)
            pred_loss = []
            for output, label, loss_fn in zip(outputs, labels, yolo_loss):
                pred_loss.append(loss_fn(label, output))

            total_valid_loss = tf.reduce_sum(pred_loss) + regularization_loss

            # 更新valid_loss
            valid_loss.update_state(total_valid_loss)

        print('\nLoss: {:.4f}, Test Loss: {:.4f}\n'.format(train_loss.result(), valid_loss.result()))
        # 保存loss,可以选择train的loss
        history_loss.append(valid_loss.result().numpy())

        # 保存到tensorboard里
        with summary_writer.as_default():
            tf.summary.scalar('train_loss', train_loss.result(), step=optimizer.iterations)
            tf.summary.scalar('valid_loss', valid_loss.result(), step=optimizer.iterations)

        # 只保存最好模型
        if valid_loss.result() < best_test_loss:
            best_test_loss = valid_loss.result()
            model.save_weights(cfg.model_path, save_format='tf')

        # EarlyStopping
        if epoch > 1 and history_loss[epoch - 2] - history_loss[epoch - 1] > min_delta:
            patience_cnt = 0
        else:
            patience_cnt += 1

        if patience_cnt >= patience:
            tf.print("No improvement for {} times, early stopping optimization.".format(patience))
            break
示例#27
0
class Trainer(object):
    """
    Train a network and manage weights loading and saving
    
    ...
    
    Attributes
    ----------
    model: obj
        model to be trained
    band: string
        band to train with
    image_hr_size: int
        size of the HR image
    name_net: string
        name of the network
    loss: obj
        loss function
    metric: obj
        metric function
    optimizer: obj
        optimizer of the training
    checkpoint_dir: string
        weights path
    log_dir: string
        logs path
 
    Methods
    -------
    restore()
        Restore a previous version found in 'checkpoint_dir' path
    fit(self, x=None, y=None, batch_size=None, buffer_size=512, epochs=100,
            verbose=1, evaluate_every=100, val_steps=100,
            validation_data=None, shuffle=True, initial_epoch=0, save_best_only=True,
           data_aug = False)
        Train the network with the configuration passed to the function
    train_step(self, lr, hr, mask)
        A single training step
    test_step(self, lr, hr, mask)
        A single testing step
    """
    def __init__(self,
                 model,
                 band,
                 image_hr_size,
                 name_net,
                 loss,
                 metric,
                 optimizer,
                 checkpoint_dir='./checkpoint',
                 log_dir='logs'):

        self.now = None
        self.band = band
        self.name_net = name_net
        self.loss = loss
        self.image_hr_size = image_hr_size
        self.metric = metric
        self.log_dir = log_dir
        self.train_loss = Mean(name='train_loss')
        self.train_psnr = Mean(name='train_psnr')

        self.test_loss = Mean(name='test_loss')
        self.test_psnr = Mean(name='test_psnr')
        self.checkpoint = tf.train.Checkpoint(step=tf.Variable(0),
                                              psnr=tf.Variable(1.0),
                                              optimizer=optimizer,
                                              model=model)
        self.checkpoint_manager = tf.train.CheckpointManager(
            checkpoint=self.checkpoint,
            directory=checkpoint_dir,
            max_to_keep=3)

        self.restore()

    def restore(self):
        if self.checkpoint_manager.latest_checkpoint:
            self.checkpoint.restore(self.checkpoint_manager.latest_checkpoint)
            print(
                f'Model restored from checkpoint at step {self.checkpoint.step.numpy()}.'
            )

    @property
    def model(self):
        return self.checkpoint.model

    def fit(self,
            x=None,
            y=None,
            batch_size=None,
            buffer_size=512,
            epochs=100,
            verbose=1,
            evaluate_every=100,
            val_steps=100,
            validation_data=None,
            shuffle=True,
            initial_epoch=0,
            save_best_only=True,
            data_aug=False):

        ds_len = x.shape[0]
        # Create dataset from slices
        train_ds = tf.data.Dataset.from_tensor_slices((x, *y)).shuffle(
            buffer_size,
            reshuffle_each_iteration=True).batch(batch_size).prefetch(
                tf.data.experimental.AUTOTUNE)

        if data_aug:
            train_ds.map(random_rotate,
                         num_parallel_calls=tf.data.experimental.AUTOTUNE)
            train_ds.map(random_flip,
                         num_parallel_calls=tf.data.experimental.AUTOTUNE)

        val_ds = tf.data.Dataset.from_tensor_slices(
            (validation_data[0], *validation_data[1]
             )).shuffle(buffer_size).batch(batch_size).prefetch(
                 tf.data.experimental.AUTOTUNE).take(val_steps)

        # Tensorboard logger
        writer_train = tf.summary.create_file_writer(
            os.path.join(self.log_dir, f'train_{self.band}_{self.name_net}'))
        writer_test = tf.summary.create_file_writer(
            os.path.join(self.log_dir, f'test_{self.band}_{self.name_net}'))

        global_step = tf.cast(self.checkpoint.step, tf.int64)
        total_steps = tf.cast(ds_len / batch_size, tf.int64)
        step = tf.cast(self.checkpoint.step, tf.int64) % total_steps

        for epoch in range(epochs - initial_epoch):
            # Iterate over the batches of the dataset.
            print("\nEpoch {}/{}".format(epoch + 1 + initial_epoch, epochs))
            pb_i = Progbar(
                ds_len,
                stateful_metrics=['Loss', 'PSNR', 'Val Loss', 'Val PSNR'])

            for x_batch_train, y_batch_train, y_mask_batch_train in train_ds:
                if (total_steps - step) == 0:
                    step = tf.cast(self.checkpoint.step,
                                   tf.int64) % total_steps

                    # Reset metrics
                    self.train_loss.reset_states()
                    self.train_psnr.reset_states()

                step += 1
                global_step += 1
                self.train_step(x_batch_train, y_batch_train,
                                y_mask_batch_train)

                self.checkpoint.step.assign_add(1)

                with writer_train.as_default():
                    tf.summary.scalar('PSNR',
                                      self.train_psnr.result(),
                                      step=global_step)

                    tf.summary.scalar('Loss',
                                      self.train_loss.result(),
                                      step=global_step)

                if step != 0 and (step % evaluate_every) == 0:
                    # Reset states for test
                    self.test_loss.reset_states()
                    self.test_psnr.reset_states()

                    for x_batch_val, y_batch_val, y_mask_batch_val in val_ds:
                        self.test_step(x_batch_val, y_batch_val,
                                       y_mask_batch_val)

                    with writer_test.as_default():
                        tf.summary.scalar('Loss',
                                          self.test_loss.result(),
                                          step=global_step)
                        tf.summary.scalar('PSNR',
                                          self.test_psnr.result(),
                                          step=global_step)

                    writer_train.flush()
                    writer_test.flush()

                    if save_best_only and (self.test_psnr.result() <=
                                           self.checkpoint.psnr):
                        # skip saving checkpoint, no PSNR improvement
                        continue
                    self.checkpoint.psnr = self.test_psnr.result()
                    self.checkpoint_manager.save()

                values = [('Loss', self.train_loss.result()),
                          ('PSNR', self.train_psnr.result()),
                          ('Val Loss', self.test_loss.result()),
                          ('Val PSNR', self.test_psnr.result())]
                pb_i.add(batch_size, values=values)

    @tf.function
    def train_step(self, lr, hr, mask):
        lr = tf.cast(lr, tf.float32)

        with tf.GradientTape() as tape:

            sr = self.checkpoint.model(lr, training=True)
            loss = self.loss(hr, sr, mask, self.image_hr_size)

        gradients = tape.gradient(loss,
                                  self.checkpoint.model.trainable_variables)
        self.checkpoint.optimizer.apply_gradients(
            zip(gradients, self.checkpoint.model.trainable_variables))

        metric = self.metric(hr, sr, mask)
        self.train_loss(loss)
        self.train_psnr(metric)

    @tf.function
    def test_step(self, lr, hr, mask):
        lr = tf.cast(lr, tf.float32)

        sr = self.checkpoint.model(lr, training=False)
        t_loss = self.loss(hr, sr, mask, self.image_hr_size)
        t_metric = self.metric(hr, sr, mask)

        self.test_loss(t_loss)
        self.test_psnr(t_metric)
示例#28
0
class ModelTrainer:
    """
    Note:
    Having this model keeps the trainStep and testStep instance new every time you call it.
    Implementing those functions outside a class will return an error
    ValueError: Creating variables on a non-first call to a function decorated with tf.function.
    """
    def __init__(self,
                 model,
                 loss,
                 metric,
                 optimizer,
                 ckptDir,
                 logDir,
                 strategy,
                 multiGPU=True,
                 evalStep=10):

        # Safety checks
        if not os.path.exists(ckptDir):
            os.makedirs(ckptDir)
        if not os.path.exists(logDir):
            os.makedirs(logDir)

        self.ckpt = tf.train.Checkpoint(step=tf.Variable(0),
                                        psnr=tf.Variable(1.0),
                                        optimizer=optimizer,
                                        model=model)
        self.ckptMngr = tf.train.CheckpointManager(checkpoint=self.ckpt,
                                                   directory=ckptDir,
                                                   max_to_keep=5)
        self.loss = loss
        self.metric = metric
        self.logDir = logDir
        self.trainLoss = Mean(name='trainLoss')
        self.trainPSNR = Mean(name='trainPSNR')
        self.testLoss = Mean(name='testLoss')
        self.testPSNR = Mean(name='testPSNR')
        self.evalStep = evalStep
        self.multiGPU = multiGPU
        self.strategy = strategy
        self.restore()

    @property
    def model(self):
        return self.ckpt.model

    def restore(self):
        if self.ckptMngr.latest_checkpoint:
            self.ckpt.restore(self.ckptMngr.latest_checkpoint)
            print(
                f'[ INFO ] Model restored from checkpoint at step {self.ckpt.step.numpy()}.'
            )

    def fitTrainData(self,
                     X: tf.Tensor,
                     y: tf.Tensor,
                     batchSize: int,
                     epochs: int,
                     valData: List[np.ma.array],
                     bufferSize: int = 256,
                     valSteps: int = 128,
                     saveBestOnly: bool = True,
                     initEpoch: int = 0):
        if self.multiGPU:
            logger.info('[ INFO ] Multi-GPU mode selected...')
            logger.info('[ INFO ] Instantiate strategy...')
            batchSizePerReplica = batchSize
            globalBatchSize = batchSizePerReplica * self.strategy.num_replicas_in_sync
        else:
            globalBatchSize = batchSize

        logger.info('[ INFO ] Loading data set to buffer cache...')
        trainSet = loadTrainDataAsTFDataSet(X, y[0], y[1], epochs,
                                            globalBatchSize, bufferSize)
        valSet = loadValDataAsTFDataSet(valData[0], valData[1], valData[2],
                                        valSteps, globalBatchSize, bufferSize)
        logger.info('[ INFO ] Loading success...')

        if self.multiGPU:
            logger.info('[ INFO ] Distributing train set...')
            trainSet = self.strategy.experimental_distribute_dataset(trainSet)
            logger.info('[ INFO ] Distributing test set...')
            valSet = self.strategy.experimental_distribute_dataset(valSet)

        w = tf.summary.create_file_writer(self.logDir)

        dataSetLength = len(X)
        totalSteps = tf.cast(dataSetLength / globalBatchSize, tf.int64)
        globalStep = tf.cast(self.ckpt.step, tf.int64)
        step = globalStep % totalSteps
        epoch = initEpoch

        logger.info('[ INFO ] Begin training...')
        with w.as_default():
            for x_batch_train, y_batch_train, y_mask_batch_train in trainSet:
                if (totalSteps - step) == 0:
                    epoch += 1
                    step = tf.cast(self.ckpt.step, tf.int64) % totalSteps
                    logger.info(f'[ NEW EPOCH ] Epoch number {epoch}')
                    # Reset metrics
                    self.trainLoss.reset_states()
                    self.trainPSNR.reset_states()
                    self.testLoss.reset_states()
                    self.testPSNR.reset_states()

                step += 1
                globalStep += 1
                self.trainDistStep(x_batch_train, y_batch_train,
                                   y_mask_batch_train)
                self.ckpt.step.assign_add(1)

                t = f"[ EPOCH {epoch}/{epochs} ] Step {step}/{int(totalSteps)}, Loss: {self.trainLoss.result():.3f}, cPSNR: {self.trainPSNR.result():.3f}"
                logger.info(t)

                tf.summary.scalar('Train PSNR',
                                  self.trainPSNR.result(),
                                  step=globalStep)
                tf.summary.scalar('Train loss',
                                  self.trainLoss.result(),
                                  step=globalStep)

                if step != 0 and (step % self.evalStep) == 0:
                    # Reset states for test
                    self.testLoss.reset_states()
                    self.testPSNR.reset_states()
                    for x_batch_val, y_batch_val, y_mask_batch_val in valSet:
                        self.testDistStep(x_batch_val, y_batch_val,
                                          y_mask_batch_val)
                    tf.summary.scalar('Test loss',
                                      self.testLoss.result(),
                                      step=globalStep)
                    tf.summary.scalar('Test PSNR',
                                      self.testPSNR.result(),
                                      step=globalStep)
                    t = f"[ VAL INFO ] Validation Loss: {self.testLoss.result():.3f}, Validation PSNR: {self.testPSNR.result():.3f}"
                    logger.info(t)
                    w.flush()

                    if saveBestOnly and (self.testPSNR.result() <=
                                         self.ckpt.psnr):
                        continue

                    logger.info('[ SAVE ] Saving checkpoint...')
                    self.ckpt.psnr = self.testPSNR.result()
                    self.ckptMngr.save()

    def computeLoss(self, patchHR, maskHR, predPatchHR):
        loss = tf.reduce_sum(self.loss(patchHR, maskHR,
                                       predPatchHR)) * (1.0 / self.batchSize)
        loss += (sum(self.ckpt.model.losses) * 1.0 /
                 self.strategy.num_replicas_in_sync)
        return loss

    def calcMetric(self, patchHR, maskHR, predPatchHR):
        return self.metric(patchHR, maskHR, predPatchHR)

    @tf.function
    def trainStep(self, patchLR, patchHR, maskHR):
        with tf.GradientTape() as tape:
            predPatchHR = self.ckpt.model(patchLR, training=True)
            # Loss(patchHR: tf.Tensor, maskHR: tf.Tensor, predPatchHR: tf.Tensor)
            loss = self.loss(patchHR, maskHR, predPatchHR)

        gradients = tape.gradient(loss, self.ckpt.model.trainable_variables)
        self.ckpt.optimizer.apply_gradients(
            zip(gradients, self.ckpt.model.trainable_variables))
        return loss

    @tf.function
    def testStep(self, patchLR, patchHR, maskHR):
        predPatchHR = self.ckpt.model(patchLR, training=False)
        loss = self.loss(patchHR, maskHR, predPatchHR)
        return loss

    @tf.function
    def trainDistStep(self, patchLR, patchHR, maskHR):
        perExampleLosses = self.strategy.experimental_run_v2(self.trainStep,
                                                             args=(patchLR,
                                                                   patchHR,
                                                                   maskHR))
        perExampleMetric = self.strategy.experimental_run_v2(self.calcMetric,
                                                             args=(patchLR,
                                                                   patchHR,
                                                                   maskHR))
        meanLoss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                        perExampleLosses,
                                        axis=0)
        meanMetric = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                          perExampleMetric,
                                          axis=0)
        self.trainLoss(meanLoss)
        self.trainPSNR(meanMetric)

    @tf.function
    def testDistStep(self, patchLR, patchHR, maskHR):
        perExampleLosses = self.strategy.experimental_run_v2(self.testStep,
                                                             args=(patchLR,
                                                                   patchHR,
                                                                   maskHR))
        perExampleMetric = self.strategy.experimental_run_v2(self.calcMetric,
                                                             args=(patchLR,
                                                                   patchHR,
                                                                   maskHR))
        meanLoss = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                        perExampleLosses,
                                        axis=0)
        meanMetric = self.strategy.reduce(tf.distribute.ReduceOp.MEAN,
                                          perExampleMetric,
                                          axis=0)
        self.testLoss(meanLoss)
        self.testPSNR(meanMetric)
class MNIST2MNIST_M_DANN(object):

    def __init__(self,config):
        """
        这是MNINST与MNIST_M域适配网络的初始化函数
        :param config: 参数配置类
        """
        # 初始化参数类
        self.cfg = config

        # 定义相关占位符
        self.grl_lambd = 1.0              # GRL层参数

        # 搭建深度域适配网络
        self.build_DANN()

        # 定义训练和验证损失与指标
        self.loss = categorical_crossentropy
        self.acc = categorical_accuracy

        self.train_loss = Mean("train_loss", dtype=tf.float32)
        self.train_image_cls_loss = Mean("train_image_cls_loss", dtype=tf.float32)
        self.train_domain_cls_loss = Mean("train_domain_cls_loss", dtype=tf.float32)
        self.train_image_cls_acc = Mean("train_image_cls_acc", dtype=tf.float32)
        self.train_domain_cls_acc = Mean("train_domain_cls_acc", dtype=tf.float32)
        self.val_loss = Mean("val_loss", dtype=tf.float32)
        self.val_image_cls_loss = Mean("val_image_cls_loss", dtype=tf.float32)
        self.val_domain_cls_loss = Mean("val_domain_cls_loss", dtype=tf.float32)
        self.val_image_cls_acc = Mean("val_image_cls_acc", dtype=tf.float32)
        self.val_domain_cls_acc = Mean("val_domain_cls_acc", dtype=tf.float32)

        # 定义优化器
        self.optimizer = tf.keras.optimizers.SGD(self.cfg.init_learning_rate,
                                                 momentum=self.cfg.momentum_rate)

        '''
        # 初始化早停策略
        self.early_stopping = EarlyStopping(min_delta=1e-5, patience=100, verbose=1)
        '''

    def build_DANN(self):
        """
        这是搭建域适配网络的函数
        :return:
        """
        # 定义源域、目标域的图像输入和DANN模型图像输入
        self.image_input = Input(shape=self.cfg.image_input_shape,name="image_input")

        # 域分类器与图像分类器的共享特征
        self.feature_encoder = build_feature_extractor()
        # 获取图像分类结果和域分类结果张量
        self.image_cls_encoder = build_image_classify_extractor()
        self.domain_cls_encoder = build_domain_classify_extractor()

        self.grl = GradientReversalLayer()

        self.dann_model = Model(self.image_input,
                                [self.image_cls_encoder(self.feature_encoder(self.image_input)),
                                 self.domain_cls_encoder(self.grl(self.feature_encoder(self.image_input)))])
        self.dann_model.summary()

        # 导入
        if self.cfg.pre_model_path is not None:
            self.dann_model.load_weights(self.cfg.pre_model_path,by_name=True,skip_mismatch=True)

    def train(self,train_source_datagen,train_target_datagen,
              val_target_datagen,train_iter_num,val_iter_num):
        """
        这是DANN的训练函数
        :param train_source_datagen: 源域训练数据集生成器
        :param train_target_datagen: 目标域训练数据集生成器
        :param val_datagen: 验证数据集生成器
        :param train_iter_num: 每个epoch的训练次数
        :param val_iter_num: 每次验证过程的验证次数
        """
        # 初始化相关文件目录路径
        time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        checkpoint_dir = os.path.join(self.cfg.checkpoints_dir,time)
        if not os.path.exists(checkpoint_dir):
            os.mkdir(checkpoint_dir)

        log_dir = os.path.join(self.cfg.logs_dir, time)
        if not os.path.exists(log_dir):
            os.mkdir(log_dir)

        self.cfg.save_config(time)

        self.writer_hyperparameter = tf.summary.create_file_writer(os.path.join(log_dir,"hyperparameter"))
        self.writer_train = tf.summary.create_file_writer(os.path.join(log_dir,"train"))
        self.writer_val = tf.summary.create_file_writer(os.path.join(log_dir,'validation'))

        print('\n----------- start to train -----------\n')
        with open(os.path.join(log_dir,'log.txt'),'w') as f:
            for ep in np.arange(1,self.cfg.epoch+1,1):
                # 初始化精度条
                self.progbar = Progbar(train_iter_num+1)
                print('Epoch {}/{}'.format(ep, self.cfg.epoch))

                # 进行一个周期的模型训练
                train_loss,train_image_cls_acc = self.train_one_epoch\
                    (train_source_datagen,train_target_datagen,train_iter_num,ep)
                # 进行一个周期的模型验证
                val_loss,val_image_cls_acc = self.eval_one_epoch(val_target_datagen,val_iter_num,ep)
                # 更新进度条
                self.progbar.update(train_iter_num+1, [('val_loss', val_loss),
                                                       ("val_image_acc", val_image_cls_acc)])
                # 损失和指标清零
                self.train_loss.reset_states()
                self.train_image_cls_acc.reset_states()
                self.train_domain_cls_loss.reset_states()
                self.train_image_cls_acc.reset_states()
                self.train_domain_cls_acc.reset_states()
                self.val_loss.reset_states()
                self.val_image_cls_acc.reset_states()
                self.val_domain_cls_loss.reset_states()
                self.val_image_cls_acc.reset_states()
                self.val_domain_cls_acc.reset_states()

                # 保存训练过程中的模型
                str = "Epoch{:03d}-train_loss-{:.3f}-val_loss-{:.3f}-train_imgae_cls_acc-{:.3f}-val_image_cls_acc-{:.3f}"\
                    .format(ep, train_loss, val_loss,train_image_cls_acc,val_image_cls_acc)
                print(str)
                f.write(str+"\n")           # 写入日志文件
                self.dann_model.save(os.path.join(checkpoint_dir, str + ".h5"))

                '''
                # 判断是否需要早停模型训练过程,判断指标为目标域的图像分类精度
                stop_training = self.early_stopping.on_epoch_end(ep, val_image_cls_acc)
                if stop_training:
                    break
                '''
        self.dann_model.save(os.path.join(checkpoint_dir, "trained_dann_mnist2mnist_m.h5"))
        print('\n----------- end to train -----------\n')

    def train_one_epoch(self,train_source_datagen,train_target_datagen,train_iter_num,ep):
        """
        这是一个周期模型训练的函数
        :param train_source_datagen: 源域训练数据集生成器
        :param train_target_datagen: 目标域训练数据集生成器
        :param train_iter_num: 一个训练周期的迭代次数
        :param ep: 当前训练周期
        :return:
        """
        for i in np.arange(1, train_iter_num + 1):
            # 获取小批量数据集及其图像标签与域标签
            batch_mnist_image_data, batch_mnist_labels = train_source_datagen.__next__()  # train_source_datagen.next_batch()
            batch_mnist_m_image_data, batch_mnist_m_labels = train_target_datagen.__next__()  # train_target_datagen.next_batch()
            batch_domain_labels = np.vstack([np.tile([1., 0.], [len(batch_mnist_labels), 1]),
                                             np.tile([0., 1.], [len(batch_mnist_m_labels), 1])]).astype(np.float32)
            batch_image_data = np.concatenate([batch_mnist_image_data, batch_mnist_m_image_data], axis=0)
            # 更新学习率并可视化
            iter = (ep - 1) * train_iter_num + i
            process = iter * 1.0 / (self.cfg.epoch * train_iter_num)
            self.grl_lambd = grl_lambda_schedule(process)
            learning_rate = learning_rate_schedule(process, init_learning_rate=self.cfg.init_learning_rate)
            tf.keras.backend.set_value(self.optimizer.lr, learning_rate)
            with self.writer_hyperparameter.as_default():
                tf.summary.scalar("hyperparameter/learning_rate", tf.convert_to_tensor(learning_rate), iter)
                tf.summary.scalar("hyperparameter/grl_lambda", tf.convert_to_tensor(self.grl_lambd), iter)

            # 计算图像分类损失梯度
            with tf.GradientTape() as tape:
                # 计算图像分类预测输出、损失和精度
                image_cls_feature = self.feature_encoder(batch_mnist_image_data)
                image_cls_pred = self.image_cls_encoder(image_cls_feature,training=True)
                image_cls_loss = self.loss(batch_mnist_labels,image_cls_pred)
                image_cls_acc = self.acc(batch_mnist_labels, image_cls_pred)

                # 计算域分类预测输出、损失和精度
                domain_cls_feature = self.feature_encoder(batch_image_data)
                domain_cls_pred = self.domain_cls_encoder(self.grl(domain_cls_feature, self.grl_lambd),
                                                          training=True)
                domain_cls_loss = self.loss(batch_domain_labels, domain_cls_pred)
                domain_cls_acc = self.acc(batch_domain_labels, domain_cls_pred)

                # 计算训练损失、图像分类精度和域分类精度
                loss = tf.reduce_mean(image_cls_loss) + tf.reduce_mean(domain_cls_loss)
            # 自定义优化过程
            vars = tape.watched_variables()
            grads = tape.gradient(loss, vars)
            self.optimizer.apply_gradients(zip(grads, vars))

            # 计算平均损失与精度
            self.train_loss(loss)
            self.train_image_cls_loss(image_cls_loss)
            self.train_domain_cls_loss(domain_cls_loss)
            self.train_image_cls_acc(image_cls_acc)
            self.train_domain_cls_acc(domain_cls_acc)

            # 更新进度条
            self.progbar.update(i, [('loss', loss),
                               ('image_cls_loss', image_cls_loss),
                               ('domain_cls_loss', domain_cls_loss),
                               ("image_acc", image_cls_acc),
                               ("domain_acc", domain_cls_acc)])
        # 可视化损失与指标
        with self.writer_train.as_default():
            tf.summary.scalar("loss/loss", self.train_loss.result(), ep)
            tf.summary.scalar("loss/image_cls_loss", self.train_image_cls_loss.result(), ep)
            tf.summary.scalar("loss/domain_cls_loss", self.train_domain_cls_loss.result(), ep)
            tf.summary.scalar("acc/image_cls_acc", self.train_image_cls_acc.result(), ep)
            tf.summary.scalar("acc/domain_cls_acc", self.train_domain_cls_acc.result(), ep)

        return self.train_loss.result(),self.train_image_cls_acc.result()

    def eval_one_epoch(self,val_target_datagen,val_iter_num,ep):
        """
        这是一个周期的模型验证函数
        :param val_target_datagen: 目标域验证数据集生成器
        :param val_iter_num: 一个验证周期的迭代次数
        :param ep: 当前验证周期
        :return:
        """
        for i in np.arange(1, val_iter_num + 1):
            # 获取小批量数据集及其图像标签与域标签
            batch_mnist_m_image_data, batch_mnist_m_labels = val_target_datagen.__next__()
            batch_mnist_m_domain_labels = np.tile([0., 1.], [len(batch_mnist_m_labels), 1]).astype(np.float32)

            # 计算目标域数据的图像分类预测输出和域分类预测输出
            target_image_feature = self.feature_encoder(batch_mnist_m_image_data)
            target_image_cls_pred = self.image_cls_encoder(target_image_feature, training=False)
            target_domain_cls_pred = self.domain_cls_encoder(target_image_feature, training=False)

            # 计算目标域预测相关损失
            target_image_cls_loss = self.loss(batch_mnist_m_labels,target_image_cls_pred)
            target_domain_cls_loss = self.loss(batch_mnist_m_domain_labels,target_domain_cls_pred)
            target_loss = tf.reduce_mean(target_image_cls_loss) + tf.reduce_mean(target_domain_cls_loss)
            # 计算目标域图像分类精度
            image_cls_acc = self.acc(batch_mnist_m_labels, target_image_cls_pred)
            domain_cls_acc = self.acc(batch_mnist_m_domain_labels, target_domain_cls_pred)

            # 更新训练损失与训练精度
            self.val_loss(target_loss)
            self.val_image_cls_loss(target_image_cls_loss)
            self.val_domain_cls_loss(domain_cls_acc)
            self.val_image_cls_acc(image_cls_acc)
            self.val_domain_cls_acc(domain_cls_acc)

        # 可视化验证损失及其指标
        with self.writer_val.as_default():
            tf.summary.scalar("loss/loss", self.val_loss.result(), ep)
            tf.summary.scalar("loss/image_cls_loss", self.val_image_cls_loss.result(), ep)
            tf.summary.scalar("loss/domain_cls_loss", self.val_domain_cls_loss.result(), ep)
            tf.summary.scalar("acc/image_cls_acc", self.val_image_cls_acc.result(), ep)
            tf.summary.scalar("acc/domain_cls_acc", self.val_domain_cls_acc.result(), ep)
        return self.val_loss.result(), self.val_image_cls_acc.result()
示例#30
0
    def train_gan(self,
                  train_dataset,
                  valid_dataset,
                  epochs=200000,
                  valid_lr=None,
                  valid_hr=None):
        evaluate_size = epochs / 10
        start = time.time()
        vgg_metric = Mean()
        dls_metric = Mean()
        g_metric = Mean()
        c_metric = Mean()
        epoch = 0

        for lr, hr in train_dataset.take(epochs):
            epoch += 1
            step = tf.convert_to_tensor(epoch, tf.int64)
            vgg_loss, discremenator_loss, generator_loss, content_loss = self.train_gan_step(
                lr, hr)
            vgg_metric(vgg_loss)
            dls_metric(discremenator_loss)
            g_metric(generator_loss)
            c_metric(content_loss)

            if epoch % 50 == 0:
                vgg = vgg_metric.result()
                discriminator_loss_metric = dls_metric.result()
                generator_loss_metric = g_metric.result()
                content_loss_metric = c_metric.result()

                vgg_metric.reset_states()
                dls_metric.reset_states()
                g_metric.reset_states()
                c_metric.reset_states()

                psnr_value = self.evaluate(valid_dataset.take(1))

                print(
                    f'Time for epoch {epoch}/{epochs} is {(time.time() - start):.4f} sec, '
                    f' perceptual loss = {vgg:.4f},'
                    f' generator loss = {generator_loss_metric:.4f},'
                    f' discriminator loss = {discriminator_loss_metric:.4f},'
                    f' content loss = {content_loss_metric:.4f},'
                    f' psnr = {psnr_value:.4f}')

                start = time.time()

                if self.summary_writer is not None:
                    with self.summary_writer.as_default():
                        tf.summary.scalar('generator_loss',
                                          generator_loss_metric,
                                          step=epoch)
                        tf.summary.scalar('content loss',
                                          content_loss_metric,
                                          step=epoch)
                        tf.summary.scalar(
                            'vgg loss = content loss + 0.0001 * gan loss',
                            vgg,
                            step=epoch)
                        tf.summary.scalar('discremenator_loss',
                                          discriminator_loss_metric,
                                          step=epoch)
                        tf.summary.scalar('psnr', psnr_value, step=epoch)

            if epoch % evaluate_size == 0:
                self.util.save_checkpoint(self.checkpoint, epoch)

            if epoch % 5000 == 0:
                self.generate_and_save_images(step, valid_lr, valid_hr)