def apply(self, x, num_actions, minatar, env, normalize_obs, noisy, dueling, num_atoms,hidden_layer=2, neurons=512): del normalize_obs if minatar: x = x.squeeze(3) x = x[None, ...] x = x.astype(jnp.float32) x = nn.Conv(x, features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1), kernel_init=nn.initializers.xavier_uniform()) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) else: x = x[None, ...] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) if env is not None: x = x - env_inf[env]['MIN_VALS'] x /= env_inf[env]['MAX_VALS'] - env_inf[env]['MIN_VALS'] x = 2.0 * x - 1.0 if noisy: def net(x, features): return NoisyNetwork(x, features) else: def net(x, features): return nn.Dense(x, features, kernel_init=nn.initializers.xavier_uniform()) for _ in range(hidden_layer): x = net(x, features=neurons) #print('x:',x) x = jax.nn.relu(x) if dueling: print('dueling') adv = net(x,features=num_actions * num_atoms) value = net(x, features=num_atoms) adv = adv.reshape((adv.shape[0], num_actions, num_atoms)) value = value.reshape((value.shape[0], 1, num_atoms)) logits = value + (adv - (jnp.mean(adv, -1, keepdims=True))) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) else: #print('No dueling') x = net(x, features=num_actions * num_atoms) logits = x.reshape((x.shape[0], num_actions, num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def apply_activation(intermediate_output, intermediate_activation): """Applies selected activation function to intermediate output.""" if intermediate_activation is None: return intermediate_output if intermediate_activation == 'gelu': intermediate_output = nn.gelu(intermediate_output) elif intermediate_activation == 'relu': intermediate_output = nn.relu(intermediate_output) elif intermediate_activation == 'sigmoid': intermediate_output = nn.sigmoid(intermediate_output) elif intermediate_activation == 'softmax': intermediate_output = nn.softmax(intermediate_output) elif intermediate_activation == 'celu': intermediate_output = nn.celu(intermediate_output) elif intermediate_activation == 'elu': intermediate_output = nn.elu(intermediate_output) elif intermediate_activation == 'log_sigmoid': intermediate_output = nn.log_sigmoid(intermediate_output) elif intermediate_activation == 'log_softmax': intermediate_output = nn.log_softmax(intermediate_output) elif intermediate_activation == 'soft_sign': intermediate_output = nn.soft_sign(intermediate_output) elif intermediate_activation == 'softplus': intermediate_output = nn.softplus(intermediate_output) elif intermediate_activation == 'swish': intermediate_output = nn.swish(intermediate_output) elif intermediate_activation == 'tanh': intermediate_output = jnp.tanh(intermediate_output) else: raise NotImplementedError( '%s activation function is not yet supported.' % intermediate_activation) return intermediate_output
def apply(self, x, n_bins, stage_sizes, block_cls, num_filters=64, dtype=jnp.float32, act=nn.leaky_relu, train=True): b = x.shape[0] conv = nn.Conv.partial(bias=False, dtype=dtype) norm = nn.BatchNorm.partial( use_running_average=not train, dtype=dtype ) x = conv(x, num_filters, kernel_size=(7,), strides=(2,), padding=[(3, 3)]) x = norm(x) x = nn.leaky_relu(x) x = nn.max_pool(x, window_shape=(3,), strides=(2,), padding='SAME') for i, block_size in enumerate(stage_sizes): for j in range(block_size): strides = (2,) if i > 0 and j == 0 else (1,) x = block_cls(x, num_filters * 2 ** i, strides=strides, conv=conv, norm=norm, act=act) x = x.reshape(b, -1) x = nn.Dense(x, n_bins, dtype=dtype) x = nn.softmax(x) return x
def apply(self, x, num_actions, num_atoms): initializer = jax.nn.initializers.variance_scaling( scale=1.0 / jnp.sqrt(3.0), mode='fan_in', distribution='uniform') # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) / 255. x = nn.Conv(x, features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) x = nn.Dense(x, features=num_actions * num_atoms, kernel_init=initializer) logits = x.reshape((x.shape[0], num_actions, num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def rnn_dim2(carry,x): newCarry = actFun( cellInH(carry[1]) + cellInV(x[1]) + cellCarryH(carry[0]) + cellCarryV(x[0]) ) logits=outputDense(newCarry) sampleOut=jax.random.categorical(x[2],logits) sample=jax.nn.one_hot(sampleOut,inputDim) logProb=jnp.log( jnp.sum( nn.softmax(logits) * sample, axis=1 ) ) output = (newCarry, logProb, sampleOut) return (newCarry,sample), output
def apply(self, x, num_actions, num_atoms, support, noisy, dueling): # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) # flatten #x -= gym_lib.CARTPOLE_MIN_VALS #x /= gym_lib.CARTPOLE_MAX_VALS - gym_lib.CARTPOLE_MIN_VALS #x = 2.0 * x - 1.0 # Rescale in range [-1, 1]. if noisy: print('LunarLander-Noisy[Johan]') initializer = None bias = True def net(x, features, bias, kernel_init): return NoisyNetwork(x, features, bias, kernel_init) else: initializer = nn.initializers.xavier_uniform() bias = None def net(x, features, bias, kernel_init): return nn.Dense(x, features, kernel_init) x = net(x, features=512, bias=bias, kernel_init=initializer) x = jax.nn.relu(x) x = net(x,features=512, bias=bias, kernel_init=initializer) x = jax.nn.relu(x) if dueling: print('LunarLanderRainbowFull-Dueling') adv = net(x,features=num_actions * num_atoms, bias=bias, kernel_init=initializer) value = net(x, features=num_atoms, bias=bias, kernel_init=initializer) adv = adv.reshape((adv.shape[0], num_actions, num_atoms)) value = value.reshape((value.shape[0], 1, num_atoms)) logits = value + (adv - (jnp.mean(adv, -1, keepdims=True))) probabilities = nn.softmax(logits) q_values = jnp.sum(support * probabilities, axis=2) else: x = net(x, features=num_actions * num_atoms, bias=bias, kernel_init=initializer) logits = x.reshape((x.shape[0], num_actions, num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.sum(support * probabilities, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def apply(self, x): net = nn.Dense(x, 500, name='fc1') net = nn.leaky_relu(net) net = nn.BatchNorm(net) net = nn.Dense(net, 500, name='fc2') net = nn.leaky_relu(net) net = nn.BatchNorm(net) net = nn.Dense(net, 500, name='fc3') net = nn.leaky_relu(net) net = nn.BatchNorm(net) return nn.softmax(nn.Dense(net, n_bin))
def apply(self, x, num_actions, num_atoms, support, noisy, dueling): # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. initializer_conv = nn.initializers.xavier_uniform() x = x[None, ...] x = x.astype(jnp.float32) x = nn.Conv(x, features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1), kernel_init=initializer_conv) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten. if noisy: print('InvadersRainbowFull-Noisy[Johan]') initializer = None bias = True def net(x, features, bias, kernel_init): return NoisyNetwork(x, features, bias, kernel_init) else: initializer = nn.initializers.xavier_uniform() bias = None def net(x, features, bias, kernel_init): return nn.Dense(x, features, kernel_init) if dueling: print('InvadersRainbowFull-Dueling') adv = net(x,features=num_actions * num_atoms, bias=bias, kernel_init=initializer) value = net(x, features=num_atoms, bias=bias, kernel_init=initializer) adv = adv.reshape((adv.shape[0], num_actions, num_atoms)) value = value.reshape((value.shape[0], 1, num_atoms)) logits = value + (adv - (jnp.mean(adv, -1, keepdims=True))) probabilities = nn.softmax(logits) q_values = jnp.sum(support * probabilities, axis=2) else: x = net(x, features=num_actions * num_atoms, bias=bias, kernel_init=initializer) logits = x.reshape((x.shape[0], num_actions, num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.sum(support * probabilities, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def apply(self, x): b = x.shape[0] x = nn.Conv(x, features=128, kernel_size=(4, ), padding='SAME') x = nn.BatchNorm(x) x = nn.leaky_relu(x) x = nn.avg_pool(x, window_shape=(2, ), padding='SAME') x = nn.Conv(x, features=256, kernel_size=(4, ), padding='SAME') x = nn.BatchNorm(x) x = nn.leaky_relu(x) x = nn.avg_pool(x, window_shape=(2, ), padding='SAME') x = x.reshape(b, -1) x = nn.Dense(x, features=128) x = nn.BatchNorm(x) x = nn.leaky_relu(x) x = nn.Dense(x, features=n_bins) x = nn.softmax(x) return x
def apply(self, x, num_actions, num_atoms, support): def custom_init(key, shape, dtype=jnp.float32): del key to_pick_first_action = onp.ones(shape, dtype) to_pick_first_action[:, :num_atoms] = onp.arange(1, num_atoms + 1) return to_pick_first_action x = x[None, :] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(x, features=num_actions * num_atoms, kernel_init=custom_init, bias_init=jax.nn.initializers.ones) logits = x.reshape((-1, num_actions, num_atoms)) probabilities = nn.softmax(logits) qs = jnp.sum(support * probabilities, axis=2) return atari_lib.RainbowNetworkType(qs, logits, probabilities)
def step_single_example(hidden_states, instruction_pointer, node_embeddings, true_indexes, false_indexes, exit_index): # Execution (e.g. apply RNN) # leaves(hidden_states).shape: num_nodes, hidden_size # instruction_pointer.shape: num_nodes, # node_embeddings.shape: num_nodes, statement_length, hidden_size hidden_state_contributions = execute(hidden_states, node_embeddings) # leaves(hidden_state_contributions).shape: num_nodes, hidden_size # Use the exit node's hidden state as it's hidden state contribution # to avoid "executing" the exit node. def mask_h(h_contribution, h): return h_contribution.at[exit_index, :].set(h[exit_index, :]) hidden_state_contributions = jax.tree_multimap( mask_h, hidden_state_contributions, hidden_states) # Branch decisions (e.g. Dense layer) branch_decision_logits = branch_decide(hidden_state_contributions) branch_decisions = nn.softmax(branch_decision_logits, axis=-1) # Update state instruction_pointer_new = update_instruction_pointer( instruction_pointer, branch_decisions, true_indexes, false_indexes) hidden_states_new = aggregate(hidden_state_contributions, instruction_pointer, branch_decisions, true_indexes, false_indexes) to_tag = { 'branch_decisions': branch_decisions, 'hidden_state_contributions': hidden_state_contributions, 'hidden_states_before': hidden_states, 'hidden_states': hidden_states_new, 'instruction_pointer_before': instruction_pointer, 'instruction_pointer': instruction_pointer_new, 'true_indexes': true_indexes, 'false_indexes': false_indexes, } return hidden_states_new, instruction_pointer_new, to_tag
def apply(self, hidden_states, mask=None, *, d_qkv=64, attention_dropout_rate=0.0, output_dropout_rate=0.0, deterministic=False, kernel_init=nn.linear.default_kernel_init, output_kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.zeros, bias=True): """Applies attention for a single batch element and head.""" d_model = hidden_states.shape[-1] dense = nn.DenseGeneral.partial(axis=-1, features=(d_qkv, ), kernel_init=kernel_init, bias_init=bias_init, bias=bias) query, key, value = (dense(hidden_states, name='query'), dense(hidden_states, name='key'), dense(hidden_states, name='value')) attention_scores = jnp.einsum('TN,FN->FT', key, query) attention_scores = attention_scores / jnp.sqrt(d_qkv) if mask is not None: padding_mask = (1.0 - mask[None, :]) * NEG_INFINITY attention_scores = attention_scores + padding_mask attention_scores = nn.softmax(attention_scores) attention_probs = nn.dropout(attention_scores, rate=attention_dropout_rate, deterministic=deterministic) hidden_states = jnp.einsum('FT,TH->FH', attention_probs, value) hidden_states = nn.linear.DenseGeneral(hidden_states, features=d_model, axis=(-1, ), kernel_init=output_kernel_init, name='output') hidden_states = nn.dropout(hidden_states, rate=output_dropout_rate, deterministic=deterministic) return hidden_states
def apply(self, x, num_actions, num_atoms): initializer = nn.initializers.xavier_uniform() # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) # flatten x -= gym_lib.CARTPOLE_MIN_VALS x /= gym_lib.CARTPOLE_MAX_VALS - gym_lib.CARTPOLE_MIN_VALS x = 2.0 * x - 1.0 # Rescale in range [-1, 1]. x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) x = nn.Dense(x, features=512, kernel_init=initializer) x = jax.nn.relu(x) x = nn.Dense(x, features=num_actions * num_atoms, kernel_init=initializer) logits = x.reshape((x.shape[0], num_actions, num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def generate_text( model, vocab, max_length=100, temperature=0.5, top_k=3, start_letter="T", ): output_text = start_letter carry = nn.GRUCell.initialize_carry(jax.random.PRNGKey(0), (1, ), FLAGS.hidden_size) for i in range(max_length): input = vocab.numericalize(output_text[-1]) input_t = jnp.array(input, dtype=jnp.int32).reshape(1, 1) carry, pred = model(input_t, carry) prob = nn.softmax(pred / temperature, axis=1) # output_text += vocab.textify(prob.argmax().tolist())[0] prob_np = np.array(prob)[0] top_k_index = prob_np.argsort()[-top_k:] next_char = np.random.choice(top_k_index.tolist(), 1, prob_np[top_k_index].tolist()) output_text += vocab.textify(next_char.item())[0] return output_text
def step_single_example(hidden_states, instruction_pointer, node_embeddings, true_indexes, false_indexes, exit_index): # Execution (e.g. apply RNN) # leaves(hidden_states).shape: num_nodes, hidden_size # instruction_pointer.shape: num_nodes, # node_embeddings.shape: num_nodes, statement_length, hidden_size if config.model.interpolant.apply_code_rnn: hidden_state_contributions = execute(hidden_states, node_embeddings) # leaves(hidden_state_contributions).shape: num_nodes, hidden_size else: hidden_state_contributions = hidden_states if config.model.interpolant.apply_dense: parent_to_true_child = jax.tree_map( dense_parent_to_true_child, hidden_state_contributions) parent_to_false_child = jax.tree_map( dense_parent_to_false_child, hidden_state_contributions) true_child_to_parent = jax.tree_map( dense_true_child_to_parent, hidden_state_contributions) false_child_to_parent = jax.tree_map( dense_false_child_to_parent, hidden_state_contributions) else: parent_to_true_child = hidden_state_contributions parent_to_false_child = hidden_state_contributions true_child_to_parent = hidden_state_contributions false_child_to_parent = hidden_state_contributions # Use the exit node's hidden state as it's hidden state contribution # to avoid "executing" the exit node. def mask_h(h_contribution, h): return h_contribution.at[exit_index, :].set(h[exit_index, :]) hidden_state_contributions = jax.tree_multimap( mask_h, hidden_state_contributions, hidden_states) # Branch decisions (e.g. Dense layer) branch_decision_logits = branch_decide(hidden_state_contributions) branch_decisions = nn.softmax(branch_decision_logits, axis=-1) # Update state if config.model.interpolant.use_ipa: instruction_pointer_new = update_instruction_pointer( instruction_pointer, branch_decisions, true_indexes, false_indexes) hidden_states_new = aggregate(hidden_state_contributions, instruction_pointer, branch_decisions, true_indexes, false_indexes) else: assert config.model.interpolant.use_parent_embeddings assert config.model.interpolant.use_child_embeddings instruction_pointer_new = instruction_pointer normalization = jnp.sqrt( 2 + ( # Each node has a true and false child. # jnp.bincount(true_indexes, minlength=num_nodes) jax.ops.scatter.segment_sum(jnp.ones_like( true_indexes), true_indexes, num_segments=num_nodes) # + jnp.bincount(false_indexes, minlength=num_nodes) + jax.ops.scatter.segment_sum(jnp.ones_like( false_indexes), false_indexes, num_segments=num_nodes))) # normalization.shape: num_nodes, def aggregate_parent_and_child_contributions(p1, p2, c3, c4): return (jax.ops.scatter.segment_sum( p1, true_indexes, num_segments=num_nodes) + jax.ops.scatter.segment_sum( p2, false_indexes, num_segments=num_nodes) + c3[true_indexes] + c4[false_indexes]) / normalization[:, None] hidden_states_new = jax.tree_multimap( aggregate_parent_and_child_contributions, parent_to_true_child, parent_to_false_child, true_child_to_parent, # true_child_to_parent[child] -> parent false_child_to_parent) if config.model.interpolant.apply_gru: def apply_gru(h2, h1): output, _ = gru_cell(h2, h1) return output hidden_states_new = (jax.tree_multimap(apply_gru, hidden_states_new, hidden_states)) to_tag = { 'branch_decisions': branch_decisions, 'hidden_state_contributions': hidden_state_contributions, 'hidden_states_before': hidden_states, 'hidden_states': hidden_states_new, 'instruction_pointer_before': instruction_pointer, 'instruction_pointer': instruction_pointer_new, 'true_indexes': true_indexes, 'false_indexes': false_indexes, } return hidden_states_new, instruction_pointer_new, to_tag
def apply(self, x, num_actions, net_conf, env, normalize_obs, noisy, dueling, num_atoms, hidden_layer=2, neurons=512): del normalize_obs if net_conf == 'minatar': x = x.squeeze(3) x = x[None, ...] x = x.astype(jnp.float32) x = nn.Conv(x, features=16, kernel_size=(3, 3, 3), strides=(1, 1, 1), kernel_init=nn.initializers.xavier_uniform()) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) elif net_conf == 'atari': # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will # have removed the true batch dimension. x = x[None, ...] x = x.astype(jnp.float32) / 255. x = nn.Conv(x, features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=nn.initializers.xavier_uniform()) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=nn.initializers.xavier_uniform()) x = jax.nn.relu(x) x = nn.Conv(x, features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=nn.initializers.xavier_uniform()) x = jax.nn.relu(x) x = x.reshape((x.shape[0], -1)) # flatten elif net_conf == 'classic': #classic environments x = x[None, ...] x = x.astype(jnp.float32) x = x.reshape((x.shape[0], -1)) if env is not None: x = x - env_inf[env]['MIN_VALS'] x /= env_inf[env]['MAX_VALS'] - env_inf[env]['MIN_VALS'] x = 2.0 * x - 1.0 if noisy: def net(x, features): return NoisyNetwork(x, features) else: def net(x, features): return nn.Dense(x, features, kernel_init=nn.initializers.xavier_uniform()) for _ in range(hidden_layer): x = net(x, features=neurons) x = jax.nn.relu(x) if dueling: adv = net(x, features=num_actions * num_atoms) value = net(x, features=num_atoms) adv = adv.reshape((adv.shape[0], num_actions, num_atoms)) value = value.reshape((value.shape[0], 1, num_atoms)) logits = value + (adv - (jnp.mean(adv, -1, keepdims=True))) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) else: x = net(x, features=num_actions * num_atoms) logits = x.reshape((x.shape[0], num_actions, num_atoms)) probabilities = nn.softmax(logits) q_values = jnp.mean(logits, axis=2) return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
def lstm_cell(carry, x): newCarry, logits = jax.vmap(lstmCell)(carry[0], carry[1]) sampleOut = jax.random.categorical(x, logits) sample = jax.nn.one_hot(sampleOut, inputDim) logProb = jnp.log(jnp.sum(nn.softmax(logits) * sample, axis=1)) return (newCarry, sample), (logProb, sampleOut)
def rnn_dim2(carry,x): newCarry = actFun( cellInH(x[0]) + cellInV(x[1]) + cellCarryH(carry) + cellCarryV(x[2]) ) out = jnp.concatenate((newCarry, nn.softmax(outputDense(newCarry))), axis=1) return newCarry, out
def lstm_cell(carry, x): newCarry, out = lstmCell(carry[0], carry[1]) prob = nn.softmax(out) prob = jnp.log(jnp.sum(prob * x, axis=-1)) return (newCarry, x), prob