示例#1
0
def train(path, batch_size, EPOCHS):
    # reproducibility
    # np.random.seed(42)

    # fig = plt.figure()

    # Get image paths
    print("Loading paths..")
    paths = glob.glob(os.path.join(path, "*.jpg"))
    print("Got paths..")
    print(paths)

    # Load images
    IMAGES = np.array([load_image(p) for p in paths])
    np.random.shuffle(IMAGES)

    print(IMAGES[0])

    # IMAGES, labels = load_mnist(dataset="training", digits=np.arange(10), path=path)
    # IMAGES = np.array( [ np.array( [ scipy.misc.imresize(p, (64, 64)) / 256 ] * 3 ) for p in IMAGES ] )

    # np.random.shuffle( IMAGES )

    BATCHES = [b for b in chunks(IMAGES, batch_size)]

    discriminator = model.discriminator_model()
    generator = model.generator_model()
    discriminator_on_generator = model.generator_containing_discriminator(
        generator, discriminator)
    # adam_gen=Adam(lr=0.0002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08)
    adam_gen = Adam(lr=0.00002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08)
    adam_dis = Adam(lr=0.00002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08)
    # opt = RMSprop()
    generator.compile(loss='binary_crossentropy', optimizer=adam_gen)
    discriminator_on_generator.compile(loss='binary_crossentropy',
                                       optimizer=adam_gen)
    discriminator.trainable = True
    discriminator.compile(loss='binary_crossentropy', optimizer=adam_dis)

    print("Number of batches", len(BATCHES))
    print("Batch size is", batch_size)

    # margin = 0.25
    # equilibrium = 0.6931
    inter_model_margin = 0.10

    for epoch in range(EPOCHS):
        print()
        print("Epoch", epoch)
        print()

        # load weights on first try (i.e. if process failed previously and we are attempting to recapture lost data)
        if epoch == 0:
            if os.path.exists('generator_weights') and os.path.exists(
                    'discriminator_weights'):
                print("Loading saves weights..")
                generator.load_weights('generator_weights')
                discriminator.load_weights('discriminator_weights')
                print("Finished loading")
            else:
                pass

        for index, image_batch in enumerate(BATCHES):
            print("Epoch", epoch, "Batch", index)

            Noise_batch = np.array(
                [noise_image() for n in range(len(image_batch))])
            generated_images = generator.predict(Noise_batch)
            # print generated_images[0][-1][-1]

            for i, img in enumerate(generated_images):
                rolled = np.rollaxis(img, 0, 3)
                cv2.imwrite('results/' + str(i) + ".jpg",
                            np.uint8(255 * 0.5 * (rolled + 1.0)))

            Xd = np.concatenate((image_batch, generated_images))
            yd = [1] * len(image_batch) + [0] * len(image_batch)  # labels

            print("Training first discriminator..")
            d_loss = discriminator.train_on_batch(Xd, yd)

            Xg = Noise_batch
            yg = [1] * len(image_batch)

            print("Training first generator..")
            g_loss = discriminator_on_generator.train_on_batch(Xg, yg)

            print("Initial batch losses : ", "Generator loss", g_loss,
                  "Discriminator loss", d_loss, "Total:", g_loss + d_loss)

            # print "equilibrium - margin", equilibrium - margin

            if g_loss < d_loss and abs(d_loss - g_loss) > inter_model_margin:
                # for j in range(handicap):
                while abs(d_loss - g_loss) > inter_model_margin:
                    print("Updating discriminator..")
                    # g_loss = discriminator_on_generator.train_on_batch(Xg, yg)
                    d_loss = discriminator.train_on_batch(Xd, yd)
                    print("Generator loss", g_loss, "Discriminator loss",
                          d_loss)
                    if d_loss < g_loss:
                        break
            elif d_loss < g_loss and abs(d_loss - g_loss) > inter_model_margin:
                # for j in range(handicap):
                while abs(d_loss - g_loss) > inter_model_margin:
                    print("Updating generator..")
                    # d_loss = discriminator.train_on_batch(Xd, yd)
                    g_loss = discriminator_on_generator.train_on_batch(Xg, yg)
                    print("Generator loss", g_loss, "Discriminator loss",
                          d_loss)
                    if g_loss < d_loss:
                        break
            else:
                pass

            print("Final batch losses (after updates) : ", "Generator loss",
                  g_loss, "Discriminator loss", d_loss, "Total:",
                  g_loss + d_loss)
            print()
            if index % 20 == 0:
                print('Saving weights..')
                generator.save_weights('generator_weights', True)
                discriminator.save_weights('discriminator_weights', True)

        plt.clf()
        for i, img in enumerate(generated_images[:5]):
            i = i + 1
            plt.subplot(3, 3, i)
            rolled = np.rollaxis(img, 0, 3)
            # plt.imshow(rolled, cmap='gray')
            plt.imshow(rolled)
            plt.axis('off')
        # fig.canvas.draw()
        plt.savefig('Epoch_' + str(epoch) + '.png')
示例#2
0
def train(batch_size, epoch_num):
    # Note the x(blur) in the second, the y(full) in the first
    y_train, x_train = data_utils.load_data(data_type='train')

    # GAN
    g = generator_model()
    d = discriminator_model()
    d_on_g = generator_containing_discriminator(g, d)

    # compile the models, use default optimizer parameters
    # generator use adversarial loss
    g.compile(optimizer='adam', loss=generator_loss)
    # discriminator use binary cross entropy loss
    d.compile(optimizer='adam', loss='binary_crossentropy')
    # adversarial net use adversarial loss
    d_on_g.compile(optimizer='adam', loss=adversarial_loss)

    for epoch in range(epoch_num):
        print('epoch: ', epoch + 1, '/', epoch_num)
        print('batches: ', int(x_train.shape[0] / batch_size))

        for index in range(int(x_train.shape[0] / batch_size)):
            # select a batch data
            image_blur_batch = x_train[index * batch_size:(index + 1) *
                                       batch_size]
            image_full_batch = y_train[index * batch_size:(index + 1) *
                                       batch_size]
            generated_images = g.predict(x=image_blur_batch,
                                         batch_size=batch_size)

            # output generated images for each 30 iters
            if (index % 30 == 0) and (index != 0):
                data_utils.generate_image(image_full_batch, image_blur_batch,
                                          generated_images, 'result/interim/',
                                          epoch, index)

            # concatenate the full and generated images,
            # the full images at top, the generated images at bottom
            x = np.concatenate((image_full_batch, generated_images))

            # generate labels for the full and generated images
            y = [1] * batch_size + [0] * batch_size

            # train discriminator
            d_loss = d.train_on_batch(x, y)
            print('batch %d d_loss : %f' % (index + 1, d_loss))

            # let discriminator can't be trained
            d.trainable = False

            # train adversarial net
            d_on_g_loss = d_on_g.train_on_batch(image_blur_batch,
                                                [1] * batch_size)
            print('batch %d d_on_g_loss : %f' % (index + 1, d_on_g_loss))

            # train generator
            g_loss = g.train_on_batch(image_blur_batch, image_full_batch)
            print('batch %d g_loss : %f' % (index + 1, g_loss))

            # let discriminator can be trained
            d.trainable = True

            # output weights for generator and discriminator each 30 iters
            if (index % 30 == 0) and (index != 0):
                g.save_weights('weight/generator_weights.h5', True)
                d.save_weights('weight/discriminator_weights.h5', True)
示例#3
0
文件: train.py 项目: jhayes14/GAN
def train(path, batch_size, EPOCHS):

    #reproducibility
    #np.random.seed(42)

    fig = plt.figure()

    # Get image paths
    print "Loading paths.."
    paths = glob.glob(os.path.join(path, "*.jpg"))
    print "Got paths.."

    # Load images
    IMAGES = np.array( [ load_image(p) for p in paths ] )
    np.random.shuffle( IMAGES )

    #IMAGES, labels = load_mnist(dataset="training", digits=np.arange(10), path=path)
    #IMAGES = np.array( [ np.array( [ scipy.misc.imresize(p, (64, 64)) / 256 ] * 3 ) for p in IMAGES ] )

    #np.random.shuffle( IMAGES )

    BATCHES = [ b for b in chunks(IMAGES, batch_size) ]

    discriminator = model.discriminator_model()
    generator = model.generator_model()
    discriminator_on_generator = model.generator_containing_discriminator(generator, discriminator)
    #adam_gen=Adam(lr=0.0002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08)
    adam_gen=Adam(lr=0.00002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08)
    adam_dis=Adam(lr=0.00002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08)
    #opt = RMSprop()
    generator.compile(loss='binary_crossentropy', optimizer=adam_gen)
    discriminator_on_generator.compile(loss='binary_crossentropy', optimizer=adam_gen)
    discriminator.trainable = True
    discriminator.compile(loss='binary_crossentropy', optimizer=adam_dis)

    print "Number of batches", len(BATCHES)
    print "Batch size is", batch_size

    #margin = 0.25
    #equilibrium = 0.6931
    inter_model_margin = 0.10

    for epoch in range(EPOCHS):
        print
        print "Epoch", epoch
        print

        # load weights on first try (i.e. if process failed previously and we are attempting to recapture lost data)
        if epoch == 0:
            if os.path.exists('generator_weights') and os.path.exists('discriminator_weights'):
                print "Loading saves weights.."
                generator.load_weights('generator_weights')
                discriminator.load_weights('discriminator_weights')
                print "Finished loading"
            else:
                pass

        for index, image_batch in enumerate(BATCHES):
            print "Epoch", epoch, "Batch", index

            Noise_batch = np.array( [ noise_image() for n in range(len(image_batch)) ] )
            generated_images = generator.predict(Noise_batch)
            #print generated_images[0][-1][-1]

            for i, img in enumerate(generated_images):
                rolled = np.rollaxis(img, 0, 3)
                cv2.imwrite('results/' + str(i) + ".jpg", np.uint8(255 * 0.5 * (rolled + 1.0)))

            Xd = np.concatenate((image_batch, generated_images))
            yd = [1] * len(image_batch) + [0] * len(image_batch) # labels

            print "Training first discriminator.."
            d_loss = discriminator.train_on_batch(Xd, yd)

            Xg = Noise_batch
            yg = [1] * len(image_batch)

            print "Training first generator.."
            g_loss = discriminator_on_generator.train_on_batch(Xg, yg)

            print "Initial batch losses : ", "Generator loss", g_loss, "Discriminator loss", d_loss, "Total:", g_loss + d_loss

            #print "equilibrium - margin", equilibrium - margin

            if g_loss < d_loss and abs(d_loss - g_loss) > inter_model_margin:
                #for j in range(handicap):
                while abs(d_loss - g_loss) > inter_model_margin:
                    print "Updating discriminator.."
                    #g_loss = discriminator_on_generator.train_on_batch(Xg, yg)
                    d_loss = discriminator.train_on_batch(Xd, yd)
                    print "Generator loss", g_loss, "Discriminator loss", d_loss
                    if d_loss < g_loss:
                        break
            elif d_loss < g_loss and abs(d_loss - g_loss) > inter_model_margin:
                #for j in range(handicap):
                while abs(d_loss - g_loss) > inter_model_margin:
                    print "Updating generator.."
                    #d_loss = discriminator.train_on_batch(Xd, yd)
                    g_loss = discriminator_on_generator.train_on_batch(Xg, yg)
                    print "Generator loss", g_loss, "Discriminator loss", d_loss
                    if g_loss < d_loss:
                        break
            else:
                pass

            print "Final batch losses (after updates) : ", "Generator loss", g_loss, "Discriminator loss", d_loss, "Total:", g_loss + d_loss
            print

            if index % 20 == 0:
                print 'Saving weights..'
                generator.save_weights('generator_weights', True)
                discriminator.save_weights('discriminator_weights', True)

        plt.clf()
        for i, img in enumerate(generated_images[:5]):
            i = i+1
            plt.subplot(3, 3, i)
            rolled = np.rollaxis(img, 0, 3)
            #plt.imshow(rolled, cmap='gray')
            plt.imshow(rolled)
            plt.axis('off')
        fig.canvas.draw()
        plt.savefig('Epoch_' + str(epoch) + '.png')