def __call__(self, x, training, start_core_index=0, final_core_index=1): if training: logging.info(f'Call {self.name} for `training`') else: logging.info(f'Call {self.name} for `eval`') params = self.params if params.use_bfloat16: ops.use_bfloat16() if params.use_xla_sharding: ops.set_xla_sharding(params.num_cores_per_replica) def _block_fn(inputs, num_out_filters, stride, name): return ops.resnet_block(inputs, params=params, num_out_filters=num_out_filters, stride=stride, training=training, name=name) with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): with tf.variable_scope('stem'): x = ops.conv2d(x, 7, 64, 2) x = ops.batch_norm(x, params, training) x = ops.relu(x, leaky=0.) ops.log_tensor(x, True) x = ops.max_pool(x, 3, 2) ops.log_tensor(x, True) x = _block_fn(x, 256, 1, name='block_1') x = _block_fn(x, 256, 1, name='block_2') x = _block_fn(x, 256, 1, name='block_3') x = _block_fn(x, 512, 2, name='block_4') x = _block_fn(x, 512, 1, name='block_5') x = _block_fn(x, 512, 1, name='block_6') x = _block_fn(x, 512, 1, name='block_7') x = _block_fn(x, 1024, 2, name='block_8') x = _block_fn(x, 1024, 1, name='block_9') x = _block_fn(x, 1024, 1, name='block_10') x = _block_fn(x, 1024, 1, name='block_11') x = _block_fn(x, 1024, 1, name='block_12') x = _block_fn(x, 1024, 1, name='block_13') x = _block_fn(x, 2048, 2, name='block_14') x = _block_fn(x, 2048, 1, name='block_15') x = _block_fn(x, 2048, 1, name='block_16') with tf.variable_scope('head'): x = tf.reduce_mean(x, axis=[1, 2], name='global_avg_pool') ops.log_tensor(x, True) x = ops.dropout(x, params.dense_dropout_rate, training) x = ops.dense(x, params.num_classes) x = tf.cast(x, dtype=tf.float32, name='logits') ops.log_tensor(x, True) return x
def __call__(self, x, training, start_core_index=0, final_core_index=1): if training: logging.info(f'Call {self.name} for `training`') else: logging.info(f'Call {self.name} for `eval`') params = self.params k = self.k if params.use_bfloat16: ops.use_bfloat16() if params.use_xla_sharding: ops.set_xla_sharding(params.num_cores_per_replica) s = [16, 135, 135 * 2, 135 * 4] if k == 135 else [16 * k, 16 * k, 32 * k, 64 * k] with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): with tf.variable_scope('stem'): x = ops.conv2d(x, 3, s[0], 1) ops.log_tensor(x, True) x = ops.wrn_block(x, params, s[1], 1, training, 'block_1') x = ops.wrn_block(x, params, s[1], 1, training, 'block_2') x = ops.wrn_block(x, params, s[1], 1, training, 'block_3') x = ops.wrn_block(x, params, s[1], 1, training, 'block_4') x = ops.wrn_block(x, params, s[2], 2, training, 'block_5') x = ops.wrn_block(x, params, s[2], 1, training, 'block_6') x = ops.wrn_block(x, params, s[2], 1, training, 'block_7') x = ops.wrn_block(x, params, s[2], 1, training, 'block_8') x = ops.wrn_block(x, params, s[3], 2, training, 'block_9') x = ops.wrn_block(x, params, s[3], 1, training, 'block_10') x = ops.wrn_block(x, params, s[3], 1, training, 'block_11') x = ops.wrn_block(x, params, s[3], 1, training, 'block_12') with tf.variable_scope('head'): x = ops.batch_norm(x, params, training) x = ops.relu(x) x = tf.reduce_mean(x, axis=[1, 2], name='global_avg_pool') ops.log_tensor(x, True) x = ops.dropout(x, params.dense_dropout_rate, training) x = ops.dense(x, params.num_classes) x = tf.cast(x, dtype=tf.float32, name='logits') ops.log_tensor(x, True) return x
def __call__(self, x, training, start_core_index=0, final_core_index=1): if training: logging.info(f'Call {self.name} for `training`') else: logging.info(f'Call {self.name} for `eval`') params = self.params if params.use_bfloat16: ops.use_bfloat16() if params.use_xla_sharding: ops.set_xla_sharding(params.num_cores_per_replica) num_repeats = [self._rep(rp) for rp in [1, 2, 2, 3, 3, 4, 1]] num_blocks = sum(num_repeats) start = 0 block_start = [] for r in num_repeats: block_start.append(start) start += r s = 1 if self._is_small_net() else 2 # smaller strides for CIFAR def stack_fn(inputs, repeats, filter_size, num_out_filters, stride, expand_ratio, block_start): """Build a stack of multiple `mb_conv_block`.""" for i in range(self._rep(repeats)): inputs = ops.mb_conv_block( x=inputs, params=params, filter_size=filter_size, num_out_filters=self._fil(num_out_filters), stride=stride if i == 0 else 1, # only first block uses `stride` training=training, stochastic_depth_drop_rate=( params.stochastic_depth_drop_rate * float(block_start + i) / num_blocks), expand_ratio=expand_ratio, use_se=True, se_ratio=0.25, name=f'block_{block_start+i}') return inputs with tf.variable_scope(self.name, reuse=tf.AUTO_REUSE): with tf.variable_scope('stem'): x = ops.conv2d(x, 3, self._fil(32), s) x = ops.batch_norm(x, params, training=training) x = ops.swish(x) x = stack_fn(x, 1, 3, 16, 1, 1, block_start[0]) x = stack_fn(x, 2, 3, 24, s, 6, block_start[1]) x = stack_fn(x, 2, 5, 40, s, 6, block_start[2]) x = stack_fn(x, 3, 3, 80, 2, 6, block_start[3]) x = stack_fn(x, 3, 5, 112, 1, 6, block_start[4]) x = stack_fn(x, 4, 5, 192, 2, 6, block_start[5]) x = stack_fn(x, 1, 3, 320, 1, 6, block_start[6]) with tf.variable_scope('head'): x = ops.conv2d(x, 1, self._fil(1280), 1) x = ops.batch_norm(x, params, training) x = ops.swish(x) ops.log_tensor(x, True) x = tf.reduce_mean(x, axis=[1, 2], name='global_avg_pool') ops.log_tensor(x, True) x = ops.dropout(x, params.dense_dropout_rate, training) x = ops.dense(x, params.num_classes) x = tf.cast(x, dtype=tf.float32, name='logits') ops.log_tensor(x, True) return x