Exemplo n.º 1
0
def main():
    tf.random.set_seed(22)
    np.random.seed(22)
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    assert tf.__version__.startswith('2.')

    z_dim = 100
    epochs = 3000000
    batch_size = 512  # same as the batch_size of real images
    d_learning_rate = 0.005
    g_learning_rate = 0.002
    training = True

    img_path = glob.glob(r'.\faces\*.jpg')
    dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)
    print(dataset, img_shape)
    sample_picture = next(iter(dataset))
    print(sample_picture.shape, tf.reduce_max(sample_picture).numpy(), tf.reduce_min(sample_picture).numpy())
    dataset = dataset.repeat()
    ds_iter = iter(dataset)

    generator = Generator()
    generator.build(input_shape=(None, z_dim))

    discriminator = Discriminator()
    discriminator.build(input_shape=(None, 64, 64, 3))

    g_optimizer = tf.optimizers.RMSprop(learning_rate=g_learning_rate)
    d_optimizer = tf.optimizers.RMSprop(learning_rate=d_learning_rate)

    for epoch in range(epochs):
        batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
        batch_r = next(ds_iter)

        # discriminator training
        with tf.GradientTape() as tape:
            d_loss, gp = d_loss_func(generator, discriminator, batch_z, batch_r, training)
        grads = tape.gradient(d_loss, discriminator.trainable_variables)
        d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))

        if epoch % 5 == 0:
            with tf.GradientTape() as tape:
                g_loss = g_loss_func(generator, discriminator, batch_z, training)
            grads = tape.gradient(g_loss, generator.trainable_variables)
            g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        if epoch % 100 == 0:
            print('Current epoch:', epoch,
                  'd_loss:', float(d_loss), 'g_loss:', float(g_loss),
                  'gp:', float(gp))

            z = tf.random.uniform([100, z_dim])
            g_imgs = generator(z, training=False)
            save_path = os.path.join('images', 'wgan-%d.png' % epoch)
            save_result(g_imgs.numpy(), 10, save_path, color_mode='P')
Exemplo n.º 2
0
def main(batch_size=64, num_epochs=50, save_preds=True, train_method='ganhacks'):

    str2method = {
        'wgan': train_wgan_epoch,
        'ganhacks': train_ganhacks_epoch
    }

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(f"Using {device}")

    data_dir = 'mnist_data'
    mean, std = get_mnist_vals(data_location=data_dir)
    dataloader = torch.utils.data.DataLoader(
        datasets.MNIST(data_dir, train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize([mean], [std])
                       ])),
        batch_size=batch_size, shuffle=True)
    noisemaker = NoiseMaker(batch_size=batch_size, device=device)

    discriminator, generator = Discriminator(), Generator()

    if device.type != 'cpu':
        discriminator, generator = discriminator.cuda(), generator.cuda()

    for i in range(num_epochs):
        print(f"Epoch number {i + 1}")
        # ncritic is only meaningful if train_method == 'wgan'
        str2method[train_method](discriminator, generator, dataloader, noisemaker,
                                 ncritic=100 if ((i == 0) or (i == 10)) else 5,
                                 device=device)

    # save some predictions
    if save_preds:
        with torch.no_grad():
            generator.eval()
            noise = noisemaker()
            output = generator(noise).cpu().numpy()
            # denormalize
            output = (output * std) + mean
            np.save(f"{train_method}_generator_output.npy", output)
def _train_initialize_variables(model_str, model_params, opt_params, cuda):
    """Helper function that just initializes everything at the beginning of the train function"""
    # Params passed in as dict to model.
    if model_str == 'WGAN':
        D = WDisc(model_params)
        G = WGen(model_params)
        D.train()
        G.train()
    elif model_str == 'WDCGAN':
        D = WDCDisc(model_params)
        G = WDCGen(model_params)
        D.train()
        G.train()
    else:
        raise ValueError('Name unknown: %s' % model_str)

    d_optimizer = init_optimizer(opt_params, D)
    g_optimizer = init_optimizer(opt_params, G)

    if cuda:
        D = D.cuda()
        G = G.cuda()
    return G, D, g_optimizer, d_optimizer
Exemplo n.º 4
0
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])
train_dataset = datasets.CIFAR10(train_data_path,
                                 train=True,
                                 download=True,
                                 transform=my_trans)
test_dataset = datasets.CIFAR10(test_data_path,
                                train=False,
                                download=True,
                                transform=my_trans)
train_dataloader = DataLoader(train_dataset, batch_size=batchs, shuffle=True)
# test_dataloader = DataLoader(test_dataset,batch_size=batch,shuffle=False)

#Initialize generator and discriminator
generator = Generator(latent, img_shape)
discriminator = Discriminator(img_shape)
if cuda:
    generator.cuda()
    discriminator.cuda()

#Optimizers
optimizer_G = optim.RMSprop(generator.parameters(), lr=learning_rate)
optimizer_D = optim.RMSprop(discriminator.parameters(), lr=learning_rate)

#Training
score_tr = np.zeros((epochs, 2))
batch_done = 0
total_errG = 0
total_errD = 0
count = 0
Exemplo n.º 5
0
                                                            (0.5, 0.5, 0.5)),
                                   ]))
        # Create the data loader
        dataloader = torch.utils.data.DataLoader(dataset,
                                                 batch_size=cfg.batch_size,
                                                 shuffle=True,
                                                 num_workers=cfg.workers)
        # Save to cache
        torch.save(dataloader, dataloader_path)

    # Decide which device we want to run on
    device = torch.device("cuda:0" if (
        torch.cuda.is_available() and cfg.ngpu > 0) else "cpu")

    # Create the generator
    netG = Generator(cfg.ngpu).to(device)

    # Handle multi-gpu if desired
    if (device.type == 'cuda') and (cfg.ngpu > 1):
        netG = nn.DataParallel(netG, list(range(cfg.ngpu)))

    # Initialize all weights to mean=0, std=0.2.
    netG.apply(weights_init)

    # Print the model
    print(netG)

    # Create the Discriminator
    netD = Discriminator(cfg.ngpu).to(device)

    # Handle multi-gpu if desired
Exemplo n.º 6
0
else: dtype = torch.FloatTensor

#Models
file_base = "../../outputs/wgan_content_loss/"
load = True
# load = False

device_1 = torch.device("cuda:0")
device_2 = torch.device("cuda:1")
device_ids = [0]

if load:
    generator = torch.load(file_base + "generator_2loss.pt")
    discriminator = torch.load(file_base + "discriminator_2loss.pt")
else:
    generator = Generator(img_shape)
    discriminator = Discriminator(img_shape)


generator = nn.DataParallel(generator, device_ids=device_ids).type(dtype)
discriminator = nn.DataParallel(discriminator, device_ids=device_ids).type(dtype)
    
feature_extractor = FeatureExtractor().type(dtype)
# Optimizers
g_optimizer = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

'''
Load Data 
'''
    
Exemplo n.º 7
0
def main():
    tf.random.set_seed(233)
    np.random.seed(233)
    assert tf.__version__.startswith('2.')

    # hyper parameters
    z_dim = 100
    epochs = 3000000
    batch_size = 512
    learning_rate = 0.0005
    is_training = True

    img_path = glob.glob(r'./faces/*.jpg')
    assert len(img_path) > 0

    dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)
    print(dataset, img_shape)
    sample = next(iter(dataset))
    print(sample.shape,
          tf.reduce_max(sample).numpy(),
          tf.reduce_min(sample).numpy())
    dataset = dataset.repeat()
    db_iter = iter(dataset)

    generator = None
    USE_LOADED = False
    if (USE_LOADED and os.path.exists('./model.tf')):
        print('加载模型,继续上次训练')
        generator = tf.keras.models.load_model('./model.tf', compile=True)
    else:
        print('未找到保存的模型,重新开始训练')
        generator = Generator()
        generator.build(input_shape=(None, z_dim))
    discriminator = Discriminator()
    discriminator.build(input_shape=(None, 64, 64, 3))

    g_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,
                                           beta_1=0.5)
    d_optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate,
                                           beta_1=0.5)

    for epoch in range(epochs):

        for _ in range(5):
            batch_z = tf.random.normal([batch_size, z_dim])
            batch_x = next(db_iter)

            # train D
            with tf.GradientTape() as tape:
                d_loss, gp = d_loss_fn(generator, discriminator, batch_z,
                                       batch_x, is_training)
            grads = tape.gradient(d_loss, discriminator.trainable_variables)
            d_optimizer.apply_gradients(
                zip(grads, discriminator.trainable_variables))

        batch_z = tf.random.normal([batch_size, z_dim])

        with tf.GradientTape() as tape:
            g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
        grads = tape.gradient(g_loss, generator.trainable_variables)
        g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        if epoch % 100 == 0:
            print(epoch, 'd-loss:', float(d_loss), 'g-loss:', float(g_loss),
                  'gp:', float(gp))

            z = tf.random.normal([100, z_dim])
            fake_image = generator(z, training=False)
            img_path = os.path.join('images', 'wgan-%d.png' % epoch)
            save_result(fake_image.numpy(), 10, img_path, color_mode='P')
            # 子类实现的网络保存成h5格式不支持 参见 https://github.com/tensorflow/tensorflow/issues/29545
            generator.predict(
                z
            )  # 不调用一下,直接save 会报错,参见这里 https://github.com/tensorflow/tensorflow/issues/31057
            generator.save('./model.tf', overwrite=True, save_format="tf")