예제 #1
0
파일: baseline.py 프로젝트: zlpsls/pose-gan
def make_discriminator():
    """Creates a discriminator model that takes an image as input and outputs a single value, representing whether
    the input is real or generated."""
    x = Input((128, 64, 3))
    y = Conv2D(64, (3, 3),
               kernel_initializer='he_uniform',
               use_bias=True,
               padding='same')(x)
    y = resblock(y, (3, 3), 'DOWN', 128, InstanceNormalization)
    y = resblock(y, (3, 3), 'DOWN', 256, InstanceNormalization)
    y = resblock(y, (3, 3), 'DOWN', 512, InstanceNormalization)
    y = resblock(y, (3, 3), 'DOWN', 512, InstanceNormalization)

    y = Flatten()(y)
    y = Dense(1, use_bias=False)(y)
    return Model(inputs=x, outputs=y)
예제 #2
0
파일: baseline.py 프로젝트: zlpsls/pose-gan
def make_generator():
    """Creates a generator model that takes a 128-dimensional noise vector as a "seed", and outputs images
    of size 128x64x3."""
    x = Input((128, ))
    y = Dense(512 * 8 * 4)(x)
    y = Reshape((8, 4, 512))(y)

    y = resblock(y, (3, 3), 'UP', 512)
    y = resblock(y, (3, 3), 'UP', 256)
    y = resblock(y, (3, 3), 'UP', 128)
    y = resblock(y, (3, 3), 'UP', 64)

    y = BatchNormalization(axis=-1)(y)
    y = Activation('relu')(y)
    y = Conv2D(3, (3, 3),
               kernel_initializer='he_uniform',
               use_bias=False,
               padding='same',
               activation='tanh')(y)
    return Model(inputs=x, outputs=y)
예제 #3
0
def make_generator(input_noise_shape=(128,), output_channels=3, input_cls_shape=(1, ),
                   block_sizes=(128, 128, 128), resamples=("UP", "UP", "UP"),
                   first_block_shape=(4, 4, 128), number_of_classes=10, concat_cls=False,
                   block_norm='u', block_coloring='cs', filters_emb=10,
                   last_norm='u', last_coloring='cs',
                   decomposition='cholesky', whitten_group=1, coloring_group=1, iter_num=5, instance_norm=0,
                   gan_type=None, arch='res', spectral=False,
                   fully_diff_spectral=False, spectral_iterations=1, conv_singular=True,):

    assert arch in ['res', 'dcgan']
    inp = Input(input_noise_shape, name='GInputImage')
    cls = Input(input_cls_shape, dtype='int32', name='GLabel')

    if spectral:
        conv_layer = partial(SNConv2D, conv_singular=conv_singular,
                             fully_diff_spectral=fully_diff_spectral, spectral_iterations=spectral_iterations)
        cond_conv_layer = partial(SNConditionalConv11,
                                  fully_diff_spectral=fully_diff_spectral, spectral_iterations=spectral_iterations)
        dense_layer = partial(SNDense,
                              fully_diff_spectral=fully_diff_spectral, spectral_iterations=spectral_iterations)
        emb_layer = partial(SNEmbeding, fully_diff_spectral=fully_diff_spectral, spectral_iterations=spectral_iterations)
        factor_conv_layer = partial(SNFactorizedConv11,
                                    fully_diff_spectral=fully_diff_spectral, spectral_iterations=spectral_iterations)
    else:
        conv_layer = Conv2D
        cond_conv_layer = ConditionalConv11
        dense_layer = Dense
        emb_layer = Embedding
        factor_conv_layer = FactorizedConv11

    if concat_cls:
        y = emb_layer(input_dim=number_of_classes, output_dim=first_block_shape[-1])(cls)
        y = Reshape((first_block_shape[-1], ))(y)
        y = Concatenate(axis=-1)([y, inp])
    else:
        y = inp

    y = dense_layer(units=np.prod(first_block_shape), kernel_initializer=glorot_init)(y)
    y = Reshape(first_block_shape)(y)

    block_norm_layer = create_norm(block_norm, block_coloring,
                                   decomposition=decomposition, whitten_group=whitten_group, coloring_group=coloring_group, iter_num=iter_num, instance_norm=instance_norm,
                                   cls=cls, number_of_classes=number_of_classes, filters_emb=filters_emb,
                                   uncoditional_conv_layer=conv_layer, conditional_conv_layer=cond_conv_layer,
                                   factor_conv_layer=factor_conv_layer)

    last_norm_layer = create_norm(last_norm, last_coloring,
                                  decomposition=decomposition, whitten_group=whitten_group, coloring_group=coloring_group, iter_num=iter_num, instance_norm=instance_norm,
                                  cls=cls, number_of_classes=number_of_classes, filters_emb=filters_emb,
                                  uncoditional_conv_layer=conv_layer, conditional_conv_layer=cond_conv_layer,
                                  factor_conv_layer=factor_conv_layer)

    i = 0
    for block_size, resample in zip(block_sizes, resamples):
        if arch == 'res':
            y = resblock(y, kernel_size=(3, 3), resample=resample,
                         nfilters=block_size, name='Generator.' + str(i),
                         norm=block_norm_layer, is_first=False, conv_layer=conv_layer)
        else:
            # TODO: SN DECONV
            y = dcblock(y, kernel_size=(4, 4), resample=resample,
                        nfilters=block_size, name='Generator.' + str(i),
                        norm=block_norm_layer, is_first=False, conv_layer=Conv2DTranspose)
        i += 1

    y = last_norm_layer(axis=-1, name='Generator.BN.Final')(y)
    y = Activation('relu')(y)
    output = conv_layer(filters=output_channels, kernel_size=(3, 3), name='Generator.Final',
                        kernel_initializer=glorot_init, use_bias=True, padding='same')(y)
    output = Activation('tanh')(output)

    if gan_type is None:
        return Model(inputs=[inp], outputs=output)
    else:
        return Model(inputs=[inp, cls], outputs=output)
예제 #4
0
def make_discriminator(input_image_shape,
                       input_cls_shape=(1, ),
                       block_sizes=(128, 128, 128, 128),
                       resamples=('DOWN', "DOWN", "SAME", "SAME"),
                       number_of_classes=10,
                       type='AC_GAN',
                       norm='n',
                       after_norm='n',
                       spectral=False,
                       fully_diff_spectral=False,
                       spectral_iterations=1,
                       conv_singular=True,
                       sum_pool=False,
                       dropout=False,
                       arch='res',
                       filters_emb=10):

    assert arch in ['res', 'dcgan']
    assert len(block_sizes) == len(resamples)
    x = Input(input_image_shape)
    cls = Input(input_cls_shape, dtype='int32')

    if spectral:
        conv_layer = partial(SNConv2D,
                             conv_singular=conv_singular,
                             fully_diff_spectral=fully_diff_spectral,
                             spectral_iterations=spectral_iterations)
        cond_conv_layer = partial(SNConditionalConv11,
                                  fully_diff_spectral=fully_diff_spectral,
                                  spectral_iterations=spectral_iterations)
        dence_layer = partial(SNDense,
                              fully_diff_spectral=fully_diff_spectral,
                              spectral_iterations=spectral_iterations)
        emb_layer = partial(SNEmbeding,
                            fully_diff_spectral=fully_diff_spectral,
                            spectral_iterations=spectral_iterations)
    else:
        conv_layer = Conv2D
        cond_conv_layer = ConditionalConv11
        dence_layer = Dense
        emb_layer = Embedding

    norm_layer = create_norm(norm=norm,
                             after_norm=after_norm,
                             cls=cls,
                             number_of_classes=number_of_classes,
                             conditional_conv_layer=cond_conv_layer,
                             uncoditional_conv_layer=conv_layer,
                             filters_emb=filters_emb)

    y = x
    i = 0
    for block_size, resample in zip(block_sizes, resamples):
        if arch == 'res':
            y = resblock(y,
                         kernel_size=(3, 3),
                         resample=resample,
                         nfilters=block_size,
                         name='Discriminator.' + str(i),
                         norm=norm_layer,
                         is_first=(i == 0),
                         conv_layer=conv_layer)
            i += 1
        else:
            y = dcblock(y,
                        kernel_size=(3, 3) if resample == "SAME" else (4, 4),
                        resample=resample,
                        nfilters=block_size,
                        name='Discriminator.' + str(i),
                        norm=norm_layer,
                        is_first=(i == 0),
                        conv_layer=conv_layer)
            i += 1

    if arch == 'res':
        y = Activation('relu')(y)
    else:
        y = LeakyReLU()(y)

    if arch == 'res':
        if sum_pool:
            y = GlobalSumPooling2D()(y)
        else:
            y = GlobalAveragePooling2D()(y)
    else:
        y = Flatten()(y)

    if dropout != 0:
        y = Dropout(dropout)(y)

    if type == 'AC_GAN':
        cls_out = Dense(units=number_of_classes,
                        use_bias=True,
                        kernel_initializer=glorot_init)(y)
        out = dence_layer(units=1,
                          use_bias=True,
                          kernel_initializer=glorot_init)(y)
        return Model(inputs=x, outputs=[out, cls_out])
    elif type == "PROJECTIVE":
        emb = emb_layer(input_dim=number_of_classes,
                        output_dim=block_sizes[-1])(cls)
        phi = Lambda(
            lambda inp: K.sum(inp[1] * K.expand_dims(inp[0], axis=1), axis=2),
            output_shape=(1, ))([y, emb])
        psi = dence_layer(units=1,
                          use_bias=True,
                          kernel_initializer=glorot_init)(y)
        out = Add()([phi, psi])
        return Model(inputs=[x, cls], outputs=[out])
    elif type is None:
        out = dence_layer(units=1,
                          use_bias=True,
                          kernel_initializer=glorot_init)(y)
        return Model(inputs=[x], outputs=[out])