예제 #1
0
def discriminate_pix2pix(discrim_inputs, discrim_targets, num_classes, labels=None, reuse=False,
                          data_format='NCHW', scope_name=None):
    print("Pix2pix Discriminator")
    assert data_format == 'NCHW'
    size = SIZE
    sn = Config.sn

    if data_format == 'NCHW':
        channel_axis = 1
    else:
        channel_axis = 3
    if type(discrim_targets) is list:
        discrim_targets = discrim_targets[-1]

    output_dim = 1
    
    with tf.variable_scope(scope_name) as scope:
        if reuse:
            scope.reuse_variables()

        n_layers = 3
        layers = []

        # 2x [batch, 3, height, width] => [batch, 3 * 2, height, width]
        input_fusion = tf.concat([discrim_inputs, discrim_targets], axis=channel_axis)

        # layer_1: [batch, 6, 192, 192] => [batch, 64, 96, 96]
        with tf.variable_scope("layer_1"):
            convolved = nchw_conv(input_fusion, size, stride=2)
            rectified = lrelu(convolved, 0.2)
            layers.append(rectified)

        # layer_2: [batch, 64, 96, 96] => [batch, 128, 48, 48]
        # layer_3: [batch, 128, 48, 48] => [batch, 256, 24, 24]
        # layer_4: [batch, 256, 24, 24] => [batch, 512, 23, 23]
        for i in range(n_layers):
            with tf.variable_scope("layer_%d" % (len(layers) + 1)):
                out_channels = size * min(2 ** (i + 1), 8)
                stride = 1 if i == n_layers - 1 else 2  # last layer here has stride 1
                convolved = nchw_conv(layers[-1], out_channels, stride=stride)
                normalized = batchnorm(convolved, data_format=data_format)
                rectified = lrelu(normalized, 0.2)
                layers.append(rectified)

        # layer_5: [batch, 512, 23, 23] => [batch, 1, 22, 22], ==> discriminator end
        with tf.variable_scope("layer_%d" % (len(layers) + 1)):
            disc = nchw_conv(rectified, out_channels=output_dim, stride=1)
            
        # classification end
        img = tf.reduce_mean(rectified, axis=(2, 3) if data_format == 'NCHW' else (1, 2))
        logits = fully_connected(img, num_classes, sn=sn, activation_fn=None, normalizer_fn=None)

    return disc, logits
예제 #2
0
def generate_residual(z, text_vocab_indices, LSTM_hybrid, output_channel, num_classes, vocab_size, reuse=False,
                      data_format='NCHW', labels=None, scope_name=None):
    print("Residual Generator")
    size = SIZE
    sn = False

    input_dims = z.get_shape().as_list()

    if data_format == 'NCHW':
        height = input_dims[2]
        width = input_dims[3]
    else:
        height = input_dims[1]
        width = input_dims[2]

    if data_format == 'NCHW':
        concat_axis = 1
    else:
        concat_axis = 3

    if normalizer_params_g is not None and normalizer_fn_g != ly.batch_norm and normalizer_fn_g != ly.layer_norm:
        normalizer_params_g['labels'] = labels
        normalizer_params_g['n_labels'] = num_classes

    with tf.variable_scope(scope_name) as scope:
        if reuse:
            scope.reuse_variables()

        num_residual_units = [3, 4, 6, 3]

        z_encoded = image_encoder_residual(z, num_residual_units, num_classes=num_classes, reuse=reuse,
                                           data_format=data_format,
                                           labels=labels, scope_name=scope_name)  # list of hidden state

        input_e_dims = z_encoded[-1].get_shape().as_list()
        batch_size = input_e_dims[0]

        # z_encoded[-1].shape = [N, 512, 6, 6], text_vocab_indices.shape = [N, 15]

        if LSTM_hybrid:
            ## Add text LSTM
            lstm_output = encode_feat_with_text(z_encoded[-1], text_vocab_indices, input_e_dims, vocab_size)
            feat_encoded_final = lstm_output  # [N, 512, 6, 6]
        else:
            feat_encoded_final = z_encoded[-1]

        channel_depth = int(input_e_dims[concat_axis] / 8.)
        if data_format == 'NCHW':
            noise_dims = [batch_size, channel_depth, int(input_e_dims[2]), int(input_e_dims[3])]
        else:
            noise_dims = [batch_size, int(input_e_dims[1]), int(input_e_dims[2]), channel_depth]

        noise_vec = tf.random_normal(shape=(batch_size, 256), dtype=tf.float32)
        noise = fully_connected(noise_vec, int(np.prod(noise_dims[1:])), sn=sn,
                                activation_fn=activation_fn_g,
                                # normalizer_fn=normalizer_fn_g,
                                # normalizer_params=normalizer_params_g
                                )
        noise = tf.reshape(noise, shape=noise_dims)

        ## decoder
        layer_specs = [
            (size * 8, 0.0),  # decoder_5: [batch, 512 * 2, 6, 6] => [batch, 512, 12, 12]
            (size * 4, 0.0),  # decoder_4: [batch, 512 * 2, 12, 12] => [batch, 256, 24, 24]
            (size * 2, 0.0),  # decoder_3: [batch, 256 * 2, 24, 24] => [batch, 128, 48, 48]
            (size, 0.0),  # decoder_2: [batch, 128 * 2, 48, 48] => [batch, 64, 96, 96]
        ]

        num_encoder_layers = len(z_encoded)
        for decoder_layer, (out_channels, dropout) in enumerate(layer_specs):
            skip_layer = num_encoder_layers - decoder_layer - 1
            with tf.variable_scope("decoder_%d_0" % (skip_layer + 1)):
                if decoder_layer == 0:
                    input = tf.concat([feat_encoded_final, noise], axis=concat_axis)
                else:
                    input = tf.concat([z_encoded[-1], z_encoded[skip_layer]], axis=concat_axis)
                output = bottleneck_residual_de(input, out_channels)
            for uId in range(1, num_residual_units[skip_layer - 1]):
                with tf.variable_scope("decoder_%d_%d" % (skip_layer + 1, uId)):
                    output = bottleneck_residual_pu(output, out_channels, False)

            z_encoded.append(output)

        # decoder_1: [batch, 64 * 2, 96, 96] => [batch, 3, 192, 192]
        with tf.variable_scope("decoder_1"):
            input = tf.concat([z_encoded[-1], z_encoded[0]], axis=concat_axis)
            output = nchw_deconv(input, output_channel)
            output = batchnorm(output, data_format=data_format)
            output = tf.tanh(output)
            z_encoded.append(output)

        if output.get_shape().as_list()[2] != height:
            raise ValueError('Current shape', output.get_shape().as_list()[2], 'not match', height)
        return output, noise_vec
예제 #3
0
def discriminate_mru(discrim_inputs, discrim_targets, num_classes, labels=None, reuse=False, 
                      data_format='NCHW', scope_name=None):
    print("MRU Discriminator")
    assert data_format == 'NCHW'
    size = SIZE
    num_blocks = NUM_BLOCKS
    resize_func = tf.image.resize_bilinear
    sn = Config.sn

    if data_format == 'NCHW':
        channel_axis = 1
    else:
        channel_axis = 3
    if type(discrim_targets) is list:
        discrim_targets = discrim_targets[-1]

    if data_format == 'NCHW':
        x_list = []
        resized_ = discrim_targets
        x_list.append(resized_)

        for i in range(5):
            resized_ = mean_pool(resized_, data_format=data_format)
            x_list.append(resized_)
        x_list = x_list[::-1]
    else:
        raise NotImplementedError

    output_dim = 1

    with tf.variable_scope(scope_name) as scope:
        if reuse:
            scope.reuse_variables()

        h0 = conv2d(x_list[-1], 8, kernel_size=7, sn=sn, stride=1, data_format=data_format,
                    activation_fn=activation_fn_d,
                    normalizer_fn=normalizer_fn_d,
                    normalizer_params=normalizer_params_d,
                    weights_initializer=weight_initializer)

        # Initial memory state
        hidden_state_shape = h0.get_shape().as_list()
        batch_size = hidden_state_shape[0]
        hidden_state_shape[0] = 1
        hts_0 = [h0]
        for i in range(1, num_blocks):
            h0 = tf.tile(tf.get_variable("initial_hidden_state_%d" % i, shape=hidden_state_shape, dtype=tf.float32,
                                         initializer=tf.zeros_initializer()), [batch_size, 1, 1, 1])
            hts_0.append(h0)

        hts_1 = mru_conv(x_list[-1], hts_0,
                         size * 2, sn=sn, stride=2, dilate_rate=1,
                         data_format=data_format, num_blocks=num_blocks,
                         last_unit=False,
                         activation_fn=activation_fn_d,
                         normalizer_fn=normalizer_fn_d,
                         normalizer_params=normalizer_params_d,
                         weights_initializer=weight_initializer,
                         unit_num=1)
        hts_2 = mru_conv(x_list[-2], hts_1,
                         size * 4, sn=sn, stride=2, dilate_rate=1,
                         data_format=data_format, num_blocks=num_blocks,
                         last_unit=False,
                         activation_fn=activation_fn_d,
                         normalizer_fn=normalizer_fn_d,
                         normalizer_params=normalizer_params_d,
                         weights_initializer=weight_initializer,
                         unit_num=2)
        hts_3 = mru_conv(x_list[-3], hts_2,
                         size * 8, sn=sn, stride=2, dilate_rate=1,
                         data_format=data_format, num_blocks=num_blocks,
                         last_unit=False,
                         activation_fn=activation_fn_d,
                         normalizer_fn=normalizer_fn_d,
                         normalizer_params=normalizer_params_d,
                         weights_initializer=weight_initializer,
                         unit_num=3)
        hts_4 = mru_conv(x_list[-4], hts_3,
                         size * 12, sn=sn, stride=2, dilate_rate=1,
                         data_format=data_format, num_blocks=num_blocks,
                         last_unit=True,
                         activation_fn=activation_fn_d,
                         normalizer_fn=normalizer_fn_d,
                         normalizer_params=normalizer_params_d,
                         weights_initializer=weight_initializer,
                         unit_num=4)

        img = hts_4[-1]
        img_shape = img.get_shape().as_list()

        # discriminator end
        disc = conv2d(img, output_dim, kernel_size=1, sn=sn, stride=1, data_format=data_format,
                      activation_fn=None, normalizer_fn=None,
                      weights_initializer=weight_initializer)

        if Config.proj_d:
            # Projection discriminator
            assert labels is not None and (len(labels.get_shape()) == 1 or labels.get_shape().as_list()[-1] == 1)

            class_embeddings = embed_labels(labels, num_classes, img_shape[channel_axis], sn=sn)
            class_embeddings = tf.reshape(class_embeddings, (img_shape[0], img_shape[channel_axis], 1, 1))  # NCHW

            disc += tf.reduce_sum(img * class_embeddings, axis=1, keep_dims=True)

            logits = None
        else:
            # classification end
            img = tf.reduce_mean(img, axis=(2, 3) if data_format == 'NCHW' else (1, 2))
            logits = fully_connected(img, num_classes, sn=sn, activation_fn=None, normalizer_fn=None)

    return disc, logits
예제 #4
0
def generate_mru(z, text_vocab_indices, LSTM_hybrid, output_channel, num_classes, vocab_size, reuse=False,
                  data_format='NCHW',
                  labels=None, scope_name=None):
    print("MRU Generator")
    size = SIZE
    num_blocks = NUM_BLOCKS
    sn = False

    input_dims = z.get_shape().as_list()
    resize_method = tf.image.ResizeMethod.AREA

    if data_format == 'NCHW':
        height = input_dims[2]
        width = input_dims[3]
    else:
        height = input_dims[1]
        width = input_dims[2]
    resized_z = [tf.identity(z)]
    for i in range(5):
        resized_z.append(image_resize(z, [int(height / 2 ** (i + 1)), int(width / 2 ** (i + 1))],
                                      resize_method, data_format))
    resized_z = resized_z[::-1]

    if data_format == 'NCHW':
        concat_axis = 1
    else:
        concat_axis = 3

    if normalizer_params_g is not None and normalizer_fn_g != ly.batch_norm and normalizer_fn_g != ly.layer_norm:
        normalizer_params_g['labels'] = labels
        normalizer_params_g['n_labels'] = num_classes

    with tf.variable_scope(scope_name) as scope:
        if reuse:
            scope.reuse_variables()

        z_encoded = image_encoder_mru(z, num_classes=num_classes, reuse=reuse, data_format=data_format,
                                      labels=labels, scope_name=scope_name)

        input_e_dims = z_encoded[-1].get_shape().as_list()
        batch_size = input_e_dims[0]

        # z_encoded[-1].shape = [N, 512, 6, 6], text_vocab_indices.shape = [N, 15]

        if LSTM_hybrid:

            ## Add text LSTM
            lstm_output = encode_feat_with_text(z_encoded[-1], text_vocab_indices, input_e_dims, vocab_size)
            feat_encoded_final = lstm_output  # [N, 512, 6, 6]

        else:
            feat_encoded_final = z_encoded[-1]

        channel_depth = int(input_e_dims[concat_axis] / 8.)
        if data_format == 'NCHW':
            noise_dims = [batch_size, channel_depth, int(input_e_dims[2] * 2), int(input_e_dims[3] * 2)]
        else:
            noise_dims = [batch_size, int(input_e_dims[1] * 2), int(input_e_dims[2] * 2), channel_depth]

        noise_vec = tf.random_normal(shape=(batch_size, 256), dtype=tf.float32)
        noise = fully_connected(noise_vec, int(np.prod(noise_dims[1:])), sn=sn,
                                activation_fn=activation_fn_g,
                                # normalizer_fn=normalizer_fn_g,
                                # normalizer_params=normalizer_params_g
                                )
        noise = tf.reshape(noise, shape=noise_dims)

        # Initial memory state
        hidden_state_shape = z_encoded[-1].get_shape().as_list()
        hidden_state_shape[0] = 1
        hts_0 = [feat_encoded_final]

        input_0 = tf.concat([resized_z[1], noise], axis=concat_axis)
        hts_1 = mru_deconv(input_0, hts_0,
                           size * 6, sn=sn, stride=2, data_format=data_format,
                           num_blocks=num_blocks,
                           last_unit=False,
                           activation_fn=activation_fn_g,
                           normalizer_fn=normalizer_fn_g,
                           normalizer_params=normalizer_params_g,
                           weights_initializer=weight_initializer,
                           unit_num=0)
        input_1 = tf.concat([resized_z[2], z_encoded[-3]], axis=concat_axis)
        hts_2 = mru_deconv(input_1, hts_1,
                           size * 4, sn=sn, stride=2, data_format=data_format,
                           num_blocks=num_blocks,
                           last_unit=False,
                           activation_fn=activation_fn_g,
                           normalizer_fn=normalizer_fn_g,
                           normalizer_params=normalizer_params_g,
                           weights_initializer=weight_initializer,
                           unit_num=2)
        input_2 = tf.concat([resized_z[3], z_encoded[-4]], axis=concat_axis)
        hts_3 = mru_deconv(input_2, hts_2,
                           size * 2, sn=sn, stride=2, data_format=data_format,
                           num_blocks=num_blocks,
                           last_unit=False,
                           activation_fn=activation_fn_g,
                           normalizer_fn=normalizer_fn_g,
                           normalizer_params=normalizer_params_g,
                           weights_initializer=weight_initializer,
                           unit_num=4)
        input_3 = tf.concat([resized_z[4], z_encoded[-5]], axis=concat_axis)
        hts_4 = mru_deconv(input_3, hts_3,
                           size * 2, sn=sn, stride=2, data_format=data_format,
                           num_blocks=num_blocks,
                           last_unit=False,
                           activation_fn=activation_fn_g,
                           normalizer_fn=normalizer_fn_g,
                           normalizer_params=normalizer_params_g,
                           weights_initializer=weight_initializer,
                           unit_num=6)
        hts_5 = mru_deconv(resized_z[5], hts_4,
                           size * 1, sn=sn, stride=2, data_format=data_format,
                           num_blocks=num_blocks,
                           last_unit=True,
                           activation_fn=activation_fn_g,
                           normalizer_fn=normalizer_fn_g,
                           normalizer_params=normalizer_params_g,
                           weights_initializer=weight_initializer,
                           unit_num=8)
        out = conv2d(hts_5[-1], output_channel, 7, sn=sn, stride=1, data_format=data_format,
                     normalizer_fn=None, activation_fn=tf.nn.tanh,
                     weights_initializer=weight_initializer)
        if out.get_shape().as_list()[2] != height:
            raise ValueError('Current shape', out.get_shape().as_list()[2], 'not match', height)
        return out, noise_vec