def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct computation graphs for training and one for validation.
    * Initialize solvers and set parameter variables to those.
    * Instantiate a communicator and set parameter variables.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprops
      * Set parameter gradients zero
      * Execute backprop.
      * In-place allreduce (THIS IS THE MAIN difference from a single device training)
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size

    # Create contexts
    extension_module = args.context
    if extension_module != "cuda" and \
            extension_module != "cuda.cudnn":
        raise Exception("Use `cuda` or `cuda.cudnn` extension_module.")
    n_devices = args.n_devices
    ctxs = []
    for i in range(n_devices):
        ctx = extension_context(extension_module, device_id=i)
        ctxs.append(ctx)
    ctx = ctxs[-1]

    # Create training graphs
    input_image_train = []
    preds_train = []
    losses_train = []
    test = False
    for i in range(n_devices):
        image = nn.Variable((args.batch_size, 3, 32, 32))
        label = nn.Variable((args.batch_size, 1))
        device_scope_name = "device{}".format(i)

        pred = cifar10_resnet23_prediction(image, ctxs[i], device_scope_name,
                                           test)
        loss = cifar10_resnet32_loss(pred, label)

        input_image_train.append({"image": image, "label": label})
        preds_train.append(pred)
        losses_train.append(loss)

    # Create validation graph
    test = True
    device_scope_name = "device{}".format(0)
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = cifar10_resnet23_prediction(image_valid, ctxs[i],
                                             device_scope_name, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solvers = []
    for i in range(n_devices):
        with nn.context_scope(ctxs[i]):
            solver = S.Adam()
            device_scope_name = "device{}".format(i)
            with nn.parameter_scope(device_scope_name):
                params = nn.get_parameters()
                solver.set_parameters(params)
            solvers.append(solver)

    # Communicator
    comm = C.DataParalellCommunicator(ctx)
    for i in range(n_devices):
        device_scope_name = "device{}".format(i)
        with nn.parameter_scope(device_scope_name):
            ctx = ctxs[i]
            params = nn.get_parameters()
            comm.add_context_and_parameters((ctx, params))
    comm.init()

    # Create threadpools with one thread
    pools = []
    for _ in range(n_devices):
        pool = ThreadPool(processes=1)
        pools.append(pool)

    # Once forward/backward to safely secure memory
    for device_id in range(n_devices):
        data, label = \
            (np.random.randn(*input_image_train[device_id]["image"].shape),
             (np.random.rand(*input_image_train[device_id]["label"].shape) * 10).astype(np.int32))

        ret = pools[device_id].apply_async(
            forward_backward, (input_image_train[device_id]["image"], data,
                               input_image_train[device_id]["label"], label,
                               losses_train[device_id], solvers[device_id]))
        ret.get()
        losses_train[device_id].d  # sync to host

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Data Iterator
    rng = np.random.RandomState(device_id)
    tdata = data_iterator_cifar10(args.batch_size, True, rng)
    vdata = data_iterator_cifar10(args.batch_size, False)

    # Training-loop
    for i in range(int(args.max_iter / n_devices)):
        # Validation
        if i % int(n_train_samples / args.batch_size / n_devices) == 0:
            ve = 0.
            for j in range(args.val_iter):
                image, label = vdata.next()
                input_image_valid["image"].d = image
                pred_valid.forward()
                ve += categorical_error(pred_valid.d, label)
            ve /= args.val_iter
            monitor_verr.add(i * n_devices, ve)
        if i % int(args.model_save_interval / n_devices) == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))

        # Forwards/Zerograd/Backwards
        fb_results = []
        for device_id in range(n_devices):
            image, label = tdata.next()

            res = pools[device_id].apply_async(
                forward_backward,
                (input_image_train[device_id]["image"], image,
                 input_image_train[device_id]["label"], label,
                 losses_train[device_id], solvers[device_id]))
            fb_results.append(res)
        for device_id in range(n_devices):
            fb_results[device_id].get()

        # In-place allreduce
        comm.allreduce(division=True, inplace=False)

        # Solvers update
        for device_id in range(n_devices):
            solvers[device_id].update()

        e = categorical_error(preds_train[-1].d,
                              input_image_train[-1]["label"].d)
        monitor_loss.add(i * n_devices, losses_train[-1].d.copy())
        monitor_err.add(i * n_devices, e)
        monitor_time.add(i * n_devices)

    nn.save_parameters(
        os.path.join(args.model_save_path,
                     'params_%06d.h5' % (args.max_iter / n_devices)))
Ejemplo n.º 2
0
def test_data_parallel_communicator():
    try:
        import nnabla_ext
        import nnabla_ext.cuda
        from nnabla.contrib.context import extension_context

    except:
        pytest.skip("DataParallelCommunicator are only supported in CUDA now.")

    n_devices = nnabla_ext.cuda.init.get_device_count()
    if n_devices < 2:
        pytest.skip("Number of cuda devices is less than 2.")

    # Contexts and Computation Graph
    extension_module = "cuda"
    ctxs = []
    for d in range(n_devices):
        ctx = extension_context(extension_module, device_id="{}".format(d))
        ctxs.append(ctx)
        with nn.context_scope(ctx):
            x_data = np.random.rand(4, 5)
            x = nn.Variable(x_data.shape)
            with nn.parameter_scope("gpu{}".format(d)):
                with nn.parameter_scope("affine1"):
                    z = PF.affine(x, 10, with_bias=True)
                with nn.parameter_scope("affine2"):
                    y = PF.affine(z, 5)

    # Init w.g
    grads = []
    for d in range(n_devices):
        with nn.parameter_scope("gpu{}".format(d)):
            params = nn.get_parameters()
            grad = []
            for i, elm in enumerate(params.items()):
                k, v = elm
                grad_ = np.random.randn(*v.shape)
                v.g = grad_
                v.grad.cast(np.float32, ctxs[d])
                grad.append(grad_)
            grads.append(grad)

    # Reference
    ref_grads = []
    with nn.parameter_scope("gpu{}".format(d)):
        params = nn.get_parameters()
        for i in range(len(params)):
            ave_grad = 0
            for d in range(n_devices):
                ave_grad += grads[d][i]
            ave_grad /= n_devices
            ref_grads.append(ave_grad)

    # Communicator
    try:
        comm = C.DataParalellCommunicator(ctxs[0])
    except:
        pytest.skip(
            "DataParalellCommunicator is not supported in cpu or not linux platform."
        )

    for d in range(n_devices):
        with nn.parameter_scope("gpu{}".format(d)):
            comm.add_context_and_parameters((ctxs[d], nn.get_parameters()))
    comm.init()
    comm.allreduce(division=True)

    # Check
    atol = 1e-6
    for d in range(n_devices):
        with nn.parameter_scope("gpu{}".format(d)):
            params = nn.get_parameters()
            for i, elm in enumerate(params.items()):
                k, v = elm
                assert np.allclose(ref_grads[i], v.g, atol=atol)