def apply(self, x): net = nn.Dense(x, 500, name='fc1') net = nn.leaky_relu(net) net = nn.BatchNorm(net) net = nn.Dense(net, 500, name='fc2') net = nn.leaky_relu(net) net = nn.BatchNorm(net) net = nn.Dense(net, 500, name='fc3') net = nn.leaky_relu(net) net = nn.BatchNorm(net) return nn.softmax(nn.Dense(net, n_bin))
def apply(self, x, blocks_per_group, channel_multiplier, num_outputs, dropout_rate=0.0, train=True): x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = WideResnetGroup(x, blocks_per_group, 16 * channel_multiplier, dropout_rate=dropout_rate, train=train) x = WideResnetGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, train=train) x = WideResnetGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), dropout_rate=dropout_rate, train=train) x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5) x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x
def apply(self, x, num_filters=64, block_sizes=(3, 4, 6, 3), train=True, block=BottleneckBlock, small_inputs=False): if small_inputs: x = nn.Conv(x, num_filters, kernel_size=(3, 3), strides=(1, 1), bias=False, name="init_conv") else: x = nn.Conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False, name="init_conv") x = nn.BatchNorm(x, use_running_average=not train, epsilon=1e-5, name="init_bn") if not small_inputs: x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME") for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = block(x, num_filters * 2**i, strides=strides, train=train) return x
def apply( self, x, num_outputs, num_filters=64, block_sizes=[3, 4, 6, 3], # pylint: disable=dangerous-default-value train=True): x = nn.Conv(x, num_filters, (7, 7), (2, 2), bias=False, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, epsilon=1e-5, name='init_bn') x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = BottleneckBlock(x, num_filters * 2**i, strides=strides, train=train) x = jnp.mean(x, axis=(1, 2)) x_clf = nn.Dense(x, num_outputs, name='clf') # We return both the outputs from the dense layer *and* the features # that go into it. return x_clf, x
def apply(self, x, num_classes, num_filters=64, num_layers=50, train=True, dtype=jnp.float32): if num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[num_layers] x = nn.Conv(x, num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], bias=False, dtype=dtype, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn') x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = ResidualBlock(x, num_filters * 2**i, strides=strides, train=train, dtype=dtype) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes) x = nn.log_softmax(x) return x
def conv2d(inputs: tf.Tensor, conv_filters: Optional[int], config: ModelConfig, kernel_size: Union[int, Tuple[int, int]] = (1, 1), strides: Tuple[int, int] = (1, 1), use_batch_norm: bool = True, use_bias: bool = False, activation: Any = None, depthwise: bool = False, train: bool = True, conv_name: str = None, bn_name: str = None) -> jnp.ndarray: """Convolutional layer with possibly batch norm and activation. Args: inputs: Input data with dimensions (batch, spatial_dims..., features). conv_filters: Number of convolution filters. config: Configuration for the model. kernel_size: Size of the kernel, as a tuple of int. strides: Strides for the convolution, as a tuple of int. use_batch_norm: Whether batch norm should be applied to the output. use_bias: Whether we should add bias to the output of the first convolution. activation: Name of the activation function to use. depthwise: If true, will use depthwise convolutions. train: Whether the model should behave in training or inference mode. conv_name: Name to give to the convolution layer. bn_name: Name to give to the batch norm layer. Returns: The output of the convolutional layer. """ conv_fn = DepthwiseConv if depthwise else flax.nn.Conv kernel_size = ((kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)) conv_name = conv_name if conv_name else 'conv2d' bn_name = bn_name if bn_name else 'batch_normalization' x = conv_fn( inputs, conv_filters, kernel_size, tuple(strides), padding='SAME', bias=use_bias, kernel_init=conv_kernel_init_fn, name=conv_name) if use_batch_norm: x = nn.BatchNorm( x, use_running_average=not train or FLAGS.from_pretrained_checkpoint, momentum=config.bn_momentum, epsilon=config.bn_epsilon, name=bn_name, axis_name='batch') if activation is not None: x = getattr(flax.nn.activation, activation.lower())(x) return x
def apply(self, x): b = x.shape[0] x = nn.Conv(x, features=128, kernel_size=(4, ), padding='SAME') x = nn.BatchNorm(x) x = nn.leaky_relu(x) x = nn.avg_pool(x, window_shape=(2, ), padding='SAME') x = nn.Conv(x, features=256, kernel_size=(4, ), padding='SAME') x = nn.BatchNorm(x) x = nn.leaky_relu(x) x = nn.avg_pool(x, window_shape=(2, ), padding='SAME') x = x.reshape(b, -1) x = nn.Dense(x, features=128) x = nn.BatchNorm(x) x = nn.leaky_relu(x) x = nn.Dense(x, features=n_bins) x = nn.softmax(x) return x
def apply(self, g, x, in_feats, hidden_feats, out_feats, num_layers, dropout): with nn.stochastic(jax.random.PRNGKey(0)): x = SAGEConv(g, x, in_feats, hidden_feats) for idx in range(num_layers-2): x = SAGEConv(g, x, hidden_feats, hidden_feats) x = nn.BatchNorm(x) x = nn.dropout(x, rate=dropout) x = SAGEConv(g, x, hidden_feats, out_feats) return jax.nn.log_softmax(x, axis=-1)
def apply(self, x, num_classes, num_filters=64, num_layers=50, train=True, axis_name=None, axis_index_groups=None, dtype=jnp.float32, conv0_space_to_depth=False): if num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[num_layers] if conv0_space_to_depth: conv = SpaceToDepthConv.partial(block_size=(2, 2), padding=[(2, 1), (2, 1)]) else: conv = nn.Conv.partial(padding=[(3, 3), (3, 3)]) x = conv(x, num_filters, kernel_size=(7, 7), strides=(2, 2), bias=False, dtype=dtype, name='conv0') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn', axis_name=axis_name, axis_index_groups=axis_index_groups, dtype=dtype) x = nn.relu(x) # MLPerf-required x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = ResidualBlock(x, num_filters * 2**i, strides=strides, train=train, axis_name=axis_name, axis_index_groups=axis_index_groups, dtype=dtype) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(), dtype=dtype) x = nn.log_softmax(x) return x
def apply(self, x, num_classes, parameters, num_filters=64, train=True, axis_name=None, num_layers='34'): block_sizes = [3, 4, 6] data_format = 'channels_last' if ('conv0_space_to_depth' in parameters and parameters['conv0_space_to_depth']): # conv0 uses space-to-depth transform for TPU performance. x = func_conv0_space_to_depth(inputs=x, data_format=data_format, dtype=parameters['dtype']) else: x = conv2d_fixed_padding( inputs=x, filters=num_filters, kernel_size=7, strides=2, data_format=data_format, name='init_conv') replica_groups = _make_replica_groups(parameters) x = nn.BatchNorm( x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn', axis_name=axis_name, dtype=parameters['dtype'], axis_index_groups=replica_groups) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = 2 if i == 1 and j == 0 else 1 use_projection = False if i == 0 or j > 0 else True x = ResidualBlock( x, num_filters * 2**i, parameters, strides=strides, train=train, axis_name=axis_name, use_projection=use_projection, data_format=data_format) if num_layers == '34': x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(), dtype=jnp.float32) # TODO(deveci): dtype=dtype x = nn.log_softmax(x) return x
def apply(self, x, num_classes, train=True, batch_stats=None, axis_name=None, dtype=jnp.float32): x = nn.BatchNorm(x, batch_stats=batch_stats, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn', axis_name=axis_name, dtype=dtype) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_classes, kernel_init=nn.initializers.normal(), dtype=dtype) x = nn.log_softmax(x) return x
def apply(self, x, num_outputs, train=True): x = nn.Conv(x, self.NUM_FILTERS, (7, 7), (2, 2), bias=False, name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn') x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(self.BLOCK_SIZES): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = BottleneckBlock(x, self.NUM_FILTERS * 2 ** i, strides=strides, groups=self.GROUPS, base_width=self.WIDTH_PER_GROUP, train=train) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(x, num_outputs, name='clf') return x
def apply( self, x: Array, upsampling_factor: int, max_num_features: int, ) -> Array: activation = nn.silu num_upsampling_layers = np.log2(upsampling_factor).astype(int) resizes = [1] + [2] * num_upsampling_layers feature_sizes = [ max(max_num_features // 2**n, 2) for n in range(len(resizes)) ] for fs, rs in zip(feature_sizes, resizes): x = Upsample2D(x, factor=rs) x = PeriodicSpaceConv(x, fs, kernel_size=(3, 3, 3)) x = nn.BatchNorm(x) x = activation(x) x = PeriodicSpaceConv(x, features=2, kernel_size=(3, 3, 3)) return x
def apply(self, x, num_outputs, pyramid_alpha=200, pyramid_depth=272, train=True): assert (pyramid_depth - 2) % 9 == 0 # Shake-drop hyper-params mask_prob = 0.5 alpha_min, alpha_max = (-1.0, 1.0) beta_min, beta_max = (0.0, 1.0) # Bottleneck network size blocks_per_group = (pyramid_depth - 2) // 9 # See Eqn 2 in https://arxiv.org/abs/1610.02915 num_channels = 16 # N in https://arxiv.org/abs/1610.02915 total_blocks = blocks_per_group * 3 delta_channels = pyramid_alpha / total_blocks x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv') x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn') layer_num = 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), (1, 1), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 for block_i in range(blocks_per_group): num_channels += delta_channels layer_mask_prob = _calc_shakedrop_mask_prob( layer_num, total_blocks, mask_prob) x = BottleneckShakeDrop(x, int(num_channels), ((2, 2) if block_i == 0 else (1, 1)), layer_mask_prob, alpha_min, alpha_max, beta_min, beta_max, train=train) layer_num += 1 assert layer_num - 1 == total_blocks x = nn.BatchNorm(x, use_running_average=not train, momentum=0.9, epsilon=1e-5, name='final_bn') x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs) return x