Example #1
0
def Layernorm(name, norm_axes, inputs):
    mean, var = tf.nn.moments(inputs, norm_axes, keep_dims=True)
    # Assume the 'neurons' axis is the first of norm_axes. This is the case for fully-connected and BCHW conv layers.
    n_neurons = inputs.get_shape().as_list()[norm_axes[0]]

    offset = lib.param(name + '.offset', np.zeros(n_neurons, dtype='float32'))
    scale = lib.param(name + '.scale', np.ones(n_neurons, dtype='float32'))

    # Add broadcasting dims to offset and scale (e.g. BCHW conv data)
    offset = tf.reshape(offset, [-1] + [1 for i in range(len(norm_axes) - 1)])
    scale = tf.reshape(scale, [-1] + [1 for i in range(len(norm_axes) - 1)])

    result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-5)

    return result
Example #2
0
def Conv2D(name,
           input_dim,
           output_dim,
           filter_size,
           inputs,
           he_init=True,
           mask_type=None,
           stride=1,
           weightnorm=None,
           biases=True,
           gain=1.,
           cpu=False):
    """
    inputs: tensor of shape (batch size, num channels, height, width)
    mask_type: one of None, 'a', 'b'

    returns: tensor of shape (batch size, num channels, height, width)
    """
    with tf.name_scope(name) as scope:

        if mask_type is not None:
            mask_type, mask_n_channels = mask_type

            mask = np.ones((filter_size, filter_size, input_dim, output_dim),
                           dtype='float32')
            center = filter_size // 2

            # Mask out future locations
            # filter shape is (height, width, input channels, output channels)
            mask[center + 1:, :, :, :] = 0.
            mask[center, center + 1:, :, :] = 0.

            # Mask out future channels
            for i in xrange(mask_n_channels):
                for j in xrange(mask_n_channels):
                    if (mask_type == 'a' and i >= j) or (mask_type == 'b'
                                                         and i > j):
                        mask[center, center, i::mask_n_channels,
                             j::mask_n_channels] = 0.

        def uniform(stdev, size):
            return np.random.uniform(low=-stdev * np.sqrt(3),
                                     high=stdev * np.sqrt(3),
                                     size=size).astype('float32')

        fan_in = input_dim * filter_size**2
        fan_out = output_dim * filter_size**2 / (stride**2)

        if mask_type is not None:  # only approximately correct
            fan_in /= 2.
            fan_out /= 2.

        if he_init:
            filters_stdev = np.sqrt(4. / (fan_in + fan_out))
        else:  # Normalized init (Glorot & Bengio)
            filters_stdev = np.sqrt(2. / (fan_in + fan_out))

        if _weights_stdev is not None:
            filter_values = uniform(
                _weights_stdev,
                (filter_size, filter_size, input_dim, output_dim))
        else:
            filter_values = uniform(
                filters_stdev,
                (filter_size, filter_size, input_dim, output_dim))

        # print "WARNING IGNORING GAIN"
        filter_values *= gain

        filters = lib.param(name + '.Filters', filter_values)

        if weightnorm == None:
            weightnorm = _default_weightnorm
        if weightnorm:
            norm_values = np.sqrt(
                np.sum(np.square(filter_values), axis=(0, 1, 2)))
            target_norms = lib.param(name + '.g', norm_values)
            with tf.name_scope('weightnorm') as scope:
                norms = tf.sqrt(
                    tf.reduce_sum(tf.square(filters),
                                  reduction_indices=[0, 1, 2]))
                filters = filters * (target_norms / norms)

        if mask_type is not None:
            with tf.name_scope('filter_mask'):
                filters = filters * mask

        if cpu:
            data_format = 'NHWC'
            strides = [1, stride, stride, 1]
        else:
            data_format = 'NCHW'
            strides = [1, 1, stride, stride]

        result = tf.nn.conv2d(input=inputs,
                              filter=filters,
                              strides=strides,
                              padding='SAME',
                              data_format=data_format)

        if biases:
            _biases = lib.param(name + '.Biases',
                                np.zeros(output_dim, dtype='float32'))

            result = tf.nn.bias_add(result, _biases, data_format=data_format)

        return result
Example #3
0
def Deconv2D(
    name,
    input_dim,
    output_dim,
    filter_size,
    inputs,
    he_init=True,
    weightnorm=None,
    biases=True,
    gain=1.,
    mask_type=None,
):
    """
    inputs: tensor of shape (batch size, height, width, input_dim)
    returns: tensor of shape (batch size, 2*height, 2*width, output_dim)
    """
    with tf.name_scope(name) as scope:

        if mask_type != None:
            raise Exception('Unsupported configuration')

        def uniform(stdev, size):
            return np.random.uniform(low=-stdev * np.sqrt(3),
                                     high=stdev * np.sqrt(3),
                                     size=size).astype('float32')

        stride = 2
        fan_in = input_dim * filter_size**2 / (stride**2)
        fan_out = output_dim * filter_size**2

        if he_init:
            filters_stdev = np.sqrt(4. / (fan_in + fan_out))
        else:  # Normalized init (Glorot & Bengio)
            filters_stdev = np.sqrt(2. / (fan_in + fan_out))

        if _weights_stdev is not None:
            filter_values = uniform(
                _weights_stdev,
                (filter_size, filter_size, output_dim, input_dim))
        else:
            filter_values = uniform(
                filters_stdev,
                (filter_size, filter_size, output_dim, input_dim))

        filter_values *= gain

        filters = lib.param(name + '.Filters', filter_values)

        if weightnorm == None:
            weightnorm = _default_weightnorm
        if weightnorm:
            norm_values = np.sqrt(
                np.sum(np.square(filter_values), axis=(0, 1, 3)))
            target_norms = lib.param(name + '.g', norm_values)
            with tf.name_scope('weightnorm') as scope:
                norms = tf.sqrt(
                    tf.reduce_sum(tf.square(filters),
                                  reduction_indices=[0, 1, 3]))
                filters = filters * tf.expand_dims(target_norms / norms, 1)

        inputs = tf.transpose(inputs, [0, 2, 3, 1], name='NCHW_to_NHWC')

        input_shape = tf.shape(inputs)
        try:  # tf pre-1.0 (top) vs 1.0 (bottom)
            output_shape = tf.pack([
                input_shape[0], 2 * input_shape[1], 2 * input_shape[2],
                output_dim
            ])
        except Exception as e:
            output_shape = tf.stack([
                input_shape[0], 2 * input_shape[1], 2 * input_shape[2],
                output_dim
            ])

        result = tf.nn.conv2d_transpose(value=inputs,
                                        filter=filters,
                                        output_shape=output_shape,
                                        strides=[1, 2, 2, 1],
                                        padding='SAME')

        if biases:
            _biases = lib.param(name + '.Biases',
                                np.zeros(output_dim, dtype='float32'))
            result = tf.nn.bias_add(result, _biases)

        result = tf.transpose(result, [0, 3, 1, 2], name='NHWC_to_NCHW')

        return result
Example #4
0
def Batchnorm(name,
              axes,
              inputs,
              is_training,
              stats_iter=None,
              update_moving_stats=True,
              fused=False,
              decay=0.9,
              cpu=False):
    eps = 1e-5
    if ((axes == [0, 2, 3]) or (axes == [0, 2])) and fused == True:
        if axes == [0, 2]:
            inputs = tf.expand_dims(inputs, 3)

        offset = lib.param(name + '.offset', tf.zeros(inputs.get_shape()[1]))
        scale = lib.param(name + '.scale', tf.ones(inputs.get_shape()[1]))
        moving_mean = lib.param(name + '.moving_mean',
                                tf.zeros(inputs.get_shape()[1]),
                                trainable=False)
        moving_variance = lib.param(name + '.moving_variance',
                                    tf.ones(inputs.get_shape()[1]),
                                    trainable=False)

        def train_bn():
            outputs, batch_mean, batch_var = tf.nn.fused_batch_norm(
                inputs, scale, offset, epsilon=eps, data_format='NCHW')
            update_moving_mean = tf.assign(
                moving_mean, moving_mean * decay + batch_mean * (1. - decay))
            update_moving_variance = tf.assign(
                moving_variance,
                moving_variance * decay + batch_var * (1. - decay))
            with tf.control_dependencies(
                [update_moving_mean, update_moving_variance]):
                return tf.identity(outputs)

        def infer_bn():
            outputs, _, _ = tf.nn.fused_batch_norm(inputs,
                                                   scale,
                                                   offset,
                                                   epsilon=eps,
                                                   mean=moving_mean,
                                                   variance=moving_variance,
                                                   data_format='NCHW',
                                                   is_training=False)
            return outputs

        outputs = tf.cond(is_training, train_bn, infer_bn)

        if axes == [0, 2]:
            return outputs[:, :, :, 0]  # collapse last dim
        return outputs

    else:
        offset = lib.param(name + '.offset',
                           tf.zeros([inputs.get_shape()[-1]]))
        scale = lib.param(name + '.scale', tf.ones([inputs.get_shape()[-1]]))
        moving_mean = lib.param(name + '.moving_mean',
                                tf.zeros([inputs.get_shape()[-1]]),
                                trainable=False)
        moving_variance = lib.param(name + '.moving_variance',
                                    tf.ones([inputs.get_shape()[-1]]),
                                    trainable=False)

        def train_bn():
            batch_mean, batch_var = tf.nn.moments(inputs, [0])
            update_moving_mean = tf.assign(
                moving_mean, moving_mean * decay + batch_mean * (1. - decay))
            update_moving_variance = tf.assign(
                moving_variance,
                moving_variance * decay + batch_var * (1. - decay))

            with tf.control_dependencies(
                [update_moving_mean, update_moving_variance]):
                return tf.nn.batch_normalization(inputs, batch_mean, batch_var,
                                                 offset, scale, eps)

        def infer_bn():
            return tf.nn.batch_normalization(inputs, moving_mean,
                                             moving_variance, offset, scale,
                                             eps)

        return tf.cond(is_training, train_bn, infer_bn)
Example #5
0
def Linear(
        name, 
        input_dim, 
        output_dim, 
        inputs,
        biases=True,
        initialization=None,
        weightnorm=None,
        gain=1.
        ):
    """
    initialization: None, `lecun`, 'glorot', `he`, 'glorot_he', `orthogonal`, `("uniform", range)`
    """
    with tf.name_scope(name) as scope:

        def uniform(stdev, size):
            if _weights_stdev is not None:
                stdev = _weights_stdev
            return np.random.uniform(
                low=-stdev * np.sqrt(3),
                high=stdev * np.sqrt(3),
                size=size
            ).astype('float32')

        if initialization == 'lecun':# and input_dim != output_dim):
            # disabling orth. init for now because it's too slow
            weight_values = uniform(
                np.sqrt(1./input_dim),
                (input_dim, output_dim)
            )

        elif initialization == 'glorot' or (initialization == None):

            weight_values = uniform(
                np.sqrt(2./(input_dim+output_dim)),
                (input_dim, output_dim)
            )

        elif initialization == 'he':

            weight_values = uniform(
                np.sqrt(2./input_dim),
                (input_dim, output_dim)
            )

        elif initialization == 'glorot_he':

            weight_values = uniform(
                np.sqrt(4./(input_dim+output_dim)),
                (input_dim, output_dim)
            )

        elif initialization == 'orthogonal' or \
            (initialization == None and input_dim == output_dim):
            
            # From lasagne
            def sample(shape):
                if len(shape) < 2:
                    raise RuntimeError("Only shapes of length 2 or more are "
                                       "supported.")
                flat_shape = (shape[0], np.prod(shape[1:]))
                 # TODO: why normal and not uniform?
                a = np.random.normal(0.0, 1.0, flat_shape)
                u, _, v = np.linalg.svd(a, full_matrices=False)
                # pick the one with the correct shape
                q = u if u.shape == flat_shape else v
                q = q.reshape(shape)
                return q.astype('float32')
            weight_values = sample((input_dim, output_dim))
        
        elif initialization[0] == 'uniform':
        
            weight_values = np.random.uniform(
                low=-initialization[1],
                high=initialization[1],
                size=(input_dim, output_dim)
            ).astype('float32')

        else:

            raise Exception('Invalid initialization!')

        weight_values *= gain

        weight = lib.param(
            name + '.W',
            weight_values
        )

        if weightnorm==None:
            weightnorm = _default_weightnorm
        if weightnorm:
            norm_values = np.sqrt(np.sum(np.square(weight_values), axis=0))
            # norm_values = np.linalg.norm(weight_values, axis=0)

            target_norms = lib.param(
                name + '.g',
                norm_values
            )

            with tf.name_scope('weightnorm') as scope:
                norms = tf.sqrt(tf.reduce_sum(tf.square(weight), reduction_indices=[0]))
                weight = weight * (target_norms / norms)

        # if 'Discriminator' in name:
        #     print "WARNING weight constraint on {}".format(name)
        #     weight = tf.nn.softsign(10.*weight)*.1

        if inputs.get_shape().ndims == 2:
            result = tf.matmul(inputs, weight)
        else:
            reshaped_inputs = tf.reshape(inputs, [-1, input_dim])
            result = tf.matmul(reshaped_inputs, weight)
            result = tf.reshape(result, tf.pack(tf.unpack(tf.shape(inputs))[:-1] + [output_dim]))

        if biases:
            result = tf.nn.bias_add(
                result,
                lib.param(
                    name + '.b',
                    np.zeros((output_dim,), dtype='float32')
                )
            )

        return result