Example #1
0
            def forward(img_a, img_b):
                img_a /= 255.
                img_b /= 255.

                img_ab = generator(img_a, name='atob', reuse=False)
                img_ba = generator(img_b, name='btoa', reuse=False)
                img_aba = generator(img_ab, name='btoa', reuse=True)
                img_bab = generator(img_ba, name='atob', reuse=True)

                logit_fake_a = discriminator(img_ba, name='a', reuse=False)
                logit_fake_b = discriminator(img_ab, name='b', reuse=False)

                score_fake_a = O.sigmoid(logit_fake_a)
                score_fake_b = O.sigmoid(logit_fake_b)

                for name in ['img_a', 'img_b', 'img_ab', 'img_ba', 'img_aba', 'img_bab', 'score_fake_a', 'score_fake_b']:
                    dpc.add_output(locals()[name], name=name)

                if env.phase is env.Phase.TRAIN:
                    logit_real_a = discriminator(img_a, name='a', reuse=True)
                    logit_real_b = discriminator(img_b, name='b', reuse=True)
                    score_real_a = O.sigmoid(logit_real_a)
                    score_real_b = O.sigmoid(logit_real_b)

                    all_g_loss = 0.
                    all_d_loss = 0.
                    r_loss_ratio = 0.9

                    for pair_name, (real, fake), (logit_real, logit_fake), (score_real, score_fake) in zip(
                            ['lossa', 'lossb'],
                            [(img_a, img_aba), (img_b, img_bab)],
                            [(logit_real_a, logit_fake_a), (logit_real_b, logit_fake_b)],
                            [(score_real_a, score_fake_a), (score_real_b, score_fake_b)]):

                        with env.name_scope(pair_name):
                            d_loss_real = O.sigmoid_cross_entropy_with_logits(logits=logit_real, labels=O.ones_like(logit_real)).mean(name='d_loss_real')
                            d_loss_fake = O.sigmoid_cross_entropy_with_logits(logits=logit_fake, labels=O.zeros_like(logit_fake)).mean(name='d_loss_fake')
                            g_loss = O.sigmoid_cross_entropy_with_logits(logits=logit_fake, labels=O.ones_like(logit_fake)).mean(name='g_loss')

                            d_acc_real = (score_real > 0.5).astype('float32').mean(name='d_acc_real')
                            d_acc_fake = (score_fake < 0.5).astype('float32').mean(name='d_acc_fake')
                            g_accuracy = (score_fake > 0.5).astype('float32').mean(name='g_accuracy')

                            d_accuracy = O.identity(.5 * (d_acc_real + d_acc_fake), name='d_accuracy')
                            d_loss = O.identity(.5 * (d_loss_real + d_loss_fake), name='d_loss')

                            # r_loss = O.raw_l2_loss('raw_r_loss', real, fake).flatten2().sum(axis=1).mean(name='r_loss')
                            r_loss = O.raw_l2_loss('raw_r_loss', real, fake).mean(name='r_loss')
                            # r_loss = O.raw_cross_entropy_prob('raw_r_loss', real, fake).flatten2().sum(axis=1).mean(name='r_loss')

                            # all_g_loss += g_loss + r_loss
                            all_g_loss += (1 - r_loss_ratio) * g_loss + r_loss_ratio * r_loss
                            all_d_loss += d_loss

                        for v in [d_loss_real, d_loss_fake, g_loss, d_acc_real, d_acc_fake, g_accuracy, d_accuracy, d_loss, r_loss]:
                            dpc.add_output(v, name=re.sub('^tower/\d+/', '', v.name)[:-2], reduce_method='sum')

                    dpc.add_output(all_g_loss, name='g_loss', reduce_method='sum')
                    dpc.add_output(all_d_loss, name='d_loss', reduce_method='sum')
Example #2
0
            def forward(img):
                g_batch_size = get_env('trainer.batch_size') if env.phase is env.Phase.TRAIN else 1
                z = O.as_varnode(tf.random_normal([g_batch_size, code_length]))
                with env.variable_scope(GANGraphKeys.GENERATOR_VARIABLES):
                    _ = z
                    with O.argscope(O.fc, nonlin=O.tanh):
                        _ = O.fc('fc1', _, 500)
                    _ = O.fc('fc3', _, 784, nonlin=O.sigmoid)
                    x_given_z = _.reshape(-1, 28, 28, 1)

                def discriminator(x):
                    _ = x
                    with O.argscope(O.fc, nonlin=O.tanh):
                        _ = O.fc('fc1', _, 500)
                    _ = O.fc('fc3', _, 1)
                    logits = _
                    return logits

                if is_train:
                    with env.variable_scope(GANGraphKeys.DISCRIMINATOR_VARIABLES):
                        logits_real = discriminator(img).flatten()
                        score_real = O.sigmoid(logits_real)

                with env.variable_scope(GANGraphKeys.DISCRIMINATOR_VARIABLES, reuse=is_train):
                    logits_fake = discriminator(x_given_z).flatten()
                    score_fake = O.sigmoid(logits_fake)

                if is_train:
                    # build loss
                    with env.variable_scope('loss'):
                        d_loss_real = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_real, labels=O.ones_like(logits_real)).mean()
                        d_loss_fake = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_fake, labels=O.zeros_like(logits_fake)).mean()
                        g_loss = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_fake, labels=O.ones_like(logits_fake)).mean()

                    d_acc_real = (score_real > 0.5).astype('float32').mean()
                    d_acc_fake = (score_fake < 0.5).astype('float32').mean()
                    g_accuracy = (score_fake > 0.5).astype('float32').mean()

                    d_accuracy = .5 * (d_acc_real + d_acc_fake)
                    d_loss = .5 * (d_loss_real + d_loss_fake)

                    dpc.add_output(d_loss, name='d_loss', reduce_method='sum')
                    dpc.add_output(d_accuracy, name='d_accuracy', reduce_method='sum')
                    dpc.add_output(d_acc_real, name='d_acc_real', reduce_method='sum')
                    dpc.add_output(d_acc_fake, name='d_acc_fake', reduce_method='sum')
                    dpc.add_output(g_loss, name='g_loss', reduce_method='sum')
                    dpc.add_output(g_accuracy, name='g_accuracy', reduce_method='sum')

                dpc.add_output(x_given_z, name='output')
                dpc.add_output(score_fake, name='score')
Example #3
0
            def forward(x, zc):
                if env.phase is env.Phase.TRAIN:
                    zc = zc_distrib.sample(g_batch_size, prior)

                zn = O.random_normal([g_batch_size, zn_size], -1 , 1)
                z = O.concat([zc, zn], axis=1, name='z')
                
                with env.variable_scope(GANGraphKeys.GENERATOR_VARIABLES):
                    x_given_z = generator(z)

                with env.variable_scope(GANGraphKeys.DISCRIMINATOR_VARIABLES):
                    logits_fake, code_fake = discriminator(x_given_z)
                    score_fake = O.sigmoid(logits_fake)

                dpc.add_output(x_given_z, name='output')
                dpc.add_output(score_fake, name='score')
                dpc.add_output(code_fake, name='code')

                if env.phase is env.Phase.TRAIN:
                    with env.variable_scope(GANGraphKeys.DISCRIMINATOR_VARIABLES, reuse=True):
                        logits_real, code_real = discriminator(x)
                        score_real = O.sigmoid(logits_real)

                    # build loss
                    with env.variable_scope('loss'):
                        d_loss_real = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_real, labels=O.ones_like(logits_real)).mean()
                        d_loss_fake = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_fake, labels=O.zeros_like(logits_fake)).mean()
                        g_loss = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_fake, labels=O.ones_like(logits_fake)).mean()

                        entropy = zc_distrib.cross_entropy(zc, batch_prior)
                        cond_entropy = zc_distrib.cross_entropy(zc, code_fake, process_theta=True)
                        info_gain = entropy - cond_entropy

                    d_acc_real = (score_real > 0.5).astype('float32').mean()
                    d_acc_fake = (score_fake < 0.5).astype('float32').mean()
                    g_accuracy = (score_fake > 0.5).astype('float32').mean()

                    d_accuracy = .5 * (d_acc_real + d_acc_fake)
                    d_loss = .5 * (d_loss_real + d_loss_fake)

                    d_loss -= info_gain
                    g_loss -= info_gain

                    dpc.add_output(d_loss, name='d_loss', reduce_method='sum')
                    dpc.add_output(d_accuracy, name='d_accuracy', reduce_method='sum')
                    dpc.add_output(d_acc_real, name='d_acc_real', reduce_method='sum')
                    dpc.add_output(d_acc_fake, name='d_acc_fake', reduce_method='sum')
                    dpc.add_output(g_loss, name='g_loss', reduce_method='sum')
                    dpc.add_output(g_accuracy, name='g_accuracy', reduce_method='sum')
                    dpc.add_output(info_gain, name='g_info_gain', reduce_method='sum')
Example #4
0
            def forward(x):
                g_batch_size = get_env('trainer.batch_size'
                                       ) if env.phase is env.Phase.TRAIN else 1
                z = O.random_normal([g_batch_size, z_dim])

                with env.variable_scope(GANGraphKeys.GENERATOR_VARIABLES):
                    img_gen = generator(z)
                # tf.summary.image('generated-samples', img_gen, max_outputs=30)

                with env.variable_scope(GANGraphKeys.DISCRIMINATOR_VARIABLES):
                    logits_fake = discriminator(img_gen)
                    score_fake = O.sigmoid(logits_fake)
                dpc.add_output(img_gen, name='output')
                dpc.add_output(score_fake, name='score')

                if env.phase is env.Phase.TRAIN:
                    with env.variable_scope(
                            GANGraphKeys.DISCRIMINATOR_VARIABLES, reuse=True):
                        logits_real = discriminator(x)
                        score_real = O.sigmoid(logits_real)
                    # build loss
                    with env.variable_scope('loss'):
                        d_loss_real = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_real,
                            labels=O.ones_like(logits_real)).mean()
                        d_loss_fake = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_fake,
                            labels=O.zeros_like(logits_fake)).mean()
                        g_loss = O.sigmoid_cross_entropy_with_logits(
                            logits=logits_fake,
                            labels=O.ones_like(logits_fake)).mean()

                    d_acc_real = (score_real > 0.5).astype('float32').mean()
                    d_acc_fake = (score_fake < 0.5).astype('float32').mean()
                    g_accuracy = (score_fake > 0.5).astype('float32').mean()

                    d_accuracy = .5 * (d_acc_real + d_acc_fake)
                    d_loss = .5 * (d_loss_real + d_loss_fake)

                    dpc.add_output(d_loss, name='d_loss', reduce_method='sum')
                    dpc.add_output(d_accuracy,
                                   name='d_accuracy',
                                   reduce_method='sum')
                    dpc.add_output(d_acc_real,
                                   name='d_acc_real',
                                   reduce_method='sum')
                    dpc.add_output(d_acc_fake,
                                   name='d_acc_fake',
                                   reduce_method='sum')
                    dpc.add_output(g_loss, name='g_loss', reduce_method='sum')
                    dpc.add_output(g_accuracy,
                                   name='g_accuracy',
                                   reduce_method='sum')
Example #5
0
def image_diff(origin, canvas_logits):
    """
    Get the difference between original image and the canvas,
    note that the canvas is given without sigmoid activation (we will do it inside)

    :param origin: original image: batch_size, h, w, c
    :param canvas_logits: canvas logits: batch_size, h, w, c
    :return: the difference: origin - sigmoid(logits)
    """
    sigmoid_canvas = O.sigmoid(canvas_logits)
    return origin - sigmoid_canvas
Example #6
0
            def generator(z):
                w_init = O.truncated_normal_initializer(stddev=0.02)
                with O.argscope(O.conv2d, O.deconv2d, kernel=4, stride=2, W=w_init),\
                     O.argscope(O.fc, W=w_init):

                    _ = z
                    _ = O.fc('fc1', _, 1024, nonlin=O.bn_relu)
                    _ = O.fc('fc2', _, 128 * 7 * 7, nonlin=O.bn_relu)
                    _ = O.reshape(_, [-1, 7, 7, 128])
                    _ = O.deconv2d('deconv1', _, 64, nonlin=O.bn_relu)
                    _ = O.deconv2d('deconv2', _, 1)
                    _ = O.sigmoid(_, 'out')
                return _
            def decoder(z):
                w_init = O.truncated_normal_initializer(stddev=0.02)
                with O.argscope(O.conv2d, O.deconv2d, kernel=4, stride=2, W=w_init),\
                     O.argscope(O.fc, W=w_init):

                    _ = z
                    _ = O.deconv2d('deconv1', _, 256, nonlin=O.bn_relu)
                    _ = O.deconv2d('deconv2', _, 128, nonlin=O.bn_relu)
                    _ = O.deconv2d('deconv3', _, 64, nonlin=O.bn_relu)
                    _ = O.deconv2d('deconv4', _, c)
                    _ = O.sigmoid(_, name='out')
                x = _
                return x
Example #8
0
            def forward(img=None):
                encoder = O.BasicLSTMCell(256)
                decoder = O.BasicLSTMCell(256)

                batch_size = img.shape[0] if is_train else 1

                canvas = O.zeros(shape=O.canonize_sym_shape([batch_size, h, w, c]), dtype='float32')
                enc_state = encoder.zero_state(batch_size, dtype='float32')
                dec_state = decoder.zero_state(batch_size, dtype='float32')
                enc_h, dec_h = enc_state[1], dec_state[1]

                def encode(x, state, reuse):
                    with env.variable_scope('read_encoder', reuse=reuse):
                        return encoder(x, state)

                def decode(x, state, reuse):
                    with env.variable_scope('write_decoder', reuse=reuse):
                        return decoder(x, state)

                all_sqr_mus, all_vars, all_log_vars = 0., 0., 0.

                for step in range(nr_glimpse):
                    reuse = (step != 0)
                    if is_reconstruct or env.phase is env.Phase.TRAIN:
                        img_hat = draw_opr.image_diff(img, canvas)  # eq. 3

                        # Note: here the input should be dec_h
                        with env.variable_scope('read', reuse=reuse):
                            read_param = O.fc('fc_param', dec_h, 5)

                        with env.name_scope('read_step{}'.format(step)):
                            cx, cy, delta, var, gamma = draw_opr.split_att_params(h, w, att_dim, read_param)
                            read_inp = O.concat([img, img_hat], axis=3)  # of shape: batch_size x h x w x (2c)
                            read_out = draw_opr.att_read(att_dim, read_inp, cx, cy, delta, var)  # eq. 4
                            enc_inp = O.concat([gamma * read_out.flatten2(), dec_h], axis=1)
                        enc_h, enc_state = encode(enc_inp, enc_state, reuse)  # eq. 5

                        with env.variable_scope('sample', reuse=reuse):
                            _ = enc_h
                            sample_mu = O.fc('fc_mu', _, code_length)
                            sample_log_var = O.fc('fc_sigma', _, code_length)

                        with env.name_scope('sample_step{}'.format(step)):
                            sample_var = O.exp(sample_log_var)
                            sample_std = O.sqrt(sample_var)
                            sample_epsilon = O.random_normal([batch_size, code_length])
                            z = sample_mu + sample_std * sample_epsilon  # eq. 6

                        # accumulate for losses
                        all_sqr_mus += sample_mu ** 2.
                        all_vars += sample_var
                        all_log_vars += sample_log_var
                    else:
                        z = O.random_normal([1, code_length])

                    # z = O.callback_injector(z)

                    dec_h, dec_state = decode(z, dec_state, reuse)  # eq. 7
                    with env.variable_scope('write', reuse=reuse):
                        write_param = O.fc('fc_param', dec_h, 5)
                        write_in = O.fc('fc', dec_h, (att_dim * att_dim * c)).reshape(-1, att_dim, att_dim, c)

                    with env.name_scope('write_step{}'.format(step)):
                        cx, cy, delta, var, gamma = draw_opr.split_att_params(h, w, att_dim, write_param)
                        write_out = draw_opr.att_write(h, w, write_in, cx, cy, delta, var)  # eq. 8

                    canvas += write_out

                    if env.phase is env.Phase.TEST:
                        dpc.add_output(O.sigmoid(canvas), name='canvas_step{}'.format(step))

                canvas = O.sigmoid(canvas)

                if env.phase is env.Phase.TRAIN:
                    with env.variable_scope('loss'):
                        img, canvas = img.flatten2(), canvas.flatten2()
                        content_loss = O.raw_cross_entropy_prob('raw_content', canvas, img)
                        content_loss = content_loss.sum(axis=1).mean(name='content')
                        # distrib_loss = 0.5 * (O.sqr(mu) + O.sqr(std) - 2. * O.log(std + 1e-8) - 1.0).sum(axis=1)
                        distrib_loss = -0.5 * (float(nr_glimpse) + all_log_vars - all_sqr_mus - all_vars).sum(axis=1)
                        distrib_loss = distrib_loss.mean(name='distrib')

                        summary.scalar('content_loss', content_loss)
                        summary.scalar('distrib_loss', distrib_loss)

                        loss = content_loss + distrib_loss
                    dpc.add_output(loss, name='loss', reduce_method='sum')

                dpc.add_output(canvas, name='output')