示例#1
0
def run(g_weights_fname, d_weights_fname, selected_outputs, nb_samples,
        out_fname):
    generator = load_model(g_weights_fname, render_gan_custom_objects())
    discriminator = load_model(d_weights_fname, render_gan_custom_objects())
    generator._make_predict_function()
    discriminator._make_predict_function()
    dist_json = get_hdf5_attr(g_weights_fname, 'distribution').decode('utf-8')
    dist = diktya.distributions.load_from_json(dist_json)
    os.makedirs(os.path.dirname(out_fname), exist_ok=True)
    dset = DistributionHDF5Dataset(out_fname,
                                   mode='w',
                                   nb_samples=nb_samples,
                                   distribution=dist)
    batch_size = 100
    avialable_datasets = [
        name for name in generator.output_names if name != 'labels'
    ]
    print("Avialable outputs: " + ", ".join(avialable_datasets))
    generator_predict = predict_wrapper(
        lambda x: generator.predict(x, batch_size), generator.output_names)

    def sample_generator():
        z_shape = get_layer(generator.inputs[0]).batch_input_shape
        while True:
            z = np.random.uniform(-1, 1, (batch_size, ) + z_shape[1:])
            outs = generator_predict(z)
            raw_labels = outs.pop('labels')
            pos = 0
            labels = np.zeros(len(raw_labels), dtype=dist.norm_dtype)
            for name, size in dist.norm_nb_elems.items():
                labels[name] = raw_labels[:, pos:pos + size]
                pos += size
            deleted_keys = []
            if selected_outputs != 'all':
                for name in list(outs.keys()):
                    if name not in selected_outputs:
                        del outs[name]
                        deleted_keys.append(name)
            if not outs:
                raise Exception(
                    "Got no outputs. Removed {}. Selected outputs {}".format(
                        deleted_keys, selected_outputs))
            outs['labels'] = labels
            outs['discriminator'] = discriminator.predict(outs['fake'])
            yield outs

    bar = progressbar.ProgressBar(max_value=nb_samples)
    for batch in sample_generator():
        pos = dset.append(**batch)
        bar.update(pos)
        if pos >= nb_samples:
            break
    dset.close()
    print("Saved dataset with fakes and labels to: {}".format(out_fname))
def run(g_weights_fname, d_weights_fname, selected_outputs, nb_samples, out_fname):
    generator = load_model(g_weights_fname, render_gan_custom_objects())
    discriminator = load_model(d_weights_fname, render_gan_custom_objects())
    generator._make_predict_function()
    discriminator._make_predict_function()
    dist_json = get_hdf5_attr(g_weights_fname, 'distribution').decode('utf-8')
    dist = diktya.distributions.load_from_json(dist_json)
    os.makedirs(os.path.dirname(out_fname), exist_ok=True)
    dset = DistributionHDF5Dataset(out_fname, mode='w', nb_samples=nb_samples,
                                   distribution=dist)
    batch_size = 100
    avialable_datasets = [name for name in generator.output_names
                          if name != 'labels']
    print("Avialable outputs: " + ", ".join(avialable_datasets))
    generator_predict = predict_wrapper(lambda x: generator.predict(x, batch_size),
                                        generator.output_names)

    def sample_generator():
        z_shape = get_layer(generator.inputs[0]).batch_input_shape
        while True:
            z = np.random.uniform(-1, 1, (batch_size, ) + z_shape[1:])
            outs = generator_predict(z)
            raw_labels = outs.pop('labels')
            pos = 0
            labels = np.zeros(len(raw_labels), dtype=dist.norm_dtype)
            for name, size in dist.norm_nb_elems.items():
                labels[name] = raw_labels[:, pos:pos+size]
                pos += size
            deleted_keys = []
            if selected_outputs != 'all':
                for name in list(outs.keys()):
                    if name not in selected_outputs:
                        del outs[name]
                        deleted_keys.append(name)
            if not outs:
                raise Exception("Got no outputs. Removed {}. Selected outputs {}"
                                .format(deleted_keys, selected_outputs))
            outs['labels'] = labels
            outs['discriminator'] = discriminator.predict(outs['fake'])
            yield outs

    bar = progressbar.ProgressBar(max_value=nb_samples)
    for batch in sample_generator():
        pos = dset.append(**batch)
        bar.update(pos)
        if pos >= nb_samples:
            break
    dset.close()
    print("Saved dataset with fakes and labels to: {}".format(out_fname))
示例#3
0
 def __init__(self, **config):
     self.model = load_model(config['model_path'])
     # We can't use model.compile because it requires an optimizer and a loss function.
     # Since we only use the model for inference, we call the private function
     # _make_predict_function(). This is exactly what keras would do otherwise the first
     # time model.predict() is called.
     self.model._make_predict_function()
示例#4
0
 def __init__(self, model_path):
     self.model = load_model(model_path)
     self.uses_hist_equalization = get_hdf5_attr(
         model_path, 'decoder_uses_hist_equalization', True)
     self.distribution = DistributionCollection.from_hdf5(model_path)
     self._predict = predict_wrapper(self.model.predict, self.model.output_names)
     self.model._make_predict_function()
示例#5
0
 def __init__(self, **config):
     self.model = load_model(config['model_path'])
     # We can't use model.compile because it requires an optimizer and a loss function.
     # Since we only use the model for inference, we call the private function
     # _make_predict_function(). This is exactly what keras would do otherwise the first
     # time model.predict() is called.
     self.model._make_predict_function()
示例#6
0
 def __init__(self, model_path):
     self.model = load_model(model_path)
     self.uses_hist_equalization = get_hdf5_attr(
         model_path, 'decoder_uses_hist_equalization', True)
     self.distribution = DistributionCollection.from_hdf5(model_path)
     self._predict = predict_wrapper(self.model.predict,
                                     self.model.output_names)
     self.model._make_predict_function()
示例#7
0
def load_tag3d_network(weight_fname):
    model = load_model(weight_fname)
    for layer in model.layers:
        layer.trainable = False
    return model
示例#8
0
 def __init__(self,
              model_path,
              threshold=0.6):
     self.saliency_threshold = float(threshold)
     self.model = load_model(model_path)
     self.model._make_predict_function()
示例#9
0
 def __init__(self, model_path, threshold=0.6):
     self.saliency_threshold = threshold
     self.model = load_model(model_path)
     self.model._make_predict_function()
示例#10
0
def load_tag3d_network(weight_fname):
    model = load_model(weight_fname)
    for layer in model.layers:
        layer.trainable = False
    return model
示例#11
0
 def __init__(self, model_path=None):
     if model_path is None:
         model_path = os.join.path(os.path.dirname(__file__), '../model/tag_matcher.model')
     self.model = load_model(model_path)
     self.model._make_predict_function()