Ejemplo n.º 1
0
def Layernorm(name, norm_axes, inputs):
    mean, var = tf.nn.moments(inputs, norm_axes, keep_dims=True)

    # Assume the 'neurons' axis is the first of norm_axes. This is the case for fully-connected and BCHW conv layers.
    n_neurons = inputs.get_shape().as_list()[norm_axes[0]]

    offset = lib.param(name+'.offset', np.zeros(n_neurons, dtype='float32'))
    scale = lib.param(name+'.scale', np.ones(n_neurons, dtype='float32'))

    # Add broadcasting dims to offset and scale (e.g. BCHW conv data)
    offset = tf.reshape(offset, [-1] + [1 for i in xrange(len(norm_axes)-1)])
    scale = tf.reshape(scale, [-1] + [1 for i in xrange(len(norm_axes)-1)])

    result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5)

    return result
Ejemplo n.º 2
0
def Batchnorm(name,
              axes,
              inputs,
              is_training=None,
              stats_iter=None,
              update_moving_stats=True,
              fused=True,
              labels=None,
              n_labels=None):
    """conditional batchnorm (dumoulin et al 2016) for BCHW conv filtermaps"""
    if axes != [0, 2, 3]:
        raise Exception('unsupported')
    mean, var = tf.nn.moments(inputs, axes, keep_dims=True)
    shape = mean.get_shape().as_list()  # shape is [1,n,1,1]
    offset_m = lib.param(name + '.offset',
                         np.zeros([n_labels, shape[1]], dtype='float32'))
    scale_m = lib.param(name + '.scale',
                        np.ones([n_labels, shape[1]], dtype='float32'))
    offset = tf.nn.embedding_lookup(offset_m, labels)
    scale = tf.nn.embedding_lookup(scale_m, labels)
    result = tf.nn.batch_normalization(inputs, mean, var, offset[:, :, None,
                                                                 None],
                                       scale[:, :, None, None], 1e-5)
    return result
Ejemplo n.º 3
0
def Deconv2D(
    name,
    input_dim,
    output_dim,
    filter_size,
    inputs,
    he_init=True,
    weightnorm=None,
    biases=True,
    gain=1.,
    mask_type=None,
):
    """
    inputs: tensor of shape (batch size, height, width, input_dim)
    returns: tensor of shape (batch size, 2*height, 2*width, output_dim)
    """
    with tf.name_scope(name) as scope:

        if mask_type != None:
            raise Exception('Unsupported configuration')

        def uniform(stdev, size):
            return np.random.uniform(low=-stdev * np.sqrt(3),
                                     high=stdev * np.sqrt(3),
                                     size=size).astype('float32')

        stride = 2
        fan_in = input_dim * filter_size**2 / (stride**2)
        fan_out = output_dim * filter_size**2

        if he_init:
            filters_stdev = np.sqrt(4. / (fan_in + fan_out))
        else:  # Normalized init (Glorot & Bengio)
            filters_stdev = np.sqrt(2. / (fan_in + fan_out))

        if _weights_stdev is not None:
            filter_values = uniform(
                _weights_stdev,
                (filter_size, filter_size, output_dim, input_dim))
        else:
            filter_values = uniform(
                filters_stdev,
                (filter_size, filter_size, output_dim, input_dim))

        filter_values *= gain

        filters = lib.param(name + '.Filters', filter_values)

        if weightnorm == None:
            weightnorm = _default_weightnorm
        if weightnorm:
            norm_values = np.sqrt(
                np.sum(np.square(filter_values), axis=(0, 1, 3)))
            target_norms = lib.param(name + '.g', norm_values)
            with tf.name_scope('weightnorm') as scope:
                norms = tf.sqrt(
                    tf.reduce_sum(tf.square(filters),
                                  reduction_indices=[0, 1, 3]))
                filters = filters * tf.expand_dims(target_norms / norms, 1)

        inputs = tf.transpose(inputs, [0, 2, 3, 1], name='NCHW_to_NHWC')

        input_shape = tf.shape(inputs)
        try:  # tf pre-1.0 (top) vs 1.0 (bottom)
            output_shape = tf.pack([
                input_shape[0], 2 * input_shape[1], 2 * input_shape[2],
                output_dim
            ])
        except Exception as e:
            output_shape = tf.stack([
                input_shape[0], 2 * input_shape[1], 2 * input_shape[2],
                output_dim
            ])

        result = tf.nn.conv2d_transpose(value=inputs,
                                        filter=filters,
                                        output_shape=output_shape,
                                        strides=[1, 2, 2, 1],
                                        padding='SAME')

        if biases:
            _biases = lib.param(name + '.Biases',
                                np.zeros(output_dim, dtype='float32'))
            result = tf.nn.bias_add(result, _biases)

        result = tf.transpose(result, [0, 3, 1, 2], name='NHWC_to_NCHW')

        return result
Ejemplo n.º 4
0
def Conv2D(name,
           input_dim,
           output_dim,
           filter_size,
           inputs,
           he_init=True,
           mask_type=None,
           stride=1,
           weightnorm=None,
           biases=True,
           gain=1.):
    """
    inputs: tensor of shape (batch size, num channels, height, width)
    mask_type: one of None, 'a', 'b'

    returns: tensor of shape (batch size, num channels, height, width)
    """
    with tf.name_scope(name) as scope:

        if mask_type is not None:
            mask_type, mask_n_channels = mask_type

            mask = np.ones((filter_size, filter_size, input_dim, output_dim),
                           dtype='float32')
            center = filter_size // 2

            # Mask out future locations
            # filter shape is (height, width, input channels, output channels)
            mask[center + 1:, :, :, :] = 0.
            mask[center, center + 1:, :, :] = 0.

            # Mask out future channels
            for i in xrange(mask_n_channels):
                for j in xrange(mask_n_channels):
                    if (mask_type == 'a' and i >= j) or (mask_type == 'b'
                                                         and i > j):
                        mask[center, center, i::mask_n_channels,
                             j::mask_n_channels] = 0.

        def uniform(stdev, size):
            return np.random.uniform(low=-stdev * np.sqrt(3),
                                     high=stdev * np.sqrt(3),
                                     size=size).astype('float32')

        fan_in = input_dim * filter_size**2
        fan_out = output_dim * filter_size**2 / (stride**2)

        if mask_type is not None:  # only approximately correct
            fan_in /= 2.
            fan_out /= 2.

        if he_init:
            filters_stdev = np.sqrt(4. / (fan_in + fan_out))
        else:  # Normalized init (Glorot & Bengio)
            filters_stdev = np.sqrt(2. / (fan_in + fan_out))

        if _weights_stdev is not None:
            filter_values = uniform(
                _weights_stdev,
                (filter_size, filter_size, input_dim, output_dim))
        else:
            filter_values = uniform(
                filters_stdev,
                (filter_size, filter_size, input_dim, output_dim))

        # print "WARNING IGNORING GAIN"
        filter_values *= gain

        filters = lib.param(name + '.Filters', filter_values)

        if weightnorm == None:
            weightnorm = _default_weightnorm
        if weightnorm:
            norm_values = np.sqrt(
                np.sum(np.square(filter_values), axis=(0, 1, 2)))
            target_norms = lib.param(name + '.g', norm_values)
            with tf.name_scope('weightnorm') as scope:
                norms = tf.sqrt(
                    tf.reduce_sum(tf.square(filters),
                                  reduction_indices=[0, 1, 2]))
                filters = filters * (target_norms / norms)

        if mask_type is not None:
            with tf.name_scope('filter_mask'):
                filters = filters * mask

        result = tf.nn.conv2d(input=inputs,
                              filter=filters,
                              strides=[1, 1, stride, stride],
                              padding='SAME',
                              data_format='NCHW')

        if biases:
            _biases = lib.param(name + '.Biases',
                                np.zeros(output_dim, dtype='float32'))

            result = tf.nn.bias_add(result, _biases, data_format='NCHW')

        return result
Ejemplo n.º 5
0
def Linear(name,
           input_dim,
           output_dim,
           inputs,
           biases=True,
           initialization=None,
           weightnorm=None,
           gain=1.):
    """
    initialization: None, `lecun`, 'glorot', `he`, 'glorot_he', `orthogonal`, `("uniform", range)`
    """
    with tf.name_scope(name) as scope:

        def uniform(stdev, size):
            if _weights_stdev is not None:
                stdev = _weights_stdev
            return np.random.uniform(low=-stdev * np.sqrt(3),
                                     high=stdev * np.sqrt(3),
                                     size=size).astype('float32')

        if initialization == 'lecun':  # and input_dim != output_dim):
            # disabling orth. init for now because it's too slow
            weight_values = uniform(np.sqrt(1. / input_dim),
                                    (input_dim, output_dim))

        elif initialization == 'glorot' or (initialization == None):

            weight_values = uniform(np.sqrt(2. / (input_dim + output_dim)),
                                    (input_dim, output_dim))

        elif initialization == 'he':

            weight_values = uniform(np.sqrt(2. / input_dim),
                                    (input_dim, output_dim))

        elif initialization == 'glorot_he':

            weight_values = uniform(np.sqrt(4. / (input_dim + output_dim)),
                                    (input_dim, output_dim))

        elif initialization == 'orthogonal' or \
            (initialization == None and input_dim == output_dim):

            # From lasagne
            def sample(shape):
                if len(shape) < 2:
                    raise RuntimeError("Only shapes of length 2 or more are "
                                       "supported.")
                flat_shape = (shape[0], np.prod(shape[1:]))
                # TODO: why normal and not uniform?
                a = np.random.normal(0.0, 1.0, flat_shape)
                u, _, v = np.linalg.svd(a, full_matrices=False)
                # pick the one with the correct shape
                q = u if u.shape == flat_shape else v
                q = q.reshape(shape)
                return q.astype('float32')

            weight_values = sample((input_dim, output_dim))

        elif initialization[0] == 'uniform':

            weight_values = np.random.uniform(
                low=-initialization[1],
                high=initialization[1],
                size=(input_dim, output_dim)).astype('float32')

        else:

            raise Exception('Invalid initialization!')

        weight_values *= gain

        weight = lib.param(name + '.W', weight_values)

        if weightnorm == None:
            weightnorm = _default_weightnorm
        if weightnorm:
            norm_values = np.sqrt(np.sum(np.square(weight_values), axis=0))
            # norm_values = np.linalg.norm(weight_values, axis=0)

            target_norms = lib.param(name + '.g', norm_values)

            with tf.name_scope('weightnorm') as scope:
                norms = tf.sqrt(
                    tf.reduce_sum(tf.square(weight), reduction_indices=[0]))
                weight = weight * (target_norms / norms)

        # if 'Discriminator' in name:
        #     print "WARNING weight constraint on {}".format(name)
        #     weight = tf.nn.softsign(10.*weight)*.1

        if inputs.get_shape().ndims == 2:
            result = tf.matmul(inputs, weight)
        else:
            reshaped_inputs = tf.reshape(inputs, [-1, input_dim])
            result = tf.matmul(reshaped_inputs, weight)
            result = tf.reshape(
                result,
                tf.pack(tf.unpack(tf.shape(inputs))[:-1] + [output_dim]))

        if biases:
            result = tf.nn.bias_add(
                result,
                lib.param(name + '.b', np.zeros((output_dim, ),
                                                dtype='float32')))

        return result
Ejemplo n.º 6
0
def Batchnorm(name,
              axes,
              inputs,
              is_training=None,
              stats_iter=None,
              update_moving_stats=True,
              fused=True):
    """

    :param name:
    :param axes: the remaining axis represents CHANNEL, we want to normalize CHANNEL
    :param inputs:
    :param is_training:
    :param stats_iter:
    :param update_moving_stats:
    :param fused:
    :return:
    """

    if ((axes == [0, 2, 3]) or (axes == [0, 2])) and fused == True:
        if axes == [0, 2]:
            inputs = tf.expand_dims(inputs, 3)

        # Variables declaration
        offset = lib.param(name + '.offset',
                           np.zeros(inputs.get_shape()[1], dtype='float32'))
        scale = lib.param(name + '.scale',
                          np.ones(inputs.get_shape()[1], dtype='float32'))
        moving_mean = lib.param(name + '.moving_mean',
                                np.zeros(inputs.get_shape()[1],
                                         dtype='float32'),
                                trainable=False)
        moving_variance = lib.param(name + '.moving_variance',
                                    np.ones(inputs.get_shape()[1],
                                            dtype='float32'),
                                    trainable=False)

        # train
        def _fused_batch_norm_training():
            return tf.nn.fused_batch_norm(inputs,
                                          scale,
                                          offset,
                                          epsilon=1e-5,
                                          data_format='NCHW')

        # test
        def _fused_batch_norm_inference():
            # Version which blends in the current item's statistics
            batch_size = tf.cast(tf.shape(inputs)[0], 'float32')
            mean, var = tf.nn.moments(inputs, [2, 3], keep_dims=True)
            mean = ((1. / batch_size) * mean) + ((
                (batch_size - 1.) / batch_size) * moving_mean)[None, :, None,
                                                               None]
            var = ((1. / batch_size) * var) + ((
                (batch_size - 1.) / batch_size) * moving_variance)[None, :,
                                                                   None, None]
            return tf.nn.batch_normalization(inputs, mean, var,
                                             offset[None, :, None, None],
                                             scale[None, :, None,
                                                   None], 1e-5), mean, var

            # Standard version
            # return tf.nn.fused_batch_norm(
            #     inputs,
            #     scale,
            #     offset,
            #     epsilon=1e-2,
            #     mean=moving_mean,
            #     variance=moving_variance,
            #     is_training=False,
            #     data_format='NCHW'
            # )

        if is_training is None:
            outputs, batch_mean, batch_var = _fused_batch_norm_training()
        else:
            outputs, batch_mean, batch_var = tf.cond(
                is_training, _fused_batch_norm_training,
                _fused_batch_norm_inference)
            if update_moving_stats:
                no_updates = lambda: outputs

                def _force_updates():
                    """Internal function forces updates moving_vars if is_training."""
                    float_stats_iter = tf.cast(stats_iter, tf.float32)

                    update_moving_mean = tf.assign(
                        moving_mean, ((float_stats_iter /
                                       (float_stats_iter + 1)) * moving_mean) +
                        ((1 / (float_stats_iter + 1)) * batch_mean))
                    update_moving_variance = tf.assign(
                        moving_variance,
                        ((float_stats_iter /
                          (float_stats_iter + 1)) * moving_variance) +
                        ((1 / (float_stats_iter + 1)) * batch_var))

                    with tf.control_dependencies(
                        [update_moving_mean, update_moving_variance]):
                        return tf.identity(outputs)

                outputs = tf.cond(is_training, _force_updates, no_updates)

        if axes == [0, 2]:
            return outputs[:, :, :, 0]  # collapse last dim
        else:
            return outputs
    else:
        # raise Exception('old BN')
        # TODO we can probably use nn.fused_batch_norm here too for speedup
        mean, var = tf.nn.moments(inputs, axes, keep_dims=True)
        shape = mean.get_shape().as_list()
        if 0 not in axes:
            print "WARNING ({}): didn't find 0 in axes, but not using separate BN params for each item in batch".format(
                name)
            shape[0] = 1
        offset = lib.param(name + '.offset', np.zeros(shape, dtype='float32'))
        scale = lib.param(name + '.scale', np.ones(shape, dtype='float32'))
        # result = tf.cond(tf.equal(tf.shape(inputs)[0], 1), lambda: inputs, lambda: tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5))

        result = tf.nn.batch_normalization(inputs, mean, var, offset, scale,
                                           1e-5)

        return result