Exemplo n.º 1
0
def mapping_network_fixer(style_model, style_gan_name, trainable=False):
    '''
    Assuming layers names (and number of nodes) match
    '''
    W = tu.dummy_loader(style_gan_name)
    model = mu.UNET_STYLE(N,
                          input_size,
                          latent_lev,
                          latent_size,
                          mapping_size,
                          pool=pool,
                          activation=activation,
                          noise=[0.2, 0.1])
    model.set_weights(W)
    W_mapping_network = []
    for layer in model.layers:
        if layer.name[:3] == 'map':
            W_mapping_network += layer.get_weights()
    if ~trainable:
        style_model.trainable = False
        for layer in style_model.layers:
            layer.trainable = False
    opt_G = keras.optimizers.Adam(lr=0)
    style_model.compile(loss=keras.losses.mean_squared_error, optimizer=opt_G)
    style_model.set_weights(W_mapping_network)
    return style_model
Exemplo n.º 2
0
def mix_network_fixer(dscale_model, dscale_gan_name):
    W = tu.dummy_loader(dscale_gan_name)
    model = mu.UNET_STYLE(N,
                          input_size,
                          latent_lev,
                          latent_size,
                          mapping_size,
                          pool=pool,
                          activation=activation,
                          noise=[False, False])
    model.set_weights(W)
    for layer in dscale_model.layers:
        if layer.name[:9] == 'unet_left':
            layer.trainable = False
    # compile
    opt_G = keras.optimizers.Adam(lr=0)
    dscale_model.compile(loss=keras.losses.mean_squared_error, optimizer=opt_G)
    # weights
    for layer in model.layers:
        for layer2 in dscale_model.layers:
            if layer.name == layer2.name and layer.name[:9] == 'unet_left':
                layer2.set_weights(layer.get_weights())
    return dscale_model
Exemplo n.º 3
0
mapping_size = N[-1]

# overwrite flags
input_flag = [False, True, False, False, True, True]  # LR T2, HR elev, LR elev
output_flag = [True, False, False, False, False, False]  # HR T2
inout_flag = [True, True, False, False, True, True]
labels = ['batch', 'batch']  # input and output labels
key = 'SGAN'

for sea in seasons:
    # ----- G ----- #
    input_size = (None, None, N_input)
    G_style = mu.UNET_STYLE(N,
                            input_size,
                            latent_lev,
                            latent_size,
                            mapping_size,
                            pool=pool,
                            activation=activation,
                            noise=[0.2, 0.1])
    opt_G = keras.optimizers.Adam(lr=0)
    print('Compiling G')
    G_style.compile(loss=keras.losses.mean_squared_error, optimizer=opt_G)
    # ------------- #

    # ----- D ----- #
    input_size = (None, None, N_input + 1)
    D = mu.vgg_descriminator(N, input_size)
    opt_D = keras.optimizers.Adam(lr=l[1])
    print('Compiling D')
    D.compile(loss=keras.losses.mean_squared_error, optimizer=opt_D)
    D.trainable = False