Example #1
0
def sigmoid_gan_loss(logits, real):
    if real:
        return O.sigmoid_cross_entropy_with_logits(logits=logits,
                                                   labels=O.ones_like(logits))
    else:
        return O.sigmoid_cross_entropy_with_logits(logits=logits,
                                                   labels=O.zeros_like(logits))
Example #2
0
            def forward(img_a, img_b):
                img_a /= 255.
                img_b /= 255.

                img_ab = generator(img_a, name='atob', reuse=False)
                img_ba = generator(img_b, name='btoa', reuse=False)
                img_aba = generator(img_ab, name='btoa', reuse=True)
                img_bab = generator(img_ba, name='atob', reuse=True)

                logit_fake_a = discriminator(img_ba, name='a', reuse=False)
                logit_fake_b = discriminator(img_ab, name='b', reuse=False)

                score_fake_a = O.sigmoid(logit_fake_a)
                score_fake_b = O.sigmoid(logit_fake_b)

                for name in ['img_a', 'img_b', 'img_ab', 'img_ba', 'img_aba', 'img_bab', 'score_fake_a', 'score_fake_b']:
                    dpc.add_output(locals()[name], name=name)

                if env.phase is env.Phase.TRAIN:
                    logit_real_a = discriminator(img_a, name='a', reuse=True)
                    logit_real_b = discriminator(img_b, name='b', reuse=True)
                    score_real_a = O.sigmoid(logit_real_a)
                    score_real_b = O.sigmoid(logit_real_b)

                    all_g_loss = 0.
                    all_d_loss = 0.
                    r_loss_ratio = 0.9

                    for pair_name, (real, fake), (logit_real, logit_fake), (score_real, score_fake) in zip(
                            ['lossa', 'lossb'],
                            [(img_a, img_aba), (img_b, img_bab)],
                            [(logit_real_a, logit_fake_a), (logit_real_b, logit_fake_b)],
                            [(score_real_a, score_fake_a), (score_real_b, score_fake_b)]):

                        with env.name_scope(pair_name):
                            d_loss_real = O.sigmoid_cross_entropy_with_logits(logits=logit_real, labels=O.ones_like(logit_real)).mean(name='d_loss_real')
                            d_loss_fake = O.sigmoid_cross_entropy_with_logits(logits=logit_fake, labels=O.zeros_like(logit_fake)).mean(name='d_loss_fake')
                            g_loss = O.sigmoid_cross_entropy_with_logits(logits=logit_fake, labels=O.ones_like(logit_fake)).mean(name='g_loss')

                            d_acc_real = (score_real > 0.5).astype('float32').mean(name='d_acc_real')
                            d_acc_fake = (score_fake < 0.5).astype('float32').mean(name='d_acc_fake')
                            g_accuracy = (score_fake > 0.5).astype('float32').mean(name='g_accuracy')

                            d_accuracy = O.identity(.5 * (d_acc_real + d_acc_fake), name='d_accuracy')
                            d_loss = O.identity(.5 * (d_loss_real + d_loss_fake), name='d_loss')

                            # r_loss = O.raw_l2_loss('raw_r_loss', real, fake).flatten2().sum(axis=1).mean(name='r_loss')
                            r_loss = O.raw_l2_loss('raw_r_loss', real, fake).mean(name='r_loss')
                            # r_loss = O.raw_cross_entropy_prob('raw_r_loss', real, fake).flatten2().sum(axis=1).mean(name='r_loss')

                            # all_g_loss += g_loss + r_loss
                            all_g_loss += (1 - r_loss_ratio) * g_loss + r_loss_ratio * r_loss
                            all_d_loss += d_loss

                        for v in [d_loss_real, d_loss_fake, g_loss, d_acc_real, d_acc_fake, g_accuracy, d_accuracy, d_loss, r_loss]:
                            dpc.add_output(v, name=re.sub('^tower/\d+/', '', v.name)[:-2], reduce_method='sum')

                    dpc.add_output(all_g_loss, name='g_loss', reduce_method='sum')
                    dpc.add_output(all_d_loss, name='d_loss', reduce_method='sum')
Example #3
0
            def forward(img):
                g_batch_size = get_env('trainer.batch_size') if env.phase is env.Phase.TRAIN else 1
                z = O.as_varnode(tf.random_normal([g_batch_size, code_length]))
                with env.variable_scope(GANGraphKeys.GENERATOR_VARIABLES):
                    _ = z
                    with O.argscope(O.fc, nonlin=O.tanh):
                        _ = O.fc('fc1', _, 500)
                    _ = O.fc('fc3', _, 784, nonlin=O.sigmoid)
                    x_given_z = _.reshape(-1, 28, 28, 1)

                def discriminator(x):
                    _ = x
                    with O.argscope(O.fc, nonlin=O.tanh):
                        _ = O.fc('fc1', _, 500)
                    _ = O.fc('fc3', _, 1)
                    logits = _
                    return logits

                if is_train:
                    with env.variable_scope(GANGraphKeys.DISCRIMINATOR_VARIABLES):
                        logits_real = discriminator(img).flatten()
                        score_real = O.sigmoid(logits_real)

                with env.variable_scope(GANGraphKeys.DISCRIMINATOR_VARIABLES, reuse=is_train):
                    logits_fake = discriminator(x_given_z).flatten()
                    score_fake = O.sigmoid(logits_fake)

                if is_train:
                    # build loss
                    with env.variable_scope('loss'):
                        d_loss_real = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_real, labels=O.ones_like(logits_real)).mean()
                        d_loss_fake = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_fake, labels=O.zeros_like(logits_fake)).mean()
                        g_loss = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_fake, labels=O.ones_like(logits_fake)).mean()

                    d_acc_real = (score_real > 0.5).astype('float32').mean()
                    d_acc_fake = (score_fake < 0.5).astype('float32').mean()
                    g_accuracy = (score_fake > 0.5).astype('float32').mean()

                    d_accuracy = .5 * (d_acc_real + d_acc_fake)
                    d_loss = .5 * (d_loss_real + d_loss_fake)

                    dpc.add_output(d_loss, name='d_loss', reduce_method='sum')
                    dpc.add_output(d_accuracy, name='d_accuracy', reduce_method='sum')
                    dpc.add_output(d_acc_real, name='d_acc_real', reduce_method='sum')
                    dpc.add_output(d_acc_fake, name='d_acc_fake', reduce_method='sum')
                    dpc.add_output(g_loss, name='g_loss', reduce_method='sum')
                    dpc.add_output(g_accuracy, name='g_accuracy', reduce_method='sum')

                dpc.add_output(x_given_z, name='output')
                dpc.add_output(score_fake, name='score')
Example #4
0
            def forward(x, zc):
                if env.phase is env.Phase.TRAIN:
                    zc = zc_distrib.sample(g_batch_size, prior)

                zn = O.random_normal([g_batch_size, zn_size], -1 , 1)
                z = O.concat([zc, zn], axis=1, name='z')
                
                with env.variable_scope(GANGraphKeys.GENERATOR_VARIABLES):
                    x_given_z = generator(z)

                with env.variable_scope(GANGraphKeys.DISCRIMINATOR_VARIABLES):
                    logits_fake, code_fake = discriminator(x_given_z)
                    score_fake = O.sigmoid(logits_fake)

                dpc.add_output(x_given_z, name='output')
                dpc.add_output(score_fake, name='score')
                dpc.add_output(code_fake, name='code')

                if env.phase is env.Phase.TRAIN:
                    with env.variable_scope(GANGraphKeys.DISCRIMINATOR_VARIABLES, reuse=True):
                        logits_real, code_real = discriminator(x)
                        score_real = O.sigmoid(logits_real)

                    # build loss
                    with env.variable_scope('loss'):
                        d_loss_real = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_real, labels=O.ones_like(logits_real)).mean()
                        d_loss_fake = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_fake, labels=O.zeros_like(logits_fake)).mean()
                        g_loss = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_fake, labels=O.ones_like(logits_fake)).mean()

                        entropy = zc_distrib.cross_entropy(zc, batch_prior)
                        cond_entropy = zc_distrib.cross_entropy(zc, code_fake, process_theta=True)
                        info_gain = entropy - cond_entropy

                    d_acc_real = (score_real > 0.5).astype('float32').mean()
                    d_acc_fake = (score_fake < 0.5).astype('float32').mean()
                    g_accuracy = (score_fake > 0.5).astype('float32').mean()

                    d_accuracy = .5 * (d_acc_real + d_acc_fake)
                    d_loss = .5 * (d_loss_real + d_loss_fake)

                    d_loss -= info_gain
                    g_loss -= info_gain

                    dpc.add_output(d_loss, name='d_loss', reduce_method='sum')
                    dpc.add_output(d_accuracy, name='d_accuracy', reduce_method='sum')
                    dpc.add_output(d_acc_real, name='d_acc_real', reduce_method='sum')
                    dpc.add_output(d_acc_fake, name='d_acc_fake', reduce_method='sum')
                    dpc.add_output(g_loss, name='g_loss', reduce_method='sum')
                    dpc.add_output(g_accuracy, name='g_accuracy', reduce_method='sum')
                    dpc.add_output(info_gain, name='g_info_gain', reduce_method='sum')
Example #5
0
            def forward(x):
                g_batch_size = get_env('trainer.batch_size'
                                       ) if env.phase is env.Phase.TRAIN else 1
                z = O.random_normal([g_batch_size, z_dim])

                with env.variable_scope(GANGraphKeys.GENERATOR_VARIABLES):
                    img_gen = generator(z)
                # tf.summary.image('generated-samples', img_gen, max_outputs=30)

                with env.variable_scope(GANGraphKeys.DISCRIMINATOR_VARIABLES):
                    logits_fake = discriminator(img_gen)
                    score_fake = O.sigmoid(logits_fake)
                dpc.add_output(img_gen, name='output')
                dpc.add_output(score_fake, name='score')

                if env.phase is env.Phase.TRAIN:
                    with env.variable_scope(
                            GANGraphKeys.DISCRIMINATOR_VARIABLES, reuse=True):
                        logits_real = discriminator(x)
                        score_real = O.sigmoid(logits_real)
                    # build loss
                    with env.variable_scope('loss'):
                        d_loss_real = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_real,
                            labels=O.ones_like(logits_real)).mean()
                        d_loss_fake = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_fake,
                            labels=O.zeros_like(logits_fake)).mean()
                        g_loss = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_fake,
                            labels=O.ones_like(logits_fake)).mean()

                    d_acc_real = (score_real > 0.5).astype('float32').mean()
                    d_acc_fake = (score_fake < 0.5).astype('float32').mean()
                    g_accuracy = (score_fake > 0.5).astype('float32').mean()

                    d_accuracy = .5 * (d_acc_real + d_acc_fake)
                    d_loss = .5 * (d_loss_real + d_loss_fake)

                    dpc.add_output(d_loss, name='d_loss', reduce_method='sum')
                    dpc.add_output(d_accuracy,
                                   name='d_accuracy',
                                   reduce_method='sum')
                    dpc.add_output(d_acc_real,
                                   name='d_acc_real',
                                   reduce_method='sum')
                    dpc.add_output(d_acc_fake,
                                   name='d_acc_fake',
                                   reduce_method='sum')
                    dpc.add_output(g_loss, name='g_loss', reduce_method='sum')
                    dpc.add_output(g_accuracy,
                                   name='g_accuracy',
                                   reduce_method='sum')