Beispiel #1
0
def test_graph_logreg(seed):
    rng = np.random.RandomState(seed)
    x = nn.Variable([2, 3, 4], need_grad=True)
    w = nn.Variable([12, 5], need_grad=True)
    b = nn.Variable([5], need_grad=True)
    t = nn.Variable([2, 1])
    x.d = rng.randn(*x.shape)
    w.d = rng.randn(*w.shape)
    b.d = rng.randn(*b.shape)
    t.d = rng.randint(0, 5, size=t.shape)

    nn.set_default_context(nn.Context())

    # Forwardprop by definintion
    with nn.auto_forward():
        z = F.affine(x, w, b, 1)
        l = F.softmax_cross_entropy(z, t, 1)
        L = F.mean(l)

    # Backprop
    # Diff should be initialized since they are always accumulated
    x.g = 0
    w.g = 0
    b.g = 0
    L.backward(clear_buffer=True)
    x.g = rng.randn(*x.shape)

    inputs = [x, w, b]

    from nbla_test_utils import \
        compute_analytical_and_numerical_grad_graph as grads
    agrad, ngrad = grads(L, inputs, 1e-3)
    assert np.allclose(ngrad, agrad, atol=1e-2)
Beispiel #2
0
def test_graph_model(model, seed):
    np.random.seed(313)
    rng = np.random.RandomState(seed)
    x = nn.Variable([2, 3, 4, 4], need_grad=True)
    t = nn.Variable([2, 1])
    x.d = rng.randn(*x.shape)
    t.d = rng.randint(0, 5, size=t.shape)

    nn.set_default_context(nn.Context())

    # Forwardprop by definintion
    nn.clear_parameters()
    if model == "mlp":
        with nn.parameter_scope('fc1'):
            z = PF.affine(x, 3)
        z2 = F.relu(z, inplace=True)
        with nn.parameter_scope('fc2'):
            z3 = PF.affine(z2, 5)
    elif model == "recurrent":
        with nn.parameter_scope('fc1'):
            z = PF.affine(x, 3)
            z2 = F.relu(z, inplace=True)
        h = z2
        for _ in range(2):
            with nn.parameter_scope('fc2'):
                h = PF.affine(h, 3)
                h = F.relu(h, inplace=True)
        with nn.parameter_scope('fc3'):
            z3 = PF.affine(h, 5)
    elif model == "convolution":
        with nn.parameter_scope('conv1'):
            z = PF.convolution(x, 3, (2, 2))
            z2 = F.relu(z, inplace=True)
        with nn.parameter_scope('fc2'):
            z3 = PF.affine(z2, 5)
    else:
        raise ValueError()
    l = F.softmax_cross_entropy(z3, t, 1)
    L = F.mean(l)

    # Forwardprop
    L.forward(clear_no_need_grad=True)

    # Backprop
    # Diff should be initialized since they are always accumulated
    x.grad.zero()
    L.backward(clear_buffer=True)
    x.g = rng.randn(*x.shape)
    parameters = nn.get_parameters()
    for param in parameters.values():
        param.grad.zero()
    inputs = [x] + list(parameters.values())

    from nbla_test_utils import \
        compute_analytical_and_numerical_grad_graph as grads
    agrad, ngrad = grads(L, inputs, 1e-3)
    assert np.allclose(ngrad, agrad, atol=1.05e-2)
Beispiel #3
0
def test_graph_clear_buffer(seed):
    np.random.seed(313)
    rng = np.random.RandomState(seed)
    x = nn.Variable([2, 3, 4, 4])
    t = nn.Variable([2, 1])
    x.d = rng.randn(*x.shape)
    t.d = rng.randint(0, 5, size=t.shape)

    # Network definition
    nn.set_default_context(nn.Context())
    nn.clear_parameters()
    x1 = x + 1
    x2 = x1 - 1
    with nn.parameter_scope('conv1'):
        z = PF.convolution(x2, 3, (2, 2))
        z2 = F.relu(z, inplace=True)
    with nn.parameter_scope('fc2'):
        z3 = PF.affine(z2, 5)
    l = F.softmax_cross_entropy(z3, t, 1)
    L = F.mean(l)

    # Forwardprop
    import tempfile
    import os
    tmpd = tempfile.mkdtemp()
    nn.save_parameters(os.path.join(tmpd, 'parameter.h5'))
    first = False
    for cnng in [False, True]:
        for cb in [False, True]:
            _ = nn.load_parameters(os.path.join(tmpd, 'parameter.h5'))
            for v in nn.get_parameters().values():
                v.grad.zero()
            L.forward(clear_no_need_grad=cnng)
            L.backward(clear_buffer=cb)
            if not first:
                first = True
                g = list(nn.get_parameters().values())[0].g.copy()
            else:
                g2 = list(nn.get_parameters().values())[0].g.copy()
                assert np.all(g == g2)
Beispiel #4
0
def train(args):
    """
    Main script.
    """

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    margin = 1.0  # Margin for contrastive loss.

    # TRAIN
    # Create input variables.
    image0 = nn.Variable([args.batch_size, 1, 28, 28])
    image1 = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size])
    # Create predition graph.
    pred = mnist_lenet_siamese(image0, image1, test=False)
    # Create loss function.
    loss = F.mean(contrastive_loss(pred, label, margin))

    # TEST
    # Create input variables.
    vimage0 = nn.Variable([args.batch_size, 1, 28, 28])
    vimage1 = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size])
    # Create predition graph.
    vpred = mnist_lenet_siamese(vimage0, vimage1, test=True)
    vloss = F.mean(contrastive_loss(vpred, vlabel, margin))

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_vloss = M.MonitorSeries("Test loss", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    rng = np.random.RandomState(313)
    data = siamese_data_iterator(args.batch_size, True, rng)
    vdata = siamese_data_iterator(args.batch_size, False, rng)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage0.d, vimage1.d, vlabel.d = vdata.next()
                vloss.forward(clear_buffer=True)
                ve += vloss.d
            monitor_vloss.add(i, ve / args.val_iter)
        if i % args.model_save_interval == 0:
            nn.save_parameters(os.path.join(
                args.model_save_path, 'params_%06d.h5' % i))
        image0.d, image1.d, label.d = data.next()
        solver.zero_grad()
        # Training forward, backward and update
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        monitor_loss.add(i, loss.d.copy())
        monitor_time.add(i)

    parameter_file = os.path.join(
        args.model_save_path, 'params_%06d.h5' % args.max_iter)
    nn.save_parameters(parameter_file)

    nnp_file = os.path.join(
        args.model_save_path, 'siamese_%06d.nnp' % (args.max_iter))
    runtime_contents = {
        'networks': [
            {'name': 'Validation',
             'batch_size': args.batch_size,
             'outputs': {'y': vpred},
             'names': {'x0': vimage0, 'x1': vimage1}}],
        'executors': [
            {'name': 'Runtime',
             'network': 'Validation',
             'data': ['x0', 'x1'],
             'output': ['y']}]}
    save.save(nnp_file, runtime_contents)

    from cpp_forward_check import check_cpp_forward
    check_cpp_forward(args.model_save_path, [vimage0.d, vimage1.d], [
                      vimage0, vimage1], vpred, nnp_file)
Beispiel #5
0
    return h


def gan_loss(p_fake, p_real=None):
    if p_real is not None:
        return F.mean(p_fake) - F.mean(p_real)
    return -F.mean(p_fake)


if __name__ == '__main__':
    # Config
    b, c, h, w = 4, 3, 32, 32
    latent = 128
    eps = np.random.rand()
    ctx = get_extension_context("cudnn")
    nn.set_default_context(ctx)

    z = nn.Variable.from_numpy_array(np.random.randn(b, latent))
    x_real = nn.Variable.from_numpy_array(np.random.randn(b, c, h,
                                                          w)) / 127.5 - 1.0

    # Fake sample
    print("# Fake sample")
    x_fake = generator(z, test=False)
    print(x_fake)

    # Prob for fake sample
    print("# Prob for fake sample")
    p_fake = discriminator(x_fake)
    print(p_fake)
Beispiel #6
0
def train(args):
    """
    Main script.
    """

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    # TRAIN

    # Fake path
    z = nn.Variable([args.batch_size, 100, 1, 1])
    gen = generator(z, test=True)
    gen.persistent = True
    with nn.parameter_scope("gen"):
        nn.load_parameters(
            "/home/mizuochi/programing/font/dcgan_model_0220/generator_param_290000.h5"
        )
        #nn.load_parameters("/home/mizuochi/programing/font/dcgan_model_0220/generator_param_522000.h5")

#z.d = np.random.randn(*z.shape)
#gen.forward()
#for i in range(40):
#    Image.fromarray(np.uint8((gen.d[i][0]+1)*255/2.0)).save("./test/"+str(i)+".png")

# Real path
    x = nn.Variable([args.batch_size, 1, 28, 28])
    vec = nn.Variable([args.batch_size, 100])
    pred_vec = vectorizer(x, test=False)
    #loss_dis = F.mean(F.sigmoid_cross_entropy(pred_vec, vec))
    loss_dis = F.mean(F.squared_error(pred_vec, vec))

    # Create Solver.
    solver_dis = S.Adam(args.learning_rate, beta1=0.5)
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss_dis = M.MonitorSeries("Discriminator loss",
                                       monitor,
                                       interval=10)
    monitor_time = M.MonitorTimeElapsed("Time", monitor, interval=100)

    #data = data_iterator_mnist(args.batch_size, True)
    #data = iterator.simple_data_iterator(load_kanji_data(),args.batch_size,True)

    # Training loop.
    for i in range(args.max_iter):
        if i % args.model_save_interval == 0:
            with nn.parameter_scope("dis"):
                nn.save_parameters(
                    os.path.join(args.model_save_path,
                                 "vectorizer_param_%06d.h5" % i))

        # Training forward
            z.d = np.random.randn(*z.shape)
            gen.forward()
        x.d = gen.d
        vec.d = z.d.reshape((args.batch_size, 100))

        # Discriminator update.
        solver_dis.zero_grad()
        loss_dis.forward(clear_no_need_grad=True)
        loss_dis.backward(clear_buffer=True)
        solver_dis.weight_decay(args.weight_decay)
        solver_dis.update()
        monitor_loss_dis.add(i, loss_dis.d.copy())
        monitor_time.add(i)

    with nn.parameter_scope("dis"):
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         "discriminator_param_%06d.h5" % i))
Beispiel #7
0
def main():
    args = get_args()
    rng = np.random.RandomState(1223)

    # Get context
    from nnabla.ext_utils import get_extension_context, import_extension_module
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)
    ext = import_extension_module(args.context)

    # read label file
    f = open(args.label_file_path, "r")
    labels_dict = f.readlines()

    # Load parameters
    _ = nn.load_parameters(args.model_load_path)

    # Build a Deeplab v3+ network
    x = nn.Variable((1, 3, args.image_width, args.image_width),
                    need_grad=False)
    y = net.deeplabv3plus_model(x,
                                args.output_stride,
                                args.num_class,
                                test=True)

    # preprocess image
    image = imageio.imread(args.test_image_file, as_gray=False, pilmode="RGB")
    #image = imread(args.test_image_file).astype('float32')
    orig_h, orig_w, orig_c = image.shape
    old_size = (orig_h, orig_w)

    input_array = image_preprocess.preprocess_image_and_label(
        image, label=None, target_width=args.image_width, train=False)
    print('Input', input_array.shape)
    input_array = np.transpose(input_array, (2, 0, 1))
    input_array = np.reshape(
        input_array,
        (1, input_array.shape[0], input_array.shape[1], input_array.shape[2]))

    # Compute inference and inference time
    t = time.time()

    x.d = input_array
    y.forward(clear_buffer=True)
    print("done")
    available_devices = ext.get_devices()
    ext.device_synchronize(available_devices[0])
    ext.clear_memory_cache()

    elapsed = time.time() - t
    print('Inference time : %s seconds' % (elapsed))

    output = np.argmax(y.d, axis=1)  # (batch,h,w)

    # Apply post processing
    post_processed = post_process(output[0], old_size, args.image_width)

    # Get the classes predicted
    predicted_classes = np.unique(post_processed)
    for i in range(predicted_classes.shape[0]):
        print('Classes Segmented: ', labels_dict[predicted_classes[i]])

    # Visualize inference result
    visualize(post_processed)
Beispiel #8
0
def test(opt):
    """ Validate opt.checkpoint if opt.checkpoint is valid file.
    Otherwise, if opt.checkpoint_dir is set, it will search all params.h5 files and validate all of them.
    The mAP of each checkpoint will be monitored and output to opt.checkpoint_dir.

    Args:
        opt: Options

    Returns:

    """
    def test_cur_checkpoint(opt):
        if opt.checkpoint == '':
            print("Please provide trained model")
            return

        Detector = detector_factory[opt.task]
        detector = Detector(opt)

        results = {}
        num_iters = val_loader.size
        pbar = trange(num_iters, desc="[Test]")
        for ind in pbar:
            img_id = val_source.images[ind]
            img_info = val_source.coco.loadImgs(ids=[img_id])[0]
            img_path = os.path.join(val_source.img_dir, img_info['file_name'])
            ret = detector.run(img_path)
            results[img_id] = ret['results']
        mAP = val_source.run_eval(results, opt.save_dir, opt.data_dir)
        del detector
        return mAP

    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpus_str
    if opt.extension_module != 'cpu':
        if opt.mixed_precision:
            ctx = get_extension_context(
                opt.extension_module, device_id="0", type_config="half")
        else:
            ctx = get_extension_context(opt.extension_module, device_id="0")
        nn.set_default_context(ctx)

    nn.set_auto_forward(True)
    source_factory = get_data_source(opt.dataset)
    val_source = source_factory(opt, 'val', shuffle=False)
    batch_size = 1
    val_loader = data_iterator(val_source,
                               batch_size,
                               with_memory_cache=True,
                               with_file_cache=False
                               )

    if os.path.isdir(opt.checkpoint_dir) and os.path.exists(opt.checkpoint_dir):
        dir_path = opt.checkpoint_dir
        checkpoints_to_run = recursive_glob(dir_path, "params.h5")
        monitor = Monitor(dir_path)
        monitor_map = MonitorSeries(
            "Val mAP", monitor, interval=1, verbose=False)

        for cur_file in checkpoints_to_run:
            opt.checkpoint = cur_file
            mAP = test_cur_checkpoint(opt)
            folder_name = os.path.basename(os.path.dirname(cur_file))
            # The folder name format is defined in trains/ctdet.py.
            # Format: file_name = os.path.join(path, "epoch_" + str(epoch).zfill(3))
            epoch_num = int(folder_name.replace("epoch_", ""))
            monitor_map.add(epoch_num, mAP)

    else:
        test_cur_checkpoint(opt)
Beispiel #9
0
def train():
    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    if args.model_load_path == "":
        raise Exception("Set `model_load_path`")
    nn.load_parameters(args.model_load_path)
    model_prediction = cifar10_resnet23_slim_prediction
    # TRAIN
    maps = 64
    data_iterator = data_iterator_cifar10
    c = 3
    h = w = 32
    n_train = 50000
    n_valid = 10000

    # Create input variables.
    image = nn.Variable([args.batch_size, c, h, w])
    label = nn.Variable([args.batch_size, 1])
    # Create model_prediction graph.
    pred = model_prediction(image, maps=maps, test=False)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, c, h, w])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    vpred = model_prediction(vimage, maps=maps, test=True)

    # Set mask
    create_and_set_mask(nn.get_parameters(grad_only=False),
                        rrate=args.reduction_rate)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Initialize DataIterator
    data = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)
    best_ve = 1.0
    ve = 1.0
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(int(n_valid / args.batch_size)):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            ve /= int(n_valid / args.batch_size)
            monitor_verr.add(i, ve)
        if ve < best_ve:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))
            best_ve = ve
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(int(n_valid / args.batch_size)):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    ve /= int(n_valid / args.batch_size)
    monitor_verr.add(i, ve)

    parameter_file = os.path.join(args.model_save_path,
                                  'params_{:06}.h5'.format(args.max_iter))
    nn.save_parameters(parameter_file)
Beispiel #10
0
def train():
    """
    Main script for training.
    """

    args = get_args()

    num_classes = 1000

    # Communicator and Context
    from nnabla.ext_utils import get_extension_context
    extension_module = "cudnn"  # TODO: Hard coded!!!
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)

    from nnabla_ext.cuda import StreamEventHandler
    stream_event_handler = StreamEventHandler(int(comm.ctx.device_id))

    # Create data iterater
    data, vdata = get_data_iterators(args, comm, stream_event_handler)

    # Network for training
    t_model = get_model(args,
                        num_classes,
                        test=False,
                        channel_last=args.channel_last)

    # Network for validation
    v_model = get_model(args,
                        num_classes,
                        test=True,
                        channel_last=args.channel_last)

    # Solver
    loss_scaling = args.loss_scaling if args.type_config == 'half' else 1
    # To cancel loss scaling, learning rate is divided by loss_scaling.
    # Note this assumes legacy SGD w/ moemntum implementation,
    # otherwise, it is recommended to apply division at gradient itself
    # using scale_grad for example.
    base_learning_rate = args.learning_rate / loss_scaling

    # Weight decay is multiplied by loss_scaling to cancel the effect of loss_scaling
    # cancelling at learning rate.
    # Also, note that is is multiplied by number GPUs (processes),
    # because all-reduce sum over GPUs is performed before applying weight decay.
    weight_decay = args.weight_decay * loss_scaling * comm.n_procs
    solver = MomentumNoWeightDecayBn(base_learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    # Learning rate scheduler
    decay_rate = 0.1
    learning_rate_scheduler = LearningRateScheduler(
        base_learning_rate, args.learning_rate_decay_at, decay_rate,
        args.warmup_epochs)

    # Monitors
    monitor = None
    if comm.rank == 0:
        if not os.path.isdir(args.monitor_path):
            os.makedirs(args.monitor_path)
        monitor = M.Monitor(args.monitor_path)

    # Epoch runner
    train_epoch = EpochTrainer(t_model, solver, learning_rate_scheduler, data,
                               comm, monitor, loss_scaling, weight_decay,
                               stream_event_handler)
    val_epoch = None
    if args.val_interval > 0:
        val_epoch = EpochValidator(v_model, vdata, comm, monitor,
                                   stream_event_handler)

    # Epoch loop
    for epoch in range(args.max_epochs):
        # Save parameters
        if epoch > 0 and epoch % (
                args.model_save_interval) == 0 and comm.rank == 0:
            nn.save_parameters(
                os.path.join(args.monitor_path, 'param_%03d.h5' % epoch))

        # Run validation for examples in an epoch
        if val_epoch is not None \
           and epoch > 0 \
           and epoch % args.val_interval == 0:
            val_epoch.run(epoch)

        # Run training for examples in an epoch
        train_epoch.run(epoch)

    # Run final validation
    if val_epoch is not None:
        val_epoch.run(args.max_epochs)

    # Save the final model.
    if comm.rank == 0:
        nn.save_parameters(
            os.path.join(args.monitor_path,
                         'param_%03d.h5' % (args.max_epochs)))
Beispiel #11
0
def train(args):
    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    aug_list = args.aug_list

    # Model
    scope_gen = "Generator"
    scope_dis = "Discriminator"
    # generator loss
    z = nn.Variable([args.batch_size, args.latent, 1, 1])
    x_fake = Generator(z, scope_name=scope_gen, img_size=args.image_size)
    p_fake = Discriminator([augment(xf, aug_list) for xf in x_fake],
                           label="fake",
                           scope_name=scope_dis)
    lossG = loss_gen(p_fake)
    # discriminator loss
    x_real = nn.Variable(
        [args.batch_size, 3, args.image_size, args.image_size])
    x_real_aug = augment(x_real, aug_list)
    p_real, rec_imgs, part = Discriminator(x_real_aug,
                                           label="real",
                                           scope_name=scope_dis)
    lossD_fake = loss_dis_fake(p_fake)
    lossD_real = loss_dis_real(p_real, rec_imgs, part, x_real_aug)
    lossD = lossD_fake + lossD_real
    # generator with fixed latent values for test
    # Use train=True even in an inference phase
    z_test = nn.Variable.from_numpy_array(
        np.random.randn(args.batch_size, args.latent, 1, 1))
    x_test = Generator(z_test,
                       scope_name=scope_gen,
                       train=True,
                       img_size=args.image_size)[0]

    # Exponential Moving Average (EMA) model
    # Use train=True even in an inference phase
    scope_gen_ema = "Generator_EMA"
    x_test_ema = Generator(z_test,
                           scope_name=scope_gen_ema,
                           train=True,
                           img_size=args.image_size)[0]
    copy_params(scope_gen, scope_gen_ema)
    update_ema_var = make_ema_updater(scope_gen_ema, scope_gen, 0.999)

    # Solver
    solver_gen = S.Adam(args.lr, beta1=0.5)
    solver_dis = S.Adam(args.lr, beta1=0.5)
    with nn.parameter_scope(scope_gen):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope(scope_dis):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10)
    monitor_loss_dis_real = MonitorSeries("Discriminator Loss Real",
                                          monitor,
                                          interval=10)
    monitor_loss_dis_fake = MonitorSeries("Discriminator Loss Fake",
                                          monitor,
                                          interval=10)
    monitor_time = MonitorTimeElapsed("Training Time", monitor, interval=10)
    monitor_image_tile_train = MonitorImageTile("Image Tile Train",
                                                monitor,
                                                num_images=args.batch_size,
                                                interval=1,
                                                normalize_method=lambda x:
                                                (x + 1.) / 2.)
    monitor_image_tile_test = MonitorImageTile("Image Tile Test",
                                               monitor,
                                               num_images=args.batch_size,
                                               interval=1,
                                               normalize_method=lambda x:
                                               (x + 1.) / 2.)
    monitor_image_tile_test_ema = MonitorImageTile("Image Tile Test EMA",
                                                   monitor,
                                                   num_images=args.batch_size,
                                                   interval=1,
                                                   normalize_method=lambda x:
                                                   (x + 1.) / 2.)

    # Data Iterator
    rng = np.random.RandomState(141)
    di = data_iterator(args.img_path,
                       args.batch_size,
                       imsize=(args.image_size, args.image_size),
                       num_samples=args.train_samples,
                       rng=rng)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake[0].need_grad = False  # no need backward to generator
        x_fake[1].need_grad = False  # no need backward to generator
        solver_dis.zero_grad()
        x_real.d = di.next()[0]
        z.d = np.random.randn(args.batch_size, args.latent, 1, 1)
        lossD.forward()
        lossD.backward()
        solver_dis.update()

        # Train generator
        x_fake[0].need_grad = True  # need backward to generator
        x_fake[1].need_grad = True  # need backward to generator
        solver_gen.zero_grad()
        lossG.forward()
        lossG.backward()
        solver_gen.update()

        # Update EMA model
        update_ema_var.forward()

        # Monitor
        monitor_loss_gen.add(i, lossG.d)
        monitor_loss_dis_real.add(i, lossD_real.d)
        monitor_loss_dis_fake.add(i, lossD_fake.d)
        monitor_time.add(i)

        # Save
        if (i + 1) % args.save_interval == 0:
            with nn.parameter_scope(scope_gen):
                nn.save_parameters(
                    os.path.join(args.monitor_path,
                                 "Gen_iter{}.h5".format(i + 1)))
            with nn.parameter_scope(scope_gen_ema):
                nn.save_parameters(
                    os.path.join(args.monitor_path,
                                 "GenEMA_iter{}.h5".format(i + 1)))
            with nn.parameter_scope(scope_dis):
                nn.save_parameters(
                    os.path.join(args.monitor_path,
                                 "Dis_iter{}.h5".format(i + 1)))
        if (i + 1) % args.test_interval == 0:
            x_test.forward(clear_buffer=True)
            x_test_ema.forward(clear_buffer=True)
            monitor_image_tile_train.add(i + 1, x_fake[0])
            monitor_image_tile_test.add(i + 1, x_test)
            monitor_image_tile_test_ema.add(i + 1, x_test_ema)

    # Last
    x_test.forward(clear_buffer=True)
    x_test_ema.forward(clear_buffer=True)
    monitor_image_tile_train.add(args.max_iter, x_fake[0])
    monitor_image_tile_test.add(args.max_iter, x_test)
    monitor_image_tile_test_ema.add(args.max_iter, x_test_ema)
    with nn.parameter_scope(scope_gen):
        nn.save_parameters(
            os.path.join(args.monitor_path,
                         "Gen_iter{}.h5".format(args.max_iter)))
    with nn.parameter_scope(scope_gen_ema):
        nn.save_parameters(
            os.path.join(args.monitor_path,
                         "GenEMA_iter{}.h5".format(args.max_iter)))
    with nn.parameter_scope(scope_dis):
        nn.save_parameters(
            os.path.join(args.monitor_path,
                         "Dis_iter{}.h5".format(args.max_iter)))
Beispiel #12
0
def train():
    parser = argparse.ArgumentParser()
    parser.add_argument("--num-train-examples", type=int, default=1600)
    parser.add_argument("--num-valid-examples", type=int, default=100)
    parser.add_argument("--accum-grad", type=int, default=32)
    parser.add_argument("--max-iter", type=int, default=6400)
    parser.add_argument("--valid-interval", type=int, default=100)
    parser.add_argument("--context", type=str, default="cpu")
    parser.add_argument("--device-id", type=int, default=0)

    args = parser.parse_args()

    from nnabla.ext_utils import get_extension_context
    extension_module = args.context
    ctx = get_extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # prepare dataset
    tdataset = []
    for i in range(args.num_train_examples):
        V, E = random_graph(rng)
        deg = degrees(V, E)
        tdataset.append(([V], [utils.from_adjacency_list(E)], [deg]))

    vdataset = []
    for i in range(args.num_valid_examples):
        V, E = random_graph(rng)
        deg = degrees(V, E)
        vdataset.append(([V], [utils.from_adjacency_list(E)], [deg]))

    # prepare data iterator
    tdata = data_iterator(SimpleDataSource2(tdataset, shuffle=True), 1, False,
                          False, False)
    vdata = data_iterator(SimpleDataSource2(vdataset, shuffle=False), 1, False,
                          False, False)

    # prepare monitors
    monitor = M.Monitor("./degree")
    tloss = M.MonitorSeries("Training Loss", monitor, interval=10)

    verror = M.MonitorSeries("Validation Error", monitor, interval=10)

    # prepare solver
    solver = S.Adam()

    # training loop
    for i in range(args.max_iter):
        l = 0
        for b in range(args.accum_grad):
            # read data
            V, E, degree = tdata.next()
            V = V[0][0]
            E = E[0][0]
            degree = degree[0][0]

            # predict
            output = predict(V, E)

            # initialize solver
            if i == 0 and b == 0:
                solver.set_parameters(nn.get_parameters())

            # calculate loss
            label = nn.Variable(degree.shape)
            label.data.data = degree
            label = F.reshape(label, (len(V), 1))
            loss = F.mean(F.squared_error(output, label))

            # training
            loss.forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            l += loss.data.data

        solver.update()

        tloss.add(i, l / args.accum_grad)
        l = 0

        if i % args.valid_interval == 0:
            # validation
            # read data
            e = 0
            n = 0
            for b in range(vdata.size):
                V, E, degree = vdata.next()
                V = V[0][0]
                E = E[0][0]
                degree = degree[0][0]

                output = predict(V, E)

                label = nn.Variable(degree.shape)
                label.data.data = degree
                label = F.reshape(label, (len(V), 1))
                error = F.sum(F.less_scalar(F.abs(F.sub2(output, label)), 0.5))

                error.forward()

                e += error.data.data
                n += len(V)
            verror.add(i, e / n)
Beispiel #13
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_valid_samples = 10000
    bs_valid = args.batch_size
    extension_module = args.context
    ctx = get_extension_context(
        extension_module, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Dataset
    data_iterator = data_iterator_cifar10
    n_class = 10

    # Model architecture
    if args.net == "resnet18":
        prediction = functools.partial(
            resnet18_prediction, ncls=n_class, nmaps=64, act=F.relu)
    if args.net == "resnet34":
        prediction = functools.partial(
            resnet34_prediction, ncls=n_class, nmaps=64, act=F.relu)

    # Create training graphs
    test = False
    if args.mixtype == "mixup":
        mdl = MixupLearning(args.batch_size, alpha=args.alpha)
    elif args.mixtype == "cutmix":
        mdl = CutmixLearning((args.batch_size, 3, 32, 32),
                             alpha=args.alpha, cutmix_prob=1.0)
    elif args.mixtype == "vhmixup":
        mdl = VHMixupLearning((args.batch_size, 3, 32, 32), alpha=args.alpha)
    else:
        print("[ERROR] Unknown mixtype: " + args.mixtype)
        return
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    mix_image, mix_label = mdl.mix_data(single_image_augment(
        image_train), F.one_hot(label_train, (n_class, )))
    pred_train = prediction(mix_image, test)
    loss_train = mdl.loss(pred_train, mix_label)
    input_train = {"image": image_train, "label": label_train}

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = prediction(image_valid, test)
    input_valid = {"image": image_valid}

    # Solvers
    if args.solver == "Adam":
        solver = S.Adam()
    elif args.solver == "Momentum":
        solver = S.Momentum(lr=args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.save_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Data Iterator
    tdata = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)

    print("Size of the training data: %d " % tdata.size)
    # Training-loop
    for i in range(args.max_iter):
        # Forward/Zerograd/Backward
        image, label = tdata.next()
        input_train["image"].d = image
        input_train["label"].d = label
        mdl.set_mix_ratio()
        loss_train.forward()
        solver.zero_grad()
        loss_train.backward()

        # Model update by solver
        if args.solver == "Momentum":
            if i == args.max_iter / 2:
                solver.set_learning_rate(args.learning_rate / 10.0)
            if i == args.max_iter / 4 * 3:
                solver.set_learning_rate(args.learning_rate / 10.0**2)
        solver.update()

        # Validation
        if (i+1) % args.val_interval == 0 or i == 0:
            ve = 0.
            vdata._reset()
            vdata_pred = np.zeros((n_valid_samples, n_class))
            vdata_label = np.zeros((n_valid_samples, 1), dtype=np.int32)
            for j in range(0, n_valid_samples, args.batch_size):
                image, label = vdata.next()
                input_valid["image"].d = image
                pred_valid.forward()
                vdata_pred[j:min(j+args.batch_size, n_valid_samples)
                           ] = pred_valid.d[:min(args.batch_size, n_valid_samples-j)]
                vdata_label[j:min(j+args.batch_size, n_valid_samples)
                            ] = label[:min(args.batch_size, n_valid_samples-j)]
            ve = categorical_error(vdata_pred, vdata_label)
            monitor_verr.add(i+1, ve)

        if int((i+1) % args.model_save_interval) == 0:
            nn.save_parameters(os.path.join(
                args.save_path, 'params_%06d.h5' % (i+1)))

        # Monitering
        monitor_loss.add(i+1, loss_train.d.copy())
        monitor_time.add(i+1)

    nn.save_parameters(os.path.join(args.save_path,
                                    'params_%06d.h5' % (args.max_iter)))
Beispiel #14
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default=None, type=str)
    parser.add_argument('--info', default=None, type=str)
    args = parser.parse_args()

    config = load_transformer_config(args.config)
    if args.info:
        config["experiment_name"] += args.info

    pprint.pprint(config)

    #########################
    # Context Setting
    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info(f'Running in {config["context"]}.')
    ctx = get_extension_context(config["context"],
                                device_id=config["device_id"])
    nn.set_default_context(ctx)
    #########################

    # Data Loading
    logger.info('Initialing Datasource')
    train_iterator_src = data.celebv_data_iterator(
        dataset_mode="transformer",
        celeb_name=config["src_celeb_name"],
        data_dir=config["train_dir"],
        ref_dir=config["ref_dir"],
        mode=config["mode"],
        batch_size=config["train"]["batch_size"],
        shuffle=config["train"]["shuffle"],
        with_memory_cache=config["train"]["with_memory_cache"],
        with_file_cache=config["train"]["with_file_cache"],
        resize_size=config["preprocess"]["resize_size"],
        line_thickness=config["preprocess"]["line_thickness"],
        gaussian_kernel=config["preprocess"]["gaussian_kernel"],
        gaussian_sigma=config["preprocess"]["gaussian_sigma"])

    train_iterator_trg = data.celebv_data_iterator(
        dataset_mode="transformer",
        celeb_name=config["trg_celeb_name"],
        data_dir=config["train_dir"],
        ref_dir=config["ref_dir"],
        mode=config["mode"],
        batch_size=config["train"]["batch_size"],
        shuffle=config["train"]["shuffle"],
        with_memory_cache=config["train"]["with_memory_cache"],
        with_file_cache=config["train"]["with_file_cache"],
        resize_size=config["preprocess"]["resize_size"],
        line_thickness=config["preprocess"]["line_thickness"],
        gaussian_kernel=config["preprocess"]["gaussian_kernel"],
        gaussian_sigma=config["preprocess"]["gaussian_sigma"])
    train_iterators = (train_iterator_src, train_iterator_trg)
    # monitor
    monitor = nm.Monitor(
        os.path.join(config["logdir"], "transformer",
                     f'{config["src_celeb_name"]}2{config["trg_celeb_name"]}',
                     config["experiment_name"]))

    # Network
    netG = {
        'netG_A2B': models.netG_transformer,
        'netG_B2A': models.netG_transformer
    }
    netD = {
        'netD_A': models.netD_transformer,
        'netD_B': models.netD_transformer
    }

    # Optimizer
    solver_netG = {
        'netG_A2B':
        S.Adam(alpha=config["train"]["lr"],
               beta1=config["train"]["beta1"],
               beta2=config["train"]["beta2"]),
        'netG_B2A':
        S.Adam(alpha=config["train"]["lr"],
               beta1=config["train"]["beta1"],
               beta2=config["train"]["beta2"])
    }

    solver_netD = {
        'netD_A':
        S.Adam(alpha=0.5 * config["train"]["lr"],
               beta1=config["train"]["beta1"],
               beta2=config["train"]["beta2"]),
        'netD_B':
        S.Adam(alpha=0.5 * config["train"]["lr"],
               beta1=config["train"]["beta1"],
               beta2=config["train"]["beta2"])
    }

    train_transformer(config, netG, netD, solver_netG, solver_netD,
                      train_iterators, monitor)
def train():
    """
    Main script.
    """

    args = get_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    if args.tiny_mode:
        # We use Tiny ImageNet from Stanford CS231N class.
        # (Tiny ImageNet, https://tiny-imagenet.herokuapp.com/)
        # Tiny ImageNet consists of 200 categories, each category has 500 images
        # in training set. The image size is 64x64. To adapt ResNet into 64x64
        # image inputs, the input image size of ResNet is set as 56x56, and
        # the stride in the first conv and the first max pooling are removed.
        # Please check README.
        data = data_iterator_tiny_imagenet(args.batch_size, 'train')
        vdata = data_iterator_tiny_imagenet(args.batch_size, 'val')
        num_classes = 200
    else:
        # We use ImageNet.
        # (ImageNet, https://imagenet.herokuapp.com/)
        # ImageNet consists of 1000 categories, each category has 1280 images
        # in training set. The image size is various. To adapt ResNet into
        # 320x320 image inputs, the input image size of ResNet is set as
        # 224x224. We need to get tar file and create cache file(320x320 images).
        # Please check README.
        data = data_iterator_imagenet(args.batch_size,
                                      args.train_cachefile_dir)
        vdata = data_iterator_imagenet(args.batch_size, args.val_cachefile_dir)
        num_classes = 1000
    t_model = get_model(args, num_classes, test=False, tiny=args.tiny_mode)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward
    t_pred2 = t_model.pred.unlinked()
    t_e = F.mean(F.top_n_error(t_pred2, t_model.label))
    v_model = get_model(args, num_classes, test=True, tiny=args.tiny_mode)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward
    v_pred2 = v_model.pred.unlinked()
    v_e = F.mean(F.top_n_error(v_pred2, v_model.label))

    # Create Solver.
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=10)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=10)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_vtime = M.MonitorTimeElapsed("Validation time",
                                         monitor,
                                         interval=10)

    # Training loop.
    for i in range(args.max_iter):
        # Save parameters
        if i % args.model_save_interval == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'param_%06d.h5' % i))

        # Validation
        if i % args.val_interval == 0 and i != 0:

            # Clear all intermediate memory to save memory.
            # t_model.loss.clear_recursive()

            l = 0.0
            e = 0.0
            for j in range(args.val_iter):
                images, labels = vdata.next()
                v_model.image.d = images
                v_model.label.d = labels
                v_model.image.data.cast(np.uint8, ctx)
                v_model.label.data.cast(np.int32, ctx)
                v_model.loss.forward(clear_buffer=True)
                v_e.forward(clear_buffer=True)
                l += v_model.loss.d
                e += v_e.d
            monitor_vloss.add(i, l / args.val_iter)
            monitor_verr.add(i, e / args.val_iter)
            monitor_vtime.add(i)

            # Clear all intermediate memory to save memory.
            # v_model.loss.clear_recursive()

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        def accumulate_error(l, e, t_model, t_e):
            l += t_model.loss.d
            e += t_e.d
            return l, e

        # Gradient accumulation loop
        for j in range(args.accum_grad):
            images, labels = data.next()
            if j != 0:
                # Update e and l according to previous results of forward
                # propagation.
                # The update of last iteration is performed
                # after solver update to avoid unnecessary CUDA synchronization.
                # This is performed after data.next() in order to overlap
                # the data loading and graph execution.
                # TODO: Move this to the bottom of the loop when prefetch
                # data loader is available.
                l, e = accumulate_error(l, e, t_model, t_e)
            t_model.image.d = images
            t_model.label.d = labels
            t_model.image.data.cast(np.uint8, ctx)
            t_model.label.data.cast(np.int32, ctx)
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            t_e.forward(clear_buffer=True)

        solver.weight_decay(args.weight_decay)
        solver.update()

        # Accumulate errors after solver update
        l, e = accumulate_error(l, e, t_model, t_e)

        monitor_loss.add(i, l / args.accum_grad)
        monitor_err.add(i, e / args.accum_grad)
        monitor_time.add(i)

        # Learning rate decay at scheduled iter
        if i in args.learning_rate_decay_at:
            solver.set_learning_rate(solver.learning_rate() * 0.1)
    nn.save_parameters(
        os.path.join(args.model_save_path, 'param_%06d.h5' % args.max_iter))
Beispiel #16
0
def main():

    # Get arguments
    args = get_args()
    data_file = "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt"
    model_file = args.work_dir + "model.h5"

    # Load Dataset
    itow, wtoi, dataset = load_ptbset(data_file)

    # Computation environment settings
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create data provider
    n_word = len(wtoi)
    n_dim = args.embed_dim
    batchsize = args.batchsize
    half_window = args.half_window_length
    n_negative = args.n_negative_sample

    di = DataIteratorForEmbeddingLearning(
        batchsize=batchsize,
        half_window=half_window,
        n_negative=n_negative,
        dataset=dataset)

    # Create model
    # - Real batch size including context samples and negative samples
    size = batchsize * (1 + n_negative) * (2 * (half_window - 1))

    # Model for learning
    # - input variables
    xl = nn.Variable((size,))  # variable for word
    yl = nn.Variable((size,))  # variable for context

    # Embed layers for word embedding function
    # - f_embed : word index x to get y, the n_dim vector
    # --  for each sample in a minibatch
    hx = PF.embed(xl, n_word, n_dim, name="e1")  # feature vector for word
    hy = PF.embed(yl, n_word, n_dim, name="e1")  # feature vector for context
    hl = F.sum(hx * hy, axis=1)

    # -- Approximated likelihood of context prediction
    # pos: word context, neg negative samples
    tl = nn.Variable([size, ], need_grad=False)
    loss = F.sigmoid_cross_entropy(hl, tl)
    loss = F.mean(loss)

    # Model for test of searching similar words
    xr = nn.Variable((1,), need_grad=False)
    hr = PF.embed(xr, n_word, n_dim, name="e1")  # feature vector for test

    # Create solver
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    monitor = M.Monitor(args.work_dir)
    monitor_loss = M.MonitorSeries(
        "Training loss", monitor, interval=args.monitor_interval)
    monitor_time = M.MonitorTimeElapsed(
        "Training time", monitor, interval=args.monitor_interval)

    # Do training
    max_epoch = args.max_epoch
    for epoch in range(max_epoch):

        # iteration per epoch
        for i in range(di.n_batch):

            # get minibatch
            xi, yi, ti = di.next()

            # learn
            solver.zero_grad()
            xl.d, yl.d, tl.d = xi, yi, ti
            loss.forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            solver.update()

            # monitor
            itr = epoch * di.n_batch + i
            monitor_loss.add(itr, loss.d)
            monitor_time.add(itr)

    # Save model
    nn.save_parameters(model_file)

    # Evaluate by similarity
    max_check_words = args.max_check_words
    for i in range(max_check_words):

        # prediction
        xr.d = i
        hr.forward(clear_buffer=True)
        h = hr.d

        # similarity calculation
        w = nn.get_parameters()['e1/embed/W'].d
        s = np.sqrt((w * w).sum(1))
        w /= s.reshape((s.shape[0], 1))
        similarity = w.dot(h[0]) / s[i]

        # for understanding
        output_similar_words(itow, i, similarity)
Beispiel #17
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--output-filename',
                        '-o',
                        type=str,
                        default=None,
                        help="name of an output image file.")
    parser.add_argument('--output-dir',
                        '-d',
                        type=str,
                        default="results",
                        help="directory where the generated image is saved.")

    parser.add_argument('--seed',
                        type=int,
                        required=True,
                        help="seed for primal style noise.")
    parser.add_argument('--stochastic-seed',
                        type=int,
                        default=1,
                        help="seed for noises added to intermediate features.")

    parser.add_argument('--truncation-psi',
                        default=0.5,
                        type=float,
                        help="value for truncation trick.")

    parser.add_argument(
        '--mixing',
        action='store_true',
        help="if specified, apply style mixing with additional seed.")
    parser.add_argument('--seed-mix',
                        type=int,
                        default=None,
                        help="seed for another / secondary style noise.")
    parser.add_argument('--mix-after',
                        type=int,
                        default=7,
                        help="after this layer, style mixing is applied.")

    parser.add_argument('--context',
                        '-c',
                        type=str,
                        default="cudnn",
                        help="context. cudnn is recommended.")

    args = parser.parse_args()

    assert 0 < args.mix_after < 17, "specify --mix-after from 1 to 16."

    if not os.path.isfile("styleGAN2_G_params.h5"):
        print("Downloading the pretrained weight. Please wait...")
        url = "https://nnabla.org/pretrained-models/nnabla-examples/GANs/stylegan2/styleGAN2_G_params.h5"
        from nnabla.utils.data_source_loader import download
        download(url, url.split('/')[-1], False)

    ctx = get_extension_context(args.context)
    nn.set_default_context(ctx)

    batch_size = 1
    num_layers = 18

    rnd = np.random.RandomState(args.seed)
    z = rnd.randn(batch_size, 512)

    print("Generation started...")
    print(f"truncation value: {args.truncation_psi}")
    print(f"seed for additional noise: {args.stochastic_seed}")
    if args.mixing:
        # apply style mixing
        assert args.seed_mix
        print(
            f"using style noise seed {args.seed} for layers 0-{args.mix_after - 1}"
        )
        print(
            f"using style noise seed {args.seed_mix} for layers {args.mix_after}-{num_layers}."
        )
        rnd = np.random.RandomState(args.seed_mix)
        z2 = rnd.randn(batch_size, 512)
        style_noises = [
            nn.Variable((batch_size, 512)).apply(d=z)
            for _ in range(args.mix_after)
        ]
        style_noises += [
            nn.Variable((batch_size, 512)).apply(d=z2)
            for _ in range(num_layers - args.mix_after)
        ]
    else:
        # no style mixing (single noise / style is used)
        print(f"using style noise seed {args.seed} for entire layers.")
        style_noise = nn.Variable((batch_size, 512)).apply(d=z)
        style_noises = [style_noise for _ in range(num_layers)]

    nn.load_parameters("styleGAN2_G_params.h5")
    rgb_output = generate(batch_size, style_noises, args.stochastic_seed,
                          args.truncation_psi)
    rgb_output.forward()

    # convert to uint8 to save an image file
    image = convert_images_to_uint8(rgb_output, drange=[-1, 1])
    if args.output_filename is None:
        if not args.mixing:
            filename = f"seed{args.seed}.png"
        else:
            filename = f"seed{args.seed}_{args.seed_mix}.png"
    else:
        filename = args.output_filename

    os.makedirs(args.output_dir, exist_ok=True)
    filepath = os.path.join(args.output_dir, filename)
    imsave(filepath, image, channel_first=True)
    print(f"Genetation completed. Saved {filepath}.")
Beispiel #18
0
def generate(args):
    # Load model
    nn.load_parameters(args.model_load_path)

    # Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Input
    b, c, h, w = 1, 3, args.image_size, args.image_size
    x_real_a = nn.Variable([b, c, h, w])
    x_real_b = nn.Variable([b, c, h, w])
    one = nn.Variable.from_numpy_array(np.ones((1, 1, 1, 1)) * 0.5)

    # Model
    maps = args.maps
    # content/style (domain A)
    x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a")
    x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a")
    # content/style (domain B)
    x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b")
    x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b")
    # generate over domains and reconstruction of content and style (domain A)
    z_style_a = F.randn(
        shape=x_style_a.shape) if not args.example_guided else x_style_a
    z_style_a = z_style_a.apply(persistent=True)
    x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a")
    # generate over domains and reconstruction of content and style (domain B)
    z_style_b = F.randn(
        shape=x_style_b.shape) if not args.example_guided else x_style_b
    z_style_b = z_style_b.apply(persistent=True)
    x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b")

    # Monitor
    suffix = "Stochastic" if not args.example_guided else "Example-guided"
    monitor = Monitor(args.monitor_path)
    monitor_image_a = MonitorImage("Fake Image B to A {} Valid".format(suffix),
                                   monitor,
                                   interval=1)
    monitor_image_b = MonitorImage("Fake Image A to B {} Valid".format(suffix),
                                   monitor,
                                   interval=1)

    # DataIterator
    di_a = munit_data_iterator(args.img_path_a, args.batch_size)
    di_b = munit_data_iterator(args.img_path_b, args.batch_size)

    # Generate all
    # generate (A -> B)
    if args.example_guided:
        x_real_b.d = di_b.next()[0]
    for i in range(di_a.size):
        x_real_a.d = di_a.next()[0]
        images = []
        images.append(x_data_a)
        for _ in range(args.num_repeats):
            x_fake_b.forward(clear_buffer=True)
            images.append(x_fake_b.d.copy())
        monitor_image_b.add(i, np.concatenate(images, axis=3))

    # generate (B -> A)
    if args.example_guided:
        x_real_a.d = di_a.next()[0]
    for i in range(di_b.size):
        x_real_b.d = di_b.next()[0]
        x_fake_a.forward(clear_buffer=True)
        images = []
        images.append(x_data_b)
        for _ in range(args.num_repeats):
            x_fake_a.forward(clear_buffer=True)
            images.append(x_fake_a.d.copy())
        monitor_image_a.add(i, np.concatenate(images, axis=3))
Beispiel #19
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """
    args = get_args()

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    mnist_cnn_prediction = mnist_lenet_prediction
    if args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image, test=False)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create predition graph.
    vpred = mnist_cnn_prediction(vimage, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    data = data_iterator_mnist(args.batch_size, True)
    vdata = data_iterator_mnist(args.batch_size, False)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if i % args.model_save_interval == 0:
            nn.save_parameters(os.path.join(
                args.model_save_path, 'params_%06d.h5' % i))
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)
        # IPython.embed()
    nn.save_parameters(os.path.join(args.model_save_path,
                                    'params_%06d.h5' % args.max_iter))
Beispiel #20
0
def test_nan_inf_tracer(batch_size, n_class, ext_name, trace_nan, trace_inf):
    nn.clear_parameters()

    ctx = get_extension_context(ext_name)
    nn.set_default_context(ctx)

    x = nn.Variable.from_numpy_array(
        np.random.normal(size=(batch_size, 3, 16, 16)))
    t = nn.Variable.from_numpy_array(
        np.random.randint(low=0, high=n_class, size=(batch_size, 1)))

    y = simple_cnn(x, t, n_class)

    must_be_inf = y / F.constant(0, shape=y.shape)
    must_be_nan = must_be_inf / must_be_inf

    # Refresh all arrays once so as to ensure all grad values are 0.
    must_be_nan.visit(_refresh_inputs_grad)

    nit = NanInfTracer(trace_nan=trace_nan, trace_inf=trace_inf)

    # can be run at any cases without exception.
    with nit.trace():
        y.forward(clear_no_need_grad=True,
                  function_post_hook=nit.forward_post_hook)
        y.backward(clear_buffer=True,
                   function_post_hook=nit.backward_post_hook)

    nit.check()  # this call can also work without exception.

    # check nan
    if trace_nan:
        with pytest.raises(ValueError):
            with nit.trace():
                must_be_nan.forward(clear_buffer=True,
                                    function_post_hook=nit.forward_post_hook)

        with pytest.raises(ValueError):
            with nit.trace():
                must_be_nan.backward(clear_buffer=True,
                                     function_post_hook=nit.backward_post_hook)

        must_be_nan.forward(clear_buffer=True,
                            function_post_hook=nit.forward_post_hook)
        with pytest.raises(ValueError):
            nit.check()

        must_be_nan.backward(clear_buffer=True,
                             function_post_hook=nit.backward_post_hook)

        with pytest.raises(ValueError):
            nit.check()

    # check inf
    if trace_inf:
        with pytest.raises(ValueError):
            with nit.trace():
                must_be_inf.forward(clear_buffer=True,
                                    function_post_hook=nit.forward_post_hook)

        must_be_inf.forward(clear_buffer=True,
                            function_post_hook=nit.forward_post_hook)
        with pytest.raises(ValueError):
            nit.check()
Beispiel #21
0
def main(args):
    from network import implicit_network

    # Setting
    # nn.set_auto_forward(True)
    ctx = get_extension_context('cudnn', device_id=args.device_id)
    nn.set_default_context(ctx)
    D = args.depth
    L = args.layers
    W = args.width
    H = args.height
    R = H * W
    z_orientation = 1

    # Camera parameters
    camera = Camera(image_width=W, image_height=H, z_orientation=z_orientation)
    camloc = np.array([0.75, 0.5, 1])
    camloc = (camloc / np.sum(camloc**2)**0.5) * 2
    to = np.array([0, 0, 0])
    Rt_inv = look_at(camloc, to, z_orientation=z_orientation)
    R_inv = Rt_inv[:3, :3]
    fov = 90
    K_inv = camera.compute_intrinsic_inv(fov)

    # Rays
    x, y = np.meshgrid(np.arange(W), np.arange(H), indexing="xy")
    xy = np.asarray([x.flatten(), y.flatten()])
    xy1 = np.concatenate([xy, np.ones(R)[np.newaxis, :]])
    raydir = R_inv.dot(K_inv.dot(xy1))
    raydir = raydir / np.sum(raydir**2, axis=0)**0.5
    raydir = raydir.transpose((1, 0))

    # Network
    camloc = nn.Variable.from_numpy_array(camloc[np.newaxis, ...])
    raydir = nn.Variable.from_numpy_array(raydir[np.newaxis, ...])
    sdf_net = partial(implicit_network,
                      D=D,
                      L=L,
                      initial_sphere_radius=args.initial_sphere_radius)
    sdf_net0 = sdf_net

    def sdf_net0(x):
        out = sdf_net(x)
        sdf = out[..., 0][..., np.newaxis]
        return sdf

    # Sphere trace
    t_near = args.t_near
    t_far = args.t_far
    sphere_trace_itr = args.sphere_trace_itr
    ray_march_points = args.ray_march_points
    n_chunks = args.n_chunks
    max_post_itr = args.max_post_itr
    post_method = args.post_method
    eps = args.eps
    st = time.time()
    x_hit, mask_hit, dists, _, _ = ray_trace(sdf_net0,
                                             camloc,
                                             raydir,
                                             test=True,
                                             t_near=t_near,
                                             t_far=t_far,
                                             sphere_trace_itr=sphere_trace_itr,
                                             ray_march_points=ray_march_points,
                                             n_chunks=n_chunks,
                                             max_post_itr=max_post_itr,
                                             post_method=post_method,
                                             eps=eps)

    x_hit.need_grad = False
    dists.need_grad = False
    mask_hit.need_grad = False

    x_curr = x_hit
    F.sink(*[x_curr, mask_hit]).forward(clear_buffer=False)
    # Lighting
    x_curr = x_curr.get_unlinked_variable(need_grad=True)
    sdf = sdf_net0(x_curr)
    normal = nn.grad([sdf], [x_curr])[0]
    normal = F.norm_normalization(normal, axes=normal.ndim - 1, eps=1e-24)
    dlight = DistantLight()
    cos = lambert(normal, dlight.direction.reshape([3, 1])).reshape((1, H, W))
    mask_hit = mask_hit.get_unlinked_variable(need_grad=False)
    mask_hit = F.reshape(mask_hit, (1, H, W))
    mask_hit = F.broadcast(mask_hit, (3, H, W))
    image = mask_hit * 255.0 * cos
    image.forward(clear_buffer=True)

    cv2.imwrite(
        f"sphere_{W}x{H}_sti{sphere_trace_itr:03d}_mpi{max_post_itr:03d}_{args.post_method}.png",
        image.d.transpose(1, 2, 0))
    print(
        f"Bidirectional sphere trace/ray march (W={W}, H={H}): {time.time() - st} [s]"
    )
Beispiel #22
0
def train(args):
    """
    Main script.
    """

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    margin = 1.0  # Margin for contrastive loss.

    # TRAIN
    # Create input variables.
    image0 = nn.Variable([args.batch_size, 1, 28, 28])
    image1 = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size])
    # Create prediction graph.
    pred = mnist_lenet_siamese(image0, image1, test=False)
    # Create loss function.
    loss = F.mean(contrastive_loss(pred, label, margin))

    # TEST
    # Create input variables.
    vimage0 = nn.Variable([args.batch_size, 1, 28, 28])
    vimage1 = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size])
    # Create prediction graph.
    vpred = mnist_lenet_siamese(vimage0, vimage1, test=True)
    vloss = F.mean(contrastive_loss(vpred, vlabel, margin))

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    start_point = 0
    if args.checkpoint is not None:
        # load weights and solver state info from specified checkpoint file.
        start_point = load_checkpoint(args.checkpoint, solver)

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_vloss = M.MonitorSeries("Test loss", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    rng = np.random.RandomState(313)
    data = siamese_data_iterator(args.batch_size, True, rng)
    vdata = siamese_data_iterator(args.batch_size, False, rng)

    # Training loop.
    for i in range(start_point, args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage0.d, vimage1.d, vlabel.d = vdata.next()
                vloss.forward(clear_buffer=True)
                ve += vloss.d
            monitor_vloss.add(i, ve / args.val_iter)
        if i % args.model_save_interval == 0:
            # save checkpoint file
            save_checkpoint(args.model_save_path, i, solver)
        image0.d, image1.d, label.d = data.next()
        solver.zero_grad()
        # Training forward, backward and update
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        monitor_loss.add(i, loss.d.copy())
        monitor_time.add(i)

    parameter_file = os.path.join(
        args.model_save_path, 'params_%06d.h5' % args.max_iter)
    nn.save_parameters(parameter_file)
Beispiel #23
0
def main(args):
    from numpy.random import seed
    seed(46)

    # Get context.
    from nnabla.ext_utils import get_extension_context
    ctx = get_extension_context('cudnn', device_id='0', type_config='float')
    nn.set_default_context(ctx)

    # Create CNN network
    # === TRAIN ===
    # Create input variables.
    image = nn.Variable([args.batch_size, 3, args.img_height, args.img_width])
    label = nn.Variable([args.batch_size, 1, args.img_height, args.img_width])
    # Create prediction graph.
    pred = depth_cnn_model(image, test=False)
    pred.persistent = True
    # Create loss function.
    loss = l1_loss(pred, label)
    # === VAL ===
    #vimage = nn.Variable([args.batch_size, 3, args.img_height, args.img_width])
    #vlabel = nn.Variable([args.batch_size, 1, args.img_height, args.img_width])
    #vpred = depth_cnn_model(vimage, test=True)
    #vloss = l1_loss(vpred, vlabel)

    # Prepare monitors.
    monitor = Monitor(os.path.join(args.log_dir, 'nnmonitor'))
    monitors = {
        'train_epoch_loss':
        MonitorSeries('Train epoch loss', monitor, interval=1),
        'train_itr_loss':
        MonitorSeries('Train itr loss', monitor, interval=100),
        # 'val_epoch_loss': MonitorSeries('Val epoch loss', monitor, interval=1),
        'train_viz':
        MonitorImageTile('Train images', monitor, interval=1000, num_images=4)
    }

    # Create Solver. If training from checkpoint, load the info.
    if args.optimizer == "adam":
        solver = S.Adam(alpha=args.learning_rate, beta1=0.9, beta2=0.999)
    elif args.optimizer == "sgd":
        solver = S.Momentum(lr=args.learning_rate, momentum=0.9)
    solver.set_parameters(nn.get_parameters())

    # Initialize DataIterator
    data_dic = prepare_dataloader(args.dataset_path,
                                  datatype_list=['train', 'val'],
                                  batch_size=args.batch_size,
                                  img_size=(args.img_height, args.img_width))

    # Training loop.
    logger.info("Start training!!!")
    total_itr_index = 0
    for epoch in range(1, args.epochs + 1):
        ## === training === ##
        total_train_loss = 0
        index = 0
        while index < data_dic['train']['size']:
            # Preprocess
            image.d, label.d = data_dic['train']['itr'].next()
            loss.forward(clear_no_need_grad=True)
            # Initialize gradients
            solver.zero_grad()
            # Backward execution
            loss.backward(clear_buffer=True)
            # Update parameters by computed gradients
            if args.optimizer == 'sgd':
                solver.weight_decay(1e-4)
            solver.update()

            # Update log
            index += 1
            total_itr_index += 1
            total_train_loss += loss.d

            # Pass to monitor
            monitors['train_itr_loss'].add(total_itr_index, loss.d)

            # Visualization
            pred.forward(clear_buffer=True)
            train_viz = np.concatenate([
                image.d,
                convert_depth2colormap(label.d),
                convert_depth2colormap(pred.d)
            ],
                                       axis=3)
            monitors['train_viz'].add(total_itr_index, train_viz)

            # Logger
            logger.info("[{}] {}/{} Train Loss {} ({})".format(
                epoch, index, data_dic['train']['size'],
                total_train_loss / index, loss.d))

        # Pass training loss to a monitor.
        train_error = total_train_loss / data_dic['train']['size']
        monitors['train_epoch_loss'].add(epoch, train_error)

        # Save Parameter
        out_param_file = os.path.join(args.log_dir,
                                      'checkpoint' + str(epoch) + '.h5')
        nn.save_parameters(out_param_file)
Beispiel #24
0
def train():
    bs_train, bs_valid = args.train_batch_size, args.val_batch_size
    extension_module = args.context
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    train_data_source = data_source_cifar10(train=True,
                                            shuffle=True,
                                            label_shuffle=args.shuffle_label)
    val_data_source = data_source_cifar10(train=False,
                                          shuffle=False,
                                          label_shuffle=False)
    n_train_samples = len(train_data_source.labels)
    n_val_samples = len(val_data_source.labels)
    # Data Iterator
    train_loader = data_iterator(train_data_source, bs_train, None, False,
                                 False)
    val_loader = data_iterator(val_data_source, bs_valid, None, False, False)

    input_save_dir = ("./data/input/shuffle"
                      if args.shuffle_label else "./data/input/no_shuffle")
    if not os.path.exists(input_save_dir):
        os.makedirs(input_save_dir)
    np.save(os.path.join(input_save_dir, "x_train.npy"),
            train_data_source.images)
    np.save(os.path.join(input_save_dir, "y_train.npy"),
            train_data_source.raw_label)
    np.save(os.path.join(input_save_dir, "x_val.npy"), val_data_source.images)
    np.save(os.path.join(input_save_dir, "y_val.npy"), val_data_source.labels)
    if args.shuffle_label:
        np.save(
            os.path.join(input_save_dir, "y_shuffle_train.npy"),
            train_data_source.labels,
        )

    model_prediction = vgg16_prediction
    prediction = functools.partial(model_prediction, ncls=10, seed=args.seed)

    # Create training graphs
    test = False
    image_train = nn.Variable((bs_train, 3, 32, 32))
    label_train = nn.Variable((bs_train, 1))
    pred_train, _ = prediction(image_train, test)
    loss_train = loss_function(pred_train, label_train)

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    label_valid = nn.Variable((bs_valid, 1))
    pred_valid, _ = prediction(image_valid, test)
    loss_val = loss_function(pred_valid, label_valid)

    for param in nn.get_parameters().values():
        param.grad.zero()

    cfg = read_yaml("./learning_rate.yaml")
    print(cfg)
    lr_sched = create_learning_rate_scheduler(cfg.learning_rate_config)
    solver = S.Momentum(momentum=0.9, lr=lr_sched.get_lr())
    solver.set_parameters(nn.get_parameters())
    start_point = 0

    if args.checkpoint is not None:
        # load weights and solver state info from specified checkpoint file.
        start_point = load_checkpoint(args.checkpoint, solver)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed

    monitor = Monitor(monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=1)
    monitor_err = MonitorSeries("Training error", monitor, interval=1)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=1)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)
    monitor_vloss = MonitorSeries("Test loss", monitor, interval=1)

    train_iter = math.ceil(n_train_samples / bs_train)
    val_iter = math.ceil(n_val_samples / bs_valid)

    # Training-loop
    for i in range(start_point, args.train_epochs):
        lr_sched.set_epoch(i)
        solver.set_learning_rate(lr_sched.get_lr())
        print("Learning Rate: ", lr_sched.get_lr())
        # Validation
        ve = 0.0
        vloss = 0.0
        print("## Validation")
        for j in range(val_iter):
            image, label, _ = val_loader.next()
            image_valid.d = image
            label_valid.d = label
            loss_val.forward()
            vloss += loss_val.data.data.copy() * bs_valid
            ve += categorical_error(pred_valid.d, label)
        ve /= args.val_iter
        vloss /= n_val_samples

        monitor_verr.add(i, ve)
        monitor_vloss.add(i, vloss)

        if int((i + 1) % args.model_save_interval) == 0:
            # save checkpoint file
            save_checkpoint(monitor_path, i + 1, solver)

        # Forward/Zerograd/Backward
        print("## Training")
        e = 0.0
        loss = 0.0
        for k in range(train_iter):

            image, label, shuffle = train_loader.next()

            image_train.d = image
            label_train.d = shuffle
            loss_train.forward()
            solver.zero_grad()
            loss_train.backward()
            solver.update()
            e += categorical_error(pred_train.d, label_train.d)
            loss += loss_train.data.data.copy() * bs_train
        e /= train_iter
        loss /= n_train_samples

        e = categorical_error(pred_train.d, label_train.d)
        monitor_loss.add(i, loss)
        monitor_err.add(i, e)
        monitor_time.add(i)

    # save_nnp_lastepoch
    contents = save_nnp({"x": image_valid}, {"y": pred_valid}, bs_valid)
    save.save(os.path.join(monitor_path, ("vgg16_result.nnp")), contents)
Beispiel #25
0
def train():
    """
    Main script.
    """

    args = get_args()

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Dataset
    # We use Tiny ImageNet from Stanford CS231N class.
    # https://tiny-imagenet.herokuapp.com/
    # Tiny ImageNet consists of 200 categories, each category has 500 images
    # in training set. The image size is 64x64. To adapt ResNet into 64x64
    # image inputs, the input image size of ResNet is set as 56x56, and
    # the stride in the first conv and the first max pooling are removed.
    data = data_iterator_tiny_imagenet(args.batch_size, 'train')
    vdata = data_iterator_tiny_imagenet(args.batch_size, 'val')

    num_classes = 200
    tiny = True  # TODO: Switch ILSVRC2012 dataset and TinyImageNet.
    t_model = get_model(
        args, num_classes, test=False, tiny=tiny)
    t_model.pred.persistent = True  # Not clearing buffer of pred in backward
    v_model = get_model(
        args, num_classes, test=True, tiny=tiny)
    v_model.pred.persistent = True  # Not clearing buffer of pred in forward

    # Create Solver.
    solver = S.Momentum(args.learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = M.MonitorSeries("Training error", monitor, interval=10)
    monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=10)
    monitor_verr = M.MonitorSeries("Validation error", monitor, interval=10)
    monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10)

    # Training loop.
    for i in range(args.max_iter):
        # Save parameters
        if i % args.model_save_interval == 0:
            nn.save_parameters(os.path.join(
                args.model_save_path, 'param_%06d.h5' % i))

        # Validation
        if i % args.val_interval == 0:

            # Clear all intermediate memory to save memory.
            # t_model.loss.clear_recursive()

            l = 0.0
            e = 0.0
            for j in range(args.val_iter):
                images, labels = vdata.next()
                v_model.image.d = images
                v_model.label.d = labels
                v_model.image.data.cast(np.uint8, ctx)
                v_model.label.data.cast(np.int32, ctx)
                v_model.loss.forward(clear_buffer=True)
                l += v_model.loss.d
                e += categorical_error(v_model.pred.d, v_model.label.d)
            monitor_vloss.add(i, l / args.val_iter)
            monitor_verr.add(i, e / args.val_iter)

            # Clear all intermediate memory to save memory.
            # v_model.loss.clear_recursive()

        # Training
        l = 0.0
        e = 0.0
        solver.zero_grad()

        # Gradient accumulation loop
        for j in range(args.accum_grad):
            images, labels = data.next()
            t_model.image.d = images
            t_model.label.d = labels
            t_model.image.data.cast(np.uint8, ctx)
            t_model.label.data.cast(np.int32, ctx)
            t_model.loss.forward(clear_no_need_grad=True)
            t_model.loss.backward(clear_buffer=True)  # Accumulating gradients
            l += t_model.loss.d
            e += categorical_error(t_model.pred.d, t_model.label.d)
        solver.weight_decay(args.weight_decay)
        solver.update()
        monitor_loss.add(i, l / args.accum_grad)
        monitor_err.add(i, e / args.accum_grad)
        monitor_time.add(i)

        # Learning rate decay at scheduled iter
        if i in args.learning_rate_decay_at:
            solver.set_learning_rate(solver.learning_rate() * 0.1)
    nn.save_parameters(os.path.join(args.model_save_path,
                                    'param_%06d.h5' % args.max_iter))
Beispiel #26
0
def train(args):

    # get context

    ctx = get_extension_context(args.context)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    device_id = mpi_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    config = read_yaml(args.config)

    if args.info:
        config.monitor_params.info = args.info

    if comm.size == 1:
        comm = None
    else:
        # disable outputs from logger except its rank = 0
        if comm.rank > 0:
            import logging
            logger.setLevel(logging.ERROR)

    test = False
    train_params = config.train_params
    dataset_params = config.dataset_params
    model_params = config.model_params

    loss_flags = get_loss_flags(train_params)

    start_epoch = 0

    rng = np.random.RandomState(device_id)
    data_iterator = frame_data_iterator(
        root_dir=dataset_params.root_dir,
        frame_shape=dataset_params.frame_shape,
        id_sampling=dataset_params.id_sampling,
        is_train=True,
        random_seed=rng,
        augmentation_params=dataset_params.augmentation_params,
        batch_size=train_params['batch_size'],
        shuffle=True,
        with_memory_cache=False,
        with_file_cache=False)

    if n_devices > 1:
        data_iterator = data_iterator.slice(rng=rng,
                                            num_of_slices=comm.size,
                                            slice_pos=comm.rank)
        # workaround not to use memory cache
        data_iterator._data_source._on_memory = False
        logger.info("Disabled on memory data cache.")

    bs, h, w, c = [train_params.batch_size] + dataset_params.frame_shape
    source = nn.Variable((bs, c, h, w))
    driving = nn.Variable((bs, c, h, w))

    with nn.parameter_scope("kp_detector"):
        # kp_X = {"value": Variable((bs, 10, 2)), "jacobian": Variable((bs, 10, 2, 2))}

        kp_source = detect_keypoint(source,
                                    **model_params.kp_detector_params,
                                    **model_params.common_params,
                                    test=test,
                                    comm=comm)
        persistent_all(kp_source)

        kp_driving = detect_keypoint(driving,
                                     **model_params.kp_detector_params,
                                     **model_params.common_params,
                                     test=test,
                                     comm=comm)
        persistent_all(kp_driving)

    with nn.parameter_scope("generator"):
        generated = occlusion_aware_generator(source,
                                              kp_source=kp_source,
                                              kp_driving=kp_driving,
                                              **model_params.generator_params,
                                              **model_params.common_params,
                                              test=test,
                                              comm=comm)
        # generated is a dictionary containing;
        # 'mask': Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25
        # 'sparse_deformed': Variable((bs, num_kp + 1, num_channel, h/4, w/4))
        # 'occlusion_map': Variable((bs, 1, h/4, w/4))
        # 'deformed': Variable((bs, c, h, w))
        # 'prediction': Variable((bs, c, h, w)) Only this is fed to discriminator.

    generated["prediction"].persistent = True

    pyramide_real = get_image_pyramid(driving, train_params.scales,
                                      generated["prediction"].shape[1])
    persistent_all(pyramide_real)

    pyramide_fake = get_image_pyramid(generated['prediction'],
                                      train_params.scales,
                                      generated["prediction"].shape[1])
    persistent_all(pyramide_fake)

    total_loss_G = None  # dammy. defined temporarily
    loss_var_dict = {}

    # perceptual loss using VGG19 (always applied)
    if loss_flags.use_perceptual_loss:
        logger.info("Use Perceptual Loss.")
        scales = train_params.scales
        weights = train_params.loss_weights.perceptual
        vgg_param_path = train_params.vgg_param_path
        percep_loss = perceptual_loss(pyramide_real, pyramide_fake, scales,
                                      weights, vgg_param_path)
        percep_loss.persistent = True
        loss_var_dict['perceptual_loss'] = percep_loss
        total_loss_G = percep_loss

    # (LS)GAN loss and feature matching loss
    if loss_flags.use_gan_loss:
        logger.info("Use GAN Loss.")
        with nn.parameter_scope("discriminator"):
            discriminator_maps_generated = multiscale_discriminator(
                pyramide_fake,
                kp=unlink_all(kp_driving),
                **model_params.discriminator_params,
                **model_params.common_params,
                test=test,
                comm=comm)

            discriminator_maps_real = multiscale_discriminator(
                pyramide_real,
                kp=unlink_all(kp_driving),
                **model_params.discriminator_params,
                **model_params.common_params,
                test=test,
                comm=comm)

        for v in discriminator_maps_generated["feature_maps_1"]:
            v.persistent = True
        discriminator_maps_generated["prediction_map_1"].persistent = True

        for v in discriminator_maps_real["feature_maps_1"]:
            v.persistent = True
        discriminator_maps_real["prediction_map_1"].persistent = True

        for i, scale in enumerate(model_params.discriminator_params.scales):
            key = f'prediction_map_{scale}'.replace('.', '-')
            lsgan_loss_weight = train_params.loss_weights.generator_gan
            # LSGAN loss for Generator
            if i == 0:
                gan_loss_gen = lsgan_loss(discriminator_maps_generated[key],
                                          lsgan_loss_weight)
            else:
                gan_loss_gen += lsgan_loss(discriminator_maps_generated[key],
                                           lsgan_loss_weight)
            # LSGAN loss for Discriminator
            if i == 0:
                gan_loss_dis = lsgan_loss(discriminator_maps_real[key],
                                          lsgan_loss_weight,
                                          discriminator_maps_generated[key])
            else:
                gan_loss_dis += lsgan_loss(discriminator_maps_real[key],
                                           lsgan_loss_weight,
                                           discriminator_maps_generated[key])
        gan_loss_dis.persistent = True
        loss_var_dict['gan_loss_dis'] = gan_loss_dis
        total_loss_D = gan_loss_dis
        total_loss_D.persistent = True

        gan_loss_gen.persistent = True
        loss_var_dict['gan_loss_gen'] = gan_loss_gen
        total_loss_G += gan_loss_gen

        if loss_flags.use_feature_matching_loss:
            logger.info("Use Feature Matching Loss.")
            fm_weights = train_params.loss_weights.feature_matching
            fm_loss = feature_matching_loss(discriminator_maps_real,
                                            discriminator_maps_generated,
                                            model_params, fm_weights)
            fm_loss.persistent = True
            loss_var_dict['feature_matching_loss'] = fm_loss
            total_loss_G += fm_loss

    # transform loss
    if loss_flags.use_equivariance_value_loss or loss_flags.use_equivariance_jacobian_loss:
        transform = Transform(bs, **config.train_params.transform_params)
        transformed_frame = transform.transform_frame(driving)

        with nn.parameter_scope("kp_detector"):
            transformed_kp = detect_keypoint(transformed_frame,
                                             **model_params.kp_detector_params,
                                             **model_params.common_params,
                                             test=test,
                                             comm=comm)
        persistent_all(transformed_kp)

        # Value loss part
        if loss_flags.use_equivariance_value_loss:
            logger.info("Use Equivariance Value Loss.")
            warped_kp_value = transform.warp_coordinates(
                transformed_kp['value'])
            eq_value_weight = train_params.loss_weights.equivariance_value

            eq_value_loss = equivariance_value_loss(kp_driving['value'],
                                                    warped_kp_value,
                                                    eq_value_weight)
            eq_value_loss.persistent = True
            loss_var_dict['equivariance_value_loss'] = eq_value_loss
            total_loss_G += eq_value_loss

        # jacobian loss part
        if loss_flags.use_equivariance_jacobian_loss:
            logger.info("Use Equivariance Jacobian Loss.")
            arithmetic_jacobian = transform.jacobian(transformed_kp['value'])
            eq_jac_weight = train_params.loss_weights.equivariance_jacobian
            eq_jac_loss = equivariance_jacobian_loss(
                kp_driving['jacobian'], arithmetic_jacobian,
                transformed_kp['jacobian'], eq_jac_weight)
            eq_jac_loss.persistent = True
            loss_var_dict['equivariance_jacobian_loss'] = eq_jac_loss
            total_loss_G += eq_jac_loss

    assert total_loss_G is not None
    total_loss_G.persistent = True
    loss_var_dict['total_loss_gen'] = total_loss_G

    # -------------------- Create Monitors --------------------
    monitors_gen, monitors_dis, monitor_time, monitor_vis, log_dir = get_monitors(
        config, loss_flags, loss_var_dict)

    if device_id == 0:
        # Dump training info .yaml
        _ = shutil.copy(args.config, log_dir)  # copy the config yaml
        training_info_yaml = os.path.join(log_dir, "training_info.yaml")
        os.rename(os.path.join(log_dir, os.path.basename(args.config)),
                  training_info_yaml)
        # then add additional information
        with open(training_info_yaml, "a", encoding="utf-8") as f:
            f.write(f"\nlog_dir: {log_dir}\nsaved_parameter: None")

    # -------------------- Solver Setup --------------------
    solvers = setup_solvers(train_params)
    solver_generator = solvers["generator"]
    solver_discriminator = solvers["discriminator"]
    solver_kp_detector = solvers["kp_detector"]

    # max epochs
    num_epochs = train_params['num_epochs']

    # iteration per epoch
    num_iter_per_epoch = data_iterator.size // bs
    # will be increased by num_repeat
    if 'num_repeats' in train_params or train_params['num_repeats'] != 1:
        num_iter_per_epoch *= config.train_params.num_repeats

    # modify learning rate if current epoch exceeds the number defined in
    lr_decay_at_epochs = train_params['epoch_milestones']  # ex. [60, 90]
    gamma = 0.1  # decay rate

    # -------------------- For finetuning ---------------------
    if args.ft_params:
        assert os.path.isfile(args.ft_params)
        logger.info(f"load {args.ft_params} for finetuning.")
        nn.load_parameters(args.ft_params)
        start_epoch = int(
            os.path.splitext(os.path.basename(
                args.ft_params))[0].split("epoch_")[1])

        # set solver's state
        for name, solver in solvers.items():
            saved_states = os.path.join(
                os.path.dirname(args.ft_params),
                f"state_{name}_at_epoch_{start_epoch}.h5")
            solver.load_states(saved_states)

        start_epoch += 1
        logger.info(f"Resuming from epoch {start_epoch}.")

    logger.info(
        f"Start training. Total epoch: {num_epochs - start_epoch}, {num_iter_per_epoch * n_devices} iter/epoch."
    )

    for e in range(start_epoch, num_epochs):
        logger.info(f"Epoch: {e} / {num_epochs}.")
        data_iterator._reset()  # rewind the iterator at the beginning

        # learning rate scheduler
        if e in lr_decay_at_epochs:
            logger.info("Learning rate decayed.")
            learning_rate_decay(solvers, gamma=gamma)

        for i in range(num_iter_per_epoch):
            _driving, _source = data_iterator.next()
            source.d = _source
            driving.d = _driving

            # update generator and keypoint detector
            total_loss_G.forward()

            if device_id == 0:
                monitors_gen.add((e * num_iter_per_epoch + i) * n_devices)

            solver_generator.zero_grad()
            solver_kp_detector.zero_grad()

            callback = None
            if n_devices > 1:
                params = [x.grad for x in solver_generator.get_parameters().values()] + \
                         [x.grad for x in solver_kp_detector.get_parameters().values()]
                callback = comm.all_reduce_callback(params, 2 << 20)
            total_loss_G.backward(clear_buffer=True,
                                  communicator_callbacks=callback)

            solver_generator.update()
            solver_kp_detector.update()

            if loss_flags.use_gan_loss:
                # update discriminator

                total_loss_D.forward(clear_no_need_grad=True)
                if device_id == 0:
                    monitors_dis.add((e * num_iter_per_epoch + i) * n_devices)

                solver_discriminator.zero_grad()

                callback = None
                if n_devices > 1:
                    params = [
                        x.grad for x in
                        solver_discriminator.get_parameters().values()
                    ]
                    callback = comm.all_reduce_callback(params, 2 << 20)
                total_loss_D.backward(clear_buffer=True,
                                      communicator_callbacks=callback)

                solver_discriminator.update()

            if device_id == 0:
                monitor_time.add((e * num_iter_per_epoch + i) * n_devices)

            if device_id == 0 and (
                (e * num_iter_per_epoch + i) *
                    n_devices) % config.monitor_params.visualize_freq == 0:
                images_to_visualize = [
                    source.d, driving.d, generated["prediction"].d
                ]
                visuals = combine_images(images_to_visualize)
                monitor_vis.add((e * num_iter_per_epoch + i) * n_devices,
                                visuals)

        if device_id == 0:
            if e % train_params.checkpoint_freq == 0 or e == num_epochs - 1:
                save_parameters(e, log_dir, solvers)

    return
Beispiel #27
0
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * AllReduce for gradients
      * Solver updates parameters by using gradients computed by backprop and all reduce.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    n_valid_samples = 10000
    bs_valid = args.batch_size
    rng = np.random.RandomState(313)
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=64,
                                       act=F.relu)
        data_iterator = data_iterator_cifar10

    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu)
        data_iterator = data_iterator_cifar100

    # Create Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Create training graphs
    test = False
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test)
    pred_train.persistent = True
    loss_train = loss_function(pred_train, label_train)
    error_train = F.mean(F.top_n_error(pred_train, label_train, axis=1))
    loss_error_train = F.sink(loss_train, error_train)
    input_image_train = {"image": image_train, "label": label_train}

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    label_valid = nn.Variable((args.batch_size, 1))
    pred_valid = prediction(image_valid, test)
    error_valid = F.mean(F.top_n_error(pred_valid, label_valid, axis=1))
    input_image_valid = {"image": image_valid, "label": label_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(
        1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)
    monitor_vtime = MonitorTimeElapsed("Validation time", monitor, interval=1)

    # Data Iterator
    rng = np.random.RandomState(device_id)
    _, tdata = data_iterator(args.batch_size, True, rng)
    vsource, vdata = data_iterator(args.batch_size, False)

    # Training-loop
    ve = nn.Variable()
    for i in range(int(args.max_iter / n_devices)):
        # Validation
        if i % int(n_train_samples / args.batch_size / n_devices) == 0:
            ve_local = 0.
            k = 0
            idx = np.random.permutation(n_valid_samples)
            val_images = vsource.images[idx]
            val_labels = vsource.labels[idx]
            for j in range(int(n_valid_samples / n_devices * mpi_rank),
                           int(n_valid_samples / n_devices * (mpi_rank + 1)),
                           bs_valid):
                image = val_images[j:j + bs_valid]
                label = val_labels[j:j + bs_valid]
                if len(image
                       ) != bs_valid:  # note that smaller batch is ignored
                    continue
                input_image_valid["image"].d = image
                input_image_valid["label"].d = label
                error_valid.forward(clear_buffer=True)
                ve_local += error_valid.d.copy()
                k += 1
            ve_local /= k
            ve.d = ve_local
            comm.all_reduce(ve.data, division=True, inplace=True)

            # Save model
            if device_id == 0:
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)
                if i % int(args.model_save_interval / n_devices) == 0:
                    nn.save_parameters(
                        os.path.join(args.model_save_path,
                                     'params_%06d.h5' % i))

        # Forward/Zerograd
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_error_train.forward(clear_no_need_grad=True)
        solver.zero_grad()

        # Backward/AllReduce
        backward_and_all_reduce(
            loss_train,
            comm,
            with_all_reduce_callback=args.with_all_reduce_callback)

        # Solvers update
        solver.update()

        # Synchronize by averaging the weights over devices using allreduce
        if (i + 1) % args.sync_weight_every_itr == 0:
            weights = [x.data for x in nn.get_parameters().values()]
            comm.all_reduce(weights, division=True, inplace=True)

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if device_id == 0:  # loss and error locally, and elapsed time
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, error_train.d.copy())
            monitor_time.add(i * n_devices)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'params_%06d.h5' % (args.max_iter / n_devices)))
Beispiel #28
0
def animate(args):

    # get context
    ctx = get_extension_context(args.context)
    nn.set_default_context(ctx)
    logger.setLevel(logging.ERROR)  # to supress minor messages

    if not args.config:
        assert not args.params, "pretrained weights file is given, but corresponding config file is not. Please give both."
        download_provided_file(
            "https://nnabla.org/pretrained-models/nnabla-examples/GANs/first-order-model/voxceleb_trained_info.yaml"
        )
        args.config = 'voxceleb_trained_info.yaml'

        download_provided_file(
            "https://nnabla.org/pretrained-models/nnabla-examples/GANs/first-order-model/pretrained_fomm_params.h5"
        )

    config = read_yaml(args.config)

    dataset_params = config.dataset_params
    model_params = config.model_params

    if args.detailed:
        vis_params = config.visualizer_params
        visualizer = Visualizer(**vis_params)

    if not args.params:
        assert "log_dir" in config, "no log_dir found in config. therefore failed to locate pretrained parameters."
        param_file = os.path.join(config.log_dir, config.saved_parameters)
    else:
        param_file = args.params
    print(f"Loading {param_file} for image animation...")
    nn.load_parameters(param_file)

    bs, h, w, c = [1] + dataset_params.frame_shape
    source = nn.Variable((bs, c, h, w))
    driving_initial = nn.Variable((bs, c, h, w))
    driving = nn.Variable((bs, c, h, w))

    filename = args.driving

    # process repeated until all the test data is used
    driving_video = read_video(
        filename, dataset_params.frame_shape)  # (#frames, h, w, 3)
    driving_video = np.transpose(driving_video,
                                 (0, 3, 1, 2))  # (#frames, 3, h, w)

    source_img = imread(args.source, channel_first=True,
                        size=(256, 256)) / 255.
    source_img = source_img[:3]

    source.d = np.expand_dims(source_img, 0)
    driving_initial.d = driving_video[0][:3, ]

    with nn.parameter_scope("kp_detector"):
        kp_source = detect_keypoint(source,
                                    **model_params.kp_detector_params,
                                    **model_params.common_params,
                                    test=True,
                                    comm=False)
        persistent_all(kp_source)

    with nn.parameter_scope("kp_detector"):
        kp_driving_initial = detect_keypoint(driving_initial,
                                             **model_params.kp_detector_params,
                                             **model_params.common_params,
                                             test=True,
                                             comm=False)
        persistent_all(kp_driving_initial)

    with nn.parameter_scope("kp_detector"):
        kp_driving = detect_keypoint(driving,
                                     **model_params.kp_detector_params,
                                     **model_params.common_params,
                                     test=True,
                                     comm=False)
        persistent_all(kp_driving)

    if args.adapt_movement_scale:
        nn.forward_all([
            kp_source["value"], kp_source["jacobian"],
            kp_driving_initial["value"], kp_driving_initial["jacobian"]
        ])
        source_area = ConvexHull(kp_source['value'].d[0]).volume
        driving_area = ConvexHull(kp_driving_initial['value'].d[0]).volume
        adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
    else:
        adapt_movement_scale = 1

    kp_norm = adjust_kp(kp_source=unlink_all(kp_source),
                        kp_driving=kp_driving,
                        kp_driving_initial=unlink_all(kp_driving_initial),
                        adapt_movement_scale=adapt_movement_scale,
                        use_relative_movement=args.unuse_relative_movement,
                        use_relative_jacobian=args.unuse_relative_jacobian)
    persistent_all(kp_norm)

    with nn.parameter_scope("generator"):
        generated = occlusion_aware_generator(source,
                                              kp_source=unlink_all(kp_source),
                                              kp_driving=kp_norm,
                                              **model_params.generator_params,
                                              **model_params.common_params,
                                              test=True,
                                              comm=False)

    if not args.full and 'sparse_deformed' in generated:
        del generated['sparse_deformed']  # remove needless info

    persistent_all(generated)

    generated['kp_driving'] = kp_driving
    generated['kp_source'] = kp_source
    generated['kp_norm'] = kp_norm

    # generated contains these values;
    # 'mask': <Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25
    # 'sparse_deformed': <Variable((bs, num_kp+1, num_channel, h/4, w/4))  # (bs, num_kp + 1, c, h, w)
    # 'occlusion_map': <Variable((bs, 1, h/4, w/4))
    # 'deformed': <Variable((bs, c, h, w))
    # 'prediction': <Variable((bs, c, h, w))

    mode = "arbitrary"
    if "log_dir" in config:
        result_dir = os.path.join(args.out_dir,
                                  os.path.basename(config.log_dir), f"{mode}")
    else:
        result_dir = os.path.join(args.out_dir, "test_result", f"{mode}")

    # create an empty directory to save generated results
    _ = nm.Monitor(result_dir)

    # load the header images.
    header = imread("imgs/header_combined.png", channel_first=True)
    generated_images = list()

    # compute these in advance and reuse
    nn.forward_all([kp_source["value"], kp_source["jacobian"]],
                   clear_buffer=True)
    nn.forward_all(
        [kp_driving_initial["value"], kp_driving_initial["jacobian"]],
        clear_buffer=True)

    num_of_driving_frames = driving_video.shape[0]

    for frame_idx in tqdm(range(num_of_driving_frames)):
        driving.d = driving_video[frame_idx][:3, ]
        nn.forward_all([generated["prediction"], generated["deformed"]],
                       clear_buffer=True)

        if args.detailed:
            # visualize source w/kp, driving w/kp, deformed source, generated w/kp, generated image, occlusion map
            visualization = visualizer.visualize(source=source.d,
                                                 driving=driving.d,
                                                 out=generated)
            if args.full:
                visualization = reshape_result(visualization)  # (H, W, C)
            combined_image = visualization.transpose(2, 0, 1)  # (C, H, W)

        elif args.only_generated:
            combined_image = np.clip(generated["prediction"].d[0], 0.0, 1.0)
            combined_image = (255 * combined_image).astype(
                np.uint8)  # (C, H, W)

        else:
            # visualize source, driving, and generated image
            driving_fake = np.concatenate([
                np.clip(driving.d[0], 0.0, 1.0),
                np.clip(generated["prediction"].d[0], 0.0, 1.0)
            ],
                                          axis=2)
            header_source = np.concatenate([
                np.clip(header / 255., 0.0, 1.0),
                np.clip(source.d[0], 0.0, 1.0)
            ],
                                           axis=2)
            combined_image = np.concatenate([header_source, driving_fake],
                                            axis=1)
            combined_image = (255 * combined_image).astype(np.uint8)

        generated_images.append(combined_image)

    # once each video is generated, save it.
    output_filename = f"{os.path.splitext(os.path.basename(filename))[0]}.mp4"
    output_filename = f"{os.path.basename(args.source)}_by_{output_filename}"
    output_filename = output_filename.replace("#", "_")
    if args.output_png:
        monitor_vis = nm.MonitorImage(output_filename,
                                      nm.Monitor(result_dir),
                                      interval=1,
                                      num_images=1,
                                      normalize_method=lambda x: x)
        for frame_idx, img in enumerate(generated_images):
            monitor_vis.add(frame_idx, img)
    else:
        generated_images = [_.transpose(1, 2, 0) for _ in generated_images]
        # you might need to change ffmpeg_params according to your environment.
        mimsave(f'{os.path.join(result_dir, output_filename)}',
                generated_images,
                fps=args.fps,
                ffmpeg_params=[
                    "-pix_fmt", "yuv420p", "-vcodec", "libx264", "-f", "mp4",
                    "-q", "0"
                ])

    return
Beispiel #29
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for MNIST.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """
    args = get_args()

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    mnist_cnn_prediction = mnist_lenet_prediction
    if args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction

    # TRAIN
    # Create input variables.
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    # Create prediction graph.
    pred = mnist_cnn_prediction(image, test=False)
    pred.persistent = True
    # Create loss function.
    loss = F.mean(F.softmax_cross_entropy(pred, label))

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    # Create predition graph.
    vpred = mnist_cnn_prediction(vimage, test=True)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Initialize DataIterator for MNIST.
    data = data_iterator_mnist(args.batch_size, True)
    vdata = data_iterator_mnist(args.batch_size, False)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.val_interval == 0:
            # Validation
            ve = 0.0
            for j in range(args.val_iter):
                vimage.d, vlabel.d = vdata.next()
                vpred.forward(clear_buffer=True)
                ve += categorical_error(vpred.d, vlabel.d)
            monitor_verr.add(i, ve / args.val_iter)
        if i % args.model_save_interval == 0:
            nn.save_parameters(os.path.join(
                args.model_save_path, 'params_%06d.h5' % i))
        # Training forward
        image.d, label.d = data.next()
        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()
        e = categorical_error(pred.d, label.d)
        monitor_loss.add(i, loss.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    ve = 0.0
    for j in range(args.val_iter):
        vimage.d, vlabel.d = vdata.next()
        vpred.forward(clear_buffer=True)
        ve += categorical_error(vpred.d, vlabel.d)
    monitor_verr.add(i, ve / args.val_iter)

    parameter_file = os.path.join(
        args.model_save_path, '{}_params_{:06}.h5'.format(args.net, args.max_iter))
    nn.save_parameters(parameter_file)
Beispiel #30
0
def train(args):
    # Context
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Args
    latent = args.latent
    maps = args.maps
    batch_size = args.batch_size
    image_size = args.image_size
    lambda_ = args.lambda_

    # Model
    # generator loss
    z = nn.Variable([batch_size, latent])
    x_fake = generator(z, maps=maps, up=args.up).apply(persistent=True)
    p_fake = discriminator(x_fake, maps=maps)
    loss_gen = gan_loss(p_fake).apply(persistent=True)
    # discriminator loss
    p_fake = discriminator(x_fake, maps=maps)
    x_real = nn.Variable([batch_size, 3, image_size, image_size])
    p_real = discriminator(x_real, maps=maps)
    loss_dis = gan_loss(p_fake, p_real).apply(persistent=True)
    # gradient penalty
    eps = F.rand(shape=[batch_size, 1, 1, 1])
    x_rmix = eps * x_real + (1.0 - eps) * x_fake
    p_rmix = discriminator(x_rmix, maps=maps)
    x_rmix.need_grad = True  # Enabling gradient computation for double backward
    grads = nn.grad([p_rmix], [x_rmix])
    l2norms = [F.sum(g**2.0, [1, 2, 3])**0.5 for g in grads]
    gp = sum([F.mean((l - 1.0)**2.0) for l in l2norms])
    loss_dis += lambda_ * gp
    # generator with fixed value for test
    z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent))
    x_test = generator(z_test, maps=maps, test=True,
                       up=args.up).apply(persistent=True)

    # Solver
    solver_gen = S.Adam(args.lrg, args.beta1, args.beta2)
    solver_dis = S.Adam(args.lrd, args.beta1, args.beta2)

    with nn.parameter_scope("generator"):
        params_gen = nn.get_parameters()
        solver_gen.set_parameters(params_gen)
    with nn.parameter_scope("discriminator"):
        params_dis = nn.get_parameters()
        solver_dis.set_parameters(params_dis)

    # Monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10)
    monitor_loss_cri = MonitorSeries("Negative Critic Loss",
                                     monitor,
                                     interval=10)
    monitor_time = MonitorTimeElapsed("Training Time", monitor, interval=10)
    monitor_image_tile_train = MonitorImageTile("Image Tile Train",
                                                monitor,
                                                num_images=batch_size,
                                                interval=1,
                                                normalize_method=denormalize)
    monitor_image_tile_test = MonitorImageTile("Image Tile Test",
                                               monitor,
                                               num_images=batch_size,
                                               interval=1,
                                               normalize_method=denormalize)

    # Data Iterator
    di = data_iterator_cifar10(batch_size, True)

    # Train loop
    for i in range(args.max_iter):
        # Train discriminator
        x_fake.need_grad = False  # no need backward to generator
        for _ in range(args.n_critic):
            solver_dis.zero_grad()
            x_real.d = di.next()[0] / 127.5 - 1.0
            z.d = np.random.randn(batch_size, latent)
            loss_dis.forward(clear_no_need_grad=True)
            loss_dis.backward(clear_buffer=True)
            solver_dis.update()

        # Train generator
        x_fake.need_grad = True  # need backward to generator
        solver_gen.zero_grad()
        z.d = np.random.randn(batch_size, latent)
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        solver_gen.update()
        # Monitor
        monitor_loss_gen.add(i, loss_gen.d)
        monitor_loss_cri.add(i, -loss_dis.d)
        monitor_time.add(i)

        # Save
        if i % args.save_interval == 0:
            monitor_image_tile_train.add(i, x_fake)
            monitor_image_tile_test.add(i, x_test)
            nn.save_parameters(
                os.path.join(args.monitor_path, "params_{}.h5".format(i)))

    # Last
    x_test.forward(clear_buffer=True)
    nn.save_parameters(
        os.path.join(args.monitor_path, "params_{}.h5".format(i)))
    monitor_image_tile_train.add(i, x_fake)
    monitor_image_tile_test.add(i, x_test)
Beispiel #31
0
def train(args):
    """
    Main script.
    """

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    # TRAIN

    # Fake path
    z = nn.Variable([args.batch_size, 100, 1, 1])
    fake = generator(z)
    fake.persistent = True  # Not to clear at backward
    pred_fake = discriminator(fake)
    loss_gen = F.mean(F.sigmoid_cross_entropy(
        pred_fake, F.constant(1, pred_fake.shape)))
    fake_dis = fake.unlinked()
    pred_fake_dis = discriminator(fake_dis)
    loss_dis = F.mean(F.sigmoid_cross_entropy(
        pred_fake_dis, F.constant(0, pred_fake_dis.shape)))

    # Real path
    x = nn.Variable([args.batch_size, 1, 28, 28])
    pred_real = discriminator(x)
    loss_dis += F.mean(F.sigmoid_cross_entropy(pred_real,
                                               F.constant(1, pred_real.shape)))

    # Create Solver.
    solver_gen = S.Adam(args.learning_rate, beta1=0.5)
    solver_dis = S.Adam(args.learning_rate, beta1=0.5)
    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss_gen = M.MonitorSeries("Generator loss", monitor, interval=10)
    monitor_loss_dis = M.MonitorSeries(
        "Discriminator loss", monitor, interval=10)
    monitor_time = M.MonitorTimeElapsed("Time", monitor, interval=100)
    monitor_fake = M.MonitorImageTile(
        "Fake images", monitor, normalize_method=lambda x: x + 1 / 2.)

    data = data_iterator_mnist(args.batch_size, True)
    # Training loop.
    for i in range(args.max_iter):
        if i % args.model_save_interval == 0:
            with nn.parameter_scope("gen"):
                nn.save_parameters(os.path.join(
                    args.model_save_path, "generator_param_%06d.h5" % i))
            with nn.parameter_scope("dis"):
                nn.save_parameters(os.path.join(
                    args.model_save_path, "discriminator_param_%06d.h5" % i))

        # Training forward
        image, _ = data.next()
        x.d = image / 255. - 0.5  # [0, 255] to [-1, 1]
        z.d = np.random.randn(*z.shape)

        # Generator update.
        solver_gen.zero_grad()
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        solver_gen.weight_decay(args.weight_decay)
        solver_gen.update()
        monitor_fake.add(i, fake)
        monitor_loss_gen.add(i, loss_gen.d.copy())

        # Discriminator update.
        solver_dis.zero_grad()
        loss_dis.forward(clear_no_need_grad=True)
        loss_dis.backward(clear_buffer=True)
        solver_dis.weight_decay(args.weight_decay)
        solver_dis.update()
        monitor_loss_dis.add(i, loss_dis.d.copy())
        monitor_time.add(i)

    nnp = os.path.join(
        args.model_save_path, 'dcgan_%06d.nnp' % args.max_iter)
    runtime_contents = {
        'networks': [
            {'name': 'Generator',
             'batch_size': args.batch_size,
             'outputs': {'G': fake},
             'names': {'z': z}},
            {'name': 'Discriminator',
             'batch_size': args.batch_size,
             'outputs': {'D': pred_real},
             'names': {'x': x}}],
        'executors': [
            {'name': 'Generator',
             'network': 'Generator',
             'data': ['z'],
             'output': ['G']},
            {'name': 'Discriminator',
             'network': 'Discriminator',
             'data': ['x'],
             'output': ['D']}]}

    save.save(nnp, runtime_contents)
    from cpp_forward_check import check_cpp_forward
    check_cpp_forward(args.model_save_path, [z.d], [z], fake, nnp, "Generator")
Beispiel #32
0
def test_save_load_multi_datasets(tmpdir, datasets_o, datasets_m):
    nn.clear_parameters()
    ctx = get_extension_context('cpu', device_id=0, type_config='float')
    nn.set_default_context(ctx)

    batch_size = 64
    x = nn.Variable([batch_size, 1, 28, 28])
    Affine = PF.affine(x, 1, name='Affine')
    Sigmoid = F.sigmoid(Affine)

    target = nn.Variable([batch_size, 1])
    target.data.fill(1)
    BinaryCrossEntropy = F.binary_cross_entropy(Sigmoid, target)

    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    solver.set_learning_rate(5e-4)

    contents = {
        'global_config': {
            'default_context': ctx
        },
        'training_config': {
            'max_epoch': 100,
            'iter_per_epoch': 23,
            'save_best': True,
            'monitor_interval': 10
        },
        'networks': [{
            'name': 'Main',
            'batch_size': batch_size,
            'outputs': {
                'BinaryCrossEntropy': BinaryCrossEntropy
            },
            'names': {
                'x': x
            }
        }, {
            'name': 'MainValidation',
            'batch_size': batch_size,
            'outputs': {
                'BinaryCrossEntropy': BinaryCrossEntropy
            },
            'names': {
                'x': x
            }
        }, {
            'name': 'MainRuntime',
            'batch_size': batch_size,
            'outputs': {
                'Sigmoid': Sigmoid
            },
            'names': {
                'x': x
            }
        }],
        'datasets': [{
            'name': 'dataset1',
            'uri': 'DATASET_TRAINING1',
            'cache_dir': 'here_it_is',
            'shuffle': True,
            'batch_size': batch_size,
            'no_image_normalization': False,
            'variables': {
                'x': x,
                'BinaryCrossEntropy': BinaryCrossEntropy
            }
        }, {
            'name': 'dataset2',
            'uri': 'DATASET_TRAINING2',
            'cache_dir': 'here_it_is',
            'shuffle': True,
            'batch_size': batch_size,
            'no_image_normalization': False,
            'variables': {
                'x': x,
                'BinaryCrossEntropy': BinaryCrossEntropy
            },
        }],
        'optimizers': [{
            'name': 'optimizer',
            'solver': solver,
            'network': 'Main',
            'dataset': datasets_o,
            'weight_decay': 0,
            'lr_decay': 1,
            'lr_decay_interval': 1,
            'update_interval': 1
        }],
        'monitors': [{
            'name': 'train_error',
            'network': 'MainValidation',
            'dataset': datasets_m
        }, {
            'name': 'valid_error',
            'network': 'MainValidation',
            'dataset': datasets_m
        }],
        'executors': [{
            'name': 'Executor',
            'network': 'MainRuntime',
            'data': ['x'],
            'output': ['Sigmoid']
        }]
    }

    tmpdir.ensure(dir=True)
    tmppath = tmpdir.join('testsave.nnp')
    nnp_file = tmppath.strpath
    nnabla.utils.save.save(nnp_file, contents)
    nnabla.utils.load.load([nnp_file])
Beispiel #33
0
def main():
    """
    Main script.

    Steps:
    * Get and set context.
    * Load Dataset
    * Initialize DataIterator.
    * Create Networks
    *   Net for Labeled Data
    *   Net for Unlabeled Data
    *   Net for Test Data
    * Create Solver.
    * Training Loop.
    *   Test
    *   Training
    *     by Labeled Data
    *       Calculate Cross Entropy Loss 
    *     by Unlabeled Data
    *       Estimate Adversarial Direction
    *       Calculate LDS Loss
    """

    args = get_args()

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    shape_x = (1, 28, 28)
    n_h = args.n_units
    n_y = args.n_class

    # Load MNist Dataset
    from mnist_data import MnistDataSource
    with MnistDataSource(train=True) as d:
        x_t = d.images
        t_t = d.labels
    with MnistDataSource(train=False) as d:
        x_v = d.images
        t_v = d.labels
    x_t = np.array(x_t / 256.0).astype(np.float32)
    x_t, t_t = x_t[:args.n_train], t_t[:args.n_train]
    x_v, t_v = x_v[:args.n_valid], t_v[:args.n_valid]

    # Create Semi-supervised Datasets
    x_l, t_l, x_u, _ = split_dataset(x_t, t_t, args.n_labeled, args.n_class)
    x_u = np.r_[x_l, x_u]
    x_v = np.array(x_v / 256.0).astype(np.float32)

    # Create DataIterators for datasets of labeled, unlabeled and validation
    di_l = DataIterator(args.batchsize_l, [x_l, t_l])
    di_u = DataIterator(args.batchsize_u, [x_u])
    di_v = DataIterator(args.batchsize_v, [x_v, t_v])

    # Create networks
    # feed-forward-net building function
    def forward(x, test=False):
        return mlp_net(x, n_h, n_y, test)

    # Net for learning labeled data
    xl = nn.Variable((args.batchsize_l,) + shape_x, need_grad=False)
    hl = forward(xl, test=False)
    tl = nn.Variable((args.batchsize_l, 1), need_grad=False)
    loss_l = F.mean(F.softmax_cross_entropy(hl, tl))

    # Net for learning unlabeled data
    xu = nn.Variable((args.batchsize_u,) + shape_x, need_grad=False)
    r = nn.Variable((args.batchsize_u,) + shape_x, need_grad=True)
    eps = nn.Variable((args.batchsize_u,) + shape_x, need_grad=False)
    loss_u, yu = vat(xu, r, eps, forward, distance)

    # Net for evaluating valiation data
    xv = nn.Variable((args.batchsize_v,) + shape_x, need_grad=False)
    hv = forward(xv, test=True)
    tv = nn.Variable((args.batchsize_v, 1), need_grad=False)

    # Create solver
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Monitor trainig and validation stats.
    import nnabla.monitor as M
    monitor = M.Monitor(args.model_save_path)
    monitor_verr = M.MonitorSeries("Test error", monitor, interval=240)
    monitor_time = M.MonitorTimeElapsed("Elapsed time", monitor, interval=240)

    # Training Loop.
    t0 = time.time()

    for i in range(args.max_iter):

        # Validation Test
        if i % args.val_interval == 0:
            n_error = calc_validation_error(
                di_v, xv, tv, hv, args.val_iter)
            monitor_verr.add(i, n_error)

        #################################
        ## Training by Labeled Data #####
        #################################

        # input minibatch of labeled data into variables
        xl.d, tl.d = di_l.next()

        # initialize gradients
        solver.zero_grad()

        # forward, backward and update
        loss_l.forward(clear_no_need_grad=True)
        loss_l.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()

        #################################
        ## Training by Unlabeled Data ###
        #################################

        # input minibatch of unlabeled data into variables
        xu.d, = di_u.next()

        ##### Calculate Adversarial Noise #####

        # Sample random noise
        n = np.random.normal(size=xu.shape).astype(np.float32)

        # Normalize noise vector and input to variable
        r.d = get_direction(n)

        # Set xi, the power-method scaling parameter.
        eps.data.fill(args.xi_for_vat)

        # Calculate y without noise, only once.
        yu.forward(clear_buffer=True)

        # Do power method iteration
        for k in range(args.n_iter_for_power_method):
            # Initialize gradient to receive value
            r.grad.zero()

            # forward, backward, without update
            loss_u.forward(clear_no_need_grad=True)
            loss_u.backward(clear_buffer=True)

            # Normalize gradinet vector and input to variable
            r.d = get_direction(r.g)

        ##### Calculate loss for unlabeled data #####

        # Clear remained gradients
        solver.zero_grad()

        # Set epsilon, the adversarial noise scaling parameter.
        eps.data.fill(args.eps_for_vat)

        # forward, backward and update
        loss_u.forward(clear_no_need_grad=True)
        loss_u.backward(clear_buffer=True)
        solver.weight_decay(args.weight_decay)
        solver.update()

        ##### Learning rate update #####
        if i % args.iter_per_epoch == 0:
            solver.set_learning_rate(
                solver.learning_rate() * args.learning_rate_decay)
        monitor_time.add(i)

    # Evaluate the final model by the error rate with validation dataset
    valid_error = calc_validation_error(di_v, xv, tv, hv, args.val_iter)
    monitor_verr.add(i, valid_error)
    monitor_time.add(i)

    # Save the model.
    nnp_file = os.path.join(
        args.model_save_path, 'vat_%06d.nnp' % args.max_iter)
    runtime_contents = {
        'networks': [
            {'name': 'Validation',
             'batch_size': args.batchsize_v,
             'outputs': {'y': hv},
             'names': {'x': xv}}],
        'executors': [
            {'name': 'Runtime',
             'network': 'Validation',
             'data': ['x'],
             'output': ['y']}]}
    save.save(nnp_file, runtime_contents)

    from cpp_forward_check import check_cpp_forward
    check_cpp_forward(args.model_save_path, [xv.d], [xv], hv, nnp_file)
def main():

    # Read envvar `NNABLA_EXAMPLES_ROOT` to identify the path to your local
    # nnabla-examples directory.
    HERE = os.path.dirname(__file__)
    nnabla_examples_root = os.environ.get('NNABLA_EXAMPLES_ROOT', os.path.join(
        HERE, '../../../../nnabla-examples'))
    mnist_examples_root = os.path.realpath(
        os.path.join(nnabla_examples_root, 'mnist-collection'))
    sys.path.append(mnist_examples_root)
    nnabla_examples_git_url = 'https://github.com/sony/nnabla-examples'

    # Check if nnabla-examples found.
    try:
        from args import get_args
    except ImportError:
        print(
            'An envvar `NNABLA_EXAMPLES_ROOT`'
            ' which locates the local path to '
            '[nnabla-examples]({})'
            ' repository must be set correctly.'.format(
                nnabla_examples_git_url),
            file=sys.stderr)
        raise

    # Import MNIST data
    from mnist_data import data_iterator_mnist
    from classification import mnist_lenet_prediction, mnist_resnet_prediction

    import argparse
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--max_epoch", "-me", type=int, default=100)
    parser.add_argument("--iter_per_epoch", "-ipe", type=int, default=937)
    parser.add_argument("--cache_dir", "-cd", type=str, default='cache')
    parser.add_argument("--batch-size", "-b", type=int, default=128)
    parser.add_argument("--learning-rate", "-l", type=float, default=1e-3)
    parser.add_argument("--weight-decay", "-w", type=float, default=0)
    parser.add_argument("--device-id", "-d", type=str, default='0')
    parser.add_argument("--type-config", "-t", type=str, default='float')
    parser.add_argument("--net", "-n", type=str, default='lenet')
    parser.add_argument('--context', '-c', type=str,
                        default='cpu', help="Extension modules. ex) 'cpu', 'cudnn'.")
    args = parser.parse_args()

    args_added = parser.parse_args()

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(
        args.context, device_id=args.device_id, type_config=args.type_config)
    nn.set_default_context(ctx)

    mnist_cnn_prediction = mnist_lenet_prediction
    if args.net == 'resnet':
        mnist_cnn_prediction = mnist_resnet_prediction

    # Create a computation graph to be saved.
    x = nn.Variable([args.batch_size, 1, 28, 28])
    t = nn.Variable([args.batch_size, 1])
    h_t = mnist_cnn_prediction(x, test=False, aug=False)
    loss_t = F.mean(F.softmax_cross_entropy(h_t, t))
    h_v = mnist_cnn_prediction(x, test=True, aug=False)
    loss_v = F.mean(F.softmax_cross_entropy(h_v, t))

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Save NNP file (used in C++ inference later.).
    nnp_file = '{}_initialized.nnp'.format(args.net)
    training_contents = {
        'global_config': {'default_context': ctx},
        'training_config':
            {'max_epoch': args.max_epoch,
             'iter_per_epoch': args_added.iter_per_epoch,
             'save_best': True},
        'networks': [
            {'name': 'training',
             'batch_size': args.batch_size,
             'outputs': {'loss': loss_t},
             'names': {'x': x, 'y': t, 'loss': loss_t}},
            {'name': 'validation',
             'batch_size': args.batch_size,
             'outputs': {'loss': loss_v},
             'names': {'x': x, 'y': t, 'loss': loss_v}}],
        'optimizers': [
            {'name': 'optimizer',
             'solver': solver,
             'network': 'training',
             'dataset': 'mnist_training',
             'weight_decay': 0,
             'lr_decay': 1,
             'lr_decay_interval': 1,
             'update_interval': 1}],
        'datasets': [
            {'name': 'mnist_training',
             'uri': 'MNIST_TRAINING',
             'cache_dir': args.cache_dir + '/mnist_training.cache/',
             'variables': {'x': x, 'y': t},
             'shuffle': True,
             'batch_size': args.batch_size,
             'no_image_normalization': True},
            {'name': 'mnist_validation',
             'uri': 'MNIST_VALIDATION',
             'cache_dir': args.cache_dir + '/mnist_test.cache/',
             'variables': {'x': x, 'y': t},
             'shuffle': False,
             'batch_size': args.batch_size,
             'no_image_normalization': True
             }],
        'monitors': [
            {'name': 'training_loss',
             'network': 'validation',
             'dataset': 'mnist_training'},
            {'name': 'validation_loss',
             'network': 'validation',
             'dataset': 'mnist_validation'}],
    }
    nn.utils.save.save(nnp_file, training_contents)