def train(self, x, num_iter): for i in range(self.num_epochs): print("Epoch no :" + str(i + 1) + "/" + str(self.num_epochs)) for j in tqdm(range(num_iter)): x1, y = self.gen_data(x, self.batch_size // 2) # train the discriminator self.discriminator.train_on_batch(x1, y) # Freeze the discriminator to train the GAN model utils.make_trainable(self.discriminator, False) # train the gan model inp = utils.gen_noise(self.batch_size // 2) labels = np.zeros((self.batch_size // 2, 1)) self.gan_model.train_on_batch(inp, labels) # make the discriminator params back to trainable for the next iteration utils.make_trainable(self.discriminator, True) #save the weights and plot the results every 10 epochs if i % 10 == 0: self.gan_model.save_weights(self.save_path + str(i + 1) + ".h5") utils.plot(self.generator)
def __init__(self): # Input shape self.channels = 3 self.img_size = 64 self.latent_dim = 100 self.time = time() self.dataset_name = 'vdsr' self.learning_rate = 1e-4 optimizer = Adam(self.learning_rate, beta_1=0.5, decay=0.00005) self.gf = 64 # filter size of generator's last layer self.df = 64 # filter size of discriminator's first layer # Configure data loader self.data_loader = DataLoader(dataset_name=self.dataset_name, img_res=(self.img_size, self.img_size), mem_load=True) self.n_data = self.data_loader.get_n_data() self.generator = self.build_generator() print( "---------------------generator summary----------------------------" ) self.generator.summary() self.generator.compile(loss='mse', optimizer=optimizer, metrics=['mse']) # Build and compile the discriminator self.discriminator = self.build_discriminator() print( "\n---------------------discriminator summary----------------------------" ) self.discriminator.summary() self.discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy']) make_trainable(self.discriminator, False) z = Input(shape=(self.latent_dim, )) fake_img = self.generator(z) # for the combined model, we only train ganerator self.discriminator.trainable = False validity = self.discriminator(fake_img) self.combined = Model([z], [validity]) print( "\n---------------------combined summary----------------------------" ) self.combined.summary() self.combined.compile(loss=['binary_crossentropy'], optimizer=optimizer)
def train(self, epochs, batch_size, sample_interval): def named_logs(model, logs): result = {} for l in zip(model.metrics_names, logs): result[l[0]] = l[1] return result start_time = datetime.datetime.now() valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) max_iter = int(self.n_data / batch_size) os.makedirs('./logs/%s' % self.time, exist_ok=True) tensorboard = TensorBoard('./logs/%s' % self.time) tensorboard.set_model(self.generator) os.makedirs('models/%s' % self.time, exist_ok=True) with open('models/%s/%s_architecture.json' % (self.time, 'generator'), 'w') as f: f.write(self.generator.to_json()) print( "\nbatch size : %d | num_data : %d | max iteration : %d | time : %s \n" % (batch_size, self.n_data, max_iter, self.time)) for epoch in range(1, epochs + 1): for iter in range(max_iter): # ------------------ # Train Generator # ------------------ ref_imgs = self.data_loader.load_data(batch_size) noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) gen_imgs = self.generator.predict(noise) make_trainable(self.discriminator, True) d_loss_real = self.discriminator.train_on_batch( ref_imgs, valid * 0.9) # label smoothing d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) make_trainable(self.discriminator, False) logs = self.combined.train_on_batch([noise], [valid]) tensorboard.on_epoch_end(iter, named_logs(self.combined, [logs])) if iter % (sample_interval // 10) == 0: elapsed_time = datetime.datetime.now() - start_time print( "epoch:%d | iter : %d / %d | time : %10s | g_loss : %15s | d_loss : %s " % (epoch, iter, max_iter, elapsed_time, logs, d_loss)) if (iter + 1) % sample_interval == 0: self.sample_images(epoch, iter + 1) # save weights after every epoch self.generator.save_weights('models/%s/%s_epoch%d_weights.h5' % (self.time, 'generator', epoch))
def train(self, data, batch_size=250, num_epochs=25, eval_size=200): losses = [] train, test = train_test_split(data) for epoch in range(num_epochs): for i in range(len(train) // batch_size): # ------------------ # Train Disciminator # ------------------ make_trainable(self.discriminator, True) # Get some real conformations from the train data real_confs = train[i * batch_size:(i + 1) * batch_size] real_confs = real_confs.reshape(-1, self.n_atoms, 3, 1) # Sample high dimensional noise and generate fake conformations noise = make_latent_samples(batch_size, self.noise_dim) fake_confs = self.generator.predict_on_batch(noise) # Label the conformations accordingly real_confs_labels, fake_confs_labels = make_labels(batch_size) self.discriminator.train_on_batch(real_confs, real_confs_labels) self.discriminator.train_on_batch(fake_confs, fake_confs_labels) # -------------------------------------------------- # Train Generator via GAN (swith off discriminator) # -------------------------------------------------- noise = make_latent_samples(batch_size, self.noise_dim) make_trainable(self.discriminator, False) g_loss = self.gan.train_on_batch(noise, real_confs_labels) # Evaluate performance after epoch conf_eval_real = test[np.random.choice(len(test), eval_size, replace=False)] conf_eval_real = conf_eval_real.reshape(-1, self.n_atoms, 3, 1) noise = make_latent_samples(eval_size, self.noise_dim) conf_eval_fake = self.generator.predict_on_batch(noise) eval_real_labels, eval_fake_labels = make_labels(eval_size) d_loss_r = self.discriminator.test_on_batch( conf_eval_real, eval_real_labels) d_loss_f = self.discriminator.test_on_batch( conf_eval_fake, eval_fake_labels) d_loss = (d_loss_r + d_loss_f) / 2 # we want the fake to be realistic! g_loss = self.gan.test_on_batch(noise, eval_real_labels) print( "Epoch: {:>3}/{} Discriminator Loss: {:>6.4f} Generator Loss: {:>6.4f}" .format(epoch + 1, num_epochs, d_loss, g_loss)) losses.append((d_loss, g_loss)) return losses
def create_model(**kwargs): image_width = kwargs.get('image_width', settings.image_width) image_height = kwargs.get('image_height', settings.image_height) num_input_channels = kwargs.get('num_input_channels', settings.num_input_channels) loss = kwargs.get('loss', settings.loss) metrics = kwargs.get('metrics', settings.metrics) optimizer = kwargs.get('optimizer', settings.optimizer) learning_rate = kwargs.get('learning_rate', settings.learning_rate) # For custom loss and metrics functions if loss == 'dice_coef_loss': loss = dice_coef_loss if metrics == ['dice_coef']: metrics = [dice_coef] if optimizer == 'adam': optimizer = Adam(lr=learning_rate, beta_1=kwargs.get('beta_1', settings.beta_1)) generator = create_generator(**kwargs) generator.compile(loss=loss, optimizer=optimizer) discriminator = create_discriminator(**kwargs) discriminator.compile(loss=loss, optimizer=optimizer) # Make discriminator untrainable for stacked training utils.make_trainable(discriminator, False) GAN_input = Input(shape=(num_input_channels, image_width, image_height), name='GAN_input') H = generator(GAN_input) GAN_V = discriminator(H) GAN = Model(GAN_input, GAN_V) GAN.compile(loss=loss, optimizer=optimizer) print('~~~~~~~~~~~GAN~~~~~~~~~~~') GAN.summary() print('~~~~~~~~GENERATOR~~~~~~~~') generator.summary() print('~~~~~~DISCRIMINATOR~~~~~~') discriminator.summary() return GAN
test_imgs, test_vessels, test_masks=utils.get_imgs(test_dir, augmentation=False, img_size=img_size, dataset=dataset, mask=True) # create networks g = generator(img_size, n_filters_g) if FLAGS.discriminator=='pixel': d, d_out_shape = discriminator_pixel(img_size, n_filters_d,init_lr) elif FLAGS.discriminator=='patch1': d, d_out_shape = discriminator_patch1(img_size, n_filters_d,init_lr) elif FLAGS.discriminator=='patch2': d, d_out_shape = discriminator_patch2(img_size, n_filters_d,init_lr) elif FLAGS.discriminator=='image': d, d_out_shape = discriminator_image(img_size, n_filters_d,init_lr) else: d, d_out_shape = discriminator_dummy(img_size, n_filters_d,init_lr) utils.make_trainable(d, False) gan=GAN(g,d,img_size, n_filters_g, n_filters_d,alpha_recip, init_lr) generator=pretrain_g(g, img_size, n_filters_g, init_lr) g.summary() d.summary() gan.summary() with open(os.path.join(model_out_dir,"g_{}_{}.json".format(FLAGS.discriminator,FLAGS.ratio_gan2seg)),'w') as f: f.write(g.to_json()) # start training scheduler=utils.Scheduler(n_train_imgs//batch_size, n_train_imgs//batch_size, schedules, init_lr) if alpha_recip>0 else utils.Scheduler(0, n_train_imgs//batch_size, schedules, init_lr) print "training {} images :".format(n_train_imgs) for n_round in range(n_rounds): # train D utils.make_trainable(d, True)
def train_for_n(nb_epoch=5000, BATCH_SIZE=32): for e in tqdm(range(nb_epoch)): ### Shuffle and Batch the data _random = np.random.randint(0, emb_cs.shape[0], size=BATCH_SIZE) _random2 = np.random.randint(0, emb_zh.shape[0], size=BATCH_SIZE) if not WORD_ONLY: pos_seq_cs_batch = pos_seq_cs[_random] pos_seq_zh_batch = pos_seq_zh[_random2] emb_cs_batch = emb_cs[_random] emb_zh_batch = emb_zh[_random2] noise_g = np.random.normal(0, 1, size=(BATCH_SIZE, MAX_SEQUENCE_LENGTH, NOISE_SIZE)) reward_batch = np.zeros((BATCH_SIZE, 1)) ############################################# ### Train generator ############################################# for ep in range(1): # G v.s. D training ratio if not WORD_ONLY: output_g = generator.predict( [emb_zh_batch, pos_seq_zh_batch, noise_g, reward_batch]) else: output_g = generator.predict( [emb_zh_batch, noise_g, reward_batch]) action_g, action_one_hot_g = get_action(output_g) emb_g = translate(emb_zh_batch, action_g) text_g = translate_output(emb_zh_batch, action_g) # tag POS if not WORD_ONLY: pos_seq_g = [] for line in text_g: words = pseg.cut(line) sub_data = [] idx = 0 for w in words: if w.flag == "x": idx = 0 elif idx == 0: sub_data.append(postag[w.flag]) idx = 1 pos_seq_g.append(sub_data) pos_seq_g = pad_sequences(pos_seq_g, maxlen=MAX_SEQUENCE_LENGTH, padding='post', truncating='post', value=0) one_hot_action = action_one_hot_g.reshape(BATCH_SIZE, MAX_SEQUENCE_LENGTH, 2) make_trainable(generator, True) if not WORD_ONLY: reward_batch = discriminator.predict([emb_g, pos_seq_g])[:, 0] g_loss = generator.train_on_batch( [emb_zh_batch, pos_seq_zh_batch, noise_g, reward_batch], one_hot_action) else: reward_batch = discriminator.predict([emb_g])[:, 0] g_loss = generator.train_on_batch( [emb_zh_batch, noise_g, reward_batch], one_hot_action) losses["g"].append(g_loss) write_log(callbacks, log_g, g_loss, len(losses["g"])) if g_loss < 0.15: # early stop break ############################################# ### Train discriminator on generated sentence ############################################# X_emb = np.concatenate((emb_cs_batch, emb_g)) if not WORD_ONLY: X_pos = np.concatenate((pos_seq_cs_batch, pos_seq_g)) y = np.zeros([2 * BATCH_SIZE]) y[0:BATCH_SIZE] = 0.7 + np.random.random([BATCH_SIZE]) * 0.3 y[BATCH_SIZE:] = 0 + np.random.random([BATCH_SIZE]) * 0.3 make_trainable(discriminator, True) model.embedding_word.trainable = False if not WORD_ONLY: model.embedding_pos.trainable = False model.g_bi.trainable = False for ep in range(1): # G v.s. D training ratio if not WORD_ONLY: d_loss = discriminator.train_on_batch([X_emb, X_pos], y) else: d_loss = discriminator.train_on_batch([X_emb], y) losses["d"].append(d_loss) write_log(callbacks, log_d, d_loss, len(losses["d"])) if d_loss < 0.6: # early stop break ### Save model generator.save_weights(MODEL_PATH + "gen.mdl") discriminator.save_weights(MODEL_PATH + "dis.mdl")
value=0) X_pos = np.concatenate((XT_pos, pos_seq_g)) X_emb = np.concatenate((XT_emb, emb_g)) n = XT_emb.shape[0] y = np.zeros([2 * n]) y[:n] = 0.7 + np.random.random([n]) * 0.3 y[n:] = 0 + np.random.random([n]) * 0.3 random_id = np.random.randint(0, X_emb.shape[0], size=BATCH_SIZE * 10) XX_emb = X_emb[random_id] y = y[random_id] if not WORD_ONLY: XX_pos = X_pos[random_id] make_trainable(discriminator, True) K.set_value(discriminator.optimizer.lr, dopt) K.set_value(discriminator.optimizer.decay, dopt / 100) if not WORD_ONLY: discriminator.fit([XX_emb, XX_pos], y, epochs=10, batch_size=BATCH_SIZE, validation_split=0.1, callbacks=[earlystopper]) else: discriminator.fit([XX_emb], y, epochs=10, batch_size=BATCH_SIZE, validation_split=0.1,
def train(model, training_data, validation_data, **kwargs): x_train = training_data[0] y_train = training_data[1] batch_size = kwargs.get('batch_size', settings.batch_size) # model is [GAN, generator, discriminator] GAN = model generator = GAN.layers[1] discriminator = GAN.layers[2] # Pre-train the discriminator print('pre-training the discriminator') num_images = batch_size splice = random.sample(range(0, x_train.shape[0]), num_images) XT = x_train[splice, :, :, :] YT = y_train[splice, :, :, :] generated_images = generator.predict(XT) y_fake = np.array([0] * num_images).astype('float') y_real = np.array([1] * batch_size).astype('float') y = np.concatenate((y_fake, y_real)) x = np.concatenate((generated_images, YT)) utils.make_trainable(discriminator, True) utils.make_trainable(generator, True) GAN.fit(x, y, nb_epoch=1, batch_size=batch_size, verbose=1) # Train the GAN for epoch in range(100): print("Epoch is", epoch) print("Number of batches", int(x_train.shape[0] / batch_size)) for index in range(int(x_train.shape[0] / batch_size)): x_image_batch = x_train[index * batch_size:(index + 1) * batch_size] y_image_batch = y_train[index * batch_size:(index + 1) * batch_size] # if conditional GAN: # generated_images = concat(generated_images, y_image_batch) # Train discriminator, generator weights are frozen generated_images = generator.predict(x_image_batch, verbose=1) utils.make_trainable(discriminator, True) utils.make_trainable(generator, False) y_fake = np.array([0] * batch_size).astype('float') y_real = np.array([1] * batch_size).astype('float') d_loss = GAN.train_on_batch(generated_images, y_fake) d_loss += GAN.train_on_batch(y_image_batch, y_real) # Train generator, discriminator weights are frozen utils.make_trainable(discriminator, False) utils.make_trainable(generator, True) # g_loss = GAN.train_on_batch(generated_images, y_fake) g_loss = GAN.train_on_batch(x_image_batch, y_real) # g_loss /= 2 # g_loss = generator.train_on_batch(x_image_batch, y_image_batch) print("batch %d \t d_loss %f \t g_loss : %f" % (index, d_loss, g_loss)) # Save weights every 9 indexes if index % 10 == 9: generator.save_weights('generator_weights', True) discriminator.save_weights('discriminator_weights', True) # Save a generated image every epoch image = combine_images(generated_images) image = deprocess(image) Image.fromarray(image.astype( np.uint8)).save(str(epoch) + "_" + str(index) + ".png")
def main(): """ Creates and trains the GAN. """ # Creates directories makedirs() # Initializes and trains the generator if necessary generator = train_generator() # Loads discriminator if FLAGS.disc_path is None: discriminator = create_discriminator((15, 26, 512)) else: discriminator = load_model(FLAGS.disc_path) # Compiles discriminator discriminator.compile(optimizer="adam", loss="binary_crossentropy", metrics=["acc"]) # Sets initial discriminator trainability to False make_trainable(discriminator, False) # Loads VGG vgg = keras.applications.VGG16(include_top=False) # VGG should never be trained vgg.trainable = False # Creates GAN Model input_layer = Input(shape=(240, 426, 3)) out_gen = generator(input_layer) out_vgg = vgg(out_gen) out_disc = discriminator(out_vgg) # Compiles the GAN # Notice the loss_weights argument: this can be changed to achieve different # results. In general, binary cross-entropy controls resolution accuracy # and RMSE controls color accuracy. model = Model(inputs=input_layer, outputs=[out_disc, out_gen]) model.compile(optimizer="adam", loss=["binary_crossentropy", root_mean_squared_error], loss_weights=[0.9, 0.1], metrics=["acc"]) # Gets filepaths of training set files = os.listdir(FLAGS.x_dir) # Set to False to train the discriminator first train_gen = False # Trains the model for epoch in range(FLAGS.start_epoch, FLAGS.epochs): # Shuffles the training set and divide it into batches np.random.shuffle(files) batches = chunks(files, FLAGS.batch_size) # Trains a batch for batch in batches: # Creates batch tensors x_train, y_train = create_batch_tensors(batch) # Generator's turn to train if train_gen: print("Training generator: epoch {0}".format(epoch)) # Set discriminator trainability to False make_trainable(discriminator, False) # Trains the generator metrics = model.fit(x_train, [np.ones([len(x_train)]), y_train]) # If the generator is good enough, switch to discriminator if metrics.history["generator_acc"][0] > .9: train_gen = False # Discriminator's turn to train else: print("Training discriminator: epoch {0}".format(epoch)) # Set discriminator trainability to True make_trainable(discriminator, True) # Get generated data from inputs gen_input = x_train gen_output = generator.predict(gen_input) # Combines x data with their corresponding y labels disc_input = np.concatenate([gen_output, y_train]) ground_truth = np.concatenate( [np.zeros([len(gen_output)]), np.ones([len(y_train)])]) # Shuffles generated images and real images disc_input, ground_truth = shuffle(disc_input, ground_truth) # Run VGG embeddings on generated images vgg_output = vgg.predict(disc_input, batch_size=1) # Trains the discriminator metrics = discriminator.fit(vgg_output, ground_truth) # If the discriminator is good enough, switch to generator if metrics.history["acc"][0] > .9: train_gen = True print("\n \n \n \n COMPLETED EPOCH {0} \n \n \n \n".format(epoch)) # Get samples for visual verification out = generator.predict(x_train) base_filepath = FLAGS.out_dir + "/samples/gen/epoch_{0}".format(epoch) # Save samples for i, sample in enumerate(out): resized_train = cv2.resize(x_train[i], (0, 0), fx=2, fy=2, interpolation=cv2.INTER_CUBIC) cv2.imwrite(base_filepath + "_img_{0}_input.png".format(i), resized_train) cv2.imwrite(base_filepath + "_img_{0}_pred.png".format(i), sample) cv2.imwrite(base_filepath + "_img_{0}_true.png".format(i), y_train[i]) # Save generator and discriminator weights savepath = FLAGS.out_dir + "/weights" generator.save(savepath + "/gen/epoch_{0}.h5py".format(epoch)) make_trainable(discriminator, True) discriminator.save(savepath + "/disc/epoch_{0}.h5py".format(epoch))