示例#1
0
 def __init__(self, model_path):
     self.model = load_model(model_path)
     self.uses_hist_equalization = get_hdf5_attr(
         model_path, 'decoder_uses_hist_equalization', True)
     self.distribution = DistributionCollection.from_hdf5(model_path)
     self._predict = predict_wrapper(self.model.predict, self.model.output_names)
     self.model._make_predict_function()
示例#2
0
 def __init__(self, model_path):
     self.model = load_model(model_path)
     self.uses_hist_equalization = get_hdf5_attr(
         model_path, 'decoder_uses_hist_equalization', True)
     self.distribution = DistributionCollection.from_hdf5(model_path)
     self._predict = predict_wrapper(self.model.predict,
                                     self.model.output_names)
     self.model._make_predict_function()
示例#3
0
def run(g_weights_fname, d_weights_fname, selected_outputs, nb_samples,
        out_fname):
    generator = load_model(g_weights_fname, render_gan_custom_objects())
    discriminator = load_model(d_weights_fname, render_gan_custom_objects())
    generator._make_predict_function()
    discriminator._make_predict_function()
    dist_json = get_hdf5_attr(g_weights_fname, 'distribution').decode('utf-8')
    dist = diktya.distributions.load_from_json(dist_json)
    os.makedirs(os.path.dirname(out_fname), exist_ok=True)
    dset = DistributionHDF5Dataset(out_fname,
                                   mode='w',
                                   nb_samples=nb_samples,
                                   distribution=dist)
    batch_size = 100
    avialable_datasets = [
        name for name in generator.output_names if name != 'labels'
    ]
    print("Avialable outputs: " + ", ".join(avialable_datasets))
    generator_predict = predict_wrapper(
        lambda x: generator.predict(x, batch_size), generator.output_names)

    def sample_generator():
        z_shape = get_layer(generator.inputs[0]).batch_input_shape
        while True:
            z = np.random.uniform(-1, 1, (batch_size, ) + z_shape[1:])
            outs = generator_predict(z)
            raw_labels = outs.pop('labels')
            pos = 0
            labels = np.zeros(len(raw_labels), dtype=dist.norm_dtype)
            for name, size in dist.norm_nb_elems.items():
                labels[name] = raw_labels[:, pos:pos + size]
                pos += size
            deleted_keys = []
            if selected_outputs != 'all':
                for name in list(outs.keys()):
                    if name not in selected_outputs:
                        del outs[name]
                        deleted_keys.append(name)
            if not outs:
                raise Exception(
                    "Got no outputs. Removed {}. Selected outputs {}".format(
                        deleted_keys, selected_outputs))
            outs['labels'] = labels
            outs['discriminator'] = discriminator.predict(outs['fake'])
            yield outs

    bar = progressbar.ProgressBar(max_value=nb_samples)
    for batch in sample_generator():
        pos = dset.append(**batch)
        bar.update(pos)
        if pos >= nb_samples:
            break
    dset.close()
    print("Saved dataset with fakes and labels to: {}".format(out_fname))
def run(g_weights_fname, d_weights_fname, selected_outputs, nb_samples, out_fname):
    generator = load_model(g_weights_fname, render_gan_custom_objects())
    discriminator = load_model(d_weights_fname, render_gan_custom_objects())
    generator._make_predict_function()
    discriminator._make_predict_function()
    dist_json = get_hdf5_attr(g_weights_fname, 'distribution').decode('utf-8')
    dist = diktya.distributions.load_from_json(dist_json)
    os.makedirs(os.path.dirname(out_fname), exist_ok=True)
    dset = DistributionHDF5Dataset(out_fname, mode='w', nb_samples=nb_samples,
                                   distribution=dist)
    batch_size = 100
    avialable_datasets = [name for name in generator.output_names
                          if name != 'labels']
    print("Avialable outputs: " + ", ".join(avialable_datasets))
    generator_predict = predict_wrapper(lambda x: generator.predict(x, batch_size),
                                        generator.output_names)

    def sample_generator():
        z_shape = get_layer(generator.inputs[0]).batch_input_shape
        while True:
            z = np.random.uniform(-1, 1, (batch_size, ) + z_shape[1:])
            outs = generator_predict(z)
            raw_labels = outs.pop('labels')
            pos = 0
            labels = np.zeros(len(raw_labels), dtype=dist.norm_dtype)
            for name, size in dist.norm_nb_elems.items():
                labels[name] = raw_labels[:, pos:pos+size]
                pos += size
            deleted_keys = []
            if selected_outputs != 'all':
                for name in list(outs.keys()):
                    if name not in selected_outputs:
                        del outs[name]
                        deleted_keys.append(name)
            if not outs:
                raise Exception("Got no outputs. Removed {}. Selected outputs {}"
                                .format(deleted_keys, selected_outputs))
            outs['labels'] = labels
            outs['discriminator'] = discriminator.predict(outs['fake'])
            yield outs

    bar = progressbar.ProgressBar(max_value=nb_samples)
    for batch in sample_generator():
        pos = dset.append(**batch)
        bar.update(pos)
        if pos >= nb_samples:
            break
    dset.close()
    print("Saved dataset with fakes and labels to: {}".format(out_fname))
示例#5
0
def train_callbacks(rendergan,
                    output_dir,
                    nb_visualise,
                    real_hdf5_fname,
                    distribution,
                    lr_schedule=None,
                    overwrite=False):
    save_gan_cb = SaveGAN(rendergan,
                          join(output_dir, "models/{epoch:03d}/{name}.hdf5"),
                          every_epoch=10,
                          hdf5_attrs=get_distribution_hdf5_attrs(distribution))
    nb_score = 1000

    sample_fn = predict_wrapper(
        rendergan.sample_generator_given_z.predict,
        rendergan.sample_generator_given_z_output_names)

    real = next(train_data_generator(real_hdf5_fname, nb_score, 1))['data']

    vis_cb = VisualiseTag3dAndFake(nb_samples=nb_visualise // 2,
                                   output_dir=join(output_dir,
                                                   'visualise_tag3d_fake'),
                                   show=False,
                                   preprocess=lambda x: np.clip(x, -1, 1))
    vis_all = VisualiseAll(nb_samples=nb_visualise //
                           len(rendergan.sample_generator_given_z.outputs),
                           output_dir=join(output_dir, 'visualise_all'),
                           show=False,
                           preprocess=lambda x: np.clip(x, -1, 1))

    vis_fake_sorted = VisualiseFakesSorted(
        nb_samples=nb_visualise,
        output_dir=join(output_dir, 'visualise_fakes_sorted'),
        show=False,
        preprocess=lambda x: np.clip(x, -1, 1))

    vis_real_sorted = VisualiseRealsSorted(
        nb_samples=nb_visualise,
        output_dir=join(output_dir, 'visualise_reals_sorted'),
        show=False,
        preprocess=lambda x: np.clip(x, -1, 1))

    def default_lr_schedule(lr):
        return {
            200: lr / 4,
            250: lr / 4**2,
            300: lr / 4**3,
        }

    def lr_scheduler(opt):
        return LearningRateScheduler(opt,
                                     lr_schedule(float(K.get_value(opt.lr))))

    if lr_schedule is None:
        lr_schedule = default_lr_schedule

    g_optimizer = rendergan.gan.g_optimizer
    d_optimizer = rendergan.gan.d_optimizer
    lr_schedulers = [
        lr_scheduler(g_optimizer),
        lr_scheduler(d_optimizer),
    ]
    hist_dir = join(output_dir, "history")
    os.makedirs(hist_dir, exist_ok=True)
    hist = HistoryPerBatch(hist_dir)

    def history_plot(e, logs={}):
        fig, _ = hist.plot(save_as="{:03d}.png".format(e),
                           metrics=['g_loss', 'd_loss'])
        plt.close(fig)  # allows fig to be garbage collected

    hist_save = OnEpochEnd(history_plot, every_nth_epoch=20)

    sample_outdir = join(output_dir, 'samples')
    os.makedirs(sample_outdir, exist_ok=True)
    store_samples_cb = StoreSamples(sample_outdir, distribution, overwrite)

    dscore_outdir = join(output_dir, 'd_score_hist')
    os.makedirs(dscore_outdir, exist_ok=True)
    dscore = DScoreHistogram(dscore_outdir)

    nb_sample = max(nb_score, nb_visualise)
    sample_cb = SampleGAN(sample_fn,
                          rendergan.discriminator.predict,
                          rendergan.gan.random_z(nb_sample),
                          real,
                          callbacks=[
                              vis_cb, vis_fake_sorted, vis_all,
                              vis_real_sorted, dscore, store_samples_cb
                          ])
    return [sample_cb, save_gan_cb, hist, hist_save] + lr_schedulers
示例#6
0
def train_callbacks(rendergan, output_dir, nb_visualise, real_hdf5_fname,
                    distribution, lr_schedule=None, overwrite=False):
    save_gan_cb = SaveGAN(rendergan, join(output_dir, "models/{epoch:03d}/{name}.hdf5"),
                          every_epoch=10, hdf5_attrs=get_distribution_hdf5_attrs(distribution))
    nb_score = 1000

    sample_fn = predict_wrapper(rendergan.sample_generator_given_z.predict,
                                rendergan.sample_generator_given_z_output_names)

    real = next(train_data_generator(real_hdf5_fname, nb_score, 1))['data']

    vis_cb = VisualiseTag3dAndFake(
        nb_samples=nb_visualise // 2,
        output_dir=join(output_dir, 'visualise_tag3d_fake'),
        show=False,
        preprocess=lambda x: np.clip(x, -1, 1)
    )
    vis_all = VisualiseAll(
        nb_samples=nb_visualise // len(rendergan.sample_generator_given_z.outputs),
        output_dir=join(output_dir, 'visualise_all'),
        show=False,
        preprocess=lambda x: np.clip(x, -1, 1))

    vis_fake_sorted = VisualiseFakesSorted(
        nb_samples=nb_visualise,
        output_dir=join(output_dir, 'visualise_fakes_sorted'),
        show=False,
        preprocess=lambda x: np.clip(x, -1, 1))

    vis_real_sorted = VisualiseRealsSorted(
        nb_samples=nb_visualise,
        output_dir=join(output_dir, 'visualise_reals_sorted'),
        show=False,
        preprocess=lambda x: np.clip(x, -1, 1))

    def default_lr_schedule(lr):
        return {
            200: lr / 4,
            250: lr / 4**2,
            300: lr / 4**3,
        }

    def lr_scheduler(opt):
        return LearningRateScheduler(opt, lr_schedule(float(K.get_value(opt.lr))))

    if lr_schedule is None:
        lr_schedule = default_lr_schedule

    g_optimizer = rendergan.gan.g_optimizer
    d_optimizer = rendergan.gan.d_optimizer
    lr_schedulers = [
        lr_scheduler(g_optimizer),
        lr_scheduler(d_optimizer),
    ]
    hist_dir = join(output_dir, "history")
    os.makedirs(hist_dir, exist_ok=True)
    hist = HistoryPerBatch(hist_dir)

    def history_plot(e, logs={}):
        fig, _ = hist.plot(save_as="{:03d}.png".format(e), metrics=['g_loss', 'd_loss'])
        plt.close(fig)  # allows fig to be garbage collected
    hist_save = OnEpochEnd(history_plot, every_nth_epoch=20)

    sample_outdir = join(output_dir, 'samples')
    os.makedirs(sample_outdir, exist_ok=True)
    store_samples_cb = StoreSamples(sample_outdir, distribution, overwrite)

    dscore_outdir = join(output_dir, 'd_score_hist')
    os.makedirs(dscore_outdir, exist_ok=True)
    dscore = DScoreHistogram(dscore_outdir)

    nb_sample = max(nb_score, nb_visualise)
    sample_cb = SampleGAN(sample_fn, rendergan.discriminator.predict,
                          rendergan.gan.random_z(nb_sample), real,
                          callbacks=[vis_cb, vis_fake_sorted, vis_all, vis_real_sorted,
                                     dscore, store_samples_cb])
    return [sample_cb, save_gan_cb, hist, hist_save] + lr_schedulers