Esempio n. 1
0
 def __init__(self):
     self.model = Sequential([
         Conv2D(1, 8, strides=2, k=5, use_bias=False),
         BatchNorm2D(8), gelu,
         Conv2D(8, 16, strides=2, k=3, use_bias=False),
         BatchNorm2D(16), gelu,
         Conv2D(16, 32, strides=2, k=3, use_bias=False),
         BatchNorm2D(32), gelu,
         Conv2D(32, 64, strides=2, k=3, use_bias=False),
         BatchNorm2D(64), gelu,
         Conv2D(64, 1, strides=1, k=1, use_bias=False)
     ])  # logits
Esempio n. 2
0
    def __init__(
            self,
            in_channels: int,
            num_classes: int,
            blocks_per_group: Sequence[int],
            bottleneck: bool = True,
            channels_per_group: Sequence[int] = (256, 512, 1024, 2048),
            group_strides: Sequence[int] = (1, 2, 2, 2),
            group_use_projection: Sequence[bool] = (True, True, True, True),
            normalization_fn: Callable[...,
                                       objax.Module] = objax.nn.BatchNorm2D,
            activation_fn: Callable[[JaxArray],
                                    JaxArray] = objax.functional.relu):
        """Creates ResNetV2 instance.

        Args:
            in_channels: number of channels in the input image.
            num_classes: number of output classes.
            blocks_per_group: number of blocks in each block group.
            bottleneck: if True then use bottleneck blocks.
            channels_per_group: number of output channels for each block group.
            group_strides: strides for each block group.
            normalization_fn: module which used as normalization function.
            activation_fn: activation function.
        """
        assert len(channels_per_group) == len(blocks_per_group)
        assert len(group_strides) == len(blocks_per_group)
        assert len(group_use_projection) == len(blocks_per_group)
        nin = in_channels
        nout = 64
        ops = [
            Conv2D(nin, nout, k=7, strides=2, **conv_args(7, 64, (3, 3))),
            functools.partial(jn.pad,
                              pad_width=((0, 0), (0, 0), (1, 1), (1, 1))),
            functools.partial(objax.functional.max_pool_2d,
                              size=3,
                              strides=2,
                              padding=ConvPadding.VALID)
        ]
        for i in range(len(blocks_per_group)):
            nin = nout
            nout = channels_per_group[i]
            ops.append(
                ResNetV2BlockGroup(nin,
                                   nout,
                                   num_blocks=blocks_per_group[i],
                                   stride=group_strides[i],
                                   bottleneck=bottleneck,
                                   use_projection=group_use_projection[i],
                                   normalization_fn=normalization_fn,
                                   activation_fn=activation_fn))

        ops.extend([
            normalization_fn(nout), activation_fn, lambda x: x.mean((2, 3)),
            objax.nn.Linear(nout,
                            num_classes,
                            w_init=objax.nn.init.xavier_normal)
        ])
        super().__init__(ops)
Esempio n. 3
0
    def __init__(self, nin, nclass, scales, filters, filters_max):
        def nl(x):
            """Return tanh as activation function. Tanh has better utility for
            differentially private SGD https://arxiv.org/abs/2007.14191 .
            """
            return tanh(x)

        def nf(scale):
            return min(filters_max, filters << scale)

        ops = [Conv2D(nin, nf(0), 3), nl]
        for i in range(scales):
            ops.extend([
                Conv2D(nf(i), nf(i), 3), nl,
                Conv2D(nf(i), nf(i + 1), 3), nl,
                partial(average_pool_2d, size=2, strides=2)
            ])
        ops.extend([Conv2D(nf(scales), nclass, 3), lambda x: x.mean((2, 3))])
        super().__init__(ops)
Esempio n. 4
0
    def __init__(self):

        num_channels = 4  # 3 from RGB_t1 + 1 from dither_t0

        self.encoders = objax.ModuleList()
        k = 7
        for num_output_channels in [32, 64, 128, 128]:
            self.encoders.append(
                EncoderBlock(num_channels, num_output_channels, k))
            k = 3
            num_channels = num_output_channels

        self.decoders = objax.ModuleList()
        for num_output_channels in [128, 64, 32, 16]:
            self.decoders.append(
                DecoderBlock(num_channels, num_output_channels))
            num_channels = num_output_channels

        self.logits = Conv2D(num_channels,
                             nout=1,
                             strides=1,
                             k=1,
                             w_init=xavier_normal)
Esempio n. 5
0
    def __init__(self,
                 nin: int,
                 nout: int,
                 stride: Union[int, Sequence[int]],
                 use_projection: bool,
                 bottleneck: bool,
                 normalization_fn: Callable[..., objax.Module] = objax.nn.BatchNorm2D,
                 activation_fn: Callable[[JaxArray], JaxArray] = objax.functional.relu):
        """Creates ResNetV2Block instance.

        Args:
            nin: number of input filters.
            nout: number of output filters.
            stride: stride for 3x3 convolution and projection convolution in this block.
            use_projection: if True then include projection convolution into this block.
            bottleneck: if True then make bottleneck block.
            normalization_fn: module which used as normalization function.
            activation_fn: activation function.
        """
        self.use_projection = use_projection
        self.activation_fn = activation_fn
        self.stride = stride

        if self.use_projection:
            self.proj_conv = Conv2D(nin, nout, 1, strides=stride, **conv_args(1, nout))

        if bottleneck:
            self.norm_0 = normalization_fn(nin)
            self.conv_0 = Conv2D(nin, nout // 4, 1, strides=1, **conv_args(1, nout // 4))
            self.norm_1 = normalization_fn(nout // 4)
            self.conv_1 = Conv2D(nout // 4, nout // 4, 3, strides=stride, **conv_args(3, nout // 4, (1, 1)))
            self.norm_2 = normalization_fn(nout // 4)
            self.conv_2 = Conv2D(nout // 4, nout, 1, strides=1, **conv_args(1, nout))
            self.layers = ((self.norm_0, self.conv_0), (self.norm_1, self.conv_1), (self.norm_2, self.conv_2))
        else:
            self.norm_0 = normalization_fn(nin)
            self.conv_0 = Conv2D(nin, nout, 3, strides=1, **conv_args(3, nout, (1, 1)))
            self.norm_1 = normalization_fn(nout)
            self.conv_1 = Conv2D(nout, nout, 3, strides=stride, **conv_args(3, nout, (1, 1)))
            self.layers = ((self.norm_0, self.conv_0), (self.norm_1, self.conv_1))
Esempio n. 6
0
 def __init__(self, nin, nout):
     self.shortcut = Conv2D(nin, nout, strides=1, k=3)
     self.conv1 = Conv2D(nin, nout, strides=1, k=3)
     self.conv2 = Conv2D(nout, nout, strides=1, k=3)
     self.skip_conv = Conv2D(2 * nout, nout, strides=1, k=1)
Esempio n. 7
0
 def __init__(self, nin, nout, k):
     self.shortcut = Conv2D(nin, nout, strides=2, k=3)
     self.conv1 = Conv2D(nin, nout, strides=2, k=k)
     self.conv2 = Conv2D(nout, nout, strides=1, k=3)