def train(args):
    """
    Main script.
    """

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    # TRAIN

    # Fake path
		#nn.load_parameters("/home/mizuochi/programing/font/dcgan_model_0220/generator_param_522000.h5")
	#z.d = np.random.randn(*z.shape)
	#gen.forward()
	#for i in range(40):
	#    Image.fromarray(np.uint8((gen.d[i][0]+1)*255/2.0)).save("./test/"+str(i)+".png")


    # Real path
    x = nn.Variable([args.batch_size, 1, 28, 28])
    x_ref = nn.Variable([args.batch_size, 1, 28, 28])
    #vec = nn.Variable([args.batch_size, 100])
    pred_vec = vectorizer(x,test = False)
    print pred_vec.shape
    #z = pred_vec.reshape((args.batch_size, 100, 1, 1))
    #gen = generator(z,test=True)
    gen = generator(pred_vec,test=False)
    gen.persistent = True
    with nn.parameter_scope("gen"):
    	nn.load_parameters("/home/mizuochi/programing/font/dcgan_model_0220/generator_param_290000.h5")
	#loss_dis = F.mean(F.sigmoid_cross_entropy(pred_vec, vec))
    print "x_ref shape",x_ref.shape
    print "gen shape",gen.shape
    #loss_dis = F.mean(F.squared_error(x_ref.reshape((64,28*28)), gen.reshape((64,28*28))))
    loss_dis = F.mean(F.squared_error(x_ref, gen))
    print loss_dis.d

    # Create Solver.
    solver_dis = S.Adam(args.learning_rate, beta1=0.5)
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss_dis = M.MonitorSeries(
        "Discriminator loss", monitor, interval=10)
    monitor_time = M.MonitorTimeElapsed("Time", monitor, interval=100)

    #data = data_iterator_mnist(args.batch_size, True)
    data = iterator.simple_data_iterator(load_kanji_data(),args.batch_size,True)

    # Training loop.
    for i in range(args.max_iter):
        if i % args.model_save_interval == 0:
            with nn.parameter_scope("dis"):
                nn.save_parameters(os.path.join(
                    args.model_save_path, "vectorizer_param_%06d.h5" % i))

        # Training forward
        buf,buf2 = data.next()
        x.d = buf*2.0/255. - 1.0
        x_ref.d = buf*2.0/255. - 1.0

        # Discriminator update.
        solver_dis.zero_grad()
        loss_dis.forward(clear_no_need_grad=True)
        loss_dis.backward(clear_buffer=True)
        solver_dis.weight_decay(args.weight_decay)
        solver_dis.update()
        monitor_loss_dis.add(i, loss_dis.d.copy())
        monitor_time.add(i)

    with nn.parameter_scope("dis"):
        nn.save_parameters(os.path.join(
            args.model_save_path, "vectorizer_param_%06d.h5" % i))
Exemplo n.º 2
0
def train(args):
    """
    Main script.
    """

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create CNN network for both training and testing.
    # TRAIN

    # Fake path
    x1 = nn.Variable([args.batch_size, 1, 28, 28])

    #z = nn.Variable([args.batch_size, VEC_SIZE, 1, 1])
    #z = vectorizer(x1,maxh = 1024)
    #fake = generator(z,maxh= 1024)
    z = vectorizer(x1)
    fake = generator(z)
    fake.persistent = True  # Not to clear at backward
    pred_fake = discriminator(fake)
    loss_gen = F.mean(
        F.sigmoid_cross_entropy(pred_fake, F.constant(1, pred_fake.shape)))
    loss_vec = F.mean(F.squared_error(fake, x1))
    fake_dis = fake.unlinked()
    pred_fake_dis = discriminator(fake_dis)
    loss_dis = F.mean(
        F.sigmoid_cross_entropy(pred_fake_dis,
                                F.constant(0, pred_fake_dis.shape)))

    # Real path
    x = nn.Variable([args.batch_size, 1, 28, 28])
    pred_real = discriminator(x)
    loss_dis += F.mean(
        F.sigmoid_cross_entropy(pred_real, F.constant(1, pred_real.shape)))

    # Create Solver.
    solver_gen = S.Adam(args.learning_rate, beta1=0.5)
    solver_dis = S.Adam(args.learning_rate, beta1=0.5)
    solver_vec = S.Adam(args.learning_rate, beta1=0.5)
    with nn.parameter_scope("vec"):
        solver_vec.set_parameters(nn.get_parameters())
    with nn.parameter_scope("gen"):
        solver_vec.set_parameters(nn.get_parameters())
    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss_gen = M.MonitorSeries("Generator loss", monitor, interval=10)
    monitor_loss_dis = M.MonitorSeries("Discriminator loss",
                                       monitor,
                                       interval=10)
    monitor_loss_vec = M.MonitorSeries("Vectorizer loss", monitor, interval=10)
    monitor_time = M.MonitorTimeElapsed("Time", monitor, interval=100)
    monitor_fake = M.MonitorImageTile("Fake images",
                                      monitor,
                                      normalize_method=lambda x: x + 1 / 2.)
    monitor_vec1 = M.MonitorImageTile("vec images1",
                                      monitor,
                                      normalize_method=lambda x: x + 1 / 2.)
    monitor_vec2 = M.MonitorImageTile("vec images2",
                                      monitor,
                                      normalize_method=lambda x: x + 1 / 2.)

    #data = data_iterator_mnist(args.batch_size, True)
    data = iterator.simple_data_iterator(load_kanji_data(), args.batch_size,
                                         True)

    # Training loop.
    for i in range(args.max_iter):
        if i % args.model_save_interval == 0:
            with nn.parameter_scope("gen"):
                nn.save_parameters(
                    os.path.join(args.model_save_path,
                                 "generator_param_%06d.h5" % i))
            with nn.parameter_scope("dis"):
                nn.save_parameters(
                    os.path.join(args.model_save_path,
                                 "discriminator_param_%06d.h5" % i))

        # Training forward
        image, _ = data.next()

        x1.d = image / 255. - 0.5
        # Generator update.
        solver_vec.zero_grad()
        loss_vec.forward(clear_no_need_grad=True)
        loss_vec.backward(clear_buffer=True)
        solver_vec.weight_decay(args.weight_decay)
        solver_vec.update()
        monitor_vec1.add(i, fake)
        monitor_vec2.add(i, x1)
        monitor_loss_vec.add(i, loss_vec.d.copy())

        x.d = image / 255. - 0.5  # [0, 255] to [-1, 1]
        z.d = np.random.randn(*z.shape)

        # Generator update.
        solver_gen.zero_grad()
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        solver_gen.weight_decay(args.weight_decay)
        solver_gen.update()
        monitor_fake.add(i, fake)
        monitor_loss_gen.add(i, loss_gen.d.copy())

        # Discriminator update.
        solver_dis.zero_grad()
        loss_dis.forward(clear_no_need_grad=True)
        loss_dis.backward(clear_buffer=True)
        solver_dis.weight_decay(args.weight_decay)
        solver_dis.update()
        monitor_loss_dis.add(i, loss_dis.d.copy())
        monitor_time.add(i)

    with nn.parameter_scope("gen"):
        nn.save_parameters(
            os.path.join(args.model_save_path, "generator_param_%06d.h5" % i))
    with nn.parameter_scope("dis"):
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         "discriminator_param_%06d.h5" % i))
Exemplo n.º 3
0
def train(args):
    x1 = nn.Variable([args.batch_size, 1, 28, 28])
    z_vec = vectorizer(x1)
    z = z_vec.unlinked()
    fake2 = generator(z_vec)
    fake = generator(z)
    fake.persistent = True  # Not to clear at backward
    pred_fake = discriminator(fake)
    loss_gen = F.mean(
        F.sigmoid_cross_entropy(pred_fake, F.constant(1, pred_fake.shape)))
    loss_vec = F.mean(F.squared_error(fake2, x1))
    fake_dis = fake.unlinked()
    pred_fake_dis = discriminator(fake_dis)
    loss_dis = F.mean(
        F.sigmoid_cross_entropy(pred_fake_dis,
                                F.constant(0, pred_fake_dis.shape)))

    # Real path
    x = nn.Variable([args.batch_size, 1, 28, 28])
    pred_real = discriminator(x)
    loss_dis += F.mean(
        F.sigmoid_cross_entropy(pred_real, F.constant(1, pred_real.shape)))

    # Create Solver.
    solver_gen = S.Adam(args.learning_rate, beta1=0.5)
    solver_dis = S.Adam(args.learning_rate, beta1=0.5)
    solver_vec = S.Adam(args.learning_rate, beta1=0.5)
    with nn.parameter_scope("vec"):
        solver_vec.set_parameters(nn.get_parameters())
    with nn.parameter_scope("gen"):
        solver_vec.set_parameters(nn.get_parameters())
    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    # Create monitor.
    import nnabla.monitor as M
    monitor = M.Monitor(args.monitor_path)
    monitor_loss_gen = M.MonitorSeries("Generator loss", monitor, interval=10)
    monitor_loss_dis = M.MonitorSeries("Discriminator loss",
                                       monitor,
                                       interval=10)
    monitor_loss_vec = M.MonitorSeries("Vectorizer loss", monitor, interval=10)
    monitor_time = M.MonitorTimeElapsed("Time", monitor, interval=100)
    monitor_fake = M.MonitorImageTile("Fake images",
                                      monitor,
                                      normalize_method=lambda x: x + 1 / 2.)
    monitor_vec1 = M.MonitorImageTile("vec images1",
                                      monitor,
                                      normalize_method=lambda x: x + 1 / 2.)
    monitor_vec2 = M.MonitorImageTile("vec images2",
                                      monitor,
                                      normalize_method=lambda x: x + 1 / 2.)

    #data = data_iterator_mnist(args.batch_size, True)
    data = iterator.simple_data_iterator(load_kanji_data(), args.batch_size,
                                         True)

    # Training loop.
    for i in range(args.max_iter):
        if i % args.model_save_interval == 0:
            with nn.parameter_scope("gen"):
                nn.save_parameters(
                    os.path.join(args.model_save_path,
                                 "generator_param_%06d.h5" % i))
            with nn.parameter_scope("dis"):
                nn.save_parameters(
                    os.path.join(args.model_save_path,
                                 "discriminator_param_%06d.h5" % i))
            with nn.parameter_scope("vec"):
                nn.save_parameters(
                    os.path.join(args.model_save_path,
                                 "vectorizer_param_%06d.h5" % i))

        # Training forward
        image, _ = data.next()

        x1.d = image / 255. * 2 - 1.0
        # Generator update.
        solver_vec.zero_grad()
        loss_vec.forward(clear_no_need_grad=True)
        loss_vec.backward(clear_buffer=True)
        solver_vec.weight_decay(args.weight_decay)
        solver_vec.update()
        fake2.forward()
        monitor_vec1.add(i, fake2)
        monitor_vec2.add(i, x1)
        monitor_loss_vec.add(i, loss_vec.d.copy())

        x.d = image / 255. * 2 - 1.0  # [0, 255] to [-1, 1]
        z.d = np.random.randn(*z.shape)

        # Generator update.
        solver_gen.zero_grad()
        loss_gen.forward(clear_no_need_grad=True)
        loss_gen.backward(clear_buffer=True)
        solver_gen.weight_decay(args.weight_decay)
        solver_gen.update()
        monitor_fake.add(i, fake)
        monitor_loss_gen.add(i, loss_gen.d.copy())

        # Discriminator update.
        solver_dis.zero_grad()
        loss_dis.forward(clear_no_need_grad=True)
        loss_dis.backward(clear_buffer=True)
        solver_dis.weight_decay(args.weight_decay)
        solver_dis.update()
        monitor_loss_dis.add(i, loss_dis.d.copy())
        monitor_time.add(i)