Ejemplo n.º 1
0
def _conv_with_fixed_kernel(params,
                            input_filters_or_mask,
                            output_filters_or_mask,
                            kernel_size,
                            strides = (1, 1),
                            activation = None,
                            use_batch_norm = True):
  """Construct a Conv2D, followed by a batch norm and optional activation."""
  result = []
  if (isinstance(input_filters_or_mask, int) and
      isinstance(output_filters_or_mask, int)):
    result.append(
        layers.Conv2D(
            filters=output_filters_or_mask,
            kernel_size=kernel_size,
            kernel_initializer=params['kernel_initializer'],
            kernel_regularizer=params['kernel_regularizer'],
            use_bias=not use_batch_norm,
            strides=strides))
  else:
    result.append(
        layers.MaskedConv2D(
            input_mask=_to_filters_mask(input_filters_or_mask),
            output_mask=_to_filters_mask(output_filters_or_mask),
            kernel_size=kernel_size,
            kernel_initializer=params['kernel_initializer'],
            kernel_regularizer=params['kernel_regularizer'],
            use_bias=not use_batch_norm,
            strides=strides))

  if use_batch_norm:
    result.append(_batch_norm(params, output_filters_or_mask))
  if activation is not None:
    result.append(activation)
  return layers.Sequential(result)
Ejemplo n.º 2
0
def _squeeze_and_excite(params,
                        input_filters_or_mask,
                        inner_activation,
                        gating_activation):
  """Generate a squeeze-and-excite layer."""
  # We provide two code paths:
  # 1. For the case where the number of input filters is known at graph
  #    construction time, and input_filters_or_mask is an int. This typically
  #    happens during stand-alone model training.
  # 2. For the case where the number of input filters is not known until
  #    runtime, and input_filters_or_mask is a 1D float tensor. This often
  #    happens during an architecture search.
  if isinstance(input_filters_or_mask, int):
    input_filters = input_filters_or_mask
    hidden_filters = search_space_utils.make_divisible(
        input_filters * _SQUEEZE_AND_EXCITE_RATIO,
        divisor=params['filters_base'])

    return layers.ParallelProduct([
        layers.Identity(),
        layers.Sequential([
            layers.GlobalAveragePool(keepdims=True),
            layers.Conv2D(
                filters=hidden_filters,
                kernel_size=(1, 1),
                kernel_initializer=params['kernel_initializer'],
                kernel_regularizer=params['kernel_regularizer'],
                use_bias=True),
            inner_activation,
            layers.Conv2D(
                filters=input_filters,
                kernel_size=(1, 1),
                kernel_initializer=params['kernel_initializer'],
                kernel_regularizer=params['kernel_regularizer'],
                use_bias=True),
            gating_activation,
        ]),
    ])
  else:
    input_mask = input_filters_or_mask
    input_filters = tf.reduce_sum(input_mask)
    hidden_filters = search_space_utils.tf_make_divisible(
        input_filters * _SQUEEZE_AND_EXCITE_RATIO,
        divisor=params['filters_base'])

    max_input_filters = int(input_mask.shape[0])
    max_hidden_filters = search_space_utils.make_divisible(
        max_input_filters * _SQUEEZE_AND_EXCITE_RATIO,
        divisor=params['filters_base'])

    hidden_mask = tf.sequence_mask(
        hidden_filters, max_hidden_filters, dtype=tf.float32)

    return layers.ParallelProduct([
        layers.Identity(),
        layers.Sequential([
            layers.GlobalAveragePool(keepdims=True),
            layers.MaskedConv2D(
                input_mask=input_mask,
                output_mask=hidden_mask,
                kernel_size=(1, 1),
                kernel_initializer=params['kernel_initializer'],
                kernel_regularizer=params['kernel_regularizer'],
                use_bias=True),
            inner_activation,
            layers.MaskedConv2D(
                input_mask=hidden_mask,
                output_mask=input_mask,
                kernel_size=(1, 1),
                kernel_initializer=params['kernel_initializer'],
                kernel_regularizer=params['kernel_regularizer'],
                use_bias=True),
            gating_activation,
        ])
    ])