def run_testing(opts, transformer): testing_graph = tf.Graph() with testing_graph.as_default(): with tf.device("cpu"): logger.info("Creating test dataset") dataset, num_test, vocab = data_utils.make_dataset( opts, use_synthetic_data=opts.use_synthetic_data, training=False) batch_size = opts.batch_size if opts.pipeline: batch_size *= opts.gradient_accumulation_count batches_per_epoch = num_test // batch_size logger.info(f"Effective batch-size (global batch): {batch_size}") logger.info("Creating infeed and outfeed queues") test_infeed = IPUInfeedQueue(dataset, feed_name="test_infeed") test_outfeed = IPUOutfeedQueue(feed_name="test_outfeed") # Compile the forward pass for testing with scopes.ipu_scope("/device:IPU:0"): # Helper function def loop_builder(iterations, builder_func, infeed): return loops.repeat(iterations, builder_func, [], infeed) if opts.pipeline: logger.info("Creating pipelined test graph") test_loop = partial(forward_pass, opts, transformer, batches_per_epoch, False, test_outfeed, dense_queue=None, infeed=test_infeed) else: logger.info("Creating test graph") test_loop = partial(forward_pass, opts, transformer, batches_per_epoch, False, test_outfeed, None) test_loop = partial(loop_builder, batches_per_epoch, test_loop, test_infeed) test_loop = ipu_compiler.compile(test_loop, inputs=[]) # Metrics with tf.device("cpu"): metrics_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="metrics") metrics_initializer = tf.variables_initializer( var_list=metrics_vars) saver = tf.train.Saver() if opts.restore_epoch is None: checkpoint = tf.train.latest_checkpoint(opts.train_checkpoint_path) else: checkpoint = opts.train_checkpoint_path + "/model_" + str( opts.restore_epoch) + ".ckpt" with tf.Session(graph=testing_graph) as sess: # The sparsity will also be streamed from the checkpoint logger.info("Restoring weights") saver.restore(sess, checkpoint) sess.run(test_infeed.initializer) sess.run(metrics_initializer) # Run inference (whole dataset in one session call) logger.info("Testing...") dt = time.perf_counter() sess.run(test_loop) dt = time.perf_counter() - dt session_outputs = sess.run(test_outfeed.dequeue())[-1] # Test set performance # Log progress nll_loss = session_outputs['nll_loss'][-1] training_loss = session_outputs['training_loss'][-1] perplexity = session_outputs["perplexity"][-1] token_accuracy = session_outputs['token_accuracy'][-1] desc = (f"\nTraining loss : {training_loss:.4f}" f"\nXentropy loss : {nll_loss:.4f}" f"\nPerplexity : {perplexity:.3f}" f"\nToken accuracy: {token_accuracy:.2f}") logger.info(desc) if (opts.decode and opts.log_level == 'INFO'): text_pred, text_target = data_utils.decode_prediction( prediction=session_outputs['predictions'][-1], target=session_outputs['target'][-1], vocab=vocab) logger.info(f"Target: {text_target}\n" f"Prediction: {text_pred}\n") os.sys.stdout.flush() logger.info(f"Test complete.") return desc
def run_training(opts, transformer): # Construct the training graph training_graph = tf.Graph() with training_graph.as_default(): with tf.device("cpu"): dataset, num_train, vocab = data_utils.make_dataset( opts, use_synthetic_data=opts.use_synthetic_data, training=True) # Calculate dataset length batch_size = opts.batch_size if opts.pipeline: batch_size *= opts.gradient_accumulation_count batches_per_epoch = num_train // batch_size io_steps_per_epoch = batches_per_epoch // opts.repeat_count total_io_steps = opts.nepochs * io_steps_per_epoch total_global_steps = opts.nepochs * io_steps_per_epoch * opts.repeat_count logger.info(f"Effective batch-size (global batch): {batch_size}, " f"IO steps per epoch: {io_steps_per_epoch}, " f"Total IO steps: {total_io_steps} " f"Total global steps: {total_global_steps}") if opts.prune_ratio is not None and opts.prune_ratio > 0: # Compute the pruning ratio when the learning rate will reach a minimum lr_decay_steps = opts.cooldown_steps + opts.warmup_steps lr_min_epochs = lr_decay_steps / (io_steps_per_epoch * opts.repeat_count) remainining_prune_ratio = opts.prune_ratio * sparse_training.cosine_prune_function( lr_decay_steps, total_global_steps, opts.cosine_prune_schedule) logger.warn( f"\n\nThe learning rate schedule will reach a minimum after {lr_min_epochs:0.2f} epochs, " f"at which point the pruning ratio will be {remainining_prune_ratio:0.3f}\n\n" ) logger.info( f"Cosine prune schedule options: {opts.cosine_prune_schedule}") logger.info("Creating infeed and outfeed queues") # Queues for streaming from host to device and back train_infeed = IPUInfeedQueue(dataset, feed_name="train_infeed") train_outfeed = IPUOutfeedQueue(feed_name="train_outfeed") prune_and_grow_outfeed = IPUOutfeedQueue( feed_name="prune_and_grow_outfeed") # Helper function def loop_builder(iterations, builder_func, infeed): return loops.repeat(iterations, builder_func, [], infeed) # Compile the forward and backward pass for training with scopes.ipu_scope("/device:IPU:0"): if opts.pipeline: logger.info("Creating pipelined training graph") train_loop = partial(forward_pass, opts, transformer, opts.repeat_count, True, train_outfeed, prune_and_grow_outfeed, train_infeed) else: logger.info("Creating training graph") train_body = partial(forward_pass, opts, transformer, opts.repeat_count, True, train_outfeed, prune_and_grow_outfeed) train_loop = partial(loop_builder, opts.repeat_count, train_body, train_infeed) train_loop = ipu_compiler.compile(train_loop, inputs=[]) transformer.buildSparsityUpdateOps() # Metrics with tf.device("cpu"): metrics_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="metrics") metrics_initializer = tf.variables_initializer( var_list=metrics_vars) saver = tf.train.Saver() # These ops are declared here so that the graph can be frozen afterwards global_initializer = tf.global_variables_initializer() train_outfeed_dequeue = train_outfeed.dequeue() if opts.prune_ratio is not None and opts.prune_ratio > 0: prune_and_grow_dequeue = prune_and_grow_outfeed.dequeue() utils.move_variable_initialization_to_cpu() # Tensorboard log_name = "logs/" + datetime.now().isoformat() summary_writer = tf.summary.FileWriter(logdir=os.path.join( opts.train_checkpoint_path, log_name), flush_secs=5) # Run the model: training_graph.finalize() # no more new ops added from here on out with tf.Session(graph=training_graph) as sess: logger.info(f"Initializing training session") sess.run(global_initializer) sess.run(train_infeed.initializer) logger.info(f"Training...") progress = tqdm(range(opts.nepochs)) for e in progress: sess.run(metrics_initializer) for io_step in range(io_steps_per_epoch): # Train the model step_start_time = time.perf_counter() sess.run(train_loop) ipu_train_time = time.perf_counter() - step_start_time session_outputs = sess.run(train_outfeed_dequeue)[-1] logger.debug(f"Train outputs: {session_outputs.keys()}") # Calculate avg throughput num_tokens = transformer.source_sequence_length * opts.repeat_count * batch_size throughput = num_tokens / ipu_train_time # Log progress - average stats over the last accumulation step only: start_point = -1 if not opts.pipeline else -opts.gradient_accumulation_count lr = np.mean(session_outputs["learning_rate"][start_point:]) training_loss = np.mean( session_outputs['training_loss'][start_point:]) std_training_loss = np.std( session_outputs['training_loss'][start_point:]) nll_loss = np.mean(session_outputs['nll_loss'][start_point:]) perplexity = np.mean( session_outputs["perplexity"][start_point:]) token_accuracy = np.mean( session_outputs['token_accuracy'][start_point:]) global_step = session_outputs['global_step'][start_point:][-1] logger.info( f"\nEpoch {e}: io_step {io_step+1}/{io_steps_per_epoch}" f"\nGlobal step: {global_step}/{total_global_steps}" f"\nTraining loss : {training_loss:.4f}" f"\nTraining loss standard deviation: {std_training_loss:.4f}" f"\nXentropy loss : {nll_loss:.4f}" f"\nPerplexity : {perplexity:.3f}" f"\nToken accuracy: {token_accuracy:.2f}" f"\nLearning rate: {lr:3.4e}" f"\nThroughput {throughput:.1f} token/s") if opts.decode and logger.level <= logging.INFO: try: text_pred, text_target = data_utils.decode_prediction( prediction=session_outputs['predictions'][-1], target=session_outputs['target'][-1], vocab=vocab) logger.info( f"\nTarget: {text_target}\n\nPrediction: {text_pred}\n" ) except Exception as ex: logger.warn(f"Decoding failed: {ex}") summary_value = [ tf.Summary.Value(tag="perplexity", simple_value=perplexity), tf.Summary.Value(tag="training_loss", simple_value=training_loss), tf.Summary.Value(tag="stddev_training_loss", simple_value=std_training_loss), tf.Summary.Value(tag="xentropy_loss", simple_value=nll_loss), tf.Summary.Value(tag="token_accuracy", simple_value=token_accuracy), tf.Summary.Value(tag="learning_rate", simple_value=lr), tf.Summary.Value(tag="throughput", simple_value=throughput), tf.Summary.Value(tag="epoch", simple_value=e) ] # If we just completed the last io step we do not # prune and grow regardless, otherwise check the prune ratio: if io_step + 1 < io_steps_per_epoch and transformer.prune_ratio is not None and transformer.prune_ratio > 0: # Retrieve p and g results from the conditional queue: prune_and_grow_data = sess.run(prune_and_grow_dequeue) for k in prune_and_grow_data: prune_and_grow_data[k] = prune_and_grow_data[k][-1] logger.debug( f"Prune and grow outputs: {prune_and_grow_data.keys()}" ) prune_and_grow_time, cosine_schedule_factor = transformer.syncPruneAndRegrowOnHost( opts.cosine_prune_schedule, global_step, total_global_steps, prune_and_grow_data) transformer.streamSparsityFromHostToDevice() summary_value.extend([ tf.Summary.Value(tag="prune+grow_time", simple_value=prune_and_grow_time), tf.Summary.Value(tag="cosine_schedule_factor", simple_value=cosine_schedule_factor) ]) for layer_name, sparse_layer in transformer.sparse_layers.items( ): values_var = sparse_layer.get_values_var() grad_w_name = values_var.name.replace( 'nz_values:0', 'grad_w') grad_w = np.array(prune_and_grow_data[grad_w_name]) if (opts.log_histograms): histogram = tf_utils.make_histogram_proto( grad_w, bins_count=opts.bins_count) summary_value.extend([ tf.Summary.Value(tag=layer_name + "/dense_grad_w", histo=histogram) ]) summary_value.extend([ tf.Summary.Value(tag=layer_name + "/dense_grad_w_stddev", simple_value=np.std(grad_w)), tf.Summary.Value(tag=layer_name + "/dense_grad_w_mean", simple_value=np.mean(grad_w)), tf.Summary.Value(tag=layer_name + "/dense_grad_w_min", simple_value=np.min(grad_w)), tf.Summary.Value(tag=layer_name + "/dense_grad_w_max", simple_value=np.max(grad_w)) ]) for slot_name, slot in sparse_layer.get_slot_var_dict( ).items(): slot_val = prune_and_grow_data[ slot.tf_variable.name] if opts.log_histograms: histogram = tf_utils.make_histogram_proto( slot_val, bins_count=opts.bins_count) summary_value.extend([ tf.Summary.Value(tag=slot_name, histo=histogram) ]) summary_value.extend([ tf.Summary.Value( tag=slot_name + "/stddev", simple_value=np.std(slot_val)), tf.Summary.Value( tag=slot_name + "/mean", simple_value=np.mean(slot_val)), tf.Summary.Value( tag=slot_name + "/min", simple_value=np.min(slot_val)), tf.Summary.Value(tag=slot_name + "/max", simple_value=np.max(slot_val)) ]) # Log to tensorboard (outside any graph) summary = tf.Summary(value=summary_value) summary_writer.add_summary(summary, np.mean(global_step)) if opts.use_wandb: wandb.tensorflow.log(summary.SerializeToString()) logger.info( f"Total time for step {time.perf_counter() - step_start_time}" ) logger.info(f"IPU train time for step {ipu_train_time}") logger.info(f"Saving model after epoch {e}") saver.save( sess, os.path.join(opts.train_checkpoint_path, 'model_' + str(e) + '.ckpt')) os.sys.stdout.flush() logger.info(f"Training complete.")