def masked_binary_accuracy(preds, labels): preds = to_numpy(preds) labels = to_numpy(labels) assert (np.all(labels <= 2)) pred_labels = np.argmax(preds, axis=-1) masked_acc = (pred_labels == labels)[labels != 2] mask_acc = (pred_labels == labels)[labels == 2] return masked_acc.astype(np.float64).mean(), mask_acc.astype( np.float64).mean()
def binary_accuracy(preds, labels): preds = to_numpy(preds) labels = to_numpy(labels) num_c = preds.shape[-1] fp = np.logical_and(preds > 0.5, labels < 0.5).astype(np.float64) fn = np.logical_and(preds < 0.5, labels > 0.5).astype(np.float64) acc = ((preds > 0.5) == (labels > 0.5)).astype(np.float64).mean() return acc, fp.reshape([-1, num_c]).mean(axis=0), fn.reshape([-1, num_c ]).mean(axis=0)
def test_graph_collation(): node_index, edge_index = construct_full_graph(5) node_input = torch.randn(10, 10) edge_input = [ get_edge_features(node_input[:5], edge_index, lambda a, b: b - a), get_edge_features(node_input[:5], edge_index, lambda a, b: b - a) ] edge_input = torch.cat(edge_input, dim=0) node_index = [node_index, node_index] edge_index = [edge_index, edge_index] gs = collate_torch_graphs(node_input, edge_input, node_index, edge_index) ni, ei = separate_graph_collated_features(gs.x, node_index, edge_index) assert (to_numpy(torch.all(ei == edge_input)) == 1) assert (to_numpy(torch.all(ni == node_input)) == 1)
def _add_stats(self, key, val, n_iter): if key not in self._stats: self._stats[key] = [] self._stats_iter[key] = [] if isinstance(val, torch.Tensor): val = to_numpy(val) self._stats[key].append(val) self._stats_iter[key].append(n_iter)
def debug(outputs, batch, env=None): subgoal = tu.to_numpy( masked_symbolic_state_index(batch['subgoal'], batch['subgoal_mask'])) subgoal_preds = tu.to_numpy(outputs['subgoal_preds'].argmax(-1)) focus_goal = tu.to_numpy( masked_symbolic_state_index(batch['goal'], batch['focus_mask'])) print('subgoal') for pi, pip, fg in zip(subgoal, subgoal_preds, focus_goal): if np.all(pi == pip): continue print( '- preds: ', env.masked_symbolic_state(env.deserialize_symbolic_state(pip))) print( 'label: ', env.masked_symbolic_state(env.deserialize_symbolic_state(pi))) print( 'focused: ', env.masked_symbolic_state(env.deserialize_symbolic_state(fg)))
def find_subgoal(self, object_state, entity_state, goal, graphs, max_depth=10): """ Resolve the next subgoal directly """ curr_goal = goal.argmax(dim=-1) # [1, num_object, num_predicate] focus_group, ret = self._serialize_subgoals(entity_state, object_state, curr_goal) subgoal = None if focus_group is not None: bp_out = self.backward_plan(entity_state, object_state, focus_group, graphs) subgoal_preds = bp_out['subgoal_preds'].argmax(dim=-1) subgoal = tu.to_onehot(subgoal_preds, 3) if self.verbose and self.env is not None: curr_goal_np = tu.to_numpy(curr_goal[0]) print( '[bp] current goals: ', self.env.deserialize_goals(curr_goal_np, curr_goal_np != 2)) focus_group_np = tu.to_numpy(focus_group[0]) print( '[bp] focus group: ', self.env.deserialize_goals(focus_group_np[..., 1], (1 - focus_group_np[..., 2]))) subgoal_np = tu.to_numpy(subgoal[0]) print( '[bp] subgoal: ', self.env.deserialize_goals(subgoal_np[..., 1], (1 - subgoal_np[..., 2]))) return {'subgoal': subgoal, 'ret': ret}
def debug(outputs, batch, env=None): preimage = tu.to_numpy( masked_symbolic_state_index(batch['preimage'], batch['preimage_mask'])) preimage_preds = tu.to_numpy(outputs['preimage_preds'].argmax(-1)) print('preimage') for pi, pip, pm in zip(preimage, preimage_preds, tu.to_numpy(batch['preimage_loss_mask'])): if np.all(pi == pip) or np.all(pm == 0): continue print( 'preds: ', env.masked_symbolic_state(env.deserialize_symbolic_state(pip))) print( 'label: ', env.masked_symbolic_state(env.deserialize_symbolic_state(pi))) focus_goal = tu.to_numpy( masked_symbolic_state_index(batch['goal'], batch['focus_mask'])) print('reachable') reachable_preds = tu.to_numpy(outputs['reachable_preds'].argmax(-1)) reachable_label = tu.to_numpy(batch['reachable']) for i, (rp, rl) in enumerate(zip(reachable_preds, reachable_label)): if int(rp) == int(rl): continue msg = 'fp' if int(rl) == 0 else 'fn' print( msg, env.masked_symbolic_state( env.deserialize_symbolic_state(focus_goal[i]))) print('dependency') dep_preds = tu.to_numpy(outputs['dependency_preds'].argmax(-1)) dep_label = tu.to_numpy(batch['dependency']) for i, (dp, dl) in enumerate((zip(dep_preds, dep_label))): if int(dp) == int(dl[-1]): continue msg = 'fp' if int(dl[-1]) == 0 else 'fn' print(msg, env.deserialize_dependency_entry(dl))
def find_subgoal(self, object_state, entity_state, goal, graphs, max_depth=10): """ Resolve the next subgoal recursively Planner logic: 1. Use Bron–Kerbosch to find all maximal cliques in the (disconnected) dependency graph 2. Form a DAG by using the cliques as nodes 3. Sort the DAG topologically. Find the first group that is not satisfied and name it root. 4. Use root as the mask for the current goal to form the focused group 5. Predict the preimage and reachability of the focused group 6. If the focus group is reachable, stop and feed the focused group to the policy. 7. Otherwise, treat the focus goal group as the new goal and go back to 1. :param object_state: current state of objects :param goal: global goal :return: the next subgoal """ curr_goal = goal.argmax(dim=-1) # [1, num_object, num_predicate] subgoal = None depth = 0 ret = -1 while depth < max_depth: if self.verbose: print('[bp] Depth: %i ==== ' % depth) if (curr_goal == 2).all(): ret = 'NETWORK_EMPTY_GOAL' break focus_group, ret = self._serialize_subgoals( entity_state, object_state, curr_goal) if focus_group is None: break # preimage bp_out = self.backward_plan(entity_state, object_state, focus_group, graphs) preimage_preds = bp_out['preimage_preds'].argmax(dim=-1) reachable_preds = bp_out['reachable_preds'].argmax(dim=-1) if self.verbose and self.env is not None: curr_goal_np = tu.to_numpy(curr_goal[0]) print( '[bp] current goals: ', self.env.deserialize_goals(curr_goal_np, curr_goal_np != 2)) focus_group_np = tu.to_numpy(focus_group[0]) print( '[bp] focus group: ', self.env.deserialize_goals(focus_group_np[..., 1], (1 - focus_group_np[..., 2]))) print('[bp] reachable: ', reachable_preds) if (reachable_preds == 1).any(): subgoal = focus_group break curr_goal = preimage_preds depth += 1 else: ret = 'NETWORK_MAX_DEPTH' if self.verbose: print('[bp] EOP###########') return {'subgoal': subgoal, 'ret': ret}
def _serialize_subgoals(self, entity_state, object_state, curr_goal): assert (len(curr_goal.shape) == 3) assert (len(entity_state.shape) == 3) assert (entity_state.shape[1] == curr_goal.shape[1]) state_np = tu.to_numpy(entity_state)[0] num_predicate = curr_goal.shape[-1] curr_goal_np = tu.to_numpy(curr_goal[0]) goal_object_index, goal_predicates_index = np.where(curr_goal_np != 2) goal_index = np.stack([goal_object_index, goal_predicates_index]).transpose() goal_predicates_value = curr_goal_np[(goal_object_index, goal_predicates_index)] num_goal = goal_index.shape[0] # predict satisfaction sat_state_inputs = state_np[goal_object_index] sat_predicate_mask = npu.to_onehot(goal_predicates_index, num_predicate).astype(np.float32) sat_predicate = sat_predicate_mask.copy() sat_predicate[sat_predicate_mask.astype( np.bool)] = goal_predicates_value sat_state_inputs = tu.to_tensor(sat_state_inputs[None, ...], device=entity_state.device) sat_sym_inputs_np = np.concatenate((sat_predicate, sat_predicate_mask), axis=-1)[None, ...] sat_sym_inputs = tu.to_tensor(sat_sym_inputs_np, device=entity_state.device) sat_preds = tu.to_numpy( self.forward_sat(sat_state_inputs, sat_sym_inputs).argmax(-1))[0] # [ng] assert (sat_preds.shape[0] == num_goal) if self.verbose and self.env is not None: for sat_p, sat_m, sp, oi in zip(sat_predicate, sat_predicate_mask, sat_preds, goal_object_index): sat_pad = np.hstack([[oi], sat_p, sat_m, [sp]]) print('[bp] sat: ', self.env.deserialize_satisfied_entry(sat_pad)) # Construct dependency graphs nodes, edges = construct_full_graph(num_goal) src_object_index = goal_object_index[ edges[:, 0]] # list of [object_idx, predicate_idx] for each edge source tgt_object_index = goal_object_index[edges[:, 1]] src_inputs = state_np[src_object_index] # list of object states tgt_inputs = state_np[tgt_object_index] src_predicate_value = goal_predicates_value[ edges[:, 0]] # list of predicate values for each edge source src_predicate_index = goal_predicates_index[edges[:, 0]] tgt_predicate_value = goal_predicates_value[edges[:, 1]] tgt_predicate_index = goal_predicates_index[edges[:, 1]] src_predicate_mask = npu.to_onehot(src_predicate_index, num_predicate).astype(np.float32) src_predicate = np.zeros_like(src_predicate_mask) src_predicate[src_predicate_mask.astype(np.bool)] = src_predicate_value tgt_predicate_mask = npu.to_onehot(tgt_predicate_index, num_predicate).astype(np.float32) tgt_predicate = np.zeros_like(tgt_predicate_mask) tgt_predicate[tgt_predicate_mask.astype(np.bool)] = tgt_predicate_value # dependency_inputs_np = np.concatenate( # (src_inputs, tgt_inputs, src_predicate, src_predicate_mask, tgt_predicate, tgt_predicate_mask), axis=-1) dependency_state_inputs_np = np.concatenate((src_inputs, tgt_inputs), axis=-1) dependency_sym_inputs_np = np.concatenate( (src_predicate, src_predicate_mask, tgt_predicate, tgt_predicate_mask), axis=-1) if dependency_state_inputs_np.shape[0] > 0: dependency_state_inputs = tu.to_tensor( dependency_state_inputs_np, device=entity_state.device).unsqueeze(0) dependency_sym_inputs = tu.to_tensor( dependency_sym_inputs_np, device=entity_state.device).unsqueeze(0) deps_preds = tu.to_numpy( self.forward_dep(dependency_state_inputs, dependency_sym_inputs).argmax(-1))[0] dep_graph_edges = edges[deps_preds > 0] else: dep_graph_edges = np.array([]) sorted_goal_groups = sort_goal_graph(dep_graph_edges, nodes) focus_group_idx = None for gg in reversed(sorted_goal_groups): if not np.any(sat_preds[gg]): # if unsatisfied focus_group_idx = gg break if focus_group_idx is None: return None, 'NETWORK_ALL_SATISFIED' # focus_group_idx is a list of goal index focus_group_np = np.ones_like(curr_goal_np) * 2 for fg_idx in focus_group_idx: fg_obj_i, fg_pred_i = goal_index[fg_idx] focus_group_np[fg_obj_i, fg_pred_i] = curr_goal_np[fg_obj_i, fg_pred_i] focus_group = tu.to_tensor(focus_group_np, device=entity_state.device).unsqueeze(0) focus_group = tu.to_onehot(focus_group, 3) return focus_group, -1
def classification_accuracy(preds, labels): preds = to_numpy(preds) labels = to_numpy(labels) pred_labels = np.argmax(preds, axis=-1) assert (pred_labels.shape == labels.shape) return (pred_labels == labels).astype(np.float64).mean()