예제 #1
0
def Layernorm(name, norm_axes, inputs):
    mean, var = tf.nn.moments(inputs, norm_axes, keep_dims=True)

    # Assume the 'neurons' axis is the first of norm_axes. This is the case for fully-connected and BCHW conv layers.
    n_neurons = inputs.get_shape().as_list()[norm_axes[0]]
    offset = lib.param(name + '.offset', np.zeros(n_neurons, dtype='float32'))
    scale = lib.param(name + '.scale', np.ones(n_neurons, dtype='float32'))

    # Add broadcasting dims to offset and scale (e.g. BCHW conv data)
    offset = tf.reshape(offset, [-1] + [1 for i in range(len(norm_axes) - 1)])
    scale = tf.reshape(scale, [-1] + [1 for i in range(len(norm_axes) - 1)])

    result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5)
    return result
예제 #2
0
def Batchnorm(name,
              axes,
              inputs,
              is_training=None,
              stats_iter=None,
              update_moving_stats=True,
              fused=True,
              labels=None,
              n_labels=None):
    """conditional batchnorm (dumoulin et al 2016) for BCHW conv filtermaps"""
    if axes != [0, 2, 3]:
        raise Exception('unsupported')
    mean, var = tf.nn.moments(inputs, axes, keep_dims=True)
    shape = mean.get_shape().as_list()  # shape is [1,n,1,1]
    offset_m = lib.param(name + '.offset',
                         np.zeros([n_labels, shape[1]], dtype='float32'))
    scale_m = lib.param(name + '.scale',
                        np.ones([n_labels, shape[1]], dtype='float32'))
    offset = tf.nn.embedding_lookup(offset_m, labels)
    scale = tf.nn.embedding_lookup(scale_m, labels)
    result = tf.nn.batch_normalization(inputs, mean, var, offset[:, :, None,
                                                                 None],
                                       scale[:, :, None, None], 1e-5)
    return result
예제 #3
0
def Batchnorm(name, axes, inputs, is_training=None, stats_iter=None, update_moving_stats=True, fused=True):
    if ((axes == [0,2,3]) or (axes == [0,2])) and fused==True:
        if axes==[0,2]:
            inputs = tf.expand_dims(inputs, 3)
        # Old (working but pretty slow) implementation:
        ##########

        # inputs = tf.transpose(inputs, [0,2,3,1])

        # mean, var = tf.nn.moments(inputs, [0,1,2], keep_dims=False)
        # offset = lib.param(name+'.offset', np.zeros(mean.get_shape()[-1], dtype='float32'))
        # scale = lib.param(name+'.scale', np.ones(var.get_shape()[-1], dtype='float32'))
        # result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-4)

        # return tf.transpose(result, [0,3,1,2])

        # New (super fast but untested) implementation:
        offset = lib.param(name+'.offset', np.zeros(inputs.get_shape()[1], dtype='float32'))
        scale = lib.param(name+'.scale', np.ones(inputs.get_shape()[1], dtype='float32'))

        moving_mean = lib.param(name+'.moving_mean', np.zeros(inputs.get_shape()[1], dtype='float32'), trainable=False)
        moving_variance = lib.param(name+'.moving_variance', np.ones(inputs.get_shape()[1], dtype='float32'), trainable=False)

        def _fused_batch_norm_training():
            return tf.nn.fused_batch_norm(inputs, scale, offset, epsilon=1e-5, data_format='NCHW')
        def _fused_batch_norm_inference():
            # Version which blends in the current item's statistics
            batch_size = tf.cast(tf.shape(inputs)[0], 'float32')
            mean, var = tf.nn.moments(inputs, [2,3], keep_dims=True)
            mean = ((1./batch_size)*mean) + (((batch_size-1.)/batch_size)*moving_mean)[None,:,None,None]
            var = ((1./batch_size)*var) + (((batch_size-1.)/batch_size)*moving_variance)[None,:,None,None]
            return tf.nn.batch_normalization(inputs, mean, var, offset[None,:,None,None], scale[None,:,None,None], 1e-5), mean, var

            # Standard version
            # return tf.nn.fused_batch_norm(
            #     inputs,
            #     scale,
            #     offset,
            #     epsilon=1e-2, 
            #     mean=moving_mean,
            #     variance=moving_variance,
            #     is_training=False,
            #     data_format='NCHW'
            # )

        if is_training is None:
            outputs, batch_mean, batch_var = _fused_batch_norm_training()
        else:
            outputs, batch_mean, batch_var = tf.cond(is_training,
                                                       _fused_batch_norm_training,
                                                       _fused_batch_norm_inference)
            if update_moving_stats:
                no_updates = lambda: outputs
                def _force_updates():
                    """Internal function forces updates moving_vars if is_training."""
                    float_stats_iter = tf.cast(stats_iter, tf.float32)

                    update_moving_mean = tf.assign(moving_mean, ((float_stats_iter/(float_stats_iter+1))*moving_mean) + ((1/(float_stats_iter+1))*batch_mean))
                    update_moving_variance = tf.assign(moving_variance, ((float_stats_iter/(float_stats_iter+1))*moving_variance) + ((1/(float_stats_iter+1))*batch_var))

                    with tf.control_dependencies([update_moving_mean, update_moving_variance]):
                        return tf.identity(outputs)
                outputs = tf.cond(is_training, _force_updates, no_updates)

        if axes == [0,2]:
            return outputs[:,:,:,0] # collapse last dim
        else:
            return outputs
    else:
        # raise Exception('old BN')
        # TODO we can probably use nn.fused_batch_norm here too for speedup
        mean, var = tf.nn.moments(inputs, axes, keep_dims=True)
        shape = mean.get_shape().as_list()
        if 0 not in axes:
            print("WARNING ({}): didn't find 0 in axes, but not using separate BN params for each item in batch".format(name))
            shape[0] = 1
        offset = lib.param(name+'.offset', np.zeros(shape, dtype='float32'))
        scale = lib.param(name+'.scale', np.ones(shape, dtype='float32'))
        result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5)


        return result
예제 #4
0
def Linear(name,
           input_dim,
           output_dim,
           inputs,
           biases=True,
           initialization=None,
           weightnorm=None,
           gain=1.):
    """
    initialization: None, `lecun`, 'glorot', `he`, 'glorot_he', `orthogonal`, `("uniform", range)`
    """
    with tf.name_scope(name) as scope:

        def uniform(stdev, size):
            if _weights_stdev is not None:
                stdev = _weights_stdev
            return np.random.uniform(low=-stdev * np.sqrt(3),
                                     high=stdev * np.sqrt(3),
                                     size=size).astype('float32')

        if initialization == 'lecun':  # and input_dim != output_dim):
            # disabling orth. init for now because it's too slow
            weight_values = uniform(np.sqrt(1. / input_dim),
                                    (input_dim, output_dim))

        elif initialization == 'glorot' or (initialization == None):

            weight_values = uniform(np.sqrt(2. / (input_dim + output_dim)),
                                    (input_dim, output_dim))

        elif initialization == 'he':

            weight_values = uniform(np.sqrt(2. / input_dim),
                                    (input_dim, output_dim))

        elif initialization == 'glorot_he':

            weight_values = uniform(np.sqrt(4. / (input_dim + output_dim)),
                                    (input_dim, output_dim))

        elif initialization == 'orthogonal' or \
            (initialization == None and input_dim == output_dim):

            # From lasagne
            def sample(shape):
                if len(shape) < 2:
                    raise RuntimeError("Only shapes of length 2 or more are "
                                       "supported.")
                flat_shape = (shape[0], np.prod(shape[1:]))
                # TODO: why normal and not uniform?
                a = np.random.normal(0.0, 1.0, flat_shape)
                u, _, v = np.linalg.svd(a, full_matrices=False)
                # pick the one with the correct shape
                q = u if u.shape == flat_shape else v
                q = q.reshape(shape)
                return q.astype('float32')

            weight_values = sample((input_dim, output_dim))

        elif initialization[0] == 'uniform':

            weight_values = np.random.uniform(
                low=-initialization[1],
                high=initialization[1],
                size=(input_dim, output_dim)).astype('float32')

        else:

            raise Exception('Invalid initialization!')

        weight_values *= gain

        weight = lib.param(name + '.W', weight_values)

        if weightnorm == None:
            weightnorm = _default_weightnorm
        if weightnorm:
            norm_values = np.sqrt(np.sum(np.square(weight_values), axis=0))
            # norm_values = np.linalg.norm(weight_values, axis=0)

            target_norms = lib.param(name + '.g', norm_values)

            with tf.name_scope('weightnorm') as scope:
                norms = tf.sqrt(
                    tf.reduce_sum(tf.square(weight), reduction_indices=[0]))
                weight = weight * (target_norms / norms)

        # if 'Discriminator' in name:
        #     print "WARNING weight constraint on {}".format(name)
        #     weight = tf.nn.softsign(10.*weight)*.1

        if inputs.get_shape().ndims == 2:
            result = tf.matmul(inputs, weight)
        else:
            reshaped_inputs = tf.reshape(inputs, [-1, input_dim])
            result = tf.matmul(reshaped_inputs, weight)
            result = tf.reshape(
                result,
                tf.pack(tf.unpack(tf.shape(inputs))[:-1] + [output_dim]))

        if biases:
            result = tf.nn.bias_add(
                result,
                lib.param(name + '.b', np.zeros((output_dim, ),
                                                dtype='float32')))

        return result
예제 #5
0
def Conv2D(name,
           input_dim,
           output_dim,
           filter_size,
           inputs,
           he_init=True,
           mask_type=None,
           stride=1,
           weightnorm=None,
           biases=True,
           gain=1.):
    """
    inputs: tensor of shape (batch size, num channels, height, width)
    mask_type: one of None, 'a', 'b'

    returns: tensor of shape (batch size, num channels, height, width)
    """
    with tf.name_scope(name) as scope:

        if mask_type is not None:
            mask_type, mask_n_channels = mask_type

            mask = np.ones((filter_size, filter_size, input_dim, output_dim),
                           dtype='float32')
            center = filter_size // 2

            # Mask out future locations
            # filter shape is (height, width, input channels, output channels)
            mask[center + 1:, :, :, :] = 0.
            mask[center, center + 1:, :, :] = 0.

            # Mask out future channels
            for i in xrange(mask_n_channels):
                for j in xrange(mask_n_channels):
                    if (mask_type == 'a' and i >= j) or (mask_type == 'b'
                                                         and i > j):
                        mask[center, center, i::mask_n_channels,
                             j::mask_n_channels] = 0.

        def uniform(stdev, size):
            return np.random.uniform(low=-stdev * np.sqrt(3),
                                     high=stdev * np.sqrt(3),
                                     size=size).astype('float32')

        fan_in = input_dim * filter_size**2
        fan_out = output_dim * filter_size**2 / (stride**2)

        if mask_type is not None:  # only approximately correct
            fan_in /= 2.
            fan_out /= 2.

        if he_init:
            filters_stdev = np.sqrt(4. / (fan_in + fan_out))
        else:  # Normalized init (Glorot & Bengio)
            filters_stdev = np.sqrt(2. / (fan_in + fan_out))

        if _weights_stdev is not None:
            filter_values = uniform(
                _weights_stdev,
                (filter_size, filter_size, input_dim, output_dim))
        else:
            filter_values = uniform(
                filters_stdev,
                (filter_size, filter_size, input_dim, output_dim))

        # print "WARNING IGNORING GAIN"
        filter_values *= gain

        filters = lib.param(name + '.Filters', filter_values)

        if weightnorm == None:
            weightnorm = _default_weightnorm
        if weightnorm:
            norm_values = np.sqrt(
                np.sum(np.square(filter_values), axis=(0, 1, 2)))
            target_norms = lib.param(name + '.g', norm_values)
            with tf.name_scope('weightnorm') as scope:
                norms = tf.sqrt(
                    tf.reduce_sum(tf.square(filters),
                                  reduction_indices=[0, 1, 2]))
                filters = filters * (target_norms / norms)

        if mask_type is not None:
            with tf.name_scope('filter_mask'):
                filters = filters * mask

        result = tf.nn.conv2d(input=inputs,
                              filter=filters,
                              strides=[1, 1, stride, stride],
                              padding='SAME',
                              data_format='NCHW')

        if biases:
            _biases = lib.param(name + '.Biases',
                                np.zeros(output_dim, dtype='float32'))
            result = tf.nn.bias_add(result, _biases, data_format='NCHW')

        return result