def prune_and_grow(name, fc, prune_and_grow_outputs, random_gen, step, total_steps, opts, metainfo): def cosine_prune_schedule(t, T, max_pruned): s = sparse_training.cosine_prune_function(t, T, opts.cosine_prune_schedule) logger.info(f"t/T: {t}/{T} max:{max_pruned} sched: {s}") return int(np.ceil(max_pruned * s)) # Sync the layer's internal host-side state with prune_and_grow_outputs results # (both weights and slots need to be kept in sync): fc.sync_internal_representation( {"nz": prune_and_grow_outputs[fc.get_values_var().name]}, { slot_name: prune_and_grow_outputs[slot_name] for slot_name in fc.sparse_slots }, {"metainfo": metainfo}) grad_w_name = fc.get_values_var().name.replace('nz_values:0', 'grad_w') grow_results = sparse_training.prune_and_grow( name=fc.name, triplets=fc.get_triplets(), shape=fc.get_shape(), spec=fc.weights.spec, max_non_zeros=fc.get_max_non_zeros(), slot_triplets=fc.extract_slot_triplets(), prune_schedule=partial(cosine_prune_schedule, t=step, T=total_steps), prune_ratio=opts.prune_ratio, grad_w=np.array(prune_and_grow_outputs[grad_w_name]), grow_method=opts.regrow, random_gen=np.random.default_rng(seed=opts.seed), ipu_pooling_type=fc.pooling_type) if grow_results is not None: try: fc.update_triplets(grow_results['gt']) fc.update_slots_from_triplets(grow_results['gs']) except: logger.info( f"Failed to update representation with triplets:\n{grow_results['gt'][0]}\n{grow_results['gt'][1]}\n{grow_results['gt'][2]}" ) logger.info(f"Non-zeros: {len(grow_results['gt'][0])}") logger.info(f"Layer spec: {fc.weights.spec}") raise if opts.records_path and name == 'fc1': # Save the first hidden layer's weight mask for later analysis: save_weights(opts, name, fc, step) return opts.prune_ratio * sparse_training.cosine_prune_function( step, total_steps, opts.cosine_prune_schedule)
def cosine_prune_schedule(t, T, max_pruned): s = sparse_training.cosine_prune_function(t, T, opts.cosine_prune_schedule) logger.info(f"t/T: {t}/{T} max:{max_pruned} sched: {s}") return int(np.ceil(max_pruned * s))
def syncPruneAndRegrowOnHost(self, cosine_options, step, total_steps, session_outputs): # Pruning schedule def cosine_prune_schedule(t, T, max_pruned): return int( np.ceil(max_pruned * sparse_training.cosine_prune_function( t, T, cosine_options))) if step == total_steps: logger.debug("Final step: pruning will be skipped.") return None, self.prune_ratio * sparse_training.cosine_prune_function( step, total_steps, cosine_options) if sparse_training.cosine_prune_function(step, total_steps, cosine_options) == 0: sched = self.prune_ratio * sparse_training.cosine_prune_function( step, total_steps, cosine_options) logger.debug( f"Nothing to prune at step {step}/{total_steps}: schedule is {sched}" ) return None, sched t0 = time.perf_counter() # Prune and grow each sparse layer for layer_name, sparse_layer in self.sparse_layers.items(): values_var_name = sparse_layer.get_values_var().name slots = { slot_name: session_outputs[slot.tf_variable.name] for slot_name, slot in sparse_layer.get_slot_var_dict().items() } nz = session_outputs[values_var_name] sparse_layer.sync_internal_representation({"nz": nz}, slots) # run prune and grow grow_results = sparse_training.prune_and_grow( name=layer_name + "/" + sparse_layer.name, triplets=sparse_layer.get_triplets(), shape=sparse_layer.get_shape(), spec=sparse_layer.weights.spec, max_non_zeros=sparse_layer.get_max_non_zeros(), slot_triplets=sparse_layer.extract_slot_triplets(), prune_schedule=partial(cosine_prune_schedule, t=step, T=total_steps), prune_ratio=self.prune_ratio, grad_w=np.array(session_outputs[layer_name + "/sparse_layer/grad_w"]), grow_method=self.regrow_type, random_gen=self.random, ipu_pooling_type=self.pooling_type) if grow_results is not None: sparse_layer.update_triplets(grow_results['gt']) sparse_layer.update_slots_from_triplets(grow_results['gs']) t1 = time.perf_counter() prune_and_grow_time = t1 - t0 logger.info( f"Prune and grow for step {step} completed in " f"{prune_and_grow_time:0.3f} seconds for {len(self.sparse_layers.keys())} layers" ) # return the time it took to performn the prune and grow as well as the current # factor of the cosine schedule for monitoring return prune_and_grow_time, self.prune_ratio * sparse_training.cosine_prune_function( step, total_steps, cosine_options)
def cosine_prune_schedule(t, T, max_pruned): return int( np.ceil(max_pruned * sparse_training.cosine_prune_function( t, T, cosine_options)))
def run_training(opts, transformer): # Construct the training graph training_graph = tf.Graph() with training_graph.as_default(): with tf.device("cpu"): dataset, num_train, vocab = data_utils.make_dataset( opts, use_synthetic_data=opts.use_synthetic_data, training=True) # Calculate dataset length batch_size = opts.batch_size if opts.pipeline: batch_size *= opts.gradient_accumulation_count batches_per_epoch = num_train // batch_size io_steps_per_epoch = batches_per_epoch // opts.repeat_count total_io_steps = opts.nepochs * io_steps_per_epoch total_global_steps = opts.nepochs * io_steps_per_epoch * opts.repeat_count logger.info(f"Effective batch-size (global batch): {batch_size}, " f"IO steps per epoch: {io_steps_per_epoch}, " f"Total IO steps: {total_io_steps} " f"Total global steps: {total_global_steps}") if opts.prune_ratio is not None and opts.prune_ratio > 0: # Compute the pruning ratio when the learning rate will reach a minimum lr_decay_steps = opts.cooldown_steps + opts.warmup_steps lr_min_epochs = lr_decay_steps / (io_steps_per_epoch * opts.repeat_count) remainining_prune_ratio = opts.prune_ratio * sparse_training.cosine_prune_function( lr_decay_steps, total_global_steps, opts.cosine_prune_schedule) logger.warn( f"\n\nThe learning rate schedule will reach a minimum after {lr_min_epochs:0.2f} epochs, " f"at which point the pruning ratio will be {remainining_prune_ratio:0.3f}\n\n" ) logger.info( f"Cosine prune schedule options: {opts.cosine_prune_schedule}") logger.info("Creating infeed and outfeed queues") # Queues for streaming from host to device and back train_infeed = IPUInfeedQueue(dataset, feed_name="train_infeed") train_outfeed = IPUOutfeedQueue(feed_name="train_outfeed") prune_and_grow_outfeed = IPUOutfeedQueue( feed_name="prune_and_grow_outfeed") # Helper function def loop_builder(iterations, builder_func, infeed): return loops.repeat(iterations, builder_func, [], infeed) # Compile the forward and backward pass for training with scopes.ipu_scope("/device:IPU:0"): if opts.pipeline: logger.info("Creating pipelined training graph") train_loop = partial(forward_pass, opts, transformer, opts.repeat_count, True, train_outfeed, prune_and_grow_outfeed, train_infeed) else: logger.info("Creating training graph") train_body = partial(forward_pass, opts, transformer, opts.repeat_count, True, train_outfeed, prune_and_grow_outfeed) train_loop = partial(loop_builder, opts.repeat_count, train_body, train_infeed) train_loop = ipu_compiler.compile(train_loop, inputs=[]) transformer.buildSparsityUpdateOps() # Metrics with tf.device("cpu"): metrics_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="metrics") metrics_initializer = tf.variables_initializer( var_list=metrics_vars) saver = tf.train.Saver() # These ops are declared here so that the graph can be frozen afterwards global_initializer = tf.global_variables_initializer() train_outfeed_dequeue = train_outfeed.dequeue() if opts.prune_ratio is not None and opts.prune_ratio > 0: prune_and_grow_dequeue = prune_and_grow_outfeed.dequeue() utils.move_variable_initialization_to_cpu() # Tensorboard log_name = "logs/" + datetime.now().isoformat() summary_writer = tf.summary.FileWriter(logdir=os.path.join( opts.train_checkpoint_path, log_name), flush_secs=5) # Run the model: training_graph.finalize() # no more new ops added from here on out with tf.Session(graph=training_graph) as sess: logger.info(f"Initializing training session") sess.run(global_initializer) sess.run(train_infeed.initializer) logger.info(f"Training...") progress = tqdm(range(opts.nepochs)) for e in progress: sess.run(metrics_initializer) for io_step in range(io_steps_per_epoch): # Train the model step_start_time = time.perf_counter() sess.run(train_loop) ipu_train_time = time.perf_counter() - step_start_time session_outputs = sess.run(train_outfeed_dequeue)[-1] logger.debug(f"Train outputs: {session_outputs.keys()}") # Calculate avg throughput num_tokens = transformer.source_sequence_length * opts.repeat_count * batch_size throughput = num_tokens / ipu_train_time # Log progress - average stats over the last accumulation step only: start_point = -1 if not opts.pipeline else -opts.gradient_accumulation_count lr = np.mean(session_outputs["learning_rate"][start_point:]) training_loss = np.mean( session_outputs['training_loss'][start_point:]) std_training_loss = np.std( session_outputs['training_loss'][start_point:]) nll_loss = np.mean(session_outputs['nll_loss'][start_point:]) perplexity = np.mean( session_outputs["perplexity"][start_point:]) token_accuracy = np.mean( session_outputs['token_accuracy'][start_point:]) global_step = session_outputs['global_step'][start_point:][-1] logger.info( f"\nEpoch {e}: io_step {io_step+1}/{io_steps_per_epoch}" f"\nGlobal step: {global_step}/{total_global_steps}" f"\nTraining loss : {training_loss:.4f}" f"\nTraining loss standard deviation: {std_training_loss:.4f}" f"\nXentropy loss : {nll_loss:.4f}" f"\nPerplexity : {perplexity:.3f}" f"\nToken accuracy: {token_accuracy:.2f}" f"\nLearning rate: {lr:3.4e}" f"\nThroughput {throughput:.1f} token/s") if opts.decode and logger.level <= logging.INFO: try: text_pred, text_target = data_utils.decode_prediction( prediction=session_outputs['predictions'][-1], target=session_outputs['target'][-1], vocab=vocab) logger.info( f"\nTarget: {text_target}\n\nPrediction: {text_pred}\n" ) except Exception as ex: logger.warn(f"Decoding failed: {ex}") summary_value = [ tf.Summary.Value(tag="perplexity", simple_value=perplexity), tf.Summary.Value(tag="training_loss", simple_value=training_loss), tf.Summary.Value(tag="stddev_training_loss", simple_value=std_training_loss), tf.Summary.Value(tag="xentropy_loss", simple_value=nll_loss), tf.Summary.Value(tag="token_accuracy", simple_value=token_accuracy), tf.Summary.Value(tag="learning_rate", simple_value=lr), tf.Summary.Value(tag="throughput", simple_value=throughput), tf.Summary.Value(tag="epoch", simple_value=e) ] # If we just completed the last io step we do not # prune and grow regardless, otherwise check the prune ratio: if io_step + 1 < io_steps_per_epoch and transformer.prune_ratio is not None and transformer.prune_ratio > 0: # Retrieve p and g results from the conditional queue: prune_and_grow_data = sess.run(prune_and_grow_dequeue) for k in prune_and_grow_data: prune_and_grow_data[k] = prune_and_grow_data[k][-1] logger.debug( f"Prune and grow outputs: {prune_and_grow_data.keys()}" ) prune_and_grow_time, cosine_schedule_factor = transformer.syncPruneAndRegrowOnHost( opts.cosine_prune_schedule, global_step, total_global_steps, prune_and_grow_data) transformer.streamSparsityFromHostToDevice() summary_value.extend([ tf.Summary.Value(tag="prune+grow_time", simple_value=prune_and_grow_time), tf.Summary.Value(tag="cosine_schedule_factor", simple_value=cosine_schedule_factor) ]) for layer_name, sparse_layer in transformer.sparse_layers.items( ): values_var = sparse_layer.get_values_var() grad_w_name = values_var.name.replace( 'nz_values:0', 'grad_w') grad_w = np.array(prune_and_grow_data[grad_w_name]) if (opts.log_histograms): histogram = tf_utils.make_histogram_proto( grad_w, bins_count=opts.bins_count) summary_value.extend([ tf.Summary.Value(tag=layer_name + "/dense_grad_w", histo=histogram) ]) summary_value.extend([ tf.Summary.Value(tag=layer_name + "/dense_grad_w_stddev", simple_value=np.std(grad_w)), tf.Summary.Value(tag=layer_name + "/dense_grad_w_mean", simple_value=np.mean(grad_w)), tf.Summary.Value(tag=layer_name + "/dense_grad_w_min", simple_value=np.min(grad_w)), tf.Summary.Value(tag=layer_name + "/dense_grad_w_max", simple_value=np.max(grad_w)) ]) for slot_name, slot in sparse_layer.get_slot_var_dict( ).items(): slot_val = prune_and_grow_data[ slot.tf_variable.name] if opts.log_histograms: histogram = tf_utils.make_histogram_proto( slot_val, bins_count=opts.bins_count) summary_value.extend([ tf.Summary.Value(tag=slot_name, histo=histogram) ]) summary_value.extend([ tf.Summary.Value( tag=slot_name + "/stddev", simple_value=np.std(slot_val)), tf.Summary.Value( tag=slot_name + "/mean", simple_value=np.mean(slot_val)), tf.Summary.Value( tag=slot_name + "/min", simple_value=np.min(slot_val)), tf.Summary.Value(tag=slot_name + "/max", simple_value=np.max(slot_val)) ]) # Log to tensorboard (outside any graph) summary = tf.Summary(value=summary_value) summary_writer.add_summary(summary, np.mean(global_step)) if opts.use_wandb: wandb.tensorflow.log(summary.SerializeToString()) logger.info( f"Total time for step {time.perf_counter() - step_start_time}" ) logger.info(f"IPU train time for step {ipu_train_time}") logger.info(f"Saving model after epoch {e}") saver.save( sess, os.path.join(opts.train_checkpoint_path, 'model_' + str(e) + '.ckpt')) os.sys.stdout.flush() logger.info(f"Training complete.")