Exemplo n.º 1
0
def prune_and_grow(name, fc, prune_and_grow_outputs, random_gen, step,
                   total_steps, opts, metainfo):
    def cosine_prune_schedule(t, T, max_pruned):
        s = sparse_training.cosine_prune_function(t, T,
                                                  opts.cosine_prune_schedule)
        logger.info(f"t/T: {t}/{T} max:{max_pruned} sched: {s}")
        return int(np.ceil(max_pruned * s))

    # Sync the layer's internal host-side state with prune_and_grow_outputs results
    # (both weights and slots need to be kept in sync):
    fc.sync_internal_representation(
        {"nz": prune_and_grow_outputs[fc.get_values_var().name]}, {
            slot_name: prune_and_grow_outputs[slot_name]
            for slot_name in fc.sparse_slots
        }, {"metainfo": metainfo})

    grad_w_name = fc.get_values_var().name.replace('nz_values:0', 'grad_w')
    grow_results = sparse_training.prune_and_grow(
        name=fc.name,
        triplets=fc.get_triplets(),
        shape=fc.get_shape(),
        spec=fc.weights.spec,
        max_non_zeros=fc.get_max_non_zeros(),
        slot_triplets=fc.extract_slot_triplets(),
        prune_schedule=partial(cosine_prune_schedule, t=step, T=total_steps),
        prune_ratio=opts.prune_ratio,
        grad_w=np.array(prune_and_grow_outputs[grad_w_name]),
        grow_method=opts.regrow,
        random_gen=np.random.default_rng(seed=opts.seed),
        ipu_pooling_type=fc.pooling_type)

    if grow_results is not None:
        try:
            fc.update_triplets(grow_results['gt'])
            fc.update_slots_from_triplets(grow_results['gs'])
        except:
            logger.info(
                f"Failed to update representation with triplets:\n{grow_results['gt'][0]}\n{grow_results['gt'][1]}\n{grow_results['gt'][2]}"
            )
            logger.info(f"Non-zeros: {len(grow_results['gt'][0])}")
            logger.info(f"Layer spec: {fc.weights.spec}")
            raise

    if opts.records_path and name == 'fc1':
        # Save the first hidden layer's weight mask for later analysis:
        save_weights(opts, name, fc, step)

    return opts.prune_ratio * sparse_training.cosine_prune_function(
        step, total_steps, opts.cosine_prune_schedule)
Exemplo n.º 2
0
 def cosine_prune_schedule(t, T, max_pruned):
     s = sparse_training.cosine_prune_function(t, T,
                                               opts.cosine_prune_schedule)
     logger.info(f"t/T: {t}/{T} max:{max_pruned} sched: {s}")
     return int(np.ceil(max_pruned * s))
Exemplo n.º 3
0
    def syncPruneAndRegrowOnHost(self, cosine_options, step, total_steps,
                                 session_outputs):
        # Pruning schedule
        def cosine_prune_schedule(t, T, max_pruned):
            return int(
                np.ceil(max_pruned * sparse_training.cosine_prune_function(
                    t, T, cosine_options)))

        if step == total_steps:
            logger.debug("Final step: pruning will be skipped.")
            return None, self.prune_ratio * sparse_training.cosine_prune_function(
                step, total_steps, cosine_options)

        if sparse_training.cosine_prune_function(step, total_steps,
                                                 cosine_options) == 0:
            sched = self.prune_ratio * sparse_training.cosine_prune_function(
                step, total_steps, cosine_options)
            logger.debug(
                f"Nothing to prune at step {step}/{total_steps}: schedule is {sched}"
            )
            return None, sched

        t0 = time.perf_counter()
        # Prune and grow each sparse layer
        for layer_name, sparse_layer in self.sparse_layers.items():
            values_var_name = sparse_layer.get_values_var().name
            slots = {
                slot_name: session_outputs[slot.tf_variable.name]
                for slot_name, slot in
                sparse_layer.get_slot_var_dict().items()
            }
            nz = session_outputs[values_var_name]

            sparse_layer.sync_internal_representation({"nz": nz}, slots)

            # run prune and grow
            grow_results = sparse_training.prune_and_grow(
                name=layer_name + "/" + sparse_layer.name,
                triplets=sparse_layer.get_triplets(),
                shape=sparse_layer.get_shape(),
                spec=sparse_layer.weights.spec,
                max_non_zeros=sparse_layer.get_max_non_zeros(),
                slot_triplets=sparse_layer.extract_slot_triplets(),
                prune_schedule=partial(cosine_prune_schedule,
                                       t=step,
                                       T=total_steps),
                prune_ratio=self.prune_ratio,
                grad_w=np.array(session_outputs[layer_name +
                                                "/sparse_layer/grad_w"]),
                grow_method=self.regrow_type,
                random_gen=self.random,
                ipu_pooling_type=self.pooling_type)

            if grow_results is not None:
                sparse_layer.update_triplets(grow_results['gt'])
                sparse_layer.update_slots_from_triplets(grow_results['gs'])

        t1 = time.perf_counter()
        prune_and_grow_time = t1 - t0
        logger.info(
            f"Prune and grow for step {step} completed in "
            f"{prune_and_grow_time:0.3f} seconds for {len(self.sparse_layers.keys())} layers"
        )
        # return the time it took to performn the prune and grow as well as the current
        # factor of the cosine schedule for monitoring
        return prune_and_grow_time, self.prune_ratio * sparse_training.cosine_prune_function(
            step, total_steps, cosine_options)
Exemplo n.º 4
0
 def cosine_prune_schedule(t, T, max_pruned):
     return int(
         np.ceil(max_pruned * sparse_training.cosine_prune_function(
             t, T, cosine_options)))
Exemplo n.º 5
0
def run_training(opts, transformer):
    # Construct the training graph
    training_graph = tf.Graph()
    with training_graph.as_default():
        with tf.device("cpu"):
            dataset, num_train, vocab = data_utils.make_dataset(
                opts,
                use_synthetic_data=opts.use_synthetic_data,
                training=True)

        # Calculate dataset length
        batch_size = opts.batch_size
        if opts.pipeline:
            batch_size *= opts.gradient_accumulation_count
        batches_per_epoch = num_train // batch_size
        io_steps_per_epoch = batches_per_epoch // opts.repeat_count
        total_io_steps = opts.nepochs * io_steps_per_epoch
        total_global_steps = opts.nepochs * io_steps_per_epoch * opts.repeat_count
        logger.info(f"Effective batch-size (global batch): {batch_size}, "
                    f"IO steps per epoch: {io_steps_per_epoch}, "
                    f"Total IO steps: {total_io_steps} "
                    f"Total global steps: {total_global_steps}")

        if opts.prune_ratio is not None and opts.prune_ratio > 0:
            # Compute the pruning ratio when the learning rate will reach a minimum
            lr_decay_steps = opts.cooldown_steps + opts.warmup_steps
            lr_min_epochs = lr_decay_steps / (io_steps_per_epoch *
                                              opts.repeat_count)
            remainining_prune_ratio = opts.prune_ratio * sparse_training.cosine_prune_function(
                lr_decay_steps, total_global_steps, opts.cosine_prune_schedule)
            logger.warn(
                f"\n\nThe learning rate schedule will reach a minimum after {lr_min_epochs:0.2f} epochs, "
                f"at which point the pruning ratio will be {remainining_prune_ratio:0.3f}\n\n"
            )
            logger.info(
                f"Cosine prune schedule options: {opts.cosine_prune_schedule}")

        logger.info("Creating infeed and outfeed queues")
        # Queues for streaming from host to device and back
        train_infeed = IPUInfeedQueue(dataset, feed_name="train_infeed")
        train_outfeed = IPUOutfeedQueue(feed_name="train_outfeed")
        prune_and_grow_outfeed = IPUOutfeedQueue(
            feed_name="prune_and_grow_outfeed")

        # Helper function
        def loop_builder(iterations, builder_func, infeed):
            return loops.repeat(iterations, builder_func, [], infeed)

        # Compile the forward and backward pass for training
        with scopes.ipu_scope("/device:IPU:0"):
            if opts.pipeline:
                logger.info("Creating pipelined training graph")
                train_loop = partial(forward_pass, opts, transformer,
                                     opts.repeat_count, True, train_outfeed,
                                     prune_and_grow_outfeed, train_infeed)
            else:
                logger.info("Creating training graph")
                train_body = partial(forward_pass, opts, transformer,
                                     opts.repeat_count, True, train_outfeed,
                                     prune_and_grow_outfeed)
                train_loop = partial(loop_builder, opts.repeat_count,
                                     train_body, train_infeed)
            train_loop = ipu_compiler.compile(train_loop, inputs=[])
            transformer.buildSparsityUpdateOps()

        # Metrics
        with tf.device("cpu"):
            metrics_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                             scope="metrics")
            metrics_initializer = tf.variables_initializer(
                var_list=metrics_vars)
            saver = tf.train.Saver()

            # These ops are declared here so that the graph can be frozen afterwards
            global_initializer = tf.global_variables_initializer()
            train_outfeed_dequeue = train_outfeed.dequeue()
            if opts.prune_ratio is not None and opts.prune_ratio > 0:
                prune_and_grow_dequeue = prune_and_grow_outfeed.dequeue()
            utils.move_variable_initialization_to_cpu()

            # Tensorboard
            log_name = "logs/" + datetime.now().isoformat()
            summary_writer = tf.summary.FileWriter(logdir=os.path.join(
                opts.train_checkpoint_path, log_name),
                                                   flush_secs=5)

    # Run the model:
    training_graph.finalize()  # no more new ops added from here on out
    with tf.Session(graph=training_graph) as sess:
        logger.info(f"Initializing training session")
        sess.run(global_initializer)
        sess.run(train_infeed.initializer)
        logger.info(f"Training...")
        progress = tqdm(range(opts.nepochs))
        for e in progress:
            sess.run(metrics_initializer)
            for io_step in range(io_steps_per_epoch):
                # Train the model
                step_start_time = time.perf_counter()
                sess.run(train_loop)
                ipu_train_time = time.perf_counter() - step_start_time

                session_outputs = sess.run(train_outfeed_dequeue)[-1]
                logger.debug(f"Train outputs: {session_outputs.keys()}")

                # Calculate avg throughput
                num_tokens = transformer.source_sequence_length * opts.repeat_count * batch_size
                throughput = num_tokens / ipu_train_time

                # Log progress - average stats over the last accumulation step only:
                start_point = -1 if not opts.pipeline else -opts.gradient_accumulation_count
                lr = np.mean(session_outputs["learning_rate"][start_point:])
                training_loss = np.mean(
                    session_outputs['training_loss'][start_point:])
                std_training_loss = np.std(
                    session_outputs['training_loss'][start_point:])
                nll_loss = np.mean(session_outputs['nll_loss'][start_point:])
                perplexity = np.mean(
                    session_outputs["perplexity"][start_point:])
                token_accuracy = np.mean(
                    session_outputs['token_accuracy'][start_point:])
                global_step = session_outputs['global_step'][start_point:][-1]
                logger.info(
                    f"\nEpoch {e}: io_step {io_step+1}/{io_steps_per_epoch}"
                    f"\nGlobal step: {global_step}/{total_global_steps}"
                    f"\nTraining loss : {training_loss:.4f}"
                    f"\nTraining loss standard deviation: {std_training_loss:.4f}"
                    f"\nXentropy loss : {nll_loss:.4f}"
                    f"\nPerplexity : {perplexity:.3f}"
                    f"\nToken accuracy: {token_accuracy:.2f}"
                    f"\nLearning rate: {lr:3.4e}"
                    f"\nThroughput {throughput:.1f} token/s")

                if opts.decode and logger.level <= logging.INFO:
                    try:
                        text_pred, text_target = data_utils.decode_prediction(
                            prediction=session_outputs['predictions'][-1],
                            target=session_outputs['target'][-1],
                            vocab=vocab)
                        logger.info(
                            f"\nTarget: {text_target}\n\nPrediction: {text_pred}\n"
                        )
                    except Exception as ex:
                        logger.warn(f"Decoding failed: {ex}")

                summary_value = [
                    tf.Summary.Value(tag="perplexity",
                                     simple_value=perplexity),
                    tf.Summary.Value(tag="training_loss",
                                     simple_value=training_loss),
                    tf.Summary.Value(tag="stddev_training_loss",
                                     simple_value=std_training_loss),
                    tf.Summary.Value(tag="xentropy_loss",
                                     simple_value=nll_loss),
                    tf.Summary.Value(tag="token_accuracy",
                                     simple_value=token_accuracy),
                    tf.Summary.Value(tag="learning_rate", simple_value=lr),
                    tf.Summary.Value(tag="throughput",
                                     simple_value=throughput),
                    tf.Summary.Value(tag="epoch", simple_value=e)
                ]

                # If we just completed the last io step we do not
                # prune and grow regardless, otherwise check the prune ratio:
                if io_step + 1 < io_steps_per_epoch and transformer.prune_ratio is not None and transformer.prune_ratio > 0:
                    # Retrieve p and g results from the conditional queue:
                    prune_and_grow_data = sess.run(prune_and_grow_dequeue)
                    for k in prune_and_grow_data:
                        prune_and_grow_data[k] = prune_and_grow_data[k][-1]
                    logger.debug(
                        f"Prune and grow outputs: {prune_and_grow_data.keys()}"
                    )

                    prune_and_grow_time, cosine_schedule_factor = transformer.syncPruneAndRegrowOnHost(
                        opts.cosine_prune_schedule, global_step,
                        total_global_steps, prune_and_grow_data)
                    transformer.streamSparsityFromHostToDevice()
                    summary_value.extend([
                        tf.Summary.Value(tag="prune+grow_time",
                                         simple_value=prune_and_grow_time),
                        tf.Summary.Value(tag="cosine_schedule_factor",
                                         simple_value=cosine_schedule_factor)
                    ])

                    for layer_name, sparse_layer in transformer.sparse_layers.items(
                    ):
                        values_var = sparse_layer.get_values_var()
                        grad_w_name = values_var.name.replace(
                            'nz_values:0', 'grad_w')
                        grad_w = np.array(prune_and_grow_data[grad_w_name])
                        if (opts.log_histograms):
                            histogram = tf_utils.make_histogram_proto(
                                grad_w, bins_count=opts.bins_count)
                            summary_value.extend([
                                tf.Summary.Value(tag=layer_name +
                                                 "/dense_grad_w",
                                                 histo=histogram)
                            ])

                        summary_value.extend([
                            tf.Summary.Value(tag=layer_name +
                                             "/dense_grad_w_stddev",
                                             simple_value=np.std(grad_w)),
                            tf.Summary.Value(tag=layer_name +
                                             "/dense_grad_w_mean",
                                             simple_value=np.mean(grad_w)),
                            tf.Summary.Value(tag=layer_name +
                                             "/dense_grad_w_min",
                                             simple_value=np.min(grad_w)),
                            tf.Summary.Value(tag=layer_name +
                                             "/dense_grad_w_max",
                                             simple_value=np.max(grad_w))
                        ])

                        for slot_name, slot in sparse_layer.get_slot_var_dict(
                        ).items():
                            slot_val = prune_and_grow_data[
                                slot.tf_variable.name]
                            if opts.log_histograms:
                                histogram = tf_utils.make_histogram_proto(
                                    slot_val, bins_count=opts.bins_count)
                                summary_value.extend([
                                    tf.Summary.Value(tag=slot_name,
                                                     histo=histogram)
                                ])
                            summary_value.extend([
                                tf.Summary.Value(
                                    tag=slot_name + "/stddev",
                                    simple_value=np.std(slot_val)),
                                tf.Summary.Value(
                                    tag=slot_name + "/mean",
                                    simple_value=np.mean(slot_val)),
                                tf.Summary.Value(
                                    tag=slot_name + "/min",
                                    simple_value=np.min(slot_val)),
                                tf.Summary.Value(tag=slot_name + "/max",
                                                 simple_value=np.max(slot_val))
                            ])

                # Log to tensorboard (outside any graph)
                summary = tf.Summary(value=summary_value)
                summary_writer.add_summary(summary, np.mean(global_step))
                if opts.use_wandb:
                    wandb.tensorflow.log(summary.SerializeToString())
                logger.info(
                    f"Total time for step {time.perf_counter() - step_start_time}"
                )
                logger.info(f"IPU train time for step {ipu_train_time}")

            logger.info(f"Saving model after epoch {e}")
            saver.save(
                sess,
                os.path.join(opts.train_checkpoint_path,
                             'model_' + str(e) + '.ckpt'))
            os.sys.stdout.flush()
        logger.info(f"Training complete.")