コード例 #1
0
def write_images():
    """Load each checkpoint in order and generate a visualization for each one
    """
    G_names = sorted(os.listdir(os.path.join('checkpoints', RUN_NAME, 'G')))
    D_names = sorted(os.listdir(os.path.join('checkpoints', RUN_NAME, 'D')))
    if D_names == []:
        D_names = [None] * len(G_names)

    os.makedirs(os.path.join('images', RUN_NAME), exist_ok=True)
    custom_objects = {'swish': swish}


    for i, (G_name, D_name) in tqdm(enumerate(zip(G_names, D_names))):
        G = tf.keras.models.load_model(os.path.join('checkpoints', RUN_NAME, 'G', G_name), custom_objects=custom_objects, compile=False)
        # D = tf.keras.models.load_model(os.path.join('checkpoints', RUN_NAME, 'D', D_name), custom_objects=custom_objects, compile=False)
        messages = make_first_n_messages(num_messages, VECTOR_DIM, ENCODING)
        symbols = [get_glyph_symbol(message, ENCODING, VECTOR_DIM) for message in messages]
        glyphs = G(messages)
        if PREVIEW:
            visualize(symbols, glyphs, f'Epoch {i + 1}', get_fig=False, use_titles=False)
        else:
            fig = visualize(symbols, glyphs, f'Epoch {i + 1}', get_fig=True, use_titles=False)
            # formatting
            margin = 20
            top_margin = 40
            fig.update_layout(width=(message_dim * num_message_dim * 1.1) + margin * 2,
                            height=(message_dim * num_message_dim * 1.1) + margin + top_margin,
                            margin=go.layout.Margin(l=margin,
                                                    r=margin,
                                                    b=margin,
                                                    t=top_margin,
                                                    pad=0))
            fig.update_xaxes(showticklabels=False)
            fig.update_yaxes(showticklabels=False)
            fig.write_image(os.path.join('images', RUN_NAME, f'{i+1:03d}.png'))
コード例 #2
0
def visualize_samples(dim, title):
    """Generate a square of samples and visualize them
    """
    num_glyphs = dim**2
    messages, _ = make_messages(num_glyphs, opt.encoding, opt.vector_dim)
    symbols = [
        get_glyph_symbol(message, opt.encoding, opt.vector_dim)
        for message in messages
    ]
    glyphs = G(messages)
    visualize(symbols, glyphs, title)
コード例 #3
0
                this_difficulty = None
                if randomize:
                    this_difficulty = random.choice(list(range(DIFFICULTY +
                                                               1)))
                else:
                    this_difficulty = DIFFICULTY
                if this_difficulty != 0:
                    glyphs = func(glyphs, this_difficulty)
            glyphs = random_augmentation(glyphs)
            return glyphs

    funcs = []
    for func_name in func_names:
        assert func_name in dir(
            Differentiable_Augment), f"Function '{func_name}' doesn't exist"
        funcs.append(getattr(Differentiable_Augment, func_name))
    return lambda glyphs, DIFFICULTY, randomize=False: noise_pipeline(
        glyphs, funcs, DIFFICULTY, randomize)


# preview image augmentation
if __name__ == "__main__":
    symbols = ['random'] * 9
    glyphs = random_glyphs(9, [64, 64, 3])
    noisy_chanel = get_noisy_channel()
    visualize(symbols, glyphs, 'Before Augmentation')
    for DIFFICULTY in range(10):
        new_glyphs = noisy_chanel(glyphs, DIFFICULTY)
        visualize(symbols, new_glyphs,
                  f'After Augmentation (difficulty = {DIFFICULTY})')