def validation_run(valid, filepath, i, epoch, first_run, opts): if filepath: valid.saver.restore(valid.session, filepath) # Gather accuracy statistics accuracy = 0.0 start = time.time() for __ in range(opts["validation_iterations"]): try: a = valid.session.run(valid.ops)[0] except tf.errors.OpError as e: raise tf.errors.ResourceExhaustedError(e.node_def, e.op, e.message) accuracy += a val_time = time.time() - start accuracy /= opts["validation_iterations"] valid_format = ( "Validation top-1 accuracy (iteration: {iteration:6d}, epoch: {epoch:6.2f}, img/sec: {img_per_sec:6.2f}," " time: {val_time:8.6f}): {val_acc:6.3f}%") stats = OrderedDict([ ('iteration', i), ('epoch', epoch), ('val_acc', accuracy), ('val_time', val_time), ('img_per_sec', (opts["validation_iterations"] * opts["validation_batches_per_step"] * opts['validation_total_batch_size']) / val_time), ]) logging.print_to_file_and_screen(valid_format.format(**stats), opts) logging.write_to_csv(stats, first_run, False, opts)
def train_process(model, LR_Class, opts): # --------------- OPTIONS --------------------- epochs = opts["epochs"] iterations_per_epoch = DATASET_CONSTANTS[ opts['dataset']]['NUM_IMAGES'] // opts["total_batch_size"] if not opts['iterations']: iterations = epochs * iterations_per_epoch log_freq = iterations_per_epoch // opts['logs_per_epoch'] else: iterations = opts['iterations'] log_freq = opts['log_freq'] if log_freq < opts['batches_per_step']: iterations_per_step = log_freq else: iterations_per_step = log_freq // int( round(log_freq / opts['batches_per_step'])) iterations_per_valid = iterations_per_epoch iterations_per_ckpt = iterations_per_epoch // opts[ 'ckpts_per_epoch'] if opts['ckpts_per_epoch'] else np.inf LR = LR_Class(opts, iterations) batch_accs = deque(maxlen=iterations_per_epoch // iterations_per_step) batch_losses = deque(maxlen=iterations_per_epoch // iterations_per_step) batch_times = deque(maxlen=iterations_per_epoch // iterations_per_step) start_all = None # -------------- BUILD TRAINING GRAPH ---------------- train = training_graph( model, opts, iterations_per_step * opts["gradients_to_accumulate"]) train.session.run(train.init) train.session.run(train.iterator.initializer) # -------------- BUILD VALIDATION GRAPH ---------------- if opts['validation']: valid = validation.initialise_validation(model, opts) # -------------- SAVE AND RESTORE -------------- if opts['ckpts_per_epoch']: filepath = train.saver.save(train.session, opts["checkpoint_path"], global_step=0) print("Saved checkpoint to {}".format(filepath)) if opts.get('restoring'): filename_pattern = re.compile(".*ckpt-[0-9]+$") ckpt_pattern = re.compile(".*ckpt-([0-9]+)$") filenames = sorted( [ os.path.join(opts['logs_path'], f[:-len(".index")]) for f in os.listdir(opts['logs_path']) if filename_pattern.match(f[:-len(".index")]) and f[-len(".index"):] == ".index" ], key=lambda x: int(ckpt_pattern.match(x).groups()[0])) latest_checkpoint = filenames[-1] logging.print_to_file_and_screen( "Restoring training from latest checkpoint: {}".format( latest_checkpoint), opts) ckpt_pattern = re.compile(".*ckpt-([0-9]+)$") i = int(ckpt_pattern.match(latest_checkpoint).groups()[0]) + 1 train.saver.restore(train.session, latest_checkpoint) epoch = float(opts["total_batch_size"] * (i + iterations_per_step)) / DATASET_CONSTANTS[ opts['dataset']]['NUM_IMAGES'] else: i = 0 # ------------- TRAINING LOOP ---------------- print_format = ( "step: {step:6d}, iteration: {iteration:6d}, epoch: {epoch:6.2f}, lr: {lr:6.4g}, loss: {loss_avg:6.3f}, accuracy: {train_acc_avg:6.3f}%" ", img/sec: {img_per_sec:6.2f}, time: {it_time:8.6f}, total_time: {total_time:8.1f}" ) step = 0 start_all = time.time() while i < iterations: step += opts["gradients_to_accumulate"] log_this_step = ((i // log_freq) < ((i + iterations_per_step) // log_freq) or (i == 0) or ((i + (2 * iterations_per_step)) >= iterations)) ckpt_this_step = ((i // iterations_per_ckpt) < ((i + iterations_per_step) // iterations_per_ckpt) or (i == 0) or ((i + (2 * iterations_per_step)) >= iterations)) valid_this_step = (opts['validation'] and ( (i // iterations_per_valid) < ((i + iterations_per_step) // iterations_per_valid) or (i == 0) or ((i + (2 * iterations_per_step)) >= iterations))) # Run Training try: batch_loss, batch_acc, batch_time, current_lr, scaled_lr = training_step( train, i + 1, LR.feed_dict_lr(i)) except tf.errors.OpError as e: raise tf.errors.ResourceExhaustedError(e.node_def, e.op, e.message) batch_time /= iterations_per_step # Calculate Stats batch_accs.append([batch_acc]) batch_losses.append([batch_loss]) if i != 0: batch_times.append([batch_time]) # Print loss if log_this_step: train_acc = np.mean(batch_accs) train_loss = np.mean(batch_losses) if len(batch_times) != 0: avg_batch_time = np.mean(batch_times) else: avg_batch_time = batch_time # flush times every time it is reported batch_times.clear() total_time = time.time() - start_all epoch = float(opts["total_batch_size"] * (i + iterations_per_step)) / DATASET_CONSTANTS[ opts['dataset']]['NUM_IMAGES'] stats = OrderedDict([ ('step', step), ('iteration', i + iterations_per_step), ('epoch', epoch), ('lr', current_lr), ('scaled_lr', scaled_lr), ('loss_batch', batch_loss), ('loss_avg', train_loss), ('train_acc_batch', batch_acc), ('train_acc_avg', train_acc), ('it_time', avg_batch_time), ('img_per_sec', opts['total_batch_size'] / avg_batch_time), ('total_time', total_time), ]) logging.print_to_file_and_screen(print_format.format(**stats), opts) logging.write_to_csv(stats, i == 0, True, opts) if ckpt_this_step: filepath = train.saver.save(train.session, opts["checkpoint_path"], global_step=i + iterations_per_step) print("Saved checkpoint to {}".format(filepath)) # Eval if valid_this_step and opts['validation']: if 'validation_points' not in locals(): validation_points = [] validation_points.append( (i + iterations_per_step, epoch, i == 0, filepath)) i += iterations_per_step # ------------ RUN VALIDATION ------------ if opts['validation']: for iteration, epoch, first_run, filepath in validation_points: validation.validation_run(valid, filepath, iteration, epoch, first_run, opts) # --------------- CLEANUP ---------------- train.session.close()
def inference_run(exec_filename, ckpt_name, iteration, epoch, first_run, opts): """Run inference for multiple iterations and collect latency values.""" logging.mlperf_logging(key="EVAL_START", log_type="start", metadata={"epoch_num": round(epoch)}) engine_name = "my_engine" ctx = embedded_runtime.embedded_runtime_start(exec_filename, [], engine_name, timeout=1000) input_placeholder = tf.placeholder( tf.uint8, (opts['micro_batch_size'], opts['image_size'], opts['image_size'], 3)) num_iters = opts['iterations'] if opts['generated_data']: placeholders = [input_placeholder] images = np.random.normal(size=(opts['micro_batch_size'], opts['image_size'], opts['image_size'], 3)).astype(np.uint8) labels = None else: label_placeholder = tf.placeholder(tf.int32, (opts['micro_batch_size'])) placeholders = [input_placeholder, label_placeholder] with tf.Graph().as_default(): inference_dataset = dataset.data( opts, is_training=False).map(lambda x: {'data_dict': x}) images, labels = dataset_to_list( inference_dataset, num_iters * opts['micro_batch_size']) call_result = embedded_runtime.embedded_runtime_call(placeholders, ctx) ipu.config.reset_ipu_configuration() gc.collect() thread_queue = Queue() with tf.Session() as session: # do not include time of the first iteration in stats initial_feed_dict = prepare_feed_dict(placeholders, images, labels, opts['micro_batch_size'], opts['generated_data'], 0) session.run(call_result, initial_feed_dict) def runner(session, thread_idx): thread_channel = pvti.createTraceChannel(f"Thread {thread_idx}") latencies = [] accuracies = [] for iter_idx in range(num_iters): feed_dict = prepare_feed_dict(placeholders, images, labels, opts['micro_batch_size'], opts['generated_data'], iter_idx) with pvti.Tracepoint(thread_channel, f"Iteration {iter_idx}"): start_iter = time.time() _, predictions = session.run(call_result, feed_dict) end_iter = time.time() latencies.append(end_iter - start_iter) if not opts['generated_data']: expected = feed_dict[label_placeholder] accuracy = np.mean( np.equal(predictions, expected).astype(np.float32)) accuracies.append(accuracy) thread_queue.put((latencies, accuracies), timeout=10) thp = [ Thread(target=runner, args=(session, thread_idx)) for thread_idx in range(opts['num_inference_thread']) ] inference_start = time.time() for idx, _thread in enumerate(thp): _thread.start() print(f"Thread {idx} started") for idx, _thread in enumerate(thp): _thread.join() print(f"Thread {idx} joined") val_time = time.time() - inference_start latencies, accuracies = [], [] while not thread_queue.empty(): lat_acc = thread_queue.get() latencies.extend(lat_acc[0]) accuracies.extend(lat_acc[1]) if opts['generated_data']: total_accuracy = -1 else: total_accuracy = sum(accuracies) / len(accuracies) total_accuracy *= 100 # convert latencies to miliseconds latencies = [1000 * latency_s for latency_s in latencies] max_latency = max(latencies) mean_latency = np.mean(latencies) perc_99 = np.percentile(latencies, 99) perc_99_9 = np.percentile(latencies, 99.9) print( f"Latencies - avg: {mean_latency:8.4f}, 99th percentile: {perc_99:8.4f}, " f"99.9th percentile: {perc_99_9:8.4f}, max: {max_latency:8.4f}") valid_format = ( "Validation top-1 accuracy [{name}] (iteration: {iteration:6d}, epoch: {epoch:6.2f}, " "img/sec: {img_per_sec:6.2f}, time: {val_time:8.6f}, " "latency (ms): {latency:8.4f}: {val_acc:6.3f}%") val_size = (num_iters * opts['num_inference_thread'] * opts['validation_total_batch_size']) stats = OrderedDict([ ('name', ckpt_name), ('iteration', iteration), ('epoch', epoch), ('val_acc', total_accuracy), ('val_time', val_time), ('val_size', val_size), ('img_per_sec', val_size / val_time), ('latency', mean_latency), ]) logging.print_to_file_and_screen(valid_format.format(**stats), opts) logging.write_to_csv(stats, first_run, False, opts) if opts['wandb'] and opts['distributed_worker_index'] == 0: logging.log_to_wandb(stats) logging.mlperf_logging(key="EVAL_STOP", log_type="stop", metadata={"epoch_num": round(epoch)}) logging.mlperf_logging(key="EVAL_ACCURACY", value=float(stats['val_acc']) / 100, metadata={"epoch_num": round(epoch)}) return stats
def evaluate(opts): epochs = opts["epochs"] total_samples = dataset.get_dataset_files_count(opts, is_training=True) logger.info("[evaluation] Total samples with duplications {}".format( total_samples)) total_independent_samples = total_samples // opts['duplication_factor'] logger.info("[evaluation] Total samples without duplications {}".format( total_independent_samples)) steps_per_epoch = total_independent_samples // (opts['batches_per_step'] * opts["total_batch_size"]) iterations_per_epoch = total_independent_samples // ( opts["total_batch_size"]) # total iterations if opts['steps']: logger.warn( "[evaluation] Ignoring the epoch flag and using the steps one") steps = opts['steps'] else: steps = epochs * steps_per_epoch logger.info( "[evaluation] Total training steps equal to {}, total number of samples being analyzed equal to {}" .format(steps, steps * opts['batches_per_step'] * opts['total_batch_size'])) iterations_per_step = opts['batches_per_step'] ckpt_per_step = opts['steps_per_ckpts'] logger.info( "################################################################################" ) logger.info("Start evaluation......") print_format = ( "[evaluation] step: {step:6d}, iteration: {iteration:6d}, epoch: {epoch:6.3f}, lr: {lr:10.3g}, mlm_loss: {mlm_loss:6.3f}, nsp_loss: {nsp_loss:6.3f}, " "samples/sec: {samples_per_sec:6.2f}, time: {iter_time:8.6f}, total_time: {total_time:8.1f}, mlm_acc: {mlm_acc:8.5f}, nsp_acc: {nsp_acc:8.5f}" ) # avoid nan issue caused by queue length is zero. queue_len = iterations_per_epoch // iterations_per_step if queue_len == 0: queue_len = 1 batch_times = deque(maxlen=queue_len) # best_saver = train.saver["best_saver"] iterations_per_step = opts['batches_per_step'] evals = build_graph(bert_config, opts, iterations_per_step, is_training=False, feed_name="trainfeed") evals.session.run(evals.init) evals.session.run(evals.iterator.initializer) evals_saver = evals.saver["train_saver"] evals_saver.restore( evals.session, "/localdata/yongxiy/Desktop/examples-ipu/applications/tensorflow/bert/checkpoint/phase1/BERT_pretraining_2021-03-15 08:49:29.404/" + f'ckpt_last-{100}') step = 0 i = 0 start_all = time.time() while step < steps: try: batch_time, mlm_loss, nsp_loss, mlm_acc, nsp_acc = eval_step( evals, 1.0) except tf.errors.OpError as e: raise tf.errors.ResourceExhaustedError(e.node_def, e.op, e.message) epoch = float(opts["total_batch_size"] * i) / total_independent_samples batch_time /= iterations_per_step if step != 0: batch_times.append([batch_time]) if step == 1: poplar_compile_time = time.time() - start_all logger.info( f"[evaluation] the poplar compile time {poplar_compile_time}") # Print loss if step % opts['steps_per_logs'] == 0: if len(batch_times) != 0: avg_batch_time = np.mean(batch_times) else: avg_batch_time = batch_time samples_per_sec = opts['total_batch_size'] / avg_batch_time # flush times every time it is reported batch_times.clear() total_time = time.time() - start_all stats = OrderedDict([ ('step', step), ('iteration', i), ('epoch', epoch), ('lr', 1.0), ('mlm_loss', mlm_loss), ('nsp_loss', nsp_loss), ('mlm_acc', mlm_acc), ('nsp_acc', nsp_acc), ('iter_time', avg_batch_time), ('samples_per_sec', samples_per_sec), ('total_time', total_time), ]) logger.info(print_format.format(**stats)) bert_logging.write_to_csv(stats, i == 0, True, opts['logs_path']) logger.info( f"[evaluation] throughput samples per second: {samples_per_sec}" ) logger.info(f"[evaluation] average batch time: {avg_batch_time}") # sys_summary = tf.Summary() # sys_summary.value.add(tag='perf/throughput_samples_per_second', simple_value=samples_per_sec) # sys_summary.value.add(tag='perf/average_batch_time', simple_value=avg_batch_time) # summary_writer.add_summary(sys_summary, step) i += iterations_per_step step += 1
def train(bert_config, opts): # --------------- OPTIONS --------------------- epochs = opts["epochs"] total_samples = dataset.get_dataset_files_count(opts, is_training=True) logger.info("Total samples with duplications {}".format(total_samples)) total_independent_samples = total_samples // opts['duplication_factor'] logger.info("Total samples without duplications {}".format( total_independent_samples)) steps_per_epoch = total_independent_samples // (opts['batches_per_step'] * opts["total_batch_size"]) iterations_per_epoch = total_independent_samples // ( opts["total_batch_size"]) # total iterations if opts['steps']: logger.warn("Ignoring the epoch flag and using the steps one") steps = opts['steps'] else: steps = epochs * steps_per_epoch logger.info( "Total training steps equal to {}, total number of samples being analyzed equal to {}" .format(steps, steps * opts['batches_per_step'] * opts['total_batch_size'])) iterations_per_step = opts['batches_per_step'] ckpt_per_step = opts['steps_per_ckpts'] # avoid nan issue caused by queue length is zero. queue_len = iterations_per_epoch // iterations_per_step if queue_len == 0: queue_len = 1 batch_times = deque(maxlen=queue_len) # learning rate strategy lr_schedule_name = opts['lr_schedule'] logger.info(f"Using learning rate schedule {lr_schedule_name}") LR = make_lr_schedule(lr_schedule_name, opts, steps) if opts['do_train']: # -------------- BUILD TRAINING GRAPH ---------------- train = build_graph(bert_config, opts, iterations_per_step, is_training=True, feed_name="trainfeed") train.session.run(train.init) train.session.run(train.iterator.initializer) step = 0 i = 0 if opts['restore_path'] is not None: if os.path.isdir(opts['restore_path']): ckpt_file_path = tf.train.latest_checkpoint( opts['restore_path']) logger.info(f"Restoring training from latest checkpoint") else: # Assume it's a directory ckpt_file_path = opts['restore_path'] logger.info( f"Restoring training from checkpoint: {ckpt_file_path}") train.restore.restore(train.session, ckpt_file_path) ckpt_pattern = re.compile(".*ckpt-([0-9]+)$") i = int(ckpt_pattern.match(ckpt_file_path).groups()[0]) step = int(i // iterations_per_step) if opts['start_from_ckpt']: # We use a checkpoint to initialise our model train.restore.restore(train.session, opts['start_from_ckpt']) logger.info("Starting the training from the checkpoint {}".format( opts['start_from_ckpt'])) # Initialise Weights & Biases if available if opts['wandb']: import wandb wandb.init(project="tf-bert", sync_tensorboard=True) wandb.config.update(opts) # Tensorboard logs path log_path = os.path.join(opts["logs_path"], 'event') logger.info("Tensorboard event file path {}".format(log_path)) summary_writer = tf.summary.FileWriter(log_path, train.graph, session=train.session) # ------------- TRAINING LOOP ---------------- logger.info( "################################################################################" ) logger.info("Start training......") print_format = ( "step: {step:6d}, iteration: {iteration:6d}, epoch: {epoch:6.3f}, lr: {lr:10.3g}, mlm_loss: {mlm_loss:6.3f}, nsp_loss: {nsp_loss:6.3f}, " "samples/sec: {samples_per_sec:6.2f}, time: {iter_time:8.6f}, total_time: {total_time:8.1f}, mlm_acc: {mlm_acc:8.5f}, nsp_acc: {nsp_acc:8.5f}" ) start_all = time.time() train_saver = train.saver["train_saver"] best_saver = train.saver["best_saver"] # We initialize the best loss to a super large value best_total_loss = 1e10 best_step = 0 while step < steps: # Run Training learning_rate = LR.feed_dict_lr(step) try: batch_time, mlm_loss, nsp_loss, mlm_acc, nsp_acc = training_step( train, learning_rate) except tf.errors.OpError as e: raise tf.errors.ResourceExhaustedError(e.node_def, e.op, e.message) epoch = float( opts["total_batch_size"] * i) / total_independent_samples batch_time /= iterations_per_step if step != 0: batch_times.append([batch_time]) if step == 1: poplar_compile_time = time.time() - start_all poplar_summary = tf.Summary() poplar_summary.value.add(tag='poplar/compile_time', simple_value=poplar_compile_time) summary_writer.add_summary(poplar_summary) # Print loss if step % opts['steps_per_logs'] == 0: if len(batch_times) != 0: avg_batch_time = np.mean(batch_times) else: avg_batch_time = batch_time samples_per_sec = opts['total_batch_size'] / avg_batch_time # flush times every time it is reported batch_times.clear() total_time = time.time() - start_all stats = OrderedDict([ ('step', step), ('iteration', i), ('epoch', epoch), ('lr', learning_rate), ('mlm_loss', mlm_loss), ('nsp_loss', nsp_loss), ('mlm_acc', mlm_acc), ('nsp_acc', nsp_acc), ('iter_time', avg_batch_time), ('samples_per_sec', samples_per_sec), ('total_time', total_time), ]) logger.info(print_format.format(**stats)) bert_logging.write_to_csv(stats, i == 0, True, opts['logs_path']) sys_summary = tf.Summary() sys_summary.value.add(tag='perf/throughput_samples_per_second', simple_value=samples_per_sec) sys_summary.value.add(tag='perf/average_batch_time', simple_value=avg_batch_time) summary_writer.add_summary(sys_summary, step) # Log training statistics train_summary = tf.Summary() train_summary.value.add(tag='epoch', simple_value=epoch) train_summary.value.add(tag='loss/MLM', simple_value=mlm_loss) train_summary.value.add(tag='loss/NSP', simple_value=nsp_loss) train_summary.value.add(tag='accuracy/MLM', simple_value=mlm_acc) train_summary.value.add(tag='accuracy/NSP', simple_value=nsp_acc) train_summary.value.add(tag='defaultLearningRate', simple_value=learning_rate) train_summary.value.add(tag='samples', simple_value=step * opts['batches_per_step'] * opts['total_batch_size']) summary_writer.add_summary(train_summary, step) summary_writer.flush() if step % ckpt_per_step == 0 and step: filepath = train_saver.save(train.session, save_path=opts["checkpoint_path"], global_step=step) logger.info("Saved checkpoint to {}".format(filepath)) if not opts['wandb']: bert_logging.save_model_statistics(filepath, summary_writer, step) # Mechanism to checkpoint the best model. # set opts["best_ckpt_min_steps"] to 0 to disable if best_total_loss > mlm_loss + nsp_loss and step - best_step > opts[ "best_ckpt_min_steps"] and opts["best_ckpt_min_steps"]: best_total_loss = mlm_loss + nsp_loss best_step = step filepath = best_saver.save(train.session, save_path=opts["checkpoint_path"] + '_best', global_step=step) logger.info("Saved Best checkpoint to {}".format(filepath)) i += iterations_per_step step += 1 # --------------- LAST CHECKPOINT ---------------- filepath = train_saver.save(train.session, save_path=opts["checkpoint_path"] + '_last', global_step=step) logger.info("Final model saved to to {}".format(filepath)) # --------------- CLEANUP ---------------- train.session.close()
def validation_run(valid, filepath, i, epoch, first_run, opts, latency_thread): run = True if filepath: valid.saver.restore(valid.session, filepath) name = filepath.split('/')[-1] csv_path = os.path.join(opts['logs_path'], 'validation.csv') if os.path.exists(csv_path): with open(csv_path, 'rU') as infile: # read the file as a dictionary for each row ({header : value}) reader = csv.DictReader(infile) for row in reader: if row['name'] == name: run = False print( 'Skipping validation run on checkpoint: {}'.format( name)) break else: name = None if run: if opts['use_popdist']: # synchronise the model weights across all instances valid.session.run(valid.ops['broadcast_weights']) logging.mlperf_logging(key="EVAL_START", log_type="start", metadata={"epoch_num": round(epoch)}) # Gather accuracy statistics accuracy = 0.0 # start latency thread latency_thread.start() start = relative_timer.now() for __ in range(opts["validation_iterations"]): try: a = valid.session.run(valid.ops['accuracy']) except tf.errors.OpError as e: if opts['compile_only'] and 'compilation only' in e.message: print("Validation graph successfully compiled") print("Exiting...") sys.exit(0) raise tf.errors.ResourceExhaustedError(e.node_def, e.op, e.message) accuracy += a val_time = relative_timer.now() - start accuracy /= opts["validation_iterations"] # wait for all dequeues and latency computation latency_thread.join() latency = latency_thread.get_latency() valid_format = ( "Validation top-1 accuracy [{name}] (iteration: {iteration:6d}, epoch: {epoch:6.2f}, img/sec: {img_per_sec:6.2f}," " time: {val_time:8.6f}, latency (ms): {latency:8.4f}): {val_acc:6.3f}%" ) val_size = (opts["validation_iterations"] * opts["validation_batches_per_step"] * opts["validation_global_batch_size"]) count = int( DATASET_CONSTANTS[opts['dataset']]['NUM_VALIDATION_IMAGES']) raw_accuracy = accuracy if count < val_size: accuracy = accuracy * val_size / count stats = OrderedDict([ ('name', name), ('iteration', i), ('epoch', epoch), ('val_acc', accuracy), ('raw_acc', raw_accuracy), ('val_time', val_time), ('val_size', val_size), ('img_per_sec', val_size / val_time), ('latency', latency * 1000), ]) logging.print_to_file_and_screen(valid_format.format(**stats), opts) logging.write_to_csv(stats, first_run, False, opts) if opts["wandb"] and opts["distributed_worker_index"] == 0: logging.log_to_wandb(stats) logging.mlperf_logging(key="EVAL_STOP", log_type="stop", metadata={"epoch_num": round(epoch)}) logging.mlperf_logging(key="EVAL_ACCURACY", value=float(stats["val_acc"]) / 100, metadata={"epoch_num": round(epoch)}) return stats