コード例 #1
0
ファイル: models.py プロジェクト: cc-hpc-itwm/tarantella
def load_model(filepath, compile=True, **kwargs):
    logger.debug("Load model from file: {}".format(filepath))
    keras_model = tf.keras.models.load_model(filepath,
                                             compile=compile,
                                             **kwargs)
    # FIXME load models with any type of parallelization strategy
    logger.warning("Loading model with the default `data parallel` strategy.")
    tnt_model = tnt.Model(keras_model,
                          parallel_strategy=tnt.ParallelStrategy.DATA)
    if compile:
        try:
            tnt_optimizer = tnt.distributed_optimizers.SynchDistributedOptimizer(
                keras_model.optimizer, group=tnt_model.group)
            tnt_model.dist_optimizer = tnt_optimizer
            tnt_model._set_internal_optimizer(tnt_model.dist_optimizer)
            tnt_model.compiled = True
            tnt_model.done_broadcast = True

            if version_utils.tf_version_below_equal('2.1'):
                tnt_model.model._experimental_run_tf_function = False
                logger.info("Set `experimental_run_tf_function` to False.")
        except:
            logger.info("The loaded model was not pre-compiled.")
    tnt_model.barrier.execute()
    return tnt_model
コード例 #2
0
ファイル: conftest.py プロジェクト: cc-hpc-itwm/tarantella
def check_max_tfversion_marker(item):
  supported_versions = get_marker_values(item, 'max_tfversion')
  if len(supported_versions) > 1:
    raise ValueError(f"Maximum version specified incorrectly as `{supported_versions}`, "
                      "it should be a single version number")
  if supported_versions:
    return utils.tf_version_below_equal(supported_versions[0])
  else: # marker not specified
    return True
コード例 #3
0
ファイル: utilities.py プロジェクト: cc-hpc-itwm/tarantella
def set_tf_random_seed(seed=42):
    np.random.seed(seed)
    tf.random.set_seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

    # from TF 2.7, 'TF_DETERMINISTIC_OPS' was replaced with `enable_op_determinism`
    # https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism
    if version_utils.tf_version_below_equal('2.6'):
        os.environ['TF_DETERMINISTIC_OPS'] = '1'
    if version_utils.tf_version_above_equal('2.7'):
        tf.keras.utils.set_random_seed(seed)
コード例 #4
0
def validate_local_dataset(ref_dataset, local_dataset, micro_batch_size, rank,
                           comm_size):
    if version_utils.tf_version_below_equal('2.1'):
        # padding implementation not supported in TF version <= 2.1
        return ds_utils.validate_local_dataset(ref_dataset,
                                               local_dataset,
                                               micro_batch_size,
                                               rank,
                                               comm_size,
                                               padded=False)
    else:
        return ds_utils.validate_local_dataset(ref_dataset,
                                               local_dataset,
                                               micro_batch_size,
                                               rank,
                                               comm_size,
                                               padded=True)
コード例 #5
0
def _customize_progbar_logger(progbar_logger: tf.keras.callbacks.ProgbarLogger) -> None:
  if version_utils.tf_version_below_equal('2.2'):
    raise EnvironmentError("[tnt.callbacks.ProgbarLogger] "
                            "`ProgbarLogger` support from TF 2.3")
  # the other ranks only need to participate in averaging logs
  progbar_logger.should_print_progbar = tnt.is_group_master_rank(progbar_logger.group)

  def progbar_logger_distribute_callback(callback_func: Callable,
                                         **kwargs: Any) -> Any:
    if progbar_logger.run_on_all_ranks:
      kwargs_copy = progbar_logger._average_callback_logs(kwargs)    
      if progbar_logger.should_print_progbar:
        return callback_func(**kwargs_copy)
    else:
      if tnt.is_group_master_rank(progbar_logger.group) and progbar_logger.should_print_progbar:
        return callback_func(**kwargs)
  progbar_logger._distribute_callback = progbar_logger_distribute_callback
コード例 #6
0
def _pad_dataset_if_necessary(dataset, num_samples, batch_size,
                              min_last_batch_size):
    last_batch_size = _get_last_incomplete_batch_size(num_samples, batch_size)
    if last_batch_size == 0:
        logger.debug(f"No padding required: number of samples {num_samples} is a multiple " \
                     f"of the batch size {batch_size}.")
        return dataset

    logger.info(f"Incomplete last batch in the dataset: number of samples is " \
                f"{last_batch_size} ( != batch size {batch_size}).")

    if version_utils.tf_version_below_equal('2.1'):
        num_samples_multiple = num_samples - last_batch_size
        logger.warn(f"Number of samples ({num_samples}) is not a multiple of batch size. " \
                    f"This use case is not supported in TF v{version_utils.current_version()}. " \
                    f"Dropping the last incomplete batch from the dataset, "\
                    f"and proceeding with {num_samples_multiple} samples.")
        return dataset.take(num_samples_multiple)

    if last_batch_size < min_last_batch_size:
        logger.debug(f"Padding required for the last batch: number of samples is " \
                     f"{last_batch_size} ( < min_batch_size {min_last_batch_size}).")

        # Create helper dataset that contains one full batch and one incomplete batch
        helper_dataset = dataset.take(min_last_batch_size + last_batch_size)
        helper_dataset = helper_dataset.batch(min_last_batch_size,
                                              drop_remainder=False)

        # If `padded_shape` is unspecified, all dimensions of all components
        # are padded to the maximum size in the batch.
        # The second batch in `helper_dataset` will now contain `min_last_batch_size - last_batch_size`
        # default-initialized samples.
        helper_dataset = helper_dataset.padded_batch(2)

        # Switch back to a list of samples instead of batches
        helper_dataset = helper_dataset.unbatch().unbatch()

        # Remaining samples in the dataset are those generated through padding
        padding_samples = helper_dataset.skip(min_last_batch_size +
                                              last_batch_size)
        dataset = dataset.concatenate(padding_samples)
        logger.info(f"[Rank {tnt.get_rank()}] Dataset padded with " \
                    f"{min_last_batch_size - last_batch_size} samples.")
    return dataset
コード例 #7
0
    def reduce_gradients(self, gradients_and_weights):
        gradients_to_reduce = list()
        for grad, weight in gradients_and_weights:
            # add an Allreduce operation for each gradient
            grad_id = self.weight_to_index[weight.name]
            number_partial_sums = tnt.get_size()
            grad = grad / number_partial_sums
            output_grad = tnt_ops.start_allreduce_op(
                grad, tensor_id=grad_id, tnt_synchcomm=self.comm.get_raw_ptr())
            gradients_to_reduce.append(output_grad)

        # Create barrier op in the Tensorflow graph to make sure all
        # the Allreduce operations on gradients have started.
        # This ensures that the graph execution does not get delayed by waiting
        # for gradients to be reduced as long as there are remaining computations
        # in the backward pass.
        temp_gradients = tnt_ops.barrier_op(gradients_to_reduce,
                                            Tout=[tf.float32] *
                                            len(gradients_to_reduce))

        # Add individual ops that wait for each gradient to be reduced before updating
        # the weights.
        # These ops are executed only after the backward pass has been completed.
        reduced_gradients = list()
        for idx, (_, weight) in enumerate(gradients_and_weights):
            # gradient tensors obtained after barrier are listed in the same order
            # as the initial `gradients_and_weights`
            gradient = temp_gradients[idx]
            grad_id = self.weight_to_index[weight.name]

            output_grad = tnt_ops.finish_allreduce_op(
                gradient,
                tensor_id=grad_id,
                Tout=tf.float32,
                tnt_synchcomm=self.comm.get_raw_ptr())
            if version_utils.tf_version_below_equal('2.3'):
                reduced_gradients.append(output_grad)
            else:
                reduced_gradients.append((output_grad, weight))
        return reduced_gradients
コード例 #8
0
ファイル: metrics.py プロジェクト: cc-hpc-itwm/tarantella
 def _add_support_for_deprecated_methods(self):
     if version_utils.tf_version_below_equal('2.4'):
         self.reset_states = self.reset_state
コード例 #9
0
 def _preprocess_compile_kwargs(self, kwargs):
   if version_utils.tf_version_below_equal('2.1'):
     kwargs['experimental_run_tf_function'] = False
     logger.info("Set `experimental_run_tf_function` to False.")
   return kwargs
コード例 #10
0
def autotune_flag():
    if version_utils.tf_version_below_equal('2.3'):
        return tf.data.experimental.AUTOTUNE
    else:
        return tf.data.AUTOTUNE