Пример #1
0
def main(args):

    exp = expman.from_dir(args.run)
    params = exp.params

    batch_size = args.batch_size if args.batch_size else params.batch_size
    is_object = params.category in objects

    # get data
    test_dataset, test_labels = get_test_data(params.category,
                                              image_size=params.image_size,
                                              patch_size=params.patch_size,
                                              batch_size=batch_size)

    # build models
    generator = make_generator(params.latent_size,
                               channels=params.channels,
                               upsample_first=is_object,
                               upsample_type=params.ge_up,
                               bn=params.ge_bn,
                               act=params.ge_act)

    encoder = make_encoder(params.patch_size,
                           params.latent_size,
                           channels=params.channels,
                           bn=params.ge_bn,
                           act=params.ge_act)

    discriminator = make_discriminator(params.patch_size,
                                       params.latent_size,
                                       channels=params.channels,
                                       bn=params.d_bn,
                                       act=params.d_act)

    # checkpointer
    checkpoint = tf.train.Checkpoint(generator=generator,
                                     encoder=encoder,
                                     discriminator=discriminator)
    ckpt_suffix = 'best' if args.best else 'last'
    ckpt_path = exp.path_to(f'ckpt/ckpt_{params.category}_{ckpt_suffix}')
    checkpoint.read(ckpt_path).expect_partial()

    discriminator_features = get_discriminator_features_model(discriminator)
    auc, balanced_accuracy = evaluate(generator,
                                      encoder,
                                      discriminator_features,
                                      test_dataset,
                                      test_labels,
                                      patch_size=params.patch_size,
                                      lambda_=args.lambda_)

    # print(f'{params.category}: AUC={auc}, BalAcc={balanced_accuracy}')
    index = pd.Index(args.lambda_, name='lambda')
    table = pd.DataFrame({
        'auc': auc,
        'balanced_accuracy': balanced_accuracy
    },
                         index=index)
    print(table)
Пример #2
0
def main(args):

    # do not track lambda param, it can be changed after train
    exp = Experiment(args, ignore=('lambda_', ))
    print(exp)

    if exp.found:
        print('Already exists: SKIPPING')
        exit(0)

    np.random.seed(args.seed)
    tf.random.set_seed(args.seed)

    # get data
    train_dataset = get_train_data(args.category,
                                   image_size=args.image_size,
                                   patch_size=args.patch_size,
                                   batch_size=args.batch_size,
                                   n_batches=args.n_batches,
                                   rotation_range=args.rotation_range,
                                   seed=args.seed)

    test_dataset, test_labels = get_test_data(args.category,
                                              image_size=args.image_size,
                                              patch_size=args.patch_size,
                                              batch_size=args.batch_size)

    is_object = args.category in objects

    # build models
    generator = make_generator(args.latent_size,
                               channels=args.channels,
                               upsample_first=is_object,
                               upsample_type=args.ge_up,
                               bn=args.ge_bn,
                               act=args.ge_act)
    encoder = make_encoder(args.patch_size,
                           args.latent_size,
                           channels=args.channels,
                           bn=args.ge_bn,
                           act=args.ge_act)
    discriminator = make_discriminator(args.patch_size,
                                       args.latent_size,
                                       channels=args.channels,
                                       bn=args.d_bn,
                                       act=args.d_act)
    # feature extractor model for evaluation
    discriminator_features = get_discriminator_features_model(discriminator)

    # build optimizers
    generator_encoder_optimizer = O.Adam(args.lr,
                                         beta_1=args.ge_beta1,
                                         beta_2=args.ge_beta2)
    discriminator_optimizer = O.Adam(args.lr,
                                     beta_1=args.d_beta1,
                                     beta_2=args.d_beta2)

    # reference to the models to use in eval
    generator_eval = generator
    encoder_eval = encoder

    # for smoothing generator and encoder evolution
    if args.ge_decay > 0:
        ema = tf.train.ExponentialMovingAverage(decay=args.ge_decay)
        generator_ema = tf.keras.models.clone_model(generator)
        encoder_ema = tf.keras.models.clone_model(encoder)

        generator_eval = generator_ema
        encoder_eval = encoder_ema

    # checkpointer
    checkpoint = tf.train.Checkpoint(
        generator=generator,
        encoder=encoder,
        discriminator=discriminator,
        generator_encoder_optimizer=generator_encoder_optimizer,
        discriminator_optimizer=discriminator_optimizer)
    best_ckpt_path = exp.ckpt(f'ckpt_{args.category}_best')
    last_ckpt_path = exp.ckpt(f'ckpt_{args.category}_last')

    # log stuff
    log, log_file = exp.require_csv(f'log_{args.category}.csv.gz')
    metrics, metrics_file = exp.require_csv(f'metrics_{args.category}.csv')
    best_metric = 0.
    best_recon = float('inf')
    best_recon_file = exp.path_to(f'best_recon_{args.category}.png')
    last_recon_file = exp.path_to(f'last_recon_{args.category}.png')

    # animate generation during training
    n_preview = 6
    train_batch = next(iter(train_dataset))[:n_preview]
    test_batch = next(iter(test_dataset))[0][:n_preview]
    latent_batch = tf.random.normal([n_preview, args.latent_size])

    if not is_object:  # take random patches from test images
        patch_location = np.random.randint(0,
                                           args.image_size - args.patch_size,
                                           (n_preview, 2))
        test_batch = [
            x[i:i + args.patch_size, j:j + args.patch_size, :]
            for x, (i, j) in zip(test_batch, patch_location)
        ]
        test_batch = K.stack(test_batch)

    video_out = exp.path_to(f'{args.category}.mp4')
    video_options = dict(fps=30, codec='libx265',
                         quality=4)  # see imageio FFMPEG options
    video_saver = VideoSaver(train_batch, test_batch, latent_batch, video_out,
                             **video_options)
    video_saver.generate_and_save(generator, encoder)

    # train loop
    progress = tqdm(train_dataset, desc=args.category, dynamic_ncols=True)
    try:
        for step, image_batch in enumerate(progress, start=1):
            if step == 1 or args.d_iter == 0:  # only for JIT compilation (tf.function) to work
                d_train = True
                ge_train = True
            elif args.d_iter:
                n_iter = step % (abs(args.d_iter) + 1)  # can be in [0, d_iter]
                d_train = (n_iter != 0) if (args.d_iter > 0) else (
                    n_iter == 0)  # True in [1, d_iter]
                ge_train = not d_train  # True when step == d_iter + 1
            else:  # d_iter == None: dynamic adjustment
                d_train = (scores['fake_score'] > 0) or (scores['real_score'] <
                                                         0)
                ge_train = (scores['real_score'] > 0) or (scores['fake_score']
                                                          < 0)

            losses, scores = train_step(image_batch,
                                        generator,
                                        encoder,
                                        discriminator,
                                        generator_encoder_optimizer,
                                        discriminator_optimizer,
                                        d_train,
                                        ge_train,
                                        alpha=args.alpha,
                                        gp_weight=args.gp_weight)

            if (args.ge_decay > 0) and (step % 10 == 0):
                ge_vars = generator.variables + encoder.variables
                ema.apply(ge_vars)  # update exponential moving average

            # tensor to numpy
            losses = {
                n: l.numpy() if l is not None else l
                for n, l in losses.items()
            }
            scores = {
                n: s.numpy() if s is not None else s
                for n, s in scores.items()
            }

            # log step metrics
            entry = {
                'step': step,
                'timestamp': pd.to_datetime('now'),
                **losses,
                **scores
            }
            log = log.append(entry, ignore_index=True)

            if step % 100 == 0:
                if args.ge_decay > 0:
                    ge_ema_vars = generator_ema.variables + encoder_ema.variables
                    for v_ema, v in zip(ge_ema_vars, ge_vars):
                        v_ema.assign(ema.average(v))

                preview = video_saver.generate_and_save(
                    generator_eval, encoder_eval)

            if step % 1000 == 0:
                log.to_csv(log_file, index=False)
                checkpoint.write(file_prefix=last_ckpt_path)

                auc, balanced_accuracy = evaluate(generator_eval,
                                                  encoder_eval,
                                                  discriminator_features,
                                                  test_dataset,
                                                  test_labels,
                                                  patch_size=args.patch_size,
                                                  lambda_=args.lambda_)

                entry = {
                    'step': step,
                    'auc': auc,
                    'balanced_accuracy': balanced_accuracy
                }
                metrics = metrics.append(entry, ignore_index=True)
                metrics.to_csv(metrics_file, index=False)

                if auc > best_metric:
                    best_metric = auc
                    checkpoint.write(file_prefix=best_ckpt_path)

                # save last image to inspect it during training
                imageio.imwrite(last_recon_file, preview)

                recon = losses['images_reconstruction_loss']
                if recon < best_recon:
                    best_recon = recon
                    imageio.imwrite(best_recon_file, preview)

                progress.set_postfix({
                    'AUC': f'{auc:.1%}',
                    'BalAcc': f'{balanced_accuracy:.1%}',
                    'BestAUC': f'{best_metric:.1%}',
                })

    except KeyboardInterrupt:
        checkpoint.write(file_prefix=last_ckpt_path)
    finally:
        log.to_csv(log_file, index=False)
        video_saver.close()

    # score the test set
    checkpoint.read(best_ckpt_path)

    auc, balanced_accuracy = evaluate(generator,
                                      encoder,
                                      discriminator_features,
                                      test_dataset,
                                      test_labels,
                                      patch_size=args.patch_size,
                                      lambda_=args.lambda_)
    print(f'{args.category}: AUC={auc}, BalAcc={balanced_accuracy}')
Пример #3
0
def train_wgan(batch_size, epochs, image_shape):

    enc_model_1 = model.make_encoder()
    img = Input(shape=input_shape)
    z = enc_model_1(img)
    encoder1 = Model(img, z)

    z = Input(shape=(latent_dim,))
    modelG = model.construct_generator()
    gen_img = modelG(z)
    generator = Model(z, gen_img)
    critic = model.construct_critic(image_shape)

    critic.trainable = False
    img = Input(shape=input_shape)
    z = encoder1(img)

    img_ = generator(z)
    real = critic(img_)
    optimizer = RMSprop(0.0002)
    gan = Model(img, [real, img_])
    gan.compile(loss=[model.wasserstein_loss, 'mean_absolute_error'], optimizer=optimizer, metrics=None)

    X_train = model.load_data(168, 224)
    number_of_batches = int(X_train.shape[0] / batch_size)

    generator_iterations = 0
    d_loss = 0

    for epoch in range(epochs):

        current_batch = 0

        while current_batch < number_of_batches:

            start_time = time.time()
            # In the first 25 epochs, the critic is updated 100 times
            # for each generator update. In the other epochs the default value is 5
            if generator_iterations < 25 or (generator_iterations + 1) % 500 == 0:
                critic_iterations = 100
            else:
                critic_iterations = 5

            # Update the critic a number of critic iterations
            for critic_iteration in range(critic_iterations):

                if current_batch > number_of_batches:
                    break

                # real_images = dataset_generator.next()
                it_index = np.random.randint(0, number_of_batches - 1)
                real_images = X_train[it_index * batch_size:(it_index + 1) * batch_size]

                current_batch += 1

                # The last batch is smaller than the other ones, so we need to
                # take that into account
                current_batch_size = real_images.shape[0]
                # Generate images
                z = encoder1.predict(real_images)
                generated_images = generator.predict(z)
                # generated_images = generator.predict(noise)

                # Add some noise to the labels that will be fed to the critic
                real_y = np.ones(current_batch_size)
                fake_y = np.ones(current_batch_size) * -1
                # print('real_y', real_y)

                # Let's train the critic
                critic.trainable = True

                # Clip the weights to small numbers near zero
                for layer in critic.layers:
                    weights = layer.get_weights()
                    weights = [np.clip(w, -0.01, 0.01) for w in weights]
                    layer.set_weights(weights)

                d_real = critic.train_on_batch(real_images, real_y)
                d_fake = critic.train_on_batch(generated_images, fake_y)

                d_loss = d_real - d_fake

            # Update the generator
            critic.trainable = False
            itt_index = np.random.randint(0, number_of_batches - 1)
            imgs = X_train[itt_index * batch_size:(itt_index + 1) * batch_size]
            # We try to mislead the critic by giving the opposite labels
            fake_yy = np.ones(current_batch_size)
            g_loss = gan.train_on_batch(imgs, [fake_yy, imgs])

            time_elapsed = time.time() - start_time
            print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_G_imgs: %f -> %f s'
                  % (epoch, epochs, current_batch, number_of_batches, generator_iterations,
                     d_loss, g_loss[0], g_loss[1], time_elapsed))

            generator_iterations += 1