Ejemplo n.º 1
0
    def _latent_batch_internal_inference(self, model, actions, initial_hidden_state):
        # Replicate initial_hidden_state_and_info to make B=1 to B=Na
        # For t = 1 to T:
            # For each hidden state, we need to get a pick location pick_locs (K,2)
            # Once we have pick locations for all hidden states, we can create the appriopriate env actions (B,A=raw)
            # Apply model.batch_internal_inference() with proper actions
            # Get new hidden states, repeat

        num_actions = actions.shape[0]
        all_hidden_states = model.replicate_state(initial_hidden_state, num_actions)  # (B=Na)
        all_pick_locs = ptu.zeros(num_actions, self.time_horizon, self.action_type, 2)  # (B,T,K,2)
        all_env_actions = ptu.zeros(num_actions, self.time_horizon, 4)  # (B,T,4)
        schedule = np.array([1])  # We only want to do a rollout of just a single action
        for t in range(self.time_horizon):
            for i in range(num_actions):
                single_state = model.select_specific_state(all_hidden_states, [i])
                pick_locs = self._get_latent_locs(model, single_state)  # (K,2)
                if t == 0:
                    all_pick_locs[:, t] = pick_locs  # Broadcast across everything as we have the same initial state
                    break
                else:
                    all_pick_locs[i, t] = pick_locs

            raw_actions = actions[:, t, :-2]  # (B,K), note these are one hot vectors
            selected_pick_locs = (raw_actions.unsqueeze(2) * all_pick_locs[:, t]).sum(1)  # (B,K,1)*(K,2)->(B,K,2)->(B,2)  Note that we use the one hot encoding as a mask
            env_actions = torch.cat([selected_pick_locs, actions[:, t, -2:]], dim=1)  # (B,2) (B,2) -> (B,4)
            all_env_actions[:, t] = env_actions  # (B,4)

            env_actions = env_actions.unsqueeze(1)  # (B,1,4)
            predicted_info = model.batch_internal_inference(obs=None, actions=env_actions,
                                                            initial_hidden_state=all_hidden_states,
                                                            schedule=schedule, figure_path=None)
            all_hidden_states = predicted_info["state"]
        return predicted_info, all_env_actions
Ejemplo n.º 2
0
    def get_all_activation_values(self,
                                  initial_hidden_state,
                                  actions,
                                  batch_size=10):
        B = actions.shape[0]
        num_batches = int(np.ceil(B / batch_size))

        state_action_attention = ptu.zeros(
            (B, self.K), torch_device=actions.device)  # (B,K)
        interaction_attention = ptu.zeros(
            (B, self.K, self.K - 1), torch_device=actions.device)  # (B,K,K-1)
        all_delta_vals = ptu.zeros((B, self.K),
                                   torch_device=actions.device)  # (B,K)
        all_lambdas_deltas = ptu.zeros((B, self.K),
                                       torch_device=actions.device)  # (B,K)

        for i in range(num_batches):
            start_index = i * batch_size
            end_index = min(start_index + batch_size, B)

            batch_initial_hidden_state = self.replicate_state(
                initial_hidden_state, end_index - start_index)  # (b,...)
            states = self.get_full_tensor_state(
                batch_initial_hidden_state)  # (b,K,R)
            states = self._flatten_first_two(states)  # (b*K,R)
            input_actions = actions[start_index:end_index].unsqueeze(1).repeat(
                1, self.K, 1)  # (b,K,A)
            input_actions = self._flatten_first_two(input_actions)
            state_vals, inter_vals, delta_vals = self.dynamics_net.get_all_attention_values(
                states, input_actions, self.K)  # (b*k,1), (b*k,k-1,1), (b*k,1)
            # pdb.set_trace()
            state_action_attention[
                start_index:end_index] = self._unflatten_first(
                    state_vals, self.K)[..., 0]  # (b,k)
            interaction_attention[
                start_index:end_index] = self._unflatten_first(
                    inter_vals, self.K)[..., 0]  # (b,k,k-1)
            all_delta_vals[start_index:end_index] = self._unflatten_first(
                delta_vals, self.K)[..., 0]  # (b,k)

            deter_state, lambdas1, lambdas2 = self.dynamics_net(
                states, input_actions)  # (b*k,Rd),  (b*k,Rs),  (b*k,Rs)
            lambdas_deltas = self._flatten_first_two(
                batch_initial_hidden_state["post"]
                ["lambdas1"])  # (b,k,Rs)->(b*k,Rs)
            lambdas_deltas = torch.abs(lambdas_deltas - lambdas1).sum(
                1)  # (b*k,Rs)->(b*k)
            if deter_state is not None:
                deter_state_deltas = torch.abs(states[:, :self.det_size] -
                                               deter_state).sum(
                                                   1)  # (b*k,Rd)->(b*k)
                lambdas_deltas += deter_state_deltas
            all_lambdas_deltas[start_index:end_index] = self._unflatten_first(
                lambdas_deltas, self.K)  # (b,k)

        return state_action_attention.detach(), interaction_attention.detach(
        ), all_delta_vals.detach(), all_lambdas_deltas.detach()
Ejemplo n.º 3
0
    def get_activation_values(self,
                              initial_hidden_state,
                              actions,
                              batch_size=10):
        B = actions.shape[0]
        num_batches = int(np.ceil(B / batch_size))

        act_values = ptu.zeros((B, self.K),
                               torch_device=actions.device)  # (B,K)

        for i in range(num_batches):
            start_index = i * batch_size
            end_index = min(start_index + batch_size, B)

            batch_initial_hidden_state = self.replicate_state(
                initial_hidden_state, end_index - start_index)  # (b,...)
            states = self.get_full_tensor_state(
                batch_initial_hidden_state)  # (b,K,R)
            # pdb.set_trace()
            states = self._flatten_first_two(states)  # (b*K,R)
            input_actions = actions[start_index:end_index].repeat(self.K,
                                                                  1)  # (b*K,A)
            vals = self.dynamics_net.get_state_action_attention_values(
                states, input_actions)  # (b*k, 1)
            act_values[start_index:end_index] = self._unflatten_first(
                vals, self.K)[:, :, 0]  # (b,k)

        return act_values
Ejemplo n.º 4
0
def get_object_subimage_recon(frames,
                              actions,
                              model,
                              model_type,
                              T,
                              image_type="normal"):
    def get_image_mse(frames, preds):  #Both of size (bs, T, ch, D, D)
        return torch.pow(frames - preds, 2).mean((-1, -2, -3))  #(bs, T)

    seed_steps = 5
    if model_type == 'static':
        all_object_recons = []
        masks_recons = []
        for i in range(T):
            schedule = np.zeros(5)  # Do 5 refine steps
            # Inputs for model.run_schedule(...):
            #   images: None or (B, T_obs, 3, D, D),  actions: None or (B, T_acs, A),  initial_hidden_state or None
            #   schedule: (T1),   loss_schedule:(T1)
            # Output: colors_list (B,T1,K,3,D,D), masks_list (B,T1,K,1,D,D), final_recon (B,3,D,D),
            #   total_loss, total_kle_loss, total_clog_prob, mse are all (Sc), end_hidden_state
            colors_list, masks_list, final_recon, total_loss, total_kle_loss, total_clog_prob, mse, end_hidden_state = \
                model.run_schedule(frames[:, i:i+1], actions, initial_hidden_state=None, schedule=schedule,
                                   loss_schedule=schedule, should_detach=True)
            x_hats = colors_list
            masks = masks_list

            object_recons = x_hats * masks  #(bs, 5, K, 3, D, D)
            object_recons = object_recons[:, -1]  #(bs, K, 3, D, D)
            final_recons = object_recons.sum(1,
                                             keepdim=True)  # (bs, 1, 3, D, D)

            #If plotting masks
            if image_type == 'masks':
                object_recons = masks.repeat(1, 1, 1, 3, 1,
                                             1)  # (bs, 5, K, 3, D, D)
                object_recons = object_recons[:, -1]  # (bs, K, 3, D, D)

            #If plotting subimages with white background
            # masks = masks.repeat(1, 1, 1, 3, 1, 1)  # (bs, 5, K, 3, D, D)
            # masks = masks[:, -1]  # (bs, K, 3, D, D)
            # object_recons = torch.where(masks < 0.01, ptu.ones_like(object_recons), object_recons)

            tmp = torch.cat([final_recons, object_recons],
                            dim=1)  #(bs, K+1, 3, D, D)
            all_object_recons.append(tmp)
        all_object_recons = torch.stack(all_object_recons,
                                        dim=0)  #(T, bs, K+1, 3, D, D))
        all_object_recons = all_object_recons.permute(
            1, 0, 2, 3, 4, 5).contiguous()  #(bs, T, K+1, 3, D, D)
        mse = get_image_mse(frames[:, :T], all_object_recons[:, :,
                                                             0])  #(bs, T)
        return all_object_recons, mse
    elif model_type == 'rprp':  # We store p(x_t|x0:t, a0:t-1)
        # T is the total number of frames, so we do T-1 physics steps
        num_refine_per_phys = 2
        num_refine_per_phys += 1
        schedule = np.zeros(seed_steps + (T - 1) *
                            (num_refine_per_phys))  # len(schedule) = T2
        schedule[seed_steps::(
            num_refine_per_phys
        )] = 1  # [0,0,0,0,1,0,1,0,1,0] if num_refine_per_phys=1 for example
        colors_list, masks_list, final_recon, total_loss, total_kle_loss, total_clog_prob, mse, end_hidden_state = \
            model.run_schedule(frames, actions, initial_hidden_state=None, schedule=schedule,
                               loss_schedule=schedule, should_detach=True)
        x_hats = colors_list
        masks = masks_list

        object_recons = x_hats * masks  # (bs, T2, K, 3, D, D)
        object_recons = object_recons[:, seed_steps - 1::
                                      num_refine_per_phys]  # (bs, T, K, 3, D, D)
        final_recons = object_recons.sum(2, keepdim=True)  #(bs, T, 1, 3, D, D)

        # If plotting masks
        if image_type == 'masks':
            object_recons = masks.repeat(1, 1, 1, 3, 1,
                                         1)  # (bs, T2, K, 3, D, D)
            object_recons = object_recons[:, seed_steps - 1::
                                          num_refine_per_phys]  # (bs, T, K, 3, D, D)

        # If plotting subimages with white background
        # masks = masks.repeat(1, 1, 1, 3, 1, 1)  # (bs, 5, K, 3, D, D)
        # masks = masks[:, seed_steps - 1::num_refine_per_phys]  # (bs, T, K, 3, D, D)
        # object_recons = torch.where(masks < 0.01, ptu.ones_like(object_recons), object_recons)

        all_object_recons = torch.cat([final_recons, object_recons],
                                      dim=2)  #(bs, T, K+1, 3, D, D)
        mse = get_image_mse(frames[:, :T], all_object_recons[:, :,
                                                             0])  #(bs, T)

        return all_object_recons, mse
    elif model_type == 'rprp_pred':  # We store p(x_t|x0:t-1, a0:t-1). Note we end at x_t-1, so we are predicting here
        # T is the total number of frames, so we do T-1 physics steps
        num_refine_per_phys = 2
        num_refine_per_phys += 1
        schedule = np.zeros(
            seed_steps + (T - 1) * num_refine_per_phys)  # len(schedule) = T2
        schedule[
            seed_steps::
            num_refine_per_phys] = 1  # [0,0,0,0,1,0,1,0,1,0] if num_refine_per_phys=1 for example
        colors_list, masks_list, final_recon, total_loss, total_kle_loss, total_clog_prob, mse, end_hidden_state = \
            model.run_schedule(frames, actions, initial_hidden_state=None, schedule=schedule,
                               loss_schedule=schedule, should_detach=True)
        x_hats = colors_list
        masks = masks_list

        object_recons = x_hats * masks  # (bs, T2, K, 3, D, D)
        object_recons = object_recons[:, seed_steps::
                                      num_refine_per_phys]  # (bs, T-1, K, 3, D, D)
        final_recons = object_recons.sum(2,
                                         keepdim=True)  # (bs, T-1, 1, 3, D, D)

        # If plotting masks
        if image_type == 'masks':
            object_recons = masks.repeat(1, 1, 1, 3, 1,
                                         1)  # (bs, T2, K, 3, D, D)
            object_recons = object_recons[:, seed_steps::
                                          num_refine_per_phys]  # (bs, T-1, K, 3, D, D)

        all_object_recons = torch.cat([final_recons, object_recons],
                                      dim=2)  # (bs, T-1, K+1, 3, D, D)
        padding = ptu.zeros([
            all_object_recons.shape[0], 1, *list(all_object_recons.shape[2:])
        ])  # (bs, 1, K+1, 3, D, D)
        all_object_recons = torch.cat([padding, all_object_recons],
                                      dim=1)  # (bs, T, K+1, 3, D, D)
        mse = get_image_mse(frames[:, :T], all_object_recons[:, :,
                                                             0])  # (bs, T)
        return all_object_recons, mse
    elif model_type == 'next_step':
        schedule = np.ones(T - 1) * 2
        colors_list, masks_list, final_recon, total_loss, total_kle_loss, total_clog_prob, mse, end_hidden_state = \
            model.run_schedule(frames, actions, initial_hidden_state=None, schedule=schedule,
                               loss_schedule=schedule, should_detach=True)
        x_hats = colors_list
        masks = masks_list

        object_recons = x_hats * masks  # (bs, T-1, K, 3, D, D)
        final_recons = object_recons.sum(2,
                                         keepdim=True)  # (bs, T-1, 1, 3, D, D)

        # If plotting masks
        if image_type == 'masks':
            object_recons = masks.repeat(1, 1, 1, 3, 1,
                                         1)  # (bs, T-1, K, 3, D, D)

        all_object_recons = torch.cat([final_recons, object_recons],
                                      dim=2)  # (bs, T-1, K+1, 3, D, D)
        padding = ptu.zeros([
            all_object_recons.shape[0], 1, *list(all_object_recons.shape[2:])
        ])  #(bs, 1, K+1, 3, D, D)
        all_object_recons = torch.cat([padding, all_object_recons],
                                      dim=1)  #(bs, T, K+1, 3, D, D)
        mse = get_image_mse(frames[:, :T], all_object_recons[:, :,
                                                             0])  #(bs, T)
        return all_object_recons, mse
    else:
        return ValueError("Invalid model_type: {}".format(model_type))
Ejemplo n.º 5
0
    def forward(self, sampled_state, actions):
        K = self.K
        bs = sampled_state.shape[0] // K

        state_enc_flat = self.inertia_encoder(sampled_state)  #Encode sample

        if actions is not None:
            if self.action_size == 4 and actions.shape[-1] == 6:
                action_enc = self.action_encoder(
                    actions[:, torch.LongTensor([0, 1, 3, 4])]
                )  #RV: Encode actions, why torch.longTensor?
            else:
                action_enc = self.action_encoder(actions)  #Encode actions
            # action_enc = self.action_encoder(actions)  # Encode actions
            state_enc_actions = torch.cat([state_enc_flat, action_enc], -1)

            state_action_effect = self.action_effect_network(
                state_enc_actions)  #(bs*k, h)
            state_action_attention = self.action_attention_network(
                state_enc_actions)  #(bs*k, 1)
            state_enc = (state_action_effect * state_action_attention).view(
                bs, K, self.full_rep_size)  #(bs, k, h)
        else:
            state_enc = state_enc_flat.view(bs, K,
                                            self.full_rep_size)  #(bs, k, h)

        if K != 1:
            pairs = []
            for i in range(K):
                for j in range(K):
                    if i == j:
                        continue
                    pairs.append(
                        torch.cat([state_enc[:, i], state_enc[:, j]],
                                  -1))  #Create array of all pairs

            all_pairs = torch.stack(pairs,
                                    1).view(bs * K, K - 1,
                                            -1)  #Create torch of all pairs

            pairwise_interaction = self.pairwise_encoder_network(
                all_pairs)  #(bs*k,k-1,h)
            effect = self.interaction_effect_network(
                pairwise_interaction)  # (bs*k,k-1,h)
            attention = self.interaction_attention_network(
                pairwise_interaction)  #(bs*k,k-1,1)
            total_effect = (effect * attention).sum(1)  #(bs*k,h)
        else:
            total_effect = ptu.zeros(
                (bs, self.effect_size)).to(sampled_state.device)

        state_and_effect = torch.cat(
            [state_enc.view(bs * K, self.full_rep_size), total_effect],
            -1)  # (bs*k,h)

        aggregate_state = self.final_merge_network(state_and_effect)
        if self.det_size == 0:
            deter_state = None
        else:
            deter_state = self.det_output(aggregate_state)
        lambdas1 = self.lambdas1_output(aggregate_state)
        lambdas2 = self.lambdas2_output(aggregate_state)

        return deter_state, lambdas1, lambdas2
Ejemplo n.º 6
0
 def initialize_hidden(self, bs):
     return ptu.zeros((1, bs, self.lstm_size)), ptu.zeros(
         (1, bs, self.lstm_size))