def get_all_activation_values(self, initial_hidden_state, actions, batch_size=10): B = actions.shape[0] num_batches = int(np.ceil(B / batch_size)) state_action_attention = ptu.zeros((B, self.K), torch_device=actions.device) # (B,K) interaction_attention = ptu.zeros((B, self.K, self.K-1), torch_device=actions.device) # (B,K,K-1) all_delta_vals = ptu.zeros((B, self.K), torch_device=actions.device) # (B,K) all_lambdas_deltas = ptu.zeros((B, self.K), torch_device=actions.device) # (B,K) for i in range(num_batches): start_index = i * batch_size end_index = min(start_index + batch_size, B) batch_initial_hidden_state = self.replicate_state(initial_hidden_state, end_index - start_index) # (b,...) states = self.get_full_tensor_state(batch_initial_hidden_state) # (b,K,R) states = self._flatten_first_two(states) # (b*K,R) input_actions = actions[start_index:end_index].unsqueeze(1).repeat(1, self.K, 1) # (b,K,A) input_actions = self._flatten_first_two(input_actions) state_vals, inter_vals, delta_vals = self.dynamics_net.get_all_attention_values(states, input_actions, self.K) # (b*k,1), (b*k,k-1,1), (b*k,1) state_action_attention[start_index:end_index] = self._unflatten_first(state_vals, self.K)[..., 0] # (b,k) interaction_attention[start_index:end_index] = self._unflatten_first(inter_vals, self.K)[..., 0] # (b,k,k-1) all_delta_vals[start_index:end_index] = self._unflatten_first(delta_vals, self.K)[..., 0] # (b,k) deter_state, lambdas1, lambdas2 = self.dynamics_net(states, input_actions) # (b*k,Rd), (b*k,Rs), (b*k,Rs) lambdas_deltas = self._flatten_first_two(batch_initial_hidden_state["post"]["lambdas1"]) # (b,k,Rs)->(b*k,Rs) lambdas_deltas = torch.abs(lambdas_deltas - lambdas1).sum(1) # (b*k,Rs)->(b*k) if deter_state is not None: deter_state_deltas = torch.abs(states[:, :self.det_size] - deter_state).sum(1) # (b*k,Rd)->(b*k) lambdas_deltas += deter_state_deltas all_lambdas_deltas[start_index:end_index] = self._unflatten_first(lambdas_deltas, self.K) # (b,k) return state_action_attention.detach(), interaction_attention.detach(), all_delta_vals.detach(), all_lambdas_deltas.detach()
def forward(self, sampled_state, actions): K = self.K bs = sampled_state.shape[0]//K state_enc_flat = self.inertia_encoder(sampled_state) #Encode sample if actions is not None: if self.action_size == 4 and actions.shape[-1] == 6: action_enc = self.action_encoder(actions[:, torch.LongTensor([0, 1, 3, 4])]) #RV: Encode actions, why torch.longTensor? else: action_enc = self.action_encoder(actions) #Encode actions # action_enc = self.action_encoder(actions) # Encode actions state_enc_actions = torch.cat([state_enc_flat, action_enc], -1) state_action_effect = self.action_effect_network(state_enc_actions) #(bs*k, h) state_action_attention = self.action_attention_network(state_enc_actions) #(bs*k, 1) state_enc = (state_action_effect*state_action_attention).view(bs, K, self.full_rep_size) #(bs, k, h) else: state_enc = state_enc_flat.view(bs, K, self.full_rep_size) #(bs, k, h) if K != 1: pairs = [] for i in range(K): for j in range(K): if i == j: continue pairs.append(torch.cat([state_enc[:, i], state_enc[:, j]], -1)) #Create array of all pairs all_pairs = torch.stack(pairs, 1).view(bs*K, K-1, -1) #Create torch of all pairs pairwise_interaction = self.pairwise_encoder_network(all_pairs) #(bs*k,k-1,h) effect = self.interaction_effect_network(pairwise_interaction) # (bs*k,k-1,h) attention = self.interaction_attention_network(pairwise_interaction) #(bs*k,k-1,1) total_effect = (effect*attention).sum(1) #(bs*k,h) else: total_effect = ptu.zeros((bs, self.effect_size)).to(sampled_state.device) state_and_effect = torch.cat([state_enc.view(bs*K, self.full_rep_size), total_effect], -1) # (bs*k,h) aggregate_state = self.final_merge_network(state_and_effect) if self.det_size == 0: deter_state = None else: deter_state = self.det_output(aggregate_state) lambdas1 = self.lambdas1_output(aggregate_state) lambdas2 = self.lambdas2_output(aggregate_state) return deter_state, lambdas1, lambdas2
def get_activation_values(self, initial_hidden_state, actions, batch_size=10): B = actions.shape[0] num_batches = int(np.ceil(B / batch_size)) act_values = ptu.zeros((B, self.K), torch_device=actions.device) # (B,K) for i in range(num_batches): start_index = i*batch_size end_index = min(start_index+batch_size, B) batch_initial_hidden_state = self.replicate_state(initial_hidden_state, end_index - start_index) # (b,...) states = self.get_full_tensor_state(batch_initial_hidden_state) # (b,K,R) states = self._flatten_first_two(states) # (b*K,R) input_actions = actions[start_index:end_index].repeat(self.K, 1) # (b*K,A) vals = self.dynamics_net.get_state_action_attention_values(states, input_actions) # (b*k, 1) act_values[start_index:end_index] = self._unflatten_first(vals, self.K)[:, :, 0] # (b,k) return act_values
def initialize_hidden(self, bs): return ptu.zeros((1, bs, self.lstm_size)), ptu.zeros( (1, bs, self.lstm_size))