def test_lstm(self): rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3)) c0, h0 = nn.LSTMCell.initialize_carry(rng, (2, ), 4) self.assertEqual(c0.shape, (2, 4)) self.assertEqual(h0.shape, (2, 4)) lstm = nn.LSTMCell() (carry, y), initial_params = lstm.init_with_output(key2, (c0, h0), x) self.assertEqual(carry[0].shape, (2, 4)) self.assertEqual(carry[1].shape, (2, 4)) np.testing.assert_allclose(y, carry[1]) param_shapes = jax.tree_map(np.shape, initial_params['params']) self.assertEqual( param_shapes, { 'ii': { 'kernel': (3, 4) }, 'if': { 'kernel': (3, 4) }, 'ig': { 'kernel': (3, 4) }, 'io': { 'kernel': (3, 4) }, 'hi': { 'kernel': (4, 4), 'bias': (4, ) }, 'hf': { 'kernel': (4, 4), 'bias': (4, ) }, 'hg': { 'kernel': (4, 4), 'bias': (4, ) }, 'ho': { 'kernel': (4, 4), 'bias': (4, ) }, })
def __call__(self, graph, feat): r""" Compute set2set pooling. Parameters ---------- graph : DGLGraph The input graph. feat : torch.Tensor The input feature with shape :math:`(N, D)` where :math:`N` is the number of nodes in the graph, and :math:`D` means the size of features. Returns ------- torch.Tensor The output feature with shape :math:`(B, D)`, where :math:`B` refers to the batch size, and :math:`D` means the size of features. """ with graph.local_scope(): batch_size = graph.batch_size h = (jnp.zeros((self.n_layers, batch_size, self.input_dim)), jnp.zeros((self.n_layers, batch_size, self.input_dim))) q_star = jnp.zeros((batch_size, self.output_dim)) for _ in range(self.n_iters): h, q = nn.LSTMCell()(h, jnp.expand_dims(q_star, 0)) q = q[0].reshape((batch_size, self.input_dim)) e = (feat * broadcast_nodes(graph, q)).sum(axis=-1, keepdims=True) graph.ndata['e'] = e alpha = softmax_nodes(graph, 'e') graph.ndata['r'] = feat * alpha readout = sum_nodes(graph, 'r') q_star = jnp.concatenate([q, readout], axis=-1) return q_star
def test_optimized_lstm_cell_matches_regular(self): # Create regular LSTMCell. rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3)) c0, h0 = nn.LSTMCell.initialize_carry(rng, (2,), 4) self.assertEqual(c0.shape, (2, 4)) self.assertEqual(h0.shape, (2, 4)) lstm = nn.LSTMCell() (_, y), lstm_params = lstm.init_with_output(key2, (c0, h0), x) # Create OptimizedLSTMCell. rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (2, 3)) c0, h0 = nn.OptimizedLSTMCell.initialize_carry(rng, (2,), 4) self.assertEqual(c0.shape, (2, 4)) self.assertEqual(h0.shape, (2, 4)) lstm_opt = nn.OptimizedLSTMCell() (_, y_opt), lstm_opt_params = lstm_opt.init_with_output(key2, (c0, h0), x) np.testing.assert_allclose(y, y_opt, rtol=1e-6) jtu.check_eq(lstm_params, lstm_opt_params)
def __call__(self, c, xs): return nn.LSTMCell(name="lstm_cell")(c, xs)
def __call__(self, carry, x): return nn.LSTMCell()(carry, x)
def __call__(self, c, b, xs): assert b.shape == (4, ) return nn.LSTMCell(name="lstm_cell")(c, xs)
episode_reward = 0 episode_timesteps = 0 episode_num = 0 for t in range(int(max_episodes)): language_states = np.zeros((episode_max_time, args.language_dim)) vision_states = np.zeros((episode_max_time, *vision_dim)) actions = np.zeros((episode_max_time, args.action_dim)) rewards = np.zeros((episode_max_time, 1)) discounts = np.zeros((episode_max_time, 1)) memory_hiddens = np.zeros((episode_max_time, args.memory_hidden_dim)) reconstruction_hiddens = np.zeros( (episode_max_time, args.embedding_dim)) episode_logits = np.zeros((episode_max_time, wrapped_env.action_dim)) h_prev = nn.LSTMCell().initialize_carry(next(policy.rng), (1, ), args.memory_hidden_dim) decoder_h_prev = nn.LSTMCell().initialize_carry( next(policy.rng), (1, ), args.embedding_dim) while not timestep.last(): language_state = policy.get_tokens(timestep) vision_state = timestep.observation["RGB_INTERLEAVED"] language_states[episode_timesteps] = language_state vision_states[episode_timesteps] = vision_state memory_hiddens[episode_timesteps] = h_prev[1].squeeze() reconstruction_hiddens[episode_timesteps] = decoder_h_prev[ 1].squeeze() if t < start_episodes: logits = np.random.rand(wrapped_env.action_dim)