Ejemplo n.º 1
0
    def forward_batch(self, batch):
        if not self.policy_mode:
            goal = masked_symbolic_state_index(batch['goal'],
                                               batch['goal_mask'])
            goal = tu.to_onehot(goal, 3)
            focus_goal = masked_symbolic_state_index(batch['goal'],
                                                     batch['focus_mask'])
            focus_goal = tu.to_onehot(focus_goal, 3)
            satisfied_info = batch['satisfied'][:, :-1]
            dependency_info = batch['dependency'][:, :-1]

            return self.forward(object_states=batch['states'],
                                entity_states=batch.get(
                                    'entity_states', batch['states']),
                                goal=goal,
                                focus_goal=focus_goal,
                                satisfied_info=satisfied_info,
                                dependency_info=dependency_info,
                                graph=batch.get('graph', None),
                                num_entities=batch.get('num_entities', None))
        else:
            goal = masked_symbolic_state_index(batch['goal'],
                                               batch['goal_mask'])
            goal = tu.to_onehot(goal, 3)
            return self.forward_policy(
                object_states=batch['states'],
                entity_states=batch.get('entity_states', batch['states']),
                goal=goal,
                graphs=batch.get('graphs', None),
            )
Ejemplo n.º 2
0
    def log_outputs(outputs, batch, summarizer, global_step, prefix):
        preimage = masked_symbolic_state_index(batch['preimage'],
                                               batch['preimage_mask'])
        preimage_preds = outputs['preimage_preds'].argmax(-1)
        preimage_preds.masked_fill_(batch['preimage_loss_mask'] == 0, 2)
        preimage.masked_fill_(batch['preimage_loss_mask'] == 0, 2)

        preimage_acc, preimage_mask_acc = masked_binary_accuracy(
            tu.to_onehot(preimage_preds, 3), preimage)
        focus = masked_symbolic_state_index(batch['goal'], batch['focus_mask'])
        focus_preds = outputs['focus_preds'].argmax(-1)
        focus_acc, focus_mask_acc = masked_binary_accuracy(
            tu.to_onehot(focus_preds, 3), focus)
        reachable_acc = classification_accuracy(outputs['reachable_preds'],
                                                batch['reachable'])
        summarizer.add_scalar(prefix + 'acc/focus',
                              focus_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/focus_mask',
                              focus_mask_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/preimage',
                              preimage_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/preimage_mask',
                              preimage_mask_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/reachable',
                              reachable_acc,
                              global_step=global_step)
Ejemplo n.º 3
0
 def forward_batch(self, batch):
     goal = masked_symbolic_state_index(batch['goal'], batch['goal_mask'])
     goal = tu.to_onehot(goal, 3)
     subgoal = None
     if not self.policy_mode:
         subgoal = masked_symbolic_state_index(batch['subgoal'],
                                               batch['subgoal_mask'])
         subgoal = tu.to_onehot(subgoal, 3)
     return self(
         states=batch['states'],
         goal=goal,
         subgoal=subgoal,
     )
Ejemplo n.º 4
0
    def log_outputs(outputs, batch, summarizer, global_step, prefix):
        preimage = masked_symbolic_state_index(batch['preimage'],
                                               batch['preimage_mask'])
        preimage_preds = outputs['preimage_preds'].argmax(-1)
        preimage_preds.masked_fill_(batch['preimage_loss_mask'] == 0, 2)
        preimage.masked_fill_(batch['preimage_loss_mask'] == 0, 2)

        preimage_acc, preimage_mask_acc = masked_binary_accuracy(
            tu.to_onehot(preimage_preds, 3), preimage)
        reachable_acc = classification_accuracy(outputs['reachable_preds'],
                                                batch['reachable'])
        satisfied_acc = classification_accuracy(
            outputs['satisfied_preds'], batch['satisfied'][:, -1].long())
        dependency_acc = classification_accuracy(
            outputs['dependency_preds'], batch['dependency'][:, -1].long())

        summarizer.add_scalar(prefix + 'acc/preimage',
                              preimage_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/preimage_mask',
                              preimage_mask_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/reachable',
                              reachable_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/satisfied',
                              satisfied_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/dependency',
                              dependency_acc,
                              global_step=global_step)
Ejemplo n.º 5
0
 def find_subgoal(self,
                  object_state,
                  entity_state,
                  goal,
                  graphs,
                  max_depth=10):
     return {
         'subgoal':
         tu.to_onehot(
             self.plan(object_state, goal)['subgoal_preds'].argmax(-1), 3),
         'ret':
         -1
     }
Ejemplo n.º 6
0
    def forward(self, states, goal, subgoal=None):
        states = tu.flatten(states)
        goal = tu.flatten(goal)

        # get sub-goal prediction
        sg_out = self._sg_net(torch.cat((states, goal), dim=-1))

        sg_preds = sg_out.view(states.shape[0], -1, self.c.symbol_size, 3)
        if self.policy_mode:
            sg_cls = sg_preds.argmax(dim=-1)
            subgoal = tu.to_onehot(sg_cls, 3)  # [false, true, masked]

        # get action prediction
        return {'subgoal_preds': sg_preds, 'subgoal': subgoal}
Ejemplo n.º 7
0
 def log_outputs(outputs, batch, summarizer, global_step, prefix):
     subgoal = masked_symbolic_state_index(batch['subgoal'],
                                           batch['subgoal_mask'])
     subgoal_preds = outputs['subgoal_preds'].argmax(-1)
     subgoal_acc, subgoal_mask_acc = masked_binary_accuracy(
         tu.to_onehot(subgoal_preds, 3), subgoal)
     satisfied_acc = classification_accuracy(
         outputs['satisfied_preds'], batch['satisfied'][:, -1].long())
     dependency_acc = classification_accuracy(
         outputs['dependency_preds'], batch['dependency'][:, -1].long())
     summarizer.add_scalar(prefix + 'acc/subgoal',
                           subgoal_acc,
                           global_step=global_step)
     summarizer.add_scalar(prefix + 'acc/subgoal_mask',
                           subgoal_mask_acc,
                           global_step=global_step)
     summarizer.add_scalar(prefix + 'acc/satisfied',
                           satisfied_acc,
                           global_step=global_step)
     summarizer.add_scalar(prefix + 'acc/dependency',
                           dependency_acc,
                           global_step=global_step)
Ejemplo n.º 8
0
    def find_subgoal(self,
                     object_state,
                     entity_state,
                     goal,
                     graphs,
                     max_depth=10):
        """
        Resolve the next subgoal directly
        """
        curr_goal = goal.argmax(dim=-1)  # [1, num_object, num_predicate]
        focus_group, ret = self._serialize_subgoals(entity_state, object_state,
                                                    curr_goal)
        subgoal = None
        if focus_group is not None:
            bp_out = self.backward_plan(entity_state, object_state,
                                        focus_group, graphs)
            subgoal_preds = bp_out['subgoal_preds'].argmax(dim=-1)
            subgoal = tu.to_onehot(subgoal_preds, 3)
            if self.verbose and self.env is not None:
                curr_goal_np = tu.to_numpy(curr_goal[0])
                print(
                    '[bp] current goals: ',
                    self.env.deserialize_goals(curr_goal_np,
                                               curr_goal_np != 2))
                focus_group_np = tu.to_numpy(focus_group[0])
                print(
                    '[bp] focus group: ',
                    self.env.deserialize_goals(focus_group_np[..., 1],
                                               (1 - focus_group_np[..., 2])))
                subgoal_np = tu.to_numpy(subgoal[0])
                print(
                    '[bp] subgoal: ',
                    self.env.deserialize_goals(subgoal_np[..., 1],
                                               (1 - subgoal_np[..., 2])))

        return {'subgoal': subgoal, 'ret': ret}
Ejemplo n.º 9
0
 def _serialize_subgoals(self, entity_state, object_state, curr_goal):
     curr_goal = tu.to_onehot(curr_goal, 3)
     focus_group = self.focus(entity_state, object_state,
                              curr_goal)['focus_preds']
     focus_group = tu.to_onehot(focus_group.argmax(-1), 3)
     return focus_group, -1
Ejemplo n.º 10
0
    def _serialize_subgoals(self, entity_state, object_state, curr_goal):
        assert (len(curr_goal.shape) == 3)
        assert (len(entity_state.shape) == 3)
        assert (entity_state.shape[1] == curr_goal.shape[1])
        state_np = tu.to_numpy(entity_state)[0]
        num_predicate = curr_goal.shape[-1]
        curr_goal_np = tu.to_numpy(curr_goal[0])
        goal_object_index, goal_predicates_index = np.where(curr_goal_np != 2)

        goal_index = np.stack([goal_object_index,
                               goal_predicates_index]).transpose()
        goal_predicates_value = curr_goal_np[(goal_object_index,
                                              goal_predicates_index)]
        num_goal = goal_index.shape[0]
        # predict satisfaction
        sat_state_inputs = state_np[goal_object_index]
        sat_predicate_mask = npu.to_onehot(goal_predicates_index,
                                           num_predicate).astype(np.float32)
        sat_predicate = sat_predicate_mask.copy()
        sat_predicate[sat_predicate_mask.astype(
            np.bool)] = goal_predicates_value

        sat_state_inputs = tu.to_tensor(sat_state_inputs[None, ...],
                                        device=entity_state.device)
        sat_sym_inputs_np = np.concatenate((sat_predicate, sat_predicate_mask),
                                           axis=-1)[None, ...]
        sat_sym_inputs = tu.to_tensor(sat_sym_inputs_np,
                                      device=entity_state.device)
        sat_preds = tu.to_numpy(
            self.forward_sat(sat_state_inputs,
                             sat_sym_inputs).argmax(-1))[0]  # [ng]
        assert (sat_preds.shape[0] == num_goal)
        if self.verbose and self.env is not None:
            for sat_p, sat_m, sp, oi in zip(sat_predicate, sat_predicate_mask,
                                            sat_preds, goal_object_index):
                sat_pad = np.hstack([[oi], sat_p, sat_m, [sp]])
                print('[bp] sat: ',
                      self.env.deserialize_satisfied_entry(sat_pad))

        # Construct dependency graphs
        nodes, edges = construct_full_graph(num_goal)
        src_object_index = goal_object_index[
            edges[:,
                  0]]  # list of [object_idx, predicate_idx] for each edge source
        tgt_object_index = goal_object_index[edges[:, 1]]
        src_inputs = state_np[src_object_index]  # list of object states
        tgt_inputs = state_np[tgt_object_index]

        src_predicate_value = goal_predicates_value[
            edges[:, 0]]  # list of predicate values for each edge source
        src_predicate_index = goal_predicates_index[edges[:, 0]]
        tgt_predicate_value = goal_predicates_value[edges[:, 1]]
        tgt_predicate_index = goal_predicates_index[edges[:, 1]]
        src_predicate_mask = npu.to_onehot(src_predicate_index,
                                           num_predicate).astype(np.float32)
        src_predicate = np.zeros_like(src_predicate_mask)
        src_predicate[src_predicate_mask.astype(np.bool)] = src_predicate_value
        tgt_predicate_mask = npu.to_onehot(tgt_predicate_index,
                                           num_predicate).astype(np.float32)
        tgt_predicate = np.zeros_like(tgt_predicate_mask)
        tgt_predicate[tgt_predicate_mask.astype(np.bool)] = tgt_predicate_value
        # dependency_inputs_np = np.concatenate(
        #     (src_inputs, tgt_inputs, src_predicate, src_predicate_mask, tgt_predicate, tgt_predicate_mask), axis=-1)
        dependency_state_inputs_np = np.concatenate((src_inputs, tgt_inputs),
                                                    axis=-1)
        dependency_sym_inputs_np = np.concatenate(
            (src_predicate, src_predicate_mask, tgt_predicate,
             tgt_predicate_mask),
            axis=-1)

        if dependency_state_inputs_np.shape[0] > 0:
            dependency_state_inputs = tu.to_tensor(
                dependency_state_inputs_np,
                device=entity_state.device).unsqueeze(0)
            dependency_sym_inputs = tu.to_tensor(
                dependency_sym_inputs_np,
                device=entity_state.device).unsqueeze(0)
            deps_preds = tu.to_numpy(
                self.forward_dep(dependency_state_inputs,
                                 dependency_sym_inputs).argmax(-1))[0]
            dep_graph_edges = edges[deps_preds > 0]
        else:
            dep_graph_edges = np.array([])
        sorted_goal_groups = sort_goal_graph(dep_graph_edges, nodes)
        focus_group_idx = None
        for gg in reversed(sorted_goal_groups):
            if not np.any(sat_preds[gg]):  # if unsatisfied
                focus_group_idx = gg
                break

        if focus_group_idx is None:
            return None, 'NETWORK_ALL_SATISFIED'

        # focus_group_idx is a list of goal index
        focus_group_np = np.ones_like(curr_goal_np) * 2
        for fg_idx in focus_group_idx:
            fg_obj_i, fg_pred_i = goal_index[fg_idx]
            focus_group_np[fg_obj_i, fg_pred_i] = curr_goal_np[fg_obj_i,
                                                               fg_pred_i]
        focus_group = tu.to_tensor(focus_group_np,
                                   device=entity_state.device).unsqueeze(0)
        focus_group = tu.to_onehot(focus_group, 3)
        return focus_group, -1