Exemplo n.º 1
0
  def apply(self, x, num_actions, minatar, env, normalize_obs, noisy, dueling, num_atoms,hidden_layer=2, neurons=512):
    del normalize_obs

    if minatar:
      x = x.squeeze(3)
      x = x[None, ...]
      x = x.astype(jnp.float32)
      x = nn.Conv(x, features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1),  kernel_init=nn.initializers.xavier_uniform())
      x = jax.nn.relu(x)
      x = x.reshape((x.shape[0], -1))

    else:
      x = x[None, ...]
      x = x.astype(jnp.float32)
      x = x.reshape((x.shape[0], -1))


    if env is not None:
      x = x - env_inf[env]['MIN_VALS']
      x /= env_inf[env]['MAX_VALS'] - env_inf[env]['MIN_VALS']
      x = 2.0 * x - 1.0


    if noisy:
      def net(x, features):
        return NoisyNetwork(x, features)
    else:
      def net(x, features):
        return nn.Dense(x, features, kernel_init=nn.initializers.xavier_uniform())


    for _ in range(hidden_layer):
      x = net(x, features=neurons)
      #print('x:',x)
      x = jax.nn.relu(x)

    if dueling:
      print('dueling')
      adv = net(x,features=num_actions * num_atoms)
      value = net(x, features=num_atoms)
      adv = adv.reshape((adv.shape[0], num_actions, num_atoms))
      value = value.reshape((value.shape[0], 1, num_atoms))
      logits = value + (adv - (jnp.mean(adv, -1, keepdims=True)))
      probabilities = nn.softmax(logits)
      q_values = jnp.mean(logits, axis=2)

    else:
      #print('No dueling')
      x = net(x, features=num_actions * num_atoms)
      logits = x.reshape((x.shape[0], num_actions, num_atoms))
      probabilities = nn.softmax(logits)
      q_values = jnp.mean(logits, axis=2)


    return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Exemplo n.º 2
0
def apply_activation(intermediate_output, intermediate_activation):
    """Applies selected activation function to intermediate output."""
    if intermediate_activation is None:
        return intermediate_output

    if intermediate_activation == 'gelu':
        intermediate_output = nn.gelu(intermediate_output)
    elif intermediate_activation == 'relu':
        intermediate_output = nn.relu(intermediate_output)
    elif intermediate_activation == 'sigmoid':
        intermediate_output = nn.sigmoid(intermediate_output)
    elif intermediate_activation == 'softmax':
        intermediate_output = nn.softmax(intermediate_output)
    elif intermediate_activation == 'celu':
        intermediate_output = nn.celu(intermediate_output)
    elif intermediate_activation == 'elu':
        intermediate_output = nn.elu(intermediate_output)
    elif intermediate_activation == 'log_sigmoid':
        intermediate_output = nn.log_sigmoid(intermediate_output)
    elif intermediate_activation == 'log_softmax':
        intermediate_output = nn.log_softmax(intermediate_output)
    elif intermediate_activation == 'soft_sign':
        intermediate_output = nn.soft_sign(intermediate_output)
    elif intermediate_activation == 'softplus':
        intermediate_output = nn.softplus(intermediate_output)
    elif intermediate_activation == 'swish':
        intermediate_output = nn.swish(intermediate_output)
    elif intermediate_activation == 'tanh':
        intermediate_output = jnp.tanh(intermediate_output)
    else:
        raise NotImplementedError(
            '%s activation function is not yet supported.' %
            intermediate_activation)

    return intermediate_output
Exemplo n.º 3
0
            def apply(self, x, n_bins, stage_sizes, block_cls, num_filters=64, 
                    dtype=jnp.float32, act=nn.leaky_relu, train=True):
                b = x.shape[0]
                conv = nn.Conv.partial(bias=False, dtype=dtype)
                norm = nn.BatchNorm.partial(
                    use_running_average=not train,
                    dtype=dtype
                )

                x = conv(x, num_filters, kernel_size=(7,), strides=(2,), padding=[(3, 3)])
                x = norm(x)
                x = nn.leaky_relu(x)
                x = nn.max_pool(x, window_shape=(3,), strides=(2,), padding='SAME')

                for i, block_size in enumerate(stage_sizes):
                    for j in range(block_size):
                        strides = (2,) if i > 0 and j == 0 else (1,)
                        x = block_cls(x, num_filters * 2 ** i,
                                  strides=strides,
                                  conv=conv,
                                  norm=norm,
                                  act=act)
                x = x.reshape(b, -1)
                x = nn.Dense(x, n_bins, dtype=dtype)
                x = nn.softmax(x)
                return x
Exemplo n.º 4
0
 def apply(self, x, num_actions, num_atoms):
     initializer = jax.nn.initializers.variance_scaling(
         scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform')
     # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
     # have removed the true batch dimension.
     x = x[None, ...]
     x = x.astype(jnp.float32) / 255.
     x = nn.Conv(x,
                 features=32,
                 kernel_size=(8, 8),
                 strides=(4, 4),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(4, 4),
                 strides=(2, 2),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Conv(x,
                 features=64,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 kernel_init=initializer)
     x = jax.nn.relu(x)
     x = x.reshape((x.shape[0], -1))  # flatten
     x = nn.Dense(x, features=512, kernel_init=initializer)
     x = jax.nn.relu(x)
     x = nn.Dense(x,
                  features=num_actions * num_atoms,
                  kernel_init=initializer)
     logits = x.reshape((x.shape[0], num_actions, num_atoms))
     probabilities = nn.softmax(logits)
     q_values = jnp.mean(logits, axis=2)
     return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Exemplo n.º 5
0
 def rnn_dim2(carry,x):
     newCarry = actFun( cellInH(carry[1]) + cellInV(x[1]) + cellCarryH(carry[0]) + cellCarryV(x[0]) )
     logits=outputDense(newCarry)
     sampleOut=jax.random.categorical(x[2],logits)
     sample=jax.nn.one_hot(sampleOut,inputDim)
     logProb=jnp.log( jnp.sum( nn.softmax(logits) * sample, axis=1 ) )
     output = (newCarry, logProb, sampleOut)
     return (newCarry,sample), output
Exemplo n.º 6
0
  def apply(self, x, num_actions, num_atoms, support, noisy, dueling):
    # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
    # have removed the true batch dimension.

    x = x[None, ...]
    x = x.astype(jnp.float32)
    x = x.reshape((x.shape[0], -1))  # flatten
    #x -= gym_lib.CARTPOLE_MIN_VALS
    #x /= gym_lib.CARTPOLE_MAX_VALS - gym_lib.CARTPOLE_MIN_VALS
    #x = 2.0 * x - 1.0  # Rescale in range [-1, 1].

    if noisy:
        print('LunarLander-Noisy[Johan]')
        initializer = None
        bias = True
        def net(x, features, bias, kernel_init):
            return NoisyNetwork(x, features, bias, kernel_init)
    else:
        initializer = nn.initializers.xavier_uniform()
        bias = None
        def net(x, features, bias, kernel_init):
            return nn.Dense(x, features, kernel_init)

    x = net(x, features=512, bias=bias, kernel_init=initializer)
    x = jax.nn.relu(x)
    x = net(x,features=512, bias=bias, kernel_init=initializer)
    x = jax.nn.relu(x)

    if dueling:
        print('LunarLanderRainbowFull-Dueling')
        adv = net(x,features=num_actions * num_atoms, bias=bias, kernel_init=initializer)
        value = net(x, features=num_atoms, bias=bias, kernel_init=initializer)
        adv = adv.reshape((adv.shape[0], num_actions, num_atoms))
        value = value.reshape((value.shape[0], 1, num_atoms))
        logits = value + (adv - (jnp.mean(adv, -1, keepdims=True)))
        probabilities = nn.softmax(logits)
        q_values = jnp.sum(support * probabilities, axis=2)

    else:
        x = net(x, features=num_actions * num_atoms, bias=bias, kernel_init=initializer)
        logits = x.reshape((x.shape[0], num_actions, num_atoms))
        probabilities = nn.softmax(logits)
        q_values = jnp.sum(support * probabilities, axis=2)
    
    return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Exemplo n.º 7
0
 def apply(self, x):
     net = nn.Dense(x, 500, name='fc1')
     net = nn.leaky_relu(net)
     net = nn.BatchNorm(net)
     net = nn.Dense(net, 500, name='fc2')
     net = nn.leaky_relu(net)
     net = nn.BatchNorm(net)
     net = nn.Dense(net, 500, name='fc3')
     net = nn.leaky_relu(net)
     net = nn.BatchNorm(net)
     return nn.softmax(nn.Dense(net, n_bin))
Exemplo n.º 8
0
  def apply(self, x, num_actions, num_atoms, support, noisy, dueling):
    # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
    # have removed the true batch dimension.
    initializer_conv = nn.initializers.xavier_uniform()
    x = x[None, ...]
    x = x.astype(jnp.float32)
    x = nn.Conv(x, features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1),  kernel_init=initializer_conv)
    x = jax.nn.relu(x)
    x = x.reshape((x.shape[0], -1))  # flatten.

    if noisy:
        print('InvadersRainbowFull-Noisy[Johan]')
        initializer = None
        bias = True
        def net(x, features, bias, kernel_init):
            return NoisyNetwork(x, features, bias, kernel_init)
    else:
        initializer = nn.initializers.xavier_uniform()
        bias = None
        def net(x, features, bias, kernel_init):
            return nn.Dense(x, features, kernel_init)        

    if dueling:
        print('InvadersRainbowFull-Dueling')
        adv = net(x,features=num_actions * num_atoms, bias=bias, kernel_init=initializer)
        value = net(x, features=num_atoms, bias=bias, kernel_init=initializer)
        adv = adv.reshape((adv.shape[0], num_actions, num_atoms))
        value = value.reshape((value.shape[0], 1, num_atoms))
        logits = value + (adv - (jnp.mean(adv, -1, keepdims=True)))
        probabilities = nn.softmax(logits)
        q_values = jnp.sum(support * probabilities, axis=2)

    else:
        x = net(x, features=num_actions * num_atoms, bias=bias, kernel_init=initializer)
        logits = x.reshape((x.shape[0], num_actions, num_atoms))
        probabilities = nn.softmax(logits)
        q_values = jnp.sum(support * probabilities, axis=2)
    
    return atari_lib.RainbowNetworkType(q_values, logits, probabilities)  
Exemplo n.º 9
0
 def apply(self, x):
     b = x.shape[0]
     x = nn.Conv(x, features=128, kernel_size=(4, ), padding='SAME')
     x = nn.BatchNorm(x)
     x = nn.leaky_relu(x)
     x = nn.avg_pool(x, window_shape=(2, ), padding='SAME')
     x = nn.Conv(x, features=256, kernel_size=(4, ), padding='SAME')
     x = nn.BatchNorm(x)
     x = nn.leaky_relu(x)
     x = nn.avg_pool(x, window_shape=(2, ), padding='SAME')
     x = x.reshape(b, -1)
     x = nn.Dense(x, features=128)
     x = nn.BatchNorm(x)
     x = nn.leaky_relu(x)
     x = nn.Dense(x, features=n_bins)
     x = nn.softmax(x)
     return x
Exemplo n.º 10
0
      def apply(self, x, num_actions, num_atoms, support):
        def custom_init(key, shape, dtype=jnp.float32):
          del key
          to_pick_first_action = onp.ones(shape, dtype)
          to_pick_first_action[:, :num_atoms] = onp.arange(1, num_atoms + 1)
          return to_pick_first_action

        x = x[None, :]
        x = x.astype(jnp.float32)
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(x, features=num_actions * num_atoms,
                     kernel_init=custom_init,
                     bias_init=jax.nn.initializers.ones)
        logits = x.reshape((-1, num_actions, num_atoms))
        probabilities = nn.softmax(logits)
        qs = jnp.sum(support * probabilities, axis=2)
        return atari_lib.RainbowNetworkType(qs, logits, probabilities)
Exemplo n.º 11
0
        def step_single_example(hidden_states, instruction_pointer,
                                node_embeddings, true_indexes, false_indexes,
                                exit_index):
            # Execution (e.g. apply RNN)
            # leaves(hidden_states).shape: num_nodes, hidden_size
            # instruction_pointer.shape: num_nodes,
            # node_embeddings.shape: num_nodes, statement_length, hidden_size
            hidden_state_contributions = execute(hidden_states,
                                                 node_embeddings)

            # leaves(hidden_state_contributions).shape: num_nodes, hidden_size

            # Use the exit node's hidden state as it's hidden state contribution
            # to avoid "executing" the exit node.
            def mask_h(h_contribution, h):
                return h_contribution.at[exit_index, :].set(h[exit_index, :])

            hidden_state_contributions = jax.tree_multimap(
                mask_h, hidden_state_contributions, hidden_states)

            # Branch decisions (e.g. Dense layer)
            branch_decision_logits = branch_decide(hidden_state_contributions)
            branch_decisions = nn.softmax(branch_decision_logits, axis=-1)

            # Update state
            instruction_pointer_new = update_instruction_pointer(
                instruction_pointer, branch_decisions, true_indexes,
                false_indexes)
            hidden_states_new = aggregate(hidden_state_contributions,
                                          instruction_pointer,
                                          branch_decisions, true_indexes,
                                          false_indexes)

            to_tag = {
                'branch_decisions': branch_decisions,
                'hidden_state_contributions': hidden_state_contributions,
                'hidden_states_before': hidden_states,
                'hidden_states': hidden_states_new,
                'instruction_pointer_before': instruction_pointer,
                'instruction_pointer': instruction_pointer_new,
                'true_indexes': true_indexes,
                'false_indexes': false_indexes,
            }
            return hidden_states_new, instruction_pointer_new, to_tag
Exemplo n.º 12
0
 def apply(self,
           hidden_states,
           mask=None,
           *,
           d_qkv=64,
           attention_dropout_rate=0.0,
           output_dropout_rate=0.0,
           deterministic=False,
           kernel_init=nn.linear.default_kernel_init,
           output_kernel_init=nn.initializers.xavier_uniform(),
           bias_init=nn.initializers.zeros,
           bias=True):
     """Applies attention for a single batch element and head."""
     d_model = hidden_states.shape[-1]
     dense = nn.DenseGeneral.partial(axis=-1,
                                     features=(d_qkv, ),
                                     kernel_init=kernel_init,
                                     bias_init=bias_init,
                                     bias=bias)
     query, key, value = (dense(hidden_states, name='query'),
                          dense(hidden_states, name='key'),
                          dense(hidden_states, name='value'))
     attention_scores = jnp.einsum('TN,FN->FT', key, query)
     attention_scores = attention_scores / jnp.sqrt(d_qkv)
     if mask is not None:
         padding_mask = (1.0 - mask[None, :]) * NEG_INFINITY
         attention_scores = attention_scores + padding_mask
     attention_scores = nn.softmax(attention_scores)
     attention_probs = nn.dropout(attention_scores,
                                  rate=attention_dropout_rate,
                                  deterministic=deterministic)
     hidden_states = jnp.einsum('FT,TH->FH', attention_probs, value)
     hidden_states = nn.linear.DenseGeneral(hidden_states,
                                            features=d_model,
                                            axis=(-1, ),
                                            kernel_init=output_kernel_init,
                                            name='output')
     hidden_states = nn.dropout(hidden_states,
                                rate=output_dropout_rate,
                                deterministic=deterministic)
     return hidden_states
Exemplo n.º 13
0
  def apply(self, x, num_actions, num_atoms):
    
    initializer = nn.initializers.xavier_uniform()
    # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
    # have removed the true batch dimension.
    x = x[None, ...]
    x = x.astype(jnp.float32)
    x = x.reshape((x.shape[0], -1))  # flatten
    x -= gym_lib.CARTPOLE_MIN_VALS
    x /= gym_lib.CARTPOLE_MAX_VALS - gym_lib.CARTPOLE_MIN_VALS
    x = 2.0 * x - 1.0  # Rescale in range [-1, 1].
    x = nn.Dense(x, features=512, kernel_init=initializer)
    x = jax.nn.relu(x)
    x = nn.Dense(x, features=512, kernel_init=initializer)
    x = jax.nn.relu(x)
    x = nn.Dense(x, features=num_actions * num_atoms, kernel_init=initializer)

    logits = x.reshape((x.shape[0], num_actions, num_atoms))
    probabilities = nn.softmax(logits)
    q_values = jnp.mean(logits, axis=2)
    return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Exemplo n.º 14
0
def generate_text(
    model,
    vocab,
    max_length=100,
    temperature=0.5,
    top_k=3,
    start_letter="T",
):
    output_text = start_letter
    carry = nn.GRUCell.initialize_carry(jax.random.PRNGKey(0), (1, ),
                                        FLAGS.hidden_size)
    for i in range(max_length):
        input = vocab.numericalize(output_text[-1])
        input_t = jnp.array(input, dtype=jnp.int32).reshape(1, 1)
        carry, pred = model(input_t, carry)
        prob = nn.softmax(pred / temperature, axis=1)
        # output_text += vocab.textify(prob.argmax().tolist())[0]
        prob_np = np.array(prob)[0]
        top_k_index = prob_np.argsort()[-top_k:]
        next_char = np.random.choice(top_k_index.tolist(), 1,
                                     prob_np[top_k_index].tolist())
        output_text += vocab.textify(next_char.item())[0]

    return output_text
Exemplo n.º 15
0
        def step_single_example(hidden_states, instruction_pointer,
                                node_embeddings, true_indexes, false_indexes,
                                exit_index):
            # Execution (e.g. apply RNN)
            # leaves(hidden_states).shape: num_nodes, hidden_size
            # instruction_pointer.shape: num_nodes,
            # node_embeddings.shape: num_nodes, statement_length, hidden_size
            if config.model.interpolant.apply_code_rnn:
                hidden_state_contributions = execute(hidden_states,
                                                     node_embeddings)
                # leaves(hidden_state_contributions).shape: num_nodes, hidden_size
            else:
                hidden_state_contributions = hidden_states

            if config.model.interpolant.apply_dense:
                parent_to_true_child = jax.tree_map(
                    dense_parent_to_true_child, hidden_state_contributions)
                parent_to_false_child = jax.tree_map(
                    dense_parent_to_false_child, hidden_state_contributions)
                true_child_to_parent = jax.tree_map(
                    dense_true_child_to_parent, hidden_state_contributions)
                false_child_to_parent = jax.tree_map(
                    dense_false_child_to_parent, hidden_state_contributions)
            else:
                parent_to_true_child = hidden_state_contributions
                parent_to_false_child = hidden_state_contributions
                true_child_to_parent = hidden_state_contributions
                false_child_to_parent = hidden_state_contributions

            # Use the exit node's hidden state as it's hidden state contribution
            # to avoid "executing" the exit node.
            def mask_h(h_contribution, h):
                return h_contribution.at[exit_index, :].set(h[exit_index, :])

            hidden_state_contributions = jax.tree_multimap(
                mask_h, hidden_state_contributions, hidden_states)

            # Branch decisions (e.g. Dense layer)
            branch_decision_logits = branch_decide(hidden_state_contributions)
            branch_decisions = nn.softmax(branch_decision_logits, axis=-1)

            # Update state
            if config.model.interpolant.use_ipa:
                instruction_pointer_new = update_instruction_pointer(
                    instruction_pointer, branch_decisions, true_indexes,
                    false_indexes)
                hidden_states_new = aggregate(hidden_state_contributions,
                                              instruction_pointer,
                                              branch_decisions, true_indexes,
                                              false_indexes)
            else:
                assert config.model.interpolant.use_parent_embeddings
                assert config.model.interpolant.use_child_embeddings
                instruction_pointer_new = instruction_pointer
                normalization = jnp.sqrt(
                    2 + (  # Each node has a true and false child.
                        # jnp.bincount(true_indexes, minlength=num_nodes)
                        jax.ops.scatter.segment_sum(jnp.ones_like(
                            true_indexes),
                                                    true_indexes,
                                                    num_segments=num_nodes)
                        # + jnp.bincount(false_indexes, minlength=num_nodes)
                        + jax.ops.scatter.segment_sum(jnp.ones_like(
                            false_indexes),
                                                      false_indexes,
                                                      num_segments=num_nodes)))

                # normalization.shape: num_nodes,
                def aggregate_parent_and_child_contributions(p1, p2, c3, c4):
                    return (jax.ops.scatter.segment_sum(
                        p1, true_indexes, num_segments=num_nodes) +
                            jax.ops.scatter.segment_sum(
                                p2, false_indexes, num_segments=num_nodes) +
                            c3[true_indexes] +
                            c4[false_indexes]) / normalization[:, None]

                hidden_states_new = jax.tree_multimap(
                    aggregate_parent_and_child_contributions,
                    parent_to_true_child,
                    parent_to_false_child,
                    true_child_to_parent,  # true_child_to_parent[child] -> parent
                    false_child_to_parent)
            if config.model.interpolant.apply_gru:

                def apply_gru(h2, h1):
                    output, _ = gru_cell(h2, h1)
                    return output

                hidden_states_new = (jax.tree_multimap(apply_gru,
                                                       hidden_states_new,
                                                       hidden_states))

            to_tag = {
                'branch_decisions': branch_decisions,
                'hidden_state_contributions': hidden_state_contributions,
                'hidden_states_before': hidden_states,
                'hidden_states': hidden_states_new,
                'instruction_pointer_before': instruction_pointer,
                'instruction_pointer': instruction_pointer_new,
                'true_indexes': true_indexes,
                'false_indexes': false_indexes,
            }
            return hidden_states_new, instruction_pointer_new, to_tag
    def apply(self,
              x,
              num_actions,
              net_conf,
              env,
              normalize_obs,
              noisy,
              dueling,
              num_atoms,
              hidden_layer=2,
              neurons=512):
        del normalize_obs

        if net_conf == 'minatar':
            x = x.squeeze(3)
            x = x[None, ...]
            x = x.astype(jnp.float32)
            x = nn.Conv(x,
                        features=16,
                        kernel_size=(3, 3, 3),
                        strides=(1, 1, 1),
                        kernel_init=nn.initializers.xavier_uniform())
            x = jax.nn.relu(x)
            x = x.reshape((x.shape[0], -1))

        elif net_conf == 'atari':
            # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
            # have removed the true batch dimension.
            x = x[None, ...]
            x = x.astype(jnp.float32) / 255.
            x = nn.Conv(x,
                        features=32,
                        kernel_size=(8, 8),
                        strides=(4, 4),
                        kernel_init=nn.initializers.xavier_uniform())
            x = jax.nn.relu(x)
            x = nn.Conv(x,
                        features=64,
                        kernel_size=(4, 4),
                        strides=(2, 2),
                        kernel_init=nn.initializers.xavier_uniform())
            x = jax.nn.relu(x)
            x = nn.Conv(x,
                        features=64,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        kernel_init=nn.initializers.xavier_uniform())
            x = jax.nn.relu(x)
            x = x.reshape((x.shape[0], -1))  # flatten

        elif net_conf == 'classic':
            #classic environments
            x = x[None, ...]
            x = x.astype(jnp.float32)
            x = x.reshape((x.shape[0], -1))

        if env is not None:
            x = x - env_inf[env]['MIN_VALS']
            x /= env_inf[env]['MAX_VALS'] - env_inf[env]['MIN_VALS']
            x = 2.0 * x - 1.0

        if noisy:

            def net(x, features):
                return NoisyNetwork(x, features)
        else:

            def net(x, features):
                return nn.Dense(x,
                                features,
                                kernel_init=nn.initializers.xavier_uniform())

        for _ in range(hidden_layer):
            x = net(x, features=neurons)
            x = jax.nn.relu(x)

        if dueling:
            adv = net(x, features=num_actions * num_atoms)
            value = net(x, features=num_atoms)
            adv = adv.reshape((adv.shape[0], num_actions, num_atoms))
            value = value.reshape((value.shape[0], 1, num_atoms))
            logits = value + (adv - (jnp.mean(adv, -1, keepdims=True)))
            probabilities = nn.softmax(logits)
            q_values = jnp.mean(logits, axis=2)

        else:
            x = net(x, features=num_actions * num_atoms)
            logits = x.reshape((x.shape[0], num_actions, num_atoms))
            probabilities = nn.softmax(logits)
            q_values = jnp.mean(logits, axis=2)

        return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Exemplo n.º 17
0
 def lstm_cell(carry, x):
     newCarry, logits = jax.vmap(lstmCell)(carry[0], carry[1])
     sampleOut = jax.random.categorical(x, logits)
     sample = jax.nn.one_hot(sampleOut, inputDim)
     logProb = jnp.log(jnp.sum(nn.softmax(logits) * sample, axis=1))
     return (newCarry, sample), (logProb, sampleOut)
Exemplo n.º 18
0
 def rnn_dim2(carry,x):
     newCarry = actFun( cellInH(x[0]) + cellInV(x[1]) + cellCarryH(carry) + cellCarryV(x[2]) )
     out = jnp.concatenate((newCarry, nn.softmax(outputDense(newCarry))), axis=1)
     return newCarry, out
Exemplo n.º 19
0
 def lstm_cell(carry, x):
     newCarry, out = lstmCell(carry[0], carry[1])
     prob = nn.softmax(out)
     prob = jnp.log(jnp.sum(prob * x, axis=-1))
     return (newCarry, x), prob