Пример #1
0
def train():
    param = Parameter()

    # 1. Build models
    encoder = Encoder(param).model()
    decoder = Decoder(param).model()
    discriminator = Discriminator(param).model()

    # 2. Set optimizers
    opt_ae = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_ae,
                                      beta_1=0.5,
                                      beta_2=0.999,
                                      epsilon=0.01)
    opt_gen = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_gen,
                                       beta_1=0.5,
                                       beta_2=0.999,
                                       epsilon=0.01)
    opt_dis = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_dis,
                                       beta_1=0.5,
                                       beta_2=0.999,
                                       epsilon=0.01)

    # 3. Set trainable variables
    var_ae = encoder.trainable_variables + decoder.trainable_variables
    var_gen = encoder.trainable_variables
    var_dis = discriminator.trainable_variables

    # 4. Load data
    data_loader = MNISTLoader(one_hot=True)
    train_set = data_loader.train.batch(batch_size=param.batch_size,
                                        drop_remainder=True).shuffle(
                                            buffer_size=data_loader.num_train,
                                            reshuffle_each_iteration=True)
    test_set = data_loader.test.batch(batch_size=param.batch_size,
                                      drop_remainder=True)

    # 5. Define loss

    # 6. Etc.
    check_point_dir = os.path.join(param.cur_dir, 'training_checkpoints')

    graph_path = os.path.join(param.cur_dir, 'graph')
    if not os.path.isdir(graph_path):
        os.mkdir(graph_path)

    model_path = os.path.join(param.cur_dir, 'model')
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    check_point_prefix = os.path.join(check_point_dir, 'aae')

    enc_name = 'aae_enc'
    dec_name = 'aae_dec'
    dis_name = 'aae_dis'

    graph = 'aae'

    check_point = tf.train.Checkpoint(opt_gen=opt_gen,
                                      opt_dis=opt_dis,
                                      opt_ae=opt_ae,
                                      encoder=encoder,
                                      decoder=decoder,
                                      discriminator=discriminator)
    ckpt_manager = tf.train.CheckpointManager(
        check_point,
        check_point_dir,
        max_to_keep=5,
        checkpoint_name=check_point_prefix)

    # 7. Define train / validation step ################################################################################
    def training_gen_step(_x, _y):
        with tf.GradientTape() as _gen_tape:
            _z_tilde = encoder(_x, training=True)

            _z_tilde_input = tf.concat([_y, _z_tilde],
                                       axis=-1,
                                       name='z_tilde_input')

            _dis_fake = discriminator(_z_tilde_input, training=True)

            _loss_gen = -tf.reduce_mean(_dis_fake)

        _grad_gen = _gen_tape.gradient(_loss_gen, var_gen)

        opt_gen.apply_gradients(zip(_grad_gen, var_gen))

    def training_step(_x, _y):
        with tf.GradientTape() as _ae_tape, tf.GradientTape(
        ) as _gen_tape, tf.GradientTape() as _dis_tape:
            _z = tf.random.normal(shape=(param.batch_size, param.latent_dim),
                                  mean=0.0,
                                  stddev=param.prior_noise_std,
                                  dtype=tf.float32)

            _z_tilde = encoder(_x, training=True)

            _z_input = tf.concat([_y, _z], axis=-1, name='z_input')
            _z_tilde_input = tf.concat([_y, _z_tilde],
                                       axis=-1,
                                       name='z_tilde_input')

            _x_bar = decoder(_z_tilde_input, training=True)

            _dis_real = discriminator(_z_input, training=True)
            _dis_fake = discriminator(_z_tilde_input, training=True)

            _real_loss, _fake_loss = -tf.reduce_mean(
                _dis_real), tf.reduce_mean(_dis_fake)
            _gp = gradient_penalty(partial(discriminator, training=True),
                                   _z_input, _z_tilde_input)

            _loss_gen = -tf.reduce_mean(_dis_fake)
            _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda
            _loss_ae = tf.reduce_mean(tf.abs(tf.subtract(_x, _x_bar)))

        _grad_ae = _ae_tape.gradient(_loss_ae, var_ae)
        _grad_gen = _gen_tape.gradient(_loss_gen, var_gen)
        _grad_dis = _dis_tape.gradient(_loss_dis, var_dis)

        opt_ae.apply_gradients(zip(_grad_ae, var_ae))
        opt_gen.apply_gradients(zip(_grad_gen, var_gen))
        opt_dis.apply_gradients(zip(_grad_dis, var_dis))

    def validation_step(_x, _y):
        _z = tf.random.normal(shape=(param.batch_size, param.latent_dim),
                              mean=0.0,
                              stddev=param.prior_noise_std,
                              dtype=tf.float32)

        _z_tilde = encoder(_x, training=False)

        _z_input = tf.concat([_y, _z], axis=-1, name='z_input')
        _z_tilde_input = tf.concat([_y, _z_tilde],
                                   axis=-1,
                                   name='z_tilde_input')

        _x_bar = decoder(_z_tilde_input, training=False)
        _x_tilde = decoder(_z_input, training=False)

        _dis_real = discriminator(_z_input, training=False)
        _dis_fake = discriminator(_z_tilde_input, training=False)

        _real_loss, _fake_loss = -tf.reduce_mean(_dis_real), tf.reduce_mean(
            _dis_fake)
        _gp = gradient_penalty(partial(discriminator, training=False),
                               _z_input, _z_tilde_input)

        _loss_gen = -tf.reduce_mean(_dis_fake)
        _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda
        _loss_ae = tf.reduce_mean(tf.abs(tf.subtract(_x, _x_bar)))

        return _x_tilde, _x_bar, _loss_dis.numpy(), _loss_gen.numpy(
        ), _loss_ae.numpy(), (-_fake_loss.numpy() - _real_loss.numpy())

    ####################################################################################################################

    # 8. Train
    start_time = time.time()
    for epoch in range(0, param.max_epoch):
        # 8-1. Train AAE
        num_train = 0
        for x_train, y_train in train_set:
            if num_train % 2 == 0:
                training_step(x_train, tf.cast(y_train, dtype=tf.float32))
            else:
                training_gen_step(x_train, tf.cast(y_train, dtype=tf.float32))
                training_gen_step(x_train, tf.cast(y_train, dtype=tf.float32))
            num_train += 1

        # 8-2. Validation
        num_valid = 0
        val_loss_dis, val_loss_gen, val_loss_ae, val_was_x = [], [], [], []
        for x_valid, y_valid in test_set:
            x_tilde, x_bar, loss_dis, loss_gen, loss_ae, was_x = validation_step(
                x_valid, tf.cast(y_valid, dtype=tf.float32))

            val_loss_dis.append(loss_dis)
            val_loss_gen.append(loss_gen)
            val_loss_ae.append(loss_ae)
            val_was_x.append(was_x)

            num_valid += 1

            if num_valid > param.valid_step:
                break

        # 8-3. Report in training
        elapsed_time = (time.time() - start_time) / 60.
        _val_loss_ae = np.mean(np.reshape(val_loss_ae, (-1)))
        _val_loss_dis = np.mean(np.reshape(val_loss_dis, (-1)))
        _val_loss_gen = np.mean(np.reshape(val_loss_gen, (-1)))
        _val_was_x = np.mean(np.reshape(val_was_x, (-1)))

        print(
            "[Epoch: {:04d}] {:.01f}m.\tdis: {:.6f}\tgen: {:.6f}\tae: {:.6f}\tw_x: {:.6f}"
            .format(epoch, elapsed_time, _val_loss_dis, _val_loss_gen,
                    _val_loss_ae, _val_was_x))

        if epoch % param.save_frequency == 0 and epoch > 1:
            save_decode_image_array(x_valid.numpy(),
                                    path=os.path.join(
                                        graph_path,
                                        '{}_original-{:04d}.png'.format(
                                            graph, epoch)))
            save_decode_image_array(
                x_bar.numpy(),
                path=os.path.join(graph_path,
                                  '{}_decode-{:04d}.png'.format(graph, epoch)))
            save_decode_image_array(x_tilde.numpy(),
                                    path=os.path.join(
                                        graph_path,
                                        '{}_generated-{:04d}.png'.format(
                                            graph, epoch)))
            ckpt_manager.save(checkpoint_number=epoch)

    save_message = "\tSave model: End of training"

    encoder.save_weights(os.path.join(model_path, enc_name))
    decoder.save_weights(os.path.join(model_path, dec_name))
    discriminator.save_weights(os.path.join(model_path, dis_name))

    # 6-3. Report
    print("[Epoch: {:04d}] {:.01f} min.".format(param.max_epoch, elapsed_time))
    print(save_message)
Пример #2
0
def train(w_gp=False):
    param = Parameter()

    # 1. Build models
    generator = Generator(param).model()
    discriminator = Discriminator(param).model()

    # 2. Set optimizers
    opt_gen = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_gen)
    opt_dis = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_dis)

    # 3. Set trainable variables
    var_gen = generator.trainable_variables
    var_dis = discriminator.trainable_variables

    # 4. Load data
    data_loader = MNISTLoader(one_hot=False)
    train_set = data_loader.train.batch(batch_size=param.batch_size,
                                        drop_remainder=True).shuffle(
                                            buffer_size=data_loader.num_train,
                                            reshuffle_each_iteration=True)
    test_set = data_loader.test.batch(batch_size=param.batch_size,
                                      drop_remainder=True)

    # 5. Define loss
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    # 6. Etc.
    check_point_dir = os.path.join(param.cur_dir, 'training_checkpoints')

    graph_path = os.path.join(param.cur_dir, 'graph')
    if not os.path.isdir(graph_path):
        os.mkdir(graph_path)

    model_path = os.path.join(param.cur_dir, 'model')
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    if w_gp is True:
        check_point_prefix = os.path.join(check_point_dir, 'gan_w_gp')
        gen_name = 'gan_w_gp'
        dis_name = 'dis_w_gp'
        graph = 'gan_w_gp'
    else:
        check_point_prefix = os.path.join(check_point_dir, 'gan')
        gen_name = 'gan'
        dis_name = 'dis'
        graph = 'gan'

    check_point = tf.train.Checkpoint(opt_gen=opt_gen,
                                      opt_dis=opt_dis,
                                      generator=generator,
                                      discriminator=discriminator)
    ckpt_manager = tf.train.CheckpointManager(
        check_point,
        check_point_dir,
        max_to_keep=5,
        checkpoint_name=check_point_prefix)

    # 7. Define train / validation step ################################################################################
    def training_step(_x):
        with tf.GradientTape() as _gen_tape, tf.GradientTape() as _dis_tape:
            _noise = tf.random.normal(shape=(param.batch_size,
                                             param.latent_dim),
                                      mean=0.0,
                                      stddev=param.prior_noise_std,
                                      dtype=tf.float32)

            _x_tilde = generator(_noise, training=True)

            _dis_real = discriminator(_x, training=True)
            _dis_fake = discriminator(_x_tilde, training=True)

            if w_gp is True:
                _real_loss, _fake_loss = -tf.reduce_mean(
                    _dis_real), tf.reduce_mean(_dis_fake)
                _gp = gradient_penalty(partial(discriminator, training=True),
                                       _x, _x_tilde)

                _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda
                _loss_gen = -tf.reduce_mean(_dis_fake)
            else:
                _loss_dis = cross_entropy(
                    tf.ones_like(_dis_real), _dis_real) + cross_entropy(
                        tf.zeros_like(_dis_fake), _dis_fake)
                _loss_gen = cross_entropy(tf.ones_like(_dis_fake), _dis_fake)

        _grad_gen = _gen_tape.gradient(_loss_gen, var_gen)
        _grad_dis = _dis_tape.gradient(_loss_dis, var_dis)

        opt_gen.apply_gradients(zip(_grad_gen, var_gen))
        opt_dis.apply_gradients(zip(_grad_dis, var_dis))

    def validation_step(_x):
        _noise = tf.random.normal(shape=(param.batch_size, param.latent_dim),
                                  mean=0.0,
                                  stddev=param.prior_noise_std,
                                  dtype=tf.float32)

        _x_tilde = generator(_noise, training=False)

        _dis_real = discriminator(_x, training=False)
        _dis_fake = discriminator(_x_tilde, training=False)

        if w_gp is True:
            _real_loss, _fake_loss = -tf.reduce_mean(
                _dis_real), tf.reduce_mean(_dis_fake)
            _gp = gradient_penalty(partial(discriminator, training=False), _x,
                                   _x_tilde)

            _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda
            _loss_gen = -tf.reduce_mean(_dis_fake)
        else:
            _loss_dis = cross_entropy(tf.ones_like(_dis_real),
                                      _dis_real) + cross_entropy(
                                          tf.zeros_like(_dis_fake), _dis_fake)
            _loss_gen = cross_entropy(tf.ones_like(_dis_fake), _dis_fake)

        return _x_tilde, _loss_dis.numpy(), _loss_gen.numpy()

    ####################################################################################################################

    # 8. Train
    start_time = time.time()
    for epoch in range(0, param.max_epoch):
        # 8-1. Train GANs
        for x_train, _ in train_set:
            training_step(x_train)

        # 8-2. Validation
        num_valid = 0
        val_loss_dis, val_loss_gen = [], []
        for x_valid, _ in test_set:
            if num_valid == param.valid_step:
                break

            x_tilde, loss_dis, loss_gen = validation_step(x_valid)

            val_loss_dis.append(loss_dis)
            val_loss_gen.append(loss_gen)

            num_valid += 1

        if epoch % param.save_frequency == 0 and epoch > 1:
            save_decode_image_array(x_valid.numpy(),
                                    path=os.path.join(
                                        graph_path,
                                        '{}_original-{:04d}.png'.format(
                                            graph, epoch)))
            save_decode_image_array(x_tilde.numpy(),
                                    path=os.path.join(
                                        graph_path,
                                        '{}_generated-{:04d}.png'.format(
                                            graph, epoch)))

            ckpt_manager.save(checkpoint_number=epoch)

        # 7-3. Report in training
        elapsed_time = (time.time() - start_time) / 60.
        _val_loss_dis = np.mean(np.reshape(val_loss_dis, (-1)))
        _val_loss_gen = np.mean(np.reshape(val_loss_gen, (-1)))
        print(
            "[Epoch: {:04d}] {:.01f} min.\t loss dis: {:.6f}\t loss gen: {:.6f}"
            .format(epoch, elapsed_time, _val_loss_dis, _val_loss_gen))

    save_message = "\tSave model: End of training"

    generator.save_weights(os.path.join(model_path, gen_name))
    discriminator.save_weights(os.path.join(model_path, dis_name))

    # 6-3. Report
    print("[Epoch: {:04d}] {:.01f} min.".format(param.max_epoch, elapsed_time))
    print(save_message)
Пример #3
0
def inference():
    param = Parameter()

    # 1. Build models
    encoder = Encoder(param).model()
    decoder = Decoder(param).model()
    discriminator = Discriminator(param).model()

    # 2. Load data
    data_loader = MNISTLoader(one_hot=True)
    train_set = data_loader.train.batch(batch_size=param.batch_size,
                                        drop_remainder=True)
    test_set = data_loader.test.batch(batch_size=param.batch_size,
                                      drop_remainder=True)

    # 3. Etc.
    graph_path = os.path.join(param.cur_dir, 'graph')
    if not os.path.isdir(graph_path):
        os.mkdir(graph_path)

    model_path = os.path.join(param.cur_dir, 'model')
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    # 4. Load model
    enc_name = 'aae_enc'
    dec_name = 'aae_dec'
    dis_name = 'aae_dis'

    graph = 'aae'

    encoder.load_weights(os.path.join(model_path, enc_name))
    discriminator.load_weights(os.path.join(model_path, dis_name))
    decoder.load_weights(os.path.join(model_path, dec_name))

    # 5. Define loss

    # 6. Define testing step ###########################################################################################
    def testing_step(_x, _y):
        _z = tf.random.normal(shape=(param.batch_size, param.latent_dim),
                              mean=0.0,
                              stddev=param.prior_noise_std,
                              dtype=tf.float32)

        _z_tilde = encoder(_x, training=False)

        _z_input = tf.concat([_y, _z], axis=-1, name='z_input')
        _z_tilde_input = tf.concat([_y, _z_tilde],
                                   axis=-1,
                                   name='z_tilde_input')

        _x_bar = decoder(_z_tilde_input, training=False)
        _x_tilde = decoder(_z_input, training=False)

        _dis_real = discriminator(_z_input, training=False)
        _dis_fake = discriminator(_z_tilde_input, training=False)

        _real_loss, _fake_loss = -tf.reduce_mean(_dis_real), tf.reduce_mean(
            _dis_fake)
        _gp = gradient_penalty(partial(discriminator, training=False),
                               _z_input, _z_tilde_input)

        _loss_gen = -tf.reduce_mean(_dis_fake)
        _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda
        _loss_ae = tf.reduce_mean(tf.abs(tf.subtract(_x, _x_bar)))

        return _x_tilde, _x_bar, _loss_dis.numpy(), _loss_gen.numpy(
        ), _loss_ae.numpy(), (-_fake_loss.numpy() - _real_loss.numpy())

    ####################################################################################################################

    # 6. Inference
    train_loss_dis, train_loss_gen, train_loss_ae, train_was_x = [], [], [], []
    for x_train, y_train in train_set:
        y = tf.cast(y_train, dtype=tf.float32)
        x_tilde, x_bar, loss_dis, loss_gen, loss_ae, was_x = testing_step(
            x_train, y)

        train_loss_dis.append(loss_dis)
        train_loss_gen.append(loss_gen)
        train_loss_ae.append(loss_ae)
        train_was_x.append(was_x)

    num_test = 0
    val_loss_dis, val_loss_gen, val_loss_ae, val_was_x = [], [], [], []
    test_loss_dis, test_loss_gen, test_loss_ae, test_was_x = [], [], [], []

    for x_test, y_test in test_set:
        y = tf.cast(y_test, dtype=tf.float32)
        x_tilde, x_bar, loss_dis, loss_gen, loss_ae, was_x = testing_step(
            x_test, y)

        if num_test <= param.valid_step:
            val_loss_dis.append(loss_dis)
            val_loss_gen.append(loss_gen)
            val_loss_ae.append(loss_ae)
            val_was_x.append(was_x)
        else:
            test_loss_dis.append(loss_dis)
            test_loss_gen.append(loss_gen)
            test_loss_ae.append(loss_ae)
            test_was_x.append(was_x)

        num_test += 1

    _z = tf.random.normal(shape=(param.batch_size, param.latent_dim),
                          mean=0.0,
                          stddev=param.prior_noise_std,
                          dtype=tf.float32)
    for class_idx in range(0, param.num_class):
        _indices = np.ones(param.batch_size, dtype=np.float) * class_idx
        _y = tf.one_hot(indices=_indices,
                        depth=param.num_class,
                        dtype=tf.float32)

        _gen_input = tf.concat([_y, _z], axis=-1, name='gen_input')
        _x_tilde = decoder(_gen_input, training=False)

        save_decode_image_array(
            _x_tilde.numpy(),
            path=os.path.join(graph_path,
                              '{}_c{}_generated.png'.format(graph, class_idx)))

    # 7. Report
    train_loss_dis = np.mean(np.reshape(train_loss_dis, (-1)))
    train_loss_gen = np.mean(np.reshape(train_loss_gen, (-1)))
    train_loss_ae = np.mean(np.reshape(train_loss_ae, (-1)))
    train_was_x = np.mean(np.reshape(train_was_x, (-1)))

    val_loss_dis = np.mean(np.reshape(val_loss_dis, (-1)))
    val_loss_gen = np.mean(np.reshape(val_loss_gen, (-1)))
    val_loss_ae = np.mean(np.reshape(val_loss_ae, (-1)))
    val_was_x = np.mean(np.reshape(val_was_x, (-1)))

    test_loss_dis = np.mean(np.reshape(test_loss_dis, (-1)))
    test_loss_gen = np.mean(np.reshape(test_loss_gen, (-1)))
    test_loss_ae = np.mean(np.reshape(test_loss_ae, (-1)))
    test_was_x = np.mean(np.reshape(test_was_x, (-1)))

    print("[Loss dis] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".
          format(train_loss_dis, val_loss_dis, test_loss_dis))
    print("[Loss gen] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".
          format(train_loss_gen, val_loss_gen, test_loss_gen))
    print("[Loss ae] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".
          format(train_loss_ae, val_loss_ae, test_loss_ae))
    print(
        "[Was X] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".format(
            train_was_x, val_was_x, test_was_x))

    # 8. Draw some samples
    save_decode_image_array(x_test.numpy(),
                            path=os.path.join(graph_path,
                                              '{}_original.png'.format(graph)))
    save_decode_image_array(x_bar.numpy(),
                            path=os.path.join(graph_path,
                                              '{}_decoded.png'.format(graph)))
    save_decode_image_array(x_tilde.numpy(),
                            path=os.path.join(
                                graph_path, '{}_generated.png'.format(graph)))
Пример #4
0
def inference(denoise=True):
    param = Parameter()

    # 1. Build models
    encoder = Encoder(param).model()
    decoder = Decoder(param).model()

    # 2. Load data
    data_loader = MNISTLoader()
    train_set = data_loader.train.batch(batch_size=param.batch_size,
                                        drop_remainder=True)
    test_set = data_loader.test.batch(batch_size=param.batch_size,
                                      drop_remainder=True)

    # 3. Etc.
    graph_path = os.path.join(param.cur_dir, 'graph')
    if not os.path.isdir(graph_path):
        os.mkdir(graph_path)

    model_path = os.path.join(param.cur_dir, 'model')
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    # 4. Load model
    if denoise is True:
        graph = 'denoising_ae'
        enc_name = 'encoder_denoise'
        dec_name = 'decoder_denoise'
    else:
        graph = 'ae'
        enc_name = 'encoder'
        dec_name = 'decoder'

    encoder.load_weights(os.path.join(model_path, enc_name))
    decoder.load_weights(os.path.join(model_path, dec_name))

    # 5. Define loss

    # 6. Define test step ##############################################################################################
    def testing_step(_x):
        if denoise is True:
            _noise = tf.random.normal(shape=(param.batch_size, ) +
                                      param.input_dim,
                                      mean=0.0,
                                      stddev=param.white_noise_std,
                                      dtype=tf.float32)
            _x_test = _x + _noise
        else:
            _x_test = _x

        _z_tilde = encoder(_x_test, training=False)
        _x_bar = decoder(_z_tilde, training=False)

        _loss_ae = tf.reduce_mean(tf.abs(tf.subtract(
            _x, _x_bar)))  # pixel-wise loss

        return _x_test, _x_bar, _loss_ae.numpy()

    ####################################################################################################################

    # 6. Inference
    train_mse = []
    for x_train, _ in train_set:
        x_noise, x_bar, loss_ae = testing_step(x_train)

        train_mse.append(loss_ae)

    num_test = 0
    valid_mse = []
    test_mse = []
    for x_test, _ in test_set:
        x_noise, x_bar, loss_ae = testing_step(x_test)

        if num_test <= param.valid_step:
            valid_mse.append(loss_ae)
        else:
            test_mse.append(loss_ae)
        num_test += 1

    # 7. Report
    print("[Loss] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".format(
        np.mean(train_mse), np.mean(valid_mse), np.mean(test_mse)))

    # 8. Draw some samples
    if denoise is True:
        save_decode_image_array(x_noise.numpy(),
                                path=os.path.join(
                                    graph_path, '{}_noise.png'.format(graph)))
    save_decode_image_array(x_test.numpy(),
                            path=os.path.join(graph_path,
                                              '{}_original.png'.format(graph)))
    save_decode_image_array(x_bar.numpy(),
                            path=os.path.join(graph_path,
                                              '{}_decoded.png'.format(graph)))
Пример #5
0
def train():
    param = Parameter()

    # 1. Build models
    generator = Generator(param).model()
    discriminator = Discriminator(param).model()
    classifier = Classifier(param).model()

    # 2. Set optimizers
    opt_gen = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_gen,
                                       beta_1=0.5,
                                       beta_2=0.999)
    opt_dis = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_dis,
                                       beta_1=0.5,
                                       beta_2=0.999)
    opt_cla = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_cla,
                                       beta_1=0.5,
                                       beta_2=0.999)

    # 3. Set trainable variables
    var_gen = generator.trainable_variables
    var_dis = discriminator.trainable_variables
    var_cla = generator.trainable_variables + classifier.trainable_variables + discriminator.trainable_variables[:
                                                                                                                 -2]

    # 4. Load data
    data_loader = MNISTLoader(one_hot=True)
    train_set = data_loader.train.batch(batch_size=param.batch_size,
                                        drop_remainder=True).shuffle(
                                            buffer_size=data_loader.num_train,
                                            reshuffle_each_iteration=True)
    test_set = data_loader.test.batch(batch_size=param.batch_size,
                                      drop_remainder=True)

    # 5. Define loss

    # 6. Etc.
    check_point_dir = os.path.join(param.cur_dir, 'training_checkpoints')

    graph_path = os.path.join(param.cur_dir, 'graph')
    if not os.path.isdir(graph_path):
        os.mkdir(graph_path)

    model_path = os.path.join(param.cur_dir, 'model')
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    check_point_prefix = os.path.join(check_point_dir, 'acgans')

    gen_name = 'acgans_gen'
    dis_name = 'acgans_dis'
    cla_name = 'acgans_cla'

    graph = 'acgans'

    check_point = tf.train.Checkpoint(opt_gen=opt_gen,
                                      opt_dis=opt_dis,
                                      generator=generator,
                                      discriminator=discriminator,
                                      classifier=classifier)
    ckpt_manager = tf.train.CheckpointManager(
        check_point,
        check_point_dir,
        max_to_keep=5,
        checkpoint_name=check_point_prefix)

    # 7. Define train / validation step ################################################################################
    def training_step(_x, _y):
        with tf.GradientTape() as _gen_tape, tf.GradientTape(
        ) as _dis_tape, tf.GradientTape() as _cla_tape:
            _z = tf.random.uniform(shape=(param.batch_size, param.latent_dim),
                                   minval=-1.0,
                                   maxval=1.0,
                                   dtype=tf.float32)

            _gen_input = tf.concat([_y, _z], axis=-1, name='gen_input')
            _x_tilde = generator(_gen_input, training=True)

            _cla_real_logits, _dis_real = discriminator(_x, training=True)
            _cla_fake_logits, _dis_fake = discriminator(_x_tilde,
                                                        training=True)

            _cla_real = classifier(_cla_real_logits, training=True)
            _cla_fake = classifier(_cla_fake_logits, training=True)

            _loss_cla_real = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=tf.argmax(_y, axis=1), logits=_cla_real))
            _loss_cla_fake = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=tf.argmax(_y, axis=1), logits=_cla_fake))

            _real_loss, _fake_loss = -tf.reduce_mean(
                _dis_real), tf.reduce_mean(_dis_fake)
            _gp = gradient_penalty(partial(discriminator, training=True), _x,
                                   _x_tilde)

            _loss_cla = _loss_cla_real + _loss_cla_fake
            _loss_gen = -tf.reduce_mean(_dis_fake)
            _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda

        _grad_gen = _gen_tape.gradient(_loss_gen, var_gen)
        _grad_dis = _dis_tape.gradient(_loss_dis, var_dis)
        _grad_cla = _cla_tape.gradient(_loss_cla, var_cla)

        opt_dis.apply_gradients(zip(_grad_dis, var_dis))
        opt_gen.apply_gradients(zip(_grad_gen, var_gen))
        opt_cla.apply_gradients(zip(_grad_cla, var_cla))

    def validation_step(_x, _y):
        _z = tf.random.uniform(shape=(param.batch_size, param.latent_dim),
                               minval=-1.0,
                               maxval=1.0,
                               dtype=tf.float32)

        _gen_input = tf.concat([_y, _z], axis=-1, name='gen_input')
        _x_tilde = generator(_gen_input, training=False)

        _cla_real_logits, _dis_real = discriminator(_x, training=False)
        _cla_fake_logits, _dis_fake = discriminator(_x_tilde, training=False)

        _cla_real = classifier(_cla_real_logits, training=False)
        _cla_fake = classifier(_cla_fake_logits, training=False)

        _loss_gen = -tf.reduce_mean(_dis_fake)

        _real_loss, _fake_loss = -tf.reduce_mean(_dis_real), tf.reduce_mean(
            _dis_fake)
        _gp = gradient_penalty(partial(discriminator, training=False), _x,
                               _x_tilde)
        _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda

        _loss_cla_real = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(
                _y, axis=1),
                                                           logits=_cla_real))
        _loss_cla_fake = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(
                _y, axis=1),
                                                           logits=_cla_fake))
        _loss_cla = _loss_cla_real + _loss_cla_fake

        return _x_tilde, _loss_gen.numpy(), _loss_dis.numpy(
        ), _loss_cla_real.numpy(), _loss_cla_fake.numpy(), (
            -_fake_loss.numpy() - _real_loss.numpy())

    ####################################################################################################################

    # 8. Train
    start_time = time.time()
    for epoch in range(0, param.max_epoch):
        # 8-1. Train ACGANs
        num_train = 0
        for x_train, y_train in train_set:
            training_step(x_train, tf.cast(y_train, dtype=tf.float32))
            num_train += 1

        # 8-2. Validation
        num_valid = 0
        val_loss_dis, val_loss_gen, val_loss_cla_real, val_loss_cla_fake, val_was_x = [], [], [], [], []
        for x_valid, y_valid in test_set:
            x_tilde, loss_dis, loss_gen, loss_cla_real, loss_cla_fake, was_x = validation_step(
                x_valid, tf.cast(y_valid, dtype=tf.float32))

            val_loss_dis.append(loss_dis)
            val_loss_gen.append(loss_gen)
            val_loss_cla_real.append(loss_cla_real)
            val_loss_cla_fake.append(loss_cla_fake)
            val_was_x.append(was_x)

            num_valid += 1

            if num_valid > param.valid_step:
                break

        # 8-3. Report in training
        elapsed_time = (time.time() - start_time) / 60.
        _val_loss_cla_real = np.mean(np.reshape(val_loss_cla_real, (-1)))
        _val_loss_cla_fake = np.mean(np.reshape(val_loss_cla_fake, (-1)))
        _val_loss_dis = np.mean(np.reshape(val_loss_dis, (-1)))
        _val_loss_gen = np.mean(np.reshape(val_loss_gen, (-1)))
        _val_was_x = np.mean(np.reshape(val_was_x, (-1)))

        print(
            "[{:04d}] {:.01f} m.\tdis: {:.6f}\tgen: {:.6f}\tcla_r: {:.6f}\tcla_f: {:.6f}\t w_x: {:.4f}"
            .format(epoch, elapsed_time, _val_loss_dis, _val_loss_gen,
                    _val_loss_cla_real, _val_loss_cla_fake, _val_was_x))

        if epoch % param.save_frequency == 0 and epoch > 1:
            save_decode_image_array(x_valid.numpy(),
                                    path=os.path.join(
                                        graph_path,
                                        '{}_original-{:04d}.png'.format(
                                            graph, epoch)))
            save_decode_image_array(x_tilde.numpy(),
                                    path=os.path.join(
                                        graph_path,
                                        '{}_generated-{:04d}.png'.format(
                                            graph, epoch)))
            ckpt_manager.save(checkpoint_number=epoch)

    save_message = "\tSave model: End of training"

    generator.save_weights(os.path.join(model_path, gen_name))
    discriminator.save_weights(os.path.join(model_path, dis_name))
    classifier.save_weights(os.path.join(model_path, cla_name))

    print("[Epoch: {:04d}] {:.01f} min.".format(param.max_epoch, elapsed_time))
    print(save_message)
Пример #6
0
def train(denoise=False):
    param = Parameter()

    # 1. Build models
    encoder = Encoder(param).model()
    decoder = Decoder(param).model()

    # 2. Set optimizers
    opt_ae = tf.keras.optimizers.Adam(learning_rate=param.learning_rate_ae)

    # 3. Set trainable variables
    var_ae = encoder.trainable_variables + decoder.trainable_variables

    # 4. Load data
    data_loader = MNISTLoader()
    train_set = data_loader.train.batch(batch_size=param.batch_size,
                                        drop_remainder=True).shuffle(buffer_size=data_loader.num_train,
                                                                     reshuffle_each_iteration=True)
    test_set = data_loader.test.batch(batch_size=param.batch_size, drop_remainder=True)

    # 5. Define loss

    # 6. Etc.
    check_point_dir = os.path.join(param.cur_dir, 'training_checkpoints')

    model_path = os.path.join(param.cur_dir, 'model')
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    graph_path = os.path.join(param.cur_dir, 'graph')
    if not os.path.isdir(graph_path):
        os.mkdir(graph_path)

    minimum_mse = 100.
    num_effective_epoch = 0

    if denoise is True:
        check_point_prefix = os.path.join(check_point_dir, 'ae_denoise')
        graph = 'ae_denoise'
        enc_name = 'encoder_denoise'
        dec_name = 'decoder_denoise'
    else:
        check_point_prefix = os.path.join(check_point_dir, 'ae')
        graph = 'ae'
        enc_name = 'encoder'
        dec_name = 'decoder'

    check_point = tf.train.Checkpoint(opt_ae=opt_ae, encoder=encoder, decoder=decoder)
    ckpt_manager = tf.train.CheckpointManager(check_point, check_point_dir, max_to_keep=5,
                                              checkpoint_name=check_point_prefix)

    # 7. Define train / validation step ################################################################################
    def training_step(_x):
        if denoise is True:
            _noise = tf.random.normal(shape=(param.batch_size,) + param.input_dim, mean=0.0,
                                      stddev=param.white_noise_std, dtype=tf.float32)
            _x_train = _x + _noise
        else:
            _x_train = _x

        with tf.GradientTape() as _ae_tape:
            _z_tilde = encoder(_x_train, training=True)
            _x_bar = decoder(_z_tilde, training=True)

            _loss_ae = tf.reduce_mean(tf.abs(tf.subtract(_x, _x_bar)))  # pixel-wise loss

        _grad_ae = _ae_tape.gradient(_loss_ae, var_ae)
        opt_ae.apply_gradients(zip(_grad_ae, var_ae))

    def validation_step(_x):
        if denoise is True:
            _noise = tf.random.normal(shape=(param.batch_size,) + param.input_dim, mean=0.0,
                                      stddev=param.white_noise_std, dtype=tf.float32)
            _x_valid = _x + _noise
        else:
            _x_valid = _x

        _z_tilde = encoder(_x_valid, training=False)
        _x_bar = decoder(_z_tilde, training=False)

        _loss_ae = tf.reduce_mean(tf.abs(tf.subtract(_x, _x_bar)))  # pixel-wise loss

        return _x_valid, _x_bar, _loss_ae.numpy()
    ####################################################################################################################

    # 8. Train
    start_time = time.time()
    for epoch in range(0, param.max_epoch):
        # 8-1. Train auto encoder
        for x_train, _ in train_set:
            training_step(x_train)

        # 8-2. Validation
        num_valid = 0
        mse_valid = []
        for x_valid, _ in test_set:
            x_noise, x_bar, loss_ae = validation_step(x_valid)
            mse_valid.append(loss_ae)

            if num_valid == param.valid_step:
                break
            num_valid += 1

        valid_loss = np.mean(mse_valid)

        save_message = ''
        if minimum_mse > valid_loss:
            num_effective_epoch = 0
            minimum_mse = valid_loss
            save_message = "\tSave model: detecting lowest L1: {:.6f} at epoch {:04d}".format(minimum_mse, epoch)

            encoder.save_weights(os.path.join(model_path, enc_name))
            decoder.save_weights(os.path.join(model_path, dec_name))

        elapsed_time = (time.time() - start_time) / 60.

        # 8-3. Report
        print("[Epoch: {:04d}] {:.01f} min. ae loss: {:.6f} Effective: {}".format(epoch, elapsed_time,
                                                                                  valid_loss,
                                                                                  (num_effective_epoch == 0)))
        print("{}".format(save_message))

        if epoch % param.save_frequency == 0 and epoch > 1:
            if denoise is True:
                save_decode_image_array(x_noise.numpy(), path=os.path.join(graph_path,
                                                                           '{}_noise_{:04d}.png'.format(graph, epoch)))
            save_decode_image_array(x_valid.numpy(), path=os.path.join(graph_path,
                                                                       '{}_original_{:04d}.png'.format(graph, epoch)))
            save_decode_image_array(x_bar.numpy(), path=os.path.join(graph_path, '{}_decoded_{:04d}.png'.format(graph,
                                                                                                                epoch)))
            ckpt_manager.save(checkpoint_number=epoch)

        if num_effective_epoch >= param.num_early_stopping:
            print("\t Early stopping at epoch {:04d}!".format(epoch))
            break

        num_effective_epoch += 1
    print("\t Stopping at epoch {:04d}!".format(epoch))
Пример #7
0
def inference(w_gp=False):
    param = Parameter()

    # 1. Build models
    generator = Generator(param).model()
    discriminator = Discriminator(param).model()

    # 2. Load data
    data_loader = MNISTLoader(one_hot=False)
    train_set = data_loader.train.batch(batch_size=param.batch_size,
                                        drop_remainder=True)
    test_set = data_loader.test.batch(batch_size=param.batch_size,
                                      drop_remainder=True)

    # 3. Etc.
    graph_path = os.path.join(param.cur_dir, 'graph')
    if not os.path.isdir(graph_path):
        os.mkdir(graph_path)

    model_path = os.path.join(param.cur_dir, 'model')
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    # 4. Load model
    if w_gp is True:
        gen_name = 'gan_w_gp'
        dis_name = 'dis_w_gp'
        graph = 'gan_w_gp'
    else:
        gen_name = 'gan'
        dis_name = 'dis'
        graph = 'gan'

    generator.load_weights(os.path.join(model_path, gen_name))
    discriminator.load_weights(os.path.join(model_path, dis_name))

    # 5. Define loss
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    # 6. Inference
    train_dis_loss, train_gen_loss = [], []
    for x_train, _ in train_set:
        # noise = tf.random.uniform(shape=(param.batch_size, param.latent_dim), minval=-1, maxval=1, dtype=tf.float32)
        noise = tf.random.normal(shape=(param.batch_size, param.latent_dim),
                                 mean=0.0,
                                 stddev=0.3,
                                 dtype=tf.float32)
        x_tilde = generator(noise, training=False)

        dis_real = discriminator(x_train, training=False)
        dis_fake = discriminator(x_tilde, training=False)

        if w_gp is False:
            loss_dis = cross_entropy(tf.ones_like(dis_real),
                                     dis_real) + cross_entropy(
                                         tf.zeros_like(dis_fake), dis_fake)
            loss_gen = cross_entropy(tf.ones_like(dis_fake), dis_fake)
        else:
            real_loss, fake_loss = -tf.reduce_mean(dis_real), tf.reduce_mean(
                dis_fake)
            gp = gradient_penalty(partial(discriminator, training=False),
                                  x_train, x_tilde)

            loss_dis = (real_loss + fake_loss) + gp * param.w_gp_lambda
            loss_gen = -tf.reduce_mean(dis_fake)

        train_dis_loss.append(loss_dis.numpy())
        train_gen_loss.append(loss_gen.numpy())

    num_test = 0
    valid_dis_loss, valid_gen_loss = [], []
    test_dis_loss, test_gen_loss = [], []

    for x_test, _ in test_set:
        # noise = tf.random.uniform(shape=(param.batch_size, param.latent_dim), minval=-1, maxval=1, dtype=tf.float32)
        noise = tf.random.normal(shape=(param.batch_size, param.latent_dim),
                                 mean=0.0,
                                 stddev=0.3,
                                 dtype=tf.float32)
        x_tilde = generator(noise, training=False)

        dis_real = discriminator(x_test, training=False)
        dis_fake = discriminator(x_tilde, training=False)

        if w_gp is False:
            loss_dis = cross_entropy(tf.ones_like(dis_real),
                                     dis_real) + cross_entropy(
                                         tf.zeros_like(dis_fake), dis_fake)
            loss_gen = cross_entropy(tf.ones_like(dis_fake), dis_fake)
        else:
            real_loss, fake_loss = -tf.reduce_mean(dis_real), tf.reduce_mean(
                dis_fake)
            gp = gradient_penalty(partial(discriminator, training=False),
                                  x_test, x_tilde)

            loss_dis = (real_loss + fake_loss) + gp * param.w_gp_lambda
            loss_gen = -tf.reduce_mean(dis_fake)

        if num_test <= param.valid_step:
            valid_dis_loss.append(loss_dis.numpy())
            valid_gen_loss.append(loss_gen.numpy())

        else:
            test_dis_loss.append(loss_dis.numpy())
            test_gen_loss.append(loss_gen.numpy())
        num_test += 1

    # 7. Report
    train_dis_loss = np.mean(np.reshape(train_dis_loss, (-1)))
    valid_dis_loss = np.mean(np.reshape(valid_dis_loss, (-1)))
    test_dis_loss = np.mean(np.reshape(test_dis_loss, (-1)))

    train_gen_loss = np.mean(np.reshape(train_gen_loss, (-1)))
    valid_gen_loss = np.mean(np.reshape(valid_gen_loss, (-1)))
    test_gen_loss = np.mean(np.reshape(test_gen_loss, (-1)))

    print("[Loss dis] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".
          format(train_dis_loss, valid_dis_loss, test_dis_loss))
    print("[Loss gen] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".
          format(train_gen_loss, valid_gen_loss, test_gen_loss))

    # 8. Draw some samples
    save_decode_image_array(x_test.numpy(),
                            path=os.path.join(graph_path,
                                              '{}_original.png'.format(graph)))
    save_decode_image_array(x_tilde.numpy(),
                            path=os.path.join(
                                graph_path, '{}_generated.png'.format(graph)))
Пример #8
0
def inference():
    param = Parameter()

    # 1. Build models
    generator = Generator(param).model()
    discriminator = Discriminator(param).model()
    classifier = Classifier(param).model()

    # 2. Load data
    data_loader = MNISTLoader(one_hot=True)
    train_set = data_loader.train.batch(batch_size=param.batch_size,
                                        drop_remainder=True)
    test_set = data_loader.test.batch(batch_size=param.batch_size,
                                      drop_remainder=True)

    # 3. Etc.
    graph_path = os.path.join(param.cur_dir, 'graph')
    if not os.path.isdir(graph_path):
        os.mkdir(graph_path)

    model_path = os.path.join(param.cur_dir, 'model')
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    # 4. Load model
    gen_name = 'acgans_gen'
    dis_name = 'acgans_dis'
    cla_name = 'acgans_cla'

    graph = 'acgans'

    generator.load_weights(os.path.join(model_path, gen_name))
    discriminator.load_weights(os.path.join(model_path, dis_name))
    classifier.load_weights(os.path.join(model_path, cla_name))

    # 5. Define loss

    # 6. Define testing step ###########################################################################################
    def testing_step(_x, _y):
        _z = tf.random.uniform(shape=(param.batch_size, param.latent_dim),
                               minval=-1.0,
                               maxval=1.0,
                               dtype=tf.float32)

        _gen_input = tf.concat([_y, _z], axis=-1, name='gen_input')
        _x_tilde = generator(_gen_input, training=False)

        _cla_real_logits, _dis_real = discriminator(_x, training=False)
        _cla_fake_logits, _dis_fake = discriminator(_x_tilde, training=False)

        _cla_real = classifier(_cla_real_logits, training=False)
        _cla_fake = classifier(_cla_fake_logits, training=False)

        _loss_gen = -tf.reduce_mean(_dis_fake)

        _real_loss, _fake_loss = -tf.reduce_mean(_dis_real), tf.reduce_mean(
            _dis_fake)
        _gp = gradient_penalty(partial(discriminator, training=False), _x,
                               _x_tilde)
        _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda

        _loss_cla_real = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(
                _y, axis=1),
                                                           logits=_cla_real))
        _loss_cla_fake = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(
                _y, axis=1),
                                                           logits=_cla_fake))
        _loss_cla = _loss_cla_real + _loss_cla_fake

        return _x_tilde, _cla_real, _loss_gen.numpy(), _loss_dis.numpy(
        ), _loss_cla_real.numpy(), _loss_cla_fake.numpy(), (
            -_fake_loss.numpy() - _real_loss.numpy())

    ####################################################################################################################

    # 6. Inference
    train_real_label, train_prediction = [], []
    train_loss_dis, train_loss_gen, train_loss_cla_real, train_loss_cla_fake, train_was_x = [], [], [], [], []
    for x_train, y_train in train_set:
        y = tf.cast(y_train, dtype=tf.float32)
        x_tilde, prediction, loss_dis, loss_gen, loss_cla_real, loss_cla_fake, was_x = testing_step(
            x_train, y)

        _pred_y = tf.argmax(prediction, axis=1)

        train_loss_dis.append(loss_dis)
        train_loss_gen.append(loss_gen)
        train_loss_cla_real.append(loss_cla_real)
        train_loss_cla_fake.append(loss_cla_fake)
        train_was_x.append(was_x)

        train_real_label.append(tf.argmax(y_train, axis=1).numpy())
        train_prediction.append(_pred_y)

    num_test = 0
    valid_real_label, valid_prediction = [], []
    test_real_label, test_prediction = [], []
    val_loss_dis, val_loss_gen, val_loss_cla_real, val_loss_cla_fake, val_was_x = [], [], [], [], []
    test_loss_dis, test_loss_gen, test_loss_cla_real, test_loss_cla_fake, test_was_x = [], [], [], [], []

    for x_test, y_test in test_set:
        y = tf.cast(y_test, dtype=tf.float32)
        x_tilde, prediction, loss_dis, loss_gen, loss_cla_real, loss_cla_fake, was_x = testing_step(
            x_test, y)

        _pred_y = tf.argmax(prediction, axis=1)

        if num_test <= param.valid_step:
            val_loss_dis.append(loss_dis)
            val_loss_gen.append(loss_gen)
            val_loss_cla_real.append(loss_cla_real)
            val_loss_cla_fake.append(loss_cla_fake)
            val_was_x.append(was_x)

            valid_real_label.append(tf.argmax(y_test, axis=1).numpy())
            valid_prediction.append(_pred_y)

        else:
            test_loss_dis.append(loss_dis)
            test_loss_gen.append(loss_gen)
            test_loss_cla_real.append(loss_cla_real)
            test_loss_cla_fake.append(loss_cla_fake)
            test_was_x.append(was_x)

            test_real_label.append(tf.argmax(y_test, axis=1).numpy())
            test_prediction.append(_pred_y)
        num_test += 1

    for class_idx in range(0, param.num_class):
        _indices = np.ones(param.batch_size, dtype=np.float) * class_idx
        _y = tf.one_hot(indices=_indices,
                        depth=param.num_class,
                        dtype=tf.float32)
        _z = tf.random.uniform(shape=(param.batch_size, param.latent_dim),
                               minval=-1.0,
                               maxval=1.0,
                               dtype=tf.float32)

        _gen_input = tf.concat([_y, _z], axis=-1, name='gen_input')
        _x_tilde = generator(_gen_input, training=False)

        save_decode_image_array(
            _x_tilde.numpy(),
            path=os.path.join(graph_path,
                              '{}_c{}_generated.png'.format(graph, class_idx)))

    # 7. Report
    train_loss_dis = np.mean(np.reshape(train_loss_dis, (-1)))
    train_loss_gen = np.mean(np.reshape(train_loss_gen, (-1)))
    train_loss_cla_real = np.mean(np.reshape(train_loss_cla_real, (-1)))
    train_loss_cla_fake = np.mean(np.reshape(train_loss_cla_fake, (-1)))
    train_was_x = np.mean(np.reshape(train_was_x, (-1)))

    val_loss_dis = np.mean(np.reshape(val_loss_dis, (-1)))
    val_loss_gen = np.mean(np.reshape(val_loss_gen, (-1)))
    val_loss_cla_real = np.mean(np.reshape(val_loss_cla_real, (-1)))
    val_loss_cla_fake = np.mean(np.reshape(val_loss_cla_fake, (-1)))
    val_was_x = np.mean(np.reshape(val_was_x, (-1)))

    test_loss_dis = np.mean(np.reshape(test_loss_dis, (-1)))
    test_loss_gen = np.mean(np.reshape(test_loss_gen, (-1)))
    test_loss_cla_real = np.mean(np.reshape(test_loss_cla_real, (-1)))
    test_loss_cla_fake = np.mean(np.reshape(test_loss_cla_fake, (-1)))
    test_was_x = np.mean(np.reshape(test_was_x, (-1)))

    train_real_label, train_prediction = np.reshape(train_real_label,
                                                    (-1)), np.reshape(
                                                        train_prediction, (-1))
    valid_real_label, valid_prediction = np.reshape(valid_real_label,
                                                    (-1)), np.reshape(
                                                        valid_prediction, (-1))
    test_real_label, test_prediction = np.reshape(test_real_label,
                                                  (-1)), np.reshape(
                                                      test_prediction, (-1))

    train_acc = metrics.accuracy_score(train_real_label, train_prediction)
    valid_acc = metrics.accuracy_score(valid_real_label, valid_prediction)
    test_acc = metrics.accuracy_score(test_real_label, test_prediction)

    print(
        "[Loss cla fake] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".
        format(train_loss_cla_fake, val_loss_cla_fake, test_loss_cla_fake))
    print(
        "[Loss cla real] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".
        format(train_loss_cla_real, val_loss_cla_real, test_loss_cla_real))
    print("[Loss dis] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".
          format(train_loss_dis, val_loss_dis, test_loss_dis))
    print("[Loss gen] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".
          format(train_loss_gen, val_loss_gen, test_loss_gen))

    print(
        "[Was X] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".format(
            train_was_x, val_was_x, test_was_x))
    print("[Accuracy] Train: {:.05f}\t Validation: {:.05f}\t Test: {:.05f}".
          format(train_acc, valid_acc, test_acc))

    # 8. Draw some samples
    save_decode_image_array(x_test.numpy(),
                            path=os.path.join(graph_path,
                                              '{}_original.png'.format(graph)))
    save_decode_image_array(x_tilde.numpy(),
                            path=os.path.join(
                                graph_path, '{}_generated.png'.format(graph)))