Пример #1
0
def build_model(data_tensor,
                reuse,
                training,
                output_shape,
                data_format='NHWC'):
    """Create the hgru from Learning long-range..."""
    if isinstance(output_shape, list):
        output_shape = output_shape[-1]
    elif isinstance(output_shape, dict):
        output_shape = output_shape['output']
    output_normalization_type = 'batch_norm_original'
    ff_kernel_size = (5, 5)
    ff_nl = tf.nn.elu
    data_tensor, long_data_format = tf_fun.interpret_data_format(
        data_tensor=data_tensor, data_format=data_format)

    # Build model
    with tf.variable_scope('gammanet', reuse=reuse):
        conv_aux = {
            'pretrained': os.path.join('weights',
                                       'gabors_for_contours_11.npy'),
            'pretrained_key': 's1',
            'nonlinearity': 'square'
        }
        activity = conv.conv_layer(bottom=data_tensor,
                                   name='gabor_input',
                                   stride=[1, 1, 1, 1],
                                   padding='SAME',
                                   trainable=training,
                                   use_bias=True,
                                   aux=conv_aux)
        layer_hgru = hgru.hGRU('hgru_1',
                               x_shape=activity.get_shape().as_list(),
                               timesteps=8,
                               h_ext=15,
                               strides=[1, 1, 1, 1],
                               padding='SAME',
                               aux={
                                   'reuse': False,
                                   'constrain': False
                               },
                               train=training)
        h2 = layer_hgru.build(activity)
        h2 = normalization.batch_contrib(bottom=h2,
                                         name='hgru_bn',
                                         training=training)
        mask = np.load('weights/cardena_mask.npy')[None, :, :, None]
        activity = h2 * mask
    with tf.variable_scope('cv_readout', reuse=reuse):
        activity = tf.reduce_mean(activity, reduction_indices=[1, 2])
        activity = tf.layers.dense(activity, output_shape)
    if long_data_format is 'channels_first':
        activity = tf.transpose(activity, (0, 2, 3, 1))
    extra_activities = {}
    if activity.dtype != tf.float32:
        activity = tf.cast(activity, tf.float32)
    # return [activity, h_deep], extra_activities
    return activity, extra_activities
Пример #2
0
def readout_layer(
        activity,
        reuse,
        training,
        output_shape,
        dtype=tf.float32,
        var_scope='readout_1',
        pool_type='max',
        renorm=False,
        use_bn=False,
        features=2,
        return_fc=False,):
    """Readout layer for recurrent experiments in Kim et al., 2019."""
    with tf.variable_scope(var_scope, reuse=reuse):
        prepool_activity = tf.layers.conv2d(
            inputs=activity,
            filters=features,
            kernel_size=1,
            name='pre_readout_conv',
            strides=(1, 1),
            padding='same',
            activation=None,
            trainable=training,
            use_bias=False)
        pool_aux = {'pool_type': pool_type}
        if pool_type == 'select':
            # Gather center column of activity
            raise NotImplementedError
            act_shape = prepool_activity.get_shape().as_list()
            h = act_shape[1] // 2
            w = act_shape[2] // 2
            activity = tf.squeeze(prepool_activity[:, h, w, :], [1, 2])
        else:
            activity = pooling.global_pool(
                bottom=prepool_activity,
                name='pre_readout_pool',
                aux=pool_aux)
        if use_bn:
            activity = normalization.batch_contrib(
                bottom=activity,
                renorm=renorm,
                dtype=dtype,
                name='readout_1_bn',
                training=training)
    with tf.variable_scope('readout_2', reuse=reuse):
        out_activity = tf.layers.flatten(
            activity,
            name='flat_readout')
        out_activity = tf.layers.dense(
            inputs=out_activity,
            units=output_shape)
    if return_fc:
        return out_activity, prepool_activity
    else:
        return out_activity
Пример #3
0
def build_model(data_tensor,
                reuse,
                training,
                output_shape,
                data_format='NHWC'):
    """Create the hgru from Learning long-range..."""
    if isinstance(output_shape, list):
        output_shape = output_shape[-1]
    elif isinstance(output_shape, dict):
        output_shape = output_shape['output']
    if data_format is 'NCHW':
        data_tensor = tf.transpose(data_tensor, (0, 3, 1, 2))
        long_data_format = 'channels_first'
    else:
        long_data_format = 'channels_last'

    with tf.variable_scope('cnn', reuse=reuse):
        normalization_type = 'instance_norm'  #
        # # Concatenate standard deviation
        # _, var = tf.nn.moments(data_tensor, axes=[3])
        # std = tf.expand_dims(tf.sqrt(var), axis=-1)
        # data_tensor = tf.concat([data_tensor, std], axis=-1)

        # Add input
        in_emb = tf.layers.conv2d(inputs=data_tensor,
                                  filters=16,
                                  kernel_size=7,
                                  strides=(1, 1),
                                  padding='same',
                                  data_format=long_data_format,
                                  activation=tf.nn.relu,
                                  trainable=training,
                                  use_bias=True,
                                  name='l0')

        # Run fGRU
        hgru_kernels = OrderedDict()
        hgru_kernels['h1'] = [9, 9]  # height/width
        hgru_kernels['h2'] = [3, 3]
        hgru_kernels['fb1'] = [1, 1]
        hgru_features = OrderedDict()
        hgru_features['h1'] = [16, 16]  # Fan-in/fan-out, I and E (match fb1)
        hgru_features['h2'] = [48, 48]
        hgru_features['fb1'] = [16, 16]  # (match h1)
        # intermediate_ff = [16, 16, 24, 24, 32, 32]  # Last feature must match h2
        # intermediate_ks = [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3], [3, 3]]
        # intermediate_repeats = [4, 4, 4, 4, 4, 4]  # Repeat each FF this many times
        intermediate_ff = [16, 48]  # Last feature must match h2
        intermediate_ks = [[3, 3], [3, 3]]
        intermediate_repeats = [4, 4]  # Repeat each FF this many times
        gammanet_layer = gammanet.GN(
            layer_name='fgru',
            x=in_emb,
            data_format=data_format,
            reuse=reuse,
            timesteps=2,
            strides=[1, 1, 1, 1],
            hgru_features=hgru_features,
            hgru_kernels=hgru_kernels,
            intermediate_ff=intermediate_ff,
            intermediate_ks=intermediate_ks,
            intermediate_repeats=intermediate_repeats,
            horizontal_kernel_initializer=lambda x: tf.initializers.orthogonal(
                gain=0.1),
            kernel_initializer=lambda x: tf.initializers.orthogonal(gain=0.1),
            padding='SAME',
            aux={
                'readout': 'fb',
                'attention': 'se',  # 'gala',  # 'gala', se
                'attention_layers': 2,
                'saliency_filter': 7,
                'use_homunculus': True,
                'upsample_convs': True,
                'separable_convs': 4,  # Multiplier
                'separable_upsample': True,
                'td_cell_state': True,
                'td_gate': False,  # Add top-down activity to the in-gate
                'normalization_type': normalization_type,
                'residual': True,  # intermediate resid connections
                'while_loop': False,
                'skip': False,
                'symmetric_weights': True,
                'bilinear_init': True,
                'include_pooling': True,
                'time_homunculus': True,
                'squeeze_fb': False,  # Compress Inh-hat with a 1x1 conv
                'force_horizontal': False,
                'time_skips': False,
                'timestep_output': False,
                'excite_se': False,  # Add S/E in the excitation stage
            },
            pool_strides=[2, 2],
            pooling_kernel=[4, 4],  # 4, 4 helped but also try 3, 3
            up_kernel=[4, 4],
            train=training)
        h2 = gammanet_layer(in_emb)
        if normalization_type is 'batch_norm':
            h2 = normalization.batch_contrib(bottom=h2,
                                             renorm=False,
                                             name='hgru_bn',
                                             dtype=h2.dtype,
                                             data_format=data_format,
                                             training=training)
        elif normalization_type is 'instance_norm':
            h2 = normalization.instance(bottom=h2,
                                        data_format=data_format,
                                        training=training)
        elif normalization_type is 'ada_batch_norm':
            h2 = normalization.batch_contrib(bottom=h2,
                                             renorm=False,
                                             name='hgru_bn',
                                             dtype=h2.dtype,
                                             data_format=data_format,
                                             training=training)
        else:
            raise NotImplementedError(normalization_type)
    with tf.variable_scope('cv_readout', reuse=reuse):
        activity = tf.layers.conv2d(inputs=h2,
                                    filters=output_shape,
                                    kernel_size=(1, 1),
                                    padding='same',
                                    data_format=long_data_format,
                                    name='pre_readout_conv',
                                    use_bias=True,
                                    reuse=reuse)
    if long_data_format is 'channels_first':
        activity = tf.transpose(activity, (0, 2, 3, 1))
    extra_activities = {'activity': h2}
    if activity.dtype != tf.float32:
        activity = tf.cast(activity, tf.float32)
    return activity, extra_activities