def probabilities_from_activation(self, vmap): return activation_functions.softmax_with_zero(vmap[self])
def sample_from_activation(self, vmap): p0 = activation_functions.softmax_with_zero(vmap[self]) s0 = samplers.multinomial(p0) s = s0[:, :, :-1] # chop off the last state (zero state) return s