Exemplo n.º 1
0
def get_ipu_config(fp_exceptions=True,
                   stochastic_rounding=True,
                   xla_recompute=False,
                   available_memory_proportion=None,
                   disable_graph_outlining=False,
                   num_ipus_required=0,
                   max_cross_replica_sum_buffer_size=0,
                   scheduler_selection='',
                   compile_only=False,
                   partials_type="half"):
    """Builds ipu_options"""
    config = utils.create_ipu_config(
        max_report_size=3001819596000,
        merge_infeed_io_copies=True,
        always_rearrange_copies_on_the_host=False,
        selection_order=utils.SelectionOrder.AUTO,
        disable_graph_outlining=disable_graph_outlining,
        max_cross_replica_sum_buffer_size=max_cross_replica_sum_buffer_size,
        scheduler_selection=scheduler_selection)

    config = utils.auto_select_ipus(config, num_ipus_required)

    config = utils.set_matmul_options(config, clear_pass_type=True)

    if available_memory_proportion is not None:
        config = utils.set_convolution_options(
            config, {
                "availableMemoryProportion": str(available_memory_proportion),
                "partialsType": partials_type
            })
        config = utils.set_matmul_options(
            config, {
                "availableMemoryProportion": str(available_memory_proportion),
                "partialsType": partials_type
            })

    config = utils.set_norm_options(config, use_stable_statistics=True)

    config = utils.set_recomputation_options(config,
                                             allow_recompute=xla_recompute)

    if compile_only:
        config = utils.set_ipu_connection_type(
            config,
            utils.DeviceConnectionType.NEVER,
            ipu_version=2,
            enable_remote_buffers=True)

    config = utils.set_floating_point_behaviour_options(
        config,
        inv=fp_exceptions,
        div0=fp_exceptions,
        oflo=fp_exceptions,
        esr=stochastic_rounding,
        nanoo=fp_exceptions)
    return config
Exemplo n.º 2
0
    def testResetSeed(self):
        # The dataset for feeding the graphs
        ds = dataset_ops.Dataset.from_tensors(
            array_ops.constant(1.0, shape=[SIZE]))
        ds = ds.map(lambda x: [x, x])
        ds = ds.repeat()

        # The host side queues
        infeed_queue = ipu_infeed_queue.IPUInfeedQueue(
            ds, feed_name="infeed", replication_factor=REPLICAS)
        outfeed_queue = ipu_outfeed_queue.IPUOutfeedQueue(
            feed_name="outfeed", replication_factor=REPLICAS)

        # The device side
        def body(x1, x2):
            d1 = rand_ops.dropout(x1)
            d2 = rand_ops.dropout(x2)
            outfeed = outfeed_queue.enqueue({'d1': d1, 'd2': d2})
            return outfeed

        def my_net():
            r = loops.repeat(REPEATS, body, [], infeed_queue)
            return r

        with scopes.ipu_scope('/device:IPU:0'):
            res = ipu_compiler.compile(my_net, inputs=[])

        # The outfeed dequeue has to happen after the outfeed enqueue
        dequeue_outfeed = outfeed_queue.dequeue()

        # Configure the hardware
        config = utils.create_ipu_config(profiling=True)
        config = utils.auto_select_ipus(config, REPLICAS)
        config = utils.set_floating_point_behaviour_options(config)
        utils.configure_ipu_system(config)

        with session.Session() as sess:
            res_all = set()
            total = 0

            sess.run(infeed_queue.initializer)

            for _ in range(EXECS):
                sess.run(res)
                outfed_result = sess.run(dequeue_outfeed)
                for r in np.array(list(outfed_result.values())).reshape(
                    [-1, SIZE]):
                    total += 1
                    res_all.add(r.tostring())

            # 2 dropouts per replica * REPLICAS * REPEATS * EXECS
            expected = 2 * REPLICAS * REPEATS * EXECS
            self.assertEqual(total, expected)
            self.assertEqual(len(res_all), expected)
Exemplo n.º 3
0
def get_config(prng=False,
               ipu_id=-1,
               shards=1,
               number_of_replicas=1,
               max_cross_replica_buffer_size=10*1024*1024,
               merge_infeed_io_copies=True,
               fp_exceptions=True,
               xla_recompute=False,
               seed=None,
               profile=None,
               availableMemoryProportion=None,
               stable_norm=False):
    """Builds ipu_options"""

    profile_exec_modes = {"NO_PROFILE": ExecutionProfileType.NO_PROFILE,
                          "TILE_PROFILE": ExecutionProfileType.TILE_PROFILE,
                          "DEVICE_PROFILE": ExecutionProfileType.DEVICE_PROFILE,
                          "IPU_PROFILE": ExecutionProfileType.IPU_PROFILE}

    config = utils.create_ipu_config(max_cross_replica_sum_buffer_size=max_cross_replica_buffer_size,
                                     merge_infeed_io_copies=merge_infeed_io_copies,
                                     always_rearrange_copies_on_the_host=False,
                                     profiling=profile is not None,
                                     profile_execution=profile_exec_modes[profile] if profile else None)

    if "GCL_REAL_COLLECTIVES" in os.environ:
        config = utils.set_gcl_options(config, num_io_tiles=128, gcl_options={"useGclCollectives": "true", })

    if ipu_id == -1:
        config = utils.auto_select_ipus(config, number_of_replicas*shards)
    else:
        config = utils.select_ipus(config, [ipu_id])
    config = utils.set_compilation_options(config, {
        "device.clearAtomicFlagAfterExchange": "false",
        "prng.enable": "true" if prng else "false",
        "target.deterministicWorkers": "false" if seed is None else "true",
    })

    if availableMemoryProportion is not None:
        config = utils.set_convolution_options(config, {
            "availableMemoryProportion": str(availableMemoryProportion)
        })

    if stable_norm:
        config = utils.set_norm_options(config, use_stable_statistics=True)

    if xla_recompute:
        utils.set_recomputation_options(config, allow_recompute=True)

    config = utils.set_floating_point_behaviour_options(config, inv=fp_exceptions, div0=fp_exceptions,
                                                        oflo=fp_exceptions, esr=prng, nanoo=True)

    return config
Exemplo n.º 4
0
def run_language_model(opts):
    if opts.random_seed is not None:
        utils.reset_ipu_seed(opts.random_seed)

    # Setup and acquire an IPU device:
    logging.info("Acquiring devices")
    if not opts.pipeline:
        opts.num_shards = 1  # FIX-ME enable sparse models using multiple shards

    # Make sure that no matter the number of shards/stages required, we always
    # acquire a power of 2 ipus (else attachment will fail)
    k = 0
    while 2**k < opts.num_shards:
        k += 1
    num_ipus = 2**k
    logger.info(f"Need {opts.num_shards} IPUs, requesting {num_ipus}")
    config = utils.create_ipu_config()

    if opts.compile_only:
        if opts.compile_only_ipu_version is None:
            raise AttributeError(
                "Must provide --compile-only-ipu-version if --compile-only is set."
            )

        config = utils.set_ipu_connection_type(
            config,
            utils.DeviceConnectionType.NEVER,
            ipu_version=opts.compile_only_ipu_version,
            enable_remote_buffers=True)

    config = utils.auto_select_ipus(config, num_ipus)
    config = utils.set_recomputation_options(config,
                                             allow_recompute=opts.recompute)
    # Enable stochastic rounding
    config = utils.set_floating_point_behaviour_options(config,
                                                        inv=False,
                                                        div0=False,
                                                        oflo=False,
                                                        esr=True,
                                                        nanoo=False)
    config = sparse.set_system_config(
        config, custom_op_debug_printing=opts.debug_dense_grad)
    utils.configure_ipu_system(config)

    transformer = DynsparseTransformer(opts)
    if opts.mode in ["all", "train"]:
        run_training(opts, transformer)

    if opts.mode in ["all", "test"]:
        run_testing(opts, transformer)
Exemplo n.º 5
0
def get_config(fp_exceptions,
               xla_recompute,
               disable_graph_outlining,
               num_required_ipus,
               enable_stochastic_rounding,
               max_cross_replica_sum_buffer_size,
               scheduler_selection,
               compile_only,
               ipu_id):

    # Builds ipu_options
    config = utils.create_ipu_config(
        merge_infeed_io_copies=True,
        always_rearrange_copies_on_the_host=False,
        disable_graph_outlining=disable_graph_outlining,
        selection_order=utils.SelectionOrder.AUTO,
        scheduler_selection=scheduler_selection
    )

    if ipu_id:
        config = utils.select_ipus(config, [ipu_id])
    else:
        config = utils.auto_select_ipus(config, num_required_ipus)

    config = utils.set_recomputation_options(
        config, allow_recompute=xla_recompute)
    # simple way to skip the big `Transpose` operation due to bad allocation
    # config = utils.set_matmul_options(config, clear_pass_type=True)
    config = utils.set_norm_options(config, use_stable_statistics=True)
    config = utils.set_floating_point_behaviour_options(
        config,
        inv=fp_exceptions,
        div0=fp_exceptions,
        oflo=fp_exceptions,
        esr=enable_stochastic_rounding,
        nanoo=fp_exceptions)
    config = utils.set_optimization_options(
        config,
        merge_remote_buffers=True,
        max_cross_replica_sum_buffer_size=max_cross_replica_sum_buffer_size)

    # Do not acquire a device, compile only.
    if compile_only:
        config = utils.set_ipu_connection_type(
            config, utils.DeviceConnectionType.NEVER, ipu_version=2, enable_remote_buffers=True)

    return config
Exemplo n.º 6
0
def get_config(prng=False,
               ipu_id=-1,
               shards=1,
               number_of_replicas=1,
               max_cross_replica_buffer_size=10 * 1024 * 1024,
               merge_infeed_io_copies=True,
               fp_exceptions=True,
               xla_recompute=False,
               seed=None,
               profile=False,
               availableMemoryProportion=None):
    """Builds ipu_options"""
    config = utils.create_ipu_config(
        max_cross_replica_sum_buffer_size=max_cross_replica_buffer_size,
        merge_infeed_io_copies=merge_infeed_io_copies,
        always_rearrange_copies_on_the_host=False,
        profiling=profile,
        profile_execution=profile)
    if ipu_id == -1:
        config = utils.auto_select_ipus(config, number_of_replicas * shards)
    else:
        config = utils.select_ipus(config, [ipu_id])
    config = utils.set_compilation_options(
        config, {
            "device.clearAtomicFlagAfterExchange": "false",
            "prng.enable": "true" if prng else "false",
            "target.deterministicWorkers": "false" if seed is None else "true",
        })

    if availableMemoryProportion is not None:
        config = utils.set_convolution_options(
            config,
            {"availableMemoryProportion": str(availableMemoryProportion)})

    if xla_recompute:
        utils.set_recomputation_options(config, allow_recompute=True)

    config = utils.set_floating_point_behaviour_options(config,
                                                        inv=fp_exceptions,
                                                        div0=fp_exceptions,
                                                        oflo=fp_exceptions,
                                                        esr=prng,
                                                        nanoo=True)

    return config
def training_graph(opts, training_data, device_index=0, learning_rate=0.001):
    train_graph = tf.Graph()

    with train_graph.as_default():

        dataset, _, placeholders = training_data.get_dataset(opts,
                                                             is_training=True)
        infeed = ipu_infeed_queue.IPUInfeedQueue(
            dataset, "training_dataset_infeed{0}".format(device_index), 0)

        with ipu_scope('/device:IPU:0'):

            def comp_fn():
                def body(total_loss_, sum_rmse_metric, *args):
                    data_tensors = args
                    observed_ratings = data_tensors[0]
                    loss, rmse_metric, apply_grads_ = graph_builder(
                        opts,
                        observed_ratings=observed_ratings,
                        learning_rate=placeholders["learning_rate"])
                    with tf.control_dependencies([apply_grads_]):
                        return total_loss_ + loss, sum_rmse_metric + rmse_metric

                return loops.repeat(
                    opts.batches_per_step, body,
                    [tf.constant(0, tf.float32),
                     tf.constant(0, tf.float32)], infeed)

            total_loss, sum_rmse_metric = ipu_compiler.compile(comp_fn, [])

        rmse = sum_rmse_metric / opts.batches_per_step
        loss = total_loss / opts.batches_per_step

        tf.summary.scalar("loss", loss)
        tf.summary.scalar("learning_rate", learning_rate)
        tf.summary.scalar("RMSE/train", rmse)

        train_summary = tf.summary.merge_all()
        train_saver = tf.train.Saver()

        ipu_utils.move_variable_initialization_to_cpu()
        train_init = tf.global_variables_initializer()

    train_writer = tf.summary.FileWriter(opts.logs_path +
                                         '/train{0}'.format(device_index),
                                         graph=train_graph,
                                         flush_secs=30)

    ipu_options = ipu_utils.create_ipu_config(profiling=False)
    ipu_options = ipu_utils.set_floating_point_behaviour_options(
        ipu_options,
        inv=opts.fp_exceptions,
        div0=opts.fp_exceptions,
        oflo=opts.fp_exceptions,
        esr=opts.prng,
        nanoo=True)
    ipu_options = ipu_utils.auto_select_ipus(ipu_options, 1)
    ipu_utils.configure_ipu_system(ipu_options)

    train_sess = tf.Session(graph=train_graph)

    return GraphOps(train_graph, train_sess, train_init,
                    [loss, train_summary, rmse], placeholders, infeed,
                    train_saver, train_writer)
Exemplo n.º 8
0
def get_config(prng=False,
               ipu_id=-1,
               shards=1,
               number_of_replicas=1,
               max_cross_replica_buffer_size=10 * 1024 * 1024,
               merge_infeed_io_copies=True,
               fp_exceptions=True,
               half_partials=False,
               conv_dithering=False,
               xla_recompute=False,
               seed=None,
               profile=None,
               availableMemoryProportion=None,
               stable_norm=False,
               internalExchangeOptimisationTarget=None,
               limitVertexState=None):
    """Builds ipu_options"""

    profile_exec_modes = {
        "NO_PROFILE": ExecutionProfileType.NO_PROFILE,
        "TILE_PROFILE": ExecutionProfileType.TILE_PROFILE,
        "DEVICE_PROFILE": ExecutionProfileType.DEVICE_PROFILE,
        "IPU_PROFILE": ExecutionProfileType.IPU_PROFILE
    }

    config = utils.create_ipu_config(
        merge_infeed_io_copies=merge_infeed_io_copies,
        always_rearrange_copies_on_the_host=False,
        profiling=profile is not None,
        profile_execution=profile_exec_modes[profile] if profile else None)

    config = utils.set_optimization_options(
        config,
        max_cross_replica_sum_buffer_size=max_cross_replica_buffer_size)

    if ipu_id == -1:
        config = utils.auto_select_ipus(config, number_of_replicas * shards)
    else:
        config = utils.select_ipus(config, [ipu_id])
    config = utils.set_compilation_options(
        config, {
            "device.clearAtomicFlagAfterExchange": "false",
            "prng.enable": "true" if prng else "false",
            "target.deterministicWorkers":
            "false" if seed is None else "portable",
        })

    if internalExchangeOptimisationTarget is not None:
        utils.set_compilation_options(
            config, {
                "opt.internalExchangeOptimisationTarget":
                internalExchangeOptimisationTarget
            })

    if limitVertexState is not None:
        config = utils.set_compilation_options(
            config, {
                "opt.limitVertexStateToLower256K":
                "true" if limitVertexState else "false"
            })

    if availableMemoryProportion is not None:
        config = utils.set_convolution_options(
            config,
            {"availableMemoryProportion": str(availableMemoryProportion)})

    if half_partials:
        config = utils.set_convolution_options(config,
                                               {"partialsType": 'half'})
        config = utils.set_matmul_options(config, {"partialsType": 'half'})

    if conv_dithering:
        config = utils.set_convolution_options(config,
                                               {"enableConvDithering": "true"})

    if stable_norm:
        config = utils.set_norm_options(config, use_stable_statistics=True)

    if xla_recompute:
        utils.set_recomputation_options(config, allow_recompute=True)

    config = utils.set_floating_point_behaviour_options(config,
                                                        inv=fp_exceptions,
                                                        div0=fp_exceptions,
                                                        oflo=fp_exceptions,
                                                        esr=prng,
                                                        nanoo=True)

    return config
Exemplo n.º 9
0
def get_config(prng=False,
               ipu_id=-1,
               shards=1,
               number_of_replicas=1,
               max_cross_replica_buffer_size=10*1024*1024,
               merge_infeed_io_copies=True,
               fp_exceptions=True,
               half_partials=False,
               conv_dithering=False,
               xla_recompute=False,
               seed=None,
               profile=None,
               availableMemoryProportion=None,
               stable_norm=False,
               internalExchangeOptimisationTarget=None):
    """Builds ipu_options"""

    profile_exec_modes = {"NO_PROFILE": ExecutionProfileType.NO_PROFILE,
                          "TILE_PROFILE": ExecutionProfileType.TILE_PROFILE,
                          "DEVICE_PROFILE": ExecutionProfileType.DEVICE_PROFILE,
                          "IPU_PROFILE": ExecutionProfileType.IPU_PROFILE}

    config = utils.create_ipu_config(merge_infeed_io_copies=merge_infeed_io_copies,
                                     always_rearrange_copies_on_the_host=False,
                                     profiling=profile is not None,
                                     profile_execution=profile_exec_modes[profile] if profile else None)

    config = utils.set_optimization_options(config,
                                            max_cross_replica_sum_buffer_size=max_cross_replica_buffer_size)

    if "GCL_REAL_COLLECTIVES" in os.environ:
        # The GCL_NUM_IO_TILES environment variable sets how many tiles in the IPU are reserved for Graphcore Communication Library (GCL) collectives.
        iotiles = int(os.environ['GCL_NUM_IO_TILES'])
        if iotiles % 2 or iotiles < 32 or iotiles > 192:
            raise ValueError(
                'GCL IO Tiles must be a multiple of 2 in between 32 and 192.'.format(iotiles))

        config = utils.set_gcl_options(config, num_io_tiles=iotiles, gcl_options={
                                       "useGclCollectives": "true", })

    if ipu_id == -1:
        config = utils.auto_select_ipus(config, number_of_replicas*shards)
    else:
        config = utils.select_ipus(config, [ipu_id])
    config = utils.set_compilation_options(config, {
        "device.clearAtomicFlagAfterExchange": "false",
        "prng.enable": "true" if prng else "false",
        "target.deterministicWorkers": "false" if seed is None else "portable",
    })

    if internalExchangeOptimisationTarget is not None:
        utils.set_compilation_options(config, {
            "opt.internalExchangeOptimisationTarget": internalExchangeOptimisationTarget
        })

    if availableMemoryProportion is not None:
        config = utils.set_convolution_options(config, {
            "availableMemoryProportion": str(availableMemoryProportion)
        })

    if half_partials:
        config = utils.set_convolution_options(config, {
            "partialsType": 'half'
        })
        config = utils.set_matmul_options(config, {
            "partialsType": 'half'
        })

    if conv_dithering:
        config = utils.set_convolution_options(config, {
            "enableConvDithering": "true"
        })

    if stable_norm:
        config = utils.set_norm_options(config, use_stable_statistics=True)

    if xla_recompute:
        utils.set_recomputation_options(config, allow_recompute=True)

    config = utils.set_floating_point_behaviour_options(config, inv=fp_exceptions, div0=fp_exceptions,
                                                        oflo=fp_exceptions, esr=prng, nanoo=True)

    return config
Exemplo n.º 10
0
def run_mnist(opts):
    if opts.pipelining and opts.gradient_accumulation_count < 4:
        raise ValueError(
            "Pipelining requires at least 4 gradient accumulation steps.")
    if opts.seed is not None:
        utils.reset_ipu_seed(opts.seed)
    random_gen = np.random.default_rng(seed=opts.seed)

    # Use Keras to get the dataset:
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

    # Sizes/shapes for the dataset:
    image_shape = x_train.shape[1:]
    num_pixels = image_shape[0] * image_shape[1]
    batch_size = opts.batch_size // opts.gradient_accumulation_count
    batch_shape = [batch_size, num_pixels]
    num_train = y_train.shape[0]
    num_test = y_test.shape[0]
    dtype = tf.float16 if opts.data_type == 'fp16' else tf.float32

    # Flatten the images and cast the labels:
    permutation = make_pixel_permutation_matrix(opts, image_shape)

    x_train_flat = x_train.astype(dtype.as_numpy_dtype()).reshape(
        -1, num_pixels)
    x_test_flat = x_test.astype(dtype.as_numpy_dtype()).reshape(-1, num_pixels)

    x_train_flat[:, ...] = x_train_flat[:, permutation]
    x_test_flat[:, ...] = x_test_flat[:, permutation]

    if opts.records_path:
        os.makedirs(opts.records_path, exist_ok=True)
        filename = os.path.join(opts.records_path, "pixel_permutation")
        np.save(filename, permutation)

    y_train = y_train.astype(np.int32)
    y_test = y_test.astype(np.int32)

    # Decide how to split epochs into loops up front:
    if opts.pipelining:
        logger.info(
            f"Pipelined: micro-batch-size: {batch_size} accumulation-count: {opts.gradient_accumulation_count}"
        )
    batches_per_epoch = num_train // (batch_size *
                                      opts.gradient_accumulation_count)
    test_batches = num_test // (batch_size * opts.gradient_accumulation_count)

    batches_per_step = opts.batches_per_step_override
    if batches_per_step is None:
        batches_per_step = batches_per_epoch // opts.steps_per_epoch

    if not (batches_per_epoch % opts.steps_per_epoch) == 0:
        raise ValueError(
            f"IPU steps per epoch {opts.steps_per_epoch} must divide batches per epoch {batches_per_epoch} exactly."
        )

    # Create FC layer descriptions:
    fc_layers = create_fc_layers(opts, batch_shape, random_gen)
    for name, fc in fc_layers.items():
        logger.info(f"Layer Config: {name}: {type(fc)}")

    # Put placeholders on the CPU host:
    with tf.device("cpu"):
        lr_placeholder = tf.placeholder(dtype, shape=[])

    # Create dataset and IPU feeds:
    def make_generator(features, labels):
        return lambda: zip(features, labels)

    # Input pipeline
    def make_dataset(features, labels, is_training: bool):
        dataset = tf.data.Dataset.from_generator(
            generator=make_generator(features, labels),
            output_types=(features.dtype, labels.dtype),
            output_shapes=(features.shape[1:], labels.shape[1:]))

        if is_training:
            dataset = dataset.shuffle(buffer_size=num_train,
                                      seed=opts.seed).cache()

        dataset = dataset.repeat().batch(batch_size, drop_remainder=True)
        return dataset

    train_dataset = make_dataset(features=x_train_flat,
                                 labels=y_train,
                                 is_training=True)

    test_dataset = make_dataset(features=x_test_flat,
                                labels=y_test,
                                is_training=False)

    infeed_train_queue = ipu_infeed_queue.IPUInfeedQueue(
        train_dataset, feed_name="train_infeed")
    outfeed_train_queue = ipu_outfeed_queue.IPUOutfeedQueue(
        feed_name="train_outfeed")
    outfeed_prune_and_grow_queue = ipu_outfeed_queue.IPUOutfeedQueue(
        feed_name="train_prune_and_grow_outfeed")
    infeed_test_queue = ipu_infeed_queue.IPUInfeedQueue(
        test_dataset, feed_name="test_infeed")
    outfeed_test_queue = ipu_outfeed_queue.IPUOutfeedQueue(
        feed_name="test_outfeed")

    # Get optimiser
    opt_cls, opt_kws = build_optimizer(opts.optimizer, opts.optimizer_arg)
    logger.info('Optimiser %s, optimiser keywords %s', opt_cls.__name__,
                opt_kws)

    # Get the bound model functions
    bound_model_fn = make_bound_model_pipelining if opts.pipelining else make_bound_model
    (bound_train_loop, bound_test_loop), train_inputs = bound_model_fn(
        fc_layers=fc_layers,
        opts=opts,
        lr_placeholder=lr_placeholder,
        opt_cls=opt_cls,
        opt_kws=opt_kws,
        train_batches_per_step=batches_per_step,
        test_batches_per_step=test_batches,
        train_queues=(outfeed_train_queue, infeed_train_queue),
        test_queues=(outfeed_test_queue, infeed_test_queue),
        png_queue=outfeed_prune_and_grow_queue,
        disable_dense_grad=opts.disable_dense_grad_override)

    # Use the bound builder functions to place the model on the IPU:
    with scopes.ipu_scope("/device:IPU:0"):
        train_loop = ipu_compiler.compile(bound_train_loop,
                                          inputs=train_inputs)
        test_loop = ipu_compiler.compile(bound_test_loop)

    # Placeholders can only be created on cpu after all the slots have registered:
    with tf.device("cpu"):
        for fc in fc_layers.values():
            fc.create_placeholders()

    # Create update op on IPU:
    with scopes.ipu_scope("/device:IPU:0"):
        update_representation = build_update_op(fc_layers)

    # Initialisers should go on the CPU:
    with tf.device("cpu"):
        metrics_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                         scope="metrics")
        metrics_initializer = tf.variables_initializer(var_list=metrics_vars)
        saver = tf.train.Saver()

    # Setup and acquire an IPU device:
    utils.move_variable_initialization_to_cpu()
    config = utils.create_ipu_config()
    config = utils.auto_select_ipus(config, 1)
    config = utils.set_floating_point_behaviour_options(config,
                                                        inv=False,
                                                        div0=False,
                                                        oflo=False,
                                                        esr=True,
                                                        nanoo=False)
    utils.configure_ipu_system(config)

    # These allow us to retrieve the results of IPU feeds:
    dequeue_test_outfeed = outfeed_test_queue.dequeue()
    dequeue_train_outfeed = outfeed_train_queue.dequeue()

    # Add dense gradient outfeed if we have sparse layers
    dequeue_prune_and_grow_outfeed = None
    if not opts.disable_dense_grad_override and any(
            fc.is_sparse() for fc in fc_layers.values()):
        dequeue_prune_and_grow_outfeed = outfeed_prune_and_grow_queue.dequeue()

    logger.info(
        f"Image shape: {image_shape} Training examples: {num_train} Test examples: {num_test}"
    )
    logger.info(
        f"Epochs: {opts.epochs} Batch-size: {batch_size} Steps-per-epoch: {opts.steps_per_epoch} Batches-per-step: {batches_per_step}"
    )
    total_steps = opts.steps_per_epoch * opts.epochs
    logger.info(f"Total steps: {total_steps}")

    if opts.log:
        # Open log and write header fields:
        log_file = open(opts.log, 'w')
        d1, d2 = opts.densities
        log_file.write(f"Iteration Density_{d1}_{d2}\n")

    if opts.restore:
        logpath = os.path.join(opts.checkpoint_path, opts.restore)
    else:
        logpath = os.path.join(opts.checkpoint_path,
                               datetime.now().strftime("%Y%m%d-%H%M%S"))
    summary_writer = tf.summary.FileWriter(logpath)

    if opts.records_path:
        # Save the first hidden layer's weight mask for later analysis:
        save_weights(opts, 'fc1', fc_layers['fc1'], 0)

    # Run the model:
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(infeed_train_queue.initializer)

        if opts.restore:
            saver.restore(sess, logpath + '/model.ckpt')

        if opts.test_mode in ["all", "training"]:
            logger.info(f"Training...")
            start = opts.start_epoch if opts.restore else 0
            progress = tqdm(
                range(start, opts.epochs),
                bar_format='{desc} Epoch: {n_fmt}/{total_fmt} {bar}')
            for e in progress:
                for i in range(opts.steps_per_epoch):
                    sess.run(metrics_initializer)

                    t1 = time.perf_counter()
                    sess.run(train_loop,
                             feed_dict={lr_placeholder: scheduler(e, opts)})
                    t2 = time.perf_counter()
                    sess_time = t2 - t1
                    batch_time = sess_time / batches_per_step
                    throughput = batch_size / batch_time
                    logger.info(f"Time for sess.run: {sess_time:0.3f} "
                                f"Time per batch: {batch_time:0.6f} "
                                f"Throughput: {throughput}")

                    if opts.single_train_step_only:
                        return

                    train_outputs = sess.run(dequeue_train_outfeed)
                    if opts.pipelining:
                        train_outputs = train_outputs[-1]

                    # Get the last value for all items:
                    for k, v in train_outputs.items():
                        train_outputs[k] = v[-1]
                    logger.debug(f"Train outputs: {train_outputs.keys()}")

                    # Merge prune and grow fetches with last fetches:
                    if dequeue_prune_and_grow_outfeed is not None:
                        png_data = sess.run(dequeue_prune_and_grow_outfeed)
                        for k in png_data:
                            png_data[k] = png_data[k][-1]
                        logger.debug(
                            f"Prune and grow outputs: {png_data.keys()}")

                    steps = 1 + i + e * opts.steps_per_epoch
                    batches_processed = batches_per_step * steps
                    for name, fc in fc_layers.items():
                        if fc.is_sparse():
                            var_name = fc.get_values_var().name
                            logger.info(
                                f"Average weights for layer {name}: {np.mean(png_data[var_name])}"
                            )
                            for slot_name in fc.sparse_slots:
                                logger.info(
                                    f"Average {slot_name} for layer {name} : {np.mean(png_data[slot_name])}"
                                )
                            if i == 0 and e == opts.start_epoch:
                                metainfo = sess.run(fc.get_metainfo_var())
                            else:
                                metainfo = None
                            if not opts.disable_pruning:
                                logger.info(
                                    f"Starting prune and grow for layer {name}"
                                )
                                t0 = time.perf_counter()
                                prune_sched = prune_and_grow(name,
                                                             fc,
                                                             png_data,
                                                             random_gen,
                                                             steps,
                                                             total_steps,
                                                             opts,
                                                             metainfo=metainfo)
                                t1 = time.perf_counter()
                                logger.info(
                                    f"Prune and grow for layer {name} complete in {t1-t0:0.3f} seconds"
                                )
                                logger.info(
                                    f"Pruned proportion: {prune_sched}")
                                if opts.use_wandb:
                                    wandb.log({'Prune Schedule': prune_sched},
                                              commit=False)

                    if opts.log:
                        log_file.write(
                            f"{batches_processed} {train_outputs['acc']}\n")
                    if opts.use_wandb:
                        wandb.log(
                            {
                                'Loss': train_outputs['mean_loss'],
                                'Accuracy': train_outputs['acc'],
                                'Throughput': throughput
                            },
                            commit=True)
                    progress.set_description(
                        f"Loss {train_outputs['mean_loss']:.5f} Accuracy {train_outputs['acc']:.5f}"
                    )

                    # Only need to feed an updated sparsity representation if we are running rig-L:
                    if not opts.disable_pruning:
                        # Merge the feeds needed for all layers:
                        sparse_feed = {}
                        for fc in fc_layers.values():
                            if fc.is_sparse():
                                sparse_feed.update(fc.feed_dict())
                        sess.run(update_representation, feed_dict=sparse_feed)

                if e % opts.checkpoint_freq == 0:
                    logger.info(f"Saving...")
                    saver.save(sess, os.path.join(logpath, 'model.ckpt'))

        if opts.test_mode in ["all", "tests"]:
            logger.info(f"Testing...")
            sess.run(metrics_initializer)
            sess.run(infeed_test_queue.initializer)
            sess.run(test_loop)
            result = sess.run(dequeue_test_outfeed)

            test_loss = result['mean_loss'][-1]
            test_acc = result['acc'][-1]
            logger.info(
                f"Test loss: {test_loss:.8f} Test accuracy: {test_acc:.8f} Name: {opts.log}"
            )
            if opts.use_wandb:
                wandb.run.summary["Test Loss"] = test_loss
                wandb.run.summary["Test Accuracy"] = test_acc
                wandb.log(commit=True)