예제 #1
0
 def apply(self, x):
     x = nn.Dense(x, features=32)
     x = nn.sigmoid(x)
     x = nn.Dense(x, features=32)
     x = nn.sigmoid(x)
     x = nn.Dense(x, features=1)
     return nn.sigmoid(x)
예제 #2
0
def apply_activation(intermediate_output, intermediate_activation):
    """Applies selected activation function to intermediate output."""
    if intermediate_activation is None:
        return intermediate_output

    if intermediate_activation == 'gelu':
        intermediate_output = nn.gelu(intermediate_output)
    elif intermediate_activation == 'relu':
        intermediate_output = nn.relu(intermediate_output)
    elif intermediate_activation == 'sigmoid':
        intermediate_output = nn.sigmoid(intermediate_output)
    elif intermediate_activation == 'softmax':
        intermediate_output = nn.softmax(intermediate_output)
    elif intermediate_activation == 'celu':
        intermediate_output = nn.celu(intermediate_output)
    elif intermediate_activation == 'elu':
        intermediate_output = nn.elu(intermediate_output)
    elif intermediate_activation == 'log_sigmoid':
        intermediate_output = nn.log_sigmoid(intermediate_output)
    elif intermediate_activation == 'log_softmax':
        intermediate_output = nn.log_softmax(intermediate_output)
    elif intermediate_activation == 'soft_sign':
        intermediate_output = nn.soft_sign(intermediate_output)
    elif intermediate_activation == 'softplus':
        intermediate_output = nn.softplus(intermediate_output)
    elif intermediate_activation == 'swish':
        intermediate_output = nn.swish(intermediate_output)
    elif intermediate_activation == 'tanh':
        intermediate_output = jnp.tanh(intermediate_output)
    else:
        raise NotImplementedError(
            '%s activation function is not yet supported.' %
            intermediate_activation)

    return intermediate_output
  def apply(self, x, rep_size, m_layers, m_features, m_kernel_sizes, conv_rep_size, padding_mask=None):
        
    H_0 = nn.relu(nn.Dense(x, conv_rep_size))
    G_0 = nn.relu(nn.Dense(x, conv_rep_size))
    H, G = jnp.expand_dims(H_0, axis=2), jnp.expand_dims(G_0, axis=2)

    for layer in range(1, m_layers+1):
      
      if layer < m_layers:
        H_features, G_features = m_features[layer-1]
      else:
        H_features, G_features = conv_rep_size, conv_rep_size
      
      H_kernel_size, G_kernel_size = m_kernel_sizes[layer-1]

      H = nn.Conv(H, features=H_features, kernel_size=(H_kernel_size, 1))
      G = nn.Conv(G, features=G_features, kernel_size=(G_kernel_size, 1)) 

      if layer < m_layers:
        H = nn.relu(H)
        G = nn.relu(G)
      else:
        H = nn.tanh(H)
        G = nn.sigmoid(G)

    H, G = jnp.squeeze(H, axis=2), jnp.squeeze(G, axis=2)
    
    F = H * G + G_0
    
    rep = linear_max_pool(F, padding_mask=padding_mask, rep_size=rep_size)
    
    return rep
 def apply(self, x, reduction=16):
     num_channels = x.shape[-1]
     y = x.mean(axis=(1, 2))
     y = nn.Dense(y, features=num_channels // reduction, bias=False)
     y = nn.relu(y)
     y = nn.Dense(y, features=num_channels, bias=False)
     y = nn.sigmoid(y)
     return x * y[:, None, None, :]
def multilabel_accuracy(logits, labels, stats):
    """Strict multilabel classification accuracy (all labels have to be correct).

  Args:
    logits: Tensor, shape [batch_size, num_classes]
    labels: Tensor, shape [batch_size, num_classes], values in {0, 1}
    stats: Dict of statistics output by the model.

  Returns:
    per_sample_success: Tensor of shape [batch_size]
  """
    del stats
    error = jnp.abs(labels - jnp.round(nn.sigmoid(logits)))
    return 1. - jnp.max(error, axis=-1)
예제 #6
0
    def apply_params(self, c, params):
        # Make an activation array of shape [decision_day, location] which satisfies
        # the pivot table constraint.
        site_activation = nn.sigmoid(params)
        for i in range(len(self.decision_day_idx)):
            new_activation, persistence_and_deactivation = (
                self._new_and_old_activation(site_activation[:i + 1]))
            for group, group_idx in self.group_to_idx.items():
                num_allowed = self.allowed_activations.loc[group].iloc[i]
                num_activated = new_activation[group_idx].sum()
                if num_activated > num_allowed:
                    squashed = (num_allowed /
                                num_activated) * new_activation[group_idx]
                    new_activation = jax.ops.index_update(
                        new_activation, jax.ops.index[group_idx], squashed)
                elif self.force_hit_cap:
                    # Same idea as above, but squash towards 1 instead of towards 0.
                    max_new_possible = 1 - persistence_and_deactivation[
                        group_idx]
                    excess_available = max_new_possible - new_activation[
                        group_idx]
                    total_excess_available = excess_available.sum()
                    total_excess_required = num_allowed - num_activated
                    scaler = 1.0
                    if total_excess_available > total_excess_required:
                        scaler = total_excess_required / total_excess_available
                    squashed = new_activation[
                        group_idx] + scaler * excess_available
                    new_activation = jax.ops.index_update(
                        new_activation, jax.ops.index[group_idx], squashed)
            adjusted_activation = persistence_and_deactivation + new_activation
            site_activation = jax.ops.index_update(site_activation,
                                                   jax.ops.index[i, :],
                                                   adjusted_activation)

        # Expand from [decision_day, location] to [time, location].
        c.site_activation = jnp.zeros(c.site_activation.shape)
        for i in range(len(self.decision_day_idx) - 1):
            day_idx = self.decision_day_idx[i]
            next_day_idx = self.decision_day_idx[i + 1]
            c.site_activation = jax.ops.index_update(
                c.site_activation, jax.ops.index[:, day_idx:next_day_idx],
                site_activation[i, :, jnp.newaxis])
        c.site_activation = jax.ops.index_update(
            c.site_activation, jax.ops.index[:, self.decision_day_idx[-1]:],
            site_activation[-1, :, jnp.newaxis])
예제 #7
0
def GatedResnet(inputs,
                aux=None,
                conv_module=None,
                nonlinearity=concat_elu,
                dropout_p=0.):
    c = inputs.shape[-1]
    y = conv_module(nonlinearity(inputs), c)
    if aux is not None:
        y = nonlinearity(y + ConvOneByOne(nonlinearity(aux), c))

    if dropout_p > 0:
        y = nn.dropout(y, dropout_p)

    # Set init_scale=0.1 so that the res block is close to the identity at
    # initialization.
    a, b = np.split(conv_module(y, 2 * c, init_scale=0.1), 2, axis=-1)
    return inputs + a * nn.sigmoid(b)
예제 #8
0
def sigmoid_mean_squared_error(logits, targets, weights=None):
    """Computes the sigmoid mean squared error between logits and targets.

  Args:
    logits: float array of shape (batch, output_shape)
    targets: float array of shape (batch, output_shape)
    weights: None or float array of shape (batch,)

  Returns:
    float array of sigmoid mean squared error between logits and targets
    of shape (batch, output_shape)
  """
    loss = jnp.sum(jnp.square(nn.sigmoid(logits) - targets).reshape(
        targets.shape[0], -1),
                   axis=-1)
    if weights is None:
        weights = jnp.ones(loss.shape[0])
    weights = weights / sum(weights)
    return jnp.sum(jnp.dot(loss, weights))
예제 #9
0
 def loss_fn(control_arm_events):
     cum_events = control_arm_events.cumsum(axis=-1)
     successiness = nn.sigmoid((cum_events - center) / width)
     return -successiness.mean(axis=-1)
예제 #10
0
 def apply(self, x):
     tau = self.param('tau', x.shape[-1:], nn.initializers.ones)
     return x * nn.sigmoid(x * tau)
예제 #11
0
 def apply(self, x):
     return nn.sigmoid(x) * nn.sigmoid(-x) * nn.tanh(x) * (1 / 0.15)
예제 #12
0
 def generate(self, z, **unused_kwargs):
     decoder = self._create_decoder()
     return nn.sigmoid(decoder(z))
예제 #13
0
 def generate_one_liner(self, z):
     return nn.sigmoid(Decoder(z, name='decoder'))
예제 #14
0
def differentiable_greater_than(x, threshold, width):
    """A smoothed version of (x > threshold).astype(float)."""
    return nn.sigmoid((x - threshold) / width)
예제 #15
0
    def apply(self, x, communication=Communication.NONE, train=True):
        """Forward pass."""
        batch_size = x.shape[0]

        if communication is Communication.SQUEEZE_EXCITE_X:
            x = sample_patches.SqueezeExciteLayer(x)
        # end if squeeze excite x

        d1 = nn.relu(
            nn.Conv(x,
                    128,
                    kernel_size=(3, 3),
                    strides=(1, 1),
                    bias=True,
                    name="down1"))
        d2 = nn.relu(
            nn.Conv(d1,
                    128,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    bias=True,
                    name="down2"))
        d3 = nn.relu(
            nn.Conv(d2,
                    128,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    bias=True,
                    name="down3"))

        if communication is Communication.SQUEEZE_EXCITE_D:
            d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c")
            d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c")
            d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c")

            nd1 = d1_flatten.shape[1]
            nd2 = d2_flatten.shape[1]

            d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten],
                                         axis=1)

            num_channels = d_together.shape[-1]
            y = d_together.mean(axis=1)
            y = nn.Dense(y, features=num_channels // 4, bias=False)
            y = nn.relu(y)
            y = nn.Dense(y, features=num_channels, bias=False)
            y = nn.sigmoid(y)

            d_together = d_together * y[:, None, :]

            # split and reshape
            d1 = d_together[:, :nd1].reshape(d1.shape)
            d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape)
            d3 = d_together[:, nd1 + nd2:].reshape(d3.shape)

        elif communication is Communication.TRANSFORMER:
            d1_flatten = einops.rearrange(d1, "b h w c -> b (h w) c")
            d2_flatten = einops.rearrange(d2, "b h w c -> b (h w) c")
            d3_flatten = einops.rearrange(d3, "b h w c -> b (h w) c")

            nd1 = d1_flatten.shape[1]
            nd2 = d2_flatten.shape[1]

            d_together = jnp.concatenate([d1_flatten, d2_flatten, d3_flatten],
                                         axis=1)

            positional_encodings = self.param(
                "scale_ratio_position_encodings",
                shape=(1, ) + d_together.shape[1:],
                initializer=jax.nn.initializers.normal(1. /
                                                       d_together.shape[-1]))
            d_together = transformer.Transformer(d_together +
                                                 positional_encodings,
                                                 num_layers=2,
                                                 num_heads=8,
                                                 is_training=train)

            # split and reshape
            d1 = d_together[:, :nd1].reshape(d1.shape)
            d2 = d_together[:, nd1:nd1 + nd2].reshape(d2.shape)
            d3 = d_together[:, nd1 + nd2:].reshape(d3.shape)

        t1 = nn.Conv(d1,
                     6,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy1")
        t2 = nn.Conv(d2,
                     6,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy2")
        t3 = nn.Conv(d3,
                     9,
                     kernel_size=(1, 1),
                     strides=(1, 1),
                     bias=True,
                     name="tidy3")

        raw_scores = (jnp.split(t1, 6, axis=-1) + jnp.split(t2, 6, axis=-1) +
                      jnp.split(t3, 9, axis=-1))

        # The following is for normalization.
        t = jnp.concatenate((jnp.reshape(
            t1, [batch_size, -1]), jnp.reshape(
                t2, [batch_size, -1]), jnp.reshape(t3, [batch_size, -1])),
                            axis=1)
        t_min = jnp.reshape(jnp.min(t, axis=-1), [batch_size, 1, 1, 1])
        t_max = jnp.reshape(jnp.max(t, axis=-1), [batch_size, 1, 1, 1])
        normalized_scores = zeroone(raw_scores, t_min, t_max)

        stats = {
            "scores": normalized_scores,
            "raw_scores": t,
        }
        # removes the split dimension. scores are now b x h' x w' shaped
        normalized_scores = [s.squeeze(-1) for s in normalized_scores]

        return normalized_scores, stats
예제 #16
0
  def apply(self,
            x,
            layer=LAYER_EVONORM_B0,
            nonlinearity=True,
            num_groups=32,
            group_size=None,
            batch_stats=None,
            use_running_average=False,
            axis=-1,
            momentum=0.99,
            epsilon=1e-5,
            dtype=jnp.float32,
            bias=True,
            scale=True,
            bias_init=initializers.zeros,
            scale_init=initializers.ones,
            axis_name=None,
            axis_index_groups=None):
    """Normalizes the input using batch statistics.

    Args:
      x: the input to be normalized.
      layer: LAYER_EVONORM_B0 or LAYER_EVONORM_S0.
      nonlinearity: use the EvoNorm nonlinearity.
      num_groups: number of groups to use for group statistics.
      group_size: size of groups, see nn.GroupNorm.
      batch_stats: a `flax.nn.Collection` used to store an exponential moving
        average of the batch statistics (default: None).
      use_running_average: if true, the statistics stored in batch_stats will be
        used instead of computing the batch statistics on the input.
      axis: the feature or non-batch axis of the input.
      momentum: decay rate for the exponential moving average of the batch
        statistics.
      epsilon: a small float added to variance to avoid dividing by zero.
      dtype: the dtype of the computation (default: float32).
      bias:  if True, bias (beta) is added.
      scale: if True, multiply by scale (gamma). When the next layer is linear
        (also e.g. nn.relu), this can be disabled since the scaling will be done
        by the next layer.
      bias_init: initializer for bias, by default, zero.
      scale_init: initializer for scale, by default, one.
      axis_name: the axis name used to combine batch statistics from multiple
        devices. See `jax.pmap` for a description of axis names (default: None).
      axis_index_groups: groups of axis indices within that named axis
        representing subsets of devices to reduce over (default: None). For
          example, `[[0, 1], [2, 3]]` would independently batch-normalize over
          the examples on the first two and last two devices. See `jax.lax.psum`
          for more details.

    Returns:
      Normalized inputs (the same shape as inputs).
    """
    x = jnp.asarray(x, jnp.float32)

    axis = axis if isinstance(axis, tuple) else (axis,)
    # pylint: disable=protected-access
    axis = nn.normalization._absolute_dims(x.ndim, axis)
    # pylint: enable=protected-access
    feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape))
    reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis)
    instance_reduction_axis = tuple(
        i for i in range(x.ndim) if i not in axis and i > 0)
    batch_reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)

    if nonlinearity:
      v = self.param('v', reduced_feature_shape,
                     jax.nn.initializers.ones).reshape(feature_shape)
      if layer == LAYER_EVONORM_S0:
        den, group_shape, input_shape = _GroupStd(
            x,
            num_groups=num_groups,
            group_size=group_size,
            epsilon=epsilon,
            dtype=dtype,
        )
        x = x * nn.sigmoid(v * x)
        x = x.reshape(group_shape)
        x /= den
        x = x.reshape(input_shape)
      elif layer == LAYER_EVONORM_B0:
        if self.is_stateful() or batch_stats:
          ra_var = self.state(
              'var',
              reduced_feature_shape,
              initializers.ones,
              collection=batch_stats)
        else:
          ra_var = None

        if use_running_average:
          if ra_var is None:
            raise ValueError(
                'when use_running_averages is True '
                'either use a stateful context or provide batch_stats')
          var = ra_var.value
        else:
          mean = jnp.mean(x, axis=batch_reduction_axis, keepdims=False)
          mean2 = jnp.mean(
              lax.square(x), axis=batch_reduction_axis, keepdims=False)
          if axis_name is not None and not self.is_initializing():
            concatenated_mean = jnp.concatenate([mean, mean2])
            mean, mean2 = jnp.split(
                lax.pmean(
                    concatenated_mean,
                    axis_name=axis_name,
                    axis_index_groups=axis_index_groups), 2)
          var = mean2 - lax.square(mean)

          if ra_var and not self.is_initializing():
            ra_var.value = momentum * ra_var.value + (1 - momentum) * var

        left = lax.sqrt(var + epsilon)

        instance_std = jnp.sqrt(
            x.var(axis=instance_reduction_axis, keepdims=True) + epsilon)
        right = v * x + instance_std
        x = x / jnp.maximum(left, right)
      else:
        raise ValueError('Unknown EvoNorm layer: {}'.format(layer))

    if scale:
      x *= self.param('scale', reduced_feature_shape,
                      scale_init).reshape(feature_shape)
    if bias:
      x = x + self.param('bias', reduced_feature_shape,
                         bias_init).reshape(feature_shape)
    return jnp.asarray(x, dtype)
예제 #17
0
 def generate_shared(self, z):
     return nn.sigmoid(self._created_decoder()(z))
예제 #18
0
 def apply_params(self, c, params):
     params = jnp.broadcast_to(params[:, jnp.newaxis],
                               c.site_capacity.shape)
     c.site_activation = nn.sigmoid(params)
예제 #19
0
 def apply_params(self, c, params):
     c.site_activation = nn.sigmoid(params)
예제 #20
0
 def generate(self, z):
     params = self.get_param('decoder')
     return nn.sigmoid(Decoder.call(params, z))