示例#1
0
 def forward(self, entity_enc, focus_goal):
     sym_enc = time_distributed(tu.flatten(focus_goal, begin_axis=2),
                                self._sym_encoder)
     focus_enc = self._reachable_encoder(
         torch.cat([entity_enc, sym_enc], dim=-1))
     reachable_enc_red, _ = torch.max(focus_enc, dim=-2)
     return self._reachable(reachable_enc_red)
示例#2
0
    def backward_plan(self,
                      entity_states,
                      object_states,
                      focus_goal,
                      graphs=None):
        inputs_focus = torch.cat(
            (tu.flatten(object_states), tu.flatten(focus_goal)), dim=-1)
        preimage_preds = self._preimage(inputs_focus)

        reachable_inputs = torch.cat(
            (tu.flatten(entity_states,
                        begin_axis=2), tu.flatten(focus_goal, begin_axis=2)),
            dim=-1)
        focus_enc = time_distributed(reachable_inputs, self._reachable_encoder)
        # reachable_enc = self._reachable_gn(focus_enc, graphs)
        reachable_enc_red, _ = torch.max(focus_enc, dim=-2)
        reachable_preds = self._reachable(reachable_enc_red)

        return {
            'preimage_preds':
            preimage_preds.reshape(focus_goal.shape[0], -1, self.c.symbol_size,
                                   3),
            'reachable_preds':
            reachable_preds,
        }
示例#3
0
    def forward_policy(self, object_states, entity_states, goal, graphs=None):
        object_feat = time_distributed(object_states, self._state_encoder)
        planner_out = self.find_subgoal(object_feat, None, goal, graphs)
        if self.verbose and planner_out['ret'] != -1:
            print(planner_out['ret'])

        subgoal = planner_out['subgoal']
        if subgoal is None:
            return planner_out
        # TODO: policy
        return planner_out
示例#4
0
 def forward(
     self,
     object_states,
     entity_states,
     goal,
     focus_goal,
     satisfied_info,
     dependency_info,
     graph,
     num_entities,
 ):
     object_feat = time_distributed(object_states, self._state_encoder)
     planner_out = self.plan(object_feat, goal)
     return planner_out
示例#5
0
 def forward(
     self,
     object_states,
     entity_states,
     goal,
     focus_goal,
     satisfied_info,
     dependency_info,
     graph,
     num_entities,
 ):
     object_feat = time_distributed(object_states, self._state_encoder)
     entity_feat = self.batch_entity_features(object_feat)
     planner_out = self.focus(entity_feat, object_feat, goal)
     planner_out.update(
         self.backward_plan(entity_feat, object_feat, focus_goal, graph))
     return planner_out