def __init__(self): self.model = Sequential([ Conv2D(1, 8, strides=2, k=5, use_bias=False), BatchNorm2D(8), gelu, Conv2D(8, 16, strides=2, k=3, use_bias=False), BatchNorm2D(16), gelu, Conv2D(16, 32, strides=2, k=3, use_bias=False), BatchNorm2D(32), gelu, Conv2D(32, 64, strides=2, k=3, use_bias=False), BatchNorm2D(64), gelu, Conv2D(64, 1, strides=1, k=1, use_bias=False) ]) # logits
def __init__( self, in_channels: int, num_classes: int, blocks_per_group: Sequence[int], bottleneck: bool = True, channels_per_group: Sequence[int] = (256, 512, 1024, 2048), group_strides: Sequence[int] = (1, 2, 2, 2), group_use_projection: Sequence[bool] = (True, True, True, True), normalization_fn: Callable[..., objax.Module] = objax.nn.BatchNorm2D, activation_fn: Callable[[JaxArray], JaxArray] = objax.functional.relu): """Creates ResNetV2 instance. Args: in_channels: number of channels in the input image. num_classes: number of output classes. blocks_per_group: number of blocks in each block group. bottleneck: if True then use bottleneck blocks. channels_per_group: number of output channels for each block group. group_strides: strides for each block group. normalization_fn: module which used as normalization function. activation_fn: activation function. """ assert len(channels_per_group) == len(blocks_per_group) assert len(group_strides) == len(blocks_per_group) assert len(group_use_projection) == len(blocks_per_group) nin = in_channels nout = 64 ops = [ Conv2D(nin, nout, k=7, strides=2, **conv_args(7, 64, (3, 3))), functools.partial(jn.pad, pad_width=((0, 0), (0, 0), (1, 1), (1, 1))), functools.partial(objax.functional.max_pool_2d, size=3, strides=2, padding=ConvPadding.VALID) ] for i in range(len(blocks_per_group)): nin = nout nout = channels_per_group[i] ops.append( ResNetV2BlockGroup(nin, nout, num_blocks=blocks_per_group[i], stride=group_strides[i], bottleneck=bottleneck, use_projection=group_use_projection[i], normalization_fn=normalization_fn, activation_fn=activation_fn)) ops.extend([ normalization_fn(nout), activation_fn, lambda x: x.mean((2, 3)), objax.nn.Linear(nout, num_classes, w_init=objax.nn.init.xavier_normal) ]) super().__init__(ops)
def __init__(self, nin, nclass, scales, filters, filters_max): def nl(x): """Return tanh as activation function. Tanh has better utility for differentially private SGD https://arxiv.org/abs/2007.14191 . """ return tanh(x) def nf(scale): return min(filters_max, filters << scale) ops = [Conv2D(nin, nf(0), 3), nl] for i in range(scales): ops.extend([ Conv2D(nf(i), nf(i), 3), nl, Conv2D(nf(i), nf(i + 1), 3), nl, partial(average_pool_2d, size=2, strides=2) ]) ops.extend([Conv2D(nf(scales), nclass, 3), lambda x: x.mean((2, 3))]) super().__init__(ops)
def __init__(self): num_channels = 4 # 3 from RGB_t1 + 1 from dither_t0 self.encoders = objax.ModuleList() k = 7 for num_output_channels in [32, 64, 128, 128]: self.encoders.append( EncoderBlock(num_channels, num_output_channels, k)) k = 3 num_channels = num_output_channels self.decoders = objax.ModuleList() for num_output_channels in [128, 64, 32, 16]: self.decoders.append( DecoderBlock(num_channels, num_output_channels)) num_channels = num_output_channels self.logits = Conv2D(num_channels, nout=1, strides=1, k=1, w_init=xavier_normal)
def __init__(self, nin: int, nout: int, stride: Union[int, Sequence[int]], use_projection: bool, bottleneck: bool, normalization_fn: Callable[..., objax.Module] = objax.nn.BatchNorm2D, activation_fn: Callable[[JaxArray], JaxArray] = objax.functional.relu): """Creates ResNetV2Block instance. Args: nin: number of input filters. nout: number of output filters. stride: stride for 3x3 convolution and projection convolution in this block. use_projection: if True then include projection convolution into this block. bottleneck: if True then make bottleneck block. normalization_fn: module which used as normalization function. activation_fn: activation function. """ self.use_projection = use_projection self.activation_fn = activation_fn self.stride = stride if self.use_projection: self.proj_conv = Conv2D(nin, nout, 1, strides=stride, **conv_args(1, nout)) if bottleneck: self.norm_0 = normalization_fn(nin) self.conv_0 = Conv2D(nin, nout // 4, 1, strides=1, **conv_args(1, nout // 4)) self.norm_1 = normalization_fn(nout // 4) self.conv_1 = Conv2D(nout // 4, nout // 4, 3, strides=stride, **conv_args(3, nout // 4, (1, 1))) self.norm_2 = normalization_fn(nout // 4) self.conv_2 = Conv2D(nout // 4, nout, 1, strides=1, **conv_args(1, nout)) self.layers = ((self.norm_0, self.conv_0), (self.norm_1, self.conv_1), (self.norm_2, self.conv_2)) else: self.norm_0 = normalization_fn(nin) self.conv_0 = Conv2D(nin, nout, 3, strides=1, **conv_args(3, nout, (1, 1))) self.norm_1 = normalization_fn(nout) self.conv_1 = Conv2D(nout, nout, 3, strides=stride, **conv_args(3, nout, (1, 1))) self.layers = ((self.norm_0, self.conv_0), (self.norm_1, self.conv_1))
def __init__(self, nin, nout): self.shortcut = Conv2D(nin, nout, strides=1, k=3) self.conv1 = Conv2D(nin, nout, strides=1, k=3) self.conv2 = Conv2D(nout, nout, strides=1, k=3) self.skip_conv = Conv2D(2 * nout, nout, strides=1, k=1)
def __init__(self, nin, nout, k): self.shortcut = Conv2D(nin, nout, strides=2, k=3) self.conv1 = Conv2D(nin, nout, strides=2, k=k) self.conv2 = Conv2D(nout, nout, strides=1, k=3)