Exemplo n.º 1
0
def linear_layer(x,
                 is_training,
                 num_classes,
                 use_bias=True,
                 use_bn=False,
                 name='linear_layer'):
    """Linear head for linear evaluation.

  Args:
    x: hidden state tensor of shape (bsz, dim).
    is_training: boolean indicator for training or test.
    num_classes: number of classes.
    use_bias: whether or not to use bias.
    use_bn: whether or not to use BN for output units.
    name: the name for variable scope.

  Returns:
    logits of shape (bsz, num_classes)
  """
    assert x.shape.ndims == 2, x.shape
    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        x = tf.layers.dense(
            inputs=x,
            units=num_classes,
            use_bias=use_bias and not use_bn,
            kernel_initializer=tf.random_normal_initializer(stddev=.01))
        if use_bn:
            x = resnet.batch_norm_relu(x,
                                       is_training,
                                       relu=False,
                                       center=use_bias)
        x = tf.identity(x, '%s_out' % name)
    return x
Exemplo n.º 2
0
    def __call__(self, images, training):
        """Add operations to classify a batch of input images.

    Args:
      inputs: A Tensor representing a batch of input images.
      training: A boolean. Set to True to add operations required only when
        training the classifier.

    Returns:
      A logits Tensor with shape [<batch_size>, self.num_classes].
    """

        if self.data_format == 'channels_first':
            # Convert the inputs from channels_last (NHWC) to channels_first (NCHW).
            # This provides a large performance boost on GPU. See
            # https://www.tensorflow.org/performance/performance_guide#data_formats
            images = tf.transpose(images, [0, 3, 1, 2])

        # representation learning and classifier network
        with tf.variable_scope('resnet'):
            inputs = resnet.conv2d_fixed_padding(inputs=images,
                                                 filters=self.num_filters,
                                                 kernel_size=self.kernel_size,
                                                 strides=self.conv_stride,
                                                 data_format=self.data_format)
            inputs = tf.identity(inputs, 'initial_conv')

            if self.first_pool_size:
                inputs = tf.layers.max_pooling2d(
                    inputs=inputs,
                    pool_size=self.first_pool_size,
                    strides=self.first_pool_stride,
                    padding='SAME',
                    data_format=self.data_format)
                inputs = tf.identity(inputs, 'initial_max_pool')

            for i, num_blocks in enumerate(self.block_sizes):
                num_filters = self.num_filters * (2**i)
                inputs = resnet.block_layer(inputs=inputs,
                                            filters=num_filters,
                                            block_fn=self.block_fn,
                                            blocks=num_blocks,
                                            strides=self.block_strides[i],
                                            training=training,
                                            name='block_layer{}'.format(i + 1),
                                            data_format=self.data_format)

            inputs = resnet.batch_norm_relu(inputs, training, self.data_format)
            inputs = tf.layers.average_pooling2d(
                inputs=inputs,
                pool_size=self.second_pool_size,
                strides=self.second_pool_stride,
                padding='VALID',
                data_format=self.data_format)
            inputs = tf.identity(inputs, 'final_avg_pool')
            inputs = tf.reshape(inputs, [-1, self.final_size])
            features = tf.identity(inputs, 'features')
            inputs = tf.layers.dense(inputs=inputs, units=self.num_classes)
            inputs = tf.identity(inputs, 'final_dense')
            scope = tf.get_variable_scope()
            scope.reuse_variables()

        variables_graph = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                            scope='resnet')

        # distilling network
        with tf.variable_scope('store_resnet'):
            dis_inputs = resnet.conv2d_fixed_padding(
                inputs=images,
                filters=self.num_filters,
                kernel_size=self.kernel_size,
                strides=self.conv_stride,
                data_format=self.data_format)
            dis_inputs = tf.identity(dis_inputs, 'initial_conv')

            if self.first_pool_size:
                dis_inputs = tf.layers.max_pooling2d(
                    inputs=dis_inputs,
                    pool_size=self.first_pool_size,
                    strides=self.first_pool_stride,
                    padding='SAME',
                    data_format=self.data_format)
                dis_inputs = tf.identity(dis_inputs, 'initial_max_pool')

            for i, num_blocks in enumerate(self.block_sizes):
                num_filters = self.num_filters * (2**i)
                dis_inputs = resnet.block_layer(inputs=dis_inputs,
                                                filters=num_filters,
                                                block_fn=self.block_fn,
                                                blocks=num_blocks,
                                                strides=self.block_strides[i],
                                                training=False,
                                                name='block_layer{}'.format(i +
                                                                            1),
                                                data_format=self.data_format)

            dis_inputs = resnet.batch_norm_relu(dis_inputs, False,
                                                self.data_format)
            dis_inputs = tf.layers.average_pooling2d(
                inputs=dis_inputs,
                pool_size=self.second_pool_size,
                strides=self.second_pool_stride,
                padding='VALID',
                data_format=self.data_format)
            dis_inputs = tf.identity(dis_inputs, 'final_avg_pool')
            dis_inputs = tf.reshape(dis_inputs, [-1, self.final_size])
            dis_features = tf.identity(dis_inputs, 'features')
            dis_inputs = tf.layers.dense(inputs=dis_inputs,
                                         units=self.num_classes)
            dis_inputs = tf.identity(dis_inputs, 'final_dense')
            scope = tf.get_variable_scope()
            scope.reuse_variables()

        store_variables_graph = tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES, scope='store_resnet')

        return features, inputs, variables_graph, dis_inputs
Exemplo n.º 3
0
    def generate(self, Z, Y):
        is_training = True
        data_format = 'channels_last'
        yb = tf.reshape(Y, tf.stack([-1, 1, 1, self.dim_y]))
        Z = tf.concat(axis=1, values=[Z, Y])
        print_layer_shape(Z, 'gen z ')
        h1 = tf.layers.dense(Z, self.dim_W1, use_bias=False, name='gen_h1')
        # h1 = batch_norm_relu_1d(h1, is_training, data_format)
        h1 = tf.layers.batch_normalization(inputs=h1,
                                           axis=1,
                                           momentum=0.997,
                                           epsilon=1e-5,
                                           center=True,
                                           scale=True,
                                           training=is_training,
                                           fused=True)
        h1 = tf.nn.relu(h1)
        print_layer_shape(h1, 'gen h1 ')
        h1 = tf.concat(axis=1, values=[h1, Y])
        h2 = tf.layers.dense(h1,
                             self.dim_W2 * 3 * 1,
                             use_bias=False,
                             name='gen_h2')
        h2 = tf.layers.batch_normalization(inputs=h2,
                                           axis=1,
                                           momentum=0.997,
                                           epsilon=1e-5,
                                           center=True,
                                           scale=True,
                                           training=is_training,
                                           fused=True)
        h2 = tf.nn.relu(h2)
        # h2 = batch_norm_relu_1d(h2, is_training, data_format)
        # h2 = tf.nn.relu(batchnormalize(tf.matmul(h1, self.gen_W2)))
        h2 = tf.reshape(h2, tf.stack([-1, 3, 1, self.dim_W2]))
        pattern = tf.stack([1, 3, 1, 1])
        yb2 = tf.tile(yb, pattern)
        h2 = tf.concat(axis=3, values=[h2, yb2])
        print_layer_shape(h2, 'gen h2 ')
        # h2 = tf.concat(axis=3, values=[h2, yb*tf.ones([self.batch_size, 7, 7, self.dim_y])])

        # output_shape_l3 = [-1,14,14,self.dim_W3]
        # h3 = tf.nn.conv2d_transpose(h2, self.gen_W3, output_shape=output_shape_l3, strides=[1,2,2,1])
        # h3 = tf.nn.relu( batchnormalize(h3) )
        # h3 = tf.concat(axis=3, values=[h3, yb*tf.ones([self.batch_size, 14,14,self.dim_y])] )
        #6,2
        h3 = tf.layers.conv2d_transpose(h2,
                                        self.dim_W3,
                                        kernel_size=5,
                                        strides=2,
                                        padding='same',
                                        name='gen_h3')
        h3 = batch_norm_relu(h3, is_training, data_format)
        # pattern = tf.stack([1, 49, 20, 1])
        # yb3 = tf.tile(yb, pattern)
        # h3 = tf.concat(axis=3, values=[h3, yb3] )
        h3 = self.concat_yb(h3, yb)
        print_layer_shape(h3, 'gen h3 ')
        # output_shape_l4 = [self.batch_size,28,28,self.dim_channel]
        # h4 = tf.nn.conv2d_transpose(h3, self.gen_W4, output_shape=output_shape_l4, strides=[1,2,2,1])
        #12,4
        h4 = tf.layers.conv2d_transpose(h3,
                                        kernel_size=5,
                                        filters=128,
                                        strides=2,
                                        padding='same',
                                        name='gen_h4_1')
        h4 = batch_norm_relu(h4, is_training, data_format)
        #12 5
        h4 = tf.pad(h4, [[0, 0], [0, 0], [0, 1], [0, 0]])
        h4 = self.concat_yb(h4, yb)
        print_layer_shape(h4, 'gen h4 1 ')
        #24 10
        h4 = tf.layers.conv2d_transpose(h4,
                                        kernel_size=5,
                                        filters=256,
                                        strides=2,
                                        padding='same',
                                        name='gen_h4_2')
        h4 = batch_norm_relu(h4, is_training, data_format)
        h4 = self.concat_yb(h4, yb)
        print_layer_shape(h4, 'gen h4 2 ')
        #48 20
        h4 = tf.layers.conv2d_transpose(h4,
                                        kernel_size=5,
                                        filters=512,
                                        strides=2,
                                        padding='same',
                                        name='gen_h4_3')
        h4 = batch_norm_relu(h4, is_training, data_format)
        #49 20
        h4 = tf.pad(h4, [[0, 0], [1, 0], [0, 0], [0, 0]])
        h4 = self.concat_yb(h4, yb)
        print_layer_shape(h4, 'gen h4 3 ')

        h4 = tf.layers.conv2d_transpose(h4,
                                        kernel_size=5,
                                        filters=self.dim_channel,
                                        strides=2,
                                        padding='same',
                                        name='gen_h4')
        print_layer_shape(h4, 'gen h4')
        return h4