def _latent_batch_internal_inference(self, model, actions, initial_hidden_state): # Replicate initial_hidden_state_and_info to make B=1 to B=Na # For t = 1 to T: # For each hidden state, we need to get a pick location pick_locs (K,2) # Once we have pick locations for all hidden states, we can create the appriopriate env actions (B,A=raw) # Apply model.batch_internal_inference() with proper actions # Get new hidden states, repeat num_actions = actions.shape[0] all_hidden_states = model.replicate_state(initial_hidden_state, num_actions) # (B=Na) all_pick_locs = ptu.zeros(num_actions, self.time_horizon, self.action_type, 2) # (B,T,K,2) all_env_actions = ptu.zeros(num_actions, self.time_horizon, 4) # (B,T,4) schedule = np.array([1]) # We only want to do a rollout of just a single action for t in range(self.time_horizon): for i in range(num_actions): single_state = model.select_specific_state(all_hidden_states, [i]) pick_locs = self._get_latent_locs(model, single_state) # (K,2) if t == 0: all_pick_locs[:, t] = pick_locs # Broadcast across everything as we have the same initial state break else: all_pick_locs[i, t] = pick_locs raw_actions = actions[:, t, :-2] # (B,K), note these are one hot vectors selected_pick_locs = (raw_actions.unsqueeze(2) * all_pick_locs[:, t]).sum(1) # (B,K,1)*(K,2)->(B,K,2)->(B,2) Note that we use the one hot encoding as a mask env_actions = torch.cat([selected_pick_locs, actions[:, t, -2:]], dim=1) # (B,2) (B,2) -> (B,4) all_env_actions[:, t] = env_actions # (B,4) env_actions = env_actions.unsqueeze(1) # (B,1,4) predicted_info = model.batch_internal_inference(obs=None, actions=env_actions, initial_hidden_state=all_hidden_states, schedule=schedule, figure_path=None) all_hidden_states = predicted_info["state"] return predicted_info, all_env_actions
def get_all_activation_values(self, initial_hidden_state, actions, batch_size=10): B = actions.shape[0] num_batches = int(np.ceil(B / batch_size)) state_action_attention = ptu.zeros( (B, self.K), torch_device=actions.device) # (B,K) interaction_attention = ptu.zeros( (B, self.K, self.K - 1), torch_device=actions.device) # (B,K,K-1) all_delta_vals = ptu.zeros((B, self.K), torch_device=actions.device) # (B,K) all_lambdas_deltas = ptu.zeros((B, self.K), torch_device=actions.device) # (B,K) for i in range(num_batches): start_index = i * batch_size end_index = min(start_index + batch_size, B) batch_initial_hidden_state = self.replicate_state( initial_hidden_state, end_index - start_index) # (b,...) states = self.get_full_tensor_state( batch_initial_hidden_state) # (b,K,R) states = self._flatten_first_two(states) # (b*K,R) input_actions = actions[start_index:end_index].unsqueeze(1).repeat( 1, self.K, 1) # (b,K,A) input_actions = self._flatten_first_two(input_actions) state_vals, inter_vals, delta_vals = self.dynamics_net.get_all_attention_values( states, input_actions, self.K) # (b*k,1), (b*k,k-1,1), (b*k,1) # pdb.set_trace() state_action_attention[ start_index:end_index] = self._unflatten_first( state_vals, self.K)[..., 0] # (b,k) interaction_attention[ start_index:end_index] = self._unflatten_first( inter_vals, self.K)[..., 0] # (b,k,k-1) all_delta_vals[start_index:end_index] = self._unflatten_first( delta_vals, self.K)[..., 0] # (b,k) deter_state, lambdas1, lambdas2 = self.dynamics_net( states, input_actions) # (b*k,Rd), (b*k,Rs), (b*k,Rs) lambdas_deltas = self._flatten_first_two( batch_initial_hidden_state["post"] ["lambdas1"]) # (b,k,Rs)->(b*k,Rs) lambdas_deltas = torch.abs(lambdas_deltas - lambdas1).sum( 1) # (b*k,Rs)->(b*k) if deter_state is not None: deter_state_deltas = torch.abs(states[:, :self.det_size] - deter_state).sum( 1) # (b*k,Rd)->(b*k) lambdas_deltas += deter_state_deltas all_lambdas_deltas[start_index:end_index] = self._unflatten_first( lambdas_deltas, self.K) # (b,k) return state_action_attention.detach(), interaction_attention.detach( ), all_delta_vals.detach(), all_lambdas_deltas.detach()
def get_activation_values(self, initial_hidden_state, actions, batch_size=10): B = actions.shape[0] num_batches = int(np.ceil(B / batch_size)) act_values = ptu.zeros((B, self.K), torch_device=actions.device) # (B,K) for i in range(num_batches): start_index = i * batch_size end_index = min(start_index + batch_size, B) batch_initial_hidden_state = self.replicate_state( initial_hidden_state, end_index - start_index) # (b,...) states = self.get_full_tensor_state( batch_initial_hidden_state) # (b,K,R) # pdb.set_trace() states = self._flatten_first_two(states) # (b*K,R) input_actions = actions[start_index:end_index].repeat(self.K, 1) # (b*K,A) vals = self.dynamics_net.get_state_action_attention_values( states, input_actions) # (b*k, 1) act_values[start_index:end_index] = self._unflatten_first( vals, self.K)[:, :, 0] # (b,k) return act_values
def get_object_subimage_recon(frames, actions, model, model_type, T, image_type="normal"): def get_image_mse(frames, preds): #Both of size (bs, T, ch, D, D) return torch.pow(frames - preds, 2).mean((-1, -2, -3)) #(bs, T) seed_steps = 5 if model_type == 'static': all_object_recons = [] masks_recons = [] for i in range(T): schedule = np.zeros(5) # Do 5 refine steps # Inputs for model.run_schedule(...): # images: None or (B, T_obs, 3, D, D), actions: None or (B, T_acs, A), initial_hidden_state or None # schedule: (T1), loss_schedule:(T1) # Output: colors_list (B,T1,K,3,D,D), masks_list (B,T1,K,1,D,D), final_recon (B,3,D,D), # total_loss, total_kle_loss, total_clog_prob, mse are all (Sc), end_hidden_state colors_list, masks_list, final_recon, total_loss, total_kle_loss, total_clog_prob, mse, end_hidden_state = \ model.run_schedule(frames[:, i:i+1], actions, initial_hidden_state=None, schedule=schedule, loss_schedule=schedule, should_detach=True) x_hats = colors_list masks = masks_list object_recons = x_hats * masks #(bs, 5, K, 3, D, D) object_recons = object_recons[:, -1] #(bs, K, 3, D, D) final_recons = object_recons.sum(1, keepdim=True) # (bs, 1, 3, D, D) #If plotting masks if image_type == 'masks': object_recons = masks.repeat(1, 1, 1, 3, 1, 1) # (bs, 5, K, 3, D, D) object_recons = object_recons[:, -1] # (bs, K, 3, D, D) #If plotting subimages with white background # masks = masks.repeat(1, 1, 1, 3, 1, 1) # (bs, 5, K, 3, D, D) # masks = masks[:, -1] # (bs, K, 3, D, D) # object_recons = torch.where(masks < 0.01, ptu.ones_like(object_recons), object_recons) tmp = torch.cat([final_recons, object_recons], dim=1) #(bs, K+1, 3, D, D) all_object_recons.append(tmp) all_object_recons = torch.stack(all_object_recons, dim=0) #(T, bs, K+1, 3, D, D)) all_object_recons = all_object_recons.permute( 1, 0, 2, 3, 4, 5).contiguous() #(bs, T, K+1, 3, D, D) mse = get_image_mse(frames[:, :T], all_object_recons[:, :, 0]) #(bs, T) return all_object_recons, mse elif model_type == 'rprp': # We store p(x_t|x0:t, a0:t-1) # T is the total number of frames, so we do T-1 physics steps num_refine_per_phys = 2 num_refine_per_phys += 1 schedule = np.zeros(seed_steps + (T - 1) * (num_refine_per_phys)) # len(schedule) = T2 schedule[seed_steps::( num_refine_per_phys )] = 1 # [0,0,0,0,1,0,1,0,1,0] if num_refine_per_phys=1 for example colors_list, masks_list, final_recon, total_loss, total_kle_loss, total_clog_prob, mse, end_hidden_state = \ model.run_schedule(frames, actions, initial_hidden_state=None, schedule=schedule, loss_schedule=schedule, should_detach=True) x_hats = colors_list masks = masks_list object_recons = x_hats * masks # (bs, T2, K, 3, D, D) object_recons = object_recons[:, seed_steps - 1:: num_refine_per_phys] # (bs, T, K, 3, D, D) final_recons = object_recons.sum(2, keepdim=True) #(bs, T, 1, 3, D, D) # If plotting masks if image_type == 'masks': object_recons = masks.repeat(1, 1, 1, 3, 1, 1) # (bs, T2, K, 3, D, D) object_recons = object_recons[:, seed_steps - 1:: num_refine_per_phys] # (bs, T, K, 3, D, D) # If plotting subimages with white background # masks = masks.repeat(1, 1, 1, 3, 1, 1) # (bs, 5, K, 3, D, D) # masks = masks[:, seed_steps - 1::num_refine_per_phys] # (bs, T, K, 3, D, D) # object_recons = torch.where(masks < 0.01, ptu.ones_like(object_recons), object_recons) all_object_recons = torch.cat([final_recons, object_recons], dim=2) #(bs, T, K+1, 3, D, D) mse = get_image_mse(frames[:, :T], all_object_recons[:, :, 0]) #(bs, T) return all_object_recons, mse elif model_type == 'rprp_pred': # We store p(x_t|x0:t-1, a0:t-1). Note we end at x_t-1, so we are predicting here # T is the total number of frames, so we do T-1 physics steps num_refine_per_phys = 2 num_refine_per_phys += 1 schedule = np.zeros( seed_steps + (T - 1) * num_refine_per_phys) # len(schedule) = T2 schedule[ seed_steps:: num_refine_per_phys] = 1 # [0,0,0,0,1,0,1,0,1,0] if num_refine_per_phys=1 for example colors_list, masks_list, final_recon, total_loss, total_kle_loss, total_clog_prob, mse, end_hidden_state = \ model.run_schedule(frames, actions, initial_hidden_state=None, schedule=schedule, loss_schedule=schedule, should_detach=True) x_hats = colors_list masks = masks_list object_recons = x_hats * masks # (bs, T2, K, 3, D, D) object_recons = object_recons[:, seed_steps:: num_refine_per_phys] # (bs, T-1, K, 3, D, D) final_recons = object_recons.sum(2, keepdim=True) # (bs, T-1, 1, 3, D, D) # If plotting masks if image_type == 'masks': object_recons = masks.repeat(1, 1, 1, 3, 1, 1) # (bs, T2, K, 3, D, D) object_recons = object_recons[:, seed_steps:: num_refine_per_phys] # (bs, T-1, K, 3, D, D) all_object_recons = torch.cat([final_recons, object_recons], dim=2) # (bs, T-1, K+1, 3, D, D) padding = ptu.zeros([ all_object_recons.shape[0], 1, *list(all_object_recons.shape[2:]) ]) # (bs, 1, K+1, 3, D, D) all_object_recons = torch.cat([padding, all_object_recons], dim=1) # (bs, T, K+1, 3, D, D) mse = get_image_mse(frames[:, :T], all_object_recons[:, :, 0]) # (bs, T) return all_object_recons, mse elif model_type == 'next_step': schedule = np.ones(T - 1) * 2 colors_list, masks_list, final_recon, total_loss, total_kle_loss, total_clog_prob, mse, end_hidden_state = \ model.run_schedule(frames, actions, initial_hidden_state=None, schedule=schedule, loss_schedule=schedule, should_detach=True) x_hats = colors_list masks = masks_list object_recons = x_hats * masks # (bs, T-1, K, 3, D, D) final_recons = object_recons.sum(2, keepdim=True) # (bs, T-1, 1, 3, D, D) # If plotting masks if image_type == 'masks': object_recons = masks.repeat(1, 1, 1, 3, 1, 1) # (bs, T-1, K, 3, D, D) all_object_recons = torch.cat([final_recons, object_recons], dim=2) # (bs, T-1, K+1, 3, D, D) padding = ptu.zeros([ all_object_recons.shape[0], 1, *list(all_object_recons.shape[2:]) ]) #(bs, 1, K+1, 3, D, D) all_object_recons = torch.cat([padding, all_object_recons], dim=1) #(bs, T, K+1, 3, D, D) mse = get_image_mse(frames[:, :T], all_object_recons[:, :, 0]) #(bs, T) return all_object_recons, mse else: return ValueError("Invalid model_type: {}".format(model_type))
def forward(self, sampled_state, actions): K = self.K bs = sampled_state.shape[0] // K state_enc_flat = self.inertia_encoder(sampled_state) #Encode sample if actions is not None: if self.action_size == 4 and actions.shape[-1] == 6: action_enc = self.action_encoder( actions[:, torch.LongTensor([0, 1, 3, 4])] ) #RV: Encode actions, why torch.longTensor? else: action_enc = self.action_encoder(actions) #Encode actions # action_enc = self.action_encoder(actions) # Encode actions state_enc_actions = torch.cat([state_enc_flat, action_enc], -1) state_action_effect = self.action_effect_network( state_enc_actions) #(bs*k, h) state_action_attention = self.action_attention_network( state_enc_actions) #(bs*k, 1) state_enc = (state_action_effect * state_action_attention).view( bs, K, self.full_rep_size) #(bs, k, h) else: state_enc = state_enc_flat.view(bs, K, self.full_rep_size) #(bs, k, h) if K != 1: pairs = [] for i in range(K): for j in range(K): if i == j: continue pairs.append( torch.cat([state_enc[:, i], state_enc[:, j]], -1)) #Create array of all pairs all_pairs = torch.stack(pairs, 1).view(bs * K, K - 1, -1) #Create torch of all pairs pairwise_interaction = self.pairwise_encoder_network( all_pairs) #(bs*k,k-1,h) effect = self.interaction_effect_network( pairwise_interaction) # (bs*k,k-1,h) attention = self.interaction_attention_network( pairwise_interaction) #(bs*k,k-1,1) total_effect = (effect * attention).sum(1) #(bs*k,h) else: total_effect = ptu.zeros( (bs, self.effect_size)).to(sampled_state.device) state_and_effect = torch.cat( [state_enc.view(bs * K, self.full_rep_size), total_effect], -1) # (bs*k,h) aggregate_state = self.final_merge_network(state_and_effect) if self.det_size == 0: deter_state = None else: deter_state = self.det_output(aggregate_state) lambdas1 = self.lambdas1_output(aggregate_state) lambdas2 = self.lambdas2_output(aggregate_state) return deter_state, lambdas1, lambdas2
def initialize_hidden(self, bs): return ptu.zeros((1, bs, self.lstm_size)), ptu.zeros( (1, bs, self.lstm_size))