def _(self, keras_callback: tf.keras.callbacks.ModelCheckpoint): logger.debug("[DataParallel] ModelCheckpoint callback") # only master rank should save and thus print messages self.verbose = keras_callback.verbose if tnt.is_group_master_rank(self.group) \ else utilities.TF_verbose.SILENT.value self.run_on_all_ranks = False # only one checkpoint is needed (models are identical in a data parallel setting) # disable checkpointing for all ranks except the master rank if not tnt.is_group_master_rank(self.group): self._supports_tf_logs = False self.save_freq = 1e20 # very large value to avoid triggering checkpointing self.epochs_since_last_save = 0
def _distribute_callback_default(self, callback_func: Callable, **kwargs: Any) -> Any: if self.run_on_all_ranks: kwargs_copy = self._average_callback_logs(kwargs) return callback_func(**kwargs_copy) else: if tnt.is_group_master_rank(self.group): return callback_func(**kwargs)
def progbar_logger_distribute_callback(callback_func: Callable, **kwargs: Any) -> Any: if progbar_logger.run_on_all_ranks: kwargs_copy = progbar_logger._average_callback_logs(kwargs) if progbar_logger.should_print_progbar: return callback_func(**kwargs_copy) else: if tnt.is_group_master_rank(progbar_logger.group) and progbar_logger.should_print_progbar: return callback_func(**kwargs)
def _(self, keras_callback: tf.keras.callbacks.EarlyStopping): logger.debug("[DataParallel] EarlyStopping callback") # only master rank should print messages self.verbose = keras_callback.verbose if tnt.is_group_master_rank(self.group) \ else utilities.TF_verbose.SILENT.value def _get_monitor_value(self, logs): averaged_logs = self.average_logs(logs) return super().get_monitor_value(averaged_logs) self.get_monitor_value = _get_monitor_value
def _(self, keras_callback: tf.keras.callbacks.TensorBoard): logger.debug("[PipeliningParallel] TensorBoard callback") if tnt.global_tnt_config.tensorboard_on_all_devices: self.log_dir += f"/rank_{tnt.get_rank()}" else: # disable any data logging for all ranks except the last partition if not tnt.is_group_master_rank(self.group): self.histogram_freq = 0 self.write_graph = False self.write_images = False self.write_steps_per_second = False self.update_freq = 0 self.embeddings_freq = 0 self.embeddings_metadata = None self.profile_batch = None
def _customize_progbar_logger(progbar_logger: tf.keras.callbacks.ProgbarLogger) -> None: if version_utils.tf_version_below_equal('2.2'): raise EnvironmentError("[tnt.callbacks.ProgbarLogger] " "`ProgbarLogger` support from TF 2.3") # the other ranks only need to participate in averaging logs progbar_logger.should_print_progbar = tnt.is_group_master_rank(progbar_logger.group) def progbar_logger_distribute_callback(callback_func: Callable, **kwargs: Any) -> Any: if progbar_logger.run_on_all_ranks: kwargs_copy = progbar_logger._average_callback_logs(kwargs) if progbar_logger.should_print_progbar: return callback_func(**kwargs_copy) else: if tnt.is_group_master_rank(progbar_logger.group) and progbar_logger.should_print_progbar: return callback_func(**kwargs) progbar_logger._distribute_callback = progbar_logger_distribute_callback
def predict(self, x = None, callbacks = None, **kwargs): self._configure_rebuild(dataset = x) self._build_model_and_compile_if_necessary() processed_callbacks = utilities._preprocess_callbacks(callbacks, self.group, parallel_strategy = tnt.ParallelStrategy.PIPELINING, exec_type = 'predict', verbose = kwargs.get('verbose', None)) ds = self._get_microbatched_dataset(dataset = x, nano_batch_size = self.nano_batch_size, num_pipeline_stages = self.num_pipeline_stages) test_loss_metrics = self.model.predict(x = ds, callbacks = processed_callbacks, **kwargs) if tnt.is_group_master_rank(self.group): # last partition return test_loss_metrics
def summary(self, *args, **kwargs): if tnt.global_tnt_config.output_on_all_devices: self.model.summary(*args, **kwargs) else: if tnt.is_group_master_rank(self.group): self.model.summary(*args, **kwargs)
def _(self, keras_callback: tf.keras.callbacks.LearningRateScheduler): logger.debug("[DataParallel] LearningRateScheduler callback") if not tnt.global_tnt_config.output_on_all_devices: if not tnt.is_group_master_rank(self.group): self.verbose = 0
def _(self, keras_callback: tf.keras.callbacks.ReduceLROnPlateau): logger.debug("[DataParallel] ReduceLROnPlateau callback") # only master rank should print messages self.verbose = keras_callback.verbose if tnt.is_group_master_rank(self.group) \ else utilities.TF_verbose.SILENT.value
def _distribute_callback(self, callback_func: Callable, **kwargs: Any) -> Any: if tnt.is_group_master_rank(self.group): processed_kwargs = self._process_callback_logs(kwargs) return callback_func(**processed_kwargs)