def __init__(self, tensors, log_dir=None, metric_logger=None, every_n_iter=None, every_n_secs=None, at_end=False): """Initializer for LoggingMetricHook. Args: tensors: `dict` that maps string-valued tags to tensors/tensor names, or `iterable` of tensors/tensor names. log_dir: `string`, directory path that metric hook should write log to. metric_logger: instance of `BenchmarkLogger`, the benchmark logger that hook should use to write the log. Exactly one of the `log_dir` and `metric_logger` should be provided. every_n_iter: `int`, print the values of `tensors` once every N local steps taken on the current worker. every_n_secs: `int` or `float`, print the values of `tensors` once every N seconds. Exactly one of `every_n_iter` and `every_n_secs` should be provided. at_end: `bool` specifying whether to print the values of `tensors` at the end of the run. Raises: ValueError: 1. `every_n_iter` is non-positive, or 2. Exactly one of every_n_iter and every_n_secs should be provided. 3. Exactly one of log_dir and metric_logger should be provided. """ super(LoggingMetricHook, self).__init__(tensors=tensors, every_n_iter=every_n_iter, every_n_secs=every_n_secs, at_end=at_end) if (log_dir is None) == (metric_logger is None): raise ValueError( "exactly one of log_dir and metric_logger should be provided.") if log_dir is not None: self._logger = logger.BenchmarkLogger(log_dir) else: self._logger = metric_logger
def resnet_main(seed, flags, model_function, input_function, shape=None): """Shared main loop for ResNet Models. Args: flags: FLAGS object that contains the params for running. See ResnetArgParser for created flags. model_function: the function that instantiates the Model and builds the ops for train/eval. This will be passed directly into the estimator. input_function: the function that processes the dataset and returns a dataset that the estimator can train on. This will be wrapped with all the relevant flags for running and passed to estimator. shape: list of ints representing the shape of the images used for training. This is only used if flags.export_dir is passed. """ # Using the Winograd non-fused algorithms provides a small performance boost. os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # Create session config based on values of inter_op_parallelism_threads and # intra_op_parallelism_threads. Note that we default to having # allow_soft_placement = True, which is required for multi-GPU and not # harmful for other modes. session_config = tf.compat.v1.ConfigProto( inter_op_parallelism_threads=flags.inter_op_parallelism_threads, intra_op_parallelism_threads=flags.intra_op_parallelism_threads, allow_soft_placement=True) if flags.num_gpus == 0: distribution = tf.distribute.OneDeviceStrategy('device:CPU:0') elif flags.num_gpus == 1: distribution = tf.distribute.OneDeviceStrategy('device:GPU:0') else: distribution = tf.distribute.MirroredStrategy( num_gpus=flags.num_gpus ) mllogger.event(key=mllog.constants.SEED, value=seed) run_config = tf.estimator.RunConfig(train_distribute=distribution, session_config=session_config, log_step_count_steps=20, # output logs more frequently save_checkpoints_steps=2502, keep_checkpoint_max=1, tf_random_seed=seed) mllogger.event(key=mllog.constants.GLOBAL_BATCH_SIZE, value=flags.batch_size*hvd.size()) if is_mpi: if hvd.rank() == 0: model_dir = os.path.join(flags.model_dir,"main") else: model_dir = os.path.join(flags.model_dir,"tmp{}".format(hvd.rank())) benchmark_log_dir = flags.benchmark_log_dir if hvd.rank() == 0 else None else: model_dir = flags.model_dir benchmark_log_dir = flags.benchmark_log_dir classifier = tf.estimator.Estimator( model_fn=model_function, model_dir=model_dir, config=run_config, params={ 'resnet_size': flags.resnet_size, 'data_format': flags.data_format, 'batch_size': flags.batch_size, 'version': flags.version, 'loss_scale': flags.loss_scale, 'dtype': flags.dtype, 'label_smoothing': flags.label_smoothing, 'enable_lars': flags.enable_lars, 'weight_decay': flags.weight_decay, 'fine_tune': flags.fine_tune, 'use_bfloat16': flags.use_bfloat16 }) if benchmark_log_dir is not None: benchmark_logger = logger.BenchmarkLogger(benchmark_log_dir) benchmark_logger.log_run_info('resnet') else: benchmark_logger = None # for MPI only to figure out the steps per epoch or per eval, per worker if is_mpi: num_eval_steps = _NUM_IMAGES['validation'] // flags.batch_size steps_per_epoch = _NUM_IMAGES['train'] // flags.batch_size steps_per_epoch_per_worker = steps_per_epoch // hvd.size() steps_per_eval_per_worker = steps_per_epoch_per_worker * flags.epochs_between_evals # The reference performs the first evaluation on the fourth epoch. (offset # eval by 3 epochs) success = False for i in range(flags.train_epochs // flags.epochs_between_evals): # Data for epochs_between_evals (i.e. 4 epochs between evals) worth of # epochs is concatenated and run as a single block inside a session. For # this reason we declare all of the epochs that will be run at the start. # Submitters may report in a way which is reasonable for their control flow. mllogger.start(key=mllog.constants.BLOCK_START, value=i+1) mllogger.event(key=mllog.constants.FIRST_EPOCH_NUM, value=i*flags.epochs_between_evals) mllogger.event(key=mllog.constants.EPOCH_COUNT, value=flags.epochs_between_evals) for j in range(flags.epochs_between_evals): mllogger.event(key=mllog.constants.EPOCH_NUM, value=i * flags.epochs_between_evals + j) flags.hooks += ["examplespersecondhook"] if is_mpi: train_hooks = [hvd.BroadcastGlobalVariablesHook(0)] train_hooks = train_hooks + hooks_helper.get_train_hooks( flags.hooks, batch_size=flags.batch_size*hvd.size(), benchmark_log_dir=flags.benchmark_log_dir) else: train_hooks = hooks_helper.get_train_hooks( flags.hooks, batch_size=flags.batch_size, benchmark_log_dir=flags.benchmark_log_dir) _log_cache = [] def formatter(x): """Abuse side effects to get tensors out of the model_fn.""" if _log_cache: _log_cache.pop() _log_cache.append(x.copy()) return str(x) compliance_hook = tf.estimator.LoggingTensorHook( tensors={_NUM_EXAMPLES_NAME: _NUM_EXAMPLES_NAME}, every_n_iter=int(1e10), at_end=True, formatter=formatter) print('Starting a training cycle.') def input_fn_train(): return input_function( is_training=True, data_dir=flags.data_dir, batch_size=per_device_batch_size(flags.batch_size, flags.num_gpus), num_epochs=flags.epochs_between_evals, num_gpus=flags.num_gpus, dtype=flags.dtype ) if is_mpi: # if max step is set, use max_step, not the steps_per_eval_per_worker # assuming max_train_steps is smaller than steps_per_eval_per_worker # Also assuming when -- steps is specified, the train epochs should # be set to be equal to epochs_between_evals so that the # range(flags.train_epochs // flags.epochs_between_evals) gets to be 1 if (flags.max_train_steps) and (flags.max_train_steps < steps_per_eval_per_worker): train_steps = flags.max_train_steps else: train_steps = steps_per_eval_per_worker classifier.train(input_fn=input_fn_train, hooks=train_hooks + [compliance_hook], steps=train_steps) else: classifier.train(input_fn=input_fn_train, hooks=train_hooks + [compliance_hook], max_steps=flags.max_train_steps) #train_examples = int(_log_cache.pop()[_NUM_EXAMPLES_NAME]) #mlperf_log.resnet_print(key=mlperf_log.INPUT_SIZE, value=train_examples) mllogger.end(key=mllog.constants.BLOCK_STOP, value=i+1) print('Starting to evaluate.') # Evaluate the model and print results def input_fn_eval(): return input_function( is_training=False, data_dir=flags.data_dir, batch_size=per_device_batch_size(flags.batch_size, flags.num_gpus), num_epochs=1, dtype=flags.dtype ) mllogger.start(key=mllog.constants.EVAL_START) # flags.max_train_steps is generally associated with testing and profiling. # As a result it is frequently called with synthetic data, which will # iterate forever. Passing steps=flags.max_train_steps allows the eval # (which is generally unimportant in those circumstances) to terminate. # Note that eval will run for max_train_steps each loop, regardless of the # global_step count. eval_results = classifier.evaluate(input_fn=input_fn_eval, steps=flags.max_train_steps) mllogger.event(key=mllog.constants.EVAL_SAMPLES, value=int(eval_results[_NUM_EXAMPLES_NAME])) mllogger.event(key=mllog.constants.EVAL_ACCURACY, value=float(eval_results['accuracy'])) mllogger.end(key=mllog.constants.EVAL_STOP) print(eval_results) if benchmark_logger: benchmark_logger.log_estimator_evaluation_result(eval_results) if model_helpers.past_stop_threshold( flags.stop_threshold, eval_results['accuracy']): success = True break mllogger.event(key=mllog.constants.RUN_STOP, value={"success": success}) mllogger.end(key=mllog.constants.RUN_STOP)