예제 #1
0
    def _(self, keras_callback: tf.keras.callbacks.ModelCheckpoint):
      logger.debug("[DataParallel] ModelCheckpoint callback")
      # only master rank should save and thus print messages
      self.verbose = keras_callback.verbose if tnt.is_group_master_rank(self.group) \
                                            else utilities.TF_verbose.SILENT.value
      self.run_on_all_ranks = False # only one checkpoint is needed (models are identical in a data parallel setting)

      # disable checkpointing for all ranks except the master rank
      if not tnt.is_group_master_rank(self.group):
        self._supports_tf_logs = False
        self.save_freq = 1e20 # very large value to avoid triggering checkpointing
        self.epochs_since_last_save = 0
예제 #2
0
 def _distribute_callback_default(self, callback_func: Callable, **kwargs: Any) -> Any:
   if self.run_on_all_ranks:
     kwargs_copy = self._average_callback_logs(kwargs)    
     return callback_func(**kwargs_copy)
   else:
     if tnt.is_group_master_rank(self.group):
       return callback_func(**kwargs)
예제 #3
0
 def progbar_logger_distribute_callback(callback_func: Callable,
                                        **kwargs: Any) -> Any:
   if progbar_logger.run_on_all_ranks:
     kwargs_copy = progbar_logger._average_callback_logs(kwargs)    
     if progbar_logger.should_print_progbar:
       return callback_func(**kwargs_copy)
   else:
     if tnt.is_group_master_rank(progbar_logger.group) and progbar_logger.should_print_progbar:
       return callback_func(**kwargs)
예제 #4
0
    def _(self, keras_callback: tf.keras.callbacks.EarlyStopping):
      logger.debug("[DataParallel] EarlyStopping callback")
      # only master rank should print messages
      self.verbose = keras_callback.verbose if tnt.is_group_master_rank(self.group) \
                                            else utilities.TF_verbose.SILENT.value

      def _get_monitor_value(self, logs):
        averaged_logs = self.average_logs(logs)
        return super().get_monitor_value(averaged_logs)
      self.get_monitor_value = _get_monitor_value
 def _(self, keras_callback: tf.keras.callbacks.TensorBoard):
     logger.debug("[PipeliningParallel] TensorBoard callback")
     if tnt.global_tnt_config.tensorboard_on_all_devices:
         self.log_dir += f"/rank_{tnt.get_rank()}"
     else:
         # disable any data logging for all ranks except the last partition
         if not tnt.is_group_master_rank(self.group):
             self.histogram_freq = 0
             self.write_graph = False
             self.write_images = False
             self.write_steps_per_second = False
             self.update_freq = 0
             self.embeddings_freq = 0
             self.embeddings_metadata = None
             self.profile_batch = None
예제 #6
0
def _customize_progbar_logger(progbar_logger: tf.keras.callbacks.ProgbarLogger) -> None:
  if version_utils.tf_version_below_equal('2.2'):
    raise EnvironmentError("[tnt.callbacks.ProgbarLogger] "
                            "`ProgbarLogger` support from TF 2.3")
  # the other ranks only need to participate in averaging logs
  progbar_logger.should_print_progbar = tnt.is_group_master_rank(progbar_logger.group)

  def progbar_logger_distribute_callback(callback_func: Callable,
                                         **kwargs: Any) -> Any:
    if progbar_logger.run_on_all_ranks:
      kwargs_copy = progbar_logger._average_callback_logs(kwargs)    
      if progbar_logger.should_print_progbar:
        return callback_func(**kwargs_copy)
    else:
      if tnt.is_group_master_rank(progbar_logger.group) and progbar_logger.should_print_progbar:
        return callback_func(**kwargs)
  progbar_logger._distribute_callback = progbar_logger_distribute_callback
예제 #7
0
  def predict(self,
              x = None,
              callbacks = None,
              **kwargs):
    self._configure_rebuild(dataset = x)
    self._build_model_and_compile_if_necessary()

    processed_callbacks = utilities._preprocess_callbacks(callbacks, self.group,
                                                          parallel_strategy = tnt.ParallelStrategy.PIPELINING,
                                                          exec_type = 'predict',
                                                          verbose = kwargs.get('verbose', None))

    ds = self._get_microbatched_dataset(dataset = x, nano_batch_size = self.nano_batch_size,
                                        num_pipeline_stages = self.num_pipeline_stages)
    test_loss_metrics = self.model.predict(x = ds, callbacks = processed_callbacks, **kwargs)
    if tnt.is_group_master_rank(self.group):  # last partition
      return test_loss_metrics
예제 #8
0
 def summary(self, *args, **kwargs):
   if tnt.global_tnt_config.output_on_all_devices:
     self.model.summary(*args, **kwargs)
   else:
     if tnt.is_group_master_rank(self.group):
       self.model.summary(*args, **kwargs)
예제 #9
0
 def _(self, keras_callback: tf.keras.callbacks.LearningRateScheduler):
   logger.debug("[DataParallel] LearningRateScheduler callback")
   if not tnt.global_tnt_config.output_on_all_devices:
     if not tnt.is_group_master_rank(self.group):
       self.verbose = 0
예제 #10
0
 def _(self, keras_callback: tf.keras.callbacks.ReduceLROnPlateau):
   logger.debug("[DataParallel] ReduceLROnPlateau callback")
   # only master rank should print messages
   self.verbose = keras_callback.verbose if tnt.is_group_master_rank(self.group) \
                                         else utilities.TF_verbose.SILENT.value
예제 #11
0
 def _distribute_callback(self, callback_func: Callable,
                          **kwargs: Any) -> Any:
     if tnt.is_group_master_rank(self.group):
         processed_kwargs = self._process_callback_logs(kwargs)
         return callback_func(**processed_kwargs)