Example #1
0
    def apply_fun(params, inputs):
        conv_params, pair_params, conv_block_params, serial_params = params

        # Apply the primary convolutional layer.
        conv_out = conv_apply(conv_params, inputs)
        conv_out = relu(conv_out)

        # Group all possible pairs.
        W, b = pair_params
        pair_1 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,1), dim_nums) + b
        pair_2 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,2), dim_nums) + b
        pair_3 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,3), dim_nums) + b
        pair_4 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,4), dim_nums) + b
        pair_5 = conv_general_dilated(conv_out, W, unit_stride, zero_pad, (1,1), (1,5), dim_nums) + b
        pair_out = jnp.dstack([pair_1, pair_2, pair_3, pair_4, pair_5])
        pair_out = relu(pair_out)

        # Convolutional block.
        conv_block_out = conv_block_apply(conv_block_params, pair_out)

        # Residual connection.
        res_out = conv_block_out + pair_out
        res_out = relu(res_out)

        # Forward pass.
        out = serial_apply(serial_params, res_out)
        return out
Example #2
0
def predict_proba(params, teamA_rating, teamB_rating, has_tie):
    dr = (teamA_rating - teamB_rating) * params["beta"]
    gamma = nn.relu(params["gamma"]) * has_tie
    pA = jnp.clip(nn.sigmoid(dr - gamma), __EPS__, 1 - __EPS__)
    pB = jnp.clip(nn.sigmoid(-dr - gamma), __EPS__, 1 - __EPS__)
    pD = nn.relu(1.0 - pA - pB) * has_tie
    s = pA + pB + pD
    return [jnp.array(x, float) for x in [pA / s, pD / s, pB / s]]
Example #3
0
 def __call__(self, x):
     # He's initializer.
     w_init = hk.initializers.Orthogonal(scale=np.sqrt(2))
     # Floatify the image.
     x = x.astype(jnp.float32) / 255.0
     # Apply CNN.
     x = hk.Conv2D(32, kernel_shape=8, stride=4, padding="VALID", w_init=w_init)(x)
     x = nn.relu(x)
     x = hk.Conv2D(64, kernel_shape=4, stride=2, padding="VALID", w_init=w_init)(x)
     x = nn.relu(x)
     x = hk.Conv2D(64, kernel_shape=3, stride=1, padding="VALID", w_init=w_init)(x)
     x = nn.relu(x)
     # Flatten the feature map.
     return hk.Flatten()(x)
Example #4
0
    def get_latents(self, encodings, probs_b, training):
        """Read out latents (z) form input encodings for a single segment."""
        readout_mask = probs_b[:, 1:, None]  # Offset readout by 1 to left.
        readout = (encodings[:, :-1] * readout_mask).sum(1)
        hidden = nn.relu(self.head_z_1(readout))
        logits_z = self.head_z_2(hidden)

        # Gaussian latents.
        if self.latent_dist == 'gaussian':
            if training:
                mu, log_var = jnp.split(logits_z, 2, axis=1)
                sample_z = utils.gaussian_sample(hk.next_rng_key(), mu,
                                                 log_var)
            else:
                sample_z = logits_z[:, :self.latent_dim]

        # Concrete / Gumbel softmax latents.
        elif self.latent_dist == 'concrete':
            if training:
                sample_z = utils.gumbel_softmax_sample(hk.next_rng_key(),
                                                       logits_z,
                                                       temp=self.temp_z)
            else:
                sample_z_idx = jnp.argmax(logits_z, axis=1)
                sample_z = utils.to_one_hot(sample_z_idx, logits_z.size(1))
        else:
            raise ValueError('Invalid argument for `latent_dist`.')

        return logits_z, sample_z
Example #5
0
    def identity_block(inputs):
        main = Sequential(Conv(filters1, (1, 1)), BatchNorm(), relu,
                          Conv(filters2, (ks, ks), padding='SAME'),
                          BatchNorm(), relu, Conv(inputs.shape[3], (1, 1)),
                          BatchNorm())

        return relu(sum((main(inputs), inputs)))
Example #6
0
 def conv_block(inputs):
     main = Sequential(Conv(filters1, (1, 1), strides), BatchNorm(), relu,
                       Conv(filters2, (ks, ks),
                            padding='SAME'), BatchNorm(), relu,
                       Conv(filters3, (1, 1)), BatchNorm())
     shortcut = Sequential(Conv(filters3, (1, 1), strides), BatchNorm())
     return relu(sum((main(inputs), shortcut(inputs))))
Example #7
0
    def get_boundaries(self, encodings, segment_id, lengths, training):
        """Get boundaries (b) for a single segment in batch."""
        if segment_id == self.max_num_segments - 1:
            # Last boundary is always placed on last sequence element.
            logits_b = None
            # sample_b = jnp.zeros_like(encodings[:, :, 0]).scatter_(
            #     1, jnp.expand_dims(lengths, -1) - 1, 1)
            sample_b = jnp.zeros_like(encodings[:, :, 0])
            sample_b = jax.ops.index_update(
                sample_b, jax.ops.index[jnp.arange(len(lengths)), lengths - 1],
                1)
        else:
            hidden = nn.relu(self.head_b_1(encodings))
            logits_b = jnp.squeeze(self.head_b_2(hidden), -1)
            # Mask out first position with large neg. value.
            neg_inf = jnp.ones((encodings.shape[0], 1)) * utils.NEG_INF
            # TODO(tkipf): Mask out padded positions with large neg. value.
            logits_b = jnp.concatenate([neg_inf, logits_b[:, 1:]], axis=1)
            if training:
                sample_b = utils.gumbel_softmax_sample(hk.next_rng_key(),
                                                       logits_b,
                                                       temp=self.temp_b)
            else:
                sample_b_idx = jnp.argmax(logits_b, axis=1)
                sample_b = nn.one_hot(sample_b_idx, logits_b.shape[1])

        return logits_b, sample_b
Example #8
0
        def _fn(x, cum_p):
            if len(x.shape) == 4:
                x = DQNBody()(x)

            # NOTE: For IQN and FQF, number of quantiles are variable.
            feature_dim = x.shape[1]
            num_quantiles = cum_p.shape[1]
            # Calculate features.
            cosine = jnp.cos(jnp.expand_dims(cum_p, 2) * self.pi).reshape(
                -1, self.num_cosines)
            cosine_feature = nn.relu(hk.Linear(feature_dim)(cosine)).reshape(
                -1, num_quantiles, feature_dim)
            x = (x.reshape(-1, 1, feature_dim) * cosine_feature).reshape(
                -1, feature_dim)
            # Apply quantile network.
            output = MLP(
                self.action_space.n,
                self.hidden_units,
                hidden_activation=nn.relu,
                hidden_scale=np.sqrt(2),
            )(x)
            output = output.reshape(-1, num_quantiles, self.action_space.n)
            if self.dueling_net:
                baseline = MLP(
                    1,
                    self.hidden_units,
                    hidden_activation=nn.relu,
                    hidden_scale=np.sqrt(2),
                )(x)
                baseline = baseline.reshape(-1, num_quantiles, 1)
                return output + baseline - output.mean(axis=2, keepdims=True)
            else:
                return output
Example #9
0
        def update_ratings(params, teamA_rating, teamB_rating, teamA_idx,
                           teamB_idx, winner, rating):
            '''Update rating step'''

            pA, _, pB = predict_proba(params, teamA_rating, teamB_rating,
                                      self.has_tie)

            operand = nn.relu(params["lr"])
            delta_A_d = lax.cond(
                winner == 1.0,
                lambda x: x * (pB - pA),
                lambda x: 0.0,
                operand,
            )
            delta_B_d = -delta_A_d

            delta_A_win = lax.cond(winner == 0.0, lambda x: x * (1 - pA),
                                   lambda x: 0.0, operand)
            delta_B_lose = lax.cond(winner == 0.0, lambda x: x * (0 - pB),
                                    lambda x: 0.0, operand)

            delta_A_lose = lax.cond(winner == 2.0, lambda x: x * (0 - pA),
                                    lambda x: 0.0, operand)
            delta_B_win = lax.cond(winner == 2.0, lambda x: x * (1 - pB),
                                   lambda x: 0.0, operand)

            delta_A = delta_A_d + delta_A_win + delta_A_lose
            delta_B = delta_B_d + delta_B_lose + delta_B_win

            rating = jop.index_add(rating, teamA_idx, jnp.tanh(delta_A))
            rating = jop.index_add(rating, teamB_idx, jnp.tanh(delta_B))
            return rating
Example #10
0
def predict(params, inputs):
	activations = inputs
	for w, b in params[:-1]:
		outputs = jnp.dot(activations, w) + b
		activations = nn.relu(outputs)

	final_w, final_b = params[-1]
	logits = jnp.dot(activations, final_w) + final_b
	return logits - logsumexp(logits, axis=1, keepdims=True)
Example #11
0
def developing_step(key, state, state_timer, recovery_probabilities,
                    state_length_sampler):
  to_develop = np.logical_and(state_timer == 1, is_transitional(state))
  state_timer = relu(state_timer - 1)
  key, new_state = sample_development(key, state, recovery_probabilities)
  key, new_state_timer = state_length_sampler(key, new_state)
  return (key,
          state * (1 - to_develop) + new_state * to_develop,
          state_timer * (1 - to_develop) + new_state_timer * to_develop)
Example #12
0
def gaussian_and_tanh_log_prob(
    log_std: jnp.ndarray,
    noise: jnp.ndarray,
    action: jnp.ndarray,
) -> jnp.ndarray:
    """
    Calculate log probabilities of gaussian distributions and tanh transformation.
    """
    return gaussian_log_prob(
        log_std, noise) - jnp.log(nn.relu(1.0 - jnp.square(action)) + 1e-6)
Example #13
0
    def apply_fun(params, x, adj, is_training=False, **kwargs):
        rng = kwargs.pop('rng', None)
        k1, k2, k3, k4 = random.split(rng, 4)

        x = drop_fun(None, x, is_training=is_training, rng=k1)
        x = gc1_fun(params[0], x, adj, rng=k2)
        x = nn.relu(x)
        x = drop_fun(None, x, is_training=is_training, rng=k3)
        x = gc2_fun(params[1], x, adj, rng=k4)
        x = nn.log_softmax(x)
        return x
Example #14
0
def compute_feature_energy(params, x):
    A_1, A_2, *_ = params
    return relu(dense(A_2, relu(dense(A_1, x))))
Example #15
0
 def decode(self, sample_z, length):
     """Decode single time step from latents and repeat over full seq."""
     hidden = nn.relu(self.decode_1(sample_z))
     pred = self.decode_2(hidden)
     return jnp.tile(jnp.expand_dims(pred, 1), (1, length, 1))
Example #16
0
 def _relu(x):
     return relu(x)
Example #17
0
 def __call__(self, x: DeviceArray) -> DeviceArray:
     x = self.linear1(x)
     x = relu(x)
     x = self.linear2(x)
     return x