示例#1
0
def load_model(filepath, compile=True, **kwargs):
    logger.debug("Load model from file: {}".format(filepath))
    keras_model = tf.keras.models.load_model(filepath,
                                             compile=compile,
                                             **kwargs)
    # FIXME load models with any type of parallelization strategy
    logger.warning("Loading model with the default `data parallel` strategy.")
    tnt_model = tnt.Model(keras_model,
                          parallel_strategy=tnt.ParallelStrategy.DATA)
    if compile:
        try:
            tnt_optimizer = tnt.distributed_optimizers.SynchDistributedOptimizer(
                keras_model.optimizer, group=tnt_model.group)
            tnt_model.dist_optimizer = tnt_optimizer
            tnt_model._set_internal_optimizer(tnt_model.dist_optimizer)
            tnt_model.compiled = True
            tnt_model.done_broadcast = True

            if version_utils.tf_version_below_equal('2.1'):
                tnt_model.model._experimental_run_tf_function = False
                logger.info("Set `experimental_run_tf_function` to False.")
        except:
            logger.info("The loaded model was not pre-compiled.")
    tnt_model.barrier.execute()
    return tnt_model
示例#2
0
def model_from_yaml(yaml_string, **kwargs):
    logger.debug("Load model from yaml")
    try:
        keras_model = tf.keras.models.model_from_yaml(yaml_string, **kwargs)
        # FIXME load models with any type of parallelization strategy
        logger.warning(
            "Loading model with the default `data parallel` strategy.")
        return tnt.Model(keras_model,
                         parallel_strategy=tnt.ParallelStrategy.DATA)
    except:
        raise RuntimeError("[tnt.models.model_from_yaml] Cannot load model")
示例#3
0
def clone_model(model, **kwargs):
    if isinstance(model, tnt.strategy.parallel_model.ParallelModel):
        keras_model = tf.keras.models.clone_model(model.model, **kwargs)
        logger.info("clone model from instance of tnt.Model")
    elif isinstance(model, tf.keras.Model):
        keras_model = tf.keras.models.clone_model(model, **kwargs)
        logger.info("clone model from instance of tf.keras.Model")
    else:
        raise ValueError("[tnt.models.clone_model] `model` needs to be either",
                         "a `tf.keras.Model`, or a `tnt.Model`")
    # FIXME load models with any type of parallelization strategy
    logger.warning("Loading model with the default `data parallel` strategy.")
    return tnt.Model(keras_model, parallel_strategy=tnt.ParallelStrategy.DATA)
示例#4
0
 def from_config(cls, config, **kwargs):
     try:
         keras_model = tf.keras.Sequential.from_config(config, **kwargs)
         logger.info("Loaded model from `keras.Sequential`.")
     except:
         raise RuntimeError(
             """[tnt.keras.Sequential.from_config] Cannot load
         model; provided configuration is not a `keras.Sequential` model."""
         )
     # FIXME load models with any type of parallelization strategy
     logger.warning(
         "Loading model with the default `data parallel` strategy.")
     return tnt.Model(keras_model,
                      parallel_strategy=tnt.ParallelStrategy.DATA)
示例#5
0
 def from_config(cls, *args, **kwargs):
     # FIXME load models with any type of parallelization strategy
     logger.warning(
         "Loading model with the default `data parallel` strategy.")
     return dpm.DataParallelModel.from_config(*args, **kwargs)
示例#6
0
def model_from_json(json_string, **kwargs):
    logger.debug("Load model from json")
    keras_model = tf.keras.models.model_from_json(json_string, **kwargs)
    # FIXME load models with any type of parallelization strategy
    logger.warning("Loading model with the default `data parallel` strategy.")
    return tnt.Model(keras_model, parallel_strategy=tnt.ParallelStrategy.DATA)