예제 #1
0
 def body(img):
     with scopes.ipu_scope('/device:IPU:0'):
         if mode == 'sharded':
             with autoshard.ipu_autoshard():
                 probs = tf.import_graph_def(
                     network.optimized_graph,
                     input_map={network.graph_input: img},
                     name="optimized",
                     return_elements=[network.graph_output])[0]
             autoshard.automatic_sharding(num_shards=num_ipus,
                                          input_ts=img,
                                          loss_ts=probs,
                                          frozen_inference=True)
             outfeed_op = outfeed_queue.enqueue(probs)
             outfeed_op._set_attr(
                 sharding._XLA_SHARDING,
                 attr_value_pb2.AttrValue(
                     s=probs.op.get_attr('_XlaSharding')))
         else:
             probs = tf.import_graph_def(
                 network.optimized_graph,
                 input_map={network.graph_input: img},
                 name="optimized",
                 return_elements=[network.graph_output])[0]
             outfeed_op = outfeed_queue.enqueue(probs)
         # Note that enqueue happens on the IPU.
         return outfeed_op
예제 #2
0
def basic_training_step(image, label, model, opts, learning_rate):
    """
    A basic training step that will work on all hardware
    """
    if opts['no_hostside_norm']:
        image = imagenet_dataset.accelerator_side_preprocessing(image,
                                                                opts=opts)

    logits = model(opts, training=True, image=image)
    loss, cross_entropy, accuracy = calculate_loss(logits, label, opts)

    learning_rate, train_op = calculate_and_apply_gradients(
        loss, opts, learning_rate=learning_rate)
    if opts['shards'] > 1:

        def filter(edge):
            return (any(f in e for e in edge
                        for f in opts["sharding_exclude_filter"])
                    or not any(f in e for e in edge
                               for f in opts["sharding_include_filter"]))

        automatic_sharding(opts['shards'],
                           image,
                           cross_entropy,
                           edge_filter=filter)

    return loss, cross_entropy, accuracy, learning_rate, train_op
예제 #3
0
def my_graph(pa, pb, pc):
    if opts.autoshard:
        result = auto_sharding(pa, pb, pc)
        # The first argument to automatic_sharding is the number
        # of shards.  The second argument is the tensor closest to
        # the input data source in the graph.  In this case it
        # could be pa, pb or pc.  The third argument is the
        # tensor closest to the loss of the graph.  There is no
        # loss function, thus the output of the graph is the
        # closest.  By defining the extremities of the graph
        # the automatic sharding mechanism can calculate which
        # edges it can split across.
        autoshard.automatic_sharding(NUM_SHARDS, pa, result)
    else:
        result = manual_sharding(pa, pb, pc)
    return result
예제 #4
0
파일: train.py 프로젝트: hubayirp/examples
def basic_training_step(image, label, model, opts, learning_rate):
    """
    A basic training step that will work on all hardware
    """
    logits = model(opts, training=True, image=image)
    loss, cross_entropy, accuracy = calculate_loss(logits, label, opts)

    learning_rate, scaled_learning_rate, train_op = calculate_and_apply_gradients(
        loss, opts, learning_rate=learning_rate)
    if opts['shards'] > 1:

        def filter(edge):
            return (any(f in e for e in edge
                        for f in opts["sharding_exclude_filter"])
                    or not any(f in e for e in edge
                               for f in opts["sharding_include_filter"]))

        automatic_sharding(opts['shards'],
                           image,
                           cross_entropy,
                           edge_filter=filter)

    return loss / opts[
        "loss_scaling"], cross_entropy, accuracy, learning_rate, scaled_learning_rate, train_op