Пример #1
0
def bottleneck_resnet_block(grid_structured,
                            hp,
                            kernel_size=3,
                            out_multiple=4):
    data_format = "channels_last"
    is_training = hp.mode == tf.estimator.ModeKeys.TRAIN
    hidden_size_base = hp.hidden_size
    filters_out = out_multiple * hidden_size_base

    def projection_shortcut(inputs):
        inputs = resnet.conv2d_fixed_padding(inputs,
                                             filters_out,
                                             kernel_size=1,
                                             data_format=data_format,
                                             strides=1,
                                             is_training=is_training)
        return resnet.batch_norm_relu(inputs,
                                      is_training,
                                      relu=False,
                                      data_format=data_format)

    residual = projection_shortcut(grid_structured)

    inputs = resnet.conv2d_fixed_padding(inputs=grid_structured,
                                         filters=hidden_size_base,
                                         kernel_size=1,
                                         strides=1,
                                         data_format=data_format,
                                         is_training=is_training)
    inputs = resnet.batch_norm_relu(inputs,
                                    is_training,
                                    data_format=data_format)

    inputs = resnet.conv2d_fixed_padding(inputs=inputs,
                                         filters=hidden_size_base,
                                         kernel_size=kernel_size,
                                         strides=1,
                                         data_format=data_format,
                                         is_training=is_training)
    inputs = resnet.batch_norm_relu(inputs,
                                    is_training,
                                    data_format=data_format)

    inputs = resnet.conv2d_fixed_padding(inputs=inputs,
                                         filters=filters_out,
                                         kernel_size=1,
                                         strides=1,
                                         data_format=data_format,
                                         is_training=is_training)
    inputs = resnet.batch_norm_relu(inputs,
                                    is_training,
                                    relu=False,
                                    init_zero=False,
                                    data_format=data_format)
    return tf.nn.relu(inputs + residual)
Пример #2
0
 def projection_shortcut(inputs):
     inputs = resnet.conv2d_fixed_padding(inputs,
                                          filters_out,
                                          kernel_size=1,
                                          data_format=data_format,
                                          strides=1,
                                          is_training=is_training)
     return resnet.batch_norm_relu(inputs,
                                   is_training,
                                   relu=False,
                                   data_format=data_format)
Пример #3
0
 def projection_shortcut(inputs):
     """Project identity branch."""
     inputs = resnet.conv2d_fixed_padding(inputs=inputs,
                                          filters=filters_out,
                                          kernel_size=1,
                                          strides=strides,
                                          data_format=data_format,
                                          use_td=use_td,
                                          targeting_rate=targeting_rate,
                                          keep_prob=keep_prob,
                                          is_training=is_training)
     return resnet.batch_norm_relu(inputs,
                                   is_training,
                                   relu=False,
                                   data_format=data_format)