Пример #1
0
def main():

    # Get arguments
    args = get_args()
    data_file = "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt"
    model_file = args.work_dir + "model.h5"

    # Load Dataset
    itow, wtoi, dataset = load_ptbset(data_file)

    # Computation environment settings
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create data provider
    n_word = len(wtoi)
    n_dim = args.embed_dim
    batchsize = args.batchsize
    half_window = args.half_window_length
    n_negative = args.n_negative_sample

    di = DataIteratorForEmbeddingLearning(
        batchsize=batchsize,
        half_window=half_window,
        n_negative=n_negative,
        dataset=dataset)

    # Create model
    # - Real batch size including context samples and negative samples
    size = batchsize * (1 + n_negative) * (2 * (half_window - 1))

    # Model for learning
    # - input variables
    xl = nn.Variable((size,))  # variable for word
    yl = nn.Variable((size,))  # variable for context

    # Embed layers for word embedding function
    # - f_embed : word index x to get y, the n_dim vector
    # --  for each sample in a minibatch
    hx = PF.embed(xl, n_word, n_dim, name="e1")  # feature vector for word
    hy = PF.embed(yl, n_word, n_dim, name="e1")  # feature vector for context
    hl = F.sum(hx * hy, axis=1)

    # -- Approximated likelihood of context prediction
    # pos: word context, neg negative samples
    tl = nn.Variable([size, ], need_grad=False)
    loss = F.sigmoid_cross_entropy(hl, tl)
    loss = F.mean(loss)

    # Model for test of searching similar words
    xr = nn.Variable((size,), need_grad=False)
    hr = PF.embed(xr, n_word, n_dim, name="e1")  # feature vector for test

    # Create solver
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    monitor = M.Monitor(args.work_dir)
    monitor_loss = M.MonitorSeries(
        "Training loss", monitor, interval=args.monitor_interval)
    monitor_time = M.MonitorTimeElapsed(
        "Training time", monitor, interval=args.monitor_interval)

    # Do training
    max_epoch = args.max_epoch
    for epoch in range(max_epoch):

        # iteration per epoch
        for i in range(di.n_batch):

            # get minibatch
            xi, yi, ti = di.next()

            # learn
            solver.zero_grad()
            xl.d, yl.d, tl.d = xi, yi, ti
            loss.forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            solver.update()

            # monitor
            itr = epoch * di.n_batch + i
            monitor_loss.add(itr, loss.d)
            monitor_time.add(itr)

    # Save model
    nn.save_parameters(model_file)
    nnp_file = os.path.join(
        args.work_dir, 'wtov_%06d.nnp' % (args.max_epoch))
    runtime_contents = {
        'networks': [
            {'name': 'Validation',
             'batch_size': size,
             'outputs': {'e': hr},
             'names': {'w': xr}}],
        'executors': [
            {'name': 'Runtime',
             'network': 'Validation',
             'data': ['w'],
             'output': ['e']}]}
    save.save(nnp_file, runtime_contents)

    from cpp_forward_check import check_cpp_forward
    check_cpp_forward(args.work_dir, [xi], [xr], hr, nnp_file)
    exit()

    # Evaluate by similarity
    max_check_words = args.max_check_words
    for i in range(max_check_words):

        # prediction
        xr.d = i
        hr.forward(clear_buffer=True)
        h = hr.d

        # similarity calculation
        w = nn.get_parameters()['e1/embed/W'].d
        s = np.sqrt((w * w).sum(1))
        w /= s.reshape((s.shape[0], 1))
        similarity = w.dot(h[0]) / s[i]

        # for understanding
        output_similar_words(itow, i, similarity)
Пример #2
0
def train(args):
    """
    Main script.
    """

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    margin = 1.0  # Margin for contrastive loss.

    # TRAIN
    # Create input variables.
    image0 = nn.Variable([args.batch_size, 1, 28, 28])
    image1 = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size])
    # Create predition graph.
    pred = mnist_lenet_siamese(image0, image1, test=False)
    # Create loss function.
    loss = F.mean(contrastive_loss(pred, label, margin))

    # TEST
    # Create input variables.
    vimage0 = nn.Variable([args.batch_size, 1, 28, 28])
    vimage1 = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size])
    # Create predition graph.
    vpred = mnist_lenet_siamese(vimage0, vimage1, test=True)
    vloss = F.mean(contrastive_loss(vpred, vlabel, margin))

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_vloss = M.MonitorSeries("Test loss", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    rng = np.random.RandomState(313)
    data = siamese_data_iterator(args.batch_size, True, rng)
    vdata = siamese_data_iterator(args.batch_size, False, rng)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage0.d, vimage1.d, vlabel.d = vdata.next()
                vloss.forward(clear_buffer=True)
                ve += vloss.d
            monitor_vloss.add(i, ve / args.val_iter)
        if i % args.model_save_interval == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
        image0.d, image1.d, label.d = data.next()
        solver.zero_grad()
        # Training forward, backward and update
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        monitor_loss.add(i, loss.d.copy())
        monitor_time.add(i)
    nn.save_parameters(
        os.path.join(args.model_save_path, 'params_%06d.h5' % args.max_iter))
Пример #3
0
def train(args):
    """
    Main script.
    """

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    # TRAIN

    # Fake path
    z = nn.Variable([args.batch_size, 1000, 1, 1])
    fake = generator(z, maxh=1024)
    fake.persistent = True  # Not to clear at backward
    pred_fake = discriminator(fake)
    loss_gen = F.mean(
        F.sigmoid_cross_entropy(pred_fake, F.constant(1, pred_fake.shape)))
    fake_dis = fake.unlinked()
    pred_fake_dis = discriminator(fake_dis)
    loss_dis = F.mean(
        F.sigmoid_cross_entropy(pred_fake_dis,
                                F.constant(0, pred_fake_dis.shape)))

    # Real path
    x = nn.Variable([args.batch_size, 1, 28, 28])
    pred_real = discriminator(x)
    loss_dis += F.mean(
        F.sigmoid_cross_entropy(pred_real, F.constant(1, pred_real.shape)))

    # Create Solver.
    solver_gen = S.Adam(args.learning_rate, beta1=0.5)
    solver_dis = S.Adam(args.learning_rate, beta1=0.5)
    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss_gen = M.MonitorSeries("Generator loss", monitor, interval=10)
    monitor_loss_dis = M.MonitorSeries("Discriminator loss",
                                       monitor,
                                       interval=10)
    monitor_time = M.MonitorTimeElapsed("Time", monitor, interval=100)
    monitor_fake = M.MonitorImageTile("Fake images",
                                      monitor,
                                      normalize_method=lambda x: x + 1 / 2.)

    #data = data_iterator_mnist(args.batch_size, True)
    data = iterator.simple_data_iterator(load_kanji_data(), args.batch_size,
                                         True)

    # Training loop.
    for i in range(args.max_iter):
        if i % args.model_save_interval == 0:
            with nn.parameter_scope("gen"):
                nn.save_parameters(
                    os.path.join(args.model_save_path,
                                 "generator_param_%06d.h5" % i))
            with nn.parameter_scope("dis"):
                nn.save_parameters(
                    os.path.join(args.model_save_path,
                                 "discriminator_param_%06d.h5" % i))

        # Training forward
        image, _ = data.next()
        x.d = image / 255. - 0.5  # [0, 255] to [-1, 1]
        z.d = np.random.randn(*z.shape)

        # Generator update.
        solver_gen.zero_grad()
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        solver_gen.weight_decay(args.weight_decay)
        solver_gen.update()
        monitor_fake.add(i, fake)
        monitor_loss_gen.add(i, loss_gen.d.copy())

        # Discriminator update.
        solver_dis.zero_grad()
        loss_dis.forward(clear_no_need_grad=True)
        loss_dis.backward(clear_buffer=True)
        solver_dis.weight_decay(args.weight_decay)
        solver_dis.update()
        monitor_loss_dis.add(i, loss_dis.d.copy())
        monitor_time.add(i)

    with nn.parameter_scope("gen"):
        nn.save_parameters(
            os.path.join(args.model_save_path, "generator_param_%06d.h5" % i))
    with nn.parameter_scope("dis"):
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         "discriminator_param_%06d.h5" % i))
Пример #4
0
def train():
    """
    Main script.
    """

    args = get_args()

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Dataset
    # We use Tiny ImageNet from Stanford CS231N class.
    # https://tiny-imagenet.herokuapp.com/
    # Tiny ImageNet consists of 200 categories, each category has 500 images
    # in training set. The image size is 64x64. To adapt ResNet into 64x64
    # image inputs, the input image size of ResNet is set as 56x56, and
    # the stride in the first conv and the first max pooling are removed.
    data = data_iterator_tiny_imagenet(args.batch_size, 'train')
    vdata = data_iterator_tiny_imagenet(args.batch_size, 'val')

    num_classes = 200
    tiny = True  # TODO: Switch ILSVRC2012 dataset and TinyImageNet.
    t_model = get_model(
        args, num_classes, test=False, tiny=tiny)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward
    v_model = get_model(
        args, num_classes, test=True, tiny=tiny)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward

    # Create Solver.
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=10)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=10)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)

    # Training loop.
    for i in range(args.max_iter):
        # Save parameters
        if i % args.model_save_interval == 0:
            nn.save_parameters(os.path.join(
                args.model_save_path, 'param_%06d.h5' % i))

        # Validation
        if i % args.val_interval == 0:

            # Clear all intermediate memory to save memory.
            # t_model.loss.clear_recursive()

            l = 0.0
            e = 0.0
            for j in range(args.val_iter):
                images, labels = vdata.next()
                v_model.image.d = images
                v_model.label.d = labels
                v_model.image.data.cast(np.uint8, ctx)
                v_model.label.data.cast(np.int32, ctx)
                v_model.loss.forward(clear_buffer=True)
                l += v_model.loss.d
                e += categorical_error(v_model.pred.d, v_model.label.d)
            monitor_vloss.add(i, l / args.val_iter)
            monitor_verr.add(i, e / args.val_iter)

            # Clear all intermediate memory to save memory.
            # v_model.loss.clear_recursive()

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        # Gradient accumulation loop
        for j in range(args.accum_grad):
            images, labels = data.next()
            t_model.image.d = images
            t_model.label.d = labels
            t_model.image.data.cast(np.uint8, ctx)
            t_model.label.data.cast(np.int32, ctx)
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            l += t_model.loss.d
            e += categorical_error(t_model.pred.d, t_model.label.d)
        solver.weight_decay(args.weight_decay)
        solver.update()
        monitor_loss.add(i, l / args.accum_grad)
        monitor_err.add(i, e / args.accum_grad)
        monitor_time.add(i)

        # Learning rate decay at scheduled iter
        if i in args.learning_rate_decay_at:
            solver.set_learning_rate(solver.learning_rate() * 0.1)
    nn.save_parameters(os.path.join(args.model_save_path,
                                    'param_%06d.h5' % args.max_iter))

    nnp_file = os.path.join(
        args.model_save_path, 'resnet_%06d.nnp' % (args.max_iter))
    runtime_contents = {
        'networks': [
            {'name': 'Validation',
             'batch_size': v_model.pred.shape[0],
             'outputs': {'y': v_model.pred},
             'names': {'x': v_model.image}}],
        'executors': [
            {'name': 'Runtime',
             'network': 'Validation',
             'data': ['x'],
             'output': ['y']}]}
    save.save(nnp_file, runtime_contents)

    from cpp_forward_check import check_cpp_forward
    check_cpp_forward(args.model_save_path, [v_model.image.d], [
                      v_model.image], v_model.pred, nnp_file)
def train():
    """
    Main script.

    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Inplace allreduce (THIS IS THE MAIN difference from a single device training)
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error

    """

    args = get_args()
    n_train_samples = 1281167
    num_classes = 1000

    # Communicator and Context
    from nnabla.ext_utils import get_extension_context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = mpi_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Pipelines and Iterators for training
    train_pipes = [
        TrainPipeline(args.batch_size,
                      args.num_threads,
                      device_id,
                      args.train_cachefile_dir,
                      args.train_list,
                      seed=device_id + 1,
                      num_gpu=n_devices,
                      random_area=args.random_area)
    ]
    train_pipes[0].build()
    data = DALIClassificationIterator(train_pipes,
                                      train_pipes[0].epoch_size("Reader") //
                                      n_devices,
                                      auto_reset=True,
                                      stop_at_epoch=False)
    # Pipelines and Iterators for validation
    val_pipes = [
        ValPipeline(args.batch_size,
                    args.num_threads,
                    device_id,
                    args.val_cachefile_dir,
                    args.val_list,
                    seed=device_id + 1,
                    num_gpu=n_devices)
    ]
    val_pipes[0].build()
    vdata = DALIClassificationIterator(val_pipes,
                                       val_pipes[0].epoch_size("Reader") //
                                       n_devices,
                                       auto_reset=True,
                                       stop_at_epoch=False)
    # Network for training
    t_model = get_model(args,
                        num_classes,
                        n_devices,
                        args.accum_grad,
                        test=False)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward
    t_pred2 = t_model.pred.get_unlinked_variable(need_grad=False)
    t_e = F.mean(F.top_n_error(t_pred2, t_model.label))
    # Network for validation
    v_model = get_model(args,
                        num_classes,
                        n_devices,
                        args.accum_grad,
                        test=True)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward
    v_pred2 = v_model.pred.get_unlinked_variable(need_grad=False)
    v_e = F.mean(F.top_n_error(v_pred2, v_model.label))

    # Solver
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_learning_rate(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Monitors
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_vtime = M.MonitorTimeElapsed("Validation time",
                                         monitor,
                                         interval=1)

    # Training loop
    vl = nn.Variable()
    ve = nn.Variable()
    for i in range(int(args.max_iter / n_devices)):
        # Save parameters
        if i % (args.model_save_interval // n_devices) == 0 and device_id == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'param_%06d.h5' % i))

        # Validation
        if i % (args.val_interval // n_devices) == 0 and i != 0:
            ve_local = 0.
            vl_local = 0.
            val_iter_local = args.val_iter // n_devices
            for j in range(val_iter_local):
                nextImage, nextLabel = vdata.next()
                v_model.image.data = nextImage
                v_model.label.data = nextLabel
                v_model.loss.forward(clear_buffer=True)
                v_e.forward(clear_buffer=True)
                vl_local += v_model.loss.d.copy()
                ve_local += v_e.d.copy()
            vl_local /= val_iter_local
            vl.d = vl_local
            comm.all_reduce(vl.data, division=True, inplace=True)
            ve_local /= val_iter_local
            ve.d = ve_local
            comm.all_reduce(ve.data, division=True, inplace=True)

            if device_id == 0:
                monitor_vloss.add(i * n_devices, vl.d.copy())
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        def accumulate_error(l, e, t_model, t_e):
            l += t_model.loss.d
            e += t_e.d
            return l, e

        # Gradient accumulation loop
        for j in range(args.accum_grad):
            nextImage, nextLabel = data.next()
            t_model.image.data = nextImage
            t_model.label.data = nextLabel
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            t_e.forward(clear_buffer=True)
            l, e = accumulate_error(l, e, t_model, t_e)

        # AllReduce
        params = [x.grad for x in nn.get_parameters().values()]
        comm.all_reduce(params, division=False, inplace=False)

        # Update
        solver.weight_decay(args.weight_decay)
        solver.update()

        if device_id == 0:
            monitor_loss.add(i * n_devices, l / args.accum_grad)
            monitor_err.add(i * n_devices, e / args.accum_grad)
            monitor_time.add(i * n_devices)

        # Learning rate decay at scheduled iter
        if i * n_devices in args.learning_rate_decay_at:
            solver.set_learning_rate(solver.learning_rate() * 0.1)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'param_%06d.h5' % (args.max_iter / n_devices)))
Пример #6
0
def train(args):
    """
    Main script.
    """

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    # TRAIN

    # Fake path
    z = nn.Variable([args.batch_size, 100, 1, 1])
    fake = generator(z)
    fake.persistent = True  # Not to clear at backward
    pred_fake = discriminator(fake)
    loss_gen = F.mean(
        F.sigmoid_cross_entropy(pred_fake, F.constant(1, pred_fake.shape)))
    fake_dis = fake.get_unlinked_variable(need_grad=True)
    fake_dis.need_grad = True  # TODO: Workaround until v1.0.2
    pred_fake_dis = discriminator(fake_dis)
    loss_dis = F.mean(
        F.sigmoid_cross_entropy(pred_fake_dis,
                                F.constant(0, pred_fake_dis.shape)))

    # Real path
    x = nn.Variable([args.batch_size, 1, 28, 28])
    pred_real = discriminator(x)
    loss_dis += F.mean(
        F.sigmoid_cross_entropy(pred_real, F.constant(1, pred_real.shape)))

    # Create Solver.
    solver_gen = S.Adam(args.learning_rate, beta1=0.5)
    solver_dis = S.Adam(args.learning_rate, beta1=0.5)
    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())
    start_point = 0

    if args.checkpoint is not None:
        # load weights and solver state info from specified checkpoint files.
        start_point = load_checkpoint(args.checkpoint, {
            "gen": solver_gen,
            "dis": solver_dis
        })

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss_gen = M.MonitorSeries("Generator loss", monitor, interval=10)
    monitor_loss_dis = M.MonitorSeries("Discriminator loss",
                                       monitor,
                                       interval=10)
    monitor_time = M.MonitorTimeElapsed("Time", monitor, interval=100)
    monitor_fake = M.MonitorImageTile("Fake images",
                                      monitor,
                                      normalize_method=lambda x: (x + 1) / 2.)

    data = data_iterator_mnist(args.batch_size, True)

    # Save_nnp
    contents = save_nnp({'x': z}, {'y': fake}, args.batch_size)
    save.save(
        os.path.join(args.model_save_path, 'Generator_result_epoch0.nnp'),
        contents)
    contents = save_nnp({'x': x}, {'y': pred_real}, args.batch_size)
    save.save(
        os.path.join(args.model_save_path, 'Discriminator_result_epoch0.nnp'),
        contents)

    # Training loop.
    for i in range(start_point, args.max_iter):
        if i % args.model_save_interval == 0:
            save_checkpoint(args.model_save_path, i, {
                "gen": solver_gen,
                "dis": solver_dis
            })

        # Training forward
        image, _ = data.next()
        x.d = image / 255. - 0.5  # [0, 255] to [-1, 1]
        z.d = np.random.randn(*z.shape)

        # Generator update.
        solver_gen.zero_grad()
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        solver_gen.weight_decay(args.weight_decay)
        solver_gen.update()
        monitor_fake.add(i, fake)
        monitor_loss_gen.add(i, loss_gen.d.copy())

        # Discriminator update.
        solver_dis.zero_grad()
        loss_dis.forward(clear_no_need_grad=True)
        loss_dis.backward(clear_buffer=True)
        solver_dis.weight_decay(args.weight_decay)
        solver_dis.update()
        monitor_loss_dis.add(i, loss_dis.d.copy())
        monitor_time.add(i)

    with nn.parameter_scope("gen"):
        nn.save_parameters(
            os.path.join(args.model_save_path, "generator_param_%06d.h5" % i))
    with nn.parameter_scope("dis"):
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         "discriminator_param_%06d.h5" % i))

    # Save_nnp
    contents = save_nnp({'x': z}, {'y': fake}, args.batch_size)
    save.save(os.path.join(args.model_save_path, 'Generator_result.nnp'),
              contents)
    contents = save_nnp({'x': x}, {'y': pred_real}, args.batch_size)
    save.save(os.path.join(args.model_save_path, 'Discriminator_result.nnp'),
              contents)
Пример #7
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Set parameter gradients zero
      * Execute forwardprop on the training graph.
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    args = get_args(monitor_path='tmp.monitor.bnn')

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Initialize DataIterator for MNIST.
    data = data_iterator_mnist(args.batch_size, True)
    vdata = data_iterator_mnist(args.batch_size, False)

    # Create CNN network for both training and testing.
    mnist_cnn_prediction = mnist_binary_connect_lenet_prediction
    if args.net == 'bincon':
        mnist_cnn_prediction = mnist_binary_connect_lenet_prediction
    elif args.net == 'binnet':
        mnist_cnn_prediction = mnist_binary_net_lenet_prediction
    elif args.net == 'bwn':
        mnist_cnn_prediction = mnist_binary_weight_lenet_prediction
    elif args.net == 'bincon_resnet':
        mnist_cnn_prediction = mnist_binary_connect_resnet_prediction
    elif args.net == 'binnet_resnet':
        mnist_cnn_prediction = mnist_binary_net_resnet_prediction
    elif args.net == 'bwn_resnet':
        mnist_cnn_prediction = mnist_binary_weight_resnet_prediction

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image / 255, test=False)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    vpred = mnist_cnn_prediction(vimage / 255, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = M.MonitorSeries("Test error", monitor, interval=10)

    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if i % args.model_save_interval == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        # Training backward & update
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        # Monitor
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    parameter_file = os.path.join(args.model_save_path,
                                  'params_%06d.h5' % args.max_iter)
    nn.save_parameters(parameter_file)
Пример #8
0
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument("--num-train-examples", type=int, default=1600)
    parser.add_argument("--num-valid-examples", type=int, default=100)
    parser.add_argument("--accum-grad", type=int, default=32)
    parser.add_argument("--max-iter", type=int, default=6400)
    parser.add_argument("--valid-interval", type=int, default=100)
    parser.add_argument("--context", type=str, default="cpu")
    parser.add_argument("--device-id", type=int, default=0)

    args = parser.parse_args()

    from nnabla.ext_utils import get_extension_context
    extension_module = args.context
    ctx = get_extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # prepare dataset
    tdataset = []
    for i in range(args.num_train_examples):
        V, E = random_graph(rng)
        deg = degrees(V, E)
        tdataset.append(([V], [utils.from_adjacency_list(E)], [deg]))

    vdataset = []
    for i in range(args.num_valid_examples):
        V, E = random_graph(rng)
        deg = degrees(V, E)
        vdataset.append(([V], [utils.from_adjacency_list(E)], [deg]))

    # prepare data iterator
    tdata = data_iterator(SimpleDataSource2(tdataset, shuffle=True), 1, False,
                          False, False)
    vdata = data_iterator(SimpleDataSource2(vdataset, shuffle=False), 1, False,
                          False, False)

    # prepare monitors
    monitor = M.Monitor("./degree")
    tloss = M.MonitorSeries("Training Loss", monitor, interval=10)

    verror = M.MonitorSeries("Validation Error", monitor, interval=10)

    # prepare solver
    solver = S.Adam()

    # training loop
    for i in range(args.max_iter):
        l = 0
        for b in range(args.accum_grad):
            # read data
            V, E, degree = tdata.next()
            V = V[0][0]
            E = E[0][0]
            degree = degree[0][0]

            # predict
            output = predict(V, E)

            # initialize solver
            if i == 0 and b == 0:
                solver.set_parameters(nn.get_parameters())

            # calculate loss
            label = nn.Variable(degree.shape)
            label.data.data = degree
            label = F.reshape(label, (len(V), 1))
            loss = F.mean(F.squared_error(output, label))

            # training
            loss.forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            l += loss.data.data

        solver.update()

        tloss.add(i, l / args.accum_grad)
        l = 0

        if i % args.valid_interval == 0:
            # validation
            # read data
            e = 0
            n = 0
            for b in range(vdata.size):
                V, E, degree = vdata.next()
                V = V[0][0]
                E = E[0][0]
                degree = degree[0][0]

                output = predict(V, E)

                label = nn.Variable(degree.shape)
                label.data.data = degree
                label = F.reshape(label, (len(V), 1))
                error = F.sum(F.less_scalar(F.abs(F.sub2(output, label)), 0.5))

                error.forward()

                e += error.data.data
                n += len(V)
            verror.add(i, e / n)
Пример #9
0
def train():
    """
    Main script.

    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Inplace allreduce (THIS IS THE MAIN difference from a single device training)
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error

    """

    args = get_args()
    if args.tiny_mode:
        n_train_samples = 100000
    else:
        n_train_samples = 1281167

    # Communicator and Context
    from nnabla.ext_utils import get_extension_context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = mpi_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # workaround to start with the same parameters.
    rng = np.random.RandomState(device_id)
    if args.tiny_mode:
        # We use Tiny ImageNet from Stanford CS231N class.
        # (Tiny ImageNet, https://tiny-imagenet.herokuapp.com/)
        # Tiny ImageNet consists of 200 categories, each category has 500 images
        # in training set. The image size is 64x64. To adapt ResNet into 64x64
        # image inputs, the input image size of ResNet is set as 56x56, and
        # the stride in the first conv and the first max pooling are removed.
        # Please check README.
        data = data_iterator_tiny_imagenet(args.batch_size, 'train')
        vdata = data_iterator_tiny_imagenet(args.batch_size, 'val')
        num_classes = 200
    else:
        # We use ImageNet.
        # (ImageNet, https://imagenet.herokuapp.com/)
        # ImageNet consists of 1000 categories, each category has 1280 images
        # in training set. The image size is various. To adapt ResNet into
        # 320x320 image inputs, the input image size of ResNet is set as
        # 224x224. We need to get tar file and create cache file(320x320 images).
        # Please check README.
        data = data_iterator_imagenet(args.batch_size,
                                      args.train_cachefile_dir,
                                      rng=rng)
        data = data.slice(rng=rng,
                          num_of_slices=n_devices,
                          slice_pos=device_id)
        vdata = data_iterator_imagenet(args.batch_size, args.val_cachefile_dir)
        vdata = vdata.slice(rng=None,
                            num_of_slices=n_devices,
                            slice_pos=device_id)
        num_classes = 1000
    # Workaround to start with the same initialized weights for all workers.
    np.random.seed(313)
    t_model = get_model(args, num_classes, test=False, tiny=args.tiny_mode)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward

    # TODO: need_grad should be passed to get_unlinked_variable after v1.0.3 fix.
    t_pred2 = t_model.pred.get_unlinked_variable()
    t_pred2.need_grad = False

    t_e = F.mean(F.top_n_error(t_pred2, t_model.label))
    v_model = get_model(args, num_classes, test=True, tiny=args.tiny_mode)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward

    # TODO: need_grad should be passed to get_unlinked_variable after v1.0.3 fix.
    v_pred2 = v_model.pred.get_unlinked_variable()
    v_pred2.need_grad = False

    v_e = F.mean(F.top_n_error(v_pred2, v_model.label))

    # Create Solver.
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_vtime = M.MonitorTimeElapsed("Validation time",
                                         monitor,
                                         interval=1)

    # Training loop.
    vl = nn.Variable()
    ve = nn.Variable()
    for i in range(int(args.max_iter / n_devices)):
        # Save parameters
        if i % (args.model_save_interval // n_devices) == 0 and device_id == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'param_%06d.h5' % i))

        # Validation
        if i % (args.val_interval // n_devices) == 0 and i != 0:
            ve_local = 0.
            vl_local = 0.
            val_iter_local = args.val_iter // n_devices
            for j in range(val_iter_local):
                images, labels = vdata.next()
                v_model.image.d = images
                v_model.label.d = labels
                v_model.image.data.cast(np.uint8, ctx)
                v_model.label.data.cast(np.int32, ctx)
                v_model.loss.forward(clear_buffer=True)
                v_e.forward(clear_buffer=True)
                vl_local += v_model.loss.d.copy()
                ve_local += v_e.d.copy()
            vl_local /= val_iter_local
            vl.d = vl_local
            comm.all_reduce(vl.data, division=True, inplace=True)
            ve_local /= val_iter_local
            ve.d = ve_local
            comm.all_reduce(ve.data, division=True, inplace=True)

            if device_id == 0:
                monitor_vloss.add(i * n_devices, vl.d.copy())
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        def accumulate_error(l, e, t_model, t_e):
            l += t_model.loss.d
            e += t_e.d
            return l, e

        # Gradient accumulation loop
        for j in range(args.accum_grad):
            images, labels = data.next()
            t_model.image.d = images
            t_model.label.d = labels
            t_model.image.data.cast(np.uint8, ctx)
            t_model.label.data.cast(np.int32, ctx)
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            t_e.forward(clear_buffer=True)
            l, e = accumulate_error(l, e, t_model, t_e)

        # AllReduce
        params = [x.grad for x in nn.get_parameters().values()]
        comm.all_reduce(params, division=False, inplace=False)

        # Update
        solver.weight_decay(args.weight_decay)
        solver.update()

        if device_id == 0:
            monitor_loss.add(i * n_devices, l / args.accum_grad)
            monitor_err.add(i * n_devices, e / args.accum_grad)
            monitor_time.add(i * n_devices)

        # Learning rate decay at scheduled iter
        if i * n_devices in args.learning_rate_decay_at:
            solver.set_learning_rate(solver.learning_rate() * 0.1)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'param_%06d.h5' % (args.max_iter / n_devices)))
Пример #10
0
def train():
    """
    Main script.
    """

    args = get_args()

    _ = nn.load_parameters(args.pretrained_model_path)
    if args.fine_tune:
        nn.parameter.pop_parameter('decoder/logits/affine/conv/W')
        nn.parameter.pop_parameter('decoder/logits/affine/conv/b')

    n_train_samples = args.train_samples
    n_val_samples = args.val_samples
    distributed = args.distributed
    compute_acc = args.compute_acc

    if distributed:
        # Communicator and Context
        from nnabla.ext_utils import get_extension_context
        extension_module = "cudnn"
        ctx = get_extension_context(
            extension_module, type_config=args.type_config)
        comm = C.MultiProcessDataParalellCommunicator(ctx)
        comm.init()
        n_devices = comm.size
        mpi_rank = comm.rank
        device_id = mpi_rank
        ctx.device_id = str(device_id)
        nn.set_default_context(ctx)
    else:
        # Get context.
        from nnabla.ext_utils import get_extension_context
        extension_module = args.context
        if args.context is None:
            extension_module = 'cpu'
        logger.info("Running in %s" % extension_module)
        ctx = get_extension_context(
            extension_module, device_id=args.device_id, type_config=args.type_config)
        nn.set_default_context(ctx)
        n_devices = 1
        device_id = 0

    # training data
    data = data_iterator_segmentation(
            args.train_samples, args.batch_size, args.train_dir, args.train_label_dir)
    # validation data
    vdata = data_iterator_segmentation(
        args.val_samples, args.batch_size, args.val_dir, args.val_label_dir)

    if distributed:
        data = data.slice(
            rng=None, num_of_slices=n_devices, slice_pos=device_id)
        vdata = vdata.slice(
            rng=None, num_of_slices=n_devices, slice_pos=device_id)
    num_classes = args.num_class

    # Workaround to start with the same initialized weights for all workers.
    np.random.seed(313)
    t_model = get_model(
        args, test=False)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward
    t_pred2 = t_model.pred.unlinked()
    t_e = F.sum(F.top_n_error(t_pred2, t_model.label, axis=1)
                * t_model.mask) / F.sum(t_model.mask)

    v_model = get_model(
        args, test=True)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward
    v_pred2 = v_model.pred.unlinked()
    v_e = F.sum(F.top_n_error(v_pred2, v_model.label, axis=1)
                * v_model.mask) / F.sum(t_model.mask)

    # Create Solver
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    # Setting warmup.
    base_lr = args.learning_rate / n_devices
    warmup_iter = int(1. * n_train_samples /
                      args.batch_size / args.accum_grad / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_miou = M.MonitorSeries("mean IOU", monitor, interval=10)
    monitor_vtime = M.MonitorTimeElapsed(
        "Validation time", monitor, interval=1)

    # Training loop
    for i in range(int(args.max_iter / n_devices)):
        # Save parameters
        if i % (args.model_save_interval // n_devices) == 0 and device_id == 0:
            nn.save_parameters(os.path.join(
                args.model_save_path, 'param_%06d.h5' % i))

        # Validation
        if i % (args.val_interval // n_devices) == 0 and i != 0:
            vmiou_local = 0.
            val_iter_local = n_val_samples // args.batch_size
            vl_local = nn.NdArray()
            vl_local.zero()
            ve_local = nn.NdArray()
            ve_local.zero()
            for j in range(val_iter_local):
                images, labels, masks = vdata.next()
                v_model.image.d = images
                v_model.label.d = labels
                v_model.mask.d = masks
                v_model.image.data.cast(np.float32, ctx)
                v_model.label.data.cast(np.int32, ctx)
                v_model.loss.forward(clear_buffer=True)
                v_e.forward(clear_buffer=True)
                vl_local += v_model.loss.data
                ve_local += v_e.data
                # Mean IOU computation
                if compute_acc:
                    vmiou_local += compute_miou(num_classes, labels,
                                                np.argmax(v_model.pred.d, axis=1), masks)

            vl_local /= val_iter_local
            ve_local /= val_iter_local
            if compute_acc:
                vmiou_local /= val_iter_local
                vmiou_ndarray = nn.NdArray.from_numpy_array(
                    np.array(vmiou_local))
            if distributed:
                comm.all_reduce(vl_local, division=True, inplace=True)
                comm.all_reduce(ve_local, division=True, inplace=True)
                if compute_acc:
                    comm.all_reduce(vmiou_ndarray, division=True, inplace=True)

            if device_id == 0:
                monitor_vloss.add(i * n_devices, vl_local.data.copy())
                monitor_verr.add(i * n_devices, ve_local.data.copy())
                if compute_acc:
                    monitor_miou.add(i * n_devices, vmiou_local)
                monitor_vtime.add(i * n_devices)

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        e_acc = nn.NdArray(t_e.shape)
        e_acc.zero()
        l_acc = nn.NdArray(t_model.loss.shape)
        l_acc.zero()
        # Gradient accumulation loop
        for j in range(args.accum_grad):
            images, labels, masks = data.next()
            t_model.image.d = images
            t_model.label.d = labels
            t_model.mask.d = masks
            t_model.image.data.cast(np.float32, ctx)
            t_model.label.data.cast(np.int32, ctx)
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            t_e.forward(clear_buffer=True)
            e_acc += t_e.data
            l_acc += t_model.loss.data

        # AllReduce
        if distributed:
            params = [x.grad for x in nn.get_parameters().values()]
            comm.all_reduce(params, division=False, inplace=False)
            comm.all_reduce(l_acc, division=True, inplace=True)
            comm.all_reduce(e_acc, division=True, inplace=True)
        solver.scale_grad(1./args.accum_grad)
        solver.weight_decay(args.weight_decay)
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if distributed:
            # Synchronize by averaging the weights over devices using allreduce
            if (i+1) % args.sync_weight_every_itr == 0:
                weights = [x.data for x in nn.get_parameters().values()]
                comm.all_reduce(weights, division=True, inplace=True)

        if device_id == 0:
            monitor_loss.add(
                i * n_devices, (l_acc / args.accum_grad).data.copy())
            monitor_err.add(
                i * n_devices, (e_acc / args.accum_grad).data.copy())
            monitor_time.add(i * n_devices)

        # Learning rate decay at scheduled iter --> changed to poly learning rate decay policy
        # if i in args.learning_rate_decay_at:
        solver.set_learning_rate(base_lr * ((1 - i / args.max_iter)**0.1))

    if device_id == 0:
        nn.save_parameters(os.path.join(args.model_save_path,
                                        'param_%06d.h5' % args.max_iter))
Пример #11
0
def main():
    """
    Main script.
    Steps:
    * Setup calculation environment
    * Initialize data iterator.
    * Create Networks
    * Create Solver.
    * Training Loop.
    *   Training
    *   Test
    * Save
    """

    # Set args
    args = get_args(monitor_path='tmp.monitor.vae', max_iter=60000,
                    model_save_path=None, learning_rate=3e-4, batch_size=100, weight_decay=0)

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Initialize data provider
    di_l = data_iterator_mnist(args.batch_size, True)
    di_t = data_iterator_mnist(args.batch_size, False)

    # Network
    shape_x = (1, 28, 28)
    shape_z = (50,)
    x = nn.Variable((args.batch_size,) + shape_x)
    loss_l = vae(x, shape_z, test=False)
    loss_t = vae(x, shape_z, test=True)

    # Create solver
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Monitors for training and validation
    monitor = M.Monitor(args.model_save_path)
    monitor_training_loss = M.MonitorSeries(
        "Training loss", monitor, interval=600)
    monitor_test_loss = M.MonitorSeries("Test loss", monitor, interval=600)
    monitor_time = M.MonitorTimeElapsed("Elapsed time", monitor, interval=600)

    # Training Loop.
    for i in range(args.max_iter):

        # Initialize gradients
        solver.zero_grad()

        # Forward, backward and update
        x.d, _ = di_l.next()
        loss_l.forward(clear_no_need_grad=True)
        loss_l.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()

        # Forward for test
        x.d, _ = di_t.next()
        loss_t.forward(clear_no_need_grad=True)

        # Monitor for logging
        monitor_training_loss.add(i, loss_l.d.copy())
        monitor_test_loss.add(i, loss_t.d.copy())
        monitor_time.add(i)

    # Save the model
    nn.save_parameters(
        os.path.join(args.model_save_path, 'params_%06d.h5' % args.max_iter))
Пример #12
0
def train(args):
    """
    Main script.
    """

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    # TRAIN

    # Fake path
    z = nn.Variable([args.batch_size, 100, 1, 1])
    gen = generator(z, test=True)
    gen.persistent = True
    with nn.parameter_scope("gen"):
        nn.load_parameters(
            "/home/mizuochi/programing/font/dcgan_model_0220/generator_param_290000.h5"
        )
        #nn.load_parameters("/home/mizuochi/programing/font/dcgan_model_0220/generator_param_522000.h5")

#z.d = np.random.randn(*z.shape)
#gen.forward()
#for i in range(40):
#    Image.fromarray(np.uint8((gen.d[i][0]+1)*255/2.0)).save("./test/"+str(i)+".png")

# Real path
    x = nn.Variable([args.batch_size, 1, 28, 28])
    vec = nn.Variable([args.batch_size, 100])
    pred_vec = vectorizer(x, test=False)
    #loss_dis = F.mean(F.sigmoid_cross_entropy(pred_vec, vec))
    loss_dis = F.mean(F.squared_error(pred_vec, vec))

    # Create Solver.
    solver_dis = S.Adam(args.learning_rate, beta1=0.5)
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss_dis = M.MonitorSeries("Discriminator loss",
                                       monitor,
                                       interval=10)
    monitor_time = M.MonitorTimeElapsed("Time", monitor, interval=100)

    #data = data_iterator_mnist(args.batch_size, True)
    #data = iterator.simple_data_iterator(load_kanji_data(),args.batch_size,True)

    # Training loop.
    for i in range(args.max_iter):
        if i % args.model_save_interval == 0:
            with nn.parameter_scope("dis"):
                nn.save_parameters(
                    os.path.join(args.model_save_path,
                                 "vectorizer_param_%06d.h5" % i))

        # Training forward
            z.d = np.random.randn(*z.shape)
            gen.forward()
        x.d = gen.d
        vec.d = z.d.reshape((args.batch_size, 100))

        # Discriminator update.
        solver_dis.zero_grad()
        loss_dis.forward(clear_no_need_grad=True)
        loss_dis.backward(clear_buffer=True)
        solver_dis.weight_decay(args.weight_decay)
        solver_dis.update()
        monitor_loss_dis.add(i, loss_dis.d.copy())
        monitor_time.add(i)

    with nn.parameter_scope("dis"):
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         "discriminator_param_%06d.h5" % i))
Пример #13
0
def train():
    """
    Main script for training.
    """

    args = get_args()

    num_classes = 1000

    # Communicator and Context
    from nnabla.ext_utils import get_extension_context
    extension_module = "cudnn"  # TODO: Hard coded!!!
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)

    from nnabla_ext.cuda import StreamEventHandler
    stream_event_handler = StreamEventHandler(int(comm.ctx.device_id))

    # Create data iterater
    data, vdata = get_data_iterators(args, comm, stream_event_handler)

    # Network for training
    t_model = get_model(args,
                        num_classes,
                        test=False,
                        channel_last=args.channel_last)

    # Network for validation
    v_model = get_model(args,
                        num_classes,
                        test=True,
                        channel_last=args.channel_last)

    # Solver
    loss_scaling = args.loss_scaling if args.type_config == 'half' else 1
    # To cancel loss scaling, learning rate is divided by loss_scaling.
    # Note this assumes legacy SGD w/ moemntum implementation,
    # otherwise, it is recommended to apply division at gradient itself
    # using scale_grad for example.
    base_learning_rate = args.learning_rate / loss_scaling

    # Weight decay is multiplied by loss_scaling to cancel the effect of loss_scaling
    # cancelling at learning rate.
    # Also, note that is is multiplied by number GPUs (processes),
    # because all-reduce sum over GPUs is performed before applying weight decay.
    weight_decay = args.weight_decay * loss_scaling * comm.n_procs
    solver = MomentumNoWeightDecayBn(base_learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    # Learning rate scheduler
    decay_rate = 0.1
    learning_rate_scheduler = LearningRateScheduler(
        base_learning_rate, args.learning_rate_decay_at, decay_rate,
        args.warmup_epochs)

    # Monitors
    monitor = None
    if comm.rank == 0:
        if not os.path.isdir(args.monitor_path):
            os.makedirs(args.monitor_path)
        monitor = M.Monitor(args.monitor_path)

    # Epoch runner
    train_epoch = EpochTrainer(t_model, solver, learning_rate_scheduler, data,
                               comm, monitor, loss_scaling, weight_decay,
                               stream_event_handler)
    val_epoch = None
    if args.val_interval > 0:
        val_epoch = EpochValidator(v_model, vdata, comm, monitor,
                                   stream_event_handler)

    # Epoch loop
    for epoch in range(args.max_epochs):
        # Save parameters
        if epoch > 0 and epoch % (
                args.model_save_interval) == 0 and comm.rank == 0:
            nn.save_parameters(
                os.path.join(args.monitor_path, 'param_%03d.h5' % epoch))

        # Run validation for examples in an epoch
        if val_epoch is not None \
           and epoch > 0 \
           and epoch % args.val_interval == 0:
            val_epoch.run(epoch)

        # Run training for examples in an epoch
        train_epoch.run(epoch)

    # Run final validation
    if val_epoch is not None:
        val_epoch.run(args.max_epochs)

    # Save the final model.
    if comm.rank == 0:
        nn.save_parameters(
            os.path.join(args.monitor_path,
                         'param_%03d.h5' % (args.max_epochs)))
Пример #14
0
def animate(args):

    # get context
    ctx = get_extension_context(args.context)
    nn.set_default_context(ctx)
    logger.setLevel(logging.ERROR)  # to supress minor messages

    if not args.config:
        assert not args.params, "pretrained weights file is given, but corresponding config file is not. Please give both."
        download_provided_file(
            "https://nnabla.org/pretrained-models/nnabla-examples/GANs/first-order-model/voxceleb_trained_info.yaml")
        args.config = 'voxceleb_trained_info.yaml'

        download_provided_file(
            "https://nnabla.org/pretrained-models/nnabla-examples/GANs/first-order-model/pretrained_fomm_params.h5")

    config = read_yaml(args.config)

    dataset_params = config.dataset_params
    model_params = config.model_params

    if args.detailed:
        vis_params = config.visualizer_params
        visualizer = Visualizer(**vis_params)

    if not args.params:
        assert "log_dir" in config, "no log_dir found in config. therefore failed to locate pretrained parameters."
        param_file = os.path.join(
            config.log_dir, config.saved_parameters)
    else:
        param_file = args.params
    print(f"Loading {param_file} for image animation...")
    nn.load_parameters(param_file)

    bs, h, w, c = [1] + dataset_params.frame_shape
    source = nn.Variable((bs, c, h, w))
    driving_initial = nn.Variable((bs, c, h, w))
    driving = nn.Variable((bs, c, h, w))

    filename = args.driving

    # process repeated until all the test data is used
    driving_video = read_video(
        filename, dataset_params.frame_shape)  # (#frames, h, w, 3)
    driving_video = np.transpose(
        driving_video, (0, 3, 1, 2))  # (#frames, 3, h, w)

    source_img = imread(args.source, channel_first=True,
                        size=(256, 256)) / 255.
    source_img = source_img[:3]

    source.d = np.expand_dims(source_img, 0)
    driving_initial.d = driving_video[0][:3, ]

    with nn.parameter_scope("kp_detector"):
        kp_source = detect_keypoint(source,
                                    **model_params.kp_detector_params,
                                    **model_params.common_params,
                                    test=True, comm=False)
        persistent_all(kp_source)

    with nn.parameter_scope("kp_detector"):
        kp_driving_initial = detect_keypoint(driving_initial,
                                             **model_params.kp_detector_params,
                                             **model_params.common_params,
                                             test=True, comm=False)
        persistent_all(kp_driving_initial)

    with nn.parameter_scope("kp_detector"):
        kp_driving = detect_keypoint(driving,
                                     **model_params.kp_detector_params,
                                     **model_params.common_params,
                                     test=True, comm=False)
        persistent_all(kp_driving)

    if args.adapt_movement_scale:
        nn.forward_all([kp_source["value"],
                        kp_source["jacobian"],
                        kp_driving_initial["value"],
                        kp_driving_initial["jacobian"]])
        source_area = ConvexHull(kp_source['value'][0].d).volume
        driving_area = ConvexHull(kp_driving_initial['value'][0].d).volume
        adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
    else:
        adapt_movement_scale = 1

    kp_norm = adjust_kp(kp_source=unlink_all(kp_source), kp_driving=kp_driving,
                        kp_driving_initial=unlink_all(kp_driving_initial),
                        adapt_movement_scale=adapt_movement_scale,
                        use_relative_movement=args.unuse_relative_movement,
                        use_relative_jacobian=args.unuse_relative_jacobian)
    persistent_all(kp_norm)

    with nn.parameter_scope("generator"):
        generated = occlusion_aware_generator(source,
                                              kp_source=unlink_all(kp_source),
                                              kp_driving=kp_norm,
                                              **model_params.generator_params,
                                              **model_params.common_params,
                                              test=True, comm=False)

    if not args.full and 'sparse_deformed' in generated:
        del generated['sparse_deformed']  # remove needless info

    persistent_all(generated)

    generated['kp_driving'] = kp_driving
    generated['kp_source'] = kp_source
    generated['kp_norm'] = kp_norm

    # generated contains these values;
    # 'mask': <Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25
    # 'sparse_deformed': <Variable((bs, num_kp+1, num_channel, h/4, w/4))  # (bs, num_kp + 1, c, h, w)
    # 'occlusion_map': <Variable((bs, 1, h/4, w/4))
    # 'deformed': <Variable((bs, c, h, w))
    # 'prediction': <Variable((bs, c, h, w))

    mode = "arbitrary"
    if "log_dir" in config:
        result_dir = os.path.join(args.out_dir, os.path.basename(config.log_dir), f"{mode}")
    else:
        result_dir = os.path.join(args.out_dir, "test_result", f"{mode}")

    # create an empty directory to save generated results
    _ = nm.Monitor(result_dir)

    # load the header images.
    header = imread("imgs/header_combined.png", channel_first=True)
    generated_images = list()

    # compute these in advance and reuse
    nn.forward_all([kp_source["value"],
                    kp_source["jacobian"]],
                   clear_buffer=True)
    nn.forward_all([kp_driving_initial["value"],
                    kp_driving_initial["jacobian"]],
                   clear_buffer=True)

    num_of_driving_frames = driving_video.shape[0]

    for frame_idx in tqdm(range(num_of_driving_frames)):
        driving.d = driving_video[frame_idx][:3, ]
        nn.forward_all([generated["prediction"],
                        generated["deformed"]], clear_buffer=True)

        if args.detailed:
            # visualize source w/kp, driving w/kp, deformed source, generated w/kp, generated image, occlusion map
            visualization = visualizer.visualize(
                source=source.d, driving=driving.d, out=generated)
            if args.full:
                visualization = reshape_result(visualization)  # (H, W, C)
            combined_image = visualization.transpose(2, 0, 1)  # (C, H, W)

        elif args.only_generated:
            combined_image = np.clip(generated["prediction"].d[0], 0.0, 1.0)
            combined_image = (255*combined_image).astype(np.uint8)  # (C, H, W)

        else:
            # visualize source, driving, and generated image
            driving_fake = np.concatenate([np.clip(driving.d[0], 0.0, 1.0),
                                           np.clip(generated["prediction"].d[0], 0.0, 1.0)], axis=2)
            header_source = np.concatenate([np.clip(header / 255., 0.0, 1.0),
                                            np.clip(source.d[0], 0.0, 1.0)], axis=2)
            combined_image = np.concatenate(
                [header_source, driving_fake], axis=1)
            combined_image = (255*combined_image).astype(np.uint8)

        generated_images.append(combined_image)

    # once each video is generated, save it.
    output_filename = f"{os.path.splitext(os.path.basename(filename))[0]}.mp4"
    output_filename = f"{os.path.basename(args.source)}_by_{output_filename}"
    output_filename = output_filename.replace("#", "_")
    if args.output_png:
        monitor_vis = nm.MonitorImage(output_filename, nm.Monitor(result_dir),
                                      interval=1, num_images=1,
                                      normalize_method=lambda x: x)
        for frame_idx, img in enumerate(generated_images):
            monitor_vis.add(frame_idx, img)
    else:
        generated_images = [_.transpose(1, 2, 0) for _ in generated_images]
        # you might need to change ffmpeg_params according to your environment.
        mimsave(f'{os.path.join(result_dir, output_filename)}', generated_images,
                fps=args.fps,
                ffmpeg_params=["-pix_fmt", "yuv420p",
                               "-vcodec", "libx264",
                               "-f", "mp4",
                               "-q", "0"])

    return
Пример #15
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('config', default=None, type=str)
    parser.add_argument('--param-file', default=None, type=str)
    parser.add_argument('--num-test', '-n', default=None, type=int)
    args = parser.parse_args()
    param_file = args.param_file

    config = load_transformer_config(args.config)

    config["num_test"] = args.num_test

    #########################
    # Context Setting
    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info(f'Running in {config["context"]}.')
    ctx = get_extension_context(config["context"],
                                device_id=config["device_id"])
    nn.set_default_context(ctx)
    #########################

    # Data Loading
    logger.info('Initializing Datasource')
    train_iterator_src = data.celebv_data_iterator(
        dataset_mode="transformer",
        celeb_name=config["src_celeb_name"],
        data_dir=config["train_dir"],
        ref_dir=config["ref_dir"],
        mode="test",
        batch_size=config["test"]["batch_size"],
        shuffle=False,
        with_memory_cache=config["test"]["with_memory_cache"],
        with_file_cache=config["test"]["with_file_cache"],
        resize_size=config["preprocess"]["resize_size"],
        line_thickness=config["preprocess"]["line_thickness"],
        gaussian_kernel=config["preprocess"]["gaussian_kernel"],
        gaussian_sigma=config["preprocess"]["gaussian_sigma"])

    train_iterator_trg = data.celebv_data_iterator(
        dataset_mode="transformer",
        celeb_name=config["trg_celeb_name"],
        data_dir=config["train_dir"],
        ref_dir=config["ref_dir"],
        mode="test",
        batch_size=config["test"]["batch_size"],
        shuffle=False,
        with_memory_cache=config["test"]["with_memory_cache"],
        with_file_cache=config["test"]["with_file_cache"],
        resize_size=config["preprocess"]["resize_size"],
        line_thickness=config["preprocess"]["line_thickness"],
        gaussian_kernel=config["preprocess"]["gaussian_kernel"],
        gaussian_sigma=config["preprocess"]["gaussian_sigma"])
    train_iterators = (train_iterator_src, train_iterator_trg)

    # monitor
    monitor = nm.Monitor(
        os.path.join(config["test"]["logdir"], "transformer",
                     f'{config["src_celeb_name"]}2{config["trg_celeb_name"]}',
                     config["experiment_name"]))

    # Network
    netG = {
        'netG_A2B': models.netG_transformer,
        'netG_B2A': models.netG_transformer
    }
    if not param_file:
        param_file_A2B = sorted(glob.glob(
            os.path.join(
                config["logdir"], "transformer",
                f'{config["src_celeb_name"]}2{config["trg_celeb_name"]}',
                config["experiment_name"], "netG_transformer_A2B_*")),
                                key=os.path.getmtime)[-1]
    else:
        param_file_A2B = param_file

    test_transformer(config, netG, train_iterators, monitor, param_file_A2B)
Пример #16
0
def train():
    """
    Main script.
    """

    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    if args.tiny_mode:
        # We use Tiny ImageNet from Stanford CS231N class.
        # (Tiny ImageNet, https://tiny-imagenet.herokuapp.com/)
        # Tiny ImageNet consists of 200 categories, each category has 500 images
        # in training set. The image size is 64x64. To adapt ResNet into 64x64
        # image inputs, the input image size of ResNet is set as 56x56, and
        # the stride in the first conv and the first max pooling are removed.
        # Please check README.
        data = data_iterator_tiny_imagenet(args.batch_size, 'train')
        vdata = data_iterator_tiny_imagenet(args.batch_size, 'val')
        num_classes = 200
    else:
        # We use ImageNet.
        # (ImageNet, https://imagenet.herokuapp.com/)
        # ImageNet consists of 1000 categories, each category has 1280 images
        # in training set. The image size is various. To adapt ResNet into
        # 320x320 image inputs, the input image size of ResNet is set as
        # 224x224. We need to get tar file and create cache file(320x320 images).
        # Please check README.
        data = data_iterator_imagenet(args.batch_size,
                                      args.train_cachefile_dir)
        vdata = data_iterator_imagenet(args.batch_size, args.val_cachefile_dir)
        num_classes = 1000
    t_model = get_model(args, num_classes, test=False, tiny=args.tiny_mode)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward

    # TODO: need_grad should be passed to get_unlinked_variable after v1.0.3 fix.
    t_pred2 = t_model.pred.get_unlinked_variable()
    t_pred2.need_grad = False

    t_e = F.mean(F.top_n_error(t_pred2, t_model.label))
    v_model = get_model(args, num_classes, test=True, tiny=args.tiny_mode)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward

    # TODO: need_grad should be passed to get_unlinked_variable after v1.0.3 fix.
    v_pred2 = v_model.pred.get_unlinked_variable()
    v_pred2.need_grad = False

    v_e = F.mean(F.top_n_error(v_pred2, v_model.label))

    # Create Solver.
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=10)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=10)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_vtime = M.MonitorTimeElapsed("Validation time",
                                         monitor,
                                         interval=10)

    # Training loop.
    for i in range(args.max_iter):
        # Save parameters
        if i % args.model_save_interval == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'param_%06d.h5' % i))

        # Validation
        if i % args.val_interval == 0 and i != 0:

            # Clear all intermediate memory to save memory.
            # t_model.loss.clear_recursive()

            l = 0.0
            e = 0.0
            for j in range(args.val_iter):
                images, labels = vdata.next()
                v_model.image.d = images
                v_model.label.d = labels
                v_model.image.data.cast(np.uint8, ctx)
                v_model.label.data.cast(np.int32, ctx)
                v_model.loss.forward(clear_buffer=True)
                v_e.forward(clear_buffer=True)
                l += v_model.loss.d
                e += v_e.d
            monitor_vloss.add(i, l / args.val_iter)
            monitor_verr.add(i, e / args.val_iter)
            monitor_vtime.add(i)

            # Clear all intermediate memory to save memory.
            # v_model.loss.clear_recursive()

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        def accumulate_error(l, e, t_model, t_e):
            l += t_model.loss.d
            e += t_e.d
            return l, e

        # Gradient accumulation loop
        for j in range(args.accum_grad):
            images, labels = data.next()
            if j != 0:
                # Update e and l according to previous results of forward
                # propagation.
                # The update of last iteration is performed
                # after solver update to avoid unnecessary CUDA synchronization.
                # This is performed after data.next() in order to overlap
                # the data loading and graph execution.
                # TODO: Move this to the bottom of the loop when prefetch
                # data loader is available.
                l, e = accumulate_error(l, e, t_model, t_e)
            t_model.image.d = images
            t_model.label.d = labels
            t_model.image.data.cast(np.uint8, ctx)
            t_model.label.data.cast(np.int32, ctx)
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            t_e.forward(clear_buffer=True)

        solver.weight_decay(args.weight_decay)
        solver.update()

        # Accumulate errors after solver update
        l, e = accumulate_error(l, e, t_model, t_e)

        monitor_loss.add(i, l / args.accum_grad)
        monitor_err.add(i, e / args.accum_grad)
        monitor_time.add(i)

        # Learning rate decay at scheduled iter
        if i in args.learning_rate_decay_at:
            solver.set_learning_rate(solver.learning_rate() * 0.1)
    nn.save_parameters(
        os.path.join(args.model_save_path, 'param_%06d.h5' % args.max_iter))
Пример #17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default=None, type=str)
    parser.add_argument('--info', default=None, type=str)
    args = parser.parse_args()

    config = load_transformer_config(args.config)
    if args.info:
        config["experiment_name"] += args.info

    pprint.pprint(config)

    #########################
    # Context Setting
    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info(f'Running in {config["context"]}.')
    ctx = get_extension_context(config["context"],
                                device_id=config["device_id"])
    nn.set_default_context(ctx)
    #########################

    # Data Loading
    logger.info('Initialing Datasource')
    train_iterator_src = data.celebv_data_iterator(
        dataset_mode="transformer",
        celeb_name=config["src_celeb_name"],
        data_dir=config["train_dir"],
        ref_dir=config["ref_dir"],
        mode=config["mode"],
        batch_size=config["train"]["batch_size"],
        shuffle=config["train"]["shuffle"],
        with_memory_cache=config["train"]["with_memory_cache"],
        with_file_cache=config["train"]["with_file_cache"],
        resize_size=config["preprocess"]["resize_size"],
        line_thickness=config["preprocess"]["line_thickness"],
        gaussian_kernel=config["preprocess"]["gaussian_kernel"],
        gaussian_sigma=config["preprocess"]["gaussian_sigma"])

    train_iterator_trg = data.celebv_data_iterator(
        dataset_mode="transformer",
        celeb_name=config["trg_celeb_name"],
        data_dir=config["train_dir"],
        ref_dir=config["ref_dir"],
        mode=config["mode"],
        batch_size=config["train"]["batch_size"],
        shuffle=config["train"]["shuffle"],
        with_memory_cache=config["train"]["with_memory_cache"],
        with_file_cache=config["train"]["with_file_cache"],
        resize_size=config["preprocess"]["resize_size"],
        line_thickness=config["preprocess"]["line_thickness"],
        gaussian_kernel=config["preprocess"]["gaussian_kernel"],
        gaussian_sigma=config["preprocess"]["gaussian_sigma"])
    train_iterators = (train_iterator_src, train_iterator_trg)
    # monitor
    monitor = nm.Monitor(
        os.path.join(config["logdir"], "transformer",
                     f'{config["src_celeb_name"]}2{config["trg_celeb_name"]}',
                     config["experiment_name"]))

    # Network
    netG = {
        'netG_A2B': models.netG_transformer,
        'netG_B2A': models.netG_transformer
    }
    netD = {
        'netD_A': models.netD_transformer,
        'netD_B': models.netD_transformer
    }

    # Optimizer
    solver_netG = {
        'netG_A2B':
        S.Adam(alpha=config["train"]["lr"],
               beta1=config["train"]["beta1"],
               beta2=config["train"]["beta2"]),
        'netG_B2A':
        S.Adam(alpha=config["train"]["lr"],
               beta1=config["train"]["beta1"],
               beta2=config["train"]["beta2"])
    }

    solver_netD = {
        'netD_A':
        S.Adam(alpha=0.5 * config["train"]["lr"],
               beta1=config["train"]["beta1"],
               beta2=config["train"]["beta2"]),
        'netD_B':
        S.Adam(alpha=0.5 * config["train"]["lr"],
               beta1=config["train"]["beta1"],
               beta2=config["train"]["beta2"])
    }

    train_transformer(config, netG, netD, solver_netG, solver_netD,
                      train_iterators, monitor)