def check_min_tfversion_marker(item): supported_versions = get_marker_values(item, 'min_tfversion') if len(supported_versions) > 1: raise ValueError(f"Minimum version specified incorrectly as `{supported_versions}`, " "it should be a single version number") if supported_versions: return utils.tf_version_above_equal(supported_versions[0]) else: # marker not specified return True
def _add_default_ProgbarLogger_callback_if_necessary(callbacks, exec_type, verbose): for callback in callbacks: if isinstance(callback, tf_callbacks.ProgbarLogger): return progbar_necessary = _is_progbar_necessary(exec_type, verbose) if progbar_necessary and version_utils.tf_version_above_equal('2.4'): # Always need to use `count_mode` to `steps` callbacks.append(tf_callbacks.ProgbarLogger(count_mode='steps'))
def set_tf_random_seed(seed=42): np.random.seed(seed) tf.random.set_seed(seed) random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) os.environ['TF_CUDNN_DETERMINISTIC'] = '1' # from TF 2.7, 'TF_DETERMINISTIC_OPS' was replaced with `enable_op_determinism` # https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism if version_utils.tf_version_below_equal('2.6'): os.environ['TF_DETERMINISTIC_OPS'] = '1' if version_utils.tf_version_above_equal('2.7'): tf.keras.utils.set_random_seed(seed)
def _get_dataset_slice_per_rank(self, dataset, batch_size, micro_batch_size): if ds_helpers._is_batch_multiple_num_ranks(self.num_ranks, batch_size): dataset = dataset.shard(num_shards=self.num_ranks, index=self.rank) else: dataset = dataset.skip( self.rank) # skip samples up to the starting point for `rank` dataset = dataset.window(size=micro_batch_size, shift=batch_size, stride=self.num_ranks, drop_remainder=False) kwargs = {} if version_utils.tf_version_above_equal('2.2'): kwargs['deterministic'] = True dataset = dataset.interleave( ds_helpers._window_datasets_to_tuples, num_parallel_calls=ds_helpers.autotune_flag(), block_length=micro_batch_size, **kwargs) return dataset
import tarantella as tnt import tarantella.utilities.tf_version as version_utils from tarantella import logger import tarantella.strategy.data_parallel.data_parallel_model as dpm import tarantella.strategy.pipelining.partitioned_model as pm import tarantella.strategy.pipelining.partition_generator as pgen import tarantella.strategy.pipelining.rank_mapper as rmapper import tensorflow as tf # Model parallelism not supportted for TF version < 2.3 TF_DEFAULT_PIPELINING_FLAG = (version_utils.tf_version_above_equal('2.3')) class ModelMeta(type): def __call__(cls, *args, **kwargs): obj = cls._create_tnt_model(*args, **kwargs) return obj def _create_tnt_model(cls, model: tf.keras.Model, parallel_strategy: tnt.ParallelStrategy = tnt.ParallelStrategy.ALL if TF_DEFAULT_PIPELINING_FLAG \ else tnt.ParallelStrategy.DATA, num_pipeline_stages: int = 1): replica_group = tnt.Group() if (tnt.ParallelStrategy.PIPELINING in parallel_strategy) and isinstance(model, tf.keras.Sequential): logger.warn( f"Cannot pipeline a `tf.keras.Sequential` model; disabling model parallelism."
def to_json(self, **kwargs): model_config = self.model._updated_config() model_config['config'] = self.get_config() if tf_version_utils.tf_version_above_equal('2.4'): kwargs['default'] = json_utils.get_json_type return json.dumps(model_config, **kwargs)