def run_testing(opts, transformer):
    testing_graph = tf.Graph()
    with testing_graph.as_default():
        with tf.device("cpu"):
            logger.info("Creating test dataset")
            dataset, num_test, vocab = data_utils.make_dataset(
                opts,
                use_synthetic_data=opts.use_synthetic_data,
                training=False)

            batch_size = opts.batch_size
            if opts.pipeline:
                batch_size *= opts.gradient_accumulation_count
            batches_per_epoch = num_test // batch_size
            logger.info(f"Effective batch-size (global batch): {batch_size}")

            logger.info("Creating infeed and outfeed queues")
            test_infeed = IPUInfeedQueue(dataset, feed_name="test_infeed")
            test_outfeed = IPUOutfeedQueue(feed_name="test_outfeed")

        # Compile the forward pass for testing
        with scopes.ipu_scope("/device:IPU:0"):
            # Helper function
            def loop_builder(iterations, builder_func, infeed):
                return loops.repeat(iterations, builder_func, [], infeed)

            if opts.pipeline:
                logger.info("Creating pipelined test graph")
                test_loop = partial(forward_pass,
                                    opts,
                                    transformer,
                                    batches_per_epoch,
                                    False,
                                    test_outfeed,
                                    dense_queue=None,
                                    infeed=test_infeed)
            else:
                logger.info("Creating test graph")
                test_loop = partial(forward_pass, opts, transformer,
                                    batches_per_epoch, False, test_outfeed,
                                    None)
                test_loop = partial(loop_builder, batches_per_epoch, test_loop,
                                    test_infeed)
            test_loop = ipu_compiler.compile(test_loop, inputs=[])

        # Metrics
        with tf.device("cpu"):
            metrics_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="metrics")
            metrics_initializer = tf.variables_initializer(
                var_list=metrics_vars)
            saver = tf.train.Saver()

    if opts.restore_epoch is None:
        checkpoint = tf.train.latest_checkpoint(opts.train_checkpoint_path)
    else:
        checkpoint = opts.train_checkpoint_path + "/model_" + str(
            opts.restore_epoch) + ".ckpt"

    with tf.Session(graph=testing_graph) as sess:
        # The sparsity will also  be streamed from the checkpoint
        logger.info("Restoring weights")
        saver.restore(sess, checkpoint)
        sess.run(test_infeed.initializer)
        sess.run(metrics_initializer)

        # Run inference (whole dataset in one session call)
        logger.info("Testing...")
        dt = time.perf_counter()
        sess.run(test_loop)
        dt = time.perf_counter() - dt
        session_outputs = sess.run(test_outfeed.dequeue())[-1]

        # Test set performance
        # Log progress
        nll_loss = session_outputs['nll_loss'][-1]
        training_loss = session_outputs['training_loss'][-1]
        perplexity = session_outputs["perplexity"][-1]
        token_accuracy = session_outputs['token_accuracy'][-1]
        desc = (f"\nTraining loss : {training_loss:.4f}"
                f"\nXentropy loss : {nll_loss:.4f}"
                f"\nPerplexity : {perplexity:.3f}"
                f"\nToken accuracy: {token_accuracy:.2f}")
        logger.info(desc)

        if (opts.decode and opts.log_level == 'INFO'):
            text_pred, text_target = data_utils.decode_prediction(
                prediction=session_outputs['predictions'][-1],
                target=session_outputs['target'][-1],
                vocab=vocab)
            logger.info(f"Target: {text_target}\n"
                        f"Prediction: {text_pred}\n")
        os.sys.stdout.flush()

        logger.info(f"Test complete.")

    return desc
def run_testing(opts, transformer, x_test, y_test):
    batches_per_epoch = len(y_test) // opts.batch_size
    testing_graph = tf.Graph()
    with testing_graph.as_default():
        with tf.device("cpu"):
            input_shape = [None, *x_test.shape[1:]]
            place_x = tf.placeholder(dtype=opts.dtype,
                                     shape=input_shape,
                                     name="input")
            place_y = tf.placeholder(dtype=tf.int32,
                                     shape=[None],
                                     name="label")

            # Create dataset and IPU feeds:
            dataset = tf.data.Dataset.from_tensor_slices(
                (place_x, place_y)).cache()
            dataset = dataset.batch(opts.batch_size, drop_remainder=True)
            test_infeed = IPUInfeedQueue(dataset, feed_name="test_infeed")
            test_outfeed = IPUOutfeedQueue(feed_name="test_outfeed")

            # Helper function
            def loop_builder(iterations, builder_func, infeed):
                return loops.repeat(iterations, builder_func, [], infeed)

            # Compile the forward pass for testing
            with scopes.ipu_scope("/device:IPU:0"):
                test_loop = partial(forward_pass, opts, transformer, None,
                                    batches_per_epoch, False, test_outfeed,
                                    None)
                test_loop = partial(loop_builder, batches_per_epoch, test_loop,
                                    test_infeed)
                test_loop = ipu_compiler.compile(test_loop, inputs=[])

            # Metrics
            with tf.device("cpu"):
                metrics_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                                 scope="metrics")
                metrics_initializer = tf.variables_initializer(
                    var_list=metrics_vars)
                saver = tf.train.Saver()

                test_outfeed_dequeue = test_outfeed.dequeue()

    # Setup and acquire an IPU device:
    config = utils.auto_select_ipus(utils.create_ipu_config(), 1)
    utils.configure_ipu_system(config)

    logpath = os.path.join(opts.train_checkpoint_path, "test")
    checkpoint = tf.train.latest_checkpoint(opts.train_checkpoint_path)
    summary_writer = tf.summary.FileWriter(logpath)

    testing_graph.finalize()  # no more new ops added from here on out
    with tf.Session(graph=testing_graph) as sess:
        logger.info(f"Testing...")
        # The sparsity will also  be streamed from the checkpoint
        # The host and device sparsity are not in sync here
        saver.restore(sess, checkpoint)
        sess.run(test_infeed.initializer,
                 feed_dict={
                     place_x: x_test,
                     place_y: y_test
                 })
        sess.run(metrics_initializer)

        # Run inference (whole dataset in one session call)
        dt = time.perf_counter()
        sess.run(test_loop)
        dt = time.perf_counter() - dt
        session_outputs = sess.run(test_outfeed_dequeue)

        # Test set performance
        throughput = transformer.source_sequence_length * len(y_test) / dt
        test_loss = session_outputs['mean_loss'].mean()
        test_acc = session_outputs['acc'][-1]
        desc = f"Test loss: {test_loss:.8f} Test accuracy: {test_acc:.8f}"
        logger.info(desc + f" Throughput {throughput:.1f} token/s")

    # Regression tests
    accuracy_threshold = 0.85
    assert test_acc >= accuracy_threshold, f"Test accuracy ({test_acc:3.2f}) is below threshold of ({accuracy_threshold:3.2f})"
    print("All asserts pass.")
def run_training(opts, transformer):
    # Construct the training graph
    training_graph = tf.Graph()
    with training_graph.as_default():
        with tf.device("cpu"):
            dataset, num_train, vocab = data_utils.make_dataset(
                opts,
                use_synthetic_data=opts.use_synthetic_data,
                training=True)

        # Calculate dataset length
        batch_size = opts.batch_size
        if opts.pipeline:
            batch_size *= opts.gradient_accumulation_count
        batches_per_epoch = num_train // batch_size
        io_steps_per_epoch = batches_per_epoch // opts.repeat_count
        total_io_steps = opts.nepochs * io_steps_per_epoch
        total_global_steps = opts.nepochs * io_steps_per_epoch * opts.repeat_count
        logger.info(f"Effective batch-size (global batch): {batch_size}, "
                    f"IO steps per epoch: {io_steps_per_epoch}, "
                    f"Total IO steps: {total_io_steps} "
                    f"Total global steps: {total_global_steps}")

        if opts.prune_ratio is not None and opts.prune_ratio > 0:
            # Compute the pruning ratio when the learning rate will reach a minimum
            lr_decay_steps = opts.cooldown_steps + opts.warmup_steps
            lr_min_epochs = lr_decay_steps / (io_steps_per_epoch *
                                              opts.repeat_count)
            remainining_prune_ratio = opts.prune_ratio * sparse_training.cosine_prune_function(
                lr_decay_steps, total_global_steps, opts.cosine_prune_schedule)
            logger.warn(
                f"\n\nThe learning rate schedule will reach a minimum after {lr_min_epochs:0.2f} epochs, "
                f"at which point the pruning ratio will be {remainining_prune_ratio:0.3f}\n\n"
            )
            logger.info(
                f"Cosine prune schedule options: {opts.cosine_prune_schedule}")

        logger.info("Creating infeed and outfeed queues")
        # Queues for streaming from host to device and back
        train_infeed = IPUInfeedQueue(dataset, feed_name="train_infeed")
        train_outfeed = IPUOutfeedQueue(feed_name="train_outfeed")
        prune_and_grow_outfeed = IPUOutfeedQueue(
            feed_name="prune_and_grow_outfeed")

        # Helper function
        def loop_builder(iterations, builder_func, infeed):
            return loops.repeat(iterations, builder_func, [], infeed)

        # Compile the forward and backward pass for training
        with scopes.ipu_scope("/device:IPU:0"):
            if opts.pipeline:
                logger.info("Creating pipelined training graph")
                train_loop = partial(forward_pass, opts, transformer,
                                     opts.repeat_count, True, train_outfeed,
                                     prune_and_grow_outfeed, train_infeed)
            else:
                logger.info("Creating training graph")
                train_body = partial(forward_pass, opts, transformer,
                                     opts.repeat_count, True, train_outfeed,
                                     prune_and_grow_outfeed)
                train_loop = partial(loop_builder, opts.repeat_count,
                                     train_body, train_infeed)
            train_loop = ipu_compiler.compile(train_loop, inputs=[])
            transformer.buildSparsityUpdateOps()

        # Metrics
        with tf.device("cpu"):
            metrics_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="metrics")
            metrics_initializer = tf.variables_initializer(
                var_list=metrics_vars)
            saver = tf.train.Saver()

            # These ops are declared here so that the graph can be frozen afterwards
            global_initializer = tf.global_variables_initializer()
            train_outfeed_dequeue = train_outfeed.dequeue()
            if opts.prune_ratio is not None and opts.prune_ratio > 0:
                prune_and_grow_dequeue = prune_and_grow_outfeed.dequeue()
            utils.move_variable_initialization_to_cpu()

            # Tensorboard
            log_name = "logs/" + datetime.now().isoformat()
            summary_writer = tf.summary.FileWriter(logdir=os.path.join(
                opts.train_checkpoint_path, log_name),
                                                   flush_secs=5)

    # Run the model:
    training_graph.finalize()  # no more new ops added from here on out
    with tf.Session(graph=training_graph) as sess:
        logger.info(f"Initializing training session")
        sess.run(global_initializer)
        sess.run(train_infeed.initializer)
        logger.info(f"Training...")
        progress = tqdm(range(opts.nepochs))
        for e in progress:
            sess.run(metrics_initializer)
            for io_step in range(io_steps_per_epoch):
                # Train the model
                step_start_time = time.perf_counter()
                sess.run(train_loop)
                ipu_train_time = time.perf_counter() - step_start_time

                session_outputs = sess.run(train_outfeed_dequeue)[-1]
                logger.debug(f"Train outputs: {session_outputs.keys()}")

                # Calculate avg throughput
                num_tokens = transformer.source_sequence_length * opts.repeat_count * batch_size
                throughput = num_tokens / ipu_train_time

                # Log progress - average stats over the last accumulation step only:
                start_point = -1 if not opts.pipeline else -opts.gradient_accumulation_count
                lr = np.mean(session_outputs["learning_rate"][start_point:])
                training_loss = np.mean(
                    session_outputs['training_loss'][start_point:])
                std_training_loss = np.std(
                    session_outputs['training_loss'][start_point:])
                nll_loss = np.mean(session_outputs['nll_loss'][start_point:])
                perplexity = np.mean(
                    session_outputs["perplexity"][start_point:])
                token_accuracy = np.mean(
                    session_outputs['token_accuracy'][start_point:])
                global_step = session_outputs['global_step'][start_point:][-1]
                logger.info(
                    f"\nEpoch {e}: io_step {io_step+1}/{io_steps_per_epoch}"
                    f"\nGlobal step: {global_step}/{total_global_steps}"
                    f"\nTraining loss : {training_loss:.4f}"
                    f"\nTraining loss standard deviation: {std_training_loss:.4f}"
                    f"\nXentropy loss : {nll_loss:.4f}"
                    f"\nPerplexity : {perplexity:.3f}"
                    f"\nToken accuracy: {token_accuracy:.2f}"
                    f"\nLearning rate: {lr:3.4e}"
                    f"\nThroughput {throughput:.1f} token/s")

                if opts.decode and logger.level <= logging.INFO:
                    try:
                        text_pred, text_target = data_utils.decode_prediction(
                            prediction=session_outputs['predictions'][-1],
                            target=session_outputs['target'][-1],
                            vocab=vocab)
                        logger.info(
                            f"\nTarget: {text_target}\n\nPrediction: {text_pred}\n"
                        )
                    except Exception as ex:
                        logger.warn(f"Decoding failed: {ex}")

                summary_value = [
                    tf.Summary.Value(tag="perplexity",
                                     simple_value=perplexity),
                    tf.Summary.Value(tag="training_loss",
                                     simple_value=training_loss),
                    tf.Summary.Value(tag="stddev_training_loss",
                                     simple_value=std_training_loss),
                    tf.Summary.Value(tag="xentropy_loss",
                                     simple_value=nll_loss),
                    tf.Summary.Value(tag="token_accuracy",
                                     simple_value=token_accuracy),
                    tf.Summary.Value(tag="learning_rate", simple_value=lr),
                    tf.Summary.Value(tag="throughput",
                                     simple_value=throughput),
                    tf.Summary.Value(tag="epoch", simple_value=e)
                ]

                # If we just completed the last io step we do not
                # prune and grow regardless, otherwise check the prune ratio:
                if io_step + 1 < io_steps_per_epoch and transformer.prune_ratio is not None and transformer.prune_ratio > 0:
                    # Retrieve p and g results from the conditional queue:
                    prune_and_grow_data = sess.run(prune_and_grow_dequeue)
                    for k in prune_and_grow_data:
                        prune_and_grow_data[k] = prune_and_grow_data[k][-1]
                    logger.debug(
                        f"Prune and grow outputs: {prune_and_grow_data.keys()}"
                    )

                    prune_and_grow_time, cosine_schedule_factor = transformer.syncPruneAndRegrowOnHost(
                        opts.cosine_prune_schedule, global_step,
                        total_global_steps, prune_and_grow_data)
                    transformer.streamSparsityFromHostToDevice()
                    summary_value.extend([
                        tf.Summary.Value(tag="prune+grow_time",
                                         simple_value=prune_and_grow_time),
                        tf.Summary.Value(tag="cosine_schedule_factor",
                                         simple_value=cosine_schedule_factor)
                    ])

                    for layer_name, sparse_layer in transformer.sparse_layers.items(
                    ):
                        values_var = sparse_layer.get_values_var()
                        grad_w_name = values_var.name.replace(
                            'nz_values:0', 'grad_w')
                        grad_w = np.array(prune_and_grow_data[grad_w_name])
                        if (opts.log_histograms):
                            histogram = tf_utils.make_histogram_proto(
                                grad_w, bins_count=opts.bins_count)
                            summary_value.extend([
                                tf.Summary.Value(tag=layer_name +
                                                 "/dense_grad_w",
                                                 histo=histogram)
                            ])

                        summary_value.extend([
                            tf.Summary.Value(tag=layer_name +
                                             "/dense_grad_w_stddev",
                                             simple_value=np.std(grad_w)),
                            tf.Summary.Value(tag=layer_name +
                                             "/dense_grad_w_mean",
                                             simple_value=np.mean(grad_w)),
                            tf.Summary.Value(tag=layer_name +
                                             "/dense_grad_w_min",
                                             simple_value=np.min(grad_w)),
                            tf.Summary.Value(tag=layer_name +
                                             "/dense_grad_w_max",
                                             simple_value=np.max(grad_w))
                        ])

                        for slot_name, slot in sparse_layer.get_slot_var_dict(
                        ).items():
                            slot_val = prune_and_grow_data[
                                slot.tf_variable.name]
                            if opts.log_histograms:
                                histogram = tf_utils.make_histogram_proto(
                                    slot_val, bins_count=opts.bins_count)
                                summary_value.extend([
                                    tf.Summary.Value(tag=slot_name,
                                                     histo=histogram)
                                ])
                            summary_value.extend([
                                tf.Summary.Value(
                                    tag=slot_name + "/stddev",
                                    simple_value=np.std(slot_val)),
                                tf.Summary.Value(
                                    tag=slot_name + "/mean",
                                    simple_value=np.mean(slot_val)),
                                tf.Summary.Value(
                                    tag=slot_name + "/min",
                                    simple_value=np.min(slot_val)),
                                tf.Summary.Value(tag=slot_name + "/max",
                                                 simple_value=np.max(slot_val))
                            ])

                # Log to tensorboard (outside any graph)
                summary = tf.Summary(value=summary_value)
                summary_writer.add_summary(summary, np.mean(global_step))
                if opts.use_wandb:
                    wandb.tensorflow.log(summary.SerializeToString())
                logger.info(
                    f"Total time for step {time.perf_counter() - step_start_time}"
                )
                logger.info(f"IPU train time for step {ipu_train_time}")

            logger.info(f"Saving model after epoch {e}")
            saver.save(
                sess,
                os.path.join(opts.train_checkpoint_path,
                             'model_' + str(e) + '.ckpt'))
            os.sys.stdout.flush()
        logger.info(f"Training complete.")
def run_training(opts, transformer, x_train, y_train):
    # Calculate dataset length
    num_train = len(y_train)
    batches_per_epoch = num_train // opts.batch_size
    batches_per_step = batches_per_epoch // (opts.steps_per_epoch)
    total_steps = (opts.steps_per_epoch) * opts.nepochs
    logging.info(
        f"Batches per epoch: {batches_per_epoch} Batches per step: {batches_per_step}"
    )

    if not batches_per_epoch % (opts.steps_per_epoch) == 0:
        raise ValueError(
            f"IPU steps per epoch {opts.steps_per_epoch} must divide batches per epoch {batches_per_epoch} exactly."
        )

    # Construct the training graph
    training_graph = tf.Graph()
    with training_graph.as_default():
        with tf.device("cpu"):
            input_shape = [None, *x_train.shape[1:]]
            place_x = tf.placeholder(dtype=opts.dtype,
                                     shape=input_shape,
                                     name="input")
            place_y = tf.placeholder(dtype=tf.int32,
                                     shape=[None],
                                     name="label")
            lr_placeholder = tf.placeholder(opts.dtype, shape=[])

            # Create dataset and IPU feeds:
            dataset = tf.data.Dataset.from_tensor_slices((place_x, place_y))
            dataset = dataset.shuffle(buffer_size=len(y_train),
                                      reshuffle_each_iteration=True,
                                      seed=opts.random_seed).cache()
            dataset = dataset.repeat().batch(opts.batch_size,
                                             drop_remainder=True)

            # Queues for streaming from host to device and back
            train_infeed = IPUInfeedQueue(dataset, feed_name="train_infeed")
            train_outfeed = IPUOutfeedQueue(feed_name="train_outfeed")
            png_outfeed = IPUOutfeedQueue(feed_name="prune_and_grow_outfeed")

            # Helper function
            def loop_builder(iterations, builder_func, infeed):
                return loops.repeat(iterations, builder_func, [], infeed)

            # Compile the forward and backward pass for training
            with scopes.ipu_scope("/device:IPU:0"):
                train_loop = partial(forward_pass, opts, transformer,
                                     lr_placeholder, batches_per_step, True,
                                     train_outfeed, png_outfeed)
                train_loop = partial(loop_builder, batches_per_step,
                                     train_loop, train_infeed)
                train_loop = ipu_compiler.compile(train_loop, inputs=[])
                transformer.buildSparsityUpdateOps()

            # Metrics
            with tf.device("cpu"):
                metrics_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                                 scope="metrics")
                metrics_initializer = tf.variables_initializer(
                    var_list=metrics_vars)
                saver = tf.train.Saver(max_to_keep=5)

                # These ops are declared here so that the graph can be frozen afterwards
                global_initializer = tf.global_variables_initializer()
                train_outfeed_dequeue = train_outfeed.dequeue()
                png_outfeed_dequeue = png_outfeed.dequeue()

    # Setup and acquire an IPU device:
    config = utils.auto_select_ipus(utils.create_ipu_config(), opts.num_shards)
    utils.configure_ipu_system(config)

    logpath = os.path.join(opts.train_checkpoint_path, "train")
    summary_writer = tf.summary.FileWriter(logpath)

    # Run the model:
    training_graph.finalize()  # no more new ops added from here on out
    with tf.Session(graph=training_graph) as sess:
        logger.info(f"Creating training session")
        sess.run(global_initializer)
        sess.run(train_infeed.initializer,
                 feed_dict={
                     place_x: x_train,
                     place_y: y_train
                 })

        progress = tqdm(range(opts.nepochs),
                        bar_format='{desc} Epoch: {n_fmt}/{total_fmt} {bar}')
        for e in progress:
            for i in range(opts.steps_per_epoch):
                # Train the model
                sess.run(metrics_initializer)
                dt = time.perf_counter()
                sess.run(train_loop,
                         feed_dict={
                             lr_placeholder: learning_rate_schedule(e, opts)
                         })
                dt = time.perf_counter() - dt
                session_outputs = sess.run(train_outfeed_dequeue)
                logger.debug(f"Train outputs: {session_outputs}")

                # Calculate avg throughput
                num_tokens = transformer.source_sequence_length * batches_per_step * opts.batch_size
                throughput = num_tokens / dt
                desc = f"Loss {session_outputs['mean_loss'][-1]:.5f} " \
                       f"Accuracy {session_outputs['acc'][-1]:.5f} " \
                       f"Iteration: {session_outputs['iteration'][-1]}"
                progress.set_description(
                    desc + f" Throughput {throughput:.1f} token/s")

                # Perform pruning (if using RigL the dense grads from session_outputs are used)
                step = 1 + i + e * (opts.steps_per_epoch)
                if transformer.prune_ratio is not None:
                    t0 = time.perf_counter()
                    png_results = sess.run(png_outfeed_dequeue)
                    t1 = time.perf_counter()
                    for k in png_results:
                        png_results[k] = png_results[k][-1]
                    logger.debug(
                        f"Prune and grow outputs: {png_results.keys()}")
                    logger.info(
                        f"Downloaded the prune and grow data from Device to Host in {t1-t0:0.3f} seconds"
                    )

                    transformer.syncPruneAndRegrowOnHost(
                        opts.cosine_prune_schedule, step, total_steps,
                        png_results)
                    transformer.streamSparsityFromHostToDevice()

            # Save at the end of each epoch
            logger.info(f"Saving model")
            saver.save(sess,
                       os.path.join(opts.train_checkpoint_path, 'model.ckpt'))