def load_model(filepath, compile=True, **kwargs): logger.debug("Load model from file: {}".format(filepath)) keras_model = tf.keras.models.load_model(filepath, compile=compile, **kwargs) # FIXME load models with any type of parallelization strategy logger.warning("Loading model with the default `data parallel` strategy.") tnt_model = tnt.Model(keras_model, parallel_strategy=tnt.ParallelStrategy.DATA) if compile: try: tnt_optimizer = tnt.distributed_optimizers.SynchDistributedOptimizer( keras_model.optimizer, group=tnt_model.group) tnt_model.dist_optimizer = tnt_optimizer tnt_model._set_internal_optimizer(tnt_model.dist_optimizer) tnt_model.compiled = True tnt_model.done_broadcast = True if version_utils.tf_version_below_equal('2.1'): tnt_model.model._experimental_run_tf_function = False logger.info("Set `experimental_run_tf_function` to False.") except: logger.info("The loaded model was not pre-compiled.") tnt_model.barrier.execute() return tnt_model
def model_from_yaml(yaml_string, **kwargs): logger.debug("Load model from yaml") try: keras_model = tf.keras.models.model_from_yaml(yaml_string, **kwargs) # FIXME load models with any type of parallelization strategy logger.warning( "Loading model with the default `data parallel` strategy.") return tnt.Model(keras_model, parallel_strategy=tnt.ParallelStrategy.DATA) except: raise RuntimeError("[tnt.models.model_from_yaml] Cannot load model")
def clone_model(model, **kwargs): if isinstance(model, tnt.strategy.parallel_model.ParallelModel): keras_model = tf.keras.models.clone_model(model.model, **kwargs) logger.info("clone model from instance of tnt.Model") elif isinstance(model, tf.keras.Model): keras_model = tf.keras.models.clone_model(model, **kwargs) logger.info("clone model from instance of tf.keras.Model") else: raise ValueError("[tnt.models.clone_model] `model` needs to be either", "a `tf.keras.Model`, or a `tnt.Model`") # FIXME load models with any type of parallelization strategy logger.warning("Loading model with the default `data parallel` strategy.") return tnt.Model(keras_model, parallel_strategy=tnt.ParallelStrategy.DATA)
def from_config(cls, config, **kwargs): try: keras_model = tf.keras.Sequential.from_config(config, **kwargs) logger.info("Loaded model from `keras.Sequential`.") except: raise RuntimeError( """[tnt.keras.Sequential.from_config] Cannot load model; provided configuration is not a `keras.Sequential` model.""" ) # FIXME load models with any type of parallelization strategy logger.warning( "Loading model with the default `data parallel` strategy.") return tnt.Model(keras_model, parallel_strategy=tnt.ParallelStrategy.DATA)
def from_config(cls, *args, **kwargs): # FIXME load models with any type of parallelization strategy logger.warning( "Loading model with the default `data parallel` strategy.") return dpm.DataParallelModel.from_config(*args, **kwargs)
def model_from_json(json_string, **kwargs): logger.debug("Load model from json") keras_model = tf.keras.models.model_from_json(json_string, **kwargs) # FIXME load models with any type of parallelization strategy logger.warning("Loading model with the default `data parallel` strategy.") return tnt.Model(keras_model, parallel_strategy=tnt.ParallelStrategy.DATA)