def save_checkpoint(checkpoint, training_phase): """ Saves checkpoint. Args: checkpoint: tf.train.Checkpoint object training_phase: The training phase of the model to load/store the checkpoint for. can be one of the two "phase_1" or "phase_2" """ dir_ = settings.Settings()["checkpoint_path"][training_phase] dir_ = os.path.join(dir_, os.path.basename(dir_)) checkpoint.save(file_prefix=dir_)
def save_checkpoint(checkpoint, training_phase, basepath=""): """ Saves checkpoint. Args: checkpoint: tf.train.Checkpoint object training_phase: The training phase of the model to load/store the checkpoint for. can be one of the two "phase_1" or "phase_2" basepath: Base path to load checkpoints from. """ dir_ = settings.Settings()["checkpoint_path"][training_phase] if basepath: dir_ = os.path.join(basepath, dir_) dir_ = os.path.join(dir_, os.path.basename(dir_)) checkpoint.save(file_prefix=dir_) logging.debug("Prefix: %s. checkpoint saved successfully!" % dir_)
def load_checkpoint(checkpoint, training_phase): """ Saves checkpoint. Args: checkpoint: tf.train.Checkpoint object training_phase: The training phase of the model to load/store the checkpoint for. can be one of the two "phase_1" or "phase_2" assert_consumed: assert all the restored variables are consumed in the model """ logging.info("Loading check point for: %s" % training_phase) dir_ = settings.Settings()["checkpoint_path"][training_phase] if tf.io.gfile.exists(os.path.join(dir_, "checkpoint")): logging.info("Found checkpoint at: %s" % dir_) status = checkpoint.restore(tf.train.latest_checkpoint(dir_)) return status
def load_checkpoint(checkpoint, training_phase, basepath=""): """ Saves checkpoint. Args: checkpoint: tf.train.Checkpoint object training_phase: The training phase of the model to load/store the checkpoint for. can be one of the two "phase_1" or "phase_2" basepath: Base Path to load checkpoints from. """ logging.info("Loading check point for: %s" % training_phase) dir_ = settings.Settings()["checkpoint_path"][training_phase] if basepath: dir_ = os.path.join(basepath, dir_) if tf.io.gfile.exists(os.path.join(dir_, "checkpoint")): logging.info("Found checkpoint at: %s" % dir_) status = checkpoint.restore(tf.train.latest_checkpoint(dir_)) return status
def main(**kwargs): """ Main function for training ESRGAN model and exporting it as a SavedModel2.0 Args: config: path to config yaml file. log_dir: directory to store summary for tensorboard. data_dir: directory to store / access the dataset. manual: boolean to denote if data_dir is a manual directory. model_dir: directory to store the model into. """ for physical_device in tf.config.experimental.list_physical_devices("GPU"): tf.config.experimental.set_memory_growth(physical_device, True) sett = settings.Settings(kwargs["config"]) Stats = settings.Stats(os.path.join(sett.path, "stats.yaml")) summary_writer = tf.summary.create_file_writer(kwargs["log_dir"]) profiler.start_profiler_server(6009) generator = model.RRDBNet(out_channel=3) discriminator = model.VGGArch() training = train.Trainer( summary_writer=summary_writer, settings=sett, data_dir=kwargs["data_dir"], manual=kwargs["manual"]) phases = list(map(lambda x: x.strip(), kwargs["phases"].lower().split("_"))) if not Stats["train_step_1"] and "phase1" in phases: logging.info("starting phase 1") training.warmup_generator(generator) Stats["train_step_1"] = True if not Stats["train_step_2"] and "phase2" in phases: logging.info("starting phase 2") training.train_gan(generator, discriminator) Stats["train_step_2"] = True if Stats["train_step_1"] and Stats["train_step_2"]: # Attempting to save "Interpolated" Model as SavedModel2.0 interpolated_generator = utils.interpolate_generator( partial(model.RRDBNet, out_channel=3), discriminator, sett["interpolation_parameter"], sett["dataset"]["hr_dimension"]) tf.saved_model.save(interpolated_generator, kwargs["model_dir"])
def __init__(self, out_features=32, bias=True): super(RDB, self).__init__() _create_conv2d = partial(tf.keras.layers.Conv2D, out_features, kernel_size=[3, 3], strides=[1, 1], padding="same", use_bias=bias) self._conv2d_layers = { "conv_1": _create_conv2d(), "conv_2": _create_conv2d(), "conv_3": _create_conv2d(), "conv_4": _create_conv2d(), "conv_5": _create_conv2d() } self._lrelu = tf.keras.layers.LeakyReLU(alpha=0.2) self._beta = settings.Settings()["RDB"].get("residual_scale_beta", 0.2)
# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import os import sys PATH = os.path.dirname(os.path.realpath(__file__)) sys.path.insert(1, os.path.join(PATH, 'lib')) from lib import settings, om_manager, configure_opsman_director, configure_ert, sqs, wait_condition from lib import util, accept_eula, download_and_import my_settings = settings.Settings() asset_path = '/home/ubuntu/tiles' max_retries = 5 def check_exit_code_success(exit_code): print("exit_code {}".format(exit_code)) return exit_code == 0 def check_cr_return_code(out, err, return_code, step_name): print("Ran: {}; exit code: {}".format(step_name, exit_code)) if return_code != 0: util.exponential_backoff( functools.partial(sqs.report_cr_creation_failure, my_settings,
def main(**kwargs): """ Main function for training ESRGAN model and exporting it as a SavedModel2.0 Args: config: path to config yaml file. log_dir: directory to store summary for tensorboard. data_dir: directory to store / access the dataset. manual: boolean to denote if data_dir is a manual directory. model_dir: directory to store the model into. """ for physical_device in tf.config.experimental.list_physical_devices("GPU"): tf.config.experimental.set_memory_growth(physical_device, True) strategy = utils.SingleDeviceStrategy() scope = utils.assign_to_worker(kwargs["tpu"]) sett = settings.Settings(kwargs["config"]) Stats = settings.Stats(os.path.join(sett.path, "stats.yaml")) tf.random.set_seed(10) if kwargs["tpu"]: cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver( kwargs["tpu"]) tf.config.experimental_connect_to_host(cluster_resolver.get_master()) tf.tpu.experimental.initialize_tpu_system(cluster_resolver) strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver) with tf.device(scope), strategy.scope(): summary_writer_1 = tf.summary.create_file_writer( os.path.join(kwargs["log_dir"], "phase1")) summary_writer_2 = tf.summary.create_file_writer( os.path.join(kwargs["log_dir"], "phase2")) # profiler.start_profiler_server(6009) discriminator = model.VGGArch(batch_size=sett["batch_size"], num_features=64) if not kwargs["export_only"]: generator = model.RRDBNet(out_channel=3) logging.debug("Initiating Convolutions") generator.unsigned_call(tf.random.normal([1, 128, 128, 3])) training = train.Trainer(summary_writer=summary_writer_1, summary_writer_2=summary_writer_2, settings=sett, model_dir=kwargs["model_dir"], data_dir=kwargs["data_dir"], manual=kwargs["manual"], strategy=strategy) phases = list( map(lambda x: x.strip(), kwargs["phases"].lower().split("_"))) if not Stats["train_step_1"] and "phase1" in phases: logging.info("starting phase 1") training.warmup_generator(generator) Stats["train_step_1"] = True if not Stats["train_step_2"] and "phase2" in phases: logging.info("starting phase 2") training.train_gan(generator, discriminator) Stats["train_step_2"] = True if Stats["train_step_1"] and Stats["train_step_2"]: # Attempting to save "Interpolated" Model as SavedModel2.0 interpolated_generator = utils.interpolate_generator( partial(model.RRDBNet, out_channel=3, first_call=False), discriminator, sett["interpolation_parameter"], [720, 1080], basepath=kwargs["model_dir"]) tf.saved_model.save(interpolated_generator, os.path.join(kwargs["model_dir"], "esrgan"))
def __init__(self, out_features=32, first_call=True): super(RRDB, self).__init__() self.RDB1 = RDB(out_features, first_call=first_call) self.RDB2 = RDB(out_features, first_call=first_call) self.RDB3 = RDB(out_features, first_call=first_call) self.beta = settings.Settings()["RDB"].get("residual_scale_beta", 0.2)
def cli(ctx): my_settings = settings.Settings() ctx.obj['settings'] = my_settings
""" Module to load Teacher Models from Teacher Directory """ import os import sys from libs import teacher_imports from lib import settings settings.Settings("%s/config/config.yaml" % teacher_imports.TEACHER_DIR) from lib.model import VGGArch as discriminator from lib.model import RRDBNet as generator