Exemple #1
0
    def apply(
        self,
        x,
        action_dim,
        max_action,
        key=None,
        MPO=False,
        sample=False,
        log_sig_min=-20,
        log_sig_max=2,
    ):
        x = nn.Dense(x, features=200)
        x = nn.LayerNorm(x)
        x = nn.tanh(x)
        x = nn.Dense(x, features=200)
        x = nn.elu(x)
        x = nn.Dense(x, features=2 * action_dim)

        mu, log_sig = jnp.split(x, 2, axis=-1)
        log_sig = nn.softplus(log_sig)
        log_sig = jnp.clip(log_sig, log_sig_min, log_sig_max)

        if MPO:
            return mu, log_sig

        if not sample:
            return max_action * nn.tanh(mu), log_sig
        else:
            pi = mu + random.normal(key, mu.shape) * jnp.exp(log_sig)
            log_pi = gaussian_likelihood(pi, mu, log_sig)
            pi = nn.tanh(pi)
            log_pi -= jnp.sum(jnp.log(nn.relu(1 - pi ** 2) + 1e-6), axis=1)
            return max_action * pi, log_pi
Exemple #2
0
 def apply(self, x):
     x = nn.Dense(x, features=50)
     x = nn.tanh(x)
     x = nn.Dense(x, features=50)
     x = nn.tanh(x)
     x = nn.Dense(x, features=50)
     x = nn.tanh(x)
     x = nn.Dense(x, features=50)
     x = nn.tanh(x)
     x = nn.Dense(x, features=1)
     return x
Exemple #3
0
 def apply(self, x, action_dim, max_action):
     x = nn.Dense(x, features=256)
     x = nn.relu(x)
     x = nn.Dense(x, features=256)
     x = nn.relu(x)
     x = nn.Dense(x, features=action_dim)
     return max_action * nn.tanh(x)
  def apply(self, x, rep_size, m_layers, m_features, m_kernel_sizes, conv_rep_size, padding_mask=None):
        
    H_0 = nn.relu(nn.Dense(x, conv_rep_size))
    G_0 = nn.relu(nn.Dense(x, conv_rep_size))
    H, G = jnp.expand_dims(H_0, axis=2), jnp.expand_dims(G_0, axis=2)

    for layer in range(1, m_layers+1):
      
      if layer < m_layers:
        H_features, G_features = m_features[layer-1]
      else:
        H_features, G_features = conv_rep_size, conv_rep_size
      
      H_kernel_size, G_kernel_size = m_kernel_sizes[layer-1]

      H = nn.Conv(H, features=H_features, kernel_size=(H_kernel_size, 1))
      G = nn.Conv(G, features=G_features, kernel_size=(G_kernel_size, 1)) 

      if layer < m_layers:
        H = nn.relu(H)
        G = nn.relu(G)
      else:
        H = nn.tanh(H)
        G = nn.sigmoid(G)

    H, G = jnp.squeeze(H, axis=2), jnp.squeeze(G, axis=2)
    
    F = H * G + G_0
    
    rep = linear_max_pool(F, padding_mask=padding_mask, rep_size=rep_size)
    
    return rep
Exemple #5
0
  def apply(self,
            x,
            num_classes=1000,
            train=False,
            resnet=None,
            patches=None,
            hidden_size=None,
            transformer=None,
            representation_size=None,
            classifier='gap'):

    n, h, w, c = x.shape

    # Embed the grid or patches of the grid.
    fh, fw = patches.size
    gh, gw = h // fh, w // fw
    if hidden_size:  # We can merge s2d+emb into a single conv; it's the same.
      x = nn.Conv(
          x,
          hidden_size, (fh, fw),
          strides=(fh, fw),
          padding='VALID',
          name='embedding')
    else:
      # This path often results in excessive padding.
      x = jnp.reshape(x, [n, gh, fh, gw, fw, c])
      x = jnp.transpose(x, [0, 1, 3, 2, 4, 5])
      x = jnp.reshape(x, [n, gh, gw, -1])

    # Here, x is a grid of embeddings.

    # (Possibly partial) Transformer.
    if transformer is not None:
      n, h, w, c = x.shape
      x = jnp.reshape(x, [n, h * w, c])

      # If we want to add a class token, add it here.
      if classifier == 'token':
        cls = self.param('cls', (1, 1, c), nn.initializers.zeros)
        cls = jnp.tile(cls, [n, 1, 1])
        x = jnp.concatenate([cls, x], axis=1)

      x = Encoder(x, train=train, name='Transformer', **transformer)

    if classifier == 'token':
      x = x[:, 0]
    elif classifier == 'gap':
      x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)

    if representation_size is not None:
      x = nn.Dense(x, representation_size, name='pre_logits')
      x = nn.tanh(x)
    else:
      x = IdentityLayer(x, name='pre_logits')

    x = nn.Dense(x, num_classes, name='head', kernel_init=nn.initializers.zeros)
    return x
Exemple #6
0
    def apply(self, state, action, Q1=False):
        state_action = jnp.concatenate([state, action], axis=1)

        q1 = nn.Dense(state_action, features=500)
        q1 = nn.LayerNorm(q1)
        q1 = nn.tanh(q1)
        q1 = nn.Dense(q1, features=500)
        q1 = nn.elu(q1)
        q1 = nn.Dense(q1, features=1)

        if Q1:
            return q1

        q2 = nn.Dense(state_action, features=500)
        q2 = nn.LayerNorm(q2)
        q2 = nn.tanh(q2)
        q2 = nn.Dense(q2, features=500)
        q2 = nn.elu(q2)
        q2 = nn.Dense(q2, features=1)

        return q1, q2
Exemple #7
0
    def loss_fn(mlo, slo, actor):
        mu, log_sig = actor(state, MPO=True)
        sig = jnp.exp(log_sig)
        target_mu, target_log_sig = actor_target(state, MPO=True)
        target_sig = jnp.exp(target_log_sig)

        actor_log_prob = gaussian_likelihood(sampled_actions, target_mu, sig)
        actor_log_prob += gaussian_likelihood(sampled_actions, mu, target_sig)
        actor_log_prob = actor_log_prob.transpose((0, 1))

        mu, target_mu = nn.tanh(mu), nn.tanh(mu)

        reg_mu = eps_mu - kl_mvg_diag(target_mu, target_sig, mu, target_sig).mean()
        reg_sig = eps_sig - kl_mvg_diag(target_mu, target_sig, target_mu, sig).mean()

        mlo = lagrange_step(mlo, reg_mu)
        slo = lagrange_step(slo, reg_sig)

        actor_loss = -(actor_log_prob[:, None] * weights).sum(axis=1).mean()
        actor_loss -= mu_lagrange_optimizer.target() * reg_mu
        actor_loss -= sig_lagrange_optimizer.target() * reg_sig
        return actor_loss.mean(), (mlo, slo)
Exemple #8
0
 def apply(self,
           inputs: jnp.ndarray,
           hidden_size: int = None,
           output_size: int = None,
           output_bias: bool = False,
           dropout: float = None,
           train: bool = None):
     # inputs.shape = <float32>[batch_size, seq_length, hidden_size]
     hidden = nn.Dense(inputs, hidden_size, name='hidden')
     hidden = nn.tanh(hidden)
     if train:
         hidden = nn.dropout(hidden, rate=dropout)
     output = nn.Dense(hidden, output_size, bias=output_bias, name='output')
     return output
 def apply(self, x):
     return nn.sigmoid(x) * nn.sigmoid(-x) * nn.tanh(x) * (1 / 0.15)
Exemple #10
0
    def apply(self,
              x,
              num_classes=1000,
              train=False,
              resnet=None,
              patches=None,
              hidden_size=None,
              transformer=None,
              representation_size=None,
              classifier='gap'):

        # (Possibly partial) ResNet root.
        if resnet is not None:
            width = int(64 * resnet.width_factor)

            # Root block.
            x = models_resnet.StdConv(x,
                                      width, (7, 7), (2, 2),
                                      bias=False,
                                      name='conv_root')
            x = nn.GroupNorm(x, name='gn_root')
            x = nn.relu(x)
            x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')

            # ResNet stages.
            x = models_resnet.ResNetStage(x,
                                          resnet.num_layers[0],
                                          width,
                                          first_stride=(1, 1),
                                          name='block1')
            for i, block_size in enumerate(resnet.num_layers[1:], 1):
                x = models_resnet.ResNetStage(x,
                                              block_size,
                                              width * 2**i,
                                              first_stride=(2, 2),
                                              name=f'block{i + 1}')

        n, h, w, c = x.shape

        # We can merge s2d+emb into a single conv; it's the same.
        x = nn.Conv(x,
                    hidden_size,
                    patches.size,
                    strides=patches.size,
                    padding='VALID',
                    name='embedding')

        # Here, x is a grid of embeddings.

        # (Possibly partial) Transformer.
        if transformer is not None:
            n, h, w, c = x.shape
            x = jnp.reshape(x, [n, h * w, c])

            # If we want to add a class token, add it here.
            if classifier == 'token':
                cls = self.param('cls', (1, 1, c), nn.initializers.zeros)
                cls = jnp.tile(cls, [n, 1, 1])
                x = jnp.concatenate([cls, x], axis=1)

            x = Encoder(x, train=train, name='Transformer', **transformer)

        if classifier == 'token':
            x = x[:, 0]
        elif classifier == 'gap':
            x = jnp.mean(x, axis=list(range(1, x.ndim - 1)))  # (1,) or (1,2)

        if representation_size is not None:
            x = nn.Dense(x, representation_size, name='pre_logits')
            x = nn.tanh(x)
        else:
            x = IdentityLayer(x, name='pre_logits')

        x = nn.Dense(x,
                     num_classes,
                     name='head',
                     kernel_init=nn.initializers.zeros)
        return x
  def apply(self,
            x,
            num_classes=1,
            train=False,
            hidden_size=None,
            transformer=None,
            resnet_emb=None,
            representation_size=None):
    """Apply model on inputs.

    Args:
      x: the processed input patches and position annotations.
      num_classes: the number of output classes. 1 for single model.
      train: train or eval.
      hidden_size: the hidden dimension for patch embedding tokens.
      transformer: the model config for Transformer backbone.
      resnet_emb: the config for patch embedding w/ small resnet.
      representation_size: size of the last FC before prediction.

    Returns:
      Model prediction output.
    """
    assert transformer is not None
    # Either 3: (batch size, seq len, channel) or
    # 4: (batch size, crops, seq len, channel)
    assert len(x.shape) in [3, 4]

    multi_crops_input = False
    if len(x.shape) == 4:
      multi_crops_input = True
      batch_size, num_crops, l, channel = x.shape
      x = jnp.reshape(x, [batch_size * num_crops, l, channel])

    # We concat (x, spatial_positions, scale_posiitons, input_masks)
    # when preprocessing.
    inputs_spatial_positions = x[:, :, -3]
    inputs_spatial_positions = inputs_spatial_positions.astype(jnp.int32)
    inputs_scale_positions = x[:, :, -2]
    inputs_scale_positions = inputs_scale_positions.astype(jnp.int32)
    inputs_masks = x[:, :, -1]
    inputs_masks = inputs_masks.astype(jnp.bool_)
    x = x[:, :, :-3]
    n, l, channel = x.shape
    if hidden_size:
      if resnet_emb:
        # channel = patch_size * patch_size * 3
        patch_size = int(np.sqrt(channel // 3))
        x = jnp.reshape(x, [-1, patch_size, patch_size, 3])
        x = resnet.StdConv(
            x, RESNET_TOKEN_DIM, (7, 7), (2, 2), bias=False, name="conv_root")
        x = nn.GroupNorm(x, name="gn_root")
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")

        if resnet_emb.num_layers > 0:
          blocks, bottleneck = resnet.get_block_desc(resnet_emb.num_layers)
          if blocks:
            x = resnet.ResNetStage(
                x,
                blocks[0],
                RESNET_TOKEN_DIM,
                first_stride=(1, 1),
                bottleneck=bottleneck,
                name="block1")
            for i, block_size in enumerate(blocks[1:], 1):
              x = resnet.ResNetStage(
                  x,
                  block_size,
                  RESNET_TOKEN_DIM * 2**i,
                  first_stride=(2, 2),
                  bottleneck=bottleneck,
                  name=f"block{i + 1}")
        x = jnp.reshape(x, [n, l, -1])

      x = nn.Dense(x, hidden_size, name="embedding")

    # Here, x is a list of embeddings.
    x = utils.Encoder(
        x,
        inputs_spatial_positions,
        inputs_scale_positions,
        inputs_masks,
        train=train,
        name="Transformer",
        **transformer)

    x = x[:, 0]

    if representation_size:
      x = nn.Dense(x, representation_size, name="pre_logits")
      x = nn.tanh(x)
    else:
      x = resnet.IdentityLayer(x, name="pre_logits")

    x = nn.Dense(x, num_classes, name="head", kernel_init=nn.initializers.zeros)
    if multi_crops_input:
      _, channel = x.shape
      x = jnp.reshape(x, [batch_size, num_crops, channel])
    return x