def forward(self, entity_enc, focus_goal): sym_enc = time_distributed(tu.flatten(focus_goal, begin_axis=2), self._sym_encoder) focus_enc = self._reachable_encoder( torch.cat([entity_enc, sym_enc], dim=-1)) reachable_enc_red, _ = torch.max(focus_enc, dim=-2) return self._reachable(reachable_enc_red)
def backward_plan(self, entity_states, object_states, focus_goal, graphs=None): inputs_focus = torch.cat( (tu.flatten(object_states), tu.flatten(focus_goal)), dim=-1) preimage_preds = self._preimage(inputs_focus) reachable_inputs = torch.cat( (tu.flatten(entity_states, begin_axis=2), tu.flatten(focus_goal, begin_axis=2)), dim=-1) focus_enc = time_distributed(reachable_inputs, self._reachable_encoder) # reachable_enc = self._reachable_gn(focus_enc, graphs) reachable_enc_red, _ = torch.max(focus_enc, dim=-2) reachable_preds = self._reachable(reachable_enc_red) return { 'preimage_preds': preimage_preds.reshape(focus_goal.shape[0], -1, self.c.symbol_size, 3), 'reachable_preds': reachable_preds, }
def forward_policy(self, object_states, entity_states, goal, graphs=None): object_feat = time_distributed(object_states, self._state_encoder) planner_out = self.find_subgoal(object_feat, None, goal, graphs) if self.verbose and planner_out['ret'] != -1: print(planner_out['ret']) subgoal = planner_out['subgoal'] if subgoal is None: return planner_out # TODO: policy return planner_out
def forward( self, object_states, entity_states, goal, focus_goal, satisfied_info, dependency_info, graph, num_entities, ): object_feat = time_distributed(object_states, self._state_encoder) planner_out = self.plan(object_feat, goal) return planner_out
def forward( self, object_states, entity_states, goal, focus_goal, satisfied_info, dependency_info, graph, num_entities, ): object_feat = time_distributed(object_states, self._state_encoder) entity_feat = self.batch_entity_features(object_feat) planner_out = self.focus(entity_feat, object_feat, goal) planner_out.update( self.backward_plan(entity_feat, object_feat, focus_goal, graph)) return planner_out