def test_nested(self): eps_ids = [1, 1, 1, 5, 5, 5, 5, 5] agent_ids = [1, 1, 1, 1, 1, 1, 1, 1] f = [ { "a": np.array([1, 2, 3, 4, 13, 14, 15, 16]), "b": {"ba": np.array([5, 6, 7, 8, 9, 10, 11, 12])}, } ] s = [[209, 208, 207, 109, 108, 107, 106, 105]] f_pad, s_init, seq_lens = chop_into_sequences( episode_ids=eps_ids, unroll_ids=np.ones_like(eps_ids), agent_indices=agent_ids, feature_columns=f, state_columns=s, max_seq_len=4, handle_nested_data=True, ) check( f_pad, [ [ [1, 2, 3, 0, 4, 13, 14, 15, 16, 0, 0, 0], [5, 6, 7, 0, 8, 9, 10, 11, 12, 0, 0, 0], ] ], ) self.assertEqual([s.tolist() for s in s_init], [[209, 109, 105]]) self.assertEqual(seq_lens.tolist(), [3, 4, 1])
def _get_loss_inputs_dict(self, batch, shuffle, cross_policy_obj, policy_id=None): """When training, add the required data into the feed_dict.""" feed_dict = {} if hasattr(self, "_AddLossMixin_initialized"): assert self._AddLossMixin_initialized # parse the cross-policy info and put in feed_dict. joint_obs_ph = self._loss_input_dict[JOINT_OBS] feed_dict[joint_obs_ph] = cross_policy_obj[JOINT_OBS] replay_ph = self._loss_input_dict[PEER_ACTION] feed_dict[replay_ph] = np.concatenate([ act for pid, act in cross_policy_obj[PEER_ACTION].items() if pid != policy_id ]) # exclude policy itself action """The below codes are copied from rllib. """ if self._batch_divisibility_req > 1: meets_divisibility_reqs = ( len(batch[SampleBatch.CUR_OBS]) % self._batch_divisibility_req == 0 and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent else: meets_divisibility_reqs = True # Simple case: not RNN nor do we need to pad if not self._state_inputs and meets_divisibility_reqs: if shuffle: batch.shuffle() for k, ph in self._loss_inputs: if k in batch: # Attention! We add a condition here. feed_dict[ph] = batch[k] return feed_dict if self._state_inputs: max_seq_len = self._max_seq_len dynamic_max = True else: max_seq_len = self._batch_divisibility_req dynamic_max = False # RNN or multi-agent case feature_keys = [k for k, v in self._loss_inputs] state_keys = [ "state_in_{}".format(i) for i in range(len(self._state_inputs)) ] feature_sequences, initial_states, seq_lens = chop_into_sequences( batch[SampleBatch.EPS_ID], batch[SampleBatch.UNROLL_ID], batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys], [batch[k] for k in state_keys], max_seq_len, dynamic_max=dynamic_max, shuffle=shuffle) for k, v in zip(feature_keys, feature_sequences): feed_dict[self._loss_input_dict[k]] = v for k, v in zip(state_keys, initial_states): feed_dict[self._loss_input_dict[k]] = v feed_dict[self._seq_lens] = seq_lens return feed_dict
def testBatchId(self): eps_ids = [1, 1, 1, 5, 5, 5, 5, 5] batch_ids = [1, 1, 2, 2, 3, 3, 4, 4] agent_ids = [1, 1, 1, 1, 1, 1, 1, 1] f = [[101, 102, 103, 201, 202, 203, 204, 205], [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] _, _, seq_lens = chop_into_sequences(eps_ids, batch_ids, agent_ids, f, s, 4) self.assertEqual(seq_lens.tolist(), [2, 1, 1, 2, 2])
def test_chop_into_sequences_long_seq(self): """Test pad_batch where episodes are longer than max_seq_len. The long seq should be split into two smaller seqs that are less than max_seq_len""" max_seq_len = 2 # Input seq lens, corresponding to ep_ids, unroll_ids, etc. seq_lens = [2, 3, 1] # noqa: F841 ep_ids = [0, 0, 1, 1, 1, 2] unroll_ids = [2, 2, 3, 3, 3, 4] feats = [[1, 1, 2, 2, 2, 3]] # Input states, ie states[3] is the input state at # t = 3 and the output state at t = 2 states = [[1, 2, 3, 4, 5, 6]] agent = [0, 0, 0, 0, 0, 0] f_pad, s_init, s_lens = chop_into_sequences( feature_columns=feats, state_columns=states, max_seq_len=max_seq_len, episode_ids=ep_ids, unroll_ids=unroll_ids, agent_indices=agent, dynamic_max=False, ) expected_f_pad = [[1, 1, 2, 2, 2, 0, 3, 0]] expected_seq_lens = [2, 2, 1, 1] expected_states = [[1, 3, 5, 6]] check(f_pad, expected_f_pad) check(s_lens, expected_seq_lens) check(s_init, expected_states) # Try again with dynamic max f_pad, s_init, s_lens = chop_into_sequences( feature_columns=feats, state_columns=states, max_seq_len=max_seq_len, episode_ids=ep_ids, unroll_ids=unroll_ids, agent_indices=agent, dynamic_max=True, ) check(f_pad, expected_f_pad) check(s_lens, expected_seq_lens) check(s_init, expected_states)
def testDynamicMaxLen(self): eps_ids = [5, 2, 2] agent_ids = [2, 2, 2] f = [[1, 1, 1]] s = [[1, 1, 1]] f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, np.ones_like(eps_ids), agent_ids, f, s, 4) self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]]) self.assertEqual([s.tolist() for s in s_init], [[1, 1]]) self.assertEqual(seq_lens.tolist(), [1, 2])
def test_batch_id(self): eps_ids = [1, 1, 1, 5, 5, 5, 5, 5] batch_ids = [1, 1, 2, 2, 3, 3, 4, 4] agent_ids = [1, 1, 1, 1, 1, 1, 1, 1] f = [[101, 102, 103, 201, 202, 203, 204, 205], [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] _, _, seq_lens = chop_into_sequences(episode_ids=eps_ids, unroll_ids=batch_ids, agent_indices=agent_ids, feature_columns=f, state_columns=s, max_seq_len=4) self.assertEqual(seq_lens.tolist(), [2, 1, 1, 2, 2])
def testMultiDim(self): eps_ids = [1, 1, 1] agent_ids = [1, 1, 1] obs = np.ones((84, 84, 4)) f = [[obs, obs * 2, obs * 3]] s = [[209, 208, 207]] f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, np.ones_like(eps_ids), agent_ids, f, s, 4) self.assertEqual([f.tolist() for f in f_pad], [ np.array([obs, obs * 2, obs * 3]).tolist(), ]) self.assertEqual([s.tolist() for s in s_init], [[209]]) self.assertEqual(seq_lens.tolist(), [3])
def test_dynamic_max_len(self): eps_ids = [5, 2, 2] agent_ids = [2, 2, 2] f = [[1, 1, 1]] s = [[1, 1, 1]] f_pad, s_init, seq_lens = chop_into_sequences( episode_ids=eps_ids, unroll_ids=np.ones_like(eps_ids), agent_indices=agent_ids, feature_columns=f, state_columns=s, max_seq_len=4) self.assertEqual([f.tolist() for f in f_pad], [[1, 0, 1, 1]]) self.assertEqual([s.tolist() for s in s_init], [[1, 1]]) self.assertEqual(seq_lens.tolist(), [1, 2])
def testMultiAgent(self): eps_ids = [1, 1, 1, 5, 5, 5, 5, 5] agent_ids = [1, 1, 2, 1, 1, 2, 2, 3] f = [[101, 102, 103, 201, 202, 203, 204, 205], [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, np.ones_like(eps_ids), agent_ids, f, s, 4, dynamic_max=False) self.assertEqual(seq_lens.tolist(), [2, 1, 2, 2, 1]) self.assertEqual(len(f_pad[0]), 20) self.assertEqual(len(s_init[0]), 5)
def testBasic(self): eps_ids = [1, 1, 1, 5, 5, 5, 5, 5] agent_ids = [1, 1, 1, 1, 1, 1, 1, 1] f = [[101, 102, 103, 201, 202, 203, 204, 205], [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] f_pad, s_init, seq_lens = chop_into_sequences(eps_ids, np.ones_like(eps_ids), agent_ids, f, s, 4) self.assertEqual([f.tolist() for f in f_pad], [ [101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0], [[101], [102], [103], [0], [201], [202], [203], [204], [205], [0], [0], [0]], ]) self.assertEqual([s.tolist() for s in s_init], [[209, 109, 105]]) self.assertEqual(seq_lens.tolist(), [3, 4, 1])
def test_multi_agent(self): eps_ids = [1, 1, 1, 5, 5, 5, 5, 5] agent_ids = [1, 1, 2, 1, 1, 2, 2, 3] f = [[101, 102, 103, 201, 202, 203, 204, 205], [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] f_pad, s_init, seq_lens = chop_into_sequences( episode_ids=eps_ids, unroll_ids=np.ones_like(eps_ids), agent_indices=agent_ids, feature_columns=f, state_columns=s, max_seq_len=4, dynamic_max=False) self.assertEqual(seq_lens.tolist(), [2, 1, 2, 2, 1]) self.assertEqual(len(f_pad[0]), 20) self.assertEqual(len(s_init[0]), 5)
def test_multi_dim(self): eps_ids = [1, 1, 1] agent_ids = [1, 1, 1] obs = np.ones((84, 84, 4)) f = [[obs, obs * 2, obs * 3]] s = [[209, 208, 207]] f_pad, s_init, seq_lens = chop_into_sequences( episode_ids=eps_ids, unroll_ids=np.ones_like(eps_ids), agent_indices=agent_ids, feature_columns=f, state_columns=s, max_seq_len=4) self.assertEqual([f.tolist() for f in f_pad], [ np.array([obs, obs * 2, obs * 3]).tolist(), ]) self.assertEqual([s.tolist() for s in s_init], [[209]]) self.assertEqual(seq_lens.tolist(), [3])
def test_basic(self): eps_ids = [1, 1, 1, 5, 5, 5, 5, 5] agent_ids = [1, 1, 1, 1, 1, 1, 1, 1] f = [[101, 102, 103, 201, 202, 203, 204, 205], [[101], [102], [103], [201], [202], [203], [204], [205]]] s = [[209, 208, 207, 109, 108, 107, 106, 105]] f_pad, s_init, seq_lens = chop_into_sequences( episode_ids=eps_ids, unroll_ids=np.ones_like(eps_ids), agent_indices=agent_ids, feature_columns=f, state_columns=s, max_seq_len=4) self.assertEqual([f.tolist() for f in f_pad], [ [101, 102, 103, 0, 201, 202, 203, 204, 205, 0, 0, 0], [[101], [102], [103], [0], [201], [202], [203], [204], [205], [0], [0], [0]], ]) self.assertEqual([s.tolist() for s in s_init], [[209, 109, 105]]) self.assertEqual(seq_lens.tolist(), [3, 4, 1])
def _get_loss_inputs_dict(self, batch, shuffle): """Return a feed dict from a batch. Arguments: batch (SampleBatch): batch of data to derive inputs from shuffle (bool): whether to shuffle batch sequences. Shuffle may be done in-place. This only makes sense if you're further applying minibatch SGD after getting the outputs. Returns: feed dict of data """ feed_dict = {} if self._batch_divisibility_req > 1: meets_divisibility_reqs = ( len(batch[SampleBatch.CUR_OBS]) % self._batch_divisibility_req == 0 and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent else: meets_divisibility_reqs = True # Simple case: not RNN nor do we need to pad if not self._state_inputs and meets_divisibility_reqs: if shuffle: batch.shuffle() for k, ph in self._loss_inputs: feed_dict[ph] = batch[k] return feed_dict if self._state_inputs: max_seq_len = self._max_seq_len dynamic_max = True else: max_seq_len = self._batch_divisibility_req dynamic_max = False # RNN or multi-agent case feature_keys = [k for k, v in self._loss_inputs] state_keys = [ "state_in_{}".format(i) for i in range(len(self._state_inputs)) ] feature_sequences, initial_states, seq_lens = chop_into_sequences( batch[SampleBatch.EPS_ID], batch[SampleBatch.UNROLL_ID], batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys], [batch[k] for k in state_keys], max_seq_len, dynamic_max=dynamic_max, shuffle=shuffle) for k, v in zip(feature_keys, feature_sequences): feed_dict[self._loss_input_dict[k]] = v for k, v in zip(state_keys, initial_states): feed_dict[self._loss_input_dict[k]] = v feed_dict[self._seq_lens] = seq_lens if log_once("rnn_feed_dict"): logger.info("Padded input for RNN:\n\n{}\n".format( summarize({ "features": feature_sequences, "initial_states": initial_states, "seq_lens": seq_lens, "max_seq_len": max_seq_len, }))) return feed_dict
def learn_on_batch(self, samples): obs_batch, action_mask = self._unpack_observation( samples[SampleBatch.CUR_OBS]) next_obs_batch, next_action_mask = self._unpack_observation( samples[SampleBatch.NEXT_OBS]) group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS]) # These will be padded to shape [B * T, ...] [rew, action_mask, next_action_mask, act, dones, obs, next_obs], \ initial_states, seq_lens = \ chop_into_sequences( samples[SampleBatch.EPS_ID], samples[SampleBatch.UNROLL_ID], samples[SampleBatch.AGENT_INDEX], [ group_rewards, action_mask, next_action_mask, samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES], obs_batch, next_obs_batch ], [samples["state_in_{}".format(k)] for k in range(len(self.get_initial_state()))], max_seq_len=self.config["model"]["max_seq_len"], dynamic_max=True) B, T = len(seq_lens), max(seq_lens) def to_batches(arr): new_shape = [B, T] + list(arr.shape[1:]) return th.from_numpy(np.reshape(arr, new_shape)) rewards = to_batches(rew).float() actions = to_batches(act).long() obs = to_batches(obs).reshape([B, T, self.n_agents, self.obs_size]).float() action_mask = to_batches(action_mask) next_obs = to_batches(next_obs).reshape( [B, T, self.n_agents, self.obs_size]).float() next_action_mask = to_batches(next_action_mask) # TODO(ekl) this treats group termination as individual termination terminated = to_batches(dones.astype(np.float32)).unsqueeze(2).expand( B, T, self.n_agents) # Create mask for where index is < unpadded sequence length filled = (np.reshape(np.tile(np.arange(T), B), [B, T]) < np.expand_dims(seq_lens, 1)).astype(np.float32) mask = th.from_numpy(filled).unsqueeze(2).expand(B, T, self.n_agents) # Compute loss loss_out, mask, masked_td_error, chosen_action_qvals, targets = \ self.loss(rewards, actions, terminated, mask, obs, next_obs, action_mask, next_action_mask) # Optimise self.optimiser.zero_grad() loss_out.backward() grad_norm = th.nn.utils.clip_grad_norm_( self.params, self.config["grad_norm_clipping"]) self.optimiser.step() mask_elems = mask.sum().item() stats = { "loss": loss_out.item(), "grad_norm": grad_norm if isinstance(grad_norm, float) else grad_norm.item(), "td_error_abs": masked_td_error.abs().sum().item() / mask_elems, "q_taken_mean": (chosen_action_qvals * mask).sum().item() / mask_elems, "target_mean": (targets * mask).sum().item() / mask_elems, } return {LEARNER_STATS_KEY: stats}
def learn_on_batch(self, samples): obs_batch, action_mask, env_global_state = self._unpack_observation( samples[SampleBatch.CUR_OBS]) (next_obs_batch, next_action_mask, next_env_global_state) = self._unpack_observation( samples[SampleBatch.NEXT_OBS]) group_rewards = self._get_group_rewards(samples[SampleBatch.INFOS]) input_list = [ group_rewards, action_mask, next_action_mask, samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES], obs_batch, next_obs_batch ] if self.has_env_global_state: input_list.extend([env_global_state, next_env_global_state]) output_list, _, seq_lens = \ chop_into_sequences( samples[SampleBatch.EPS_ID], samples[SampleBatch.UNROLL_ID], samples[SampleBatch.AGENT_INDEX], input_list, [], # RNN states not used here max_seq_len=self.config["model"]["max_seq_len"], dynamic_max=True) # These will be padded to shape [B * T, ...] if self.has_env_global_state: (rew, action_mask, next_action_mask, act, dones, obs, next_obs, env_global_state, next_env_global_state) = output_list else: (rew, action_mask, next_action_mask, act, dones, obs, next_obs) = output_list B, T = len(seq_lens), max(seq_lens) def to_batches(arr, dtype): new_shape = [B, T] + list(arr.shape[1:]) return torch.as_tensor( np.reshape(arr, new_shape), dtype=dtype, device=self.device) rewards = to_batches(rew, torch.float) actions = to_batches(act, torch.long) obs = to_batches(obs, torch.float).reshape( [B, T, self.n_agents, self.obs_size]) action_mask = to_batches(action_mask, torch.float) next_obs = to_batches(next_obs, torch.float).reshape( [B, T, self.n_agents, self.obs_size]) next_action_mask = to_batches(next_action_mask, torch.float) if self.has_env_global_state: env_global_state = to_batches(env_global_state, torch.float) next_env_global_state = to_batches(next_env_global_state, torch.float) # TODO(ekl) this treats group termination as individual termination terminated = to_batches(dones, torch.float).unsqueeze(2).expand( B, T, self.n_agents) # Create mask for where index is < unpadded sequence length filled = np.reshape( np.tile(np.arange(T, dtype=np.float32), B), [B, T]) < np.expand_dims(seq_lens, 1) mask = torch.as_tensor( filled, dtype=torch.float, device=self.device).unsqueeze(2).expand( B, T, self.n_agents) # Compute loss loss_out, mask, masked_td_error, chosen_action_qvals, targets = ( self.loss(rewards, actions, terminated, mask, obs, next_obs, action_mask, next_action_mask, env_global_state, next_env_global_state)) # Optimise self.optimiser.zero_grad() loss_out.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.params, self.config["grad_norm_clipping"]) self.optimiser.step() mask_elems = mask.sum().item() stats = { "loss": loss_out.item(), "grad_norm": grad_norm if isinstance(grad_norm, float) else grad_norm.item(), "td_error_abs": masked_td_error.abs().sum().item() / mask_elems, "q_taken_mean": (chosen_action_qvals * mask).sum().item() / mask_elems, "target_mean": (targets * mask).sum().item() / mask_elems, } return {LEARNER_STATS_KEY: stats}
def _get_loss_inputs_dict(self, batch, neighbor_batch_dic, shuffle): """Return a feed dict from a batch. Arguments: batch (SampleBatch): batch of data to derive inputs from neighbor_batch_dic (dict, SampleBatch): batch of data for neighbor of the main policy shuffle (bool): whether to shuffle batch sequences. Shuffle may be done in-place. This only makes sense if you're further applying minibatch SGD after getting the outputs. Returns: feed dict of data """ feed_dict = {} if self._batch_divisibility_req > 1: meets_divisibility_reqs = ( len(batch[SampleBatch.CUR_OBS]) % self._batch_divisibility_req == 0 and max(batch[SampleBatch.AGENT_INDEX]) == 0) # not multiagent else: meets_divisibility_reqs = True neighbor_list = [None] * 5 neighbor_count = 0 for k in neighbor_batch_dic: neighbor_list[neighbor_count] = k neighbor_count += 1 tmp_dic = {} # Simple case: not RNN nor do we need to pad if not self._state_inputs and meets_divisibility_reqs: if shuffle: batch.shuffle() for k, ph in self._loss_inputs: ''' For Attention 这里是用于Q target的replay buffer, 从SampleBatch中整理出 neighbor_obs的信息 ''' if 'neighbor_obs' in k: tmp_dic[ph] = {} feed_dict[ph] = [] # neighbor_id = neighbor_list[int(k.split('_')[2])] # if neighbor_id is None: # continue # else: for neighbor_id in neighbor_list: tmp_dic[ph][neighbor_id] = neighbor_batch_dic[ neighbor_id]['obs'] neighbor_id = neighbor_list[0] # [neighbor, batch, feather] -> [batch, neighbor, feather] for batch_item in range(len(tmp_dic[ph][neighbor_id])): feed_dict[ph].append([]) for neighbor_id in neighbor_list: feed_dict[ph][batch_item].append( tmp_dic[ph][neighbor_id][batch_item]) feed_dict[ph] = np.array(feed_dict[ph]) # feed_dict[ph] = Lambda(lambda x: K.permute_dimensions(x, (0, 2, 1)))(tmp_dic[ph]) # ---------------------------------------------------------------- else: feed_dict[ph] = batch[k] return feed_dict if self._state_inputs: max_seq_len = self._max_seq_len dynamic_max = True else: max_seq_len = self._batch_divisibility_req dynamic_max = False # RNN or multi-agent case feature_keys = [k for k, v in self._loss_inputs] state_keys = [ "state_in_{}".format(i) for i in range(len(self._state_inputs)) ] feature_sequences, initial_states, seq_lens = chop_into_sequences( batch[SampleBatch.EPS_ID], batch[SampleBatch.UNROLL_ID], batch[SampleBatch.AGENT_INDEX], [batch[k] for k in feature_keys], [batch[k] for k in state_keys], max_seq_len, dynamic_max=dynamic_max, shuffle=shuffle) for k, v in zip(feature_keys, feature_sequences): feed_dict[self._loss_input_dict[k]] = v for k, v in zip(state_keys, initial_states): feed_dict[self._loss_input_dict[k]] = v feed_dict[self._seq_lens] = seq_lens if log_once("rnn_feed_dict"): logger.info("Padded input for RNN:\n\n{}\n".format( summarize({ "features": feature_sequences, "initial_states": initial_states, "seq_lens": seq_lens, "max_seq_len": max_seq_len, }))) return feed_dict
def build_q_losses_recurrent(policy, model, dist_class, samples): # Observations obs_batch, action_mask, _, _ = unpack_train_observations( policy, samples[SampleBatch.CUR_OBS], policy.device) next_obs_batch, next_action_mask, _, _ = unpack_train_observations( policy, samples[SampleBatch.NEXT_OBS], policy.device) rewards = samples[SampleBatch.REWARDS] # Obtain sequences input_list = [ rewards, action_mask, next_action_mask, samples[SampleBatch.ACTIONS], samples[SampleBatch.DONES], obs_batch, next_obs_batch ] output_list, _, seq_lens = \ chop_into_sequences( episode_ids=samples[SampleBatch.EPS_ID], unroll_ids=samples[SampleBatch.UNROLL_ID], agent_indices=samples[SampleBatch.AGENT_INDEX], feature_columns=input_list, state_columns=[], # RNN states not used here max_seq_len=policy.config["model"]["max_seq_len"], dynamic_max=True) (rew, action_mask, next_action_mask, act, dones, obs, next_obs) = output_list B, T = len(seq_lens), max(seq_lens) def to_batches(arr, dtype): new_shape = [B, T] + list(arr.shape[1:]) return torch.as_tensor(np.reshape(arr, new_shape), dtype=dtype, device=policy.device) rewards = to_batches(rew, torch.float) actions = to_batches(act, torch.long) obs = to_batches(obs, torch.float).reshape([B, T, -1]) action_mask = to_batches(action_mask, torch.float) next_obs = to_batches(next_obs, torch.float).reshape([B, T, -1]) next_action_mask = to_batches(next_action_mask, torch.float) terminated = to_batches(dones, torch.float).unsqueeze(2).expand(B, T, 1) filled = np.reshape(np.tile(np.arange(T, dtype=np.float32), B), [B, T]) < np.expand_dims(seq_lens, 1) mask = torch.as_tensor(filled, dtype=torch.float, device=policy.device).unsqueeze(2).expand(B, T, 1) q_t = compute_sequence_q_values(policy, policy.q_model, obs, explore=False, is_training=True) chosen_action_qvals = torch.gather(q_t, dim=-1, index=actions.unsqueeze(-1)) q_tp1 = compute_sequence_q_values(policy, policy.target_q_model, next_obs, explore=False, is_training=True) ignore_action_tp1 = (next_action_mask == 0) & (mask == 1) q_tp1[ignore_action_tp1] = -np.inf target_max_qvals = q_tp1.max(dim=-1)[0] targets = rewards.squeeze(-1) + policy.config["gamma"] * ( 1 - terminated.squeeze(-1)) * target_max_qvals td_error = (chosen_action_qvals.squeeze(-1) - targets.detach()) mask = mask.squeeze(-1).expand_as(td_error) masked_td_error = td_error * mask loss = (masked_td_error**2).sum() / mask.sum() policy.td_error = td_error return loss