def train(data_folderpath='data/edges2shoes', image_size=256, ndf=64, ngf=64,
          lr_d=2e-4, lr_g=2e-4, n_iterations=int(1e6),
          batch_size=64, iters_per_checkpoint=100, n_checkpoint_samples=16,
          reconstruction_weight=100, out_dir='gan'):

    logger = SummaryWriter(out_dir)
    logger.add_scalar('d_lr', lr_d, 0)
    logger.add_scalar('g_lr', lr_g, 0)

    data_iterator = iterate_minibatches(
        data_folderpath + "/train/*.jpg", batch_size, image_size)
    val_data_iterator = iterate_minibatches(
        data_folderpath + "/val/*.jpg", n_checkpoint_samples, image_size)
    img_ab_fixed, _ = next(val_data_iterator)
    img_a_fixed, img_b_fixed = img_ab_fixed[:, 0], img_ab_fixed[:, 1]

    img_a_shape = img_a_fixed.shape[1:]
    img_b_shape = img_b_fixed.shape[1:]
    patch = int(img_a_shape[0] / 2**4)  # n_layers
    disc_patch = (patch, patch, 1)
    print("img a shape ", img_a_shape)
    print("img b shape ", img_b_shape)
    print("disc_patch ", disc_patch)

    # plot real text for reference
    log_images(img_a_fixed, 'real_a', '0', logger)
    log_images(img_b_fixed, 'real_b', '0', logger)

    # build models
    D = build_discriminator(
        img_a_shape, img_b_shape, ndf, activation='sigmoid')
    G = build_generator(img_a_shape, ngf)

    # build model outputs
    img_a_input = Input(shape=img_a_shape)
    img_b_input = Input(shape=img_b_shape)

    fake_samples = G(img_a_input)
    D_real = D([img_a_input, img_b_input])
    D_fake = D([img_a_input, fake_samples])

    loss_reconstruction = partial(mean_absolute_error,
                                  real_samples=img_b_input,
                                  fake_samples=fake_samples)
    loss_reconstruction.__name__ = 'loss_reconstruction'

    # define D graph and optimizer
    G.trainable = False
    D.trainable = True
    D_model = Model(inputs=[img_a_input, img_b_input],
                    outputs=[D_real, D_fake])
    D_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.999),
                    loss='binary_crossentropy')

    # define D(G(z)) graph and optimizer
    G.trainable = True
    D.trainable = False
    G_model = Model(inputs=[img_a_input, img_b_input],
                    outputs=[D_fake, fake_samples])
    G_model.compile(Adam(lr=lr_g, beta_1=0.5, beta_2=0.999),
                    loss=['binary_crossentropy', loss_reconstruction],
                    loss_weights=[1, reconstruction_weight])

    ones = np.ones((batch_size, ) + disc_patch, dtype=np.float32)
    zeros = np.zeros((batch_size, ) + disc_patch, dtype=np.float32)
    dummy = zeros

    for i in range(n_iterations):
        D.trainable = True
        G.trainable = False

        image_ab_batch, _ = next(data_iterator)
        loss_d = D_model.train_on_batch(
            [image_ab_batch[:, 0], image_ab_batch[:, 1]],
            [ones, zeros])

        D.trainable = False
        G.trainable = True
        image_ab_batch, _ = next(data_iterator)
        loss_g = G_model.train_on_batch(
            [image_ab_batch[:, 0], image_ab_batch[:, 1]],
            [ones, dummy])

        print("iter", i)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_image = G.predict(img_a_fixed)
            log_images(fake_image, 'val_fake', i, logger)
            save_model(G, out_dir)

        log_losses(loss_d, loss_g, i, logger)
def train(n_channels=3,
          resolution=32,
          z_dim=128,
          n_labels=0,
          lr=1e-3,
          e_drift=1e-3,
          wgp_target=750,
          initial_resolution=4,
          total_kimg=25000,
          training_kimg=500,
          transition_kimg=500,
          iters_per_checkpoint=500,
          n_checkpoint_images=16,
          glob_str='cifar10',
          out_dir='cifar10'):

    # instantiate logger
    logger = SummaryWriter(out_dir)

    # load data
    batch_size = MINIBATCH_OVERWRITES[0]
    train_iterator = iterate_minibatches(glob_str, batch_size, resolution)

    # build models
    G = Generator(n_channels, resolution, z_dim, n_labels)
    D = Discriminator(n_channels, resolution, n_labels)

    G_train, D_train = GAN(G, D, z_dim, n_labels, resolution, n_channels)

    D_opt = Adam(lr=lr, beta_1=0.0, beta_2=0.99, epsilon=1e-8)
    G_opt = Adam(lr=lr, beta_1=0.0, beta_2=0.99, epsilon=1e-8)

    # define loss functions
    D_loss = [loss_mean, loss_gradient_penalty, 'mse']
    G_loss = [loss_wasserstein]

    # compile graphs used during training
    G.compile(G_opt, loss=loss_wasserstein)
    D.trainable = False
    G_train.compile(G_opt, loss=G_loss)
    D.trainable = True
    D_train.compile(D_opt, loss=D_loss, loss_weights=[1, GP_WEIGHT, e_drift])

    # for computing the loss
    ones = np.ones((batch_size, 1), dtype=np.float32)
    zeros = ones * 0.0

    # fix a z vector for training evaluation
    z_fixed = np.random.normal(0, 1, size=(n_checkpoint_images, z_dim))

    # vars
    resolution_log2 = int(np.log2(resolution))
    starting_block = resolution_log2
    starting_block -= np.floor(np.log2(initial_resolution))
    cur_block = starting_block
    cur_nimg = 0

    # compute duration of each phase and use proxy to update minibatch size
    phase_kdur = training_kimg + transition_kimg
    phase_idx_prev = 0

    # offset variable for transitioning between blocks
    offset = 0
    i = 0
    while cur_nimg < total_kimg * 1000:
        # block processing
        kimg = cur_nimg / 1000.0
        phase_idx = int(np.floor((kimg + transition_kimg) / phase_kdur))
        phase_idx = max(phase_idx, 0.0)
        phase_kimg = phase_idx * phase_kdur

        # update batch size and ones vector if we switched phases
        if phase_idx_prev < phase_idx:
            batch_size = MINIBATCH_OVERWRITES[phase_idx]
            train_iterator = iterate_minibatches(glob_str, batch_size)
            ones = np.ones((batch_size, 1), dtype=np.float32)
            zeros = ones * 0.0
            phase_idx_prev = phase_idx

        # possibly gradually update current level of detail
        if transition_kimg > 0 and phase_idx > 0:
            offset = (kimg + transition_kimg - phase_kimg) / transition_kimg
            offset = min(offset, 1.0)
            offset = offset + phase_idx - 1
            cur_block = max(starting_block - offset, 0.0)

        # update level of detail
        K.set_value(G_train.cur_block, np.float32(cur_block))
        K.set_value(D_train.cur_block, np.float32(cur_block))

        # train D
        for j in range(N_CRITIC_ITERS):
            z = np.random.normal(0, 1, size=(batch_size, z_dim))
            real_batch = next(train_iterator)
            fake_batch = G.predict_on_batch([z])
            interpolated_batch = get_interpolated_images(
                real_batch, fake_batch)
            losses_d = D_train.train_on_batch(
                [real_batch, fake_batch, interpolated_batch],
                [ones, ones * wgp_target, zeros])
            cur_nimg += batch_size

        # train G
        z = np.random.normal(0, 1, size=(batch_size, z_dim))
        loss_g = G_train.train_on_batch(z, -1 * ones)

        logger.add_scalar("cur_block", cur_block, i)
        logger.add_scalar("learning_rate", lr, i)
        logger.add_scalar("batch_size", z.shape[0], i)
        print("iter", i, "cur_block", cur_block, "lr", lr, "kimg", kimg,
              "losses_d", losses_d, "loss_g", loss_g)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_images = G.predict(z_fixed)
            # log fake images
            log_images(fake_images, 'fake', i, logger, fake_images.shape[1],
                       fake_images.shape[2], int(np.sqrt(n_checkpoint_images)))

            # plot real images for reference
            log_images(real_batch[:n_checkpoint_images], 'real', i, logger,
                       real_batch.shape[1], real_batch.shape[2],
                       int(np.sqrt(n_checkpoint_images)))

            # save the model to eventually resume training or do inference
            save_model(G, out_dir + "/model.json", out_dir + "/model.h5")

        log_losses(losses_d, loss_g, i, logger)
        i += 1
Ejemplo n.º 3
0
def train(data_filepath='data/flowers.hdf5',
          ndf=64,
          ngf=128,
          z_dim=128,
          emb_dim=128,
          lr_d=2e-4,
          lr_g=2e-4,
          n_iterations=int(1e6),
          batch_size=64,
          iters_per_checkpoint=500,
          n_checkpoint_samples=16,
          out_dir='gan'):

    logger = SummaryWriter(out_dir)
    logger.add_scalar('d_lr', lr_d, 0)
    logger.add_scalar('g_lr', lr_g, 0)
    train_data = get_data(data_filepath, 'train')
    val_data = get_data(data_filepath, 'valid')
    data_iterator = iterate_minibatches(train_data, batch_size)
    val_data_iterator = iterate_minibatches(val_data, n_checkpoint_samples)
    val_data = next(val_data_iterator)
    img_fixed = images_from_bytes(val_data[0])
    emb_fixed = val_data[1]
    txt_fixed = val_data[2]

    img_shape = img_fixed[0].shape
    emb_shape = emb_fixed[0].shape
    print("emb shape {}".format(img_shape))
    print("img shape {}".format(emb_shape))
    z_shape = (z_dim, )

    # plot real text for reference
    log_images(img_fixed, 'real', '0', logger)
    log_text(txt_fixed, 'real', '0', logger)

    # build models
    D = build_discriminator(img_shape,
                            emb_shape,
                            emb_dim,
                            ndf,
                            activation='sigmoid')
    G = build_generator(z_shape, emb_shape, emb_dim, ngf)

    # build model outputs
    real_inputs = Input(shape=img_shape)
    txt_inputs = Input(shape=emb_shape)
    txt_shuf_inputs = Input(shape=emb_shape)
    z_inputs = Input(shape=(z_dim, ))

    fake_samples = G([z_inputs, txt_inputs])
    D_real = D([real_inputs, txt_inputs])
    D_wrong = D([real_inputs, txt_shuf_inputs])
    D_fake = D([fake_samples, txt_inputs])

    # define D graph and optimizer
    G.trainable = False
    D.trainable = True
    D_model = Model(
        inputs=[real_inputs, txt_inputs, txt_shuf_inputs, z_inputs],
        outputs=[D_real, D_wrong, D_fake])
    D_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.9),
                    loss='binary_crossentropy',
                    loss_weights=[1, 0.5, 0.5])

    # define D(G(z)) graph and optimizer
    G.trainable = True
    D.trainable = False
    G_model = Model(inputs=[z_inputs, txt_inputs], outputs=D_fake)
    G_model.compile(Adam(lr=lr_g, beta_1=0.5, beta_2=0.9),
                    loss='binary_crossentropy')

    ones = np.ones((batch_size, 1, 1, 1), dtype=np.float32)
    zeros = np.zeros((batch_size, 1, 1, 1), dtype=np.float32)

    # fix a z vector for training evaluation
    z_fixed = np.random.uniform(-1, 1, size=(n_checkpoint_samples, z_dim))

    for i in range(n_iterations):
        start = clock()
        D.trainable = True
        G.trainable = False
        z = np.random.normal(0, 1, size=(batch_size, z_dim))
        real_batch = next(data_iterator)
        images_batch = images_from_bytes(real_batch[0])
        emb_text_batch = real_batch[1]
        ids = np.arange(len(emb_text_batch))
        np.random.shuffle(ids)
        emb_text_batch_shuffle = emb_text_batch[ids]
        loss_d = D_model.train_on_batch(
            [images_batch, emb_text_batch, emb_text_batch_shuffle, z],
            [ones, zeros, zeros])

        D.trainable = False
        G.trainable = True
        z = np.random.normal(0, 1, size=(batch_size, z_dim))
        real_batch = next(data_iterator)
        loss_g = G_model.train_on_batch([z, real_batch[1]], ones)

        print("iter", i, "time", clock() - start)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_image = G.predict([z_fixed, emb_fixed])
            log_images(fake_image, 'val_fake', i, logger)
            save_model(G, 'gan')

        log_losses(loss_d, loss_g, i, logger)
Ejemplo n.º 4
0
# log experiment settings:
for key in sorted(config_init):
    conf = key + " -> " + str(config_init[key])
    log(conf, config_init['out_path'])

# load datasets:
trainloader_source, testloader_source, trainloader_target, testloader_target = load_datasets(config_init)

# load network:
net = load_net(config_init)

# load optimizer:
my_optimizer = torch_optimizer(config_init, net, is_encoder=False)

# to keep initial loss weights (lambda):
config = deepcopy(config_init)

for epoch in range(config["num_epochs"]):
    st_time = time.time()

    lr, optimizer = my_optimizer.update_lr(epoch)
    config = lambda_schedule(config, config_init, epoch)

    config = training_loop(net, optimizer, trainloader_source, trainloader_target, config, epoch)

    config = testing_loop(net, testloader_source, config, is_source=True)
    config = testing_loop(net, testloader_target, config, is_source=False)

    print(out_name)
    log_losses(config, epoch, lr, st_time)
Ejemplo n.º 5
0
def train(data_folderpath='data/edges2shoes',
          image_size=256,
          ndf=64,
          ngf=64,
          lr_d=2e-4,
          lr_g=2e-4,
          n_iterations=int(1e6),
          batch_size=64,
          iters_per_checkpoint=100,
          n_checkpoint_samples=16,
          feature_matching_weight=10,
          out_dir='lsgan'):

    logger = SummaryWriter(out_dir)
    logger.add_scalar('d_lr', lr_d, 0)
    logger.add_scalar('g_lr', lr_g, 0)

    data_iterator = iterate_minibatches(data_folderpath + "/train/*.jpg",
                                        batch_size, image_size)
    val_data_iterator = iterate_minibatches(data_folderpath + "/val/*.jpg",
                                            n_checkpoint_samples, image_size)
    img_ab_fixed, _ = next(val_data_iterator)
    img_a_fixed, img_b_fixed = img_ab_fixed[:, 0], img_ab_fixed[:, 1]

    img_a_shape = img_a_fixed.shape[1:]
    img_b_shape = img_b_fixed.shape[1:]
    patch = int(img_a_shape[0] / 2**3)  # n_layers
    disc_patch_0 = (patch, patch, 1)
    disc_patch_1 = (int(patch / 2), int(patch / 2), 1)
    disc_patch_2 = (int(patch / 4), int(patch / 4), 1)
    print("img a shape ", img_a_shape)
    print("img b shape ", img_b_shape)
    print("disc_patch ", disc_patch_0)
    print("disc_patch ", disc_patch_1)
    print("disc_patch ", disc_patch_2)

    # plot real text for reference
    log_images(img_a_fixed, 'real_a', '0', logger)
    log_images(img_b_fixed, 'real_b', '0', logger)

    # build models
    D0 = build_discriminator(img_a_shape,
                             img_b_shape,
                             ndf,
                             activation='linear',
                             n_downsampling=0,
                             name='Discriminator0')
    D1 = build_discriminator(img_a_shape,
                             img_b_shape,
                             ndf,
                             activation='linear',
                             n_downsampling=1,
                             name='Discriminator1')
    D2 = build_discriminator(img_a_shape,
                             img_b_shape,
                             ndf,
                             activation='linear',
                             n_downsampling=2,
                             name='Discriminator2')
    G = build_global_generator(img_a_shape, ngf)

    # build model outputs
    img_a_input = Input(shape=img_a_shape)
    img_b_input = Input(shape=img_b_shape)

    fake_samples = G(img_a_input)[0]
    D0_real = D0([img_a_input, img_b_input])[0]
    D0_fake = D0([img_a_input, fake_samples])[0]

    D1_real = D1([img_a_input, img_b_input])[0]
    D1_fake = D1([img_a_input, fake_samples])[0]

    D2_real = D2([img_a_input, img_b_input])[0]
    D2_fake = D2([img_a_input, fake_samples])[0]

    # define D graph and optimizer
    G.trainable = False
    D0.trainable = True
    D1.trainable = True
    D2.trainable = True
    D0_model = Model([img_a_input, img_b_input], [D0_real, D0_fake],
                     name='Discriminator0_model')
    D1_model = Model([img_a_input, img_b_input], [D1_real, D1_fake],
                     name='Discriminator1_model')
    D2_model = Model([img_a_input, img_b_input], [D2_real, D2_fake],
                     name='Discriminator2_model')
    D0_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.999),
                     loss=['mse', 'mse'])
    D1_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.999),
                     loss=['mse', 'mse'])
    D2_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.999),
                     loss=['mse', 'mse'])

    # define D(G(z)) graph and optimizer
    G.trainable = True
    D0.trainable = False
    D1.trainable = False
    D2.trainable = False

    loss_fm0 = partial(loss_feature_matching,
                       image_input=img_a_input,
                       real_samples=img_b_input,
                       D=D0,
                       feature_matching_weight=feature_matching_weight)
    loss_fm1 = partial(loss_feature_matching,
                       image_input=img_a_input,
                       real_samples=img_b_input,
                       D=D1,
                       feature_matching_weight=feature_matching_weight)
    loss_fm2 = partial(loss_feature_matching,
                       image_input=img_a_input,
                       real_samples=img_b_input,
                       D=D2,
                       feature_matching_weight=feature_matching_weight)

    G_model = Model(inputs=[img_a_input, img_b_input],
                    outputs=[
                        D0_fake, D1_fake, D2_fake, fake_samples, fake_samples,
                        fake_samples
                    ])
    G_model.compile(Adam(lr=lr_g, beta_1=0.5, beta_2=0.999),
                    loss=['mse', 'mse', 'mse', loss_fm0, loss_fm1, loss_fm2])

    ones_0 = np.ones((batch_size, ) + disc_patch_0, dtype=np.float32)
    ones_1 = np.ones((batch_size, ) + disc_patch_1, dtype=np.float32)
    ones_2 = np.ones((batch_size, ) + disc_patch_2, dtype=np.float32)
    zeros_0 = np.zeros((batch_size, ) + disc_patch_0, dtype=np.float32)
    zeros_1 = np.zeros((batch_size, ) + disc_patch_1, dtype=np.float32)
    zeros_2 = np.zeros((batch_size, ) + disc_patch_2, dtype=np.float32)
    dummy = np.ones((batch_size, ), dtype=np.float32)

    for i in range(n_iterations):
        D0.trainable = True
        D1.trainable = True
        D2.trainable = True
        G.trainable = False

        image_ab_batch, _ = next(data_iterator)
        fake_image = G.predict(image_ab_batch[:, 0])[0]
        loss_d0 = D0_model.train_on_batch(
            [image_ab_batch[:, 0], image_ab_batch[:, 1]], [ones_0, zeros_0])
        loss_d1 = D0_model.train_on_batch(
            [image_ab_batch[:, 0], image_ab_batch[:, 1]], [ones_1, zeros_1])
        loss_d2 = D0_model.train_on_batch(
            [image_ab_batch[:, 0], image_ab_batch[:, 1]], [ones_2, zeros_2])

        D0.trainable = False
        D1.trainable = False
        D2.trainable = False
        G.trainable = True
        image_ab_batch, _ = next(data_iterator)
        loss_g = G_model.train_on_batch(
            [image_ab_batch[:, 0], image_ab_batch[:, 1]],
            [ones, ones, ones, dummy, dummy, dummy])

        print("iter", i)
        if (i % iters_per_checkpoint) == 0:
            G.trainable = False
            fake_image = G.predict(img_a_fixed)
            log_images(fake_image, 'val_fake', i, logger)
            save_model(G, out_dir)

        log_losses(loss_d, loss_g, i, logger)