Exemplo n.º 1
0
def block(x, labels, out_channels, num_classes, name, training=True):
    """Builds the residual blocks used in the generator.

  Args:
    x: The 4D input tensor.
    labels: The labels of the class we seek to sample from.
    out_channels: Integer number of features in the output layer.
    num_classes: Integer number of classes in the labels.
    name: The variable scope name for the block.
    training: Whether this block is for training or not.
  Returns:
    A `Tensor` representing the output of the operation.
  """
    with tf.variable_scope(name):
        bn0 = ops.ConditionalBatchNorm(num_classes, name='cbn_0')
        bn1 = ops.ConditionalBatchNorm(num_classes, name='cbn_1')
        x_0 = x
        x = tf.nn.relu(bn0(x, labels))
        x = usample(x)
        x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, training, 'snconv1')
        x = tf.nn.relu(bn1(x, labels))
        x = ops.snconv2d(x, out_channels, 3, 3, 1, 1, training, 'snconv2')

        x_0 = usample(x_0)
        x_0 = ops.snconv2d(x_0, out_channels, 1, 1, 1, 1, training, 'snconv3')

        return x_0 + x
Exemplo n.º 2
0
    def test_conditional_batch_norm_shapes(self):
        """Tests the conditional batch norm layer.

    This is a minimal test to make sure that shapes are OK.
    """
        c_bn = ops.ConditionalBatchNorm(num_categories=1000)
        label = tf.ones([
            10,
        ], dtype=tf.int32)
        image = tf.random.normal([10, 32, 32, 3])
        bn_image = c_bn(image, label)
        self.assertEqual([10, 32, 32, 3], bn_image.shape.as_list())