示例#1
0
    def __call__(self, x, training, start_core_index=0, final_core_index=1):
        if training:
            logging.info(f'Call {self.name} for `training`')
        else:
            logging.info(f'Call {self.name} for `eval`')

        params = self.params
        if params.use_bfloat16:
            ops.use_bfloat16()
        if params.use_xla_sharding:
            ops.set_xla_sharding(params.num_cores_per_replica)

        def _block_fn(inputs, num_out_filters, stride, name):
            return ops.resnet_block(inputs,
                                    params=params,
                                    num_out_filters=num_out_filters,
                                    stride=stride,
                                    training=training,
                                    name=name)

        with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
            with tf.variable_scope('stem'):
                x = ops.conv2d(x, 7, 64, 2)
                x = ops.batch_norm(x, params, training)
                x = ops.relu(x, leaky=0.)
                ops.log_tensor(x, True)

                x = ops.max_pool(x, 3, 2)
                ops.log_tensor(x, True)

            x = _block_fn(x, 256, 1, name='block_1')
            x = _block_fn(x, 256, 1, name='block_2')
            x = _block_fn(x, 256, 1, name='block_3')

            x = _block_fn(x, 512, 2, name='block_4')
            x = _block_fn(x, 512, 1, name='block_5')
            x = _block_fn(x, 512, 1, name='block_6')
            x = _block_fn(x, 512, 1, name='block_7')

            x = _block_fn(x, 1024, 2, name='block_8')
            x = _block_fn(x, 1024, 1, name='block_9')
            x = _block_fn(x, 1024, 1, name='block_10')
            x = _block_fn(x, 1024, 1, name='block_11')
            x = _block_fn(x, 1024, 1, name='block_12')
            x = _block_fn(x, 1024, 1, name='block_13')

            x = _block_fn(x, 2048, 2, name='block_14')
            x = _block_fn(x, 2048, 1, name='block_15')
            x = _block_fn(x, 2048, 1, name='block_16')

            with tf.variable_scope('head'):
                x = tf.reduce_mean(x, axis=[1, 2], name='global_avg_pool')
                ops.log_tensor(x, True)

                x = ops.dropout(x, params.dense_dropout_rate, training)
                x = ops.dense(x, params.num_classes)
                x = tf.cast(x, dtype=tf.float32, name='logits')
                ops.log_tensor(x, True)

        return x
示例#2
0
    def __call__(self, x, training, start_core_index=0, final_core_index=1):
        if training:
            logging.info(f'Call {self.name} for `training`')
        else:
            logging.info(f'Call {self.name} for `eval`')

        params = self.params
        k = self.k
        if params.use_bfloat16:
            ops.use_bfloat16()
        if params.use_xla_sharding:
            ops.set_xla_sharding(params.num_cores_per_replica)

        s = [16, 135, 135 * 2, 135 *
             4] if k == 135 else [16 * k, 16 * k, 32 * k, 64 * k]

        with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE):
            with tf.variable_scope('stem'):
                x = ops.conv2d(x, 3, s[0], 1)
                ops.log_tensor(x, True)

            x = ops.wrn_block(x, params, s[1], 1, training, 'block_1')
            x = ops.wrn_block(x, params, s[1], 1, training, 'block_2')
            x = ops.wrn_block(x, params, s[1], 1, training, 'block_3')
            x = ops.wrn_block(x, params, s[1], 1, training, 'block_4')

            x = ops.wrn_block(x, params, s[2], 2, training, 'block_5')
            x = ops.wrn_block(x, params, s[2], 1, training, 'block_6')
            x = ops.wrn_block(x, params, s[2], 1, training, 'block_7')
            x = ops.wrn_block(x, params, s[2], 1, training, 'block_8')

            x = ops.wrn_block(x, params, s[3], 2, training, 'block_9')
            x = ops.wrn_block(x, params, s[3], 1, training, 'block_10')
            x = ops.wrn_block(x, params, s[3], 1, training, 'block_11')
            x = ops.wrn_block(x, params, s[3], 1, training, 'block_12')

            with tf.variable_scope('head'):
                x = ops.batch_norm(x, params, training)
                x = ops.relu(x)
                x = tf.reduce_mean(x, axis=[1, 2], name='global_avg_pool')
                ops.log_tensor(x, True)

                x = ops.dropout(x, params.dense_dropout_rate, training)
                x = ops.dense(x, params.num_classes)
                x = tf.cast(x, dtype=tf.float32, name='logits')
                ops.log_tensor(x, True)

        return x