Esempio n. 1
0
    def load_checkpoint(self, pretrained_path):
        """Load checkpoint."""
        self.ckpt.restore(pretrained_path)
        self.steps = self.ckpt.steps.numpy()
        self.epochs = self.ckpt.epochs.numpy()
        self._gen_optimizer = self.ckpt.gen_optimizer
        # re-assign iterations (global steps) for gen_optimizer.
        self._gen_optimizer.iterations.assign(tf.cast(self.steps, tf.int64))
        # re-assign iterations (global steps) for dis_optimizer.
        try:
            discriminator_train_start_steps = self.config[
                "discriminator_train_start_steps"]
            discriminator_train_start_steps = tf.math.maximum(
                0, self.steps - discriminator_train_start_steps)
        except Exception:
            discriminator_train_start_steps = self.steps
        self._dis_optimizer = self.ckpt.dis_optimizer
        self._dis_optimizer.iterations.assign(
            tf.cast(discriminator_train_start_steps, tf.int64))

        # load weights.
        utils.load_weights(
            self._generator,
            self.saved_path + "generator-{}.h5".format(self.steps))
        utils.load_weights(
            self._discriminator,
            self.saved_path + "discriminator-{}.h5".format(self.steps))
Esempio n. 2
0
    def load_checkpoint(self, pretrained_path):
        """Load checkpoint."""
        self.ckpt.restore(pretrained_path)
        self.steps = self.ckpt.steps.numpy()
        self.epochs = self.ckpt.epochs.numpy()
        self._optimizer = self.ckpt.optimizer
        # re-assign iterations (global steps) for optimizer.
        self._optimizer.iterations.assign(tf.cast(self.steps, tf.int64))

        # load weights.
        utils.load_weights(self._model,
                           self.saved_path + "model-{}.h5".format(self.steps))