def test_train_discriminator(): """ Make sure that the discriminator can achieve low loss, when not training the generator. """ path = r'../fauxtograph/images/' paths = glob.glob(os.path.join(path, "*.jpg")) # Load images real_images = np.array( [ train.load_image(p) for p in paths ] ) np.random.shuffle( real_images ) total_samples, c_dim, x_dim, y_dim = real_images.shape train_real_images = np.array( [ im for im in real_images[ : int(total_samples/2)] ] ) test_real_images = np.array( [ im for im in real_images[int(total_samples/2) : ] ] ) fake_images = np.array( [ np.random.uniform(-1, 1, (3,64,64)) for n in range(len(real_images)) ] ) train_fake_images = np.array( [ im for im in fake_images[ : int(total_samples/2)] ] ) test_fake_images = np.array( [ im for im in fake_images[int(total_samples/2) : ] ] ) assert len(train_fake_images) == len(train_real_images) assert len(test_fake_images) == len(test_real_images) X_train = np.concatenate((train_real_images, train_fake_images)) y_train = [1] * len(train_real_images) + [0] * len(train_fake_images) # labels X_test = np.concatenate((test_real_images, test_fake_images)) y_test = [1] * len(test_real_images) + [0] * len(test_fake_images) # labels discriminator = model.discriminator_model() adam=Adam(lr=0.0002, beta_1=0.5, beta_2=0.999, epsilon=1e-08) discriminator.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy']) discriminator.fit(X_train, y_train, batch_size=128, nb_epoch=2, verbose=1, validation_data=(X_test, y_test) )
def __init__(self, args): self.img_size = args.imgsize self.channels = args.channels self.z_dim = args.zdims self.epochs = args.epoch self.batch_size = args.batchsize self.d_opt = Adam(lr=1e-5, beta_1=0.1) self.g_opt = Adam(lr=2e-4, beta_1=0.5) if not os.path.exists('./result/'): os.makedirs('./result/') if not os.path.exists('./model_images/'): os.makedirs('./model_images/') """ build discriminator model """ self.d = model.discriminator_model(self.img_size, self.channels) plot_model(self.d, to_file='./model_images/discriminator.png', show_shapes=True) """ build generator model """ self.g = model.generator_model(self.z_dim, self.img_size, self.channels) plot_model(self.g, to_file='./model_images/generator', show_shapes=True) """ discriminator on generator model """ self.d_on_g = model.generator_containg_discriminator(self.g, self.d, self.z_dim) plot_model(self.d_on_g, to_file='./model_images/d_on_g', show_shapes=True) self.g.compile(loss='mse', optimizer=self.g_opt) self.d_on_g.compile(loss='mse', optimizer=self.g_opt) self.d.trainable = True self.d.compile(loss='mse', optimizer=self.d_opt)
def test_discriminator_model(): epochs = 1 input_data = np.random.rand(1, 3, 64, 64) input_shape = input_data.shape discriminator = model.discriminator_model() adam=Adam(lr=0.0002, beta_1=0.5, beta_2=0.999, epsilon=1e-08) discriminator.compile(loss='binary_crossentropy', optimizer=adam) pred = discriminator.predict(input_data) print pred
def test_discriminator_model(): epochs = 1 input_data = np.random.rand(1, 3, 64, 64) input_shape = input_data.shape discriminator = model.discriminator_model() adam = Adam(lr=0.0002, beta_1=0.5, beta_2=0.999, epsilon=1e-08) discriminator.compile(loss='binary_crossentropy', optimizer=adam) pred = discriminator.predict(input_data) print pred
def __init__(self, hparams): super(GAN, self).__init__() self.hparams = hparams self.netG = model.colorization_model() self.netD = model.discriminator_model() self.VGG_MODEL = torchvision.models.vgg16(pretrained=True) self.generated_imgs = None self.last_imgs = None
def train_multiple_outputs(n_images, batch_size, epoch_num, critic_updates=5): data = load_images('./images/train', n_images) y_train, x_train = data['B'], data['A'] g = generator_model() d = discriminator_model() d_on_g = generator_containing_discriminator_multiple_outputs(g, d) d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) d.trainable = True d.compile(optimizer=d_opt, loss=wasserstein_loss) d.trainable = False loss = [perceptual_loss, wasserstein_loss] loss_weights = [100, 1] d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights) d.trainable = True output_true_batch, output_false_batch = np.ones((batch_size, 1)), -np.ones((batch_size, 1)) for epoch in range(epoch_num): print('epoch: {}/{}'.format(epoch, epoch_num)) print('batches: {}'.format(x_train.shape[0] / batch_size)) permutated_indexes = np.random.permutation(x_train.shape[0]) d_losses = [] d_on_g_losses = [] for index in range(int(x_train.shape[0] / batch_size)): batch_indexes = permutated_indexes[index*batch_size:(index+1)*batch_size] image_blur_batch = x_train[batch_indexes] image_full_batch = y_train[batch_indexes] generated_images = g.predict(x=image_blur_batch, batch_size=batch_size) for _ in range(critic_updates): d_loss_real = d.train_on_batch(image_full_batch, output_true_batch) d_loss_fake = d.train_on_batch(generated_images, output_false_batch) d_loss = 0.5 * np.add(d_loss_fake, d_loss_real) d_losses.append(d_loss) print('batch {} d_loss : {}'.format(index+1, np.mean(d_losses))) d.trainable = False d_on_g_loss = d_on_g.train_on_batch(image_blur_batch, [image_full_batch, output_true_batch]) d_on_g_losses.append(d_on_g_loss) print('batch {} d_on_g_loss : {}'.format(index+1, d_on_g_loss)) d.trainable = True with open('log.txt', 'a') as f: f.write('{} - {} - {}\n'.format(epoch, np.mean(d_losses), np.mean(d_on_g_losses))) save_all_weights(d, g, epoch, int(np.mean(d_on_g_losses)))
def test_train_discriminator(): """ Make sure that the discriminator can achieve low loss, when not training the generator. """ path = r'../fauxtograph/images/' paths = glob.glob(os.path.join(path, "*.jpg")) # Load images real_images = np.array([train.load_image(p) for p in paths]) np.random.shuffle(real_images) total_samples, c_dim, x_dim, y_dim = real_images.shape train_real_images = np.array( [im for im in real_images[:int(total_samples / 2)]]) test_real_images = np.array( [im for im in real_images[int(total_samples / 2):]]) fake_images = np.array([ np.random.uniform(-1, 1, (3, 64, 64)) for n in range(len(real_images)) ]) train_fake_images = np.array( [im for im in fake_images[:int(total_samples / 2)]]) test_fake_images = np.array( [im for im in fake_images[int(total_samples / 2):]]) assert len(train_fake_images) == len(train_real_images) assert len(test_fake_images) == len(test_real_images) X_train = np.concatenate((train_real_images, train_fake_images)) y_train = [1] * len(train_real_images) + [0] * len( train_fake_images) # labels X_test = np.concatenate((test_real_images, test_fake_images)) y_test = [1] * len(test_real_images) + [0] * len( test_fake_images) # labels discriminator = model.discriminator_model() adam = Adam(lr=0.0002, beta_1=0.5, beta_2=0.999, epsilon=1e-08) discriminator.compile(loss='binary_crossentropy', optimizer=adam, metrics=['accuracy']) discriminator.fit(X_train, y_train, batch_size=128, nb_epoch=2, verbose=1, validation_data=(X_test, y_test))
g_model = generator_model(vocab_size=len(reader.d), embedding_size=128, lstm_size=128, num_layer=4, max_length_encoder=40, max_length_decoder=40, max_gradient_norm=2, batch_size_num=20, learning_rate=0.001, beam_width=5) d_model = discriminator_model(vocab_size=len(reader.d), embedding_size=128, lstm_size=128, num_layer=4, max_post_length=40, max_resp_length=40, max_gradient_norm=2, batch_size_num=20, learning_rate=0.001) saver = tf.train.Saver(tf.global_variables(), keep_checkpoint_every_n_hours=1.0) config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) try: loader = tf.train.import_meta_graph('saved/model.ckpt.meta') loader.restore(sess, tf.train.latest_checkpoint('saved/')) print('load finished') except:
def train(path, batch_size, EPOCHS): # reproducibility # np.random.seed(42) # fig = plt.figure() # Get image paths print("Loading paths..") paths = glob.glob(os.path.join(path, "*.jpg")) print("Got paths..") print(paths) # Load images IMAGES = np.array([load_image(p) for p in paths]) np.random.shuffle(IMAGES) print(IMAGES[0]) # IMAGES, labels = load_mnist(dataset="training", digits=np.arange(10), path=path) # IMAGES = np.array( [ np.array( [ scipy.misc.imresize(p, (64, 64)) / 256 ] * 3 ) for p in IMAGES ] ) # np.random.shuffle( IMAGES ) BATCHES = [b for b in chunks(IMAGES, batch_size)] discriminator = model.discriminator_model() generator = model.generator_model() discriminator_on_generator = model.generator_containing_discriminator( generator, discriminator) # adam_gen=Adam(lr=0.0002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08) adam_gen = Adam(lr=0.00002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08) adam_dis = Adam(lr=0.00002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08) # opt = RMSprop() generator.compile(loss='binary_crossentropy', optimizer=adam_gen) discriminator_on_generator.compile(loss='binary_crossentropy', optimizer=adam_gen) discriminator.trainable = True discriminator.compile(loss='binary_crossentropy', optimizer=adam_dis) print("Number of batches", len(BATCHES)) print("Batch size is", batch_size) # margin = 0.25 # equilibrium = 0.6931 inter_model_margin = 0.10 for epoch in range(EPOCHS): print() print("Epoch", epoch) print() # load weights on first try (i.e. if process failed previously and we are attempting to recapture lost data) if epoch == 0: if os.path.exists('generator_weights') and os.path.exists( 'discriminator_weights'): print("Loading saves weights..") generator.load_weights('generator_weights') discriminator.load_weights('discriminator_weights') print("Finished loading") else: pass for index, image_batch in enumerate(BATCHES): print("Epoch", epoch, "Batch", index) Noise_batch = np.array( [noise_image() for n in range(len(image_batch))]) generated_images = generator.predict(Noise_batch) # print generated_images[0][-1][-1] for i, img in enumerate(generated_images): rolled = np.rollaxis(img, 0, 3) cv2.imwrite('results/' + str(i) + ".jpg", np.uint8(255 * 0.5 * (rolled + 1.0))) Xd = np.concatenate((image_batch, generated_images)) yd = [1] * len(image_batch) + [0] * len(image_batch) # labels print("Training first discriminator..") d_loss = discriminator.train_on_batch(Xd, yd) Xg = Noise_batch yg = [1] * len(image_batch) print("Training first generator..") g_loss = discriminator_on_generator.train_on_batch(Xg, yg) print("Initial batch losses : ", "Generator loss", g_loss, "Discriminator loss", d_loss, "Total:", g_loss + d_loss) # print "equilibrium - margin", equilibrium - margin if g_loss < d_loss and abs(d_loss - g_loss) > inter_model_margin: # for j in range(handicap): while abs(d_loss - g_loss) > inter_model_margin: print("Updating discriminator..") # g_loss = discriminator_on_generator.train_on_batch(Xg, yg) d_loss = discriminator.train_on_batch(Xd, yd) print("Generator loss", g_loss, "Discriminator loss", d_loss) if d_loss < g_loss: break elif d_loss < g_loss and abs(d_loss - g_loss) > inter_model_margin: # for j in range(handicap): while abs(d_loss - g_loss) > inter_model_margin: print("Updating generator..") # d_loss = discriminator.train_on_batch(Xd, yd) g_loss = discriminator_on_generator.train_on_batch(Xg, yg) print("Generator loss", g_loss, "Discriminator loss", d_loss) if g_loss < d_loss: break else: pass print("Final batch losses (after updates) : ", "Generator loss", g_loss, "Discriminator loss", d_loss, "Total:", g_loss + d_loss) print() if index % 20 == 0: print('Saving weights..') generator.save_weights('generator_weights', True) discriminator.save_weights('discriminator_weights', True) plt.clf() for i, img in enumerate(generated_images[:5]): i = i + 1 plt.subplot(3, 3, i) rolled = np.rollaxis(img, 0, 3) # plt.imshow(rolled, cmap='gray') plt.imshow(rolled) plt.axis('off') # fig.canvas.draw() plt.savefig('Epoch_' + str(epoch) + '.png')
def train(path, batch_size, EPOCHS): #reproducibility #np.random.seed(42) fig = plt.figure() # Get image paths print "Loading paths.." paths = glob.glob(os.path.join(path, "*.jpg")) print "Got paths.." # Load images IMAGES = np.array( [ load_image(p) for p in paths ] ) np.random.shuffle( IMAGES ) #IMAGES, labels = load_mnist(dataset="training", digits=np.arange(10), path=path) #IMAGES = np.array( [ np.array( [ scipy.misc.imresize(p, (64, 64)) / 256 ] * 3 ) for p in IMAGES ] ) #np.random.shuffle( IMAGES ) BATCHES = [ b for b in chunks(IMAGES, batch_size) ] discriminator = model.discriminator_model() generator = model.generator_model() discriminator_on_generator = model.generator_containing_discriminator(generator, discriminator) #adam_gen=Adam(lr=0.0002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08) adam_gen=Adam(lr=0.00002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08) adam_dis=Adam(lr=0.00002, beta_1=0.0005, beta_2=0.999, epsilon=1e-08) #opt = RMSprop() generator.compile(loss='binary_crossentropy', optimizer=adam_gen) discriminator_on_generator.compile(loss='binary_crossentropy', optimizer=adam_gen) discriminator.trainable = True discriminator.compile(loss='binary_crossentropy', optimizer=adam_dis) print "Number of batches", len(BATCHES) print "Batch size is", batch_size #margin = 0.25 #equilibrium = 0.6931 inter_model_margin = 0.10 for epoch in range(EPOCHS): print print "Epoch", epoch print # load weights on first try (i.e. if process failed previously and we are attempting to recapture lost data) if epoch == 0: if os.path.exists('generator_weights') and os.path.exists('discriminator_weights'): print "Loading saves weights.." generator.load_weights('generator_weights') discriminator.load_weights('discriminator_weights') print "Finished loading" else: pass for index, image_batch in enumerate(BATCHES): print "Epoch", epoch, "Batch", index Noise_batch = np.array( [ noise_image() for n in range(len(image_batch)) ] ) generated_images = generator.predict(Noise_batch) #print generated_images[0][-1][-1] for i, img in enumerate(generated_images): rolled = np.rollaxis(img, 0, 3) cv2.imwrite('results/' + str(i) + ".jpg", np.uint8(255 * 0.5 * (rolled + 1.0))) Xd = np.concatenate((image_batch, generated_images)) yd = [1] * len(image_batch) + [0] * len(image_batch) # labels print "Training first discriminator.." d_loss = discriminator.train_on_batch(Xd, yd) Xg = Noise_batch yg = [1] * len(image_batch) print "Training first generator.." g_loss = discriminator_on_generator.train_on_batch(Xg, yg) print "Initial batch losses : ", "Generator loss", g_loss, "Discriminator loss", d_loss, "Total:", g_loss + d_loss #print "equilibrium - margin", equilibrium - margin if g_loss < d_loss and abs(d_loss - g_loss) > inter_model_margin: #for j in range(handicap): while abs(d_loss - g_loss) > inter_model_margin: print "Updating discriminator.." #g_loss = discriminator_on_generator.train_on_batch(Xg, yg) d_loss = discriminator.train_on_batch(Xd, yd) print "Generator loss", g_loss, "Discriminator loss", d_loss if d_loss < g_loss: break elif d_loss < g_loss and abs(d_loss - g_loss) > inter_model_margin: #for j in range(handicap): while abs(d_loss - g_loss) > inter_model_margin: print "Updating generator.." #d_loss = discriminator.train_on_batch(Xd, yd) g_loss = discriminator_on_generator.train_on_batch(Xg, yg) print "Generator loss", g_loss, "Discriminator loss", d_loss if g_loss < d_loss: break else: pass print "Final batch losses (after updates) : ", "Generator loss", g_loss, "Discriminator loss", d_loss, "Total:", g_loss + d_loss print if index % 20 == 0: print 'Saving weights..' generator.save_weights('generator_weights', True) discriminator.save_weights('discriminator_weights', True) plt.clf() for i, img in enumerate(generated_images[:5]): i = i+1 plt.subplot(3, 3, i) rolled = np.rollaxis(img, 0, 3) #plt.imshow(rolled, cmap='gray') plt.imshow(rolled) plt.axis('off') fig.canvas.draw() plt.savefig('Epoch_' + str(epoch) + '.png')
def train_multiple_outputs(n_images, batch_size, epoch_num, critic_updates=5): #data = load_images('/home/turing/td/', n_images) y_train = sorted(glob.glob('/home/turing/td/data/*.png')) x_train = sorted(glob.glob('/home/turing/td/blur/*.png')) print('loaded_data') g = generator_model() g.load_weights('weights/424/generator_19_290.h5') d = discriminator_model() d.load_weights('weights/424/discriminator_19.h5') d_on_g = generator_containing_discriminator_multiple_outputs(g, d) d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) d.trainable = True d.compile(optimizer=d_opt, loss=wasserstein_loss) d.trainable = False loss = [perceptual_loss, wasserstein_loss] loss_weights = [100, 1] d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights) d.trainable = True output_true_batch, output_false_batch = np.ones((batch_size, 1)), np.zeros( (batch_size, 1)) for epoch in range(epoch_num): print('epoch: {}/{}'.format(epoch, epoch_num)) print('batches: {}'.format(len(x_train) / batch_size)) permutated_indexes = np.random.permutation(len(x_train)) d_losses = [] d_on_g_losses = [] for index in range(int(len(x_train) / batch_size)): batch_indexes = permutated_indexes[index * batch_size:(index + 1) * batch_size] x_t = [] y_t = [] for i in batch_indexes: x_t.append(x_train[i]) y_t.append(y_train[i]) image_blur_batch = load_batch(x_t) image_full_batch = load_batch(y_t) generated_images = g.predict(x=image_blur_batch, batch_size=batch_size) for _ in range(critic_updates): d_loss_real = d.train_on_batch(image_full_batch, output_true_batch) d_loss_fake = d.train_on_batch(generated_images, output_false_batch) d_loss = 0.5 * np.add(d_loss_fake, d_loss_real) d_losses.append(d_loss) print('batch {} d_loss : {}'.format(index + 1, np.mean(d_losses))) d.trainable = False d_on_g_loss = d_on_g.train_on_batch( image_blur_batch, [image_full_batch, output_true_batch]) d_on_g_losses.append(d_on_g_loss) print('batch {} d_on_g_loss : {}'.format(index + 1, d_on_g_loss)) d.trainable = True with open('log.txt', 'a') as f: f.write('{} - {} - {}\n'.format(epoch, np.mean(d_losses), np.mean(d_on_g_losses))) save_all_weights(d, g, epoch, int(np.mean(d_on_g_losses)))
train_images=train_images.reshape(-1, 28, 28, 1).astype('float32') train_images=(train_images-127.5)/127.5 train_dataset=tf.data.Dataset.from_tensor_slices(train_images).shuffle(len(train_images)).batch(batch_size) def discriminator_loss(real_output, fake_output): real_loss=cross_entropy(tf.ones_like(real_output), real_output) fake_loss=cross_entropy(tf.zeros_like(fake_output), fake_output) total_loss=real_loss+fake_loss return total_loss def generator_loss(fake_output): return cross_entropy(tf.ones_like(fake_output), fake_output) generator=model.generator_model() discriminator=model.discriminator_model() cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True) train_generator_loss=tf.keras.metrics.Mean() train_discriminator_loss=tf.keras.metrics.Mean() generator_optimizer=tf.keras.optimizers.Adam(1e-4) discriminator_optimizer=tf.keras.optimizers.Adam(1e-4) checkpoint=tf.train.Checkpoint(generator_optimizer=generator_optimizer, discriminator_optimizer=discriminator_optimizer, generator=generator, discriminator=discriminator) seed=tf.random.normal([num_examples_to_generate, noise_dim])
def train_multiple_outputs(n_images, batch_size, epoch_num, critic_updates=5): g = generator_model() d = discriminator_model() g.load_weights('generator.h5') d.load_weights('discriminator.h5') d_on_g = generator_containing_discriminator_multiple_outputs(g, d) d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) d.trainable = True d.compile(optimizer=d_opt, loss=wasserstein_loss) d.trainable = False loss = [perceptual_loss, wasserstein_loss] loss_weights = [100, 1] d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights) d.trainable = True output_true_batch, output_false_batch = np.ones((batch_size, 1)), np.zeros( (batch_size, 1)) for epoch in range(epoch_num): print('epoch: {}/{}'.format(epoch, epoch_num)) print('batches: {}'.format(batch_size)) start = 0 d_losses = [] d_on_g_losses = [] shuffle() for index in range(int(25000 // batch_size)): data = load_images(start, batch_size) y_train, x_train = data['B'], data['A'] image_blur_batch = x_train image_full_batch = y_train generated_images = g.predict(x=image_blur_batch, batch_size=batch_size) for _ in range(critic_updates): d_loss_real = d.train_on_batch(image_full_batch, output_true_batch) d_loss_fake = d.train_on_batch(generated_images, output_false_batch) d_loss = 0.5 * np.add(d_loss_fake, d_loss_real) d_losses.append(d_loss) print('batch {} d_loss : {}'.format(index + 1, np.mean(d_losses))) d.trainable = False d_on_g_loss = d_on_g.train_on_batch( image_blur_batch, [image_full_batch, output_true_batch]) d_on_g_losses.append(d_on_g_loss) print('batch {} d_on_g_loss : {}'.format(index + 1, d_on_g_loss)) d.trainable = True if (index % 300): save_all_weights(d, g, epoch, int(index * 10)) start += batch_size with open('log.txt', 'a') as f: f.write('{} - {} - {}\n'.format(epoch, np.mean(d_losses), np.mean(d_on_g_losses))) save_all_weights(d, g, epoch, int(np.mean(d_on_g_losses)))
def train(gen,disc,cGAN,gray,rgb,gray_val,rgb_val,batch): samples = len(rgb) gen_image = gen.predict(gray, batch_size=16) gen_image_val = gen.predict(gray_val, batch_size=8) inputs = np.concatenate([gray, gray]) outputs = np.concatenate([rgb, gen_image]) y = np.concatenate([np.ones((samples, 1)), np.zeros((samples, 1))]) disc.fit([inputs, outputs], y, epochs=1, batch_size=4) disc.trainable = False cGAN.fit(gray, [np.ones((samples, 1)), rgb], epochs=1, batch_size=batch,validation_data=[gray_val,[np.ones((val_samples,1)),rgb_val]]) disc.trainable = True gen = generator_model(x_shape,y_shape) disc = discriminator_model(x_shape,y_shape) cGAN = cGAN_model(gen, disc) # cGAN.load_weights('sketchColorisation/result/store/9950.h5') disc.compile(loss=['binary_crossentropy'], optimizer=tf.keras.optimizers.Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08), metrics=['accuracy']) cGAN.compile(loss=['binary_crossentropy',custom_loss_2], loss_weights=[5, 100], optimizer=tf.keras.optimizers.Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)) tensorboard = tf.keras.callbacks.TensorBoard(log_dir="logs/{}".format(time())) dataset = 'sketchColorisation/Images/' graystore = 'sketchColorisation/grayScale/' rgbstore = 'sketchColorisation/colored/' val_data = 'sketchColorisation/validation/' store = 'sketchColorisation/result/store/' store2 = 'sketchColorisation/result/store2/'
def train(batch_size, epoch_num): # Note the x(blur) in the second, the y(full) in the first y_train, x_train = data_utils.load_data(data_type='train') # GAN g = generator_model() d = discriminator_model() d_on_g = generator_containing_discriminator(g, d) # compile the models, use default optimizer parameters # generator use adversarial loss g.compile(optimizer='adam', loss=generator_loss) # discriminator use binary cross entropy loss d.compile(optimizer='adam', loss='binary_crossentropy') # adversarial net use adversarial loss d_on_g.compile(optimizer='adam', loss=adversarial_loss) for epoch in range(epoch_num): print('epoch: ', epoch + 1, '/', epoch_num) print('batches: ', int(x_train.shape[0] / batch_size)) for index in range(int(x_train.shape[0] / batch_size)): # select a batch data image_blur_batch = x_train[index * batch_size:(index + 1) * batch_size] image_full_batch = y_train[index * batch_size:(index + 1) * batch_size] generated_images = g.predict(x=image_blur_batch, batch_size=batch_size) # output generated images for each 30 iters if (index % 30 == 0) and (index != 0): data_utils.generate_image(image_full_batch, image_blur_batch, generated_images, 'result/interim/', epoch, index) # concatenate the full and generated images, # the full images at top, the generated images at bottom x = np.concatenate((image_full_batch, generated_images)) # generate labels for the full and generated images y = [1] * batch_size + [0] * batch_size # train discriminator d_loss = d.train_on_batch(x, y) print('batch %d d_loss : %f' % (index + 1, d_loss)) # let discriminator can't be trained d.trainable = False # train adversarial net d_on_g_loss = d_on_g.train_on_batch(image_blur_batch, [1] * batch_size) print('batch %d d_on_g_loss : %f' % (index + 1, d_on_g_loss)) # train generator g_loss = g.train_on_batch(image_blur_batch, image_full_batch) print('batch %d g_loss : %f' % (index + 1, g_loss)) # let discriminator can be trained d.trainable = True # output weights for generator and discriminator each 30 iters if (index % 30 == 0) and (index != 0): g.save_weights('weight/generator_weights.h5', True) d.save_weights('weight/discriminator_weights.h5', True)
def train_multiple_outputs(n_images, batch_size, epoch_num, critic_updates=5): g = generator_model() d = discriminator_model() vgg = build_vgg() d_on_g = generator_containing_discriminator_multiple_outputs(g, d, vgg) d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08) optimizer = Adam(1E-4, 0.5) vgg.trainable = False vgg.compile(loss='mse', optimizer=optimizer, metrics=['accuracy']) d.trainable = True d.compile(optimizer=d_opt, loss='binary_crossentropy') d.trainable = False loss = ['mae', 'mse', 'binary_crossentropy'] loss_weights = [0.1, 100, 1] d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights) d.trainable = True output_true_batch, output_false_batch = np.ones((batch_size, 1)), np.zeros( (batch_size, 1)) for epoch in range(epoch_num): print('epoch: {}/{}'.format(epoch, epoch_num)) y_pre, x_pre, mask = load_data(batch_size) d_losses = [] d_on_g_losses = [] generated_images = g.predict(x=x_pre, batch_size=batch_size) for _ in range(critic_updates): d_loss_real = d.train_on_batch(y_pre, output_true_batch) d_loss_fake = d.train_on_batch(generated_images, output_false_batch) d_loss = 0.5 * np.add(d_loss_fake, d_loss_real) d_losses.append(d_loss) print('batch {} d_loss : {}'.format(epoch, np.mean(d_losses))) d.trainable = False real_result = mask * y_pre y_features = vgg.predict(y_pre) d_on_g_loss = d_on_g.train_on_batch( [x_pre, mask], [real_result, y_features, output_true_batch]) d_on_g_losses.append(d_on_g_loss) print('batch {} d_on_g_loss : {}'.format(epoch, d_on_g_loss)) d.trainable = True if epoch % 100 == 0: generated = np.array([(img + 1) * 127.5 for img in generated_images]) full = np.array([(img + 1) * 127.5 for img in y_pre]) blur = np.array([(img + 1) * 127.5 for img in x_pre]) for i in range(3): img_ge = generated[i, :, :, :] img_fu = full[i, :, :, :] img_bl = blur[i, :, :, :] output = np.concatenate((img_ge, img_fu, img_bl), axis=1) cv2.imwrite( '/home/alyssa/PythonProjects/occluded/key_code/img_inpainting/out/' + str(epoch) + '_' + str(i) + '.jpg', output) if (epoch > 10000 and epoch % 1000 == 0): save_all_weights(d, g, epoch, int(np.mean(d_on_g_losses)))
def map_fn(index=None, flags=None): torch.set_default_tensor_type('torch.FloatTensor') torch.manual_seed(1234) train_data = dataset.DATA(config.TRAIN_DIR) if config.MULTI_CORE: train_sampler = torch.utils.data.distributed.DistributedSampler( train_data, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True) else: train_sampler = torch.utils.data.RandomSampler(train_data) train_loader = torch.utils.data.DataLoader( train_data, batch_size=flags['batch_size'] if config.MULTI_CORE else config.BATCH_SIZE, sampler=train_sampler, num_workers=flags['num_workers'] if config.MULTI_CORE else 4, drop_last=True, pin_memory=True) if config.MULTI_CORE: DEVICE = xm.xla_device() else: DEVICE = config.DEVICE netG = model.colorization_model().double() netD = model.discriminator_model().double() VGG_modelF = torchvision.models.vgg16(pretrained=True).double() VGG_modelF.requires_grad_(False) netG = netG.to(DEVICE) netD = netD.to(DEVICE) VGG_modelF = VGG_modelF.to(DEVICE) optD = torch.optim.Adam(netD.parameters(), lr=2e-4, betas=(0.5, 0.999)) optG = torch.optim.Adam(netG.parameters(), lr=2e-4, betas=(0.5, 0.999)) ## Trains train_start = time.time() losses = { 'G_losses': [], 'D_losses': [], 'EPOCH_G_losses': [], 'EPOCH_D_losses': [], 'G_losses_eval': [] } netG, optG, netD, optD, epoch_checkpoint = utils.load_checkpoint( config.CHECKPOINT_DIR, netG, optG, netD, optD, DEVICE) netGAN = model.GAN(netG, netD) for epoch in range( epoch_checkpoint, flags['num_epochs'] + 1 if config.MULTI_CORE else config.NUM_EPOCHS + 1): print('\n') print('#' * 8, f'EPOCH-{epoch}', '#' * 8) losses['EPOCH_G_losses'] = [] losses['EPOCH_D_losses'] = [] if config.MULTI_CORE: para_train_loader = pl.ParallelLoader( train_loader, [DEVICE]).per_device_loader(DEVICE) engine.train(para_train_loader, netGAN, netD, VGG_modelF, optG, optD, device=DEVICE, losses=losses) elapsed_train_time = time.time() - train_start print("Process", index, "finished training. Train time was:", elapsed_train_time) else: engine.train(train_loader, netGAN, netD, VGG_modelF, optG, optD, device=DEVICE, losses=losses) #########################CHECKPOINTING################################# utils.create_checkpoint(epoch, netG, optG, netD, optD, max_checkpoint=config.KEEP_CKPT, save_path=config.CHECKPOINT_DIR) ######################################################################## utils.plot_some(train_data, netG, DEVICE, epoch) gc.collect()