Beispiel #1
0
def save_checkpoint(checkpoint, training_phase):
    """ Saves checkpoint.
      Args:
        checkpoint: tf.train.Checkpoint object
        training_phase: The training phase of the model to load/store the checkpoint for.
                        can be one of the two "phase_1" or "phase_2"
  """
    dir_ = settings.Settings()["checkpoint_path"][training_phase]
    dir_ = os.path.join(dir_, os.path.basename(dir_))
    checkpoint.save(file_prefix=dir_)
Beispiel #2
0
def save_checkpoint(checkpoint, training_phase, basepath=""):
    """ Saves checkpoint.
      Args:
        checkpoint: tf.train.Checkpoint object
        training_phase: The training phase of the model to load/store the checkpoint for.
                        can be one of the two "phase_1" or "phase_2"
        basepath: Base path to load checkpoints from.
  """
    dir_ = settings.Settings()["checkpoint_path"][training_phase]
    if basepath:
        dir_ = os.path.join(basepath, dir_)
    dir_ = os.path.join(dir_, os.path.basename(dir_))
    checkpoint.save(file_prefix=dir_)
    logging.debug("Prefix: %s. checkpoint saved successfully!" % dir_)
Beispiel #3
0
def load_checkpoint(checkpoint, training_phase):
    """ Saves checkpoint.
      Args:
        checkpoint: tf.train.Checkpoint object
        training_phase: The training phase of the model to load/store the checkpoint for.
                        can be one of the two "phase_1" or "phase_2"
        assert_consumed: assert all the restored variables are consumed in the model
  """
    logging.info("Loading check point for: %s" % training_phase)
    dir_ = settings.Settings()["checkpoint_path"][training_phase]
    if tf.io.gfile.exists(os.path.join(dir_, "checkpoint")):
        logging.info("Found checkpoint at: %s" % dir_)
        status = checkpoint.restore(tf.train.latest_checkpoint(dir_))
        return status
Beispiel #4
0
def load_checkpoint(checkpoint, training_phase, basepath=""):
    """ Saves checkpoint.
      Args:
        checkpoint: tf.train.Checkpoint object
        training_phase: The training phase of the model to load/store the checkpoint for.
                        can be one of the two "phase_1" or "phase_2"
        basepath: Base Path to load checkpoints from.
  """
    logging.info("Loading check point for: %s" % training_phase)
    dir_ = settings.Settings()["checkpoint_path"][training_phase]
    if basepath:
        dir_ = os.path.join(basepath, dir_)
    if tf.io.gfile.exists(os.path.join(dir_, "checkpoint")):
        logging.info("Found checkpoint at: %s" % dir_)
        status = checkpoint.restore(tf.train.latest_checkpoint(dir_))
        return status
Beispiel #5
0
def main(**kwargs):
  """ Main function for training ESRGAN model and exporting it as a SavedModel2.0
      Args:
        config: path to config yaml file.
        log_dir: directory to store summary for tensorboard.
        data_dir: directory to store / access the dataset.
        manual: boolean to denote if data_dir is a manual directory.
        model_dir: directory to store the model into.
  """

  for physical_device in tf.config.experimental.list_physical_devices("GPU"):
    tf.config.experimental.set_memory_growth(physical_device, True)

  sett = settings.Settings(kwargs["config"])
  Stats = settings.Stats(os.path.join(sett.path, "stats.yaml"))
  summary_writer = tf.summary.create_file_writer(kwargs["log_dir"])
  profiler.start_profiler_server(6009)
  generator = model.RRDBNet(out_channel=3)
  discriminator = model.VGGArch()

  training = train.Trainer(
      summary_writer=summary_writer,
      settings=sett,
      data_dir=kwargs["data_dir"],
      manual=kwargs["manual"])
  phases = list(map(lambda x: x.strip(),
                    kwargs["phases"].lower().split("_")))

  if not Stats["train_step_1"] and "phase1" in phases:
    logging.info("starting phase 1")
    training.warmup_generator(generator)
    Stats["train_step_1"] = True
  if not Stats["train_step_2"] and "phase2" in phases:
    logging.info("starting phase 2")
    training.train_gan(generator, discriminator)
    Stats["train_step_2"] = True

  if Stats["train_step_1"] and Stats["train_step_2"]:
    # Attempting to save "Interpolated" Model as SavedModel2.0
    interpolated_generator = utils.interpolate_generator(
        partial(model.RRDBNet, out_channel=3),
        discriminator,
        sett["interpolation_parameter"],
        sett["dataset"]["hr_dimension"])
    tf.saved_model.save(interpolated_generator, kwargs["model_dir"])
Beispiel #6
0
    def __init__(self, out_features=32, bias=True):
        super(RDB, self).__init__()

        _create_conv2d = partial(tf.keras.layers.Conv2D,
                                 out_features,
                                 kernel_size=[3, 3],
                                 strides=[1, 1],
                                 padding="same",
                                 use_bias=bias)
        self._conv2d_layers = {
            "conv_1": _create_conv2d(),
            "conv_2": _create_conv2d(),
            "conv_3": _create_conv2d(),
            "conv_4": _create_conv2d(),
            "conv_5": _create_conv2d()
        }
        self._lrelu = tf.keras.layers.LeakyReLU(alpha=0.2)
        self._beta = settings.Settings()["RDB"].get("residual_scale_beta", 0.2)
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import os
import sys

PATH = os.path.dirname(os.path.realpath(__file__))
sys.path.insert(1, os.path.join(PATH, 'lib'))

from lib import settings, om_manager, configure_opsman_director, configure_ert, sqs, wait_condition
from lib import util, accept_eula, download_and_import

my_settings = settings.Settings()
asset_path = '/home/ubuntu/tiles'

max_retries = 5


def check_exit_code_success(exit_code):
    print("exit_code {}".format(exit_code))
    return exit_code == 0


def check_cr_return_code(out, err, return_code, step_name):
    print("Ran: {}; exit code: {}".format(step_name, exit_code))
    if return_code != 0:
        util.exponential_backoff(
            functools.partial(sqs.report_cr_creation_failure, my_settings,
Beispiel #8
0
def main(**kwargs):
    """ Main function for training ESRGAN model and exporting it as a SavedModel2.0
      Args:
        config: path to config yaml file.
        log_dir: directory to store summary for tensorboard.
        data_dir: directory to store / access the dataset.
        manual: boolean to denote if data_dir is a manual directory.
        model_dir: directory to store the model into.
  """

    for physical_device in tf.config.experimental.list_physical_devices("GPU"):
        tf.config.experimental.set_memory_growth(physical_device, True)
    strategy = utils.SingleDeviceStrategy()
    scope = utils.assign_to_worker(kwargs["tpu"])
    sett = settings.Settings(kwargs["config"])
    Stats = settings.Stats(os.path.join(sett.path, "stats.yaml"))
    tf.random.set_seed(10)
    if kwargs["tpu"]:
        cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
            kwargs["tpu"])
        tf.config.experimental_connect_to_host(cluster_resolver.get_master())
        tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
        strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
    with tf.device(scope), strategy.scope():
        summary_writer_1 = tf.summary.create_file_writer(
            os.path.join(kwargs["log_dir"], "phase1"))
        summary_writer_2 = tf.summary.create_file_writer(
            os.path.join(kwargs["log_dir"], "phase2"))
        # profiler.start_profiler_server(6009)
        discriminator = model.VGGArch(batch_size=sett["batch_size"],
                                      num_features=64)
        if not kwargs["export_only"]:
            generator = model.RRDBNet(out_channel=3)
            logging.debug("Initiating Convolutions")
            generator.unsigned_call(tf.random.normal([1, 128, 128, 3]))
            training = train.Trainer(summary_writer=summary_writer_1,
                                     summary_writer_2=summary_writer_2,
                                     settings=sett,
                                     model_dir=kwargs["model_dir"],
                                     data_dir=kwargs["data_dir"],
                                     manual=kwargs["manual"],
                                     strategy=strategy)
            phases = list(
                map(lambda x: x.strip(), kwargs["phases"].lower().split("_")))
            if not Stats["train_step_1"] and "phase1" in phases:
                logging.info("starting phase 1")
                training.warmup_generator(generator)
                Stats["train_step_1"] = True
            if not Stats["train_step_2"] and "phase2" in phases:
                logging.info("starting phase 2")
                training.train_gan(generator, discriminator)
                Stats["train_step_2"] = True

    if Stats["train_step_1"] and Stats["train_step_2"]:
        # Attempting to save "Interpolated" Model as SavedModel2.0
        interpolated_generator = utils.interpolate_generator(
            partial(model.RRDBNet, out_channel=3, first_call=False),
            discriminator,
            sett["interpolation_parameter"], [720, 1080],
            basepath=kwargs["model_dir"])
        tf.saved_model.save(interpolated_generator,
                            os.path.join(kwargs["model_dir"], "esrgan"))
Beispiel #9
0
 def __init__(self, out_features=32, first_call=True):
     super(RRDB, self).__init__()
     self.RDB1 = RDB(out_features, first_call=first_call)
     self.RDB2 = RDB(out_features, first_call=first_call)
     self.RDB3 = RDB(out_features, first_call=first_call)
     self.beta = settings.Settings()["RDB"].get("residual_scale_beta", 0.2)
def cli(ctx):
    my_settings = settings.Settings()
    ctx.obj['settings'] = my_settings
Beispiel #11
0
""" Module to load Teacher Models from Teacher Directory """
import os
import sys
from libs import teacher_imports
from lib import settings
settings.Settings("%s/config/config.yaml" % teacher_imports.TEACHER_DIR)
from lib.model import VGGArch as discriminator
from lib.model import RRDBNet as generator