예제 #1
0
def image_encoder_mru(x, num_classes, reuse=False, data_format='NCHW', labels=None, scope_name=None):
    assert data_format == 'NCHW'
    size = SIZE
    num_blocks = NUM_BLOCKS
    resize_func = tf.image.resize_bilinear
    sn = False

    if normalizer_params_e is not None and normalizer_fn_e != ly.batch_norm and normalizer_fn_e != ly.layer_norm:
        normalizer_params_e['labels'] = labels
        normalizer_params_e['n_labels'] = num_classes

    if data_format == 'NCHW':
        x_list = []
        resized_ = x
        x_list.append(resized_)

        for i in range(4):
            resized_ = mean_pool(resized_, data_format=data_format)
            x_list.append(resized_)
        x_list = x_list[::-1]
    else:
        raise NotImplementedError

    output_list = []

    h0 = conv2d(x_list[-1], 8, kernel_size=7, sn=sn, stride=2, data_format=data_format,
                activation_fn=None,
                normalizer_fn=None,
                normalizer_params=None,
                weights_initializer=weight_initializer)

    output_list.append(h0)

    # Initial memory state
    hidden_state_shape = h0.get_shape().as_list()
    hidden_state_shape[0] = 1
    hts_0 = [h0]

    hts_1 = mru_conv(x_list[-2], hts_0,
                     size * 1, sn=sn, stride=2, dilate_rate=1,
                     data_format=data_format, num_blocks=num_blocks,
                     last_unit=False,
                     activation_fn=activation_fn_e,
                     normalizer_fn=normalizer_fn_e,
                     normalizer_params=normalizer_params_e,
                     weights_initializer=weight_initializer,
                     unit_num=1)
    output_list.append(hts_1[-1])
    hts_2 = mru_conv(x_list[-3], hts_1,
                     size * 2, sn=sn, stride=2, dilate_rate=1,
                     data_format=data_format, num_blocks=num_blocks,
                     last_unit=False,
                     activation_fn=activation_fn_e,
                     normalizer_fn=normalizer_fn_e,
                     normalizer_params=normalizer_params_e,
                     weights_initializer=weight_initializer,
                     unit_num=2)
    output_list.append(hts_2[-1])
    hts_3 = mru_conv(x_list[-4], hts_2,
                     size * 4, sn=sn, stride=2, dilate_rate=1,
                     data_format=data_format, num_blocks=num_blocks,
                     last_unit=False,
                     activation_fn=activation_fn_e,
                     normalizer_fn=normalizer_fn_e,
                     normalizer_params=normalizer_params_e,
                     weights_initializer=weight_initializer,
                     unit_num=3)
    output_list.append(hts_3[-1])
    hts_4 = mru_conv(x_list[-5], hts_3,
                     size * 8, sn=sn, stride=2, dilate_rate=1,
                     data_format=data_format, num_blocks=num_blocks,
                     last_unit=True,
                     activation_fn=activation_fn_e,
                     normalizer_fn=normalizer_fn_e,
                     normalizer_params=normalizer_params_e,
                     weights_initializer=weight_initializer,
                     unit_num=4)
    output_list.append(hts_4[-1])

    return output_list
예제 #2
0
def discriminate_mru(discrim_inputs, discrim_targets, num_classes, labels=None, reuse=False, 
                      data_format='NCHW', scope_name=None):
    print("MRU Discriminator")
    assert data_format == 'NCHW'
    size = SIZE
    num_blocks = NUM_BLOCKS
    resize_func = tf.image.resize_bilinear
    sn = Config.sn

    if data_format == 'NCHW':
        channel_axis = 1
    else:
        channel_axis = 3
    if type(discrim_targets) is list:
        discrim_targets = discrim_targets[-1]

    if data_format == 'NCHW':
        x_list = []
        resized_ = discrim_targets
        x_list.append(resized_)

        for i in range(5):
            resized_ = mean_pool(resized_, data_format=data_format)
            x_list.append(resized_)
        x_list = x_list[::-1]
    else:
        raise NotImplementedError

    output_dim = 1

    with tf.variable_scope(scope_name) as scope:
        if reuse:
            scope.reuse_variables()

        h0 = conv2d(x_list[-1], 8, kernel_size=7, sn=sn, stride=1, data_format=data_format,
                    activation_fn=activation_fn_d,
                    normalizer_fn=normalizer_fn_d,
                    normalizer_params=normalizer_params_d,
                    weights_initializer=weight_initializer)

        # Initial memory state
        hidden_state_shape = h0.get_shape().as_list()
        batch_size = hidden_state_shape[0]
        hidden_state_shape[0] = 1
        hts_0 = [h0]
        for i in range(1, num_blocks):
            h0 = tf.tile(tf.get_variable("initial_hidden_state_%d" % i, shape=hidden_state_shape, dtype=tf.float32,
                                         initializer=tf.zeros_initializer()), [batch_size, 1, 1, 1])
            hts_0.append(h0)

        hts_1 = mru_conv(x_list[-1], hts_0,
                         size * 2, sn=sn, stride=2, dilate_rate=1,
                         data_format=data_format, num_blocks=num_blocks,
                         last_unit=False,
                         activation_fn=activation_fn_d,
                         normalizer_fn=normalizer_fn_d,
                         normalizer_params=normalizer_params_d,
                         weights_initializer=weight_initializer,
                         unit_num=1)
        hts_2 = mru_conv(x_list[-2], hts_1,
                         size * 4, sn=sn, stride=2, dilate_rate=1,
                         data_format=data_format, num_blocks=num_blocks,
                         last_unit=False,
                         activation_fn=activation_fn_d,
                         normalizer_fn=normalizer_fn_d,
                         normalizer_params=normalizer_params_d,
                         weights_initializer=weight_initializer,
                         unit_num=2)
        hts_3 = mru_conv(x_list[-3], hts_2,
                         size * 8, sn=sn, stride=2, dilate_rate=1,
                         data_format=data_format, num_blocks=num_blocks,
                         last_unit=False,
                         activation_fn=activation_fn_d,
                         normalizer_fn=normalizer_fn_d,
                         normalizer_params=normalizer_params_d,
                         weights_initializer=weight_initializer,
                         unit_num=3)
        hts_4 = mru_conv(x_list[-4], hts_3,
                         size * 12, sn=sn, stride=2, dilate_rate=1,
                         data_format=data_format, num_blocks=num_blocks,
                         last_unit=True,
                         activation_fn=activation_fn_d,
                         normalizer_fn=normalizer_fn_d,
                         normalizer_params=normalizer_params_d,
                         weights_initializer=weight_initializer,
                         unit_num=4)

        img = hts_4[-1]
        img_shape = img.get_shape().as_list()

        # discriminator end
        disc = conv2d(img, output_dim, kernel_size=1, sn=sn, stride=1, data_format=data_format,
                      activation_fn=None, normalizer_fn=None,
                      weights_initializer=weight_initializer)

        if Config.proj_d:
            # Projection discriminator
            assert labels is not None and (len(labels.get_shape()) == 1 or labels.get_shape().as_list()[-1] == 1)

            class_embeddings = embed_labels(labels, num_classes, img_shape[channel_axis], sn=sn)
            class_embeddings = tf.reshape(class_embeddings, (img_shape[0], img_shape[channel_axis], 1, 1))  # NCHW

            disc += tf.reduce_sum(img * class_embeddings, axis=1, keep_dims=True)

            logits = None
        else:
            # classification end
            img = tf.reduce_mean(img, axis=(2, 3) if data_format == 'NCHW' else (1, 2))
            logits = fully_connected(img, num_classes, sn=sn, activation_fn=None, normalizer_fn=None)

    return disc, logits
예제 #3
0
def generate_mru(z, text_vocab_indices, LSTM_hybrid, output_channel, num_classes, vocab_size, reuse=False,
                  data_format='NCHW',
                  labels=None, scope_name=None):
    print("MRU Generator")
    size = SIZE
    num_blocks = NUM_BLOCKS
    sn = False

    input_dims = z.get_shape().as_list()
    resize_method = tf.image.ResizeMethod.AREA

    if data_format == 'NCHW':
        height = input_dims[2]
        width = input_dims[3]
    else:
        height = input_dims[1]
        width = input_dims[2]
    resized_z = [tf.identity(z)]
    for i in range(5):
        resized_z.append(image_resize(z, [int(height / 2 ** (i + 1)), int(width / 2 ** (i + 1))],
                                      resize_method, data_format))
    resized_z = resized_z[::-1]

    if data_format == 'NCHW':
        concat_axis = 1
    else:
        concat_axis = 3

    if normalizer_params_g is not None and normalizer_fn_g != ly.batch_norm and normalizer_fn_g != ly.layer_norm:
        normalizer_params_g['labels'] = labels
        normalizer_params_g['n_labels'] = num_classes

    with tf.variable_scope(scope_name) as scope:
        if reuse:
            scope.reuse_variables()

        z_encoded = image_encoder_mru(z, num_classes=num_classes, reuse=reuse, data_format=data_format,
                                      labels=labels, scope_name=scope_name)

        input_e_dims = z_encoded[-1].get_shape().as_list()
        batch_size = input_e_dims[0]

        # z_encoded[-1].shape = [N, 512, 6, 6], text_vocab_indices.shape = [N, 15]

        if LSTM_hybrid:

            ## Add text LSTM
            lstm_output = encode_feat_with_text(z_encoded[-1], text_vocab_indices, input_e_dims, vocab_size)
            feat_encoded_final = lstm_output  # [N, 512, 6, 6]

        else:
            feat_encoded_final = z_encoded[-1]

        channel_depth = int(input_e_dims[concat_axis] / 8.)
        if data_format == 'NCHW':
            noise_dims = [batch_size, channel_depth, int(input_e_dims[2] * 2), int(input_e_dims[3] * 2)]
        else:
            noise_dims = [batch_size, int(input_e_dims[1] * 2), int(input_e_dims[2] * 2), channel_depth]

        noise_vec = tf.random_normal(shape=(batch_size, 256), dtype=tf.float32)
        noise = fully_connected(noise_vec, int(np.prod(noise_dims[1:])), sn=sn,
                                activation_fn=activation_fn_g,
                                # normalizer_fn=normalizer_fn_g,
                                # normalizer_params=normalizer_params_g
                                )
        noise = tf.reshape(noise, shape=noise_dims)

        # Initial memory state
        hidden_state_shape = z_encoded[-1].get_shape().as_list()
        hidden_state_shape[0] = 1
        hts_0 = [feat_encoded_final]

        input_0 = tf.concat([resized_z[1], noise], axis=concat_axis)
        hts_1 = mru_deconv(input_0, hts_0,
                           size * 6, sn=sn, stride=2, data_format=data_format,
                           num_blocks=num_blocks,
                           last_unit=False,
                           activation_fn=activation_fn_g,
                           normalizer_fn=normalizer_fn_g,
                           normalizer_params=normalizer_params_g,
                           weights_initializer=weight_initializer,
                           unit_num=0)
        input_1 = tf.concat([resized_z[2], z_encoded[-3]], axis=concat_axis)
        hts_2 = mru_deconv(input_1, hts_1,
                           size * 4, sn=sn, stride=2, data_format=data_format,
                           num_blocks=num_blocks,
                           last_unit=False,
                           activation_fn=activation_fn_g,
                           normalizer_fn=normalizer_fn_g,
                           normalizer_params=normalizer_params_g,
                           weights_initializer=weight_initializer,
                           unit_num=2)
        input_2 = tf.concat([resized_z[3], z_encoded[-4]], axis=concat_axis)
        hts_3 = mru_deconv(input_2, hts_2,
                           size * 2, sn=sn, stride=2, data_format=data_format,
                           num_blocks=num_blocks,
                           last_unit=False,
                           activation_fn=activation_fn_g,
                           normalizer_fn=normalizer_fn_g,
                           normalizer_params=normalizer_params_g,
                           weights_initializer=weight_initializer,
                           unit_num=4)
        input_3 = tf.concat([resized_z[4], z_encoded[-5]], axis=concat_axis)
        hts_4 = mru_deconv(input_3, hts_3,
                           size * 2, sn=sn, stride=2, data_format=data_format,
                           num_blocks=num_blocks,
                           last_unit=False,
                           activation_fn=activation_fn_g,
                           normalizer_fn=normalizer_fn_g,
                           normalizer_params=normalizer_params_g,
                           weights_initializer=weight_initializer,
                           unit_num=6)
        hts_5 = mru_deconv(resized_z[5], hts_4,
                           size * 1, sn=sn, stride=2, data_format=data_format,
                           num_blocks=num_blocks,
                           last_unit=True,
                           activation_fn=activation_fn_g,
                           normalizer_fn=normalizer_fn_g,
                           normalizer_params=normalizer_params_g,
                           weights_initializer=weight_initializer,
                           unit_num=8)
        out = conv2d(hts_5[-1], output_channel, 7, sn=sn, stride=1, data_format=data_format,
                     normalizer_fn=None, activation_fn=tf.nn.tanh,
                     weights_initializer=weight_initializer)
        if out.get_shape().as_list()[2] != height:
            raise ValueError('Current shape', out.get_shape().as_list()[2], 'not match', height)
        return out, noise_vec