def __call__(self, x, training, start_core_index=0, final_core_index=1): if training: logging.info(f'Call {self.name} for `training`') else: logging.info(f'Call {self.name} for `eval`') params = self.params if params.use_bfloat16: ops.use_bfloat16() if params.use_xla_sharding: ops.set_xla_sharding(params.num_cores_per_replica) def _block_fn(inputs, num_out_filters, stride, name): return ops.resnet_block(inputs, params=params, num_out_filters=num_out_filters, stride=stride, training=training, name=name) with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): with tf.variable_scope('stem'): x = ops.conv2d(x, 7, 64, 2) x = ops.batch_norm(x, params, training) x = ops.relu(x, leaky=0.) ops.log_tensor(x, True) x = ops.max_pool(x, 3, 2) ops.log_tensor(x, True) x = _block_fn(x, 256, 1, name='block_1') x = _block_fn(x, 256, 1, name='block_2') x = _block_fn(x, 256, 1, name='block_3') x = _block_fn(x, 512, 2, name='block_4') x = _block_fn(x, 512, 1, name='block_5') x = _block_fn(x, 512, 1, name='block_6') x = _block_fn(x, 512, 1, name='block_7') x = _block_fn(x, 1024, 2, name='block_8') x = _block_fn(x, 1024, 1, name='block_9') x = _block_fn(x, 1024, 1, name='block_10') x = _block_fn(x, 1024, 1, name='block_11') x = _block_fn(x, 1024, 1, name='block_12') x = _block_fn(x, 1024, 1, name='block_13') x = _block_fn(x, 2048, 2, name='block_14') x = _block_fn(x, 2048, 1, name='block_15') x = _block_fn(x, 2048, 1, name='block_16') with tf.variable_scope('head'): x = tf.reduce_mean(x, axis=[1, 2], name='global_avg_pool') ops.log_tensor(x, True) x = ops.dropout(x, params.dense_dropout_rate, training) x = ops.dense(x, params.num_classes) x = tf.cast(x, dtype=tf.float32, name='logits') ops.log_tensor(x, True) return x
def __call__(self, x, training, start_core_index=0, final_core_index=1): if training: logging.info(f'Call {self.name} for `training`') else: logging.info(f'Call {self.name} for `eval`') params = self.params k = self.k if params.use_bfloat16: ops.use_bfloat16() if params.use_xla_sharding: ops.set_xla_sharding(params.num_cores_per_replica) s = [16, 135, 135 * 2, 135 * 4] if k == 135 else [16 * k, 16 * k, 32 * k, 64 * k] with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): with tf.variable_scope('stem'): x = ops.conv2d(x, 3, s[0], 1) ops.log_tensor(x, True) x = ops.wrn_block(x, params, s[1], 1, training, 'block_1') x = ops.wrn_block(x, params, s[1], 1, training, 'block_2') x = ops.wrn_block(x, params, s[1], 1, training, 'block_3') x = ops.wrn_block(x, params, s[1], 1, training, 'block_4') x = ops.wrn_block(x, params, s[2], 2, training, 'block_5') x = ops.wrn_block(x, params, s[2], 1, training, 'block_6') x = ops.wrn_block(x, params, s[2], 1, training, 'block_7') x = ops.wrn_block(x, params, s[2], 1, training, 'block_8') x = ops.wrn_block(x, params, s[3], 2, training, 'block_9') x = ops.wrn_block(x, params, s[3], 1, training, 'block_10') x = ops.wrn_block(x, params, s[3], 1, training, 'block_11') x = ops.wrn_block(x, params, s[3], 1, training, 'block_12') with tf.variable_scope('head'): x = ops.batch_norm(x, params, training) x = ops.relu(x) x = tf.reduce_mean(x, axis=[1, 2], name='global_avg_pool') ops.log_tensor(x, True) x = ops.dropout(x, params.dense_dropout_rate, training) x = ops.dense(x, params.num_classes) x = tf.cast(x, dtype=tf.float32, name='logits') ops.log_tensor(x, True) return x