def discriminator_fn_specgram(images, **kwargs): """Builds discriminator network.""" shape = images.shape normalizer = data_normalizer.registry[kwargs['data_normalizer']](kwargs) images = normalizer.normalize_op(images) images.set_shape(shape) logits, end_points = networks.discriminator( images, kwargs['progress'], lambda block_id: _num_filters_fn(block_id, **kwargs), kwargs['resolution_schedule'], num_blocks=kwargs['num_blocks'], kernel_size=kwargs['kernel_size'], simple_arch=kwargs['simple_arch']) with tf.variable_scope('discriminator_cond'): x = tf.contrib.layers.flatten(end_points['last_conv']) end_points['classification_logits'] = layers.custom_dense( x=x, units=kwargs['num_tokens'], scope='classification_logits') return logits, end_points
def discriminator_fn_specgram(images, **kwargs): """Builds discriminator network.""" shape = images.shape normalizer = data_normalizer.registry[kwargs['data_normalizer']](kwargs) images = normalizer.normalize_op(images) images.set_shape(shape) logits, end_points = networks.discriminator( images, kwargs['progress'], lambda block_id: _num_filters_fn(block_id, **kwargs), kwargs['resolution_schedule'], num_blocks=kwargs['num_blocks'], kernel_size=kwargs['kernel_size'], simple_arch=kwargs['simple_arch']) with tf.variable_scope('discriminator_cond'): x = contrib_layers.flatten(end_points['last_conv']) end_points['classification_logits'] = layers.custom_dense( x=x, units=kwargs['num_tokens'], scope='classification_logits') return logits, end_points