def __init__(self, model_path): self.model = load_model(model_path) self.uses_hist_equalization = get_hdf5_attr( model_path, 'decoder_uses_hist_equalization', True) self.distribution = DistributionCollection.from_hdf5(model_path) self._predict = predict_wrapper(self.model.predict, self.model.output_names) self.model._make_predict_function()
def run(g_weights_fname, d_weights_fname, selected_outputs, nb_samples, out_fname): generator = load_model(g_weights_fname, render_gan_custom_objects()) discriminator = load_model(d_weights_fname, render_gan_custom_objects()) generator._make_predict_function() discriminator._make_predict_function() dist_json = get_hdf5_attr(g_weights_fname, 'distribution').decode('utf-8') dist = diktya.distributions.load_from_json(dist_json) os.makedirs(os.path.dirname(out_fname), exist_ok=True) dset = DistributionHDF5Dataset(out_fname, mode='w', nb_samples=nb_samples, distribution=dist) batch_size = 100 avialable_datasets = [ name for name in generator.output_names if name != 'labels' ] print("Avialable outputs: " + ", ".join(avialable_datasets)) generator_predict = predict_wrapper( lambda x: generator.predict(x, batch_size), generator.output_names) def sample_generator(): z_shape = get_layer(generator.inputs[0]).batch_input_shape while True: z = np.random.uniform(-1, 1, (batch_size, ) + z_shape[1:]) outs = generator_predict(z) raw_labels = outs.pop('labels') pos = 0 labels = np.zeros(len(raw_labels), dtype=dist.norm_dtype) for name, size in dist.norm_nb_elems.items(): labels[name] = raw_labels[:, pos:pos + size] pos += size deleted_keys = [] if selected_outputs != 'all': for name in list(outs.keys()): if name not in selected_outputs: del outs[name] deleted_keys.append(name) if not outs: raise Exception( "Got no outputs. Removed {}. Selected outputs {}".format( deleted_keys, selected_outputs)) outs['labels'] = labels outs['discriminator'] = discriminator.predict(outs['fake']) yield outs bar = progressbar.ProgressBar(max_value=nb_samples) for batch in sample_generator(): pos = dset.append(**batch) bar.update(pos) if pos >= nb_samples: break dset.close() print("Saved dataset with fakes and labels to: {}".format(out_fname))
def run(g_weights_fname, d_weights_fname, selected_outputs, nb_samples, out_fname): generator = load_model(g_weights_fname, render_gan_custom_objects()) discriminator = load_model(d_weights_fname, render_gan_custom_objects()) generator._make_predict_function() discriminator._make_predict_function() dist_json = get_hdf5_attr(g_weights_fname, 'distribution').decode('utf-8') dist = diktya.distributions.load_from_json(dist_json) os.makedirs(os.path.dirname(out_fname), exist_ok=True) dset = DistributionHDF5Dataset(out_fname, mode='w', nb_samples=nb_samples, distribution=dist) batch_size = 100 avialable_datasets = [name for name in generator.output_names if name != 'labels'] print("Avialable outputs: " + ", ".join(avialable_datasets)) generator_predict = predict_wrapper(lambda x: generator.predict(x, batch_size), generator.output_names) def sample_generator(): z_shape = get_layer(generator.inputs[0]).batch_input_shape while True: z = np.random.uniform(-1, 1, (batch_size, ) + z_shape[1:]) outs = generator_predict(z) raw_labels = outs.pop('labels') pos = 0 labels = np.zeros(len(raw_labels), dtype=dist.norm_dtype) for name, size in dist.norm_nb_elems.items(): labels[name] = raw_labels[:, pos:pos+size] pos += size deleted_keys = [] if selected_outputs != 'all': for name in list(outs.keys()): if name not in selected_outputs: del outs[name] deleted_keys.append(name) if not outs: raise Exception("Got no outputs. Removed {}. Selected outputs {}" .format(deleted_keys, selected_outputs)) outs['labels'] = labels outs['discriminator'] = discriminator.predict(outs['fake']) yield outs bar = progressbar.ProgressBar(max_value=nb_samples) for batch in sample_generator(): pos = dset.append(**batch) bar.update(pos) if pos >= nb_samples: break dset.close() print("Saved dataset with fakes and labels to: {}".format(out_fname))
def get_distribution(hdf5_fname): return diktya.distributions.load_from_json(get_hdf5_attr(hdf5_fname, "distribution").decode("utf-8"))
def from_hdf5(cls, fname): dist_bytes = get_hdf5_attr(fname, 'distribution') dist_config = json.loads(dist_bytes.decode('utf-8')) return cls.from_config(dist_config)