예제 #1
0
파일: train.py 프로젝트: remmarp/TF2_MNIST
    def validation_step(_x, _y):
        _z = tf.random.normal(shape=(param.batch_size, param.latent_dim),
                              mean=0.0,
                              stddev=param.prior_noise_std,
                              dtype=tf.float32)

        _z_tilde = encoder(_x, training=False)

        _z_input = tf.concat([_y, _z], axis=-1, name='z_input')
        _z_tilde_input = tf.concat([_y, _z_tilde],
                                   axis=-1,
                                   name='z_tilde_input')

        _x_bar = decoder(_z_tilde_input, training=False)
        _x_tilde = decoder(_z_input, training=False)

        _dis_real = discriminator(_z_input, training=False)
        _dis_fake = discriminator(_z_tilde_input, training=False)

        _real_loss, _fake_loss = -tf.reduce_mean(_dis_real), tf.reduce_mean(
            _dis_fake)
        _gp = gradient_penalty(partial(discriminator, training=False),
                               _z_input, _z_tilde_input)

        _loss_gen = -tf.reduce_mean(_dis_fake)
        _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda
        _loss_ae = tf.reduce_mean(tf.abs(tf.subtract(_x, _x_bar)))

        return _x_tilde, _x_bar, _loss_dis.numpy(), _loss_gen.numpy(
        ), _loss_ae.numpy(), (-_fake_loss.numpy() - _real_loss.numpy())
예제 #2
0
    def training_step(_x):
        with tf.GradientTape() as _gen_tape, tf.GradientTape() as _dis_tape:
            _noise = tf.random.normal(shape=(param.batch_size,
                                             param.latent_dim),
                                      mean=0.0,
                                      stddev=param.prior_noise_std,
                                      dtype=tf.float32)

            _x_tilde = generator(_noise, training=True)

            _dis_real = discriminator(_x, training=True)
            _dis_fake = discriminator(_x_tilde, training=True)

            if w_gp is True:
                _real_loss, _fake_loss = -tf.reduce_mean(
                    _dis_real), tf.reduce_mean(_dis_fake)
                _gp = gradient_penalty(partial(discriminator, training=True),
                                       _x, _x_tilde)

                _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda
                _loss_gen = -tf.reduce_mean(_dis_fake)
            else:
                _loss_dis = cross_entropy(
                    tf.ones_like(_dis_real), _dis_real) + cross_entropy(
                        tf.zeros_like(_dis_fake), _dis_fake)
                _loss_gen = cross_entropy(tf.ones_like(_dis_fake), _dis_fake)

        _grad_gen = _gen_tape.gradient(_loss_gen, var_gen)
        _grad_dis = _dis_tape.gradient(_loss_dis, var_dis)

        opt_gen.apply_gradients(zip(_grad_gen, var_gen))
        opt_dis.apply_gradients(zip(_grad_dis, var_dis))
예제 #3
0
    def validation_step(_x):
        _noise = tf.random.normal(shape=(param.batch_size, param.latent_dim),
                                  mean=0.0,
                                  stddev=param.prior_noise_std,
                                  dtype=tf.float32)

        _x_tilde = generator(_noise, training=False)

        _dis_real = discriminator(_x, training=False)
        _dis_fake = discriminator(_x_tilde, training=False)

        if w_gp is True:
            _real_loss, _fake_loss = -tf.reduce_mean(
                _dis_real), tf.reduce_mean(_dis_fake)
            _gp = gradient_penalty(partial(discriminator, training=False), _x,
                                   _x_tilde)

            _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda
            _loss_gen = -tf.reduce_mean(_dis_fake)
        else:
            _loss_dis = cross_entropy(tf.ones_like(_dis_real),
                                      _dis_real) + cross_entropy(
                                          tf.zeros_like(_dis_fake), _dis_fake)
            _loss_gen = cross_entropy(tf.ones_like(_dis_fake), _dis_fake)

        return _x_tilde, _loss_dis.numpy(), _loss_gen.numpy()
예제 #4
0
def D_graph(sess, phd):
    real_labels = train_labels * 2 - 1
    fake_labels = -real_labels

    u = Genc.build(train_images, phd['is_training_d'])
    fake_images = Gdec.build(u, fake_labels, phd['is_training_d'])
    train_gan_logit, train_cls_logit = D.build(train_images,
                                               phd['is_training_d'])
    fake_gan_logit, fake_cls_logit = D.build(fake_images, phd['is_training_d'])

    train_gan_loss = -tf.reduce_mean(train_gan_logit)
    fake_gan_loss = tf.reduce_mean(fake_gan_logit)
    gradien_p = util.gradient_penalty(
        lambda x: D.build(x, phd['is_training_d'])[0], train_images,
        fake_images, '1-gp', 'line')
    cls_loss = tf.losses.sigmoid_cross_entropy(
        tf.expand_dims(train_labels, axis=-1), train_cls_logit)
    reg_loss = tf.reduce_sum(D.reg_loss)

    final_loss = train_gan_loss + fake_gan_loss + 10.0 * gradien_p + cls_loss + reg_loss
    update_op = tf.train.AdamOptimizer(phd['lr_d'],
                                       beta1=0.5).minimize(final_loss,
                                                           var_list=D.vars)

    return final_loss, update_op
예제 #5
0
파일: train.py 프로젝트: remmarp/TF2_MNIST
    def training_step(_x, _y):
        with tf.GradientTape() as _gen_tape, tf.GradientTape(
        ) as _dis_tape, tf.GradientTape() as _cla_tape:
            _z = tf.random.uniform(shape=(param.batch_size, param.latent_dim),
                                   minval=-1.0,
                                   maxval=1.0,
                                   dtype=tf.float32)

            _gen_input = tf.concat([_y, _z], axis=-1, name='gen_input')
            _x_tilde = generator(_gen_input, training=True)

            _cla_real_logits, _dis_real = discriminator(_x, training=True)
            _cla_fake_logits, _dis_fake = discriminator(_x_tilde,
                                                        training=True)

            _cla_real = classifier(_cla_real_logits, training=True)
            _cla_fake = classifier(_cla_fake_logits, training=True)

            _loss_cla_real = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=tf.argmax(_y, axis=1), logits=_cla_real))
            _loss_cla_fake = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=tf.argmax(_y, axis=1), logits=_cla_fake))

            _real_loss, _fake_loss = -tf.reduce_mean(
                _dis_real), tf.reduce_mean(_dis_fake)
            _gp = gradient_penalty(partial(discriminator, training=True), _x,
                                   _x_tilde)

            _loss_cla = _loss_cla_real + _loss_cla_fake
            _loss_gen = -tf.reduce_mean(_dis_fake)
            _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda

        _grad_gen = _gen_tape.gradient(_loss_gen, var_gen)
        _grad_dis = _dis_tape.gradient(_loss_dis, var_dis)
        _grad_cla = _cla_tape.gradient(_loss_cla, var_cla)

        opt_dis.apply_gradients(zip(_grad_dis, var_dis))
        opt_gen.apply_gradients(zip(_grad_gen, var_gen))
        opt_cla.apply_gradients(zip(_grad_cla, var_cla))
예제 #6
0
파일: train.py 프로젝트: remmarp/TF2_MNIST
    def training_step(_x, _y):
        with tf.GradientTape() as _ae_tape, tf.GradientTape(
        ) as _gen_tape, tf.GradientTape() as _dis_tape:
            _z = tf.random.normal(shape=(param.batch_size, param.latent_dim),
                                  mean=0.0,
                                  stddev=param.prior_noise_std,
                                  dtype=tf.float32)

            _z_tilde = encoder(_x, training=True)

            _z_input = tf.concat([_y, _z], axis=-1, name='z_input')
            _z_tilde_input = tf.concat([_y, _z_tilde],
                                       axis=-1,
                                       name='z_tilde_input')

            _x_bar = decoder(_z_tilde_input, training=True)

            _dis_real = discriminator(_z_input, training=True)
            _dis_fake = discriminator(_z_tilde_input, training=True)

            _real_loss, _fake_loss = -tf.reduce_mean(
                _dis_real), tf.reduce_mean(_dis_fake)
            _gp = gradient_penalty(partial(discriminator, training=True),
                                   _z_input, _z_tilde_input)

            _loss_gen = -tf.reduce_mean(_dis_fake)
            _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda
            _loss_ae = tf.reduce_mean(tf.abs(tf.subtract(_x, _x_bar)))

        _grad_ae = _ae_tape.gradient(_loss_ae, var_ae)
        _grad_gen = _gen_tape.gradient(_loss_gen, var_gen)
        _grad_dis = _dis_tape.gradient(_loss_dis, var_dis)

        opt_ae.apply_gradients(zip(_grad_ae, var_ae))
        opt_gen.apply_gradients(zip(_grad_gen, var_gen))
        opt_dis.apply_gradients(zip(_grad_dis, var_dis))
예제 #7
0
파일: train.py 프로젝트: remmarp/TF2_MNIST
    def validation_step(_x, _y):
        _z = tf.random.uniform(shape=(param.batch_size, param.latent_dim),
                               minval=-1.0,
                               maxval=1.0,
                               dtype=tf.float32)

        _gen_input = tf.concat([_y, _z], axis=-1, name='gen_input')
        _x_tilde = generator(_gen_input, training=False)

        _cla_real_logits, _dis_real = discriminator(_x, training=False)
        _cla_fake_logits, _dis_fake = discriminator(_x_tilde, training=False)

        _cla_real = classifier(_cla_real_logits, training=False)
        _cla_fake = classifier(_cla_fake_logits, training=False)

        _loss_gen = -tf.reduce_mean(_dis_fake)

        _real_loss, _fake_loss = -tf.reduce_mean(_dis_real), tf.reduce_mean(
            _dis_fake)
        _gp = gradient_penalty(partial(discriminator, training=False), _x,
                               _x_tilde)
        _loss_dis = (_real_loss + _fake_loss) + _gp * param.w_gp_lambda

        _loss_cla_real = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(
                _y, axis=1),
                                                           logits=_cla_real))
        _loss_cla_fake = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(
                _y, axis=1),
                                                           logits=_cla_fake))
        _loss_cla = _loss_cla_real + _loss_cla_fake

        return _x_tilde, _loss_gen.numpy(), _loss_dis.numpy(
        ), _loss_cla_real.numpy(), _loss_cla_fake.numpy(), (
            -_fake_loss.numpy() - _real_loss.numpy())
예제 #8
0
def inference(w_gp=False):
    param = Parameter()

    # 1. Build models
    generator = Generator(param).model()
    discriminator = Discriminator(param).model()

    # 2. Load data
    data_loader = MNISTLoader(one_hot=False)
    train_set = data_loader.train.batch(batch_size=param.batch_size,
                                        drop_remainder=True)
    test_set = data_loader.test.batch(batch_size=param.batch_size,
                                      drop_remainder=True)

    # 3. Etc.
    graph_path = os.path.join(param.cur_dir, 'graph')
    if not os.path.isdir(graph_path):
        os.mkdir(graph_path)

    model_path = os.path.join(param.cur_dir, 'model')
    if not os.path.isdir(model_path):
        os.mkdir(model_path)

    # 4. Load model
    if w_gp is True:
        gen_name = 'gan_w_gp'
        dis_name = 'dis_w_gp'
        graph = 'gan_w_gp'
    else:
        gen_name = 'gan'
        dis_name = 'dis'
        graph = 'gan'

    generator.load_weights(os.path.join(model_path, gen_name))
    discriminator.load_weights(os.path.join(model_path, dis_name))

    # 5. Define loss
    cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

    # 6. Inference
    train_dis_loss, train_gen_loss = [], []
    for x_train, _ in train_set:
        # noise = tf.random.uniform(shape=(param.batch_size, param.latent_dim), minval=-1, maxval=1, dtype=tf.float32)
        noise = tf.random.normal(shape=(param.batch_size, param.latent_dim),
                                 mean=0.0,
                                 stddev=0.3,
                                 dtype=tf.float32)
        x_tilde = generator(noise, training=False)

        dis_real = discriminator(x_train, training=False)
        dis_fake = discriminator(x_tilde, training=False)

        if w_gp is False:
            loss_dis = cross_entropy(tf.ones_like(dis_real),
                                     dis_real) + cross_entropy(
                                         tf.zeros_like(dis_fake), dis_fake)
            loss_gen = cross_entropy(tf.ones_like(dis_fake), dis_fake)
        else:
            real_loss, fake_loss = -tf.reduce_mean(dis_real), tf.reduce_mean(
                dis_fake)
            gp = gradient_penalty(partial(discriminator, training=False),
                                  x_train, x_tilde)

            loss_dis = (real_loss + fake_loss) + gp * param.w_gp_lambda
            loss_gen = -tf.reduce_mean(dis_fake)

        train_dis_loss.append(loss_dis.numpy())
        train_gen_loss.append(loss_gen.numpy())

    num_test = 0
    valid_dis_loss, valid_gen_loss = [], []
    test_dis_loss, test_gen_loss = [], []

    for x_test, _ in test_set:
        # noise = tf.random.uniform(shape=(param.batch_size, param.latent_dim), minval=-1, maxval=1, dtype=tf.float32)
        noise = tf.random.normal(shape=(param.batch_size, param.latent_dim),
                                 mean=0.0,
                                 stddev=0.3,
                                 dtype=tf.float32)
        x_tilde = generator(noise, training=False)

        dis_real = discriminator(x_test, training=False)
        dis_fake = discriminator(x_tilde, training=False)

        if w_gp is False:
            loss_dis = cross_entropy(tf.ones_like(dis_real),
                                     dis_real) + cross_entropy(
                                         tf.zeros_like(dis_fake), dis_fake)
            loss_gen = cross_entropy(tf.ones_like(dis_fake), dis_fake)
        else:
            real_loss, fake_loss = -tf.reduce_mean(dis_real), tf.reduce_mean(
                dis_fake)
            gp = gradient_penalty(partial(discriminator, training=False),
                                  x_test, x_tilde)

            loss_dis = (real_loss + fake_loss) + gp * param.w_gp_lambda
            loss_gen = -tf.reduce_mean(dis_fake)

        if num_test <= param.valid_step:
            valid_dis_loss.append(loss_dis.numpy())
            valid_gen_loss.append(loss_gen.numpy())

        else:
            test_dis_loss.append(loss_dis.numpy())
            test_gen_loss.append(loss_gen.numpy())
        num_test += 1

    # 7. Report
    train_dis_loss = np.mean(np.reshape(train_dis_loss, (-1)))
    valid_dis_loss = np.mean(np.reshape(valid_dis_loss, (-1)))
    test_dis_loss = np.mean(np.reshape(test_dis_loss, (-1)))

    train_gen_loss = np.mean(np.reshape(train_gen_loss, (-1)))
    valid_gen_loss = np.mean(np.reshape(valid_gen_loss, (-1)))
    test_gen_loss = np.mean(np.reshape(test_gen_loss, (-1)))

    print("[Loss dis] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".
          format(train_dis_loss, valid_dis_loss, test_dis_loss))
    print("[Loss gen] Train: {:.06f}\t Validation: {:.06f}\t Test: {:.06f}".
          format(train_gen_loss, valid_gen_loss, test_gen_loss))

    # 8. Draw some samples
    save_decode_image_array(x_test.numpy(),
                            path=os.path.join(graph_path,
                                              '{}_original.png'.format(graph)))
    save_decode_image_array(x_tilde.numpy(),
                            path=os.path.join(
                                graph_path, '{}_generated.png'.format(graph)))