Ejemplo n.º 1
0
    def _output_from_pre_logits(self, contact_pre_logits, features_forward,
                                layers_forward, output_dimension, data_format,
                                crop_x, crop_y, use_on_the_fly_stats):
        """Given pre-logits, compute the final distogram/contact activations."""
        config_2d_deep = self._network_2d_deep
        if self._reshape_layer:
            in_channels = config_2d_deep.num_filters
            concat_features = [contact_pre_logits]
            if features_forward is not None:
                concat_features.append(features_forward)
                in_channels += self._features_forward
            if layers_forward is not None:
                concat_features.append(layers_forward)
                in_channels += 2 * config_2d_deep.num_filters
            if len(concat_features) > 1:
                contact_pre_logits = tf.concat(
                    concat_features, 1 if data_format == 'NCHW' else 3)

            contact_logits = two_dim_convnet.make_conv_layer(
                contact_pre_logits,
                in_channels=in_channels,
                out_channels=output_dimension,
                layer_name='output_reshape_1x1h',
                filter_size=1,
                filter_size_2=1,
                non_linearity=False,
                batch_norm=config_2d_deep.use_batch_norm,
                is_training=use_on_the_fly_stats,
                data_format=data_format
            )  #sets up the neural net layer that will take into account the logits of the contacts as the item of difference
        else:
            contact_logits = contact_pre_logits

        if data_format == 'NCHW':
            contact_logits = tf.transpose(contact_logits, perm=[0, 2, 3, 1])

        if self._position_specific_bias_size:
            # Make 2D pos-specific biases: NHWC.
            biases = build_crops_biases(self._position_specific_bias_size,
                                        self._position_specific_bias,
                                        crop_x,
                                        crop_y,
                                        back_prop=True)
            contact_logits += biases

        # Will be NHWC.
        return contact_logits
Ejemplo n.º 2
0
def make_sep_res_layer(input_node,
                       in_channels,
                       out_channels,
                       layer_name,
                       filter_size,
                       filter_size_2=None,
                       batch_norm=False,
                       is_training=True,
                       divide_channels_by=2,
                       atrou_rate=1,
                       channel_multiplier=0,
                       data_format='NHWC',
                       stddev=0.01,
                       dropout_keep_prob=1.0):
    """A separable resnet block."""

    with tf.name_scope(layer_name):
        input_times_almost_1 = input_node
        h_conv = input_times_almost_1

        if batch_norm:
            h_conv = two_dim_convnet.batch_norm_layer(h_conv,
                                                      layer_name=layer_name,
                                                      is_training=is_training,
                                                      data_format=data_format)

        h_conv = tf.nn.elu(h_conv)

        if filter_size_2 is None:
            filter_size_2 = filter_size

        # 1x1 with half size
        h_conv = two_dim_convnet.make_conv_layer(
            h_conv,
            in_channels=in_channels,
            out_channels=in_channels / divide_channels_by,
            layer_name=layer_name + '_1x1h',
            filter_size=1,
            filter_size_2=1,
            non_linearity=True,
            batch_norm=batch_norm,
            is_training=is_training,
            data_format=data_format,
            stddev=stddev)

        # 3x3 with half size
        if channel_multiplier == 0:
            h_conv = two_dim_convnet.make_conv_layer(
                h_conv,
                in_channels=in_channels / divide_channels_by,
                out_channels=in_channels / divide_channels_by,
                layer_name=layer_name + '_%dx%dh' %
                (filter_size, filter_size_2),
                filter_size=filter_size,
                filter_size_2=filter_size_2,
                non_linearity=True,
                batch_norm=batch_norm,
                is_training=is_training,
                atrou_rate=atrou_rate,
                data_format=data_format,
                stddev=stddev)
        else:
            # We use separable convolution for 3x3
            h_conv = two_dim_convnet.make_conv_sep2d_layer(
                h_conv,
                in_channels=in_channels / divide_channels_by,
                channel_multiplier=channel_multiplier,
                out_channels=in_channels / divide_channels_by,
                layer_name=layer_name + '_sep%dx%dh' %
                (filter_size, filter_size_2),
                filter_size=filter_size,
                filter_size_2=filter_size_2,
                batch_norm=batch_norm,
                is_training=is_training,
                atrou_rate=atrou_rate,
                data_format=data_format,
                stddev=stddev)

        # 1x1 back to normal size without relu
        h_conv = two_dim_convnet.make_conv_layer(
            h_conv,
            in_channels=in_channels / divide_channels_by,
            out_channels=out_channels,
            layer_name=layer_name + '_1x1',
            filter_size=1,
            filter_size_2=1,
            non_linearity=False,
            batch_norm=False,
            is_training=is_training,
            data_format=data_format,
            stddev=stddev)

        if dropout_keep_prob < 1.0:
            logging.info('dropout keep prob %f', dropout_keep_prob)
            h_conv = tf.nn.dropout(h_conv, keep_prob=dropout_keep_prob)

        return h_conv + input_times_almost_1
Ejemplo n.º 3
0
def make_two_dim_resnet(input_node,
                        num_residues=50,
                        num_features=40,
                        num_predictions=1,
                        num_channels=32,
                        num_layers=2,
                        filter_size=3,
                        filter_size_2=None,
                        final_non_linearity=False,
                        name_prefix='',
                        fancy=True,
                        batch_norm=False,
                        is_training=False,
                        atrou_rates=None,
                        channel_multiplier=0,
                        divide_channels_by=2,
                        resize_features_with_1x1=False,
                        data_format='NHWC',
                        stddev=0.01,
                        dropout_keep_prob=1.0):
    """Two dim resnet towers."""
    del num_residues  # Unused.

    if atrou_rates is None:
        atrou_rates = [1]
    if not fancy:
        raise ValueError('non fancy deprecated')

    logging.info('atrou rates %s', atrou_rates)

    logging.info('name prefix %s', name_prefix)
    x_image = input_node
    previous_layer = x_image
    non_linearity = True
    for i_layer in range(num_layers):
        in_channels = num_channels
        out_channels = num_channels

        curr_atrou_rate = atrou_rates[i_layer % len(atrou_rates)]

        if i_layer == 0:
            in_channels = num_features
        if i_layer == num_layers - 1:
            out_channels = num_predictions
            non_linearity = final_non_linearity
        if i_layer == 0 or i_layer == num_layers - 1:
            layer_name = name_prefix + 'conv%d' % (i_layer + 1)
            initial_filter_size = filter_size
            if resize_features_with_1x1:
                initial_filter_size = 1
            previous_layer = two_dim_convnet.make_conv_layer(
                input_node=previous_layer,
                in_channels=in_channels,
                out_channels=out_channels,
                layer_name=layer_name,
                filter_size=initial_filter_size,
                filter_size_2=filter_size_2,
                non_linearity=non_linearity,
                atrou_rate=curr_atrou_rate,
                data_format=data_format,
                stddev=stddev)
        else:
            layer_name = name_prefix + 'res%d' % (i_layer + 1)
            previous_layer = make_sep_res_layer(
                input_node=previous_layer,
                in_channels=in_channels,
                out_channels=out_channels,
                layer_name=layer_name,
                filter_size=filter_size,
                filter_size_2=filter_size_2,
                batch_norm=batch_norm,
                is_training=is_training,
                atrou_rate=curr_atrou_rate,
                channel_multiplier=channel_multiplier,
                divide_channels_by=divide_channels_by,
                data_format=data_format,
                stddev=stddev,
                dropout_keep_prob=dropout_keep_prob)

    y = previous_layer

    return y