def forward_batch(self, batch): if not self.policy_mode: goal = masked_symbolic_state_index(batch['goal'], batch['goal_mask']) goal = tu.to_onehot(goal, 3) focus_goal = masked_symbolic_state_index(batch['goal'], batch['focus_mask']) focus_goal = tu.to_onehot(focus_goal, 3) satisfied_info = batch['satisfied'][:, :-1] dependency_info = batch['dependency'][:, :-1] return self.forward(object_states=batch['states'], entity_states=batch.get( 'entity_states', batch['states']), goal=goal, focus_goal=focus_goal, satisfied_info=satisfied_info, dependency_info=dependency_info, graph=batch.get('graph', None), num_entities=batch.get('num_entities', None)) else: goal = masked_symbolic_state_index(batch['goal'], batch['goal_mask']) goal = tu.to_onehot(goal, 3) return self.forward_policy( object_states=batch['states'], entity_states=batch.get('entity_states', batch['states']), goal=goal, graphs=batch.get('graphs', None), )
def log_outputs(outputs, batch, summarizer, global_step, prefix): preimage = masked_symbolic_state_index(batch['preimage'], batch['preimage_mask']) preimage_preds = outputs['preimage_preds'].argmax(-1) preimage_preds.masked_fill_(batch['preimage_loss_mask'] == 0, 2) preimage.masked_fill_(batch['preimage_loss_mask'] == 0, 2) preimage_acc, preimage_mask_acc = masked_binary_accuracy( tu.to_onehot(preimage_preds, 3), preimage) focus = masked_symbolic_state_index(batch['goal'], batch['focus_mask']) focus_preds = outputs['focus_preds'].argmax(-1) focus_acc, focus_mask_acc = masked_binary_accuracy( tu.to_onehot(focus_preds, 3), focus) reachable_acc = classification_accuracy(outputs['reachable_preds'], batch['reachable']) summarizer.add_scalar(prefix + 'acc/focus', focus_acc, global_step=global_step) summarizer.add_scalar(prefix + 'acc/focus_mask', focus_mask_acc, global_step=global_step) summarizer.add_scalar(prefix + 'acc/preimage', preimage_acc, global_step=global_step) summarizer.add_scalar(prefix + 'acc/preimage_mask', preimage_mask_acc, global_step=global_step) summarizer.add_scalar(prefix + 'acc/reachable', reachable_acc, global_step=global_step)
def forward_batch(self, batch): goal = masked_symbolic_state_index(batch['goal'], batch['goal_mask']) goal = tu.to_onehot(goal, 3) subgoal = None if not self.policy_mode: subgoal = masked_symbolic_state_index(batch['subgoal'], batch['subgoal_mask']) subgoal = tu.to_onehot(subgoal, 3) return self( states=batch['states'], goal=goal, subgoal=subgoal, )
def log_outputs(outputs, batch, summarizer, global_step, prefix): preimage = masked_symbolic_state_index(batch['preimage'], batch['preimage_mask']) preimage_preds = outputs['preimage_preds'].argmax(-1) preimage_preds.masked_fill_(batch['preimage_loss_mask'] == 0, 2) preimage.masked_fill_(batch['preimage_loss_mask'] == 0, 2) preimage_acc, preimage_mask_acc = masked_binary_accuracy( tu.to_onehot(preimage_preds, 3), preimage) reachable_acc = classification_accuracy(outputs['reachable_preds'], batch['reachable']) satisfied_acc = classification_accuracy( outputs['satisfied_preds'], batch['satisfied'][:, -1].long()) dependency_acc = classification_accuracy( outputs['dependency_preds'], batch['dependency'][:, -1].long()) summarizer.add_scalar(prefix + 'acc/preimage', preimage_acc, global_step=global_step) summarizer.add_scalar(prefix + 'acc/preimage_mask', preimage_mask_acc, global_step=global_step) summarizer.add_scalar(prefix + 'acc/reachable', reachable_acc, global_step=global_step) summarizer.add_scalar(prefix + 'acc/satisfied', satisfied_acc, global_step=global_step) summarizer.add_scalar(prefix + 'acc/dependency', dependency_acc, global_step=global_step)
def find_subgoal(self, object_state, entity_state, goal, graphs, max_depth=10): return { 'subgoal': tu.to_onehot( self.plan(object_state, goal)['subgoal_preds'].argmax(-1), 3), 'ret': -1 }
def forward(self, states, goal, subgoal=None): states = tu.flatten(states) goal = tu.flatten(goal) # get sub-goal prediction sg_out = self._sg_net(torch.cat((states, goal), dim=-1)) sg_preds = sg_out.view(states.shape[0], -1, self.c.symbol_size, 3) if self.policy_mode: sg_cls = sg_preds.argmax(dim=-1) subgoal = tu.to_onehot(sg_cls, 3) # [false, true, masked] # get action prediction return {'subgoal_preds': sg_preds, 'subgoal': subgoal}
def log_outputs(outputs, batch, summarizer, global_step, prefix): subgoal = masked_symbolic_state_index(batch['subgoal'], batch['subgoal_mask']) subgoal_preds = outputs['subgoal_preds'].argmax(-1) subgoal_acc, subgoal_mask_acc = masked_binary_accuracy( tu.to_onehot(subgoal_preds, 3), subgoal) satisfied_acc = classification_accuracy( outputs['satisfied_preds'], batch['satisfied'][:, -1].long()) dependency_acc = classification_accuracy( outputs['dependency_preds'], batch['dependency'][:, -1].long()) summarizer.add_scalar(prefix + 'acc/subgoal', subgoal_acc, global_step=global_step) summarizer.add_scalar(prefix + 'acc/subgoal_mask', subgoal_mask_acc, global_step=global_step) summarizer.add_scalar(prefix + 'acc/satisfied', satisfied_acc, global_step=global_step) summarizer.add_scalar(prefix + 'acc/dependency', dependency_acc, global_step=global_step)
def find_subgoal(self, object_state, entity_state, goal, graphs, max_depth=10): """ Resolve the next subgoal directly """ curr_goal = goal.argmax(dim=-1) # [1, num_object, num_predicate] focus_group, ret = self._serialize_subgoals(entity_state, object_state, curr_goal) subgoal = None if focus_group is not None: bp_out = self.backward_plan(entity_state, object_state, focus_group, graphs) subgoal_preds = bp_out['subgoal_preds'].argmax(dim=-1) subgoal = tu.to_onehot(subgoal_preds, 3) if self.verbose and self.env is not None: curr_goal_np = tu.to_numpy(curr_goal[0]) print( '[bp] current goals: ', self.env.deserialize_goals(curr_goal_np, curr_goal_np != 2)) focus_group_np = tu.to_numpy(focus_group[0]) print( '[bp] focus group: ', self.env.deserialize_goals(focus_group_np[..., 1], (1 - focus_group_np[..., 2]))) subgoal_np = tu.to_numpy(subgoal[0]) print( '[bp] subgoal: ', self.env.deserialize_goals(subgoal_np[..., 1], (1 - subgoal_np[..., 2]))) return {'subgoal': subgoal, 'ret': ret}
def _serialize_subgoals(self, entity_state, object_state, curr_goal): curr_goal = tu.to_onehot(curr_goal, 3) focus_group = self.focus(entity_state, object_state, curr_goal)['focus_preds'] focus_group = tu.to_onehot(focus_group.argmax(-1), 3) return focus_group, -1
def _serialize_subgoals(self, entity_state, object_state, curr_goal): assert (len(curr_goal.shape) == 3) assert (len(entity_state.shape) == 3) assert (entity_state.shape[1] == curr_goal.shape[1]) state_np = tu.to_numpy(entity_state)[0] num_predicate = curr_goal.shape[-1] curr_goal_np = tu.to_numpy(curr_goal[0]) goal_object_index, goal_predicates_index = np.where(curr_goal_np != 2) goal_index = np.stack([goal_object_index, goal_predicates_index]).transpose() goal_predicates_value = curr_goal_np[(goal_object_index, goal_predicates_index)] num_goal = goal_index.shape[0] # predict satisfaction sat_state_inputs = state_np[goal_object_index] sat_predicate_mask = npu.to_onehot(goal_predicates_index, num_predicate).astype(np.float32) sat_predicate = sat_predicate_mask.copy() sat_predicate[sat_predicate_mask.astype( np.bool)] = goal_predicates_value sat_state_inputs = tu.to_tensor(sat_state_inputs[None, ...], device=entity_state.device) sat_sym_inputs_np = np.concatenate((sat_predicate, sat_predicate_mask), axis=-1)[None, ...] sat_sym_inputs = tu.to_tensor(sat_sym_inputs_np, device=entity_state.device) sat_preds = tu.to_numpy( self.forward_sat(sat_state_inputs, sat_sym_inputs).argmax(-1))[0] # [ng] assert (sat_preds.shape[0] == num_goal) if self.verbose and self.env is not None: for sat_p, sat_m, sp, oi in zip(sat_predicate, sat_predicate_mask, sat_preds, goal_object_index): sat_pad = np.hstack([[oi], sat_p, sat_m, [sp]]) print('[bp] sat: ', self.env.deserialize_satisfied_entry(sat_pad)) # Construct dependency graphs nodes, edges = construct_full_graph(num_goal) src_object_index = goal_object_index[ edges[:, 0]] # list of [object_idx, predicate_idx] for each edge source tgt_object_index = goal_object_index[edges[:, 1]] src_inputs = state_np[src_object_index] # list of object states tgt_inputs = state_np[tgt_object_index] src_predicate_value = goal_predicates_value[ edges[:, 0]] # list of predicate values for each edge source src_predicate_index = goal_predicates_index[edges[:, 0]] tgt_predicate_value = goal_predicates_value[edges[:, 1]] tgt_predicate_index = goal_predicates_index[edges[:, 1]] src_predicate_mask = npu.to_onehot(src_predicate_index, num_predicate).astype(np.float32) src_predicate = np.zeros_like(src_predicate_mask) src_predicate[src_predicate_mask.astype(np.bool)] = src_predicate_value tgt_predicate_mask = npu.to_onehot(tgt_predicate_index, num_predicate).astype(np.float32) tgt_predicate = np.zeros_like(tgt_predicate_mask) tgt_predicate[tgt_predicate_mask.astype(np.bool)] = tgt_predicate_value # dependency_inputs_np = np.concatenate( # (src_inputs, tgt_inputs, src_predicate, src_predicate_mask, tgt_predicate, tgt_predicate_mask), axis=-1) dependency_state_inputs_np = np.concatenate((src_inputs, tgt_inputs), axis=-1) dependency_sym_inputs_np = np.concatenate( (src_predicate, src_predicate_mask, tgt_predicate, tgt_predicate_mask), axis=-1) if dependency_state_inputs_np.shape[0] > 0: dependency_state_inputs = tu.to_tensor( dependency_state_inputs_np, device=entity_state.device).unsqueeze(0) dependency_sym_inputs = tu.to_tensor( dependency_sym_inputs_np, device=entity_state.device).unsqueeze(0) deps_preds = tu.to_numpy( self.forward_dep(dependency_state_inputs, dependency_sym_inputs).argmax(-1))[0] dep_graph_edges = edges[deps_preds > 0] else: dep_graph_edges = np.array([]) sorted_goal_groups = sort_goal_graph(dep_graph_edges, nodes) focus_group_idx = None for gg in reversed(sorted_goal_groups): if not np.any(sat_preds[gg]): # if unsatisfied focus_group_idx = gg break if focus_group_idx is None: return None, 'NETWORK_ALL_SATISFIED' # focus_group_idx is a list of goal index focus_group_np = np.ones_like(curr_goal_np) * 2 for fg_idx in focus_group_idx: fg_obj_i, fg_pred_i = goal_index[fg_idx] focus_group_np[fg_obj_i, fg_pred_i] = curr_goal_np[fg_obj_i, fg_pred_i] focus_group = tu.to_tensor(focus_group_np, device=entity_state.device).unsqueeze(0) focus_group = tu.to_onehot(focus_group, 3) return focus_group, -1