Esempio n. 1
0
    def apply(self,
              inputs,
              num_classes,
              filter_shape=(5, 5),
              filters=(16, 32),
              dense_size=64,
              train=True,
              init_fn=flax.nn.initializers.kaiming_normal,
              activation_fn=flax.nn.relu,
              masks=None,
              masked_layer_indices=None):
        """Applies a convolution to the inputs.

    Args:
      inputs: Input data with dimensions (batch, spatial_dims..., features).
      num_classes: Number of classes in the dataset.
      filter_shape: Shape of the convolutional filters.
      filters: Number of filters in each convolutional layer, and number of conv
        layers (given by length of sequence).
      dense_size: Number of filters in each convolutional layer, and number of
        conv layers (given by length of sequence).
      train: If model is being evaluated in training mode or not.
      init_fn: Initialization function used for convolutional layers.
      activation_fn: Activation function to be used for convolutional layers.
      masks: Masks of the layers in this model, in the same form as
             module params, or None.
      masked_layer_indices: The layer indices of layers in model to be masked.

    Returns:
      A tensor of shape (batch, num_classes), containing the logit output.
    Raises:
      ValueError if the number of pooling layers is too many for the given input
        size.
    """
        # Note: First dim is batch, last dim is channels, other dims are "spatial".
        if not all([(dim >= 2**len(filters)) for dim in inputs.shape[1:-2]]):
            raise ValueError(
                'Input spatial size, {}, does not allow {} pooling layers.'.
                format(str(inputs.shape[1:-2]), len(filters)))

        depth = 2 + len(filters)
        masks = masked.generate_model_masks(depth, masks, masked_layer_indices)

        batch_norm = flax.nn.BatchNorm.partial(use_running_average=not train,
                                               momentum=0.99,
                                               epsilon=1e-5)

        for i, filter_num in enumerate(filters):
            if f'MaskedModule_{i}' in masks:
                logging.info('Layer %d is masked in model', i)
                mask = masks[f'MaskedModule_{i}']
                inputs = masked.masked(flax.nn.Conv, mask)(
                    inputs,
                    features=filter_num,
                    kernel_size=filter_shape,
                    kernel_init=init.sparse_init(
                        init_fn(),
                        mask['kernel'] if mask is not None else None))
            else:
                inputs = flax.nn.Conv(inputs,
                                      features=filter_num,
                                      kernel_size=filter_shape,
                                      kernel_init=init_fn())
            inputs = batch_norm(inputs, name='bn_conv_{}'.format(i))
            inputs = activation_fn(inputs)

            if i < len(filters) - 1:
                inputs = flax.nn.max_pool(inputs,
                                          window_shape=(2, 2),
                                          strides=(2, 2),
                                          padding='VALID')

        # Global average pool at end of convolutional layers.
        inputs = flax.nn.avg_pool(inputs,
                                  window_shape=inputs.shape[1:-1],
                                  padding='VALID')

        # This is effectively a Dense layer, but we cast it as a convolution layer
        # to allow us to easily propagate masks, avoiding b/156135283.
        if f'MaskedModule_{depth - 2}' in masks:
            mask_dense_1 = masks[f'MaskedModule_{depth - 2}']
            inputs = masked.masked(flax.nn.Conv, mask_dense_1)(
                inputs,
                features=dense_size,
                kernel_size=inputs.shape[1:-1],
                kernel_init=init.sparse_init(
                    init_fn(), mask_dense_1['kernel']
                    if mask_dense_1 is not None else None))
        else:
            inputs = flax.nn.Conv(inputs,
                                  features=dense_size,
                                  kernel_size=inputs.shape[1:-1],
                                  kernel_init=init_fn())
        inputs = batch_norm(inputs, name='bn_dense_1')
        inputs = activation_fn(inputs)

        inputs = flax.nn.Dense(
            inputs,
            features=num_classes,
            kernel_init=flax.nn.initializers.xavier_normal())
        inputs = batch_norm(inputs, name='bn_dense_2')
        inputs = jnp.squeeze(inputs)
        return flax.nn.log_softmax(inputs)
Esempio n. 2
0
    def apply(self,
              inputs,
              num_classes,
              features=(32, 32),
              train=True,
              init_fn=flax.deprecated.nn.initializers.kaiming_normal,
              activation_fn=flax.deprecated.nn.relu,
              masks=None,
              masked_layer_indices=None,
              dropout_rate=0.):
        """Applies fully-connected neural network to the inputs.

    Args:
      inputs: Input data with dimensions (batch, features), if features has more
        than one dimension, it is flattened.
      num_classes: Number of classes in the dataset.
      features: Number of neurons in each layer, and number of layers (given by
        length of sequence) + one layer for softmax.
      train: If model is being evaluated in training mode or not.
      init_fn: Initialization function used for dense layers.
      activation_fn: Activation function to be used for dense layers.
      masks: Masks of the layers in this model, in the same form as module
        params, or None.
      masked_layer_indices: The layer indices of layers in model to be masked.
      dropout_rate: Dropout rate, if 0 then dropout is not used (default).

    Returns:
      A tensor of shape (batch, num_classes), containing the logit output.
    """
        batch_norm = flax.deprecated.nn.BatchNorm.partial(
            use_running_average=not train, momentum=0.99, epsilon=1e-5)

        depth = 1 + len(features)
        masks = masked.generate_model_masks(depth, masks, masked_layer_indices)

        # If inputs are in image dimensions, flatten image.
        inputs = inputs.reshape(inputs.shape[0], -1)

        for i, feature_num in enumerate(features):
            if f'MaskedModule_{i}' in masks:
                logging.info('Layer %d is masked in model', i)
                mask = masks[f'MaskedModule_{i}']
                inputs = masked.masked(flax.deprecated.nn.Dense, mask)(
                    inputs,
                    features=feature_num,
                    kernel_init=init.sparse_init(
                        init_fn(),
                        mask['kernel'] if mask is not None else None))
            else:
                inputs = flax.deprecated.nn.Dense(inputs,
                                                  features=feature_num,
                                                  kernel_init=init_fn())
            inputs = batch_norm(inputs, name=f'bn_conv_{i}')
            inputs = activation_fn(inputs)
            if dropout_rate > 0.0:
                inputs = flax.deprecated.nn.dropout(inputs,
                                                    dropout_rate,
                                                    deterministic=not train)

        inputs = flax.deprecated.nn.Dense(
            inputs,
            features=num_classes,
            kernel_init=flax.deprecated.nn.initializers.xavier_normal())

        return flax.deprecated.nn.log_softmax(inputs)