Exemplo n.º 1
0
def test_train_discriminator():
    """
    Make sure that the discriminator can achieve low loss, when
    not training the generator.
    """
    path = r'../fauxtograph/images/'
    paths = glob.glob(os.path.join(path, "*.jpg"))
    # Load images
    real_images = np.array( [ train.load_image(p) for p in paths ] )
    np.random.shuffle( real_images )
    total_samples, c_dim, x_dim, y_dim = real_images.shape

    train_real_images = np.array( [ im for im in real_images[ : int(total_samples/2)] ] )
    test_real_images = np.array( [ im for im in real_images[int(total_samples/2) : ] ] )

    fake_images = np.array( [ np.random.uniform(-1, 1, (3,64,64)) for n in range(len(real_images)) ] )

    train_fake_images = np.array( [ im for im in fake_images[ : int(total_samples/2)] ] )
    test_fake_images = np.array( [ im for im in fake_images[int(total_samples/2) : ] ] )

    assert len(train_fake_images) == len(train_real_images)
    assert len(test_fake_images) == len(test_real_images)

    X_train = np.concatenate((train_real_images, train_fake_images))
    y_train = [1] * len(train_real_images) + [0] * len(train_fake_images) # labels

    X_test = np.concatenate((test_real_images, test_fake_images))
    y_test = [1] * len(test_real_images) + [0] * len(test_fake_images) # labels

    discriminator = model.discriminator_model()
    adam=Adam(lr=0.0002, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
    discriminator.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy'])

    discriminator.fit(X_train, y_train, batch_size=128, nb_epoch=2, verbose=1, validation_data=(X_test, y_test) )
Exemplo n.º 2
0
    def __init__(self, args):

        self.img_size = args.imgsize
        self.channels = args.channels
        self.z_dim = args.zdims
        self.epochs = args.epoch
        self.batch_size = args.batchsize

        self.d_opt = Adam(lr=1e-5, beta_1=0.1)
        self.g_opt = Adam(lr=2e-4, beta_1=0.5)

        if not os.path.exists('./result/'):
            os.makedirs('./result/')
        if not os.path.exists('./model_images/'):
            os.makedirs('./model_images/')

        """ build discriminator model """
        self.d = model.discriminator_model(self.img_size, self.channels)
        plot_model(self.d, to_file='./model_images/discriminator.png', show_shapes=True)

        """ build generator model """
        self.g = model.generator_model(self.z_dim, self.img_size, self.channels)
        plot_model(self.g, to_file='./model_images/generator', show_shapes=True)

        """ discriminator on generator model """
        self.d_on_g = model.generator_containg_discriminator(self.g, self.d, self.z_dim)
        plot_model(self.d_on_g, to_file='./model_images/d_on_g', show_shapes=True)

        self.g.compile(loss='mse', optimizer=self.g_opt)
        self.d_on_g.compile(loss='mse', optimizer=self.g_opt)
        self.d.trainable = True
        self.d.compile(loss='mse', optimizer=self.d_opt)
Exemplo n.º 3
0
def test_discriminator_model():
    epochs = 1
    input_data = np.random.rand(1, 3, 64, 64)
    input_shape = input_data.shape

    discriminator = model.discriminator_model()
    adam=Adam(lr=0.0002, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
    discriminator.compile(loss='binary_crossentropy', optimizer=adam)
    pred = discriminator.predict(input_data)
    print pred
Exemplo n.º 4
0
def test_discriminator_model():
    epochs = 1
    input_data = np.random.rand(1, 3, 64, 64)
    input_shape = input_data.shape

    discriminator = model.discriminator_model()
    adam = Adam(lr=0.0002, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
    discriminator.compile(loss='binary_crossentropy', optimizer=adam)
    pred = discriminator.predict(input_data)
    print pred
Exemplo n.º 5
0
  def __init__(self, hparams):
    super(GAN, self).__init__()
    self.hparams = hparams

    self.netG = model.colorization_model()
    self.netD = model.discriminator_model()
    self.VGG_MODEL = torchvision.models.vgg16(pretrained=True)

    self.generated_imgs = None
    self.last_imgs = None
Exemplo n.º 6
0
def train_multiple_outputs(n_images, batch_size, epoch_num, critic_updates=5):
    data = load_images('./images/train', n_images)
    y_train, x_train = data['B'], data['A']

    g = generator_model()
    d = discriminator_model()
    d_on_g = generator_containing_discriminator_multiple_outputs(g, d)

    d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

    d.trainable = True
    d.compile(optimizer=d_opt, loss=wasserstein_loss)
    d.trainable = False
    loss = [perceptual_loss, wasserstein_loss]
    loss_weights = [100, 1]
    d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights)
    d.trainable = True

    output_true_batch, output_false_batch = np.ones((batch_size, 1)), -np.ones((batch_size, 1))

    for epoch in range(epoch_num):
        print('epoch: {}/{}'.format(epoch, epoch_num))
        print('batches: {}'.format(x_train.shape[0] / batch_size))

        permutated_indexes = np.random.permutation(x_train.shape[0])

        d_losses = []
        d_on_g_losses = []
        for index in range(int(x_train.shape[0] / batch_size)):
            batch_indexes = permutated_indexes[index*batch_size:(index+1)*batch_size]
            image_blur_batch = x_train[batch_indexes]
            image_full_batch = y_train[batch_indexes]

            generated_images = g.predict(x=image_blur_batch, batch_size=batch_size)

            for _ in range(critic_updates):
                d_loss_real = d.train_on_batch(image_full_batch, output_true_batch)
                d_loss_fake = d.train_on_batch(generated_images, output_false_batch)
                d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
                d_losses.append(d_loss)
            print('batch {} d_loss : {}'.format(index+1, np.mean(d_losses)))

            d.trainable = False

            d_on_g_loss = d_on_g.train_on_batch(image_blur_batch, [image_full_batch, output_true_batch])
            d_on_g_losses.append(d_on_g_loss)
            print('batch {} d_on_g_loss : {}'.format(index+1, d_on_g_loss))

            d.trainable = True

        with open('log.txt', 'a') as f:
            f.write('{} - {} - {}\n'.format(epoch, np.mean(d_losses), np.mean(d_on_g_losses)))

        save_all_weights(d, g, epoch, int(np.mean(d_on_g_losses)))
Exemplo n.º 7
0
def test_train_discriminator():
    """
    Make sure that the discriminator can achieve low loss, when
    not training the generator.
    """
    path = r'../fauxtograph/images/'
    paths = glob.glob(os.path.join(path, "*.jpg"))
    # Load images
    real_images = np.array([train.load_image(p) for p in paths])
    np.random.shuffle(real_images)
    total_samples, c_dim, x_dim, y_dim = real_images.shape

    train_real_images = np.array(
        [im for im in real_images[:int(total_samples / 2)]])
    test_real_images = np.array(
        [im for im in real_images[int(total_samples / 2):]])

    fake_images = np.array([
        np.random.uniform(-1, 1, (3, 64, 64)) for n in range(len(real_images))
    ])

    train_fake_images = np.array(
        [im for im in fake_images[:int(total_samples / 2)]])
    test_fake_images = np.array(
        [im for im in fake_images[int(total_samples / 2):]])

    assert len(train_fake_images) == len(train_real_images)
    assert len(test_fake_images) == len(test_real_images)

    X_train = np.concatenate((train_real_images, train_fake_images))
    y_train = [1] * len(train_real_images) + [0] * len(
        train_fake_images)  # labels

    X_test = np.concatenate((test_real_images, test_fake_images))
    y_test = [1] * len(test_real_images) + [0] * len(
        test_fake_images)  # labels

    discriminator = model.discriminator_model()
    adam = Adam(lr=0.0002, beta_1=0.5, beta_2=0.999, epsilon=1e-08)
    discriminator.compile(loss='binary_crossentropy',
                          optimizer=adam,
                          metrics=['accuracy'])

    discriminator.fit(X_train,
                      y_train,
                      batch_size=128,
                      nb_epoch=2,
                      verbose=1,
                      validation_data=(X_test, y_test))
Exemplo n.º 8
0
g_model = generator_model(vocab_size=len(reader.d),
                          embedding_size=128,
                          lstm_size=128,
                          num_layer=4,
                          max_length_encoder=40,
                          max_length_decoder=40,
                          max_gradient_norm=2,
                          batch_size_num=20,
                          learning_rate=0.001,
                          beam_width=5)
d_model = discriminator_model(vocab_size=len(reader.d),
                              embedding_size=128,
                              lstm_size=128,
                              num_layer=4,
                              max_post_length=40,
                              max_resp_length=40,
                              max_gradient_norm=2,
                              batch_size_num=20,
                              learning_rate=0.001)

saver = tf.train.Saver(tf.global_variables(), keep_checkpoint_every_n_hours=1.0)

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
try:
    loader = tf.train.import_meta_graph('saved/model.ckpt.meta')
    loader.restore(sess, tf.train.latest_checkpoint('saved/'))
    print('load finished')
except:
Exemplo n.º 9
0
def train(path, batch_size, EPOCHS):
    # reproducibility
    # np.random.seed(42)

    # fig = plt.figure()

    # Get image paths
    print("Loading paths..")
    paths = glob.glob(os.path.join(path, "*.jpg"))
    print("Got paths..")
    print(paths)

    # Load images
    IMAGES = np.array([load_image(p) for p in paths])
    np.random.shuffle(IMAGES)

    print(IMAGES[0])

    # IMAGES, labels = load_mnist(dataset="training", digits=np.arange(10), path=path)
    # IMAGES = np.array( [ np.array( [ scipy.misc.imresize(p, (64, 64)) / 256 ] * 3 ) for p in IMAGES ] )

    # np.random.shuffle( IMAGES )

    BATCHES = [b for b in chunks(IMAGES, batch_size)]

    discriminator = model.discriminator_model()
    generator = model.generator_model()
    discriminator_on_generator = model.generator_containing_discriminator(
        generator, discriminator)
    # adam_gen=Adam(lr=0.0002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08)
    adam_gen = Adam(lr=0.00002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08)
    adam_dis = Adam(lr=0.00002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08)
    # opt = RMSprop()
    generator.compile(loss='binary_crossentropy', optimizer=adam_gen)
    discriminator_on_generator.compile(loss='binary_crossentropy',
                                       optimizer=adam_gen)
    discriminator.trainable = True
    discriminator.compile(loss='binary_crossentropy', optimizer=adam_dis)

    print("Number of batches", len(BATCHES))
    print("Batch size is", batch_size)

    # margin = 0.25
    # equilibrium = 0.6931
    inter_model_margin = 0.10

    for epoch in range(EPOCHS):
        print()
        print("Epoch", epoch)
        print()

        # load weights on first try (i.e. if process failed previously and we are attempting to recapture lost data)
        if epoch == 0:
            if os.path.exists('generator_weights') and os.path.exists(
                    'discriminator_weights'):
                print("Loading saves weights..")
                generator.load_weights('generator_weights')
                discriminator.load_weights('discriminator_weights')
                print("Finished loading")
            else:
                pass

        for index, image_batch in enumerate(BATCHES):
            print("Epoch", epoch, "Batch", index)

            Noise_batch = np.array(
                [noise_image() for n in range(len(image_batch))])
            generated_images = generator.predict(Noise_batch)
            # print generated_images[0][-1][-1]

            for i, img in enumerate(generated_images):
                rolled = np.rollaxis(img, 0, 3)
                cv2.imwrite('results/' + str(i) + ".jpg",
                            np.uint8(255 * 0.5 * (rolled + 1.0)))

            Xd = np.concatenate((image_batch, generated_images))
            yd = [1] * len(image_batch) + [0] * len(image_batch)  # labels

            print("Training first discriminator..")
            d_loss = discriminator.train_on_batch(Xd, yd)

            Xg = Noise_batch
            yg = [1] * len(image_batch)

            print("Training first generator..")
            g_loss = discriminator_on_generator.train_on_batch(Xg, yg)

            print("Initial batch losses : ", "Generator loss", g_loss,
                  "Discriminator loss", d_loss, "Total:", g_loss + d_loss)

            # print "equilibrium - margin", equilibrium - margin

            if g_loss < d_loss and abs(d_loss - g_loss) > inter_model_margin:
                # for j in range(handicap):
                while abs(d_loss - g_loss) > inter_model_margin:
                    print("Updating discriminator..")
                    # g_loss = discriminator_on_generator.train_on_batch(Xg, yg)
                    d_loss = discriminator.train_on_batch(Xd, yd)
                    print("Generator loss", g_loss, "Discriminator loss",
                          d_loss)
                    if d_loss < g_loss:
                        break
            elif d_loss < g_loss and abs(d_loss - g_loss) > inter_model_margin:
                # for j in range(handicap):
                while abs(d_loss - g_loss) > inter_model_margin:
                    print("Updating generator..")
                    # d_loss = discriminator.train_on_batch(Xd, yd)
                    g_loss = discriminator_on_generator.train_on_batch(Xg, yg)
                    print("Generator loss", g_loss, "Discriminator loss",
                          d_loss)
                    if g_loss < d_loss:
                        break
            else:
                pass

            print("Final batch losses (after updates) : ", "Generator loss",
                  g_loss, "Discriminator loss", d_loss, "Total:",
                  g_loss + d_loss)
            print()
            if index % 20 == 0:
                print('Saving weights..')
                generator.save_weights('generator_weights', True)
                discriminator.save_weights('discriminator_weights', True)

        plt.clf()
        for i, img in enumerate(generated_images[:5]):
            i = i + 1
            plt.subplot(3, 3, i)
            rolled = np.rollaxis(img, 0, 3)
            # plt.imshow(rolled, cmap='gray')
            plt.imshow(rolled)
            plt.axis('off')
        # fig.canvas.draw()
        plt.savefig('Epoch_' + str(epoch) + '.png')
Exemplo n.º 10
0
def train(path, batch_size, EPOCHS):

    #reproducibility
    #np.random.seed(42)

    fig = plt.figure()

    # Get image paths
    print "Loading paths.."
    paths = glob.glob(os.path.join(path, "*.jpg"))
    print "Got paths.."

    # Load images
    IMAGES = np.array( [ load_image(p) for p in paths ] )
    np.random.shuffle( IMAGES )

    #IMAGES, labels = load_mnist(dataset="training", digits=np.arange(10), path=path)
    #IMAGES = np.array( [ np.array( [ scipy.misc.imresize(p, (64, 64)) / 256 ] * 3 ) for p in IMAGES ] )

    #np.random.shuffle( IMAGES )

    BATCHES = [ b for b in chunks(IMAGES, batch_size) ]

    discriminator = model.discriminator_model()
    generator = model.generator_model()
    discriminator_on_generator = model.generator_containing_discriminator(generator, discriminator)
    #adam_gen=Adam(lr=0.0002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08)
    adam_gen=Adam(lr=0.00002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08)
    adam_dis=Adam(lr=0.00002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08)
    #opt = RMSprop()
    generator.compile(loss='binary_crossentropy', optimizer=adam_gen)
    discriminator_on_generator.compile(loss='binary_crossentropy', optimizer=adam_gen)
    discriminator.trainable = True
    discriminator.compile(loss='binary_crossentropy', optimizer=adam_dis)

    print "Number of batches", len(BATCHES)
    print "Batch size is", batch_size

    #margin = 0.25
    #equilibrium = 0.6931
    inter_model_margin = 0.10

    for epoch in range(EPOCHS):
        print
        print "Epoch", epoch
        print

        # load weights on first try (i.e. if process failed previously and we are attempting to recapture lost data)
        if epoch == 0:
            if os.path.exists('generator_weights') and os.path.exists('discriminator_weights'):
                print "Loading saves weights.."
                generator.load_weights('generator_weights')
                discriminator.load_weights('discriminator_weights')
                print "Finished loading"
            else:
                pass

        for index, image_batch in enumerate(BATCHES):
            print "Epoch", epoch, "Batch", index

            Noise_batch = np.array( [ noise_image() for n in range(len(image_batch)) ] )
            generated_images = generator.predict(Noise_batch)
            #print generated_images[0][-1][-1]

            for i, img in enumerate(generated_images):
                rolled = np.rollaxis(img, 0, 3)
                cv2.imwrite('results/' + str(i) + ".jpg", np.uint8(255 * 0.5 * (rolled + 1.0)))

            Xd = np.concatenate((image_batch, generated_images))
            yd = [1] * len(image_batch) + [0] * len(image_batch) # labels

            print "Training first discriminator.."
            d_loss = discriminator.train_on_batch(Xd, yd)

            Xg = Noise_batch
            yg = [1] * len(image_batch)

            print "Training first generator.."
            g_loss = discriminator_on_generator.train_on_batch(Xg, yg)

            print "Initial batch losses : ", "Generator loss", g_loss, "Discriminator loss", d_loss, "Total:", g_loss + d_loss

            #print "equilibrium - margin", equilibrium - margin

            if g_loss < d_loss and abs(d_loss - g_loss) > inter_model_margin:
                #for j in range(handicap):
                while abs(d_loss - g_loss) > inter_model_margin:
                    print "Updating discriminator.."
                    #g_loss = discriminator_on_generator.train_on_batch(Xg, yg)
                    d_loss = discriminator.train_on_batch(Xd, yd)
                    print "Generator loss", g_loss, "Discriminator loss", d_loss
                    if d_loss < g_loss:
                        break
            elif d_loss < g_loss and abs(d_loss - g_loss) > inter_model_margin:
                #for j in range(handicap):
                while abs(d_loss - g_loss) > inter_model_margin:
                    print "Updating generator.."
                    #d_loss = discriminator.train_on_batch(Xd, yd)
                    g_loss = discriminator_on_generator.train_on_batch(Xg, yg)
                    print "Generator loss", g_loss, "Discriminator loss", d_loss
                    if g_loss < d_loss:
                        break
            else:
                pass

            print "Final batch losses (after updates) : ", "Generator loss", g_loss, "Discriminator loss", d_loss, "Total:", g_loss + d_loss
            print

            if index % 20 == 0:
                print 'Saving weights..'
                generator.save_weights('generator_weights', True)
                discriminator.save_weights('discriminator_weights', True)

        plt.clf()
        for i, img in enumerate(generated_images[:5]):
            i = i+1
            plt.subplot(3, 3, i)
            rolled = np.rollaxis(img, 0, 3)
            #plt.imshow(rolled, cmap='gray')
            plt.imshow(rolled)
            plt.axis('off')
        fig.canvas.draw()
        plt.savefig('Epoch_' + str(epoch) + '.png')
Exemplo n.º 11
0
def train_multiple_outputs(n_images, batch_size, epoch_num, critic_updates=5):
    #data = load_images('/home/turing/td/', n_images)
    y_train = sorted(glob.glob('/home/turing/td/data/*.png'))
    x_train = sorted(glob.glob('/home/turing/td/blur/*.png'))
    print('loaded_data')
    g = generator_model()
    g.load_weights('weights/424/generator_19_290.h5')
    d = discriminator_model()
    d.load_weights('weights/424/discriminator_19.h5')

    d_on_g = generator_containing_discriminator_multiple_outputs(g, d)

    d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

    d.trainable = True
    d.compile(optimizer=d_opt, loss=wasserstein_loss)
    d.trainable = False
    loss = [perceptual_loss, wasserstein_loss]
    loss_weights = [100, 1]
    d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights)
    d.trainable = True

    output_true_batch, output_false_batch = np.ones((batch_size, 1)), np.zeros(
        (batch_size, 1))

    for epoch in range(epoch_num):
        print('epoch: {}/{}'.format(epoch, epoch_num))
        print('batches: {}'.format(len(x_train) / batch_size))

        permutated_indexes = np.random.permutation(len(x_train))

        d_losses = []
        d_on_g_losses = []
        for index in range(int(len(x_train) / batch_size)):
            batch_indexes = permutated_indexes[index * batch_size:(index + 1) *
                                               batch_size]
            x_t = []
            y_t = []
            for i in batch_indexes:
                x_t.append(x_train[i])
                y_t.append(y_train[i])
            image_blur_batch = load_batch(x_t)
            image_full_batch = load_batch(y_t)

            generated_images = g.predict(x=image_blur_batch,
                                         batch_size=batch_size)

            for _ in range(critic_updates):
                d_loss_real = d.train_on_batch(image_full_batch,
                                               output_true_batch)
                d_loss_fake = d.train_on_batch(generated_images,
                                               output_false_batch)
                d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
                d_losses.append(d_loss)
            print('batch {} d_loss : {}'.format(index + 1, np.mean(d_losses)))

            d.trainable = False

            d_on_g_loss = d_on_g.train_on_batch(
                image_blur_batch, [image_full_batch, output_true_batch])
            d_on_g_losses.append(d_on_g_loss)
            print('batch {} d_on_g_loss : {}'.format(index + 1, d_on_g_loss))

            d.trainable = True

        with open('log.txt', 'a') as f:
            f.write('{} - {} - {}\n'.format(epoch, np.mean(d_losses),
                                            np.mean(d_on_g_losses)))

        save_all_weights(d, g, epoch, int(np.mean(d_on_g_losses)))
Exemplo n.º 12
0
train_images=train_images.reshape(-1, 28, 28, 1).astype('float32')
train_images=(train_images-127.5)/127.5

train_dataset=tf.data.Dataset.from_tensor_slices(train_images).shuffle(len(train_images)).batch(batch_size)

def discriminator_loss(real_output, fake_output):
    real_loss=cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss=cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss=real_loss+fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

generator=model.generator_model()
discriminator=model.discriminator_model()

cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True)

train_generator_loss=tf.keras.metrics.Mean()
train_discriminator_loss=tf.keras.metrics.Mean()

generator_optimizer=tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer=tf.keras.optimizers.Adam(1e-4)

checkpoint=tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

seed=tf.random.normal([num_examples_to_generate, noise_dim])
Exemplo n.º 13
0
def train_multiple_outputs(n_images, batch_size, epoch_num, critic_updates=5):

    g = generator_model()
    d = discriminator_model()
    g.load_weights('generator.h5')
    d.load_weights('discriminator.h5')
    d_on_g = generator_containing_discriminator_multiple_outputs(g, d)

    d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

    d.trainable = True
    d.compile(optimizer=d_opt, loss=wasserstein_loss)
    d.trainable = False
    loss = [perceptual_loss, wasserstein_loss]
    loss_weights = [100, 1]
    d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights)
    d.trainable = True

    output_true_batch, output_false_batch = np.ones((batch_size, 1)), np.zeros(
        (batch_size, 1))

    for epoch in range(epoch_num):
        print('epoch: {}/{}'.format(epoch, epoch_num))
        print('batches: {}'.format(batch_size))
        start = 0

        d_losses = []
        d_on_g_losses = []
        shuffle()
        for index in range(int(25000 // batch_size)):
            data = load_images(start, batch_size)
            y_train, x_train = data['B'], data['A']
            image_blur_batch = x_train
            image_full_batch = y_train
            generated_images = g.predict(x=image_blur_batch,
                                         batch_size=batch_size)

            for _ in range(critic_updates):
                d_loss_real = d.train_on_batch(image_full_batch,
                                               output_true_batch)
                d_loss_fake = d.train_on_batch(generated_images,
                                               output_false_batch)
                d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
                d_losses.append(d_loss)
            print('batch {} d_loss : {}'.format(index + 1, np.mean(d_losses)))

            d.trainable = False

            d_on_g_loss = d_on_g.train_on_batch(
                image_blur_batch, [image_full_batch, output_true_batch])
            d_on_g_losses.append(d_on_g_loss)
            print('batch {} d_on_g_loss : {}'.format(index + 1, d_on_g_loss))

            d.trainable = True
            if (index % 300):
                save_all_weights(d, g, epoch, int(index * 10))
            start += batch_size

        with open('log.txt', 'a') as f:
            f.write('{} - {} - {}\n'.format(epoch, np.mean(d_losses),
                                            np.mean(d_on_g_losses)))

        save_all_weights(d, g, epoch, int(np.mean(d_on_g_losses)))
Exemplo n.º 14
0
def train(gen,disc,cGAN,gray,rgb,gray_val,rgb_val,batch):
    samples = len(rgb)
    gen_image = gen.predict(gray, batch_size=16)   
    gen_image_val = gen.predict(gray_val, batch_size=8)
    inputs = np.concatenate([gray, gray])
    outputs = np.concatenate([rgb, gen_image])
    y = np.concatenate([np.ones((samples, 1)), np.zeros((samples, 1))])
    disc.fit([inputs, outputs], y, epochs=1, batch_size=4)
    disc.trainable = False
    cGAN.fit(gray, [np.ones((samples, 1)), rgb], epochs=1, batch_size=batch,validation_data=[gray_val,[np.ones((val_samples,1)),rgb_val]])
    disc.trainable = True

gen = generator_model(x_shape,y_shape)

disc = discriminator_model(x_shape,y_shape)

cGAN = cGAN_model(gen, disc)
# cGAN.load_weights('sketchColorisation/result/store/9950.h5')

disc.compile(loss=['binary_crossentropy'], optimizer=tf.keras.optimizers.Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08), metrics=['accuracy'])

cGAN.compile(loss=['binary_crossentropy',custom_loss_2], loss_weights=[5, 100], optimizer=tf.keras.optimizers.Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08))
tensorboard = tf.keras.callbacks.TensorBoard(log_dir="logs/{}".format(time()))

dataset = 'sketchColorisation/Images/' 
graystore = 'sketchColorisation/grayScale/'
rgbstore = 'sketchColorisation/colored/'
val_data = 'sketchColorisation/validation/'
store = 'sketchColorisation/result/store/'
store2 = 'sketchColorisation/result/store2/'
Exemplo n.º 15
0
def train(batch_size, epoch_num):
    # Note the x(blur) in the second, the y(full) in the first
    y_train, x_train = data_utils.load_data(data_type='train')

    # GAN
    g = generator_model()
    d = discriminator_model()
    d_on_g = generator_containing_discriminator(g, d)

    # compile the models, use default optimizer parameters
    # generator use adversarial loss
    g.compile(optimizer='adam', loss=generator_loss)
    # discriminator use binary cross entropy loss
    d.compile(optimizer='adam', loss='binary_crossentropy')
    # adversarial net use adversarial loss
    d_on_g.compile(optimizer='adam', loss=adversarial_loss)

    for epoch in range(epoch_num):
        print('epoch: ', epoch + 1, '/', epoch_num)
        print('batches: ', int(x_train.shape[0] / batch_size))

        for index in range(int(x_train.shape[0] / batch_size)):
            # select a batch data
            image_blur_batch = x_train[index * batch_size:(index + 1) *
                                       batch_size]
            image_full_batch = y_train[index * batch_size:(index + 1) *
                                       batch_size]
            generated_images = g.predict(x=image_blur_batch,
                                         batch_size=batch_size)

            # output generated images for each 30 iters
            if (index % 30 == 0) and (index != 0):
                data_utils.generate_image(image_full_batch, image_blur_batch,
                                          generated_images, 'result/interim/',
                                          epoch, index)

            # concatenate the full and generated images,
            # the full images at top, the generated images at bottom
            x = np.concatenate((image_full_batch, generated_images))

            # generate labels for the full and generated images
            y = [1] * batch_size + [0] * batch_size

            # train discriminator
            d_loss = d.train_on_batch(x, y)
            print('batch %d d_loss : %f' % (index + 1, d_loss))

            # let discriminator can't be trained
            d.trainable = False

            # train adversarial net
            d_on_g_loss = d_on_g.train_on_batch(image_blur_batch,
                                                [1] * batch_size)
            print('batch %d d_on_g_loss : %f' % (index + 1, d_on_g_loss))

            # train generator
            g_loss = g.train_on_batch(image_blur_batch, image_full_batch)
            print('batch %d g_loss : %f' % (index + 1, g_loss))

            # let discriminator can be trained
            d.trainable = True

            # output weights for generator and discriminator each 30 iters
            if (index % 30 == 0) and (index != 0):
                g.save_weights('weight/generator_weights.h5', True)
                d.save_weights('weight/discriminator_weights.h5', True)
Exemplo n.º 16
0
def train_multiple_outputs(n_images, batch_size, epoch_num, critic_updates=5):
    g = generator_model()
    d = discriminator_model()
    vgg = build_vgg()
    d_on_g = generator_containing_discriminator_multiple_outputs(g, d, vgg)

    d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
    d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

    optimizer = Adam(1E-4, 0.5)
    vgg.trainable = False
    vgg.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])

    d.trainable = True
    d.compile(optimizer=d_opt, loss='binary_crossentropy')
    d.trainable = False
    loss = ['mae', 'mse', 'binary_crossentropy']
    loss_weights = [0.1, 100, 1]
    d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights)
    d.trainable = True

    output_true_batch, output_false_batch = np.ones((batch_size, 1)), np.zeros(
        (batch_size, 1))

    for epoch in range(epoch_num):
        print('epoch: {}/{}'.format(epoch, epoch_num))

        y_pre, x_pre, mask = load_data(batch_size)

        d_losses = []
        d_on_g_losses = []

        generated_images = g.predict(x=x_pre, batch_size=batch_size)

        for _ in range(critic_updates):
            d_loss_real = d.train_on_batch(y_pre, output_true_batch)
            d_loss_fake = d.train_on_batch(generated_images,
                                           output_false_batch)
            d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
            d_losses.append(d_loss)
        print('batch {} d_loss : {}'.format(epoch, np.mean(d_losses)))

        d.trainable = False

        real_result = mask * y_pre
        y_features = vgg.predict(y_pre)

        d_on_g_loss = d_on_g.train_on_batch(
            [x_pre, mask], [real_result, y_features, output_true_batch])
        d_on_g_losses.append(d_on_g_loss)
        print('batch {} d_on_g_loss : {}'.format(epoch, d_on_g_loss))

        d.trainable = True

        if epoch % 100 == 0:
            generated = np.array([(img + 1) * 127.5
                                  for img in generated_images])
            full = np.array([(img + 1) * 127.5 for img in y_pre])
            blur = np.array([(img + 1) * 127.5 for img in x_pre])

            for i in range(3):
                img_ge = generated[i, :, :, :]
                img_fu = full[i, :, :, :]
                img_bl = blur[i, :, :, :]
                output = np.concatenate((img_ge, img_fu, img_bl), axis=1)
                cv2.imwrite(
                    '/home/alyssa/PythonProjects/occluded/key_code/img_inpainting/out/'
                    + str(epoch) + '_' + str(i) + '.jpg', output)

        if (epoch > 10000 and epoch % 1000 == 0):
            save_all_weights(d, g, epoch, int(np.mean(d_on_g_losses)))
Exemplo n.º 17
0
def map_fn(index=None, flags=None):
    torch.set_default_tensor_type('torch.FloatTensor')
    torch.manual_seed(1234)

    train_data = dataset.DATA(config.TRAIN_DIR)

    if config.MULTI_CORE:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_data,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True)
    else:
        train_sampler = torch.utils.data.RandomSampler(train_data)

    train_loader = torch.utils.data.DataLoader(
        train_data,
        batch_size=flags['batch_size']
        if config.MULTI_CORE else config.BATCH_SIZE,
        sampler=train_sampler,
        num_workers=flags['num_workers'] if config.MULTI_CORE else 4,
        drop_last=True,
        pin_memory=True)

    if config.MULTI_CORE:
        DEVICE = xm.xla_device()
    else:
        DEVICE = config.DEVICE

    netG = model.colorization_model().double()
    netD = model.discriminator_model().double()

    VGG_modelF = torchvision.models.vgg16(pretrained=True).double()
    VGG_modelF.requires_grad_(False)

    netG = netG.to(DEVICE)
    netD = netD.to(DEVICE)

    VGG_modelF = VGG_modelF.to(DEVICE)

    optD = torch.optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.999))
    optG = torch.optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.999))

    ## Trains
    train_start = time.time()
    losses = {
        'G_losses': [],
        'D_losses': [],
        'EPOCH_G_losses': [],
        'EPOCH_D_losses': [],
        'G_losses_eval': []
    }

    netG, optG, netD, optD, epoch_checkpoint = utils.load_checkpoint(
        config.CHECKPOINT_DIR, netG, optG, netD, optD, DEVICE)
    netGAN = model.GAN(netG, netD)
    for epoch in range(
            epoch_checkpoint, flags['num_epochs'] +
            1 if config.MULTI_CORE else config.NUM_EPOCHS + 1):
        print('\n')
        print('#' * 8, f'EPOCH-{epoch}', '#' * 8)
        losses['EPOCH_G_losses'] = []
        losses['EPOCH_D_losses'] = []
        if config.MULTI_CORE:
            para_train_loader = pl.ParallelLoader(
                train_loader, [DEVICE]).per_device_loader(DEVICE)
            engine.train(para_train_loader,
                         netGAN,
                         netD,
                         VGG_modelF,
                         optG,
                         optD,
                         device=DEVICE,
                         losses=losses)
            elapsed_train_time = time.time() - train_start
            print("Process", index, "finished training. Train time was:",
                  elapsed_train_time)
        else:
            engine.train(train_loader,
                         netGAN,
                         netD,
                         VGG_modelF,
                         optG,
                         optD,
                         device=DEVICE,
                         losses=losses)
        #########################CHECKPOINTING#################################
        utils.create_checkpoint(epoch,
                                netG,
                                optG,
                                netD,
                                optD,
                                max_checkpoint=config.KEEP_CKPT,
                                save_path=config.CHECKPOINT_DIR)
        ########################################################################
        utils.plot_some(train_data, netG, DEVICE, epoch)
        gc.collect()