示例#1
0
    def __init__(self, generator: Model, discriminator: Model):
        self.g = generator
        assert hasattr(self.g, 'optimizer'), "Did you forgot to call model.compile(...)?"
        self.g_optimizer = keras_copy(self.g.optimizer)

        self.d = discriminator
        assert hasattr(self.d, 'optimizer'), "Did you forgot to call model.compile(...)?"
        self.d_optimizer = keras_copy(self.d.optimizer)

        self.z = self.g.inputs[0]
        self.z_input_layer = get_layer(self.z)
        self.z_shape = self.z_input_layer.get_output_shape_at(0)
        self._build()
示例#2
0
def test_simple_gan_generator():
    nb_units = 2
    bs = 8
    z = Input(batch_shape=(bs, 50))
    labels = Input(batch_shape=(bs, 22))
    depth_map = Input(batch_shape=(bs, 1, 16, 16))
    tag3d = Input(batch_shape=(bs, 1, 64, 64))
    blur, (light_sb, light_sw, light_t), background, details = \
        simple_gan_generator(nb_units, z, labels, depth_map, tag3d, depth=2)

    inputs = [z, labels, depth_map, tag3d]
    fn = K.function([K.learning_phase(), z, labels, depth_map, tag3d],
                    [blur, light_sb, light_sw, light_t, background, details])
    out = fn([1] + [np.random.sample(get_layer(x).output_shape) for x in inputs])
    assert out[0].shape == (bs, 1)
    for i, out_arr in enumerate(out[1:]):
        assert out_arr.shape == (bs, 1, 64, 64), i + 1
示例#3
0
def test_simple_gan_generator():
    nb_units = 2
    bs = 8
    z = Input(batch_shape=(bs, 50))
    labels = Input(batch_shape=(bs, 22))
    depth_map = Input(batch_shape=(bs, 1, 16, 16))
    tag3d = Input(batch_shape=(bs, 1, 64, 64))
    blur, (light_sb, light_sw, light_t), background, details = \
        simple_gan_generator(nb_units, z, labels, depth_map, tag3d, depth=2)

    inputs = [z, labels, depth_map, tag3d]
    fn = K.function([K.learning_phase(), z, labels, depth_map, tag3d],
                    [blur, light_sb, light_sw, light_t, background, details])
    out = fn([1] +
             [np.random.sample(get_layer(x).output_shape) for x in inputs])
    assert out[0].shape == (bs, 1)
    for i, out_arr in enumerate(out[1:]):
        assert out_arr.shape == (bs, 1, 64, 64), i + 1
 def sample_generator():
     z_shape = get_layer(generator.inputs[0]).batch_input_shape
     while True:
         z = np.random.uniform(-1, 1, (batch_size, ) + z_shape[1:])
         outs = generator_predict(z)
         raw_labels = outs.pop('labels')
         pos = 0
         labels = np.zeros(len(raw_labels), dtype=dist.norm_dtype)
         for name, size in dist.norm_nb_elems.items():
             labels[name] = raw_labels[:, pos:pos+size]
             pos += size
         deleted_keys = []
         if selected_outputs != 'all':
             for name in list(outs.keys()):
                 if name not in selected_outputs:
                     del outs[name]
                     deleted_keys.append(name)
         if not outs:
             raise Exception("Got no outputs. Removed {}. Selected outputs {}"
                             .format(deleted_keys, selected_outputs))
         outs['labels'] = labels
         outs['discriminator'] = discriminator.predict(outs['fake'])
         yield outs
示例#5
0
 def sample_generator():
     z_shape = get_layer(generator.inputs[0]).batch_input_shape
     while True:
         z = np.random.uniform(-1, 1, (batch_size, ) + z_shape[1:])
         outs = generator_predict(z)
         raw_labels = outs.pop('labels')
         pos = 0
         labels = np.zeros(len(raw_labels), dtype=dist.norm_dtype)
         for name, size in dist.norm_nb_elems.items():
             labels[name] = raw_labels[:, pos:pos + size]
             pos += size
         deleted_keys = []
         if selected_outputs != 'all':
             for name in list(outs.keys()):
                 if name not in selected_outputs:
                     del outs[name]
                     deleted_keys.append(name)
         if not outs:
             raise Exception(
                 "Got no outputs. Removed {}. Selected outputs {}".format(
                     deleted_keys, selected_outputs))
         outs['labels'] = labels
         outs['discriminator'] = discriminator.predict(outs['fake'])
         yield outs