Exemple #1
0
def up_block(
        layer_name,
        bottom,
        skip_activity,
        reuse,
        kernel_size,
        num_filters,
        training,
        trainable=None,
        stride=[2, 2],
        padding='same',
        kernel_initializer=None,
        normalization_type='batch_norm',
        data_format='channels_last',
        renorm=False,
        use_bias=False):
    """Forward block for seung model."""
    if trainable is None:
        trainable = training
    with tf.variable_scope('%s_block' % layer_name, reuse=reuse):
        with tf.variable_scope('%s_layer_1' % layer_name, reuse=reuse):
            x = tf.layers.conv2d_transpose(
                inputs=bottom,
                filters=num_filters,
                kernel_size=kernel_size,
                name='%s_1' % layer_name,
                strides=stride,
                padding=padding,
                kernel_initializer=kernel_initializer,
                trainable=trainable,
                use_bias=use_bias)
            x = x + skip_activity  # Rethink if this is valid
            if normalization_type is 'batch_norm':
                x = normalization.batch(
                    bottom=x,
                    name='%s_bn_1' % layer_name,
                    data_format=data_format,
                    renorm=renorm,
                    trainable=trainable,
                    training=training)
            else:
                x = normalization.instance(
                    bottom=x,
                    training=training)
            x = tf.nn.elu(x)
    return x
Exemple #2
0
def build_model(data_tensor,
                reuse,
                training,
                output_shape,
                data_format='NHWC'):
    """Create the hgru from Learning long-range..."""
    if isinstance(output_shape, list):
        output_shape = output_shape[-1]
    elif isinstance(output_shape, dict):
        output_shape = output_shape['output']
    if data_format is 'NCHW':
        data_tensor = tf.transpose(data_tensor, (0, 3, 1, 2))
        long_data_format = 'channels_first'
    else:
        long_data_format = 'channels_last'

    with tf.variable_scope('cnn', reuse=reuse):
        normalization_type = 'instance_norm'  #
        # # Concatenate standard deviation
        # _, var = tf.nn.moments(data_tensor, axes=[3])
        # std = tf.expand_dims(tf.sqrt(var), axis=-1)
        # data_tensor = tf.concat([data_tensor, std], axis=-1)

        # Add input
        in_emb = tf.layers.conv2d(inputs=data_tensor,
                                  filters=16,
                                  kernel_size=7,
                                  strides=(1, 1),
                                  padding='same',
                                  data_format=long_data_format,
                                  activation=tf.nn.relu,
                                  trainable=training,
                                  use_bias=True,
                                  name='l0')

        # Run fGRU
        hgru_kernels = OrderedDict()
        hgru_kernels['h1'] = [9, 9]  # height/width
        hgru_kernels['h2'] = [3, 3]
        hgru_kernels['fb1'] = [1, 1]
        hgru_features = OrderedDict()
        hgru_features['h1'] = [16, 16]  # Fan-in/fan-out, I and E (match fb1)
        hgru_features['h2'] = [48, 48]
        hgru_features['fb1'] = [16, 16]  # (match h1)
        # intermediate_ff = [16, 16, 24, 24, 32, 32]  # Last feature must match h2
        # intermediate_ks = [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
        # intermediate_repeats = [4, 4, 4, 4, 4, 4]  # Repeat each FF this many times
        intermediate_ff = [16, 48]  # Last feature must match h2
        intermediate_ks = [[3, 3], [3, 3]]
        intermediate_repeats = [4, 4]  # Repeat each FF this many times
        gammanet_layer = gammanet.GN(
            layer_name='fgru',
            x=in_emb,
            data_format=data_format,
            reuse=reuse,
            timesteps=2,
            strides=[1, 1, 1, 1],
            hgru_features=hgru_features,
            hgru_kernels=hgru_kernels,
            intermediate_ff=intermediate_ff,
            intermediate_ks=intermediate_ks,
            intermediate_repeats=intermediate_repeats,
            horizontal_kernel_initializer=lambda x: tf.initializers.orthogonal(
                gain=0.1),
            kernel_initializer=lambda x: tf.initializers.orthogonal(gain=0.1),
            padding='SAME',
            aux={
                'readout': 'fb',
                'attention': 'se',  # 'gala',  # 'gala', se
                'attention_layers': 2,
                'saliency_filter': 7,
                'use_homunculus': True,
                'upsample_convs': True,
                'separable_convs': 4,  # Multiplier
                'separable_upsample': True,
                'td_cell_state': True,
                'td_gate': False,  # Add top-down activity to the in-gate
                'normalization_type': normalization_type,
                'residual': True,  # intermediate resid connections
                'while_loop': False,
                'skip': False,
                'symmetric_weights': True,
                'bilinear_init': True,
                'include_pooling': True,
                'time_homunculus': True,
                'squeeze_fb': False,  # Compress Inh-hat with a 1x1 conv
                'force_horizontal': False,
                'time_skips': False,
                'timestep_output': False,
                'excite_se': False,  # Add S/E in the excitation stage
            },
            pool_strides=[2, 2],
            pooling_kernel=[4, 4],  # 4, 4 helped but also try 3, 3
            up_kernel=[4, 4],
            train=training)
        h2 = gammanet_layer(in_emb)
        if normalization_type is 'batch_norm':
            h2 = normalization.batch_contrib(bottom=h2,
                                             renorm=False,
                                             name='hgru_bn',
                                             dtype=h2.dtype,
                                             data_format=data_format,
                                             training=training)
        elif normalization_type is 'instance_norm':
            h2 = normalization.instance(bottom=h2,
                                        data_format=data_format,
                                        training=training)
        elif normalization_type is 'ada_batch_norm':
            h2 = normalization.batch_contrib(bottom=h2,
                                             renorm=False,
                                             name='hgru_bn',
                                             dtype=h2.dtype,
                                             data_format=data_format,
                                             training=training)
        else:
            raise NotImplementedError(normalization_type)
    with tf.variable_scope('cv_readout', reuse=reuse):
        activity = tf.layers.conv2d(inputs=h2,
                                    filters=output_shape,
                                    kernel_size=(1, 1),
                                    padding='same',
                                    data_format=long_data_format,
                                    name='pre_readout_conv',
                                    use_bias=True,
                                    reuse=reuse)
    if long_data_format is 'channels_first':
        activity = tf.transpose(activity, (0, 2, 3, 1))
    extra_activities = {'activity': h2}
    if activity.dtype != tf.float32:
        activity = tf.cast(activity, tf.float32)
    return activity, extra_activities
def build_model(data_tensor, reuse, training, output_shape):
    """Create the hgru from Learning long-range..."""
    if isinstance(output_shape, list):
        output_shape = output_shape[0]
    elif isinstance(output_shape, dict):
        output_shape = output_shape['output']
    with tf.variable_scope('cnn', reuse=reuse):
        normalization_type = 'ada_batch_norm'  # 'instance_norm'
        # # Concatenate standard deviation
        # _, var = tf.nn.moments(data_tensor, axes=[3])
        # std = tf.expand_dims(tf.sqrt(var), axis=-1)
        # data_tensor = tf.concat([data_tensor, std], axis=-1)

        # Add input
        in_emb = conv.skinny_input_layer(X=data_tensor,
                                         reuse=reuse,
                                         training=training,
                                         features=16,
                                         conv_activation=tf.nn.relu,
                                         conv_kernel_size=7,
                                         pool=False,
                                         name='l0')

        # Run fGRU
        hgru_kernels = OrderedDict()
        hgru_kernels['h1'] = [15, 15]  # height/width
        hgru_kernels['h2'] = [5, 5]
        hgru_kernels['fb1'] = [1, 1]
        hgru_features = OrderedDict()
        hgru_features['h1'] = [16, 16]  # Fan-in/fan-out, I and E (match fb1)
        hgru_features['h2'] = [48, 48]
        hgru_features['fb1'] = [16, 16]  # (match h1)
        # hgru_features['fb1'] = [24, 12]  # (match h1 unless squeeze_fb)
        intermediate_ff = [24, 24, 48]  # Last feature must match h2
        intermediate_ks = [[5, 5], [5, 5], [5, 5]]
        intermediate_repeats = [1, 1, 1]  # Repeat each interm this many times
        layer_hgru = hgru.hGRU(
            'fgru',
            x_shape=in_emb.get_shape().as_list(),
            timesteps=8,
            strides=[1, 1, 1, 1],
            hgru_features=hgru_features,
            hgru_kernels=hgru_kernels,
            intermediate_ff=intermediate_ff,
            intermediate_ks=intermediate_ks,
            intermediate_repeats=intermediate_repeats,
            padding='SAME',
            aux={
                'readout': 'fb',
                'squeeze_fb': False,  # Compress Inh-hat with a 1x1 conv
                'td_gate': True,  # Add top-down activity to the in-gate
                'attention': 'gala',  # 'gala',
                'attention_layers': 2,
                'upsample_convs': True,
                'td_cell_state': True,
                'normalization_type': normalization_type,
                'excite_se': False,  # Add S/E in the excitation stage
                'residual': True,  # intermediate resid connections
                'while_loop': False,
                'skip': True,
                'time_skips': False,
                'force_horizontal': False,
                'symmetric_weights': True,
                'timestep_output': False,
                'bilinear_init': True,
                'include_pooling': True
            },
            pool_strides=[2, 2],
            pooling_kernel=[4, 4],
            up_kernel=[4, 4],
            train=training)
        h2 = layer_hgru.build(in_emb)
        if normalization_type is 'batch_norm':
            h2 = normalization.batch(bottom=h2,
                                     renorm=False,
                                     name='hgru_bn',
                                     training=training)
        elif normalization_type is 'instance_norm':
            h2 = normalization.instance(bottom=h2, training=training)
        elif normalization_type is 'ada_batch_norm':
            h2 = normalization.batch(bottom=h2,
                                     renorm=False,
                                     name='hgru_bn',
                                     training=True)
        else:
            raise NotImplementedError(normalization_type)
        fc = tf.layers.conv2d(inputs=h2,
                              filters=output_shape,
                              kernel_size=1,
                              name='fc')
        # activity = tf.reduce_mean(fc, reduction_indices=[1, 2])
        activity = tf.reduce_max(fc, reduction_indices=[1, 2])
    extra_activities = {'activity': h2}
    return activity, extra_activities
Exemple #4
0
def down_block(layer_name,
               bottom,
               reuse,
               kernel_size,
               num_filters,
               training,
               trainable=None,
               stride=(1, 1),
               normalization_type='batch_norm',
               padding='same',
               data_format='channels_last',
               kernel_initializer=None,
               renorm=False,
               use_bias=False,
               activation=tf.nn.elu,
               include_pool=True):
    """Forward block for seung model."""
    if trainable is None:
        trainable = training
    with tf.variable_scope('%s_block' % layer_name, reuse=reuse):
        with tf.variable_scope('%s_layer_1' % layer_name, reuse=reuse):
            x = tf.layers.conv2d(inputs=bottom,
                                 filters=num_filters,
                                 kernel_size=kernel_size[0],
                                 name='%s_1' % layer_name,
                                 strides=stride,
                                 padding=padding,
                                 data_format=data_format,
                                 kernel_initializer=kernel_initializer,
                                 trainable=trainable,
                                 use_bias=use_bias)
            if normalization_type is 'batch_norm':
                x = normalization.batch(bottom=x,
                                        name='%s_bn_1' % layer_name,
                                        data_format=data_format,
                                        renorm=renorm,
                                        training=training,
                                        trainable=trainable)
            else:
                x = normalization.instance(bottom=x, training=trainable)
            x = activation(x)
            skip = tf.identity(x)

        with tf.variable_scope('%s_layer_2' % layer_name, reuse=reuse):
            x = tf.layers.conv2d(inputs=x,
                                 filters=num_filters,
                                 kernel_size=kernel_size[1],
                                 name='%s_2' % layer_name,
                                 strides=stride,
                                 padding=padding,
                                 data_format=data_format,
                                 kernel_initializer=kernel_initializer,
                                 trainable=trainable,
                                 use_bias=use_bias)
            if normalization_type is 'batch_norm':
                x = normalization.batch(bottom=x,
                                        name='%s_bn_2' % layer_name,
                                        data_format=data_format,
                                        renorm=renorm,
                                        trainable=trainable,
                                        training=training)
            else:
                x = normalization.instance(bottom=x, training=training)
            x = activation(x)

        with tf.variable_scope('%s_layer_3' % layer_name, reuse=reuse):
            x = tf.layers.conv2d(inputs=x,
                                 filters=num_filters,
                                 kernel_size=kernel_size[2],
                                 name='%s_3' % layer_name,
                                 strides=stride,
                                 padding=padding,
                                 data_format=data_format,
                                 kernel_initializer=kernel_initializer,
                                 trainable=trainable,
                                 activation=activation,
                                 use_bias=use_bias)
            x = x + skip
            if normalization_type is 'batch_norm':
                x = normalization.batch(bottom=x,
                                        name='%s_bn_3' % layer_name,
                                        data_format=data_format,
                                        renorm=renorm,
                                        trainable=trainable,
                                        training=training)
            else:
                x = normalization.instance(bottom=x, training=training)

        if include_pool:
            with tf.variable_scope('%s_pool' % layer_name, reuse=reuse):
                x = tf.layers.max_pooling2d(inputs=x,
                                            pool_size=(2, 2),
                                            strides=(2, 2),
                                            padding=padding,
                                            data_format='channels_last',
                                            name='%s_pool' % layer_name)
    return x