Beispiel #1
0
def run_testing(opts, transformer):
    testing_graph = tf.Graph()
    with testing_graph.as_default():
        with tf.device("cpu"):
            logger.info("Creating test dataset")
            dataset, num_test, vocab = data_utils.make_dataset(
                opts,
                use_synthetic_data=opts.use_synthetic_data,
                training=False)

            batch_size = opts.batch_size
            if opts.pipeline:
                batch_size *= opts.gradient_accumulation_count
            batches_per_epoch = num_test // batch_size
            logger.info(f"Effective batch-size (global batch): {batch_size}")

            logger.info("Creating infeed and outfeed queues")
            test_infeed = IPUInfeedQueue(dataset, feed_name="test_infeed")
            test_outfeed = IPUOutfeedQueue(feed_name="test_outfeed")

        # Compile the forward pass for testing
        with scopes.ipu_scope("/device:IPU:0"):
            # Helper function
            def loop_builder(iterations, builder_func, infeed):
                return loops.repeat(iterations, builder_func, [], infeed)

            if opts.pipeline:
                logger.info("Creating pipelined test graph")
                test_loop = partial(forward_pass,
                                    opts,
                                    transformer,
                                    batches_per_epoch,
                                    False,
                                    test_outfeed,
                                    dense_queue=None,
                                    infeed=test_infeed)
            else:
                logger.info("Creating test graph")
                test_loop = partial(forward_pass, opts, transformer,
                                    batches_per_epoch, False, test_outfeed,
                                    None)
                test_loop = partial(loop_builder, batches_per_epoch, test_loop,
                                    test_infeed)
            test_loop = ipu_compiler.compile(test_loop, inputs=[])

        # Metrics
        with tf.device("cpu"):
            metrics_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="metrics")
            metrics_initializer = tf.variables_initializer(
                var_list=metrics_vars)
            saver = tf.train.Saver()

    if opts.restore_epoch is None:
        checkpoint = tf.train.latest_checkpoint(opts.train_checkpoint_path)
    else:
        checkpoint = opts.train_checkpoint_path + "/model_" + str(
            opts.restore_epoch) + ".ckpt"

    with tf.Session(graph=testing_graph) as sess:
        # The sparsity will also  be streamed from the checkpoint
        logger.info("Restoring weights")
        saver.restore(sess, checkpoint)
        sess.run(test_infeed.initializer)
        sess.run(metrics_initializer)

        # Run inference (whole dataset in one session call)
        logger.info("Testing...")
        dt = time.perf_counter()
        sess.run(test_loop)
        dt = time.perf_counter() - dt
        session_outputs = sess.run(test_outfeed.dequeue())[-1]

        # Test set performance
        # Log progress
        nll_loss = session_outputs['nll_loss'][-1]
        training_loss = session_outputs['training_loss'][-1]
        perplexity = session_outputs["perplexity"][-1]
        token_accuracy = session_outputs['token_accuracy'][-1]
        desc = (f"\nTraining loss : {training_loss:.4f}"
                f"\nXentropy loss : {nll_loss:.4f}"
                f"\nPerplexity : {perplexity:.3f}"
                f"\nToken accuracy: {token_accuracy:.2f}")
        logger.info(desc)

        if (opts.decode and opts.log_level == 'INFO'):
            text_pred, text_target = data_utils.decode_prediction(
                prediction=session_outputs['predictions'][-1],
                target=session_outputs['target'][-1],
                vocab=vocab)
            logger.info(f"Target: {text_target}\n"
                        f"Prediction: {text_pred}\n")
        os.sys.stdout.flush()

        logger.info(f"Test complete.")

    return desc
Beispiel #2
0
def run_training(opts, transformer):
    # Construct the training graph
    training_graph = tf.Graph()
    with training_graph.as_default():
        with tf.device("cpu"):
            dataset, num_train, vocab = data_utils.make_dataset(
                opts,
                use_synthetic_data=opts.use_synthetic_data,
                training=True)

        # Calculate dataset length
        batch_size = opts.batch_size
        if opts.pipeline:
            batch_size *= opts.gradient_accumulation_count
        batches_per_epoch = num_train // batch_size
        io_steps_per_epoch = batches_per_epoch // opts.repeat_count
        total_io_steps = opts.nepochs * io_steps_per_epoch
        total_global_steps = opts.nepochs * io_steps_per_epoch * opts.repeat_count
        logger.info(f"Effective batch-size (global batch): {batch_size}, "
                    f"IO steps per epoch: {io_steps_per_epoch}, "
                    f"Total IO steps: {total_io_steps} "
                    f"Total global steps: {total_global_steps}")

        if opts.prune_ratio is not None and opts.prune_ratio > 0:
            # Compute the pruning ratio when the learning rate will reach a minimum
            lr_decay_steps = opts.cooldown_steps + opts.warmup_steps
            lr_min_epochs = lr_decay_steps / (io_steps_per_epoch *
                                              opts.repeat_count)
            remainining_prune_ratio = opts.prune_ratio * sparse_training.cosine_prune_function(
                lr_decay_steps, total_global_steps, opts.cosine_prune_schedule)
            logger.warn(
                f"\n\nThe learning rate schedule will reach a minimum after {lr_min_epochs:0.2f} epochs, "
                f"at which point the pruning ratio will be {remainining_prune_ratio:0.3f}\n\n"
            )
            logger.info(
                f"Cosine prune schedule options: {opts.cosine_prune_schedule}")

        logger.info("Creating infeed and outfeed queues")
        # Queues for streaming from host to device and back
        train_infeed = IPUInfeedQueue(dataset, feed_name="train_infeed")
        train_outfeed = IPUOutfeedQueue(feed_name="train_outfeed")
        prune_and_grow_outfeed = IPUOutfeedQueue(
            feed_name="prune_and_grow_outfeed")

        # Helper function
        def loop_builder(iterations, builder_func, infeed):
            return loops.repeat(iterations, builder_func, [], infeed)

        # Compile the forward and backward pass for training
        with scopes.ipu_scope("/device:IPU:0"):
            if opts.pipeline:
                logger.info("Creating pipelined training graph")
                train_loop = partial(forward_pass, opts, transformer,
                                     opts.repeat_count, True, train_outfeed,
                                     prune_and_grow_outfeed, train_infeed)
            else:
                logger.info("Creating training graph")
                train_body = partial(forward_pass, opts, transformer,
                                     opts.repeat_count, True, train_outfeed,
                                     prune_and_grow_outfeed)
                train_loop = partial(loop_builder, opts.repeat_count,
                                     train_body, train_infeed)
            train_loop = ipu_compiler.compile(train_loop, inputs=[])
            transformer.buildSparsityUpdateOps()

        # Metrics
        with tf.device("cpu"):
            metrics_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="metrics")
            metrics_initializer = tf.variables_initializer(
                var_list=metrics_vars)
            saver = tf.train.Saver()

            # These ops are declared here so that the graph can be frozen afterwards
            global_initializer = tf.global_variables_initializer()
            train_outfeed_dequeue = train_outfeed.dequeue()
            if opts.prune_ratio is not None and opts.prune_ratio > 0:
                prune_and_grow_dequeue = prune_and_grow_outfeed.dequeue()
            utils.move_variable_initialization_to_cpu()

            # Tensorboard
            log_name = "logs/" + datetime.now().isoformat()
            summary_writer = tf.summary.FileWriter(logdir=os.path.join(
                opts.train_checkpoint_path, log_name),
                                                   flush_secs=5)

    # Run the model:
    training_graph.finalize()  # no more new ops added from here on out
    with tf.Session(graph=training_graph) as sess:
        logger.info(f"Initializing training session")
        sess.run(global_initializer)
        sess.run(train_infeed.initializer)
        logger.info(f"Training...")
        progress = tqdm(range(opts.nepochs))
        for e in progress:
            sess.run(metrics_initializer)
            for io_step in range(io_steps_per_epoch):
                # Train the model
                step_start_time = time.perf_counter()
                sess.run(train_loop)
                ipu_train_time = time.perf_counter() - step_start_time

                session_outputs = sess.run(train_outfeed_dequeue)[-1]
                logger.debug(f"Train outputs: {session_outputs.keys()}")

                # Calculate avg throughput
                num_tokens = transformer.source_sequence_length * opts.repeat_count * batch_size
                throughput = num_tokens / ipu_train_time

                # Log progress - average stats over the last accumulation step only:
                start_point = -1 if not opts.pipeline else -opts.gradient_accumulation_count
                lr = np.mean(session_outputs["learning_rate"][start_point:])
                training_loss = np.mean(
                    session_outputs['training_loss'][start_point:])
                std_training_loss = np.std(
                    session_outputs['training_loss'][start_point:])
                nll_loss = np.mean(session_outputs['nll_loss'][start_point:])
                perplexity = np.mean(
                    session_outputs["perplexity"][start_point:])
                token_accuracy = np.mean(
                    session_outputs['token_accuracy'][start_point:])
                global_step = session_outputs['global_step'][start_point:][-1]
                logger.info(
                    f"\nEpoch {e}: io_step {io_step+1}/{io_steps_per_epoch}"
                    f"\nGlobal step: {global_step}/{total_global_steps}"
                    f"\nTraining loss : {training_loss:.4f}"
                    f"\nTraining loss standard deviation: {std_training_loss:.4f}"
                    f"\nXentropy loss : {nll_loss:.4f}"
                    f"\nPerplexity : {perplexity:.3f}"
                    f"\nToken accuracy: {token_accuracy:.2f}"
                    f"\nLearning rate: {lr:3.4e}"
                    f"\nThroughput {throughput:.1f} token/s")

                if opts.decode and logger.level <= logging.INFO:
                    try:
                        text_pred, text_target = data_utils.decode_prediction(
                            prediction=session_outputs['predictions'][-1],
                            target=session_outputs['target'][-1],
                            vocab=vocab)
                        logger.info(
                            f"\nTarget: {text_target}\n\nPrediction: {text_pred}\n"
                        )
                    except Exception as ex:
                        logger.warn(f"Decoding failed: {ex}")

                summary_value = [
                    tf.Summary.Value(tag="perplexity",
                                     simple_value=perplexity),
                    tf.Summary.Value(tag="training_loss",
                                     simple_value=training_loss),
                    tf.Summary.Value(tag="stddev_training_loss",
                                     simple_value=std_training_loss),
                    tf.Summary.Value(tag="xentropy_loss",
                                     simple_value=nll_loss),
                    tf.Summary.Value(tag="token_accuracy",
                                     simple_value=token_accuracy),
                    tf.Summary.Value(tag="learning_rate", simple_value=lr),
                    tf.Summary.Value(tag="throughput",
                                     simple_value=throughput),
                    tf.Summary.Value(tag="epoch", simple_value=e)
                ]

                # If we just completed the last io step we do not
                # prune and grow regardless, otherwise check the prune ratio:
                if io_step + 1 < io_steps_per_epoch and transformer.prune_ratio is not None and transformer.prune_ratio > 0:
                    # Retrieve p and g results from the conditional queue:
                    prune_and_grow_data = sess.run(prune_and_grow_dequeue)
                    for k in prune_and_grow_data:
                        prune_and_grow_data[k] = prune_and_grow_data[k][-1]
                    logger.debug(
                        f"Prune and grow outputs: {prune_and_grow_data.keys()}"
                    )

                    prune_and_grow_time, cosine_schedule_factor = transformer.syncPruneAndRegrowOnHost(
                        opts.cosine_prune_schedule, global_step,
                        total_global_steps, prune_and_grow_data)
                    transformer.streamSparsityFromHostToDevice()
                    summary_value.extend([
                        tf.Summary.Value(tag="prune+grow_time",
                                         simple_value=prune_and_grow_time),
                        tf.Summary.Value(tag="cosine_schedule_factor",
                                         simple_value=cosine_schedule_factor)
                    ])

                    for layer_name, sparse_layer in transformer.sparse_layers.items(
                    ):
                        values_var = sparse_layer.get_values_var()
                        grad_w_name = values_var.name.replace(
                            'nz_values:0', 'grad_w')
                        grad_w = np.array(prune_and_grow_data[grad_w_name])
                        if (opts.log_histograms):
                            histogram = tf_utils.make_histogram_proto(
                                grad_w, bins_count=opts.bins_count)
                            summary_value.extend([
                                tf.Summary.Value(tag=layer_name +
                                                 "/dense_grad_w",
                                                 histo=histogram)
                            ])

                        summary_value.extend([
                            tf.Summary.Value(tag=layer_name +
                                             "/dense_grad_w_stddev",
                                             simple_value=np.std(grad_w)),
                            tf.Summary.Value(tag=layer_name +
                                             "/dense_grad_w_mean",
                                             simple_value=np.mean(grad_w)),
                            tf.Summary.Value(tag=layer_name +
                                             "/dense_grad_w_min",
                                             simple_value=np.min(grad_w)),
                            tf.Summary.Value(tag=layer_name +
                                             "/dense_grad_w_max",
                                             simple_value=np.max(grad_w))
                        ])

                        for slot_name, slot in sparse_layer.get_slot_var_dict(
                        ).items():
                            slot_val = prune_and_grow_data[
                                slot.tf_variable.name]
                            if opts.log_histograms:
                                histogram = tf_utils.make_histogram_proto(
                                    slot_val, bins_count=opts.bins_count)
                                summary_value.extend([
                                    tf.Summary.Value(tag=slot_name,
                                                     histo=histogram)
                                ])
                            summary_value.extend([
                                tf.Summary.Value(
                                    tag=slot_name + "/stddev",
                                    simple_value=np.std(slot_val)),
                                tf.Summary.Value(
                                    tag=slot_name + "/mean",
                                    simple_value=np.mean(slot_val)),
                                tf.Summary.Value(
                                    tag=slot_name + "/min",
                                    simple_value=np.min(slot_val)),
                                tf.Summary.Value(tag=slot_name + "/max",
                                                 simple_value=np.max(slot_val))
                            ])

                # Log to tensorboard (outside any graph)
                summary = tf.Summary(value=summary_value)
                summary_writer.add_summary(summary, np.mean(global_step))
                if opts.use_wandb:
                    wandb.tensorflow.log(summary.SerializeToString())
                logger.info(
                    f"Total time for step {time.perf_counter() - step_start_time}"
                )
                logger.info(f"IPU train time for step {ipu_train_time}")

            logger.info(f"Saving model after epoch {e}")
            saver.save(
                sess,
                os.path.join(opts.train_checkpoint_path,
                             'model_' + str(e) + '.ckpt'))
            os.sys.stdout.flush()
        logger.info(f"Training complete.")