예제 #1
0
def check_min_tfversion_marker(item):
  supported_versions = get_marker_values(item, 'min_tfversion')
  if len(supported_versions) > 1:
    raise ValueError(f"Minimum version specified incorrectly as `{supported_versions}`, "
                      "it should be a single version number")
  if supported_versions:
    return utils.tf_version_above_equal(supported_versions[0])
  else: # marker not specified
    return True
예제 #2
0
def _add_default_ProgbarLogger_callback_if_necessary(callbacks, exec_type,
                                                     verbose):
    for callback in callbacks:
        if isinstance(callback, tf_callbacks.ProgbarLogger):
            return
    progbar_necessary = _is_progbar_necessary(exec_type, verbose)
    if progbar_necessary and version_utils.tf_version_above_equal('2.4'):
        # Always need to use `count_mode` to `steps`
        callbacks.append(tf_callbacks.ProgbarLogger(count_mode='steps'))
예제 #3
0
def set_tf_random_seed(seed=42):
    np.random.seed(seed)
    tf.random.set_seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

    # from TF 2.7, 'TF_DETERMINISTIC_OPS' was replaced with `enable_op_determinism`
    # https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism
    if version_utils.tf_version_below_equal('2.6'):
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
    if version_utils.tf_version_above_equal('2.7'):
        tf.keras.utils.set_random_seed(seed)
예제 #4
0
    def _get_dataset_slice_per_rank(self, dataset, batch_size,
                                    micro_batch_size):
        if ds_helpers._is_batch_multiple_num_ranks(self.num_ranks, batch_size):
            dataset = dataset.shard(num_shards=self.num_ranks, index=self.rank)
        else:
            dataset = dataset.skip(
                self.rank)  # skip samples up to the starting point for `rank`
            dataset = dataset.window(size=micro_batch_size,
                                     shift=batch_size,
                                     stride=self.num_ranks,
                                     drop_remainder=False)

            kwargs = {}
            if version_utils.tf_version_above_equal('2.2'):
                kwargs['deterministic'] = True
            dataset = dataset.interleave(
                ds_helpers._window_datasets_to_tuples,
                num_parallel_calls=ds_helpers.autotune_flag(),
                block_length=micro_batch_size,
                **kwargs)
        return dataset
예제 #5
0
import tarantella as tnt
import tarantella.utilities.tf_version as version_utils
from tarantella import logger

import tarantella.strategy.data_parallel.data_parallel_model as dpm
import tarantella.strategy.pipelining.partitioned_model as pm
import tarantella.strategy.pipelining.partition_generator as pgen
import tarantella.strategy.pipelining.rank_mapper as rmapper

import tensorflow as tf

# Model parallelism not supportted for TF version < 2.3
TF_DEFAULT_PIPELINING_FLAG = (version_utils.tf_version_above_equal('2.3'))


class ModelMeta(type):
    def __call__(cls, *args, **kwargs):
        obj = cls._create_tnt_model(*args, **kwargs)
        return obj

    def _create_tnt_model(cls, model: tf.keras.Model,
                          parallel_strategy: tnt.ParallelStrategy = tnt.ParallelStrategy.ALL if TF_DEFAULT_PIPELINING_FLAG \
                                                                                             else tnt.ParallelStrategy.DATA,
                          num_pipeline_stages: int = 1):
        replica_group = tnt.Group()

        if (tnt.ParallelStrategy.PIPELINING
                in parallel_strategy) and isinstance(model,
                                                     tf.keras.Sequential):
            logger.warn(
                f"Cannot pipeline a `tf.keras.Sequential` model; disabling model parallelism."
예제 #6
0
 def to_json(self, **kwargs):
   model_config = self.model._updated_config()
   model_config['config'] = self.get_config()
   if tf_version_utils.tf_version_above_equal('2.4'):
     kwargs['default'] = json_utils.get_json_type
   return json.dumps(model_config, **kwargs)