def main(args): exp = expman.from_dir(args.run) params = exp.params batch_size = args.batch_size if args.batch_size else params.batch_size is_object = params.category in objects # get data test_dataset, test_labels = get_test_data(params.category, image_size=params.image_size, patch_size=params.patch_size, batch_size=batch_size) # build models generator = make_generator(params.latent_size, channels=params.channels, upsample_first=is_object, upsample_type=params.ge_up, bn=params.ge_bn, act=params.ge_act) encoder = make_encoder(params.patch_size, params.latent_size, channels=params.channels, bn=params.ge_bn, act=params.ge_act) discriminator = make_discriminator(params.patch_size, params.latent_size, channels=params.channels, bn=params.d_bn, act=params.d_act) # checkpointer checkpoint = tf.train.Checkpoint(generator=generator, encoder=encoder, discriminator=discriminator) ckpt_suffix = 'best' if args.best else 'last' ckpt_path = exp.path_to(f'ckpt/ckpt_{params.category}_{ckpt_suffix}') checkpoint.read(ckpt_path).expect_partial() discriminator_features = get_discriminator_features_model(discriminator) auc, balanced_accuracy = evaluate(generator, encoder, discriminator_features, test_dataset, test_labels, patch_size=params.patch_size, lambda_=args.lambda_) # print(f'{params.category}: AUC={auc}, BalAcc={balanced_accuracy}') index = pd.Index(args.lambda_, name='lambda') table = pd.DataFrame({ 'auc': auc, 'balanced_accuracy': balanced_accuracy }, index=index) print(table)
def main(args): # do not track lambda param, it can be changed after train exp = Experiment(args, ignore=('lambda_', )) print(exp) if exp.found: print('Already exists: SKIPPING') exit(0) np.random.seed(args.seed) tf.random.set_seed(args.seed) # get data train_dataset = get_train_data(args.category, image_size=args.image_size, patch_size=args.patch_size, batch_size=args.batch_size, n_batches=args.n_batches, rotation_range=args.rotation_range, seed=args.seed) test_dataset, test_labels = get_test_data(args.category, image_size=args.image_size, patch_size=args.patch_size, batch_size=args.batch_size) is_object = args.category in objects # build models generator = make_generator(args.latent_size, channels=args.channels, upsample_first=is_object, upsample_type=args.ge_up, bn=args.ge_bn, act=args.ge_act) encoder = make_encoder(args.patch_size, args.latent_size, channels=args.channels, bn=args.ge_bn, act=args.ge_act) discriminator = make_discriminator(args.patch_size, args.latent_size, channels=args.channels, bn=args.d_bn, act=args.d_act) # feature extractor model for evaluation discriminator_features = get_discriminator_features_model(discriminator) # build optimizers generator_encoder_optimizer = O.Adam(args.lr, beta_1=args.ge_beta1, beta_2=args.ge_beta2) discriminator_optimizer = O.Adam(args.lr, beta_1=args.d_beta1, beta_2=args.d_beta2) # reference to the models to use in eval generator_eval = generator encoder_eval = encoder # for smoothing generator and encoder evolution if args.ge_decay > 0: ema = tf.train.ExponentialMovingAverage(decay=args.ge_decay) generator_ema = tf.keras.models.clone_model(generator) encoder_ema = tf.keras.models.clone_model(encoder) generator_eval = generator_ema encoder_eval = encoder_ema # checkpointer checkpoint = tf.train.Checkpoint( generator=generator, encoder=encoder, discriminator=discriminator, generator_encoder_optimizer=generator_encoder_optimizer, discriminator_optimizer=discriminator_optimizer) best_ckpt_path = exp.ckpt(f'ckpt_{args.category}_best') last_ckpt_path = exp.ckpt(f'ckpt_{args.category}_last') # log stuff log, log_file = exp.require_csv(f'log_{args.category}.csv.gz') metrics, metrics_file = exp.require_csv(f'metrics_{args.category}.csv') best_metric = 0. best_recon = float('inf') best_recon_file = exp.path_to(f'best_recon_{args.category}.png') last_recon_file = exp.path_to(f'last_recon_{args.category}.png') # animate generation during training n_preview = 6 train_batch = next(iter(train_dataset))[:n_preview] test_batch = next(iter(test_dataset))[0][:n_preview] latent_batch = tf.random.normal([n_preview, args.latent_size]) if not is_object: # take random patches from test images patch_location = np.random.randint(0, args.image_size - args.patch_size, (n_preview, 2)) test_batch = [ x[i:i + args.patch_size, j:j + args.patch_size, :] for x, (i, j) in zip(test_batch, patch_location) ] test_batch = K.stack(test_batch) video_out = exp.path_to(f'{args.category}.mp4') video_options = dict(fps=30, codec='libx265', quality=4) # see imageio FFMPEG options video_saver = VideoSaver(train_batch, test_batch, latent_batch, video_out, **video_options) video_saver.generate_and_save(generator, encoder) # train loop progress = tqdm(train_dataset, desc=args.category, dynamic_ncols=True) try: for step, image_batch in enumerate(progress, start=1): if step == 1 or args.d_iter == 0: # only for JIT compilation (tf.function) to work d_train = True ge_train = True elif args.d_iter: n_iter = step % (abs(args.d_iter) + 1) # can be in [0, d_iter] d_train = (n_iter != 0) if (args.d_iter > 0) else ( n_iter == 0) # True in [1, d_iter] ge_train = not d_train # True when step == d_iter + 1 else: # d_iter == None: dynamic adjustment d_train = (scores['fake_score'] > 0) or (scores['real_score'] < 0) ge_train = (scores['real_score'] > 0) or (scores['fake_score'] < 0) losses, scores = train_step(image_batch, generator, encoder, discriminator, generator_encoder_optimizer, discriminator_optimizer, d_train, ge_train, alpha=args.alpha, gp_weight=args.gp_weight) if (args.ge_decay > 0) and (step % 10 == 0): ge_vars = generator.variables + encoder.variables ema.apply(ge_vars) # update exponential moving average # tensor to numpy losses = { n: l.numpy() if l is not None else l for n, l in losses.items() } scores = { n: s.numpy() if s is not None else s for n, s in scores.items() } # log step metrics entry = { 'step': step, 'timestamp': pd.to_datetime('now'), **losses, **scores } log = log.append(entry, ignore_index=True) if step % 100 == 0: if args.ge_decay > 0: ge_ema_vars = generator_ema.variables + encoder_ema.variables for v_ema, v in zip(ge_ema_vars, ge_vars): v_ema.assign(ema.average(v)) preview = video_saver.generate_and_save( generator_eval, encoder_eval) if step % 1000 == 0: log.to_csv(log_file, index=False) checkpoint.write(file_prefix=last_ckpt_path) auc, balanced_accuracy = evaluate(generator_eval, encoder_eval, discriminator_features, test_dataset, test_labels, patch_size=args.patch_size, lambda_=args.lambda_) entry = { 'step': step, 'auc': auc, 'balanced_accuracy': balanced_accuracy } metrics = metrics.append(entry, ignore_index=True) metrics.to_csv(metrics_file, index=False) if auc > best_metric: best_metric = auc checkpoint.write(file_prefix=best_ckpt_path) # save last image to inspect it during training imageio.imwrite(last_recon_file, preview) recon = losses['images_reconstruction_loss'] if recon < best_recon: best_recon = recon imageio.imwrite(best_recon_file, preview) progress.set_postfix({ 'AUC': f'{auc:.1%}', 'BalAcc': f'{balanced_accuracy:.1%}', 'BestAUC': f'{best_metric:.1%}', }) except KeyboardInterrupt: checkpoint.write(file_prefix=last_ckpt_path) finally: log.to_csv(log_file, index=False) video_saver.close() # score the test set checkpoint.read(best_ckpt_path) auc, balanced_accuracy = evaluate(generator, encoder, discriminator_features, test_dataset, test_labels, patch_size=args.patch_size, lambda_=args.lambda_) print(f'{args.category}: AUC={auc}, BalAcc={balanced_accuracy}')
def train_wgan(batch_size, epochs, image_shape): enc_model_1 = model.make_encoder() img = Input(shape=input_shape) z = enc_model_1(img) encoder1 = Model(img, z) z = Input(shape=(latent_dim,)) modelG = model.construct_generator() gen_img = modelG(z) generator = Model(z, gen_img) critic = model.construct_critic(image_shape) critic.trainable = False img = Input(shape=input_shape) z = encoder1(img) img_ = generator(z) real = critic(img_) optimizer = RMSprop(0.0002) gan = Model(img, [real, img_]) gan.compile(loss=[model.wasserstein_loss, 'mean_absolute_error'], optimizer=optimizer, metrics=None) X_train = model.load_data(168, 224) number_of_batches = int(X_train.shape[0] / batch_size) generator_iterations = 0 d_loss = 0 for epoch in range(epochs): current_batch = 0 while current_batch < number_of_batches: start_time = time.time() # In the first 25 epochs, the critic is updated 100 times # for each generator update. In the other epochs the default value is 5 if generator_iterations < 25 or (generator_iterations + 1) % 500 == 0: critic_iterations = 100 else: critic_iterations = 5 # Update the critic a number of critic iterations for critic_iteration in range(critic_iterations): if current_batch > number_of_batches: break # real_images = dataset_generator.next() it_index = np.random.randint(0, number_of_batches - 1) real_images = X_train[it_index * batch_size:(it_index + 1) * batch_size] current_batch += 1 # The last batch is smaller than the other ones, so we need to # take that into account current_batch_size = real_images.shape[0] # Generate images z = encoder1.predict(real_images) generated_images = generator.predict(z) # generated_images = generator.predict(noise) # Add some noise to the labels that will be fed to the critic real_y = np.ones(current_batch_size) fake_y = np.ones(current_batch_size) * -1 # print('real_y', real_y) # Let's train the critic critic.trainable = True # Clip the weights to small numbers near zero for layer in critic.layers: weights = layer.get_weights() weights = [np.clip(w, -0.01, 0.01) for w in weights] layer.set_weights(weights) d_real = critic.train_on_batch(real_images, real_y) d_fake = critic.train_on_batch(generated_images, fake_y) d_loss = d_real - d_fake # Update the generator critic.trainable = False itt_index = np.random.randint(0, number_of_batches - 1) imgs = X_train[itt_index * batch_size:(itt_index + 1) * batch_size] # We try to mislead the critic by giving the opposite labels fake_yy = np.ones(current_batch_size) g_loss = gan.train_on_batch(imgs, [fake_yy, imgs]) time_elapsed = time.time() - start_time print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_G_imgs: %f -> %f s' % (epoch, epochs, current_batch, number_of_batches, generator_iterations, d_loss, g_loss[0], g_loss[1], time_elapsed)) generator_iterations += 1