def backward_plan(self, entity_states, object_states, focus_goal, graphs=None): inputs_focus = torch.cat( (tu.flatten(object_states), tu.flatten(focus_goal)), dim=-1) preimage_preds = self._preimage(inputs_focus) reachable_inputs = torch.cat( (tu.flatten(entity_states, begin_axis=2), tu.flatten(focus_goal, begin_axis=2)), dim=-1) focus_enc = time_distributed(reachable_inputs, self._reachable_encoder) # reachable_enc = self._reachable_gn(focus_enc, graphs) reachable_enc_red, _ = torch.max(focus_enc, dim=-2) reachable_preds = self._reachable(reachable_enc_red) return { 'preimage_preds': preimage_preds.reshape(focus_goal.shape[0], -1, self.c.symbol_size, 3), 'reachable_preds': reachable_preds, }
def plan(self, object_states, goal): inputs = torch.cat((tu.flatten(object_states), tu.flatten(goal)), dim=-1) sg_preds = self._subgoal(inputs) return { 'subgoal_preds': sg_preds.reshape(goal.shape[0], -1, self.c.symbol_size, 3), }
def focus(self, entity_states, object_states, goal): inputs = torch.cat((tu.flatten(object_states), tu.flatten(goal)), dim=-1) focus_preds = self._focus(inputs) return { 'focus_preds': focus_preds.reshape(goal.shape[0], -1, self.c.symbol_size, 3), }
def backward_plan(self, entity_states, object_states, focus_goal, graph=None): inputs_focus = torch.cat( (tu.flatten(object_states), tu.flatten(focus_goal)), dim=-1) subgoal_preds = self._subgoal(inputs_focus) return { 'subgoal_preds': subgoal_preds.reshape(focus_goal.shape[0], -1, self.c.symbol_size, 3), }
def forward(self, states, goal, subgoal=None): states = tu.flatten(states) goal = tu.flatten(goal) # get sub-goal prediction sg_out = self._sg_net(torch.cat((states, goal), dim=-1)) sg_preds = sg_out.view(states.shape[0], -1, self.c.symbol_size, 3) if self.policy_mode: sg_cls = sg_preds.argmax(dim=-1) subgoal = tu.to_onehot(sg_cls, 3) # [false, true, masked] # get action prediction return {'subgoal_preds': sg_preds, 'subgoal': subgoal}
def forward(self, entity_enc, focus_goal): sym_enc = time_distributed(tu.flatten(focus_goal, begin_axis=2), self._sym_encoder) focus_enc = self._reachable_encoder( torch.cat([entity_enc, sym_enc], dim=-1)) reachable_enc_red, _ = torch.max(focus_enc, dim=-2) return self._reachable(reachable_enc_red)
def backward_plan(self, entity_states, object_states, focus_goal, graph=None): inputs_focus = torch.cat( (tu.flatten(object_states), tu.flatten(focus_goal)), dim=-1) preimage_preds = self._preimage(inputs_focus) reachable_preds = self._reachable(entity_states, focus_goal) return { 'preimage_preds': preimage_preds.reshape(focus_goal.shape[0], -1, self.c.symbol_size, 3), 'reachable_preds': reachable_preds, }
def forward(self, input_im): im_enc = self._net(input_im) im_enc = self._fc1(tu.flatten(im_enc, begin_axis=1)) if self._out_act is not None: im_enc = self._out_act(im_enc) return im_enc
def forward(self, inputs): return flatten(inputs, self._begin_index)