示例#1
0
def run_model(mode,
              learning_rate=2e-4,
              beta1=0.5,
              l1_lambda=100,
              max_epochs=200,
              summary_freq=200,
              display_freq=50,
              save_freq=400,
              checkpoint_dir="summary/conGAN.ckpt"):
    if mode == "train":
        xs_train, ys_train = get_input("train")
        xs_val, ys_val = get_input("val")
        print("load train data successfully")
        print("input x shape is {}".format(xs_train.shape))
        print("input y shape is {}".format(ys_train.shape))
    else:
        xs_test, ys_test = get_input("test")
        print("load test data successfully")
        print("input x shape is {}".format(xs_test.shape))
        print("input y shape is {}".format(ys_test.shape))

    # build model
    # -----------
    with tf.name_scope("input"):
        x = tf.placeholder(tf.float32, [None, 256, 256, 3], name="x-input")
        y_ = tf.placeholder(tf.float32, [None, 256, 256, 3], name="y-input")

    G_sample = models.generator(x)

    logits_fake = models.con_discriminator(x, G_sample)
    logits_real = models.con_discriminator(x, y_)

    # get loss
    D_loss, G_loss_gan = loss.gan_loss(logits_fake=logits_fake,
                                       logits_real=logits_real)
    l1_loss = loss.l1_loss(y_, G_sample)
    with tf.variable_scope("G_loss"):
        G_loss = G_loss_gan + l1_lambda * l1_loss
    tf.summary.scalar("D_loss", D_loss)
    tf.summary.scalar("G_loss", G_loss)
    merged = tf.summary.merge_all()

    # get weights list
    D_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                               "discriminator")
    G_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "generator")

    # get solver
    D_solver, G_solver = get_solver(learning_rate=learning_rate, beta1=beta1)

    # get training steps
    D_train_step = D_solver.minimize(D_loss, var_list=D_vars)
    G_train_step = G_solver.minimize(G_loss, var_list=G_vars)
    # -----------

    # get session
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.InteractiveSession(config=config)

    # get saver
    saver = tf.train.Saver()

    # training phase
    if mode == "train":
        train_writer = tf.summary.FileWriter("summary/", sess.graph)
        # init
        sess.run(tf.global_variables_initializer())

        # iterations
        for step in range(max_epochs * NUM_TRAIN_IMAGES):
            if step % NUM_TRAIN_IMAGES == 0:
                print("Epoch: {}".format(step / NUM_TRAIN_IMAGES))

            mask = np.random.choice(NUM_TRAIN_IMAGES, 1)
            _, D_loss_curr = sess.run([D_train_step, D_loss],
                                      feed_dict={
                                          x: xs_train[mask],
                                          y_: ys_train[mask]
                                      })
            _, G_loss_curr = sess.run([G_train_step, G_loss],
                                      feed_dict={
                                          x: xs_train[mask],
                                          y_: ys_train[mask]
                                      })
            _, G_loss_curr = sess.run([G_train_step, G_loss],
                                      feed_dict={
                                          x: xs_train[mask],
                                          y_: ys_train[mask]
                                      })

            if step % display_freq == 0:
                print("step {}: D_loss: {}, G_loss: {}".format(
                    step, D_loss_curr, G_loss_curr))

            # save summary and checkpoint
            if step % summary_freq == 0:
                mask = np.random.choice(NUM_TRAIN_IMAGES, 30)
                summary = sess.run(merged,
                                   feed_dict={
                                       x: xs_train[mask],
                                       y_: ys_train[mask]
                                   })
                train_writer.add_summary(summary)
                saver.save(sess, checkpoint_dir)

            # save 5 sample images
            if step % save_freq == 0:
                samples_train = sess.run(G_sample,
                                         feed_dict={
                                             x: xs_train[0:5],
                                             y_: ys_train[0:5]
                                         })
                save_sample_img(samples_train, step=step, mode="train")
                samples_val = sess.run(G_sample,
                                       feed_dict={
                                           x: xs_val[0:5],
                                           y_: ys_val[0:5]
                                       })
                save_sample_img(samples_val, step=step, mode="val")

    # testing phase
    if mode == "test":
        saver.restore(sess, checkpoint_dir)
        for i in range(20):
            samples_test = sess.run(G_sample,
                                    feed_dict={
                                        x: xs_test[5 * i:5 * (i + 1)],
                                        y_: ys_test[5 * i:5 * (i + 1)]
                                    })
            save_sample_img(samples_test, step=i, mode="test")

    # close sess
    sess.close()

    return 0
示例#2
0
def train(args):
    if args.c_dim != len(args.selected_attrs):
        print("c_dim must be the same as the num of selected attributes. Modified c_dim.")
        args.c_dim = len(args.selected_attrs)

    # Dump the config information.
    config = dict()
    print("Used config:")
    for k in args.__dir__():
        if not k.startswith("_"):
            config[k] = getattr(args, k)
            print("'{}' : {}".format(k, getattr(args, k)))

    # Prepare Generator and Discriminator based on user config.
    generator = functools.partial(
        model.generator, conv_dim=args.g_conv_dim, c_dim=args.c_dim, num_downsample=args.num_downsample, num_upsample=args.num_upsample, repeat_num=args.g_repeat_num)
    discriminator = functools.partial(model.discriminator, image_size=args.image_size,
                                      conv_dim=args.d_conv_dim, c_dim=args.c_dim, repeat_num=args.d_repeat_num)

    x_real = nn.Variable(
        [args.batch_size, 3, args.image_size, args.image_size])
    label_org = nn.Variable([args.batch_size, args.c_dim, 1, 1])
    label_trg = nn.Variable([args.batch_size, args.c_dim, 1, 1])

    with nn.parameter_scope("dis"):
        dis_real_img, dis_real_cls = discriminator(x_real)

    with nn.parameter_scope("gen"):
        x_fake = generator(x_real, label_trg)
    x_fake.persistent = True  # to retain its value during computation.

    # get an unlinked_variable of x_fake
    x_fake_unlinked = x_fake.get_unlinked_variable()

    with nn.parameter_scope("dis"):
        dis_fake_img, dis_fake_cls = discriminator(x_fake_unlinked)

    # ---------------- Define Loss for Discriminator -----------------
    d_loss_real = (-1) * loss.gan_loss(dis_real_img)
    d_loss_fake = loss.gan_loss(dis_fake_img)
    d_loss_cls = loss.classification_loss(dis_real_cls, label_org)
    d_loss_cls.persistent = True

    # Gradient Penalty.
    alpha = F.rand(shape=(args.batch_size, 1, 1, 1))
    x_hat = F.mul2(alpha, x_real) + \
        F.mul2(F.r_sub_scalar(alpha, 1), x_fake_unlinked)

    with nn.parameter_scope("dis"):
        dis_for_gp, _ = discriminator(x_hat)
    grads = nn.grad([dis_for_gp], [x_hat])

    l2norm = F.sum(grads[0] ** 2.0, axis=(1, 2, 3)) ** 0.5
    d_loss_gp = F.mean((l2norm - 1.0) ** 2.0)

    # total discriminator loss.
    d_loss = d_loss_real + d_loss_fake + args.lambda_cls * \
        d_loss_cls + args.lambda_gp * d_loss_gp

    # ---------------- Define Loss for Generator -----------------
    g_loss_fake = (-1) * loss.gan_loss(dis_fake_img)
    g_loss_cls = loss.classification_loss(dis_fake_cls, label_trg)
    g_loss_cls.persistent = True

    # Reconstruct Images.
    with nn.parameter_scope("gen"):
        x_recon = generator(x_fake_unlinked, label_org)
    x_recon.persistent = True

    g_loss_rec = loss.recon_loss(x_real, x_recon)
    g_loss_rec.persistent = True

    # total generator loss.
    g_loss = g_loss_fake + args.lambda_rec * \
        g_loss_rec + args.lambda_cls * g_loss_cls

    # -------------------- Solver Setup ---------------------
    d_lr = args.d_lr  # initial learning rate for Discriminator
    g_lr = args.g_lr  # initial learning rate for Generator
    solver_dis = S.Adam(alpha=args.d_lr, beta1=args.beta1, beta2=args.beta2)
    solver_gen = S.Adam(alpha=args.g_lr, beta1=args.beta1, beta2=args.beta2)

    # register parameters to each solver.
    with nn.parameter_scope("dis"):
        solver_dis.set_parameters(nn.get_parameters())

    with nn.parameter_scope("gen"):
        solver_gen.set_parameters(nn.get_parameters())

    # -------------------- Create Monitors --------------------
    monitor = Monitor(args.monitor_path)
    monitor_d_cls_loss = MonitorSeries(
        'real_classification_loss', monitor, args.log_step)
    monitor_g_cls_loss = MonitorSeries(
        'fake_classification_loss', monitor, args.log_step)
    monitor_loss_dis = MonitorSeries(
        'discriminator_loss', monitor, args.log_step)
    monitor_recon_loss = MonitorSeries(
        'reconstruction_loss', monitor, args.log_step)
    monitor_loss_gen = MonitorSeries('generator_loss', monitor, args.log_step)
    monitor_time = MonitorTimeElapsed("Training_time", monitor, args.log_step)

    # -------------------- Prepare / Split Dataset --------------------
    using_attr = args.selected_attrs
    dataset, attr2idx, idx2attr = get_data_dict(args.attr_path, using_attr)
    random.seed(313)  # use fixed seed.
    random.shuffle(dataset)  # shuffle dataset.
    test_dataset = dataset[-2000:]  # extract 2000 images for test

    if args.num_data:
        # Use training data partially.
        training_dataset = dataset[:min(args.num_data, len(dataset) - 2000)]
    else:
        training_dataset = dataset[:-2000]
    print("Use {} images for training.".format(len(training_dataset)))

    # create data iterators.
    load_func = functools.partial(stargan_load_func, dataset=training_dataset,
                                  image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    data_iterator = data_iterator_simple(load_func, len(
        training_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    load_func_test = functools.partial(stargan_load_func, dataset=test_dataset,
                                       image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size)
    test_data_iterator = data_iterator_simple(load_func_test, len(
        test_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False)

    # Keep fixed test images for intermediate translation visualization.
    test_real_ndarray, test_label_ndarray = test_data_iterator.next()
    test_label_ndarray = test_label_ndarray.reshape(
        test_label_ndarray.shape + (1, 1))

    # -------------------- Training Loop --------------------
    one_epoch = data_iterator.size // args.batch_size
    num_max_iter = args.max_epoch * one_epoch

    for i in range(num_max_iter):
        # Get real images and labels.
        real_ndarray, label_ndarray = data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        # Generate target domain labels randomly.
        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        # ---------------- Train Discriminator -----------------
        # generate fake image.
        x_fake.forward(clear_no_need_grad=True)
        d_loss.forward(clear_no_need_grad=True)
        solver_dis.zero_grad()
        d_loss.backward(clear_buffer=True)
        solver_dis.update()

        monitor_loss_dis.add(i, d_loss.d.item())
        monitor_d_cls_loss.add(i, d_loss_cls.d.item())
        monitor_time.add(i)

        # -------------- Train Generator --------------
        if (i + 1) % args.n_critic == 0:
            g_loss.forward(clear_no_need_grad=True)
            solver_dis.zero_grad()
            solver_gen.zero_grad()
            x_fake_unlinked.grad.zero()
            g_loss.backward(clear_buffer=True)
            x_fake.backward(grad=None)
            solver_gen.update()
            monitor_loss_gen.add(i, g_loss.d.item())
            monitor_g_cls_loss.add(i, g_loss_cls.d.item())
            monitor_recon_loss.add(i, g_loss_rec.d.item())
            monitor_time.add(i)

            if (i + 1) % args.sample_step == 0:
                # save image.
                save_results(i, args, x_real, x_fake,
                             label_org, label_trg, x_recon)
                if args.test_during_training:
                    # translate images from test dataset.
                    x_real.d, label_org.d = test_real_ndarray, test_label_ndarray
                    label_trg.d = test_label_ndarray[rand_idx]
                    x_fake.forward(clear_no_need_grad=True)
                    save_results(i, args, x_real, x_fake, label_org,
                                 label_trg, None, is_training=False)

        # Learning rates get decayed
        if (i + 1) > int(0.5 * num_max_iter) and (i + 1) % args.lr_update_step == 0:
            g_lr = max(0, g_lr - (args.lr_update_step *
                                  args.g_lr / float(0.5 * num_max_iter)))
            d_lr = max(0, d_lr - (args.lr_update_step *
                                  args.d_lr / float(0.5 * num_max_iter)))
            solver_gen.set_learning_rate(g_lr)
            solver_dis.set_learning_rate(d_lr)
            print('learning rates decayed, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

    # Save parameters and training config.
    param_name = 'trained_params_{}.h5'.format(
        datetime.datetime.today().strftime("%m%d%H%M"))
    param_path = os.path.join(args.model_save_path, param_name)
    nn.save_parameters(param_path)
    config["pretrained_params"] = param_name

    with open(os.path.join(args.model_save_path, "training_conf_{}.json".format(datetime.datetime.today().strftime("%m%d%H%M"))), "w") as f:
        json.dump(config, f)

    # -------------------- Translation on test dataset --------------------
    for i in range(args.num_test):
        real_ndarray, label_ndarray = test_data_iterator.next()
        label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1))
        label_ndarray = label_ndarray.astype(float)
        x_real.d, label_org.d = real_ndarray, label_ndarray

        rand_idx = np.random.permutation(label_org.shape[0])
        label_trg.d = label_ndarray[rand_idx]

        x_fake.forward(clear_no_need_grad=True)
        save_results(i, args, x_real, x_fake, label_org,
                     label_trg, None, is_training=False)
示例#3
0
            fake_target = netG_A(real_src, training=True)
            rec_source = netG_B(fake_target, training=True)

            fake_src = netG_B(real_target, training=True)
            rec_target = netG_A(fake_src, training=True)

            fake_tar_seg = netG_A_map(fake_target, training=True)
            rec_src_seg = netG_B_map(rec_source, training=True)

            disc_target_fake = netD_A(fake_target, training=True)
            disc_source_fake = netD_B(fake_src, training=True)
            disc_target_real = netD_A(real_target, training=True)
            disc_source_real = netD_B(fake_src, training=True)

            loss_G_A = gan_loss(disc_target_fake,
                                tf.ones_like(disc_target_fake))
            loss_G_B = gan_loss(disc_source_fake,
                                tf.ones_like(disc_source_fake))

            loss_seg_A = categorical_loss(real_src_label, fake_tar_seg, 5)
            loss_seg_B = categorical_loss(real_src_label, rec_src_seg, 5)

            loss_cycle_A = cycle_loss(real_src, rec_source, lambda_para)
            loss_cycle_B = cycle_loss(real_target, rec_target, lambda_para)

            loss_G_A = loss_G_A + loss_cycle_A + loss_cycle_B + loss_seg_A
            loss_G_B = loss_G_B + loss_cycle_A + loss_cycle_B + loss_seg_B

            loss_D_A_real = gan_loss(disc_target_real,
                                     tf.ones_like(disc_target_real))
            loss_D_A_fake = gan_loss(disc_target_fake,