Exemplo n.º 1
0
def masked_binary_accuracy(preds, labels):
    preds = to_numpy(preds)
    labels = to_numpy(labels)
    assert (np.all(labels <= 2))
    pred_labels = np.argmax(preds, axis=-1)
    masked_acc = (pred_labels == labels)[labels != 2]
    mask_acc = (pred_labels == labels)[labels == 2]
    return masked_acc.astype(np.float64).mean(), mask_acc.astype(
        np.float64).mean()
Exemplo n.º 2
0
def binary_accuracy(preds, labels):
    preds = to_numpy(preds)
    labels = to_numpy(labels)
    num_c = preds.shape[-1]
    fp = np.logical_and(preds > 0.5, labels < 0.5).astype(np.float64)
    fn = np.logical_and(preds < 0.5, labels > 0.5).astype(np.float64)
    acc = ((preds > 0.5) == (labels > 0.5)).astype(np.float64).mean()
    return acc, fp.reshape([-1,
                            num_c]).mean(axis=0), fn.reshape([-1, num_c
                                                              ]).mean(axis=0)
Exemplo n.º 3
0
def test_graph_collation():
    node_index, edge_index = construct_full_graph(5)
    node_input = torch.randn(10, 10)
    edge_input = [
        get_edge_features(node_input[:5], edge_index, lambda a, b: b - a),
        get_edge_features(node_input[:5], edge_index, lambda a, b: b - a)
    ]
    edge_input = torch.cat(edge_input, dim=0)
    node_index = [node_index, node_index]
    edge_index = [edge_index, edge_index]
    gs = collate_torch_graphs(node_input, edge_input, node_index, edge_index)
    ni, ei = separate_graph_collated_features(gs.x, node_index, edge_index)
    assert (to_numpy(torch.all(ei == edge_input)) == 1)
    assert (to_numpy(torch.all(ni == node_input)) == 1)
Exemplo n.º 4
0
 def _add_stats(self, key, val, n_iter):
     if key not in self._stats:
         self._stats[key] = []
         self._stats_iter[key] = []
     if isinstance(val, torch.Tensor):
         val = to_numpy(val)
     self._stats[key].append(val)
     self._stats_iter[key].append(n_iter)
Exemplo n.º 5
0
 def debug(outputs, batch, env=None):
     subgoal = tu.to_numpy(
         masked_symbolic_state_index(batch['subgoal'],
                                     batch['subgoal_mask']))
     subgoal_preds = tu.to_numpy(outputs['subgoal_preds'].argmax(-1))
     focus_goal = tu.to_numpy(
         masked_symbolic_state_index(batch['goal'], batch['focus_mask']))
     print('subgoal')
     for pi, pip, fg in zip(subgoal, subgoal_preds, focus_goal):
         if np.all(pi == pip):
             continue
         print(
             '- preds: ',
             env.masked_symbolic_state(env.deserialize_symbolic_state(pip)))
         print(
             'label: ',
             env.masked_symbolic_state(env.deserialize_symbolic_state(pi)))
         print(
             'focused: ',
             env.masked_symbolic_state(env.deserialize_symbolic_state(fg)))
Exemplo n.º 6
0
    def find_subgoal(self,
                     object_state,
                     entity_state,
                     goal,
                     graphs,
                     max_depth=10):
        """
        Resolve the next subgoal directly
        """
        curr_goal = goal.argmax(dim=-1)  # [1, num_object, num_predicate]
        focus_group, ret = self._serialize_subgoals(entity_state, object_state,
                                                    curr_goal)
        subgoal = None
        if focus_group is not None:
            bp_out = self.backward_plan(entity_state, object_state,
                                        focus_group, graphs)
            subgoal_preds = bp_out['subgoal_preds'].argmax(dim=-1)
            subgoal = tu.to_onehot(subgoal_preds, 3)
            if self.verbose and self.env is not None:
                curr_goal_np = tu.to_numpy(curr_goal[0])
                print(
                    '[bp] current goals: ',
                    self.env.deserialize_goals(curr_goal_np,
                                               curr_goal_np != 2))
                focus_group_np = tu.to_numpy(focus_group[0])
                print(
                    '[bp] focus group: ',
                    self.env.deserialize_goals(focus_group_np[..., 1],
                                               (1 - focus_group_np[..., 2])))
                subgoal_np = tu.to_numpy(subgoal[0])
                print(
                    '[bp] subgoal: ',
                    self.env.deserialize_goals(subgoal_np[..., 1],
                                               (1 - subgoal_np[..., 2])))

        return {'subgoal': subgoal, 'ret': ret}
Exemplo n.º 7
0
    def debug(outputs, batch, env=None):
        preimage = tu.to_numpy(
            masked_symbolic_state_index(batch['preimage'],
                                        batch['preimage_mask']))
        preimage_preds = tu.to_numpy(outputs['preimage_preds'].argmax(-1))
        print('preimage')
        for pi, pip, pm in zip(preimage, preimage_preds,
                               tu.to_numpy(batch['preimage_loss_mask'])):
            if np.all(pi == pip) or np.all(pm == 0):
                continue
            print(
                'preds: ',
                env.masked_symbolic_state(env.deserialize_symbolic_state(pip)))
            print(
                'label: ',
                env.masked_symbolic_state(env.deserialize_symbolic_state(pi)))

        focus_goal = tu.to_numpy(
            masked_symbolic_state_index(batch['goal'], batch['focus_mask']))

        print('reachable')
        reachable_preds = tu.to_numpy(outputs['reachable_preds'].argmax(-1))
        reachable_label = tu.to_numpy(batch['reachable'])
        for i, (rp, rl) in enumerate(zip(reachable_preds, reachable_label)):
            if int(rp) == int(rl):
                continue
            msg = 'fp' if int(rl) == 0 else 'fn'
            print(
                msg,
                env.masked_symbolic_state(
                    env.deserialize_symbolic_state(focus_goal[i])))

        print('dependency')
        dep_preds = tu.to_numpy(outputs['dependency_preds'].argmax(-1))
        dep_label = tu.to_numpy(batch['dependency'])
        for i, (dp, dl) in enumerate((zip(dep_preds, dep_label))):
            if int(dp) == int(dl[-1]):
                continue
            msg = 'fp' if int(dl[-1]) == 0 else 'fn'
            print(msg, env.deserialize_dependency_entry(dl))
Exemplo n.º 8
0
    def find_subgoal(self,
                     object_state,
                     entity_state,
                     goal,
                     graphs,
                     max_depth=10):
        """
        Resolve the next subgoal recursively

        Planner logic:
        1. Use Bron–Kerbosch to find all maximal cliques in the (disconnected) dependency graph
        2. Form a DAG by using the cliques as nodes
        3. Sort the DAG topologically. Find the first group that is not satisfied and name it root.
        4. Use root as the mask for the current goal to form the focused group
        5. Predict the preimage and reachability of the focused group
        6. If the focus group is reachable, stop and feed the focused group to the policy.
        7. Otherwise, treat the focus goal group as the new goal and go back to 1.

        :param object_state: current state of objects
        :param goal: global goal
        :return: the next subgoal
        """

        curr_goal = goal.argmax(dim=-1)  # [1, num_object, num_predicate]
        subgoal = None
        depth = 0
        ret = -1
        while depth < max_depth:
            if self.verbose:
                print('[bp] Depth: %i ==== ' % depth)
            if (curr_goal == 2).all():
                ret = 'NETWORK_EMPTY_GOAL'
                break
            focus_group, ret = self._serialize_subgoals(
                entity_state, object_state, curr_goal)
            if focus_group is None:
                break
            # preimage
            bp_out = self.backward_plan(entity_state, object_state,
                                        focus_group, graphs)
            preimage_preds = bp_out['preimage_preds'].argmax(dim=-1)
            reachable_preds = bp_out['reachable_preds'].argmax(dim=-1)

            if self.verbose and self.env is not None:
                curr_goal_np = tu.to_numpy(curr_goal[0])
                print(
                    '[bp] current goals: ',
                    self.env.deserialize_goals(curr_goal_np,
                                               curr_goal_np != 2))
                focus_group_np = tu.to_numpy(focus_group[0])
                print(
                    '[bp] focus group: ',
                    self.env.deserialize_goals(focus_group_np[..., 1],
                                               (1 - focus_group_np[..., 2])))
                print('[bp] reachable: ', reachable_preds)

            if (reachable_preds == 1).any():
                subgoal = focus_group
                break
            curr_goal = preimage_preds
            depth += 1
        else:
            ret = 'NETWORK_MAX_DEPTH'
        if self.verbose:
            print('[bp] EOP###########')
        return {'subgoal': subgoal, 'ret': ret}
Exemplo n.º 9
0
    def _serialize_subgoals(self, entity_state, object_state, curr_goal):
        assert (len(curr_goal.shape) == 3)
        assert (len(entity_state.shape) == 3)
        assert (entity_state.shape[1] == curr_goal.shape[1])
        state_np = tu.to_numpy(entity_state)[0]
        num_predicate = curr_goal.shape[-1]
        curr_goal_np = tu.to_numpy(curr_goal[0])
        goal_object_index, goal_predicates_index = np.where(curr_goal_np != 2)

        goal_index = np.stack([goal_object_index,
                               goal_predicates_index]).transpose()
        goal_predicates_value = curr_goal_np[(goal_object_index,
                                              goal_predicates_index)]
        num_goal = goal_index.shape[0]
        # predict satisfaction
        sat_state_inputs = state_np[goal_object_index]
        sat_predicate_mask = npu.to_onehot(goal_predicates_index,
                                           num_predicate).astype(np.float32)
        sat_predicate = sat_predicate_mask.copy()
        sat_predicate[sat_predicate_mask.astype(
            np.bool)] = goal_predicates_value

        sat_state_inputs = tu.to_tensor(sat_state_inputs[None, ...],
                                        device=entity_state.device)
        sat_sym_inputs_np = np.concatenate((sat_predicate, sat_predicate_mask),
                                           axis=-1)[None, ...]
        sat_sym_inputs = tu.to_tensor(sat_sym_inputs_np,
                                      device=entity_state.device)
        sat_preds = tu.to_numpy(
            self.forward_sat(sat_state_inputs,
                             sat_sym_inputs).argmax(-1))[0]  # [ng]
        assert (sat_preds.shape[0] == num_goal)
        if self.verbose and self.env is not None:
            for sat_p, sat_m, sp, oi in zip(sat_predicate, sat_predicate_mask,
                                            sat_preds, goal_object_index):
                sat_pad = np.hstack([[oi], sat_p, sat_m, [sp]])
                print('[bp] sat: ',
                      self.env.deserialize_satisfied_entry(sat_pad))

        # Construct dependency graphs
        nodes, edges = construct_full_graph(num_goal)
        src_object_index = goal_object_index[
            edges[:,
                  0]]  # list of [object_idx, predicate_idx] for each edge source
        tgt_object_index = goal_object_index[edges[:, 1]]
        src_inputs = state_np[src_object_index]  # list of object states
        tgt_inputs = state_np[tgt_object_index]

        src_predicate_value = goal_predicates_value[
            edges[:, 0]]  # list of predicate values for each edge source
        src_predicate_index = goal_predicates_index[edges[:, 0]]
        tgt_predicate_value = goal_predicates_value[edges[:, 1]]
        tgt_predicate_index = goal_predicates_index[edges[:, 1]]
        src_predicate_mask = npu.to_onehot(src_predicate_index,
                                           num_predicate).astype(np.float32)
        src_predicate = np.zeros_like(src_predicate_mask)
        src_predicate[src_predicate_mask.astype(np.bool)] = src_predicate_value
        tgt_predicate_mask = npu.to_onehot(tgt_predicate_index,
                                           num_predicate).astype(np.float32)
        tgt_predicate = np.zeros_like(tgt_predicate_mask)
        tgt_predicate[tgt_predicate_mask.astype(np.bool)] = tgt_predicate_value
        # dependency_inputs_np = np.concatenate(
        #     (src_inputs, tgt_inputs, src_predicate, src_predicate_mask, tgt_predicate, tgt_predicate_mask), axis=-1)
        dependency_state_inputs_np = np.concatenate((src_inputs, tgt_inputs),
                                                    axis=-1)
        dependency_sym_inputs_np = np.concatenate(
            (src_predicate, src_predicate_mask, tgt_predicate,
             tgt_predicate_mask),
            axis=-1)

        if dependency_state_inputs_np.shape[0] > 0:
            dependency_state_inputs = tu.to_tensor(
                dependency_state_inputs_np,
                device=entity_state.device).unsqueeze(0)
            dependency_sym_inputs = tu.to_tensor(
                dependency_sym_inputs_np,
                device=entity_state.device).unsqueeze(0)
            deps_preds = tu.to_numpy(
                self.forward_dep(dependency_state_inputs,
                                 dependency_sym_inputs).argmax(-1))[0]
            dep_graph_edges = edges[deps_preds > 0]
        else:
            dep_graph_edges = np.array([])
        sorted_goal_groups = sort_goal_graph(dep_graph_edges, nodes)
        focus_group_idx = None
        for gg in reversed(sorted_goal_groups):
            if not np.any(sat_preds[gg]):  # if unsatisfied
                focus_group_idx = gg
                break

        if focus_group_idx is None:
            return None, 'NETWORK_ALL_SATISFIED'

        # focus_group_idx is a list of goal index
        focus_group_np = np.ones_like(curr_goal_np) * 2
        for fg_idx in focus_group_idx:
            fg_obj_i, fg_pred_i = goal_index[fg_idx]
            focus_group_np[fg_obj_i, fg_pred_i] = curr_goal_np[fg_obj_i,
                                                               fg_pred_i]
        focus_group = tu.to_tensor(focus_group_np,
                                   device=entity_state.device).unsqueeze(0)
        focus_group = tu.to_onehot(focus_group, 3)
        return focus_group, -1
Exemplo n.º 10
0
def classification_accuracy(preds, labels):
    preds = to_numpy(preds)
    labels = to_numpy(labels)
    pred_labels = np.argmax(preds, axis=-1)
    assert (pred_labels.shape == labels.shape)
    return (pred_labels == labels).astype(np.float64).mean()