def CifarBasicBlockv2(planes, stride=1, option="A", normalization_method=None, use_fixup=False, num_layers=None, w_init=None, actfn=stax.Relu): assert not use_fixup, "nah" Main = stax.serial( maybe_use_normalization(normalization_method), actfn, Conv(planes, (3, 3), strides=(stride, stride), padding="SAME", W_init=w_init, bias=False), maybe_use_normalization(normalization_method), actfn, Conv(planes, (3, 3), padding="SAME", W_init=w_init, bias=False), ) Shortcut = Identity if stride > 1: if option == "A": # For CIFAR10 ResNet paper uses option A. Shortcut = LambdaLayer(_shortcut_pad) elif option == "B": Shortcut = Conv(planes, (1, 1), strides=(stride, stride), W_init=w_init, bias=False) return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum)
def CifarBasicBlock(planes, stride=1, option="A", normalization_method=None, use_fixup=False, num_layers=None, w_init=None, actfn=stax.Relu): Main = stax.serial( FixupBias() if use_fixup else Identity, Conv(planes, (3, 3), strides=(stride, stride), padding="SAME", W_init=fixup_init(num_layers) if use_fixup else w_init, bias=False), maybe_use_normalization(normalization_method), FixupBias() if use_fixup else Identity, actfn, FixupBias() if use_fixup else Identity, Conv(planes, (3, 3), padding="SAME", bias=False, W_init=zeros if use_fixup else w_init), maybe_use_normalization(normalization_method), FixupScale() if use_fixup else Identity, FixupBias() if use_fixup else Identity, ) Shortcut = Identity if stride > 1: if option == "A": # For CIFAR10 ResNet paper uses option A. Shortcut = stax.serial( # FixupBiast() if use_fixup else Identity, LambdaLayer(_shortcut_pad)) elif option == "B": Shortcut = stax.serial( FixupBias() if use_fixup else Identity, Conv(planes, (1, 1), strides=(stride, stride), W_init=w_init, bias=False), maybe_use_normalization(normalization_method)) return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, actfn)
def BasicBlock(planes, stride=1, downsample=None, base_width=64, norm_layer=Identity, actfn=stax.Relu): if base_width != 64: raise ValueError("BasicBlock only supports base_width=64") Main = stax.serial( Conv(planes, (3, 3), strides=(stride, stride), padding="SAME", bias=False), norm_layer, actfn, Conv(planes, (3, 3), padding="SAME", bias=False), norm_layer, ) Shortcut = downsample if downsample is not None else Identity return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, actfn)
def BottleneckBlock(planes, stride=1, downsample=None, base_width=64, norm_layer=Identity, actfn=stax.Relu): width = int(planes * (base_width / 64.)) Main = stax.serial( Conv(width, (1, 1), bias=False), norm_layer, actfn, Conv(width, (3, 3), strides=(stride, stride), padding="SAME", bias=False), norm_layer, actfn, Conv(planes * 4, (1, 1), bias=False), norm_layer, ) Shortcut = downsample if downsample is not None else Identity return stax.serial(FanOut(2), stax.parallel(Main, Shortcut), FanInSum, actfn)
def ResNet(block, expansion, layers, normalization_method=None, width_per_group=64, actfn=stax.Relu): norm_layer = Identity if normalization_method == "group_norm": norm_layer = GroupNorm(32) elif normalization_method == "batch_norm": norm_layer = BatchNorm() base_width = width_per_group def _make_layer(block, planes, blocks, stride=1): downsample = None if stride != 1: downsample = stax.serial( Conv(planes * expansion, (1, 1), strides=(stride, stride), bias=False), norm_layer, ) layers = [] layers.append(block(planes, stride, downsample, base_width, norm_layer)) for _ in range(1, blocks): layers.append( block(planes, base_width=base_width, norm_layer=norm_layer, actfn=actfn)) return stax.serial(*layers) return [ Conv(64, (3, 3), strides=(1, 1), padding="SAME", bias=False), norm_layer, actfn, # MaxPool((3, 3), strides=(2, 2), padding="SAME"), _make_layer(block, 64, layers[0]), _make_layer(block, 128, layers[1], stride=2), _make_layer(block, 256, layers[2], stride=2), _make_layer(block, 512, layers[3], stride=2), AvgPool((4, 4)), Flatten, ]
def _make_layer(block, planes, blocks, stride=1): downsample = None if stride != 1: downsample = stax.serial( Conv(planes * expansion, (1, 1), strides=(stride, stride), bias=False), norm_layer, ) layers = [] layers.append(block(planes, stride, downsample, base_width, norm_layer)) for _ in range(1, blocks): layers.append( block(planes, base_width=base_width, norm_layer=norm_layer, actfn=actfn)) return stax.serial(*layers)
def CifarResNet(block, num_blocks, expansion=1, option="A", normalization_method=None, use_fixup=False, init=None, actfn=stax.Relu): w_init = None if init == "he": w_init = partial(variance_scaling, 2.0, "fan_out", "truncated_normal")() num_layers = sum(num_blocks) def _make_layer(block, planes, num_blocks, stride): strides = [stride] + [1] * (num_blocks - 1) layers = [] for stride in strides: layers.append( block(planes, stride, option=option, normalization_method=normalization_method, use_fixup=use_fixup, num_layers=num_layers, w_init=w_init, actfn=actfn)) return stax.serial(*layers) return [ Conv(16 * expansion, (3, 3), padding="SAME", W_init=w_init, bias=False), maybe_use_normalization(normalization_method), FixupBias() if use_fixup else Identity, actfn, _make_layer(block, 16 * expansion, num_blocks[0], stride=1), _make_layer(block, 32 * expansion, num_blocks[1], stride=2), _make_layer(block, 64 * expansion, num_blocks[2], stride=2), AvgPoolAll(), Flatten, FixupBias() if use_fixup else Identity, ]