示例#1
0
def create_communicator(ignore_error=False):
    global _current_communicator

    import nnabla_ext.cudnn
    from nnabla.ext_utils import get_extension_context
    extension_module = "cudnn"
    context = get_extension_context(extension_module)
    try:
        logger.log(99, 'Create communicator with contexts {}'.format(context))
        _current_communicator = C.MultiProcessDataParalellCommunicator(context)
        _current_communicator.init()
        context.device_id = str(_current_communicator.rank %
                                _current_communicator.size)
        if _current_communicator.size == 1:
            _current_communicator = None
    except:
        if not ignore_error:
            raise
        logger.warning("Failed to initialize nnabla.communicators.")
        _current_communicator = None

    return _current_communicator
示例#2
0
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * AllReduce for gradients
      * Solver updates parameters by using gradients computed by backprop and all reduce.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    n_valid_samples = 10000
    bs_valid = args.batch_size

    # Create Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Model
    rng = np.random.RandomState(313)
    comm_syncbn = comm if args.sync_bn else None
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=32,
                                       act=F.relu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar100

    # Create training graphs
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test=False)
    pred_train.persistent = True
    loss_train = (loss_function(pred_train, label_train) /
                  n_devices).apply(persistent=True)
    error_train = F.mean(F.top_n_error(pred_train, label_train,
                                       axis=1)).apply(persistent=True)
    loss_error_train = F.sink(loss_train, error_train)
    input_image_train = {"image": image_train, "label": label_train}

    # Create validation graph
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    label_valid = nn.Variable((args.batch_size, 1))
    pred_valid = prediction(image_valid, test=True)
    error_valid = F.mean(F.top_n_error(pred_valid, label_valid, axis=1))
    input_image_valid = {"image": image_valid, "label": label_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(
        1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Validation error", monitor, interval=1)
    monitor_vtime = MonitorTimeElapsed("Validation time", monitor, interval=1)

    # Data Iterator
    rng = np.random.RandomState(device_id)
    _, tdata = data_iterator(args.batch_size, True, rng)
    vsource, vdata = data_iterator(args.batch_size, False)

    # loss_error_train.forward()

    # Training-loop
    ve = nn.Variable()
    for i in range(int(args.max_iter / n_devices)):
        # Validation
        if i % int(n_train_samples / args.batch_size / n_devices) == 0:
            ve_local = 0.
            k = 0
            idx = np.random.permutation(n_valid_samples)
            val_images = vsource.images[idx]
            val_labels = vsource.labels[idx]
            for j in range(int(n_valid_samples / n_devices * mpi_rank),
                           int(n_valid_samples / n_devices * (mpi_rank + 1)),
                           bs_valid):
                image = val_images[j:j + bs_valid]
                label = val_labels[j:j + bs_valid]
                if len(image
                       ) != bs_valid:  # note that smaller batch is ignored
                    continue
                input_image_valid["image"].d = image
                input_image_valid["label"].d = label
                error_valid.forward(clear_buffer=True)
                ve_local += error_valid.d.copy()
                k += 1
            ve_local /= k
            ve.d = ve_local
            comm.all_reduce(ve.data, division=True, inplace=True)

            # Save model
            if device_id == 0:
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)
                if i % int(args.model_save_interval / n_devices) == 0:
                    nn.save_parameters(
                        os.path.join(args.model_save_path,
                                     'params_%06d.h5' % i))

        # Forward/Zerograd
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_error_train.forward(clear_no_need_grad=True)
        solver.zero_grad()

        # Backward/AllReduce
        backward_and_all_reduce(
            loss_error_train,
            comm,
            with_all_reduce_callback=args.with_all_reduce_callback)

        # Solvers update
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if device_id == 0:  # loss and error locally, and elapsed time
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, error_train.d.copy())
            monitor_time.add(i * n_devices)

        # exit(0)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'params_%06d.h5' % (args.max_iter / n_devices)))
示例#3
0
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size
    rng = np.random.RandomState(313)
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=64,
                                       act=F.relu)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu)
        data_iterator = data_iterator_cifar100

    # Communicator and Context
    extension_module = "cuda.cudnn"
    ctx = extension_context(extension_module)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx = extension_context(extension_module, device_id=device_id)
    nn.set_default_context(ctx)

    # Create training graphs
    test = False
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test)
    loss_train = loss_function(pred_train, label_train)
    input_image_train = {"image": image_train, "label": label_train}

    # add parameters to communicator
    comm.add_context_and_parameters((ctx, nn.get_parameters()))

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = prediction(image_valid, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(
        1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Data Iterator
    rng = np.random.RandomState(device_id)
    tdata = data_iterator(args.batch_size, True, rng)
    vdata = data_iterator(args.batch_size, False)

    # Training-loop
    for i in range(int(args.max_iter / n_devices)):
        # Validation
        if device_id == 0:
            if i % int(n_train_samples / args.batch_size / n_devices) == 0:
                ve = 0.
                for j in range(args.val_iter):
                    image, label = vdata.next()
                    input_image_valid["image"].d = image
                    pred_valid.forward()
                    ve += categorical_error(pred_valid.d, label)
                ve /= args.val_iter
                monitor_verr.add(i * n_devices, ve)
            if i % int(args.model_save_interval / n_devices) == 0:
                nn.save_parameters(
                    os.path.join(args.model_save_path, 'params_%06d.h5' % i))

        # Forward/Zerograd/Backward
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_train.forward()
        solver.zero_grad()
        loss_train.backward()

        # Allreduce
        comm.allreduce(division=False, inplace=False)

        # Solvers update
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if device_id == 0:
            e = categorical_error(pred_train.d, input_image_train["label"].d)
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, e)
            monitor_time.add(i * n_devices)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'params_%06d.h5' % (args.max_iter / n_devices)))
示例#4
0
import nnabla.communicators as C
import numpy as np
from nbla_test_utils import list_context
from nnabla.contrib.context import extension_context

############################################
# Communicator has to be instantiated here,
# otherwise, mpirun fails.
############################################

# Communicator
comm = None
try:
    extension_module = "cuda"
    ctx = extension_context(extension_module)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx.device_id = str(device_id)
except:
    pass

############################################


def ref_reduce(x_data_list, size, division):
    f = reduce(lambda x, y: x + y, np.arange(size)) + size
    results = []
示例#5
0
def train():
    """
    Main script.
    """

    args = get_args()

    _ = nn.load_parameters(args.pretrained_model_path)
    if args.fine_tune:
        nnabla.parameter.pop_parameter('decoder/logits/affine/conv/W')
        nnabla.parameter.pop_parameter('decoder/logits/affine/conv/b')

    n_train_samples = args.train_samples
    n_val_samples = args.val_samples
    distributed = args.distributed
    compute_acc = args.compute_acc

    if distributed:
        # Communicator and Context
        from nnabla.ext_utils import get_extension_context
        extension_module = "cudnn"
        ctx = get_extension_context(
            extension_module, type_config=args.type_config)
        comm = C.MultiProcessDataParalellCommunicator(ctx)
        comm.init()
        n_devices = comm.size
        mpi_rank = comm.rank
        device_id = mpi_rank
        ctx.device_id = str(device_id)
        nn.set_default_context(ctx)
    else:
        # Get context.
        from nnabla.ext_utils import get_extension_context
        extension_module = args.context
        if args.context is None:
            extension_module = 'cpu'
        logger.info("Running in %s" % extension_module)
        ctx = get_extension_context(
            extension_module, device_id=args.device_id, type_config=args.type_config)
        nn.set_default_context(ctx)
        n_devices = 1
        device_id = 0

    # training data
    data = data_iterator_segmentation(
            args.train_samples, args.batch_size, args.train_dir, args.train_label_dir, target_width=args.image_width, target_height=args.image_height)
    # validation data
    vdata = data_iterator_segmentation(args.val_samples, args.batch_size, args.val_dir,
                                       args.val_label_dir, target_width=args.image_width, target_height=args.image_height)

    if distributed:
        data = data.slice(
            rng=None, num_of_slices=n_devices, slice_pos=device_id)
        vdata = vdata.slice(
            rng=None, num_of_slices=n_devices, slice_pos=device_id)
    num_classes = args.num_class

    # Workaround to start with the same initialized weights for all workers.
    np.random.seed(313)
    t_model = get_model(
        args, test=False)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward
    t_pred2 = t_model.pred.unlinked()
    t_e = F.sum(F.top_n_error(t_pred2, t_model.label, axis=1)
                * t_model.mask) / F.sum(t_model.mask)

    v_model = get_model(
        args, test=True)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward
    v_pred2 = v_model.pred.unlinked()
    v_e = F.sum(F.top_n_error(v_pred2, v_model.label, axis=1)
                * v_model.mask) / F.sum(t_model.mask)

    # Create Solver
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    # Load checkpoint
    start_point = 0
    if args.checkpoint is not None:
        # load weights and solver state info from specified checkpoint file.
        start_point = load_checkpoint(args.checkpoint, solver)

    # Setting warmup.
    base_lr = args.learning_rate / n_devices
    warmup_iter = int(1. * n_train_samples /
                      args.batch_size / args.accum_grad / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_miou = M.MonitorSeries("mean IOU", monitor, interval=10)
    monitor_vtime = M.MonitorTimeElapsed(
        "Validation time", monitor, interval=1)

    # save_nnp
    contents = save_nnp({'x': v_model.image}, {
                        'y': v_model.pred}, args.batch_size)
    save.save(os.path.join(args.model_save_path,
                           'Deeplabv3plus_result_epoch0.nnp'), contents, variable_batch_size=False)

    # Training loop
    for i in range(start_point, int(args.max_iter / n_devices)):
        # Save parameters
        if i % (args.model_save_interval // n_devices) == 0 and device_id == 0:
            save_checkpoint(args.model_save_path, i, solver)
        # Validation
        if i % (args.val_interval // n_devices) == 0 and i != 0:
            vmiou_local = 0.
            val_iter_local = n_val_samples // args.batch_size
            vl_local = nn.NdArray()
            vl_local.zero()
            ve_local = nn.NdArray()
            ve_local.zero()
            for j in range(val_iter_local):
                images, labels, masks = vdata.next()
                v_model.image.d = images
                v_model.label.d = labels
                v_model.mask.d = masks
                v_model.image.data.cast(np.float32, ctx)
                v_model.label.data.cast(np.int32, ctx)
                v_model.loss.forward(clear_buffer=True)
                v_e.forward(clear_buffer=True)
                vl_local += v_model.loss.data
                ve_local += v_e.data
                # Mean IOU computation
                if compute_acc:
                    vmiou_local += compute_miou(num_classes, labels,
                                                np.argmax(v_model.pred.d, axis=1), masks)

            vl_local /= val_iter_local
            ve_local /= val_iter_local
            if compute_acc:
                vmiou_local /= val_iter_local
                vmiou_ndarray = nn.NdArray.from_numpy_array(
                    np.array(vmiou_local))
            if distributed:
                comm.all_reduce(vl_local, division=True, inplace=True)
                comm.all_reduce(ve_local, division=True, inplace=True)
                if compute_acc:
                    comm.all_reduce(vmiou_ndarray, division=True, inplace=True)

            if device_id == 0:
                monitor_vloss.add(i * n_devices, vl_local.data.copy())
                monitor_verr.add(i * n_devices, ve_local.data.copy())
                if compute_acc:
                    monitor_miou.add(i * n_devices, vmiou_local)
                monitor_vtime.add(i * n_devices)

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        e_acc = nn.NdArray(t_e.shape)
        e_acc.zero()
        l_acc = nn.NdArray(t_model.loss.shape)
        l_acc.zero()
        # Gradient accumulation loop
        for j in range(args.accum_grad):
            images, labels, masks = data.next()
            t_model.image.d = images
            t_model.label.d = labels
            t_model.mask.d = masks
            t_model.image.data.cast(np.float32, ctx)
            t_model.label.data.cast(np.int32, ctx)
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            t_e.forward(clear_buffer=True)
            e_acc += t_e.data
            l_acc += t_model.loss.data

        # AllReduce
        if distributed:
            params = [x.grad for x in nn.get_parameters().values()]
            comm.all_reduce(params, division=False, inplace=False)
            comm.all_reduce(l_acc, division=True, inplace=True)
            comm.all_reduce(e_acc, division=True, inplace=True)
        solver.scale_grad(1./args.accum_grad)
        solver.weight_decay(args.weight_decay)
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if distributed:
            # Synchronize by averaging the weights over devices using allreduce
            if (i+1) % args.sync_weight_every_itr == 0:
                weights = [x.data for x in nn.get_parameters().values()]
                comm.all_reduce(weights, division=True, inplace=True)

        if device_id == 0:
            monitor_loss.add(
                i * n_devices, (l_acc / args.accum_grad).data.copy())
            monitor_err.add(
                i * n_devices, (e_acc / args.accum_grad).data.copy())
            monitor_time.add(i * n_devices)

        # Learning rate decay at scheduled iter --> changed to poly learning rate decay policy
        # if i in args.learning_rate_decay_at:
        solver.set_learning_rate(base_lr * ((1 - i / args.max_iter)**0.1))

    if device_id == 0:
        nn.save_parameters(os.path.join(args.model_save_path,
                                        'param_%06d.h5' % args.max_iter))

    contents = save_nnp({'x': v_model.image}, {
                        'y': v_model.pred}, args.batch_size)
    save.save(os.path.join(args.model_save_path,
                           'Deeplabv3plus_result.nnp'), contents, variable_batch_size=False)
示例#6
0
def comm_nccl_opts(request):
    """Common resources for communicator tests.
    """
    if not request.config.getoption('--test-communicator'):
        return None

    import nnabla.communicators as C
    from nnabla.ext_utils import get_extension_context

    try:
        from nnabla_ext import cuda
    except Exception as e:
        raise ImportError(
            "Communicator test requires CUDA extension.\n{}".format(e))

    gpus = request.config.getoption('--communicator-gpus')
    n_devices = cuda.get_device_count()
    if gpus is None:
        devices = list(map(str, range(n_devices)))
    else:
        devices = gpu.split(',')
        # Check numbers
        try:
            for d in devices:
                gid = int(d)
                if gid >= n_devices:
                    raise ValueError('')
        except ValueError as e:
            raise ValueError(
                "GPU IDs must be comma sperated integers of available GPUs. Given {}. Avaiable GPUs are {}.".format(gpus, n_devices))

    extension_module = "cuda"
    ctx = get_extension_context(extension_module)
    try:
        comm = C.MultiProcessDataParalellCommunicator(ctx)
    except Exception as e:
        raise RuntimeError(
            "Communicator could not be created. You may haven't build with distributed support.\n{}".format(e))
    try:
        comm.init()
    except Exception as e:
        raise RuntimeError(
            "Communicator initialization failed. (Maybe MPI init failure.)\n{}".format(e))

    assert len(
        devices) == comm.size, "Number of cuda devices used are not same as that of processes."
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    ctx.device_id = devices[mpi_local_rank]

    class CommOpts:
        pass

    c = CommOpts()
    c.comm = comm
    c.device_id = ctx.device_id
    c.devices = devices
    c.mpi_rank = mpi_rank
    c.mpi_local_rank = mpi_local_rank
    return c
示例#7
0
def train(args):
    # Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = comm.local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    n_classes = args.n_classes
    not_sn = args.not_sn

    # Model
    # workaround to start with the same weights in the distributed system.
    np.random.seed(412)
    # generator loss
    z = nn.Variable([batch_size, latent])
    y_fake = nn.Variable([batch_size])
    x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes,
                       sn=not_sn).apply(persistent=True)
    p_fake = discriminator(x_fake, y_fake, maps=maps //
                           16, n_classes=n_classes, sn=not_sn)
    loss_gen = gan_loss(p_fake)
    # discriminator loss
    y_real = nn.Variable([batch_size])
    x_real = nn.Variable([batch_size, 3, image_size, image_size])
    p_real = discriminator(x_real, y_real, maps=maps //
                           16, n_classes=n_classes, sn=not_sn)
    loss_dis = gan_loss(p_fake, p_real)
    # generator with fixed value for test
    z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent))
    y_test = nn.Variable.from_numpy_array(
        generate_random_class(n_classes, batch_size))
    x_test = generator(z_test, y_test, maps=maps,
                       n_classes=n_classes, test=True, sn=not_sn)

    # Solver
    solver_gen = S.Adam(args.lrg, args.beta1, args.beta2)
    solver_dis = S.Adam(args.lrd, args.beta1, args.beta2)
    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope("discriminator"):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    if comm.rank == 0:
        monitor = Monitor(args.monitor_path)
        monitor_loss_gen = MonitorSeries(
            "Generator Loss", monitor, interval=10)
        monitor_loss_dis = MonitorSeries(
            "Discriminator Loss", monitor, interval=10)
        monitor_time = MonitorTimeElapsed(
            "Training Time", monitor, interval=10)
        monitor_image_tile_train = MonitorImageTile("Image Tile Train", monitor,
                                                    num_images=args.batch_size,
                                                    interval=1,
                                                    normalize_method=normalize_method)
        monitor_image_tile_test = MonitorImageTile("Image Tile Test", monitor,
                                                   num_images=args.batch_size,
                                                   interval=1,
                                                   normalize_method=normalize_method)
    # DataIterator
    rng = np.random.RandomState(device_id)
    di = data_iterator_imagenet(args.train_dir, args.dirname_to_label_path,
                                args.batch_size, n_classes=args.n_classes,
                                rng=rng)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake.need_grad = False  # no need for discriminator backward
        solver_dis.zero_grad()
        for _ in range(args.accum_grad):
            # feed x_real and y_real
            x_data, y_data = di.next()
            x_real.d, y_real.d = x_data, y_data.flatten()
            # feed z and y_fake
            z_data = np.random.randn(args.batch_size, args.latent)
            y_data = generate_random_class(args.n_classes, args.batch_size)
            z.d, y_fake.d = z_data, y_data
            loss_dis.forward(clear_no_need_grad=True)
            loss_dis.backward(
                1.0 / (args.accum_grad * n_devices), clear_buffer=True)
        comm.all_reduce([v.grad for v in params_dis.values()])
        solver_dis.update()

        # Train genrator
        x_fake.need_grad = True  # need for generator backward
        solver_gen.zero_grad()
        for _ in range(args.accum_grad):
            z_data = np.random.randn(args.batch_size, args.latent)
            y_data = generate_random_class(args.n_classes, args.batch_size)
            z.d, y_fake.d = z_data, y_data
            loss_gen.forward(clear_no_need_grad=True)
            loss_gen.backward(
                1.0 / (args.accum_grad * n_devices), clear_buffer=True)
        comm.all_reduce([v.grad for v in params_gen.values()])
        solver_gen.update()

        # Synchronize by averaging the weights over devices using allreduce
        if i % args.sync_weight_every_itr == 0:
            weights = [v.data for v in nn.get_parameters().values()]
            comm.all_reduce(weights, division=True, inplace=True)

        # Save model and image
        if i % args.save_interval == 0 and comm.rank == 0:
            x_test.forward(clear_buffer=True)
            nn.save_parameters(os.path.join(
                args.monitor_path, "params_{}.h5".format(i)))
            monitor_image_tile_train.add(i, x_fake.d)
            monitor_image_tile_test.add(i, x_test.d)

        # Monitor
        if comm.rank == 0:
            monitor_loss_gen.add(i, loss_gen.d.copy())
            monitor_loss_dis.add(i, loss_dis.d.copy())
            monitor_time.add(i)

    if comm.rank == 0:
        x_test.forward(clear_buffer=True)
        nn.save_parameters(os.path.join(
            args.monitor_path, "params_{}.h5".format(i)))
        monitor_image_tile_train.add(i, x_fake.d)
        monitor_image_tile_test.add(i, x_test.d)
示例#8
0
def train(args):
    # Create Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Input
    b, c, h, w = args.batch_size, 3, args.image_size, args.image_size
    x_real_a = nn.Variable([b, c, h, w])
    x_real_b = nn.Variable([b, c, h, w])

    # Model
    # workaround for starting with the same model among devices.
    np.random.seed(412)
    maps = args.maps
    # within-domain reconstruction (domain A)
    x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a")
    x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a")
    x_recon_a = decoder(x_content_a, x_style_a, name="decoder-a")
    # within-domain reconstruction (domain B)
    x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b")
    x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b")
    x_recon_b = decoder(x_content_b, x_style_b, name="decoder-b")
    # generate over domains and reconstruction of content and style (domain A)
    z_style_a = F.randn(shape=x_style_a.shape)
    x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a")
    x_content_rec_b = content_encoder(x_fake_a, maps, name="content-encoder-a")
    x_style_rec_a = style_encoder(x_fake_a, maps, name="style-encoder-a")
    # generate over domains and reconstruction of content and style (domain B)
    z_style_b = F.randn(shape=x_style_b.shape)
    x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b")
    x_content_rec_a = content_encoder(x_fake_b, maps, name="content-encoder-b")
    x_style_rec_b = style_encoder(x_fake_b, maps, name="style-encoder-b")
    # discriminate (domain A)
    p_x_fake_a_list = discriminators(x_fake_a)
    p_x_real_a_list = discriminators(x_real_a)
    p_x_fake_b_list = discriminators(x_fake_b)
    p_x_real_b_list = discriminators(x_real_b)

    # Loss
    # within-domain reconstruction
    loss_recon_x_a = recon_loss(x_recon_a, x_real_a).apply(persistent=True)
    loss_recon_x_b = recon_loss(x_recon_b, x_real_b).apply(persistent=True)
    # content and style reconstruction
    loss_recon_x_style_a = recon_loss(x_style_rec_a,
                                      z_style_a).apply(persistent=True)
    loss_recon_x_content_b = recon_loss(x_content_rec_b,
                                        x_content_b).apply(persistent=True)
    loss_recon_x_style_b = recon_loss(x_style_rec_b,
                                      z_style_b).apply(persistent=True)
    loss_recon_x_content_a = recon_loss(x_content_rec_a,
                                        x_content_a).apply(persistent=True)

    # adversarial

    def f(x, y):
        return x + y

    loss_gen_a = reduce(f, [lsgan_loss(p_f)
                            for p_f in p_x_fake_a_list]).apply(persistent=True)
    loss_dis_a = reduce(f, [
        lsgan_loss(p_f, p_r)
        for p_f, p_r in zip(p_x_fake_a_list, p_x_real_a_list)
    ]).apply(persistent=True)
    loss_gen_b = reduce(f, [lsgan_loss(p_f)
                            for p_f in p_x_fake_b_list]).apply(persistent=True)
    loss_dis_b = reduce(f, [
        lsgan_loss(p_f, p_r)
        for p_f, p_r in zip(p_x_fake_b_list, p_x_real_b_list)
    ]).apply(persistent=True)
    # loss for generator-related models
    loss_gen = loss_gen_a + loss_gen_b \
        + args.lambda_x * (loss_recon_x_a + loss_recon_x_b) \
        + args.lambda_c * (loss_recon_x_content_a + loss_recon_x_content_b) \
        + args.lambda_s * (loss_recon_x_style_a + loss_recon_x_style_b)
    # loss for discriminators
    loss_dis = loss_dis_a + loss_dis_b

    # Solver
    lr_g, lr_d, beta1, beta2 = args.lr_g, args.lr_d, args.beta1, args.beta2
    # solver for generator-related models
    solver_gen = S.Adam(lr_g, beta1, beta2)
    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
    solver_gen.set_parameters(params_gen)
    # solver for discriminators
    solver_dis = S.Adam(lr_d, beta1, beta2)
    with nn.parameter_scope("discriminators"):
        params_dis = nn.get_parameters()
    solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    # time
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    # reconstruction
    monitor_loss_recon_x_a = MonitorSeries("Recon Loss Image A",
                                           monitor,
                                           interval=10)
    monitor_loss_recon_x_content_b = MonitorSeries("Recon Loss Content B",
                                                   monitor,
                                                   interval=10)
    monitor_loss_recon_x_style_a = MonitorSeries("Recon Loss Style A",
                                                 monitor,
                                                 interval=10)
    monitor_loss_recon_x_b = MonitorSeries("Recon Loss Image B",
                                           monitor,
                                           interval=10)
    monitor_loss_recon_x_content_a = MonitorSeries("Recon Loss Content A",
                                                   monitor,
                                                   interval=10)
    monitor_loss_recon_x_style_b = MonitorSeries("Recon Loss Style B",
                                                 monitor,
                                                 interval=10)
    # adversarial
    monitor_loss_gen_a = MonitorSeries("Gen Loss A", monitor, interval=10)
    monitor_loss_dis_a = MonitorSeries("Dis Loss A", monitor, interval=10)
    monitor_loss_gen_b = MonitorSeries("Gen Loss B", monitor, interval=10)
    monitor_loss_dis_b = MonitorSeries("Dis Loss B", monitor, interval=10)
    monitor_losses = [
        # reconstruction
        (monitor_loss_recon_x_a, loss_recon_x_a),
        (monitor_loss_recon_x_content_b, loss_recon_x_content_b),
        (monitor_loss_recon_x_style_a, loss_recon_x_style_a),
        (monitor_loss_recon_x_b, loss_recon_x_b),
        (monitor_loss_recon_x_content_a, loss_recon_x_content_a),
        (monitor_loss_recon_x_style_b, loss_recon_x_style_b),
        # adaversarial
        (monitor_loss_gen_a, loss_gen_a),
        (monitor_loss_dis_a, loss_dis_a),
        (monitor_loss_gen_b, loss_gen_b),
        (monitor_loss_dis_b, loss_dis_b)
    ]
    # image
    monitor_image_a = MonitorImage("Fake Image B to A Train",
                                   monitor,
                                   interval=1)
    monitor_image_b = MonitorImage("Fake Image A to B Train",
                                   monitor,
                                   interval=1)
    monitor_images = [
        (monitor_image_a, x_fake_a),
        (monitor_image_b, x_fake_b),
    ]

    # DataIterator
    rng_a = np.random.RandomState(device_id)
    rng_b = np.random.RandomState(device_id + n_devices)
    di_a = munit_data_iterator(args.img_path_a, args.batch_size, rng=rng_a)
    di_b = munit_data_iterator(args.img_path_b, args.batch_size, rng=rng_b)

    # Train
    for i in range(args.max_iter // n_devices):
        ii = i * n_devices
        # Train generator-related models
        x_data_a, x_data_b = di_a.next()[0], di_b.next()[0]
        x_real_a.d, x_real_b.d = x_data_a, x_data_b
        solver_gen.zero_grad()
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        comm.all_reduce([w.grad for w in params_gen.values()])
        solver_gen.weight_decay(args.weight_decay_rate)
        solver_gen.update()

        # Train discriminators
        x_data_a, x_data_b = di_a.next()[0], di_b.next()[0]
        x_real_a.d, x_real_b.d = x_data_a, x_data_b
        x_fake_a.need_grad, x_fake_b.need_grad = False, False
        solver_dis.zero_grad()
        loss_dis.forward(clear_no_need_grad=True)
        loss_dis.backward(clear_buffer=True)
        comm.all_reduce([w.grad for w in params_dis.values()])
        solver_dis.weight_decay(args.weight_decay_rate)
        solver_dis.update()
        x_fake_a.need_grad, x_fake_b.need_grad = True, True

        # LR schedule
        if (i + 1) % (args.lr_decay_at_every // n_devices) == 0:
            lr_d = solver_dis.learning_rate() * args.lr_decay_rate
            lr_g = solver_gen.learning_rate() * args.lr_decay_rate
            solver_dis.set_learning_rate(lr_d)
            solver_gen.set_learning_rate(lr_g)

        if mpi_local_rank == 0:
            # Monitor
            monitor_time.add(ii)
            for mon, loss in monitor_losses:
                mon.add(ii, loss.d)
            # Save
            if (i + 1) % (args.model_save_interval // n_devices) == 0:
                for mon, x in monitor_images:
                    mon.add(ii, x.d)
                nn.save_parameters(
                    os.path.join(args.monitor_path,
                                 "param_{:05d}.h5".format(i)))

    if mpi_local_rank == 0:
        # Monitor
        for mon, loss in monitor_losses:
            mon.add(ii, loss.d)
        # Save
        for mon, x in monitor_images:
            mon.add(ii, x.d)
        nn.save_parameters(
            os.path.join(args.monitor_path, "param_{:05d}.h5".format(i)))
def train():
    """
    Main script.

    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Inplace allreduce (THIS IS THE MAIN difference from a single device training)
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error

    """

    args = get_args()
    if args.tiny_mode:
        n_train_samples = 100000
    else:
        n_train_samples = 1282167

    # Communicator and Context
    from nnabla.ext_utils import get_extension_context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = mpi_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # workarond to start with the same parameters.
    rng = np.random.RandomState(device_id)
    if args.tiny_mode:
        # We use Tiny ImageNet from Stanford CS231N class.
        # (Tiny ImageNet, https://tiny-imagenet.herokuapp.com/)
        # Tiny ImageNet consists of 200 categories, each category has 500 images
        # in training set. The image size is 64x64. To adapt ResNet into 64x64
        # image inputs, the input image size of ResNet is set as 56x56, and
        # the stride in the first conv and the first max pooling are removed.
        # Please check README.
        data = data_iterator_tiny_imagenet(args.batch_size, 'train')
        vdata = data_iterator_tiny_imagenet(args.batch_size, 'val')
        num_classes = 200
    else:
        # We use ImageNet.
        # (ImageNet, https://imagenet.herokuapp.com/)
        # ImageNet consists of 1000 categories, each category has 1280 images
        # in training set. The image size is various. To adapt ResNet into
        # 320x320 image inputs, the input image size of ResNet is set as
        # 224x224. We need to get tar file and create cache file(320x320 images).
        # Please check README.
        data = data_iterator_imagenet(args.batch_size,
                                      args.train_cachefile_dir,
                                      rng=rng)
        vdata = data_iterator_imagenet(args.batch_size, args.val_cachefile_dir)
        vdata = vdata.slice(rng=None,
                            num_of_slices=n_devices,
                            slice_pos=device_id)
        num_classes = 1000
    # Workaround to start with the same initialized weights for all workers.
    np.random.seed(313)
    t_model = get_model(args, num_classes, test=False, tiny=args.tiny_mode)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward
    t_pred2 = t_model.pred.unlinked()
    t_e = F.mean(F.top_n_error(t_pred2, t_model.label))
    v_model = get_model(args, num_classes, test=True, tiny=args.tiny_mode)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward
    v_pred2 = v_model.pred.unlinked()
    v_e = F.mean(F.top_n_error(v_pred2, v_model.label))

    # Add parameters to communicator.
    comm.add_context_and_parameters((ctx, nn.get_parameters()))

    # Create Solver.
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    # Setting warmup.
    base_lr = args.learning_rate / n_devices
    warmup_iter = int(1. * n_train_samples / args.batch_size /
                      args.accum_grad / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_vtime = M.MonitorTimeElapsed("Validation time",
                                         monitor,
                                         interval=1)

    # Training loop.
    vl = nn.Variable()
    ve = nn.Variable()
    for i in range(int(args.max_iter / n_devices)):
        # Save parameters
        if i % (args.model_save_interval // n_devices) == 0 and device_id == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'param_%06d.h5' % i))

        # Validation
        if i % (args.val_interval // n_devices) == 0 and i != 0:
            ve_local = 0.
            vl_local = 0.
            val_iter_local = args.val_iter // n_devices
            for j in range(val_iter_local):
                images, labels = vdata.next()
                v_model.image.d = images
                v_model.label.d = labels
                v_model.image.data.cast(np.uint8, ctx)
                v_model.label.data.cast(np.int32, ctx)
                v_model.loss.forward(clear_buffer=True)
                v_e.forward(clear_buffer=True)
                vl_local += v_model.loss.d.copy()
                ve_local += v_e.d.copy()
            vl_local /= val_iter_local
            vl.d = vl_local
            comm.all_reduce(vl.data, division=True, inplace=True)
            ve_local /= val_iter_local
            ve.d = ve_local
            comm.all_reduce(ve.data, division=True, inplace=True)

            if device_id == 0:
                monitor_vloss.add(i * n_devices, vl.d.copy())
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        def accumulate_error(l, e, t_model, t_e):
            l += t_model.loss.d
            e += t_e.d
            return l, e

        # Gradient accumulation loop
        for j in range(args.accum_grad):
            images, labels = data.next()
            if j != 0:
                # Update e and l according to previous results of forward
                # propagation.
                # The update of last iteration is performed
                # after solver update to avoid unnecessary CUDA synchronization.
                # This is performed after data.next() in order to overlap
                # the data loading and graph execution.
                # TODO: Move this to the bottom of the loop when prefetch
                # data loader is available.
                l, e = accumulate_error(l, e, t_model, t_e)
            t_model.image.d = images
            t_model.label.d = labels
            t_model.image.data.cast(np.uint8, ctx)
            t_model.label.data.cast(np.int32, ctx)
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            t_e.forward(clear_buffer=True)

        # AllReduce
        params = [x.grad for x in nn.get_parameters().values()]
        comm.all_reduce(params, division=False, inplace=False)

        # Update
        solver.weight_decay(args.weight_decay)
        solver.update()

        # Accumulate errors after solver update
        l, e = accumulate_error(l, e, t_model, t_e)

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        # Synchronize by averaging the weights over devices using allreduce
        if (i + 1) % args.sync_weight_every_itr == 0:
            weights = [x.data for x in nn.get_parameters().values()]
            comm.all_reduce(weights, division=True, inplace=True)

        if device_id == 0:
            monitor_loss.add(i * n_devices, l / args.accum_grad)
            monitor_err.add(i * n_devices, e / args.accum_grad)
            monitor_time.add(i * n_devices)

        # Learning rate decay at scheduled iter
        if i * n_devices in args.learning_rate_decay_at:
            solver.set_learning_rate(solver.learning_rate() * 0.1)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'param_%06d.h5' % (args.max_iter / n_devices)))
def train():
    """
    Main script.

    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Inplace allreduce (THIS IS THE MAIN difference from a single device training)
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error

    """

    args = get_args()
    n_train_samples = 1281167
    num_classes = 1000

    # Communicator and Context
    from nnabla.ext_utils import get_extension_context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = mpi_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Pipelines and Iterators for training
    train_pipes = [
        TrainPipeline(args.batch_size,
                      args.num_threads,
                      device_id,
                      args.train_cachefile_dir,
                      args.train_list,
                      seed=device_id + 1,
                      num_gpu=n_devices,
                      random_area=args.random_area)
    ]
    train_pipes[0].build()
    data = DALIClassificationIterator(train_pipes,
                                      train_pipes[0].epoch_size("Reader") //
                                      n_devices,
                                      auto_reset=True,
                                      stop_at_epoch=False)
    # Pipelines and Iterators for validation
    val_pipes = [
        ValPipeline(args.batch_size,
                    args.num_threads,
                    device_id,
                    args.val_cachefile_dir,
                    args.val_list,
                    seed=device_id + 1,
                    num_gpu=n_devices)
    ]
    val_pipes[0].build()
    vdata = DALIClassificationIterator(val_pipes,
                                       val_pipes[0].epoch_size("Reader") //
                                       n_devices,
                                       auto_reset=True,
                                       stop_at_epoch=False)
    # Network for training
    t_model = get_model(args,
                        num_classes,
                        n_devices,
                        args.accum_grad,
                        test=False)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward
    t_pred2 = t_model.pred.get_unlinked_variable(need_grad=False)
    t_e = F.mean(F.top_n_error(t_pred2, t_model.label))
    # Network for validation
    v_model = get_model(args,
                        num_classes,
                        n_devices,
                        args.accum_grad,
                        test=True)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward
    v_pred2 = v_model.pred.get_unlinked_variable(need_grad=False)
    v_e = F.mean(F.top_n_error(v_pred2, v_model.label))

    # Solver
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_learning_rate(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Monitors
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_vtime = M.MonitorTimeElapsed("Validation time",
                                         monitor,
                                         interval=1)

    # Training loop
    vl = nn.Variable()
    ve = nn.Variable()
    for i in range(int(args.max_iter / n_devices)):
        # Save parameters
        if i % (args.model_save_interval // n_devices) == 0 and device_id == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'param_%06d.h5' % i))

        # Validation
        if i % (args.val_interval // n_devices) == 0 and i != 0:
            ve_local = 0.
            vl_local = 0.
            val_iter_local = args.val_iter // n_devices
            for j in range(val_iter_local):
                nextImage, nextLabel = vdata.next()
                v_model.image.data = nextImage
                v_model.label.data = nextLabel
                v_model.loss.forward(clear_buffer=True)
                v_e.forward(clear_buffer=True)
                vl_local += v_model.loss.d.copy()
                ve_local += v_e.d.copy()
            vl_local /= val_iter_local
            vl.d = vl_local
            comm.all_reduce(vl.data, division=True, inplace=True)
            ve_local /= val_iter_local
            ve.d = ve_local
            comm.all_reduce(ve.data, division=True, inplace=True)

            if device_id == 0:
                monitor_vloss.add(i * n_devices, vl.d.copy())
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        def accumulate_error(l, e, t_model, t_e):
            l += t_model.loss.d
            e += t_e.d
            return l, e

        # Gradient accumulation loop
        for j in range(args.accum_grad):
            nextImage, nextLabel = data.next()
            t_model.image.data = nextImage
            t_model.label.data = nextLabel
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            t_e.forward(clear_buffer=True)
            l, e = accumulate_error(l, e, t_model, t_e)

        # AllReduce
        params = [x.grad for x in nn.get_parameters().values()]
        comm.all_reduce(params, division=False, inplace=False)

        # Update
        solver.weight_decay(args.weight_decay)
        solver.update()

        if device_id == 0:
            monitor_loss.add(i * n_devices, l / args.accum_grad)
            monitor_err.add(i * n_devices, e / args.accum_grad)
            monitor_time.add(i * n_devices)

        # Learning rate decay at scheduled iter
        if i * n_devices in args.learning_rate_decay_at:
            solver.set_learning_rate(solver.learning_rate() * 0.1)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'param_%06d.h5' % (args.max_iter / n_devices)))
示例#11
0
def train(args):

    # get context

    ctx = get_extension_context(args.context)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = mpi_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    config = read_yaml(args.config)

    if args.info:
        config.monitor_params.info = args.info

    if comm.size == 1:
        comm = None
    else:
        # disable outputs from logger except its rank = 0
        if comm.rank > 0:
            import logging
            logger.setLevel(logging.ERROR)

    test = False
    train_params = config.train_params
    dataset_params = config.dataset_params
    model_params = config.model_params

    loss_flags = get_loss_flags(train_params)

    start_epoch = 0

    rng = np.random.RandomState(device_id)
    data_iterator = frame_data_iterator(
        root_dir=dataset_params.root_dir,
        frame_shape=dataset_params.frame_shape,
        id_sampling=dataset_params.id_sampling,
        is_train=True,
        random_seed=rng,
        augmentation_params=dataset_params.augmentation_params,
        batch_size=train_params['batch_size'],
        shuffle=True,
        with_memory_cache=False,
        with_file_cache=False)

    if n_devices > 1:
        data_iterator = data_iterator.slice(rng=rng,
                                            num_of_slices=comm.size,
                                            slice_pos=comm.rank)
        # workaround not to use memory cache
        data_iterator._data_source._on_memory = False
        logger.info("Disabled on memory data cache.")

    bs, h, w, c = [train_params.batch_size] + dataset_params.frame_shape
    source = nn.Variable((bs, c, h, w))
    driving = nn.Variable((bs, c, h, w))

    with nn.parameter_scope("kp_detector"):
        # kp_X = {"value": Variable((bs, 10, 2)), "jacobian": Variable((bs, 10, 2, 2))}

        kp_source = detect_keypoint(source,
                                    **model_params.kp_detector_params,
                                    **model_params.common_params,
                                    test=test,
                                    comm=comm)
        persistent_all(kp_source)

        kp_driving = detect_keypoint(driving,
                                     **model_params.kp_detector_params,
                                     **model_params.common_params,
                                     test=test,
                                     comm=comm)
        persistent_all(kp_driving)

    with nn.parameter_scope("generator"):
        generated = occlusion_aware_generator(source,
                                              kp_source=kp_source,
                                              kp_driving=kp_driving,
                                              **model_params.generator_params,
                                              **model_params.common_params,
                                              test=test,
                                              comm=comm)
        # generated is a dictionary containing;
        # 'mask': Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25
        # 'sparse_deformed': Variable((bs, num_kp + 1, num_channel, h/4, w/4))
        # 'occlusion_map': Variable((bs, 1, h/4, w/4))
        # 'deformed': Variable((bs, c, h, w))
        # 'prediction': Variable((bs, c, h, w)) Only this is fed to discriminator.

    generated["prediction"].persistent = True

    pyramide_real = get_image_pyramid(driving, train_params.scales,
                                      generated["prediction"].shape[1])
    persistent_all(pyramide_real)

    pyramide_fake = get_image_pyramid(generated['prediction'],
                                      train_params.scales,
                                      generated["prediction"].shape[1])
    persistent_all(pyramide_fake)

    total_loss_G = None  # dammy. defined temporarily
    loss_var_dict = {}

    # perceptual loss using VGG19 (always applied)
    if loss_flags.use_perceptual_loss:
        logger.info("Use Perceptual Loss.")
        scales = train_params.scales
        weights = train_params.loss_weights.perceptual
        vgg_param_path = train_params.vgg_param_path
        percep_loss = perceptual_loss(pyramide_real, pyramide_fake, scales,
                                      weights, vgg_param_path)
        percep_loss.persistent = True
        loss_var_dict['perceptual_loss'] = percep_loss
        total_loss_G = percep_loss

    # (LS)GAN loss and feature matching loss
    if loss_flags.use_gan_loss:
        logger.info("Use GAN Loss.")
        with nn.parameter_scope("discriminator"):
            discriminator_maps_generated = multiscale_discriminator(
                pyramide_fake,
                kp=unlink_all(kp_driving),
                **model_params.discriminator_params,
                **model_params.common_params,
                test=test,
                comm=comm)

            discriminator_maps_real = multiscale_discriminator(
                pyramide_real,
                kp=unlink_all(kp_driving),
                **model_params.discriminator_params,
                **model_params.common_params,
                test=test,
                comm=comm)

        for v in discriminator_maps_generated["feature_maps_1"]:
            v.persistent = True
        discriminator_maps_generated["prediction_map_1"].persistent = True

        for v in discriminator_maps_real["feature_maps_1"]:
            v.persistent = True
        discriminator_maps_real["prediction_map_1"].persistent = True

        for i, scale in enumerate(model_params.discriminator_params.scales):
            key = f'prediction_map_{scale}'.replace('.', '-')
            lsgan_loss_weight = train_params.loss_weights.generator_gan
            # LSGAN loss for Generator
            if i == 0:
                gan_loss_gen = lsgan_loss(discriminator_maps_generated[key],
                                          lsgan_loss_weight)
            else:
                gan_loss_gen += lsgan_loss(discriminator_maps_generated[key],
                                           lsgan_loss_weight)
            # LSGAN loss for Discriminator
            if i == 0:
                gan_loss_dis = lsgan_loss(discriminator_maps_real[key],
                                          lsgan_loss_weight,
                                          discriminator_maps_generated[key])
            else:
                gan_loss_dis += lsgan_loss(discriminator_maps_real[key],
                                           lsgan_loss_weight,
                                           discriminator_maps_generated[key])
        gan_loss_dis.persistent = True
        loss_var_dict['gan_loss_dis'] = gan_loss_dis
        total_loss_D = gan_loss_dis
        total_loss_D.persistent = True

        gan_loss_gen.persistent = True
        loss_var_dict['gan_loss_gen'] = gan_loss_gen
        total_loss_G += gan_loss_gen

        if loss_flags.use_feature_matching_loss:
            logger.info("Use Feature Matching Loss.")
            fm_weights = train_params.loss_weights.feature_matching
            fm_loss = feature_matching_loss(discriminator_maps_real,
                                            discriminator_maps_generated,
                                            model_params, fm_weights)
            fm_loss.persistent = True
            loss_var_dict['feature_matching_loss'] = fm_loss
            total_loss_G += fm_loss

    # transform loss
    if loss_flags.use_equivariance_value_loss or loss_flags.use_equivariance_jacobian_loss:
        transform = Transform(bs, **config.train_params.transform_params)
        transformed_frame = transform.transform_frame(driving)

        with nn.parameter_scope("kp_detector"):
            transformed_kp = detect_keypoint(transformed_frame,
                                             **model_params.kp_detector_params,
                                             **model_params.common_params,
                                             test=test,
                                             comm=comm)
        persistent_all(transformed_kp)

        # Value loss part
        if loss_flags.use_equivariance_value_loss:
            logger.info("Use Equivariance Value Loss.")
            warped_kp_value = transform.warp_coordinates(
                transformed_kp['value'])
            eq_value_weight = train_params.loss_weights.equivariance_value

            eq_value_loss = equivariance_value_loss(kp_driving['value'],
                                                    warped_kp_value,
                                                    eq_value_weight)
            eq_value_loss.persistent = True
            loss_var_dict['equivariance_value_loss'] = eq_value_loss
            total_loss_G += eq_value_loss

        # jacobian loss part
        if loss_flags.use_equivariance_jacobian_loss:
            logger.info("Use Equivariance Jacobian Loss.")
            arithmetic_jacobian = transform.jacobian(transformed_kp['value'])
            eq_jac_weight = train_params.loss_weights.equivariance_jacobian
            eq_jac_loss = equivariance_jacobian_loss(
                kp_driving['jacobian'], arithmetic_jacobian,
                transformed_kp['jacobian'], eq_jac_weight)
            eq_jac_loss.persistent = True
            loss_var_dict['equivariance_jacobian_loss'] = eq_jac_loss
            total_loss_G += eq_jac_loss

    assert total_loss_G is not None
    total_loss_G.persistent = True
    loss_var_dict['total_loss_gen'] = total_loss_G

    # -------------------- Create Monitors --------------------
    monitors_gen, monitors_dis, monitor_time, monitor_vis, log_dir = get_monitors(
        config, loss_flags, loss_var_dict)

    if device_id == 0:
        # Dump training info .yaml
        _ = shutil.copy(args.config, log_dir)  # copy the config yaml
        training_info_yaml = os.path.join(log_dir, "training_info.yaml")
        os.rename(os.path.join(log_dir, os.path.basename(args.config)),
                  training_info_yaml)
        # then add additional information
        with open(training_info_yaml, "a", encoding="utf-8") as f:
            f.write(f"\nlog_dir: {log_dir}\nsaved_parameter: None")

    # -------------------- Solver Setup --------------------
    solvers = setup_solvers(train_params)
    solver_generator = solvers["generator"]
    solver_discriminator = solvers["discriminator"]
    solver_kp_detector = solvers["kp_detector"]

    # max epochs
    num_epochs = train_params['num_epochs']

    # iteration per epoch
    num_iter_per_epoch = data_iterator.size // bs
    # will be increased by num_repeat
    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        num_iter_per_epoch *= config.train_params.num_repeats

    # modify learning rate if current epoch exceeds the number defined in
    lr_decay_at_epochs = train_params['epoch_milestones']  # ex. [60, 90]
    gamma = 0.1  # decay rate

    # -------------------- For finetuning ---------------------
    if args.ft_params:
        assert os.path.isfile(args.ft_params)
        logger.info(f"load {args.ft_params} for finetuning.")
        nn.load_parameters(args.ft_params)
        start_epoch = int(
            os.path.splitext(os.path.basename(
                args.ft_params))[0].split("epoch_")[1])

        # set solver's state
        for name, solver in solvers.items():
            saved_states = os.path.join(
                os.path.dirname(args.ft_params),
                f"state_{name}_at_epoch_{start_epoch}.h5")
            solver.load_states(saved_states)

        start_epoch += 1
        logger.info(f"Resuming from epoch {start_epoch}.")

    logger.info(
        f"Start training. Total epoch: {num_epochs - start_epoch}, {num_iter_per_epoch * n_devices} iter/epoch."
    )

    for e in range(start_epoch, num_epochs):
        logger.info(f"Epoch: {e} / {num_epochs}.")
        data_iterator._reset()  # rewind the iterator at the beginning

        # learning rate scheduler
        if e in lr_decay_at_epochs:
            logger.info("Learning rate decayed.")
            learning_rate_decay(solvers, gamma=gamma)

        for i in range(num_iter_per_epoch):
            _driving, _source = data_iterator.next()
            source.d = _source
            driving.d = _driving

            # update generator and keypoint detector
            total_loss_G.forward()

            if device_id == 0:
                monitors_gen.add((e * num_iter_per_epoch + i) * n_devices)

            solver_generator.zero_grad()
            solver_kp_detector.zero_grad()

            callback = None
            if n_devices > 1:
                params = [x.grad for x in solver_generator.get_parameters().values()] + \
                         [x.grad for x in solver_kp_detector.get_parameters().values()]
                callback = comm.all_reduce_callback(params, 2 << 20)
            total_loss_G.backward(clear_buffer=True,
                                  communicator_callbacks=callback)

            solver_generator.update()
            solver_kp_detector.update()

            if loss_flags.use_gan_loss:
                # update discriminator

                total_loss_D.forward(clear_no_need_grad=True)
                if device_id == 0:
                    monitors_dis.add((e * num_iter_per_epoch + i) * n_devices)

                solver_discriminator.zero_grad()

                callback = None
                if n_devices > 1:
                    params = [
                        x.grad for x in
                        solver_discriminator.get_parameters().values()
                    ]
                    callback = comm.all_reduce_callback(params, 2 << 20)
                total_loss_D.backward(clear_buffer=True,
                                      communicator_callbacks=callback)

                solver_discriminator.update()

            if device_id == 0:
                monitor_time.add((e * num_iter_per_epoch + i) * n_devices)

            if device_id == 0 and (
                (e * num_iter_per_epoch + i) *
                    n_devices) % config.monitor_params.visualize_freq == 0:
                images_to_visualize = [
                    source.d, driving.d, generated["prediction"].d
                ]
                visuals = combine_images(images_to_visualize)
                monitor_vis.add((e * num_iter_per_epoch + i) * n_devices,
                                visuals)

        if device_id == 0:
            if e % train_params.checkpoint_freq == 0 or e == num_epochs - 1:
                save_parameters(e, log_dir, solvers)

    return