예제 #1
0
def set_popdist_args(args):
    if not popdist.isPopdistEnvSet():
        args.use_popdist = False
        args.popdist_size = 1
        args.popdist_rank = 0
        return

    if args.inference:
        raise RuntimeError("Distributed execution is only supported for training")

    try:
        import horovod.popart as hvd
        hvd.init()
    except ImportError:
        raise ImportError("Could not find the PopART horovod extension. "
                          "Please install the horovod .whl provided in the Poplar SDK.")

    args.use_popdist = True
    popdist_local_factor = popdist.getNumLocalReplicas()
    if args.replication_factor > 1 and args.replication_factor != popdist_local_factor:
        logger.warning(f"Overwriting the local replication factor {args.replication_factor} to {popdist_local_factor}")
    args.replication_factor = popdist_local_factor

    args.popdist_size = popdist.getNumTotalReplicas() // popdist.getNumLocalReplicas()
    args.popdist_rank = popdist.getReplicaIndexOffset() // popdist.getNumLocalReplicas()
    args.checkpoint_dir = args.checkpoint_dir + "_rank_" + str(args.popdist_rank)

    from mpi4py import MPI
    setup_comm(MPI.COMM_WORLD)
예제 #2
0
def try_import_horovod():
    try:
        import horovod.popart as hvd
        hvd.init()
    except ImportError:
        raise ImportError("Could not find the PopART horovod extension. "
                          "Please install the horovod .whl provided in the Poplar SDK.")
    return hvd
예제 #3
0
def bert_distributed_training_session(args, **kwargs):
    try:
        import horovod.popart as hvd
        hvd.init()
    except ImportError:
        raise ImportError(
            "Could not find the PopART horovod extension. "
            "Please install the horovod .whl provided in the Poplar SDK.")

    session = hvd.DistributedTrainingSession(**kwargs)
    logger.info("Compiling Training Graph")
    compile_graph_checked(args, session)

    logger.info("Broadcasting weights to all instances")
    hvd.broadcast_weights(session)

    return session
예제 #4
0
def train():
    builder, proto, data_in, labels_in, output, loss = create_model()

    batches_per_step = 32
    anchor_desc = {
        output: popart.AnchorReturnType("All"),
        loss: popart.AnchorReturnType("All")
    }
    dataFlow = popart.DataFlow(batches_per_step, anchor_desc)

    userOpts = popart.SessionOptions()
    device = get_device()

    # Enable host side AllReduce operations in the graph
    userOpts.hostAllReduce = True
    training, optimizer = init_session(proto, loss, dataFlow, userOpts, device)
    if userOpts.hostAllReduce:
        hvd.init()

        distributed_optimizer = hvd.DistributedOptimizer(
            optimizer, training.session, userOpts)
        distributed_optimizer.insert_host_allreduce()

        # Broadcast weights to all the other processes
        hvd.broadcast_weights(training.session, root_rank=0)

    training.session.weightsFromHost()

    # Synthetic data
    data = np.random.normal(size=(batches_per_step, batch_size, 784)).astype(
        np.float32)
    labels = np.zeros((batches_per_step, batch_size, 1)).astype(np.int32)

    num_training_steps = 10

    for _ in range(num_training_steps):
        stepio = popart.PyStepIO({
            data_in: data,
            labels_in: labels
        }, training.anchors)
        training.session.run(stepio)
예제 #5
0
def train(opts):
    # Initialize the Horovd runtime
    hvd.init()

    # Do not require the mnist data to be present if running with synthetic data
    if opts.syn_data_type in ["random_normal", "zeros"]:
        train_data, train_labels, test_data, test_labels = load_dummy(opts)
    else:
        train_data, train_labels, test_data, test_labels = load_mnist()

    if not opts.test_mode:
        max_value = len(test_data) // opts.batch_size
        if max_value < opts.batches_per_step:
            print("(batches-per-step * batch-size) is larger than test set!\n"
                  " Reduced batches-per-step to: {}\n".format(max_value))
            opts.batches_per_step = max_value

    shard_dataset = True
    if opts.syn_data_type is not "off":
        shard_dataset = False
    training_set = DataSet(
        opts.batch_size, opts.batches_per_step, train_data, train_labels, shard=shard_dataset)

    if hvd.rank() == 0:
        test_set = DataSet(opts.batch_size, opts.batches_per_step,
                           test_data, test_labels, shard=False)

    print("Creating ONNX model.")
    proto, data_in, labels_in, output, loss = create_model(opts.batch_size)

    # Describe how to run the model
    anchor_desc = {output: popart.AnchorReturnType("ALL"),
                   loss: popart.AnchorReturnType("ALL")}
    dataFlow = popart.DataFlow(opts.batches_per_step, anchor_desc)

    # Options
    userOpts = popart.SessionOptions()

    # The validation graph by default will be optimized to change all variables to constants
    # This prevents that, which allows for checkpoints to be loaded into the model without recompiling
    userOpts.constantWeights = False

    # If requested, setup synthetic data
    if opts.syn_data_type in ["random_normal", "zeros"]:
        print(
            "Running with Synthetic Data Type '{}'".format(opts.syn_data_type)
        )
        if opts.syn_data_type == "random_normal":
            userOpts.syntheticDataMode = popart.SyntheticDataMode.RandomNormal
        elif opts.syn_data_type == "zeros":
            userOpts.syntheticDataMode = popart.SyntheticDataMode.Zeros

    # Enable auto-sharding
    if opts.num_ipus > 1:
        if not opts.pipeline:
            raise ValueError(f"Auto sharded graph only supported with pipelining")
        userOpts.virtualGraphMode = popart.VirtualGraphMode.Auto

    # Enable pipelining
    if opts.pipeline:
        userOpts.enablePipelining = True

    # Enable host side all reduce operations from the PopART program
    userOpts.hostAllReduce = True

    # A single device is shared between training and validation sessions
    device = get_device(opts.num_ipus, opts.simulation)

    rank = hvd.rank()
    training = init_session(proto, loss, dataFlow,
                            userOpts, device, training=True)
    if rank == 0:
        validation = init_session(
            proto, loss, dataFlow, userOpts, device, training=False)

    # Create the Horovod distributed optimizer and insert the allreduce
    # collective operation across the instances
    distributed_optimizer = hvd.DistributedOptimizer(
        training.optimizer, training.session, userOpts)
    distributed_optimizer.insert_host_allreduce()


    # Make weight transfer file. The rank 0 process will checkpoint its weights
    onnx_file_name = 'mnist_tmp.onnx'

    print("Running training loop.")
    for i in range(opts.epochs):
        # Broadcast the weights from the rank 0 process to the other processes
        hvd.broadcast_weights(training.session, root_rank=0)

        training.session.weightsFromHost()

        # Training
        for step, (data, labels) in enumerate(training_set):
            stepio = popart.PyStepIO(
                {data_in: data, labels_in: labels}, training.anchors)

            start = time()
            training.session.run(stepio, 'Epoch ' + str(i) + ' training step' + str(step))
            if opts.test_mode == "training":
                log_run_info(training, start, opts)

        training.session.weightsToHost()

        if rank == 0:
            aggregated_loss = 0
            aggregated_accuracy = 0
            training.session.modelToHost(onnx_file_name)
            validation.session.resetHostWeights(onnx_file_name)
            validation.session.weightsFromHost()

            # Evaluation
            for step, (data, labels) in enumerate(test_set):
                stepio = popart.PyStepIO(
                    {data_in: data, labels_in: labels}, validation.anchors)
                start = time()
                validation.session.run(stepio, 'Epoch ' + str(i) + ' evaluation step ' + str(step))
                if opts.test_mode == "inference":
                    log_run_info(validation, start, opts)

                # Loss
                aggregated_loss += np.mean(validation.anchors[loss])
                # Accuracy
                results = np.argmax(validation.anchors[output].reshape(
                    [test_set.inputs_per_step, 10]), 1)
                num_correct = np.sum(results == labels.reshape(
                    [test_set.inputs_per_step]))
                aggregated_accuracy += num_correct / test_set.inputs_per_step

            # Log statistics
            aggregated_loss /= len(test_set)
            aggregated_accuracy /= len(test_set)
            print("Epoch #{}".format(i + 1))
            print("   Loss={0:.4f}".format(aggregated_loss))
            print("   Accuracy={0:.2f}%".format(aggregated_accuracy * 100))

    if rank == 0:
        # Remove weight transfer file
        os.remove(onnx_file_name)