コード例 #1
0
    def backward_plan(self,
                      entity_states,
                      object_states,
                      focus_goal,
                      graphs=None):
        inputs_focus = torch.cat(
            (tu.flatten(object_states), tu.flatten(focus_goal)), dim=-1)
        preimage_preds = self._preimage(inputs_focus)

        reachable_inputs = torch.cat(
            (tu.flatten(entity_states,
                        begin_axis=2), tu.flatten(focus_goal, begin_axis=2)),
            dim=-1)
        focus_enc = time_distributed(reachable_inputs, self._reachable_encoder)
        # reachable_enc = self._reachable_gn(focus_enc, graphs)
        reachable_enc_red, _ = torch.max(focus_enc, dim=-2)
        reachable_preds = self._reachable(reachable_enc_red)

        return {
            'preimage_preds':
            preimage_preds.reshape(focus_goal.shape[0], -1, self.c.symbol_size,
                                   3),
            'reachable_preds':
            reachable_preds,
        }
コード例 #2
0
ファイル: rpn_pb.py プロジェクト: zhuyifengzju/RPN
 def plan(self, object_states, goal):
     inputs = torch.cat((tu.flatten(object_states), tu.flatten(goal)),
                        dim=-1)
     sg_preds = self._subgoal(inputs)
     return {
         'subgoal_preds':
         sg_preds.reshape(goal.shape[0], -1, self.c.symbol_size, 3),
     }
コード例 #3
0
ファイル: rpn_pb.py プロジェクト: zhuyifengzju/RPN
 def focus(self, entity_states, object_states, goal):
     inputs = torch.cat((tu.flatten(object_states), tu.flatten(goal)),
                        dim=-1)
     focus_preds = self._focus(inputs)
     return {
         'focus_preds':
         focus_preds.reshape(goal.shape[0], -1, self.c.symbol_size, 3),
     }
コード例 #4
0
ファイル: rpn_pb.py プロジェクト: zhuyifengzju/RPN
 def backward_plan(self,
                   entity_states,
                   object_states,
                   focus_goal,
                   graph=None):
     inputs_focus = torch.cat(
         (tu.flatten(object_states), tu.flatten(focus_goal)), dim=-1)
     subgoal_preds = self._subgoal(inputs_focus)
     return {
         'subgoal_preds':
         subgoal_preds.reshape(focus_goal.shape[0], -1, self.c.symbol_size,
                               3),
     }
コード例 #5
0
    def forward(self, states, goal, subgoal=None):
        states = tu.flatten(states)
        goal = tu.flatten(goal)

        # get sub-goal prediction
        sg_out = self._sg_net(torch.cat((states, goal), dim=-1))

        sg_preds = sg_out.view(states.shape[0], -1, self.c.symbol_size, 3)
        if self.policy_mode:
            sg_cls = sg_preds.argmax(dim=-1)
            subgoal = tu.to_onehot(sg_cls, 3)  # [false, true, masked]

        # get action prediction
        return {'subgoal_preds': sg_preds, 'subgoal': subgoal}
コード例 #6
0
ファイル: rpn_pb.py プロジェクト: zhuyifengzju/RPN
 def forward(self, entity_enc, focus_goal):
     sym_enc = time_distributed(tu.flatten(focus_goal, begin_axis=2),
                                self._sym_encoder)
     focus_enc = self._reachable_encoder(
         torch.cat([entity_enc, sym_enc], dim=-1))
     reachable_enc_red, _ = torch.max(focus_enc, dim=-2)
     return self._reachable(reachable_enc_red)
コード例 #7
0
ファイル: rpn_pb.py プロジェクト: zhuyifengzju/RPN
 def backward_plan(self,
                   entity_states,
                   object_states,
                   focus_goal,
                   graph=None):
     inputs_focus = torch.cat(
         (tu.flatten(object_states), tu.flatten(focus_goal)), dim=-1)
     preimage_preds = self._preimage(inputs_focus)
     reachable_preds = self._reachable(entity_states, focus_goal)
     return {
         'preimage_preds':
         preimage_preds.reshape(focus_goal.shape[0], -1, self.c.symbol_size,
                                3),
         'reachable_preds':
         reachable_preds,
     }
コード例 #8
0
ファイル: rpn_pb.py プロジェクト: zhuyifengzju/RPN
 def forward(self, input_im):
     im_enc = self._net(input_im)
     im_enc = self._fc1(tu.flatten(im_enc, begin_axis=1))
     if self._out_act is not None:
         im_enc = self._out_act(im_enc)
     return im_enc
コード例 #9
0
 def forward(self, inputs):
     return flatten(inputs, self._begin_index)