コード例 #1
0
ファイル: model.py プロジェクト: matpalm/large_batch
def haiku_model(x, dense_kernel_size=64, max_conv_size=256, num_classes=10):
    layers = []
    for i, c in enumerate([32, 64, 128, 256]):
        c = min(c, max_conv_size)
        layers.append(
            hk.Conv2D(output_channels=c,
                      kernel_shape=3,
                      stride=2,
                      name="conv%d_%d" % (i, c)))
        layers.append(jax.nn.gelu)
    layers += [
        global_spatial_mean_pooling,
        hk.Linear(dense_kernel_size, name="dense_%d" % dense_kernel_size),
        jax.nn.gelu,
        hk.Linear(num_classes, name='logits')
    ]
    return hk.Sequential(layers)(x)
コード例 #2
0
ファイル: models.py プロジェクト: yuanbochd/google-research
def _resnet_layer(inputs,
                  num_filters,
                  normalization_layer,
                  kernel_size=3,
                  strides=1,
                  activation=lambda x: x,
                  use_bias=True,
                  is_training=True):
    x = inputs
    x = hk.Conv2D(num_filters,
                  kernel_size,
                  stride=strides,
                  padding="same",
                  w_init=he_normal,
                  with_bias=use_bias)(x)
    x = normalization_layer()(x, is_training=is_training)
    x = activation(x)
    return x
コード例 #3
0
    def __call__(self, image, debug=False): 
        """
        if debug, then print activation shapes
        """
        # TODO: output should have length self.n_classes
#        conv_layers = self.depth * [hk.Conv2D(self.n_channels,
#                                              kernel_shape=3,
#                                              w_init=self.initializer,
#                                              b_init=self.initializer,
#                                              stride=2),
#                                    jax.nn.relu]
#        convnet = hk.Sequential(conv_layers + [hk.Flatten()])

        with_bias = False
        strides = [1,2,1,2,1,2]
        names = ['misc'] + ['conv']*5
        
        conv_layers = [
            [
                hk.Conv2D(self.n_channels,
                        kernel_shape=3,
                        w_init=self.initializer,
                        b_init=self.initializer,
                        with_bias=with_bias,
                        stride=stride,
                        name=name),
                jax.nn.relu,
                debug_layer(debug),
            ]
            for stride, name in zip(strides, names)
        ]

        conv_layers = [l for layer in conv_layers for l in layer]
        convnet = hk.Sequential(conv_layers + [
            hk.Flatten(),
            hk.Linear(self.n_classes,
                      w_init=self.initializer,
                      b_init=self.initializer,
                      name='misc'),
            debug_layer(debug),
        ])

        return convnet(image)
コード例 #4
0
    def __init__(
        self,
        num_layers: int = 1,
        num_channels: int = 64,
        use_batchnorm: bool = True,
        bn_config: Optional[Mapping[str, float]] = None,
        name: Optional[str] = None,
    ):
        """Constructs a Conv2DDownsample model.

    Args:
      num_layers: The number of conv->max_pool layers.
      num_channels: The number of conv output channels.
      use_batchnorm: Whether to use batchnorm.
      bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to be
        passed on to the :class:`~haiku.BatchNorm` layers. By default the
        ``decay_rate`` is ``0.9`` and ``eps`` is ``1e-5``.
      name: Name of the module.
    """
        super().__init__(name=name)

        self._num_layers = num_layers
        self._use_batchnorm = use_batchnorm

        bn_config = dict(bn_config or {})
        bn_config.setdefault('decay_rate', 0.9)
        bn_config.setdefault('eps', 1e-5)
        bn_config.setdefault('create_scale', True)
        bn_config.setdefault('create_offset', True)

        self.layers = []
        for _ in range(self._num_layers):
            conv = hk.Conv2D(output_channels=num_channels,
                             kernel_shape=7,
                             stride=2,
                             with_bias=False,
                             padding='SAME',
                             name='conv')
            if use_batchnorm:
                batchnorm = hk.BatchNorm(name='batchnorm', **bn_config)
            else:
                batchnorm = None
            self.layers.append(dict(conv=conv, batchnorm=batchnorm))
コード例 #5
0
    def __init__(self, num_blocks=[5, 5, 5], num_classes=10):
        super().__init__()

        self.conv1 = hk.Conv2D(
            output_channels=16,
            # output_channels=64,
            kernel_shape=3,
            stride=1,
            with_bias=False,
            padding='SAME',
            data_format='NCHW')
        self.bn1 = hk.BatchNorm(create_scale=True,
                                create_offset=True,
                                decay_rate=0.9,
                                data_format='NC...')
        self.layer1 = MultiBlock(16, 16, [1] + [1] * (num_blocks[0] - 1))
        self.layer2 = MultiBlock(16, 32, [2] + [1] * (num_blocks[1] - 1))
        self.layer3 = MultiBlock(32, 64, [2] + [1] * (num_blocks[2] - 1))
        self.linear = hk.Linear(num_classes)
コード例 #6
0
    def __init__(self,
                 num_classes: int = 10,
                 depth: int = 28,
                 width: int = 10,
                 activation: str = 'relu',
                 norm_args: Optional[Dict[str, Any]] = None,
                 name: Optional[str] = None):
        super(WideResNet, self).__init__(name=name)
        if (depth - 4) % 6 != 0:
            raise ValueError('depth should be 6n+4.')
        self._activation = getattr(jax.nn, activation)
        if norm_args is None:
            norm_args = {
                'create_offset': True,
                'create_scale': True,
                'decay_rate': .99,
            }
        self._conv = hk.Conv2D(output_channels=16,
                               kernel_shape=(3, 3),
                               stride=1,
                               with_bias=False,
                               name='init_conv')  # pytype: disable=not-callable
        self._bn = hk.BatchNorm(name='batchnorm', **norm_args)
        self._linear = hk.Linear(num_classes, w_init=jnp.zeros, name='logits')

        blocks_per_layer = (depth - 4) // 6
        filter_sizes = [width * n for n in [16, 32, 64]]
        self._blocks = []
        for layer_num, filter_size in enumerate(filter_sizes):
            blocks_of_layer = []
            for i in range(blocks_per_layer):
                stride = 2 if (layer_num != 0 and i == 0) else 1
                projection_shortcut = (i == 0)
                blocks_of_layer.append(
                    _WideResNetBlock(num_filters=filter_size,
                                     stride=stride,
                                     projection_shortcut=projection_shortcut,
                                     activation=self._activation,
                                     norm_args=norm_args,
                                     name='resnet_lay_{}_block_{}'.format(
                                         layer_num, i)))
            self._blocks.append(blocks_of_layer)
コード例 #7
0
ファイル: models_haiku.py プロジェクト: ldsec/projects-data
    def __call__(self, x: jnp.ndarray, is_train: bool):
        x = hk.Conv2D(output_channels=64, kernel_shape=3, padding='VALID')(x)
        if self._interval:
            x = self._act_fn(x, self._interval)
        else:
            x = self._act_fn(x)
        x = hk.Conv2D(output_channels=64, kernel_shape=3, padding='VALID')(x)
        if self._interval:
            x = self._act_fn(x, self._interval)
        else:
            x = self._act_fn(x)
        x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x)
        x = Dropout(0.2, self._seed)(x, is_train)
        x = hk.Conv2D(output_channels=96, kernel_shape=3, padding='VALID')(x)
        if self._interval:
            x = self._act_fn(x, self._interval)
        else:
            x = self._act_fn(x)
        x = hk.Conv2D(output_channels=96, kernel_shape=3, padding='VALID')(x)
        if self._interval:
            x = self._act_fn(x, self._interval)
        else:
            x = self._act_fn(x)
        x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x)
        x = Dropout(0.3, self._seed)(x, is_train)
        x = hk.Conv2D(output_channels=128, kernel_shape=3, padding='VALID')(x)
        if self._interval:
            x = self._act_fn(x, self._interval)
        else:
            x = self._act_fn(x)
        x = hk.Conv2D(output_channels=128, kernel_shape=3, padding='VALID')(x)
        if self._interval:
            x = self._act_fn(x, self._interval)
        else:
            x = self._act_fn(x)
        x = hk.AvgPool(window_shape=2, strides=1, padding='VALID')(x)
        x = Dropout(0.4, self._seed)(x, is_train)
        x = hk.Flatten()(x)
        x = hk.Linear(128)(x)
        if self._interval:
            x = self._act_fn(x, self._interval)
        else:
            x = self._act_fn(x)
        x = Dropout(0.5, self._seed)(x, is_train)
        x = hk.Linear(self._num_classes)(x)

        return x
コード例 #8
0
 def __init__(self):
     super().__init__()
     # Block 1
     self.conv1_1 = hk.Conv2D(32, 3, w_init=Vscaling(2.0))
     self.bn1_1 = hk.BatchNorm(True, True, 0.99)
     self.conv1_2 = hk.Conv2D(32, 3, w_init=Vscaling(2.0))
     self.bn1_2 = hk.BatchNorm(True, True, 0.99)
     # Block 2
     self.conv2_1 = hk.Conv2D(64, 3, w_init=Vscaling(2.0))
     self.bn2_1 = hk.BatchNorm(True, True, 0.99)
     self.conv2_2 = hk.Conv2D(64, 3, w_init=Vscaling(2.0))
     self.bn2_2 = hk.BatchNorm(True, True, 0.99)
     # Block 2
     self.conv3_1 = hk.Conv2D(128, 3, w_init=Vscaling(2.0))
     self.bn3_1 = hk.BatchNorm(True, True, 0.99)
     self.conv3_2 = hk.Conv2D(128, 3, w_init=Vscaling(2.0))
     self.bn3_2 = hk.BatchNorm(True, True, 0.99)
     # Linear part
     self.lin1 = hk.Linear(128, w_init=Vscaling(2.0))
     self.bn4 = hk.BatchNorm(True, True, 0.99)
     self.lin2 = hk.Linear(10, w_init=Vscaling(1.0, "fan_avg"))
コード例 #9
0
    def __init__(self):
        super().__init__()

        bn_config = {
            'create_scale': True,
            'create_offset': True,
            'decay_rate': 0.999
        }
        ## Definition of the modules.
        self.conv_block = hk.Sequential([
            hk.Conv2D(32, (3, 3), stride=3, rate=1),
            jax.nn.relu,
            hk.Conv2D(32, (3, 3), stride=3, rate=1),
            jax.nn.relu,
            hk.Conv2D(64, (3, 3), stride=3, rate=1),
            jax.nn.relu,
        ])

        self.conv_res_block = hk.Sequential([
            hk.Conv2D(32, (1, 1), stride=1, rate=1),
            jax.nn.relu,
            hk.Conv2D(32, (1, 1), stride=1, rate=1),
            jax.nn.relu,
            hk.Conv2D(64, (1, 1), stride=1, rate=1),
            jax.nn.relu,
        ])

        self.reshape_mod = hk.Flatten()

        self.lin_res_block = [
            (hk.Linear(128), hk.BatchNorm(name='lin_batchnorm_0',
                                          **bn_config)),
            (hk.Linear(256), hk.BatchNorm(name='lin_batchnorm_1', **bn_config))
        ]

        self.final_linear = hk.Linear(10)
コード例 #10
0
ファイル: networks.py プロジェクト: yynst2/deepmind-research
    def __init__(
        self,
        blocks_per_group: Sequence[int],
        num_classes: Optional[int] = None,
        bn_config: Optional[Mapping[str, float]] = None,
        resnet_v2: bool = False,
        bottleneck: bool = True,
        channels_per_group: Sequence[int] = (256, 512, 1024, 2048),
        use_projection: Sequence[bool] = (True, True, True, True),
        width_multiplier: int = 1,
        name: Optional[str] = None,
    ):
        """Constructs a ResNet model.

    Args:
      blocks_per_group: A sequence of length 4 that indicates the number of
        blocks created in each group.
      num_classes: The number of classes to classify the inputs into.
      bn_config: A dictionary of three elements, `decay_rate`, `eps`, and
        `cross_replica_axis`, to be passed on to the `BatchNorm` layers. By
        default the `decay_rate` is `0.9` and `eps` is `1e-5`, and the axis is
        `None`.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to
        False.
       bottleneck: Whether the block should bottleneck or not. Defaults to True.
      channels_per_group: A sequence of length 4 that indicates the number
        of channels used for each block in each group.
      use_projection: A sequence of length 4 that indicates whether each
        residual block should use projection.
      width_multiplier: An integer multiplying the number of channels per group.
      name: Name of the module.
    """
        super().__init__(name=name)
        self.resnet_v2 = resnet_v2

        bn_config = dict(bn_config or {})
        bn_config.setdefault('decay_rate', 0.9)
        bn_config.setdefault('eps', 1e-5)
        bn_config.setdefault('create_scale', True)
        bn_config.setdefault('create_offset', True)

        # Number of blocks in each group for ResNet.
        check_length(4, blocks_per_group, 'blocks_per_group')
        check_length(4, channels_per_group, 'channels_per_group')

        self.initial_conv = hk.Conv2D(output_channels=64 * width_multiplier,
                                      kernel_shape=7,
                                      stride=2,
                                      with_bias=False,
                                      padding='SAME',
                                      name='initial_conv')

        if not self.resnet_v2:
            self.initial_batchnorm = hk.BatchNorm(name='initial_batchnorm',
                                                  **bn_config)

        self.block_groups = []
        strides = (1, 2, 2, 2)
        for i in range(4):
            self.block_groups.append(
                hk.nets.ResNet.BlockGroup(channels=width_multiplier *
                                          channels_per_group[i],
                                          num_blocks=blocks_per_group[i],
                                          stride=strides[i],
                                          bn_config=bn_config,
                                          resnet_v2=resnet_v2,
                                          bottleneck=bottleneck,
                                          use_projection=use_projection[i],
                                          name='block_group_%d' % (i)))

        if self.resnet_v2:
            self.final_batchnorm = hk.BatchNorm(name='final_batchnorm',
                                                **bn_config)

        self.logits = hk.Linear(num_classes, w_init=jnp.zeros, name='logits')
コード例 #11
0
ファイル: crownibp_test.py プロジェクト: zeta1999/jax_verify
 def conv2d_model(inp):
   return hk.Conv2D(output_channels=1, kernel_shape=(2, 2),
                    padding='VALID', stride=1, with_bias=True)(inp)
コード例 #12
0
    def __init__(self,
                 depth=50,
                 num_classes: Optional[int] = 1000,
                 width_mult: int = 1,
                 normalize_fn: Optional[types.NormalizeFn] = None,
                 name: Optional[Text] = None,
                 remat: bool = False):
        """Creates ResNetV2 Haiku module.

    Args:
      depth: depth of the desired ResNet (18, 34, 50, 101, 152 or 202).
      num_classes: (int) Number of outputs in final layer. If None will not add
        a classification head and will return the output embedding.
      width_mult: multiplier for channel width.
      normalize_fn: normalization function, see helpers/utils.py
      name: Name of the module.
      remat: Whether to rematerialize intermediate activations (saves memory).
    """
        super(ResNetV2, self).__init__(name=name)
        self._normalize_fn = normalize_fn
        self._num_classes = num_classes
        self._width_mult = width_mult

        self._strides = [1, 2, 2, 2]
        num_blocks = {
            18: [2, 2, 2, 2],
            34: [3, 4, 6, 3],
            50: [3, 4, 6, 3],
            101: [3, 4, 23, 3],
            152: [3, 8, 36, 3],
            200: [3, 24, 36, 3],
        }
        if depth not in num_blocks:
            raise ValueError(
                f'`depth` should be in {list(num_blocks.keys())} ({depth} given).'
            )
        self._num_blocks = num_blocks[depth]

        if depth >= 50:
            self._block_module = BottleneckBlock
            self._channels = [256, 512, 1024, 2048]
        else:
            self._block_module = BasicBlock
            self._channels = [64, 128, 256, 512]

        self._initial_conv = hk.Conv2D(output_channels=64 * self._width_mult,
                                       kernel_shape=7,
                                       stride=2,
                                       with_bias=False,
                                       padding='SAME',
                                       name='initial_conv')

        if remat:
            self._initial_conv = hk.remat(self._initial_conv)

        self._block_groups = []
        for i in range(4):
            self._block_groups.append(
                ResNetUnit(channels=self._channels[i] * self._width_mult,
                           num_blocks=self._num_blocks[i],
                           block_module=self._block_module,
                           stride=self._strides[i],
                           normalize_fn=self._normalize_fn,
                           name='block_group_%d' % i,
                           remat=remat))

        if num_classes is not None:
            self._logits_layer = hk.Linear(output_size=num_classes,
                                           w_init=jnp.zeros,
                                           name='logits')
コード例 #13
0
ファイル: mnist.py プロジェクト: jirufengyu/ode
 def __init__(self, **kwargs):
     super(ConcatConv2D, self).__init__()
     self._layer = hk.Conv2D(**kwargs)
コード例 #14
0
    def __init__(self,
                 blocks_per_group,
                 bn_config,
                 bottleneck,
                 channels_per_group,
                 use_projection,
                 strides,
                 n_output_channels=1,
                 use_bn=True,
                 pad_crop=False,
                 name=None):
        """Constructs a Residual UNet model based on a traditional ResNet.
    Args:
      blocks_per_group: A sequence of length 4 that indicates the number of
        blocks created in each group.
      bn_config: A dictionary of two elements, ``decay_rate`` and ``eps`` to be
        passed on to the :class:`~haiku.BatchNorm` layers. By default the
        ``decay_rate`` is ``0.9`` and ``eps`` is ``1e-5``.
      resnet_v2: Whether to use the v1 or v2 ResNet implementation. Defaults to
        ``False``.
      bottleneck: Whether the block should bottleneck or not. Defaults to
        ``True``.
      channels_per_group: A sequence of length 4 that indicates the number
        of channels used for each block in each group.
      use_projection: A sequence of length 4 that indicates whether each
        residual block should use projection.
      n_output_channels: The number of output channels, for example to change in
        the case of a complex denoising. Defaults to 1.
      use_bn: Whether the network should use batch normalisation. Defaults to
        ``True``.
      pad_crop: Whether to use cropping/padding to make sure the images can be
        downsampled and upsampled correctly. Defaults to ``False``.
      name: Name of the module.
    """
        super().__init__(name=name)
        self.resnet_v2 = False
        self.use_bn = use_bn
        self.pad_crop = pad_crop
        self.n_output_channels = n_output_channels
        bn_config = dict(bn_config or {})
        bn_config.setdefault("decay_rate", 0.9)
        bn_config.setdefault("eps", 1e-5)
        bn_config.setdefault("create_scale", True)
        bn_config.setdefault("create_offset", True)
        self.strides = strides
        bl = len(self.strides)

        # Number of blocks in each group for ResNet.
        check_length(bl, blocks_per_group, "blocks_per_group")
        check_length(bl, channels_per_group, "channels_per_group")
        self.upsampling = upsample
        self.pooling = hk.AvgPool(window_shape=2, strides=2, padding='SAME')

        self.initial_conv = hk.Conv2D(output_channels=32,
                                      kernel_shape=7,
                                      stride=1,
                                      with_bias=not self.use_bn,
                                      padding="SAME",
                                      name="initial_conv")

        if not self.resnet_v2 and self.use_bn:
            self.initial_batchnorm = hk.BatchNorm(name="initial_batchnorm",
                                                  **bn_config)

        self.block_groups = []
        self.up_block_groups = []
        for i in range(bl):
            self.block_groups.append(
                BlockGroup(channels=channels_per_group[i],
                           num_blocks=blocks_per_group[i],
                           stride=strides[i],
                           bn_config=bn_config,
                           bottleneck=bottleneck,
                           use_projection=use_projection[i],
                           transpose=False,
                           use_bn=self.use_bn,
                           name="block_group_%d" % (i)))

        for i in range(bl):
            self.up_block_groups.append(
                BlockGroup(channels=channels_per_group[i],
                           num_blocks=blocks_per_group[i],
                           stride=strides[i],
                           bn_config=bn_config,
                           bottleneck=bottleneck,
                           use_projection=use_projection[i],
                           transpose=True,
                           use_bn=self.use_bn,
                           name="up_block_group_%d" % (i)))

        if self.resnet_v2 and self.use_bn:
            self.final_batchnorm = hk.BatchNorm(name="final_batchnorm",
                                                **bn_config)

        self.final_conv = hk.Conv2D(
            output_channels=self.n_output_channels,
            kernel_shape=5,
            stride=1,
            padding='SAME',
            name='final_conv',
        )
コード例 #15
0
    def __init__(self,
                 channels: int,
                 stride: Union[int, Sequence[int]],
                 use_projection: bool,
                 bn_config: Mapping[str, float],
                 bottleneck: bool,
                 transpose: bool = False,
                 name: Optional[str] = None):
        super().__init__(name=name)
        self.use_projection = use_projection

        bn_config = dict(bn_config)
        bn_config.setdefault("create_scale", True)
        bn_config.setdefault("create_offset", True)
        bn_config.setdefault("decay_rate", 0.999)

        if transpose:
            maybe_transposed_conv = hk.Conv2DTranspose
        else:
            maybe_transposed_conv = hk.Conv2D

        if self.use_projection:
            self.proj_conv = maybe_transposed_conv(output_channels=channels,
                                                   kernel_shape=1,
                                                   stride=stride,
                                                   with_bias=False,
                                                   padding="SAME",
                                                   name="shortcut_conv")

            self.proj_batchnorm = hk.BatchNorm(name="shortcut_batchnorm",
                                               **bn_config)

        channel_div = 4 if bottleneck else 1
        conv_0 = hk.Conv2D(output_channels=channels // channel_div,
                           kernel_shape=1 if bottleneck else 3,
                           stride=1,
                           with_bias=False,
                           padding="SAME",
                           name="conv_0")
        bn_0 = hk.BatchNorm(name="batchnorm_0", **bn_config)

        conv_1 = maybe_transposed_conv(output_channels=channels // channel_div,
                                       kernel_shape=3,
                                       stride=stride,
                                       with_bias=False,
                                       padding="SAME",
                                       name="conv_1")

        bn_1 = hk.BatchNorm(name="batchnorm_1", **bn_config)
        layers = ((conv_0, bn_0), (conv_1, bn_1))

        if bottleneck:
            conv_2 = hk.Conv2D(output_channels=channels,
                               kernel_shape=1,
                               stride=1,
                               with_bias=False,
                               padding="SAME",
                               name="conv_2")

            bn_2 = hk.BatchNorm(name="batchnorm_2",
                                scale_init=jnp.zeros,
                                **bn_config)
            layers = layers + ((conv_2, bn_2), )

        self.layers = layers
コード例 #16
0
    def __call__(self,
                 x: Tensor,
                 training: bool = False) -> Tuple[Tensor, Tensor]:
        params = self.ssd_initial_weights
        x = aj.zoo.VGG16(include_top=False,
                         pretrained=False,
                         initial_weights=self.backbone_initial_weights,
                         output_feature_maps=True)(x)

        conv4_3, x = x[-2:]
        conv4_3 = aj.nn.layers.L2Norm(init_fn=(
            hk.initializers.Constant(params['ssd/l2_norm']['gamma'])
            if params is not None else hk.initializers.Constant(20.)))(conv4_3)

        x = hk.MaxPool(window_shape=(1, 3, 3, 1), strides=1, padding='SAME')(x)

        # Replace fully connected by FCN
        conv_6 = hk.Conv2D(
            output_channels=1024,
            kernel_shape=3,
            stride=1,
            rate=6,
            w_init=(hk.initializers.Constant(params['ssd/conv2_d']['w'])
                    if params is not None else self.xavier_init_fn),
            b_init=(hk.initializers.Constant(params['ssd/conv2_d']['b'])
                    if params is not None else None),
            padding='SAME')(x)
        conv_6 = jax.nn.relu(conv_6)

        conv7 = hk.Conv2D(
            output_channels=1024,
            kernel_shape=1,
            stride=1,
            w_init=(hk.initializers.Constant(params['ssd/conv2_d_1']['w'])
                    if params is not None else self.xavier_init_fn),
            b_init=(hk.initializers.Constant(params['ssd/conv2_d_1']['b'])
                    if params is not None else None),
            padding='SAME')(conv_6)
        conv7 = jax.nn.relu(conv7)

        # Build additional features
        conv8_2 = self._additional_conv(conv7,
                                        256,
                                        512,
                                        stride=2,
                                        training=training,
                                        name='conv_8')
        conv9_2 = self._additional_conv(conv8_2,
                                        128,
                                        256,
                                        stride=2,
                                        training=training,
                                        name='conv_9')
        conv10_2 = self._additional_conv(conv9_2,
                                         128,
                                         256,
                                         stride=1,
                                         training=training,
                                         name='conv_10')
        conv11_2 = self._additional_conv(conv10_2,
                                         128,
                                         256,
                                         stride=1,
                                         training=training,
                                         name='conv_11')

        detection_features = [
            conv4_3, conv7, conv8_2, conv9_2, conv10_2, conv11_2
        ]

        detection_features = zip(itertools.cycle(self.k), detection_features)

        clf, reg = list(
            zip(*[
                self._head(o, k=k, name=f"fm_{i}")
                for i, (k, o) in enumerate(detection_features)
            ]))

        return np.concatenate(clf, axis=1), np.concatenate(reg, axis=1)
コード例 #17
0
    def __call__(self,
                 inputs: types.TensorLike,
                 is_training: bool = True) -> jnp.ndarray:
        """Connects the ResNetBlock module into the graph.

    Args:
      inputs: A 4-D float array of shape `[B, H, W, C]`.
      is_training: Whether to use training mode.

    Returns:
      A 4-D float array of shape
      `[B * num_frames, new_h, new_w, output_channels]`.
    """
        # ResNet V2 uses pre-activation, where the batch norm and relu are before
        # convolutions, rather than after as in ResNet V1.
        preact = inputs
        if self._normalize_fn is not None:
            preact = self._normalize_fn(preact, is_training=is_training)
        preact = jax.nn.relu(preact)

        if self._use_projection:
            shortcut = hk.Conv2D(output_channels=self._output_channels,
                                 kernel_shape=1,
                                 stride=self._stride,
                                 with_bias=False,
                                 padding='SAME',
                                 name='shortcut_conv')(preact)
        else:
            shortcut = inputs

        # Eventually applies Temporal Shift Module.
        if self._channel_shift_fraction != 0:
            preact = tsmu.apply_temporal_shift(
                preact,
                tsm_mode=self._tsm_mode,
                num_frames=self._num_frames,
                channel_shift_fraction=self._channel_shift_fraction)

        # First convolution.
        residual = hk.Conv2D(self._bottleneck_channels,
                             kernel_shape=1,
                             stride=1,
                             with_bias=False,
                             padding='SAME',
                             name='conv_0')(preact)

        # Second convolution.
        if self._normalize_fn is not None:
            residual = self._normalize_fn(residual, is_training=is_training)
        residual = jax.nn.relu(residual)
        residual = hk.Conv2D(output_channels=self._bottleneck_channels,
                             kernel_shape=3,
                             stride=self._stride,
                             with_bias=False,
                             padding='SAME',
                             name='conv_1')(residual)

        # Third convolution.
        if self._normalize_fn is not None:
            residual = self._normalize_fn(residual, is_training=is_training)
        residual = jax.nn.relu(residual)
        residual = hk.Conv2D(output_channels=self._output_channels,
                             kernel_shape=1,
                             stride=1,
                             with_bias=False,
                             padding='SAME',
                             name='conv_2')(residual)

        # NOTE: we do not use block multiplier.
        output = shortcut + residual
        return output
コード例 #18
0
    def __call__(self,
                 inputs: types.TensorLike,
                 is_training: bool = True,
                 final_endpoint: str = 'Embeddings') -> jnp.ndarray:
        """Connects the TSM ResNetV2 module into the graph.

    Args:
      inputs: A 4-D float array of shape `[B, H, W, C]`.
      is_training: Whether to use training mode.
      final_endpoint: Up to which endpoint to run / return.

    Returns:
      Network output at location `final_endpoint`. A float array which shape
      depends on `final_endpoint`.

    Raises:
      ValueError: If `final_endpoint` is not recognized.
    """

        # Prepare inputs for TSM.
        inputs, tsm_mode, num_frames = tsmu.prepare_inputs(inputs)
        num_frames = num_frames or self._num_frames

        self._final_endpoint = final_endpoint
        if self._final_endpoint not in self.VALID_ENDPOINTS:
            raise ValueError(f'Unknown final endpoint {self._final_endpoint}')

        # Stem convolution.
        end_point = 'tsm_resnet_stem'
        net = hk.Conv2D(output_channels=64 * self._width_mult,
                        kernel_shape=7,
                        stride=2,
                        with_bias=False,
                        name=end_point,
                        padding='SAME')(inputs)
        net = hk.MaxPool(window_shape=(1, 3, 3, 1),
                         strides=(1, 2, 2, 1),
                         padding='SAME')(net)
        if self._final_endpoint == end_point:
            return net

        # Residual block.
        for unit_id, (channels, num_blocks, stride) in enumerate(
                zip(self._channels, self._num_blocks, self._strides)):
            end_point = f'tsm_resnet_unit_{unit_id}'
            net = TSMResNetUnit(
                output_channels=channels * self._width_mult,
                num_blocks=num_blocks,
                stride=stride,
                normalize_fn=self._normalize_fn,
                channel_shift_fraction=self._channel_shift_fraction,
                num_frames=num_frames,
                tsm_mode=tsm_mode,
                name=end_point)(net, is_training=is_training)
            if self._final_endpoint == end_point:
                return net

        if self._normalize_fn is not None:
            net = self._normalize_fn(net, is_training=is_training)
        net = jax.nn.relu(net)

        end_point = 'last_conv'
        if self._final_endpoint == end_point:
            return net
        net = jnp.mean(net, axis=(1, 2))
        # Prepare embedding outputs for TSM (temporal average of features).
        net = tsmu.prepare_outputs(net, tsm_mode, num_frames)
        assert self._final_endpoint == 'Embeddings'
        return net
コード例 #19
0
def conv(c):
    return hk.Conv2D(output_channels=c, kernel_shape=3, stride=2)
コード例 #20
0
ファイル: models.py プロジェクト: njunge94/jax-styletransfer
def augmented_vgg19(fp: str,
                    content_image: jnp.ndarray,
                    style_image: jnp.ndarray,
                    mean: jnp.ndarray,
                    std: jnp.ndarray,
                    content_layers: List[str] = None,
                    style_layers: List[str] = None,
                    pooling: str = "avg") -> hk.Sequential:
    """Build a VGG19 network augmented by content and style loss layers."""
    pooling = pooling.lower()
    if pooling not in ["avg", "max"]:
        raise ValueError("Pooling method not recognized. Options are: "
                         "\"avg\", \"max\".")

    params = get_model_params(fp=fp)

    # prepend a normalization layer
    layers = [Normalization(content_image, mean, std, "norm")]

    # tracks number of conv layers
    n = 0

    # desired depth layers to compute style/content losses :
    content_layers = content_layers or []
    style_layers = style_layers or []

    model = hk.Sequential(layers=layers)

    for k, p_dict in params.items():
        if "pool" in k:
            if pooling == "avg":
                # exactly as many as pools as convolutions
                layers.append(
                    hk.AvgPool(window_shape=2,
                               strides=2,
                               padding="VALID",
                               channel_axis=1,
                               name=f"avg_pool_{n}"))
            else:
                layers.append(
                    hk.MaxPool(window_shape=2,
                               strides=2,
                               padding="VALID",
                               channel_axis=1,
                               name=f"max_pool_{n}"))
        elif "conv" in k:
            n += 1
            name = f"conv_{n}"

            kernel_h, kernel_w, in_ch, out_ch = p_dict["w"].shape

            # VGG only has square conv kernels
            assert kernel_w == kernel_h, "VGG19 only has square conv kernels"
            kernel_shape = kernel_h

            layers.append(
                hk.Conv2D(output_channels=out_ch,
                          kernel_shape=kernel_shape,
                          stride=1,
                          padding="SAME",
                          data_format="NCHW",
                          w_init=hk.initializers.Constant(p_dict["w"]),
                          b_init=hk.initializers.Constant(p_dict["b"]),
                          name=name))

            if name in style_layers:
                model.layers = tuple(layers)
                style_target = model(style_image, is_training=False)
                layers.append(
                    StyleLoss(target=style_target, name=f"style_loss_{n}"))

            if name in content_layers:
                model.layers = tuple(layers)
                content_target = model(content_image, is_training=False)
                layers.append(
                    ContentLoss(target=content_target,
                                name=f"content_loss_{n}"))

            layers.append(jax.nn.relu)

    # this modifies our n from before in place
    for n in range(len(layers) - 1, -1, -1):
        if isinstance(layers[n], (StyleLoss, ContentLoss)):
            break

    # break off after last content loss layer
    layers = layers[:(n + 1)]

    model.layers = tuple(layers)

    return model
コード例 #21
0
     shape=(BATCH_SIZE, 2, 2)),
 ModuleDescriptor(
     name="ConvNDTranspose",
     create=lambda: hk.ConvNDTranspose(1, 3, 3),
     shape=(BATCH_SIZE, 2, 2)),
 ModuleDescriptor(
     name="Conv1D",
     create=lambda: hk.Conv1D(3, 3),
     shape=(BATCH_SIZE, 2, 2)),
 ModuleDescriptor(
     name="Conv1DTranspose",
     create=lambda: hk.Conv1DTranspose(3, 3),
     shape=(BATCH_SIZE, 2, 2)),
 ModuleDescriptor(
     name="Conv2D",
     create=lambda: hk.Conv2D(3, 3),
     shape=(BATCH_SIZE, 2, 2, 2)),
 ModuleDescriptor(
     name="Conv2DTranspose",
     create=lambda: hk.Conv2DTranspose(3, 3),
     shape=(BATCH_SIZE, 2, 2, 2)),
 ModuleDescriptor(
     name="Conv3D",
     create=lambda: hk.Conv3D(3, 3),
     shape=(BATCH_SIZE, 2, 2, 2, 2)),
 ModuleDescriptor(
     name="Conv3DTranspose",
     create=lambda: hk.Conv3DTranspose(3, 3),
     shape=(BATCH_SIZE, 2, 2, 2, 2)),
 ModuleDescriptor(
     name="DepthwiseConv2D",
コード例 #22
0
    def __init__(self,
                 channels: int,
                 stride: Union[int, Sequence[int]],
                 use_projection: bool,
                 bn_config: Mapping[str, float],
                 bottleneck: bool,
                 use_bn: bool = True,
                 transpose: bool = False,
                 name: Optional[str] = None):
        super().__init__(name=name)
        self.use_projection = use_projection
        self.use_bn = use_bn
        if self.use_bn:
            bn_config = dict(bn_config)
            bn_config.setdefault("create_scale", True)
            bn_config.setdefault("create_offset", True)
            bn_config.setdefault("decay_rate", 0.999)

        if transpose:
            self.pooling_or_upsampling = upsample
        else:
            self.pooling_or_upsampling = hk.AvgPool(window_shape=2,
                                                    strides=2,
                                                    padding='SAME')
        self.stride = stride

        if self.use_projection:
            # this is just used for the skip connection
            self.proj_conv = hk.Conv2D(
                output_channels=channels,
                kernel_shape=1,
                # depending on whether it's transpose or not this stride must be
                # replaced by upsampling or avg pooling
                stride=1,
                with_bias=not self.use_bn,
                padding="SAME",
                name="shortcut_conv")
            if self.use_bn:
                self.proj_batchnorm = hk.BatchNorm(name="shortcut_batchnorm",
                                                   **bn_config)

        channel_div = 4 if bottleneck else 1
        conv_0 = hk.Conv2D(output_channels=channels // channel_div,
                           kernel_shape=1 if bottleneck else 3,
                           stride=1,
                           with_bias=not self.use_bn,
                           padding="SAME",
                           name="conv_0")
        if self.use_bn:
            bn_0 = hk.BatchNorm(name="batchnorm_0", **bn_config)

        conv_1 = hk.Conv2D(output_channels=channels // channel_div,
                           kernel_shape=3,
                           stride=1,
                           with_bias=not self.use_bn,
                           padding="SAME",
                           name="conv_1")
        if self.use_bn:
            bn_1 = hk.BatchNorm(name="batchnorm_1", **bn_config)
            layers = ((conv_0, bn_0), (conv_1, bn_1))
        else:
            layers = ((conv_0, None), (conv_1, None))

        if bottleneck:
            conv_2 = hk.Conv2D(output_channels=channels,
                               kernel_shape=1,
                               stride=1,
                               with_bias=not self.use_bn,
                               padding="SAME",
                               name="conv_2")
            if self.use_bn:
                bn_2 = hk.BatchNorm(name="batchnorm_2",
                                    scale_init=jnp.zeros,
                                    **bn_config)
                layers = layers + ((conv_2, bn_2), )
            else:
                layers = layers + ((conv_2, None), )

        self.layers = layers