示例#1
0
def predict(network_ver, dict_src_name):
    cur_dir = os.getcwd()
    res_dir = os.path.abspath(os.path.join(cur_dir, '..', 'res'))
    all_models_dir = os.path.abspath(os.path.join(cur_dir, '..', 'models'))
    model_dir = utils.check_or_create_local_path('pro_classifier',
                                                 all_models_dir)
    version_dir = utils.check_or_create_local_path(network_ver, model_dir)
    model_save_dir = utils.check_or_create_local_path('models', version_dir)

    classifications_dir = utils.check_or_create_local_path(
        'classifications', model_dir)
    utils.delete_files_in_path(classifications_dir)

    pro_dir = utils.check_or_create_local_path('pro', classifications_dir)
    notpro_dir = utils.check_or_create_local_path('notpro',
                                                  classifications_dir)

    print('Loading model...')
    classifier = load_model(f'{model_save_dir}\\latest.h5')

    print('Loading block images...')
    block_images = utils.load_block_images(res_dir)

    print('Loading encoding dictionaries...')
    block_forward, block_backward = utils.load_encoding_dict(
        res_dir, 'blocks_optimized')

    x_data, x_files = load_worlds_with_files(5000, f'{res_dir}\\worlds\\',
                                             (112, 112), block_forward)

    x_labeled = utils.load_label_dict(res_dir, dict_src_name)

    batch_size = 50
    batches = x_data.shape[0] // batch_size

    for batch_index in range(batches):
        x_batch = x_data[batch_index * batch_size:(batch_index + 1) *
                         batch_size]
        y_batch = classifier.predict(x_batch)

        for world in range(batch_size):
            g_index = (batch_index * batch_size) + world
            world_file = x_files[g_index]
            world_id = utils.get_world_id(world_file)

            # Ignore worlds we've already labeled
            if world_id in x_labeled:
                continue

            prediction = y_batch[world]

            world_data = utils.load_world_data_ver3(
                f'{res_dir}\\worlds\\{world_id}.world')

            if prediction[0] < 0.5:
                utils.save_world_preview(block_images, world_data,
                                         f'{notpro_dir}\\{world_id}.png')
            else:
                utils.save_world_preview(block_images, world_data,
                                         f'{pro_dir}\\{world_id}.png')
示例#2
0
def test(version_name, samples):
    cur_dir = os.getcwd()
    res_dir = os.path.abspath(os.path.join(cur_dir, '..', 'res'))
    all_models_dir = os.path.abspath(os.path.join(cur_dir, '..', 'models'))
    model_dir = utils.check_or_create_local_path('translator', all_models_dir)
    version_dir = utils.check_or_create_local_path(version_name, model_dir)
    model_save_dir = utils.check_or_create_local_path('models', version_dir)

    tests_dir = utils.check_or_create_local_path('tests', version_dir)
    utils.delete_files_in_path(tests_dir)

    print('Loading minimap values...')
    minimap_values = utils.load_minimap_values(res_dir)

    print('Loading block images...')
    block_images = utils.load_block_images(res_dir)

    print('Loading encoding dictionaries...')
    block_forward, block_backward = utils.load_encoding_dict(
        res_dir, 'blocks_optimized')

    print('Loading model...')
    translator = load_model(f'{model_save_dir}\\best_loss.h5')
    size = translator.input_shape[1]

    print('Loading worlds...')
    x_train, y_train = load_worlds_with_minimaps(samples,
                                                 f'{res_dir}\\worlds\\',
                                                 (size, size), block_forward,
                                                 minimap_values)

    for i in range(samples):
        utils.save_world_preview(
            block_images, utils.decode_world_sigmoid(block_backward,
                                                     x_train[i]),
            f'{tests_dir}\\world{i}.png')
        utils.save_rgb_map(utils.decode_world_minimap(y_train[i]),
                           f'{tests_dir}\\truth{i}.png')

        y_predict = translator.predict(np.array([x_train[i]]))
        utils.save_rgb_map(utils.decode_world_minimap(y_predict[0]),
                           f'{tests_dir}\\test{i}.png')
示例#3
0
def train(epochs, batch_size, world_count, version_name=None):
    cur_dir = os.getcwd()
    res_dir = os.path.abspath(os.path.join(cur_dir, '..', 'res'))
    all_models_dir = os.path.abspath(os.path.join(cur_dir, '..', 'models'))
    model_dir = utils.check_or_create_local_path('auto_encoder',
                                                 all_models_dir)
    utils.delete_empty_versions(model_dir, 1)

    no_version = version_name is None
    if no_version:
        latest = utils.get_latest_version(model_dir)
        version_name = f'ver{latest + 1}'

    version_dir = utils.check_or_create_local_path(version_name, model_dir)
    graph_dir = utils.check_or_create_local_path('graph', model_dir)
    graph_version_dir = utils.check_or_create_local_path(
        version_name, graph_dir)

    worlds_dir = utils.check_or_create_local_path('worlds', version_dir)
    previews_dir = utils.check_or_create_local_path('previews', version_dir)
    model_save_dir = utils.check_or_create_local_path('models', version_dir)

    latest_epoch = utils.get_latest_epoch(model_save_dir)
    initial_epoch = latest_epoch + 1

    print('Saving source...')
    utils.save_source_to_dir(version_dir)

    # Load block images
    print('Loading block images...')
    block_images = utils.load_block_images(res_dir)

    print('Loading encoding dictionaries...')
    block_forward, block_backward = utils.load_encoding_dict(
        res_dir, 'blocks_optimized')

    # Load model and existing weights
    print('Loading model...')

    # Try to load full model, otherwise try to load weights
    loaded_model = False
    if not no_version and latest_epoch != -1:
        if os.path.exists(
                f'{version_dir}\\models\\epoch{latest_epoch}\\autoencoder.h5'):
            print('Found models.')
            ae = load_model(
                f'{version_dir}\\models\\epoch{latest_epoch}\\autoencoder.h5')
            loaded_model = True
        elif os.path.exists(
                f'{version_dir}\\models\\epoch{latest_epoch}\\autoencoder.weights'
        ):
            print('Found weights.')
            ae = autoencoder_model(112)
            ae.load_weights(
                f'{version_dir}\\models\\epoch{latest_epoch}\\autoencoder.weights'
            )

            print('Compiling model...')
            ae_optim = Adam(lr=0.0001)
            ae.compile(loss='binary_crossentropy', optimizer=ae_optim)
            loaded_model = True

    # Model was not loaded, compile new one
    if not loaded_model:
        print('Compiling model...')
        ae = autoencoder_model(112)
        print('Compiling model...')
        ae_optim = Adam(lr=0.0001)
        ae.compile(loss='binary_crossentropy', optimizer=ae_optim)

    if no_version:
        # Delete existing worlds and previews if any
        print('Checking for old generated data...')
        utils.delete_files_in_path(worlds_dir)
        utils.delete_files_in_path(previews_dir)

    print('Saving model images...')
    keras.utils.plot_model(ae,
                           to_file=f'{version_dir}\\autoencoder.png',
                           show_shapes=True,
                           show_layer_names=True)

    # Load Data
    print('Loading worlds...')
    x_train = load_worlds(world_count, f'{res_dir}\\worlds\\', (112, 112),
                          block_forward)

    # Start Training loop
    world_count = x_train.shape[0]
    batch_cnt = (world_count - (world_count % batch_size)) // batch_size

    # Set up tensorboard
    print('Setting up tensorboard...')
    tb_manager = TensorboardManager(graph_version_dir, batch_cnt)

    for epoch in range(initial_epoch, epochs):

        # Create directories for current epoch
        cur_worlds_cur = utils.check_or_create_local_path(
            f'epoch{epoch}', worlds_dir)
        cur_previews_dir = utils.check_or_create_local_path(
            f'epoch{epoch}', previews_dir)
        cur_models_dir = utils.check_or_create_local_path(
            f'epoch{epoch}', model_save_dir)

        print('Shuffling data...')
        np.random.shuffle(x_train)

        for batch in range(batch_cnt):

            # Get real set of images
            world_batch = x_train[batch * batch_size:(batch + 1) * batch_size]

            # Train
            loss = ae.train_on_batch(world_batch, world_batch)

            # Save snapshot of generated images on last batch
            if batch == batch_cnt - 1:

                # Generate samples
                generated = ae.predict(world_batch)

                # Save samples
                for image_num in range(batch_size):
                    generated_world = generated[image_num]
                    decoded_world = utils.decode_world_sigmoid(
                        block_backward, generated_world)
                    utils.save_world_data(
                        decoded_world,
                        f'{cur_worlds_cur}\\world{image_num}.world')
                    utils.save_world_preview(
                        block_images, decoded_world,
                        f'{cur_previews_dir}\\preview{image_num}.png')

                # Save actual worlds
                for image_num in range(batch_size):
                    actual_world = world_batch[image_num]
                    decoded_world = utils.decode_world_sigmoid(
                        block_backward, actual_world)
                    utils.save_world_preview(
                        block_images, decoded_world,
                        f'{cur_previews_dir}\\actual{image_num}.png')

            # Write loss
            tb_manager.log_var('ae_loss', epoch, batch, loss)

            print(
                f'epoch [{epoch}/{epochs}] :: batch [{batch}/{batch_cnt}] :: loss = {loss}'
            )

            # Save models
            if batch % 100 == 99 or batch == batch_cnt - 1:
                print('Saving models...')
                try:
                    ae.save(f'{cur_models_dir}\\autoencoder.h5')
                    ae.save_weights(f'{cur_models_dir}\\autoencoder.weights')
                except ImportError:
                    print('Failed to save data.')
示例#4
0
def predict_sample_matlab(network_ver, samples):
    cur_dir = os.getcwd()
    res_dir = os.path.abspath(os.path.join(cur_dir, '..', 'res'))
    all_models_dir = os.path.abspath(os.path.join(cur_dir, '..', 'models'))
    model_dir = utils.check_or_create_local_path('auto_encoder',
                                                 all_models_dir)
    version_dir = utils.check_or_create_local_path(network_ver, model_dir)
    model_save_dir = utils.check_or_create_local_path('models', version_dir)

    plots_dir = utils.check_or_create_local_path('plots', model_dir)
    utils.delete_files_in_path(plots_dir)

    print('Loading model...')
    latest_epoch = utils.get_latest_epoch(model_save_dir)
    auto_encoder = load_model(
        f'{model_save_dir}\\epoch{latest_epoch}\\autoencoder.h5')

    print('Loading block images...')
    block_images = utils.load_block_images(res_dir)

    print('Loading encoding dictionaries...')
    block_forward, block_backward = utils.load_encoding_dict(
        res_dir, 'blocks_optimized')

    x_worlds = os.listdir(f'{res_dir}\\worlds\\')
    np.random.shuffle(x_worlds)

    world_size = auto_encoder.input_shape[1]
    dpi = 96
    rows = samples
    cols = 2
    hpixels = 520 * cols
    hfigsize = hpixels / dpi
    vpixels = 530 * rows
    vfigsize = vpixels / dpi
    fig = plt.figure(figsize=(hfigsize, vfigsize), dpi=dpi)

    def set_ticks():
        no_labels = 2  # how many labels to see on axis x
        step = (16 * world_size) / (no_labels - 1
                                    )  # step between consecutive labels
        positions = np.arange(0, (16 * world_size) + 1,
                              step)  # pixel count at label position
        labels = positions // 16
        plt.xticks(positions, labels)
        plt.yticks(positions, labels)

    sample_num = 0
    for world_filename in x_worlds:
        world_file = os.path.join(f'{res_dir}\\worlds\\', world_filename)
        world_id = utils.get_world_id(world_filename)

        # Load world and save preview
        encoded_regions = load_world(world_file, (world_size, world_size),
                                     block_forward)
        if len(encoded_regions) == 0:
            continue

        # Create prediction
        batch_input = np.empty((1, world_size, world_size, 10), dtype=np.int8)
        batch_input[0] = encoded_regions[0]
        encoded_world = auto_encoder.predict(batch_input)

        before = utils.decode_world_sigmoid(block_backward, encoded_regions[0])
        utils.save_world_preview(block_images, before,
                                 f'{plots_dir}\\before{sample_num}.png')

        after = utils.decode_world_sigmoid(block_backward, encoded_world[0])
        utils.save_world_preview(block_images, after,
                                 f'{plots_dir}\\after{sample_num}.png')

        # Create before plot
        before_img = mpimg.imread(f'{plots_dir}\\before{sample_num}.png')
        encoded_subplt = fig.add_subplot(rows, cols, sample_num + 1)
        encoded_subplt.set_title(f'{world_id}\nActual')
        set_ticks()
        plt.imshow(before_img)

        # Create after plot
        after_img = mpimg.imread(f'{plots_dir}\\after{sample_num}.png')
        encoded_subplt = fig.add_subplot(rows, cols, sample_num + 2)
        encoded_subplt.set_title(f'{world_id}\nEncoded')
        set_ticks()
        plt.imshow(after_img)

        print(f'Added plot {(sample_num / 2) + 1} of {samples}')

        sample_num += 2
        if sample_num >= rows * cols:
            break

    print('Saving figure...')
    fig.tight_layout()
    fig.savefig(f'{plots_dir}\\plot.png', transparent=True)
示例#5
0
def train(epochs, batch_size, world_count, version_name=None, initial_epoch=0):
    cur_dir = os.getcwd()
    res_dir = os.path.abspath(os.path.join(cur_dir, '..', 'res'))
    all_models_dir = os.path.abspath(os.path.join(cur_dir, '..', 'models'))
    model_dir = utils.check_or_create_local_path('helper', all_models_dir)

    utils.delete_empty_versions(model_dir, 1)

    no_version = version_name is None
    if no_version:
        latest = utils.get_latest_version(model_dir)
        version_name = f'ver{latest + 1}'

    version_dir = utils.check_or_create_local_path(version_name, model_dir)
    graph_dir = utils.check_or_create_local_path('graph', model_dir)
    graph_version_dir = utils.check_or_create_local_path(
        version_name, graph_dir)

    worlds_dir = utils.check_or_create_local_path('worlds', version_dir)
    previews_dir = utils.check_or_create_local_path('previews', version_dir)
    model_save_dir = utils.check_or_create_local_path('models', version_dir)

    print('Saving source...')
    utils.save_source_to_dir(version_dir)

    # Load block images
    print('Loading block images...')
    block_images = utils.load_block_images(res_dir)

    print('Loading encoding dictionaries...')
    block_forward, block_backward = utils.load_encoding_dict(
        res_dir, 'blocks_optimized')

    if no_version:
        # Delete existing worlds and previews if any
        print('Checking for old generated data...')
        utils.delete_files_in_path(worlds_dir)
        utils.delete_files_in_path(previews_dir)

    # Load model and existing weights
    print('Loading models...')

    judge = build_judge_model(32)
    judge_optimizer = Adam(lr=0.0001)
    judge.compile(loss='binary_crossentropy',
                  optimizer=judge_optimizer,
                  metrics=['accuracy'])

    helper_optimizer = Adam(lr=0.001)
    helper = build_helper_model(32)
    helper_feedback = build_helper_feedback_model(helper, judge, 32)
    helper_feedback.compile(loss='binary_crossentropy',
                            optimizer=helper_optimizer)

    # Load Data
    print('Loading worlds...')
    x_train = load_worlds(world_count, f'{res_dir}\\worlds\\', (32, 32),
                          block_forward)

    # Start Training loop
    world_count = x_train.shape[0]
    batch_cnt = (world_count - (world_count % batch_size)) // batch_size
    tb_manager = TensorboardManager(graph_version_dir, batch_cnt)

    for epoch in range(initial_epoch, epochs):

        print(f'Epoch = {epoch} ')
        # Create directories for current epoch
        cur_worlds_cur = utils.check_or_create_local_path(
            f'epoch{epoch}', worlds_dir)
        cur_previews_dir = utils.check_or_create_local_path(
            f'epoch{epoch}', previews_dir)
        cur_models_dir = utils.check_or_create_local_path(
            f'epoch{epoch}', model_save_dir)

        print('Shuffling data...')
        np.random.shuffle(x_train)

        for batch in range(batch_cnt):

            # Get real set of worlds
            world_batch = x_train[batch * batch_size:(batch + 1) * batch_size]
            world_batch_masked, world_masks = utils.mask_batch_low(world_batch)
            world_masks_reshaped = np.reshape(world_masks[:, :, :, 0],
                                              (batch_size, 32 * 32, 1))

            # Get fake set of worlds
            noise = np.random.normal(0, 1, size=(batch_size, 128))
            generated = helper.predict([world_batch_masked, noise])

            real_labels = np.ones((batch_size, 32 * 32, 1))
            fake_labels = np.zeros((batch_size, 32 * 32, 1))
            masked_labels = 1 - world_masks_reshaped

            judge.trainable = True
            j_real = judge.train_on_batch([world_batch_masked, world_batch],
                                          real_labels)
            j_fake = judge.train_on_batch([world_batch_masked, generated[1]],
                                          fake_labels)

            tb_manager.log_var('j_loss_real', epoch, batch, j_real[0])
            tb_manager.log_var('j_loss_fake', epoch, batch, j_fake[0])
            tb_manager.log_var('j_acc_real', epoch, batch, j_real[1])
            tb_manager.log_var('j_acc_fake', epoch, batch, j_fake[1])

            judge.trainable = False
            h_loss = helper_feedback.train_on_batch(
                [world_batch_masked, noise], real_labels)
            tb_manager.log_var('h_loss', epoch, batch, h_loss)

            print(
                f'epoch [{epoch}/{epochs}] :: batch [{batch}/{batch_cnt}] :: fake_loss = {j_fake[0]} :: fake_acc = '
                f'{j_fake[1]} :: real_loss = {j_real[0]} :: real_acc = {j_real[1]} :: h_loss = {h_loss}'
            )

            if batch % 1000 == 999 or batch == batch_cnt - 1:

                # Save generated batch
                for i in range(batch_size):
                    actual_world = world_batch_masked[i]
                    a_decoded = utils.decode_world_sigmoid(
                        block_backward, actual_world)
                    utils.save_world_preview(
                        block_images, a_decoded,
                        f'{cur_previews_dir}\\actual{i}.png')

                    gen_world = generated[1][i]
                    decoded = utils.decode_world_sigmoid(
                        block_backward, gen_world)
                    utils.save_world_preview(
                        block_images, decoded,
                        f'{cur_previews_dir}\\preview{i}.png')

                # Save models
                try:
                    judge.save(f'{cur_models_dir}\\judge.h5')
                    helper.save(f'{cur_models_dir}\\helper.h5')
                    judge.save_weights(f'{cur_models_dir}\\judge.weights')
                    helper.save_weights(f'{cur_models_dir}\\helper.weights')
                except ImportError:
                    print('Failed to save data.')
示例#6
0
def predict_sample_matlab(network_ver, dict_src_name, cols, rows):
    cur_dir = os.getcwd()
    res_dir = os.path.abspath(os.path.join(cur_dir, '..', 'res'))
    all_models_dir = os.path.abspath(os.path.join(cur_dir, '..', 'models'))
    model_dir = utils.check_or_create_local_path('pro_classifier',
                                                 all_models_dir)
    version_dir = utils.check_or_create_local_path(network_ver, model_dir)
    model_save_dir = utils.check_or_create_local_path('models', version_dir)

    plots_dir = utils.check_or_create_local_path('plots', model_dir)
    utils.delete_files_in_path(plots_dir)

    print('Loading model...')
    classifier = load_model(f'{model_save_dir}\\latest.h5')

    print('Loading block images...')
    block_images = utils.load_block_images(res_dir)

    print('Loading encoding dictionaries...')
    block_forward, block_backward = utils.load_encoding_dict(
        res_dir, 'blocks_optimized')

    x_labeled = utils.load_label_dict(res_dir, dict_src_name)
    x_worlds = os.listdir(f'{res_dir}\\worlds\\')
    np.random.shuffle(x_worlds)

    world_size = classifier.input_shape[1]
    dpi = 96
    hpixels = 320 * cols
    hfigsize = hpixels / dpi
    vpixels = 330 * rows
    vfigsize = vpixels / dpi
    fig = plt.figure(figsize=(hfigsize, vfigsize), dpi=dpi)

    sample_num = 0
    pro_score_floor = 0
    pro_score_ceiling = 1.0 / (rows * cols)
    for world_filename in x_worlds:
        world_file = os.path.join(f'{res_dir}\\worlds\\', world_filename)
        world_id = utils.get_world_id(world_filename)
        if world_id not in x_labeled:

            # Load world and save preview
            encoded_regions = load_world(world_file, (world_size, world_size),
                                         block_forward)
            if len(encoded_regions) == 0:
                continue

            # Create prediction
            batch_input = np.empty((1, world_size, world_size, 10),
                                   dtype=np.int8)
            batch_input[0] = encoded_regions[0]
            batch_score = classifier.predict(batch_input)
            pro_score = batch_score[0][0]

            if pro_score < pro_score_floor or pro_score > pro_score_ceiling:
                continue

            decoded_region = utils.decode_world_sigmoid(
                block_backward, encoded_regions[0])
            utils.save_world_preview(block_images, decoded_region,
                                     f'{plots_dir}\\preview{sample_num}.png')

            pro_score_floor += 1.0 / (rows * cols)
            pro_score_ceiling += 1.0 / (rows * cols)

            # Create plot
            img = mpimg.imread(f'{plots_dir}\\preview{sample_num}.png')

            subplt = fig.add_subplot(rows, cols, sample_num + 1)
            subplt.set_title(world_id)
            subplt.set_xlabel('P = %.2f%%' % (pro_score * 100))

            no_labels = 2  # how many labels to see on axis x
            step = (16 * world_size) / (no_labels - 1
                                        )  # step between consecutive labels
            positions = np.arange(0, (16 * world_size) + 1,
                                  step)  # pixel count at label position
            labels = positions // 16
            plt.xticks(positions, labels)
            plt.yticks(positions, labels)
            plt.imshow(img)

            print(f'Adding plot {sample_num + 1} of {rows * cols}')

            sample_num += 1
            if sample_num >= rows * cols:
                break

    print('Saving figure...')
    fig.tight_layout()
    fig.savefig(f'{plots_dir}\\plot.png', transparent=True)
示例#7
0
def train(epochs,
          batch_size,
          world_count,
          latent_dim,
          version_name=None,
          initial_epoch=0):
    cur_dir = os.getcwd()
    res_dir = os.path.abspath(os.path.join(cur_dir, '..', 'res'))
    all_models_dir = os.path.abspath(os.path.join(cur_dir, '..', 'models'))
    model_dir = utils.check_or_create_local_path('gan', all_models_dir)

    utils.delete_empty_versions(model_dir, 1)
    no_version = version_name is None
    if no_version:
        latest = utils.get_latest_version(model_dir)
        version_name = f'ver{latest + 1}'

    version_dir = utils.check_or_create_local_path(version_name, model_dir)
    graph_dir = utils.check_or_create_local_path('graph', model_dir)
    graph_version_dir = utils.check_or_create_local_path(
        version_name, graph_dir)

    worlds_dir = utils.check_or_create_local_path('worlds', version_dir)
    previews_dir = utils.check_or_create_local_path('previews', version_dir)
    model_save_dir = utils.check_or_create_local_path('models', version_dir)

    print('Saving source...')
    utils.save_source_to_dir(version_dir)

    print('Loading block images...')
    block_images = utils.load_block_images(res_dir)

    print('Loading encoding dictionaries...')
    block_forward, block_backward = utils.load_encoding_dict(
        res_dir, 'blocks_optimized')

    # Load model and existing weights
    print('Loading model...')

    # Try to load full model, otherwise try to load weights
    size = 64
    cur_models = f'{model_save_dir}\\epoch{initial_epoch - 1}'
    if os.path.exists(f'{cur_models}\\discriminator.h5') and os.path.exists(
            f'{cur_models}\\generator.h5'):
        print('Building model from files...')
        d = load_model(f'{cur_models}\\discriminator.h5')
        g = load_model(f'{cur_models}\\generator.h5')

        if os.path.exists(f'{cur_models}\\d_g.h5'):
            d_on_g = load_model(f'{cur_models}\\d_g.h5')
        else:
            g_optim = Adam(lr=0.0001, beta_1=0.5)
            d_on_g = generator_containing_discriminator(g, d)
            d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim)
    elif os.path.exists(
            f'{cur_models}\\discriminator.weights') and os.path.exists(
                f'{cur_models}\\generator.weights'):
        print('Building model with weights...')
        d_optim = Adam(lr=0.00001)
        d = build_discriminator(size)
        d.load_weights(f'{cur_models}\\discriminator.weights')
        d.compile(loss='binary_crossentropy',
                  optimizer=d_optim,
                  metrics=['accuracy'])

        g = build_generator(size)
        g.load_weights(f'{cur_models}\\generator.weights')

        g_optim = Adam(lr=0.0001, beta_1=0.5)
        d_on_g = generator_containing_discriminator(g, d)
        d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim)
    else:
        print('Building model from scratch...')
        d_optim = Adam(lr=0.00001)
        g_optim = Adam(lr=0.0001, beta_1=0.5)

        d = build_discriminator(size)
        d.compile(loss='binary_crossentropy',
                  optimizer=d_optim,
                  metrics=['accuracy'])
        d.summary()

        g = build_generator(size, latent_dim)
        g.summary()

        d_on_g = generator_containing_discriminator(g, d)
        d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim)

    if no_version:
        # Delete existing worlds and previews if any
        print('Checking for old generated data...')
        utils.delete_files_in_path(worlds_dir)
        utils.delete_files_in_path(previews_dir)

        print('Saving model images...')
        keras.utils.plot_model(d,
                               to_file=f'{version_dir}\\discriminator.png',
                               show_shapes=True,
                               show_layer_names=True)
        keras.utils.plot_model(g,
                               to_file=f'{version_dir}\\generator.png',
                               show_shapes=True,
                               show_layer_names=True)

    # Load Data
    print('Loading worlds...')
    label_dict = utils.load_label_dict(res_dir, 'pro_labels_b')
    x_train = load_worlds_with_label(world_count,
                                     f'{res_dir}\\worlds\\',
                                     label_dict,
                                     1, (size, size),
                                     block_forward,
                                     overlap_x=0.1,
                                     overlap_y=0.1)

    world_count = x_train.shape[0]
    batch_cnt = (world_count - (world_count % batch_size)) // batch_size

    # Set up tensorboard
    print('Setting up tensorboard...')
    tb_manager = TensorboardManager(graph_version_dir, batch_cnt)

    preview_frequency_sec = 5 * 60.0
    for epoch in range(initial_epoch, epochs):

        # Create directories for current epoch
        cur_worlds_dir = utils.check_or_create_local_path(
            f'epoch{epoch}', worlds_dir)
        cur_previews_dir = utils.check_or_create_local_path(
            f'epoch{epoch}', previews_dir)
        cur_models_dir = utils.check_or_create_local_path(
            f'epoch{epoch}', model_save_dir)

        print('Shuffling data...')
        np.random.shuffle(x_train)

        last_save_time = time.time()
        for batch in range(batch_cnt):

            # Get real set of images
            real_worlds = x_train[batch * batch_size:(batch + 1) * batch_size]

            # Get fake set of images
            noise = np.random.normal(0, 1, size=(batch_size, latent_dim))
            fake_worlds = g.predict(noise)

            real_labels = np.ones(
                (batch_size,
                 1))  # np.random.uniform(0.9, 1.1, size=(batch_size,))
            fake_labels = np.zeros(
                (batch_size,
                 1))  # np.random.uniform(-0.1, 0.1, size=(batch_size,))

            # Train discriminator on real worlds
            d.trainable = True
            d_loss = d.train_on_batch(real_worlds, real_labels)
            acc_real = d_loss[1]
            loss_real = d_loss[0]
            tb_manager.log_var('d_acc_real', epoch, batch, d_loss[1])
            tb_manager.log_var('d_loss_real', epoch, batch, d_loss[0])

            # Train discriminator on fake worlds
            d_loss = d.train_on_batch(fake_worlds, fake_labels)
            d.trainable = False
            acc_fake = d_loss[1]
            loss_fake = d_loss[0]
            tb_manager.log_var('d_acc_fake', epoch, batch, d_loss[1])
            tb_manager.log_var('d_loss_fake', epoch, batch, d_loss[0])

            # Training generator on X data, with Y labels
            # noise = np.random.normal(0, 1, (batch_size, 256))

            # Train generator to generate real
            g_loss = d_on_g.train_on_batch(noise, real_labels)
            tb_manager.log_var('g_loss', epoch, batch, g_loss)

            print(
                f'epoch [{epoch}/{epochs}] :: batch [{batch}/{batch_cnt}] :: fake_acc = {acc_fake} :: '
                f'real_acc = {acc_real} :: fake_loss = {loss_fake} :: real_loss = {loss_real} :: gen_loss = {g_loss}'
            )

            # Save models
            time_since_save = time.time() - last_save_time
            if time_since_save >= preview_frequency_sec or batch == batch_cnt - 1:
                print('Saving previews...')
                for i in range(batch_size):
                    generated_world = fake_worlds[i]
                    decoded_world = utils.decode_world_sigmoid(
                        block_backward, generated_world)
                    utils.save_world_data(decoded_world,
                                          f'{cur_worlds_dir}\\world{i}.world')
                    utils.save_world_preview(
                        block_images, decoded_world,
                        f'{cur_previews_dir}\\preview{i}.png')

                print('Saving models...')
                try:
                    d.save(f'{cur_models_dir}\\discriminator.h5')
                    g.save(f'{cur_models_dir}\\generator.h5')
                    d_on_g.save(f'{cur_models_dir}\\d_g.h5')
                    d.save_weights(f'{cur_models_dir}\\discriminator.weights')
                    g.save_weights(f'{cur_models_dir}\\generator.weights')
                    d_on_g.save_weights(f'{cur_models_dir}\\d_g.weights')
                except ImportError:
                    print('Failed to save data.')

                last_save_time = time.time()
示例#8
0
def train(epochs, batch_size, world_count, version_name=None, initial_epoch=0):
    cur_dir = os.getcwd()
    res_dir = os.path.abspath(os.path.join(cur_dir, '..', 'res'))
    all_models_dir = os.path.abspath(os.path.join(cur_dir, '..', 'models'))
    model_dir = utils.check_or_create_local_path('inpainting', all_models_dir)

    utils.delete_empty_versions(model_dir, 0)

    no_version = version_name is None
    if no_version:
        latest = utils.get_latest_version(model_dir)
        version_name = f'ver{latest}'

    version_dir = utils.check_or_create_local_path(version_name, model_dir)
    graph_dir = utils.check_or_create_local_path('graph', model_dir)
    graph_version_dir = utils.check_or_create_local_path(version_name, graph_dir)

    worlds_dir = utils.check_or_create_local_path('worlds', version_dir)
    previews_dir = utils.check_or_create_local_path('previews', version_dir)
    model_save_dir = utils.check_or_create_local_path('models', version_dir)

    print('Saving source...')
    utils.save_source_to_dir(version_dir)

    # Load block images
    print('Loading block images...')
    block_images = utils.load_block_images(res_dir)

    print('Loading encoding dictionaries...')
    block_forward, block_backward = utils.load_encoding_dict(res_dir, 'blocks_optimized')

    # Load model
    print('Loading model...')
    feature_model = auto_encoder.autoencoder_model()
    feature_model.load_weights(f'{all_models_dir}\\auto_encoder\\ver15\\models\\epoch28\\autoencoder.weights')
    feature_layers = [7, 14, 21]

    contextnet = PConvUnet(feature_model, feature_layers, inference_only=False)
    unet = contextnet.build_pconv_unet(train_bn=True, lr=0.0001)
    unet.summary()
    # pconv_unet.load_weights(f'{contextnet_dir}\\ver43\\models\\epoch4\\unet.weights')

    if no_version:
        # Delete existing worlds and previews if any
        print('Checking for old generated data...')
        utils.delete_files_in_path(worlds_dir)
        utils.delete_files_in_path(previews_dir)

    print('Saving model images...')
    keras.utils.plot_model(unet, to_file=f'{version_dir}\\unet.png', show_shapes=True,
                           show_layer_names=True)

    # Set up tensorboard
    print('Setting up tensorboard...')
    tb_writer = tf.summary.FileWriter(logdir=graph_version_dir)
    unet_loss_summary = tf.Summary()
    unet_loss_summary.value.add(tag='unet_loss', simple_value=None)

    # Load Data
    x_train = load_worlds(world_count, f'{res_dir}\\worlds\\', (128, 128), block_forward)

    # Start Training loop
    world_count = x_train.shape[0]
    batch_cnt = (world_count - (world_count % batch_size)) // batch_size

    for epoch in range(initial_epoch, epochs):

        print(f'Epoch = {epoch}')
        # Create directories for current epoch
        cur_worlds_cur = utils.check_or_create_local_path(f'epoch{epoch}', worlds_dir)
        cur_previews_dir = utils.check_or_create_local_path(f'epoch{epoch}', previews_dir)
        cur_models_dir = utils.check_or_create_local_path(f'epoch{epoch}', model_save_dir)

        print('Shuffling data...')
        np.random.shuffle(x_train)

        for batch in range(batch_cnt):

            # Get real set of images
            world_batch = x_train[batch * batch_size:(batch + 1) * batch_size]
            world_batch_masked, world_masks = utils.mask_batch_high(world_batch)

            if batch % 1000 == 999 or batch == batch_cnt - 1:

                # Save model
                try:
                    unet.save(f'{cur_models_dir}\\unet.h5')
                    unet.save_weights(f'{cur_models_dir}\\unet.weights')
                except ImportError:
                    print('Failed to save data.')

                # Save previews
                test = unet.predict([world_batch_masked, world_masks])

                d0 = utils.decode_world_sigmoid(block_backward, world_batch[0])
                utils.save_world_preview(block_images, d0, f'{cur_previews_dir}\\{batch}_orig.png')

                d1 = utils.decode_world_sigmoid(block_backward, test[0])
                utils.save_world_preview(block_images, d1, f'{cur_previews_dir}\\{batch}_fixed.png')

                d2 = utils.decode_world_sigmoid(block_backward, world_batch_masked[0])
                utils.save_world_preview(block_images, d2, f'{cur_previews_dir}\\{batch}_masked.png')

            loss = unet.train_on_batch([world_batch_masked, world_masks], world_batch)

            unet_loss_summary.value[0].simple_value = loss / 1000.0  # Divide by 1000 for better Y-Axis values
            tb_writer.add_summary(unet_loss_summary, (epoch * batch_cnt) + batch)
            tb_writer.flush()

            print(f'epoch [{epoch}/{epochs}] :: batch [{batch}/{batch_cnt}] :: unet_loss = {loss}')