def _calculate_mean_and_var(self, x, axes, keep_dims): with ops.name_scope('moments', values=[x, axes]): # The dynamic range of fp16 is too limited to support the collection of # sufficient statistics. As a workaround we simply perform the operations # on 32-bit floats before converting the mean and variance back to fp16 y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x if horovod_enabled(): num_shards = hvd.size() else: num_shards = 1 if num_shards > 1: local_sum = math_ops.reduce_sum(y, axis=axes, keepdims=True) local_squared_sum = math_ops.reduce_sum(math_ops.square(y), axis=axes, keepdims=True) batch_size = math_ops.cast(array_ops.shape_v2(y)[0], dtypes.float32) # y_sum, y_squared_sum, global_batch_size = ( # replica_ctx.all_reduce(reduce_util.ReduceOp.SUM, [ # local_sum, local_squared_sum, batch_size])) # hvd_info(f'local_sum {local_sum.shape}, local_squared_sum {local_squared_sum.shape}') y_sum = hvd.allreduce(local_sum, average=False) y_squared_sum = hvd.allreduce(local_squared_sum, average=False) global_batch_size = batch_size * num_shards axes_vals = [(array_ops.shape_v2(y))[i] for i in range(1, len(axes))] multiplier = math_ops.cast(math_ops.reduce_prod(axes_vals), dtypes.float32) multiplier = multiplier * global_batch_size mean = y_sum / multiplier y_squared_mean = y_squared_sum / multiplier # var = E(x^2) - E(x)^2 variance = y_squared_mean - math_ops.square(mean) else: # Compute true mean while keeping the dims for proper broadcasting. mean = math_ops.reduce_mean(y, axes, keepdims=True, name='mean') # sample variance, not unbiased variance # Note: stop_gradient does not change the gradient that gets # backpropagated to the mean from the variance calculation, # because that gradient is zero variance = math_ops.reduce_mean( math_ops.squared_difference(y, array_ops.stop_gradient(mean)), axes, keepdims=True, name='variance') if not keep_dims: mean = array_ops.squeeze(mean, axes) variance = array_ops.squeeze(variance, axes) if x.dtype == dtypes.float16: return (math_ops.cast(mean, dtypes.float16), math_ops.cast(variance, dtypes.float16)) else: return (mean, variance)
def update(accum_vars): with tf.control_dependencies([global_step.assign(new_global_step) ]): if allreduce_post_accumulation and horovod_enabled(): accum_vars = [ hvd.allreduce(tf.convert_to_tensor(value=accum_var)) if isinstance(accum_var, tf.IndexedSlices) else hvd.allreduce(accum_var) for accum_var in accum_vars ] return optimizer.apply_gradients(list(zip(accum_vars, tvars)), global_step=global_step)
def _moments(self, inputs, reduction_axes, keep_dims): """Compute the mean and variance: it overrides the original _moments.""" shard_mean, shard_variance = super(SyncBatchNormalization, self)._moments( inputs, reduction_axes, keep_dims=keep_dims) num_shards = hvd.size() if horovod_enabled() else 1 if num_shards > 1: # Compute variance using: Var[X]= E[X^2] - E[X]^2. shard_square_of_mean = tf.math.square(shard_mean) shard_mean_of_square = shard_variance + shard_square_of_mean group_mean = hvd.allreduce(shard_mean) group_mean_of_square = hvd.allreduce(shard_mean_of_square) group_variance = group_mean_of_square - tf.math.square(group_mean) return (group_mean, group_variance) else: return (shard_mean, shard_variance)
def __init__(self, num_accumulation_steps=1): super(TrainableVarsAllreducingHookPreOpt, self).__init__() # Modify this collection in order to allreduce other set of variables trainable_vars = tf.compat.v1.trainable_variables() allreduced_trainable_var_ops = [ v.assign(hvd.allreduce(v)) for v in trainable_vars ] self.allreduce_trainable_vars_op = tf.group( *allreduced_trainable_var_ops) self.num_accumulation_steps = num_accumulation_steps self.current_iteration = 1
def eval_end(self): """See base class.""" if self.flags_obj.use_distributed_eval and horovod_enabled(): test_accuracy = hvd.allreduce(self.test_accuracy.result()) else: test_accuracy = self.test_accuracy.result() return { 'test_loss': self.test_loss.result(), 'test_accuracy': test_accuracy }
def begin(self): if self._use_all_reduce: self._avg_ops = OrderedDict({ '{}'.format(tag): hvd.allreduce( basic_session_run_hooks._as_graph_element(tensor)) for (tag, tensor) in self._named_tensor.items() }) else: self._avg_ops = OrderedDict({ '{}'.format(tag): basic_session_run_hooks._as_graph_element(tensor) for (tag, tensor) in self._named_tensor.items() }) self._global_step_tensor = tf.train.get_or_create_global_step() self._avg_ops['step'] = self._global_step_tensor
def resnet_main( flags_obj, model_function, input_function, dataset_name, shape=None): """Shared main loop for ResNet Models. Args: flags_obj: An object containing parsed flags. See define_resnet_flags() for details. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. dataset_name: the name of the dataset for training and evaluation. This is used for logging purpose. shape: list of ints representing the shape of the images used for training. This is only used if flags_obj.export_dir is passed. Returns: Dict of results of the run. Contains the keys `eval_results` and `train_hooks`. `eval_results` contains accuracy (top_1) and accuracy_top_5. `train_hooks` is a list the instances of hooks used during training. """ experimental_preloading = flags_obj.experimental_preloading model_helpers.apply_clean(flags.FLAGS) # Ensures flag override logic is only executed if explicitly triggered. if flags_obj.tf_gpu_thread_mode: override_flags_and_set_envars_for_gpu_thread_pool(flags_obj) # Configures cluster spec for distribution strategy. num_workers = distribution_utils.configure_cluster(flags_obj.worker_hosts, flags_obj.task_index) # Creates session config. allow_soft_placement = True, is required for # multi-GPU and is not harmful for other modes. session_config = tf.compat.v1.ConfigProto( inter_op_parallelism_threads=flags_obj.inter_op_parallelism_threads, intra_op_parallelism_threads=flags_obj.intra_op_parallelism_threads, allow_soft_placement=not experimental_preloading) if horovod_enabled(): # The Scoped Allocator Optimization is enabled by default unless disabled by a flag. if not condition_env_var('TF_DISABLE_SCOPED_ALLOCATOR', default=False): from tensorflow.core.protobuf import rewriter_config_pb2 # pylint: disable=import-error session_config.graph_options.rewrite_options.scoped_allocator_optimization = rewriter_config_pb2.RewriterConfig.ON enable_op = session_config.graph_options.rewrite_options.scoped_allocator_opts.enable_op del enable_op[:] enable_op.append("HorovodAllreduce") distribution_strategy = distribution_utils.get_distribution_strategy( distribution_strategy=flags_obj.distribution_strategy, num_gpus=flags_core.get_num_gpus(flags_obj), num_workers=num_workers, all_reduce_alg=flags_obj.all_reduce_alg, num_packs=flags_obj.num_packs) # Creates a `RunConfig` that checkpoints every 24 hours which essentially # results in checkpoints determined only by `epochs_between_evals`. run_config = tf.estimator.RunConfig( train_distribute=distribution_strategy, session_config=session_config, log_step_count_steps=flags_obj.display_steps, save_checkpoints_secs=None, save_checkpoints_steps=flags_obj.save_checkpoint_steps) # Initializes model with all but the dense layer from pretrained ResNet. # if flags_obj.pretrained_model_checkpoint_path is not None: # warm_start_settings = tf.estimator.WarmStartSettings( # flags_obj.pretrained_model_checkpoint_path, # vars_to_warm_start='^(?!.*dense)') # else: # warm_start_settings = None warm_start_settings = None model_dir=flags_obj.model_dir if horovod_enabled(): model_dir="{}/rank_{}".format(flags_obj.model_dir, hvd.rank()) if experimental_preloading: SelectedEstimator = HabanaEstimator else: SelectedEstimator = tf.estimator.Estimator if flags.FLAGS.is_mlperf_enabled: for eval_batch_size in range(flags_obj.batch_size, 1, -1): if imagenet_main.NUM_IMAGES['validation'] % eval_batch_size == 0: break else: eval_batch_size = flags_obj.batch_size classifier = SelectedEstimator( model_fn=model_function, model_dir=model_dir, config=run_config, warm_start_from=warm_start_settings, params={ 'resnet_size': int(flags_obj.resnet_size), 'data_format': flags_obj.data_format, 'batch_size': flags_obj.batch_size, 'resnet_version': int(flags_obj.resnet_version), 'model_type': flags_obj.model_type, 'loss_scale': flags_core.get_loss_scale(flags_obj, default_for_fp16=128), 'dtype': flags_core.get_tf_dtype(flags_obj), 'fine_tune': flags_obj.fine_tune, 'num_workers': num_workers, 'train_epochs': flags_obj.train_epochs, 'warmup_epochs': flags_obj.warmup_epochs, 'use_cosine_lr': flags_obj.use_cosine_lr, }) run_params = { 'batch_size': flags_obj.batch_size, 'dtype': flags_core.get_tf_dtype(flags_obj), 'resnet_size': flags_obj.resnet_size, 'resnet_version': flags_obj.resnet_version, 'model_type': flags_obj.model_type, 'synthetic_data': flags_obj.use_synthetic_data, 'train_epochs': flags_obj.train_epochs, 'num_workers': num_workers, } if flags.FLAGS.is_mlperf_enabled: run_params['eval_batch_size'] = eval_batch_size if flags_obj.use_synthetic_data: dataset_name = dataset_name + '-synthetic' benchmark_logger = logger.get_benchmark_logger() benchmark_logger.log_run_info('resnet', dataset_name, run_params, test_id=flags_obj.benchmark_test_id) train_hooks = hooks_helper.get_train_hooks( flags_obj.hooks, model_dir=model_dir, batch_size=flags_obj.batch_size) if flags.FLAGS.is_mlperf_enabled: _log_cache = [] def formatter(x): """Abuse side effects to get tensors out of the model_fn.""" if _log_cache: _log_cache.pop() _log_cache.append(x.copy()) return str(x) compliance_hook = tf.estimator.LoggingTensorHook( tensors={_NUM_EXAMPLES_NAME: _NUM_EXAMPLES_NAME}, every_n_iter=int(1e10), at_end=True, formatter=formatter) else: compliance_hook = None if horovod_enabled(): if "tf_profiler_hook" not in flags_obj.hooks and os.environ.get("TF_RANGE_TRACE", False): from TensorFlow.common.utils import RangeTFProfilerHook begin = (imagenet_main.NUM_IMAGES["train"] // (flags_obj.batch_size * hvd.size()) + 100) train_hooks.append(RangeTFProfilerHook(begin,20, "./rank-{}".format(hvd.rank()))) if "synapse_logger_hook" not in flags_obj.hooks and "range" == os.environ.get("HABANA_SYNAPSE_LOGGER", "False").lower(): from TensorFlow.common.horovod_helpers import SynapseLoggerHook begin = (imagenet_main.NUM_IMAGES["train"] // (flags_obj.batch_size * hvd.size()) + 100) end = begin + 100 print("Begin: {}".format(begin)) print("End: {}".format(end)) train_hooks.append(SynapseLoggerHook(list(range(begin, end)), False)) train_hooks.append(hvd.BroadcastGlobalVariablesHook(0)) def input_fn_train(num_epochs, input_context=None): return input_function( is_training=True, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_replica_batch_size( flags_obj.batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=num_epochs, dtype=flags_core.get_dl_type(flags_obj), datasets_num_private_threads=flags_obj.datasets_num_private_threads, input_context=input_context, experimental_preloading=experimental_preloading) def input_fn_eval(): return input_function( is_training=False, data_dir=flags_obj.data_dir, batch_size=distribution_utils.per_replica_batch_size( eval_batch_size, flags_core.get_num_gpus(flags_obj)), num_epochs=1, dtype=flags_core.get_dl_type(flags_obj), experimental_preloading=experimental_preloading) train_epochs = (0 if flags_obj.eval_only or not flags_obj.train_epochs else flags_obj.train_epochs) max_train_steps = flags_obj.max_train_steps global_batch_size = flags_obj.batch_size * (hvd.size() if horovod_enabled() else 1) steps_per_epoch = (imagenet_main.NUM_IMAGES['train'] // global_batch_size) if max_train_steps is None: max_train_steps = steps_per_epoch * (train_epochs + flags_obj.train_offset) max_eval_steps = flags_obj.max_eval_steps if max_eval_steps is None: max_eval_steps = (imagenet_main.NUM_IMAGES['validation'] + eval_batch_size - 1) // eval_batch_size use_train_and_evaluate = flags_obj.use_train_and_evaluate or num_workers > 1 if use_train_and_evaluate: train_spec = tf.estimator.TrainSpec( input_fn=lambda input_context=None: input_fn_train( train_epochs, input_context=input_context), hooks=train_hooks, max_steps=max_train_steps) eval_spec = tf.estimator.EvalSpec(input_fn=input_fn_eval) tf.compat.v1.logging.info('Starting to train and evaluate.') tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec) # tf.estimator.train_and_evalute doesn't return anything in multi-worker # case. eval_results = {} else: if train_epochs == 0: # If --eval_only is set, perform a single loop with zero train epochs. schedule, n_loops = [0], 1 else: # Compute the number of times to loop while training. All but the last # pass will train for `epochs_between_evals` epochs, while the last will # train for the number needed to reach `training_epochs`. For instance if # train_epochs = 25 and epochs_between_evals = 10 # schedule will be set to [10, 10, 5]. That is to say, the loop will: # Train for 10 epochs and then evaluate. # Train for another 10 epochs and then evaluate. # Train for a final 5 epochs (to reach 25 epochs) and then evaluate. n_loops = math.ceil(train_epochs / flags_obj.epochs_between_evals) schedule = [flags_obj.epochs_between_evals for _ in range(int(n_loops))] schedule[-1] = train_epochs - sum(schedule[:-1]) # over counting. if flags.FLAGS.is_mlperf_enabled: mllogger.event(key=mllog.constants.CACHE_CLEAR) mllogger.start(key=mllog.constants.RUN_START) mllogger.event(key=mllog.constants.GLOBAL_BATCH_SIZE, value=global_batch_size) final_step = 0 if flags.FLAGS.is_mlperf_enabled: success = False if flags_obj.train_offset > 0: final_step += flags_obj.train_offset * steps_per_epoch mllogger.event(key=mllog.constants.FIRST_EPOCH_NUM, value=1, metadata={'number of epochs before main loop: ': flags_obj.train_offset}) for i in range(flags_obj.train_offset): mllogger.event(key=mllog.constants.EPOCH_NUM, value=i+1) classifier.train( input_fn=lambda input_context=None: input_fn_train( flags_obj.train_offset, input_context=input_context), hooks=train_hooks + [compliance_hook], max_steps=max_train_steps if max_train_steps < final_step else final_step) for cycle_index, num_train_epochs in enumerate(schedule): tf.compat.v1.logging.info('Starting cycle: %d/%d', cycle_index, int(n_loops)) if flags.FLAGS.is_mlperf_enabled: mllogger.start(key=mllog.constants.BLOCK_START, value=cycle_index+1) mllogger.event(key=mllog.constants.FIRST_EPOCH_NUM, value=cycle_index*flags_obj.epochs_between_evals + flags_obj.train_offset + 1) mllogger.event(key=mllog.constants.EPOCH_COUNT, value=flags_obj.epochs_between_evals) for j in range(flags_obj.epochs_between_evals): mllogger.event(key=mllog.constants.EPOCH_NUM, value=cycle_index * flags_obj.epochs_between_evals + j + flags_obj.train_offset + 1) if num_train_epochs: # Since we are calling classifier.train immediately in each loop, the # value of num_train_epochs in the lambda function will not be changed # before it is used. So it is safe to ignore the pylint error here # pylint: disable=cell-var-from-loop final_step += num_train_epochs * steps_per_epoch classifier.train( input_fn=lambda input_context=None: input_fn_train( num_train_epochs, input_context=input_context), hooks=train_hooks + [compliance_hook] if compliance_hook is not None else train_hooks, max_steps=max_train_steps if max_train_steps < final_step else final_step) if flags.FLAGS.is_mlperf_enabled: mllogger.end(key=mllog.constants.BLOCK_STOP, value=cycle_index+1) if flags.FLAGS.is_mlperf_enabled: mllogger.start(key=mllog.constants.EVAL_START) # max_eval_steps is associated with testing and profiling. # As a result it is frequently called with synthetic data, # which will iterate forever. Passing steps=max_eval_steps # allows the eval (which is generally unimportant in those circumstances) # to terminate. Note that eval will run for max_eval_steps each loop, # regardless of the global_step count. if flags_obj.get_flag_value("return_before_eval", False): return {} if flags_obj.get_flag_value("disable_eval", False): eval_results = None continue tf.compat.v1.logging.info('Starting to evaluate.') eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=max_eval_steps) if flags.FLAGS.is_mlperf_enabled: mllogger.event(key=mllog.constants.EVAL_SAMPLES, value=int(eval_results[_NUM_EXAMPLES_NAME])) valdiation_epoch = (cycle_index + 1) * flags_obj.epochs_between_evals + flags_obj.train_offset mllogger.event(key=mllog.constants.EVAL_ACCURACY, value=float(eval_results['accuracy']), metadata={'epoch_num: ': valdiation_epoch}) mllogger.end(key=mllog.constants.EVAL_STOP, metadata={'epoch_num: ' : valdiation_epoch}) if flags_obj.stop_threshold: success = bool(eval_results['accuracy'] >= flags_obj.stop_threshold) benchmark_logger.log_evaluation_result(eval_results) if flags_obj.stop_threshold: if horovod_enabled(): past_treshold = tf.cast(model_helpers.past_stop_threshold( flags_obj.stop_threshold, eval_results['accuracy']), tf.float32) global_past_treshold = tf.math.greater( hvd.allreduce(past_treshold, op=hvd.Sum), tf.zeros(1, tf.float32)) if global_past_treshold.eval(session=tf.compat.v1.Session()): break else: if model_helpers.past_stop_threshold( flags_obj.stop_threshold, eval_results['accuracy']): break if flags_obj.export_dir is not None: # Exports a saved model for the given classifier. export_dtype = flags_core.get_tf_dtype(flags_obj) if flags_obj.image_bytes_as_serving_input: input_receiver_fn = functools.partial( image_bytes_serving_input_fn, shape, dtype=export_dtype) else: input_receiver_fn = export.build_tensor_serving_input_receiver_fn( shape, batch_size=flags_obj.batch_size, dtype=export_dtype) classifier.export_savedmodel(flags_obj.export_dir, input_receiver_fn, strip_default_attrs=True) stats = {} stats['eval_results'] = eval_results stats['train_hooks'] = train_hooks if flags.FLAGS.is_mlperf_enabled: mllogger.event(key=mllog.constants.RUN_STOP, value={"success": success}) mllogger.end(key=mllog.constants.RUN_STOP) return stats
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, manual_fp16=False, use_fp16=False, num_accumulation_steps=1, optimizer_type="adam", allreduce_post_accumulation=False, init_loss_scale=2**32, weight_decay_rate=0.01,beta_1=0.9, beta_2=0.999, epsilon=1e-6,power = 0.5,use_tpu=False): """Creates an optimizer training op.""" global_step = tf.compat.v1.train.get_or_create_global_step() # avoid step change in learning rate at end of warmup phase if optimizer_type == "adam": power = 1.0 decayed_learning_rate_at_crossover_point = init_lr * ( (1.0 - float(num_warmup_steps) / float(num_train_steps)) ** power) else: power = power decayed_learning_rate_at_crossover_point = init_lr adjusted_init_lr = init_lr * (init_lr / decayed_learning_rate_at_crossover_point) print('decayed_learning_rate_at_crossover_point = %e, adjusted_init_lr = %e' % (decayed_learning_rate_at_crossover_point, adjusted_init_lr)) learning_rate = tf.constant(value=adjusted_init_lr, shape=[], dtype=tf.float32) # Implements linear decay of the learning rate. learning_rate = tf.compat.v1.train.polynomial_decay( learning_rate, global_step - 1, ## We first update global_step, then apply_grad and thus we use global_step-1. num_train_steps, end_learning_rate=0.0, power=power, cycle=False) # Implements linear warmup. I.e., if global_step < num_warmup_steps, the # learning rate will be `global_step/num_warmup_steps * init_lr`. if num_warmup_steps: global_steps_int = tf.cast(global_step, tf.int32) warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32) global_steps_float = tf.cast(global_steps_int, tf.float32) warmup_steps_float = tf.cast(warmup_steps_int, tf.float32) warmup_percent_done = global_steps_float / warmup_steps_float warmup_learning_rate = init_lr * warmup_percent_done is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32) learning_rate = ( (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate) if optimizer_type == "lamb": print("Initializing LAMB Optimizer") optimizer = LAMBOptimizer( learning_rate=learning_rate, weight_decay_rate=weight_decay_rate, beta_1=beta_1, beta_2=beta_2, epsilon=epsilon, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) else: print("Initializing ADAM Weight Decay Optimizer") # It is recommended that you use this optimizer for fine tuning, since this # is how the model was trained (note that the Adam m/v variables are NOT # loaded from init_checkpoint.) optimizer = AdamWeightDecayOptimizer( learning_rate=learning_rate, weight_decay_rate=0.01, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]) if horovod_enabled() and (num_accumulation_steps == 1 or (not allreduce_post_accumulation)): optimizer = hvd.DistributedOptimizer(optimizer, sparse_as_dense=True) if use_fp16: loss_scaler = tf.train.experimental.DynamicLossScale( initial_loss_scale=init_loss_scale, increment_period=1000, multiplier=2.0) optimizer = tf.train.experimental.enable_mixed_precision_graph_rewrite(optimizer, loss_scaler) loss_scale_value = tf.identity(loss_scaler(), name="loss_scale") if manual_fp16: assert False, "No support for ExponentialUpdateLossScaleManager and LossScaleOptimizer in TF2.0" loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(init_loss_scale=init_loss_scale, incr_every_n_steps=1000, decr_every_n_nan_or_inf=2, decr_ratio=0.5) optimizer = tf.contrib.mixed_precision.LossScaleOptimizer(optimizer, loss_scale_manager) if use_tpu: optimizer = tf.compat.v1.tpu.CrossShardOptimizer(optimizer) tvars = tf.compat.v1.trainable_variables() if num_accumulation_steps > 1: #grads_and_vars = optimizer.compute_gradients(loss * 1.0 / num_accumulation_steps, tvars) ## to match mlcomm ref we need to clip before scaling grads_and_vars = optimizer.compute_gradients(loss , tvars, gate_gradients=tf.compat.v1.train.Optimizer.GATE_NONE) local_step = tf.compat.v1.get_variable(name="local_step", shape=[], dtype=tf.int32, trainable=False, initializer=tf.compat.v1.zeros_initializer) batch_finite = tf.compat.v1.get_variable(name="batch_finite", shape=[], dtype=tf.bool, trainable=False, initializer=tf.compat.v1.ones_initializer) accum_vars = [tf.compat.v1.get_variable( name=tvar.name.split(":")[0] + "/accum", shape=tvar.shape.as_list(), dtype=tf.float32, trainable=False, initializer=tf.compat.v1.zeros_initializer()) for tvar in tf.compat.v1.trainable_variables()] reset_step = tf.cast(tf.math.equal(local_step % num_accumulation_steps, 0), dtype=tf.bool) local_step = tf.cond(pred=reset_step, true_fn=lambda: local_step.assign( tf.ones_like(local_step)), false_fn=lambda: local_step.assign_add(1)) grads_and_vars_and_accums = [(gv[0], gv[1], accum_vars[i]) for i, gv in enumerate(grads_and_vars) if gv[0] is not None] grads, tvars, accum_vars = list(zip(*grads_and_vars_and_accums)) all_are_finite = tf.reduce_all(input_tensor=[tf.reduce_all(input_tensor=tf.math.is_finite( g)) for g in grads]) if manual_fp16 or use_fp16 else tf.constant(True, dtype=tf.bool) batch_finite = tf.cond(pred=reset_step, true_fn=lambda: batch_finite.assign(tf.math.logical_and( tf.constant(True, dtype=tf.bool), all_are_finite)), false_fn=lambda: batch_finite.assign(tf.math.logical_and(batch_finite, all_are_finite))) # This is how the model was pre-trained. # ensure global norm is a finite number # to prevent clip_by_global_norm from having a hizzy fit. (clipped_grads, _) = tf.clip_by_global_norm( grads, clip_norm=1.0, use_norm=tf.cond( pred=all_are_finite, true_fn=lambda: tf.linalg.global_norm(grads), false_fn=lambda: tf.constant(1.0))) ## divide grad by acc_steps before accumulating accum_vars = tf.cond(pred=reset_step, true_fn=lambda: [accum_vars[i].assign(grad) for i, grad in enumerate(clipped_grads)], false_fn=lambda: [accum_vars[i].assign_add(grad) for i, grad in enumerate(clipped_grads)]) update_step = tf.identity(tf.cast(tf.math.equal(local_step % num_accumulation_steps, 0), dtype=tf.bool), name="update_step") def allreduce_of_batch_finite_required(): # In case of bf16 and fp32 batch finite is tf.constant(True, dtype=tf.bool) return horovod_enabled() and manual_fp16 and use_fp16 # TODO: in future if we want to enable infinite batch iter skiping we will need to change this allreduce. new_global_step = tf.cond(pred=tf.math.logical_and(update_step, tf.cast(hvd.allreduce(tf.cast(batch_finite, tf.int32)), tf.bool) if allreduce_of_batch_finite_required() else batch_finite), true_fn=lambda: global_step + 1, false_fn=lambda: global_step) new_global_step = tf.identity(new_global_step, name='step_update') def update(accum_vars): with tf.control_dependencies([global_step.assign(new_global_step)]): if allreduce_post_accumulation and horovod_enabled(): accum_vars = [hvd.allreduce(tf.convert_to_tensor(value=accum_var)* 1.0 / num_accumulation_steps, op=hvd.Sum) if isinstance(accum_var, tf.IndexedSlices) else hvd.allreduce(accum_var * 1.0 / num_accumulation_steps, op=hvd.Sum) for accum_var in accum_vars] return optimizer.apply_gradients(list(zip(accum_vars, tvars)), global_step=global_step) train_op = tf.cond(pred=update_step, true_fn=lambda: update(accum_vars), false_fn=lambda: tf.no_op()) else: grads_and_vars = optimizer.compute_gradients(loss, tvars, gate_gradients=tf.compat.v1.train.Optimizer.GATE_NONE) grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None] grads, tvars = list(zip(*grads_and_vars)) all_are_finite = tf.reduce_all( input_tensor=[tf.reduce_all(input_tensor=tf.math.is_finite(g)) for g in grads]) if use_fp16 or manual_fp16 else tf.constant(True, dtype=tf.bool) # This is how the model was pre-trained. # ensure global norm is a finite number # to prevent clip_by_global_norm from having a hizzy fit. (clipped_grads, _) = tf.clip_by_global_norm( grads, clip_norm=1.0, use_norm=tf.cond( pred=all_are_finite, true_fn=lambda: tf.linalg.global_norm(grads), false_fn=lambda: tf.constant(1.0))) new_global_step = tf.cond(pred=all_are_finite, true_fn=lambda: global_step + 1, false_fn=lambda: global_step) new_global_step = tf.identity(new_global_step, name='step_update') with tf.control_dependencies([global_step.assign(new_global_step)]): train_op = optimizer.apply_gradients( list(zip(clipped_grads, tvars)), global_step=global_step) return train_op
def eval_end(self): """See base class.""" epoch_num = int(self.epoch_helper.current_epoch) self.mlperf_mlloger.end(key=self.mlperf_mllog.constants.EVAL_STOP, value=None, metadata={'epoch_num': epoch_num + 1}) local_hit = self.test_accuracy.total local_count = self.test_accuracy.count global_hit = local_hit global_count = local_count if horovod_enabled() and self.dist_eval: global_hit = hvd.allreduce(local_hit, op=hvd.Sum) global_count = hvd.allreduce(local_count, op=hvd.Sum) global_accuracy = float(global_hit / global_count) # assign to self self.test_accuracy.total.assign(global_hit) self.test_accuracy.count.assign(global_count) eval_accuracy = global_accuracy self.eval_accuracy = eval_accuracy self.mlperf_mlloger.event( key=self.mlperf_mllog.constants.EVAL_ACCURACY, value=eval_accuracy, metadata={'epoch_num': epoch_num + 1}) first_epoch_num = max( epoch_num - self.flags_obj.epochs_between_evals + 1, 0) epoch_count = self.flags_obj.epochs_between_evals if first_epoch_num == 0: epoch_count = self.flags_obj.eval_offset_epochs if epoch_count == 0: epoch_count = self.flags_obj.epochs_between_evals self.mlperf_mlloger.end(key=self.mlperf_mllog.constants.BLOCK_STOP, value=None, metadata={ 'first_epoch_num': first_epoch_num + 1, 'epoch_count': epoch_count }) past_threshold = False if self.flags_obj.target_accuracy is not None: past_threshold = eval_accuracy >= self.flags_obj.target_accuracy if (horovod_enabled() and (not self.dist_eval)): past_threshold = hvd.allreduce( tf.cast(past_threshold, tf.float32), op=hvd.Sum) > 0 continue_training = True if past_threshold: continue_training = False elif ((not self.profile) and eval_accuracy <= 0.002): continue_training = False elif self.global_step.numpy() < self.train_steps: self.mlperf_mlloger.start( key=self.mlperf_mllog.constants.BLOCK_START, value=None, metadata={ 'first_epoch_num': epoch_num + 2, 'epoch_count': self.flags_obj.epochs_between_evals }) metrics = { 'test_accuracy': eval_accuracy, 'continue_training': continue_training, } if self.test_loss: metrics['test_loss'] = self.test_loss.result() return metrics