Exemplo n.º 1
0
    def log_outputs(outputs, batch, summarizer, global_step, prefix):
        preimage = masked_symbolic_state_index(batch['preimage'],
                                               batch['preimage_mask'])
        preimage_preds = outputs['preimage_preds'].argmax(-1)
        preimage_preds.masked_fill_(batch['preimage_loss_mask'] == 0, 2)
        preimage.masked_fill_(batch['preimage_loss_mask'] == 0, 2)

        preimage_acc, preimage_mask_acc = masked_binary_accuracy(
            tu.to_onehot(preimage_preds, 3), preimage)
        focus = masked_symbolic_state_index(batch['goal'], batch['focus_mask'])
        focus_preds = outputs['focus_preds'].argmax(-1)
        focus_acc, focus_mask_acc = masked_binary_accuracy(
            tu.to_onehot(focus_preds, 3), focus)
        reachable_acc = classification_accuracy(outputs['reachable_preds'],
                                                batch['reachable'])
        summarizer.add_scalar(prefix + 'acc/focus',
                              focus_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/focus_mask',
                              focus_mask_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/preimage',
                              preimage_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/preimage_mask',
                              preimage_mask_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/reachable',
                              reachable_acc,
                              global_step=global_step)
Exemplo n.º 2
0
    def forward_batch(self, batch):
        if not self.policy_mode:
            goal = masked_symbolic_state_index(batch['goal'],
                                               batch['goal_mask'])
            goal = tu.to_onehot(goal, 3)
            focus_goal = masked_symbolic_state_index(batch['goal'],
                                                     batch['focus_mask'])
            focus_goal = tu.to_onehot(focus_goal, 3)
            satisfied_info = batch['satisfied'][:, :-1]
            dependency_info = batch['dependency'][:, :-1]

            return self.forward(object_states=batch['states'],
                                entity_states=batch.get(
                                    'entity_states', batch['states']),
                                goal=goal,
                                focus_goal=focus_goal,
                                satisfied_info=satisfied_info,
                                dependency_info=dependency_info,
                                graph=batch.get('graph', None),
                                num_entities=batch.get('num_entities', None))
        else:
            goal = masked_symbolic_state_index(batch['goal'],
                                               batch['goal_mask'])
            goal = tu.to_onehot(goal, 3)
            return self.forward_policy(
                object_states=batch['states'],
                entity_states=batch.get('entity_states', batch['states']),
                goal=goal,
                graphs=batch.get('graphs', None),
            )
Exemplo n.º 3
0
 def forward_batch(self, batch):
     goal = masked_symbolic_state_index(batch['goal'], batch['goal_mask'])
     goal = tu.to_onehot(goal, 3)
     subgoal = None
     if not self.policy_mode:
         subgoal = masked_symbolic_state_index(batch['subgoal'],
                                               batch['subgoal_mask'])
         subgoal = tu.to_onehot(subgoal, 3)
     return self(
         states=batch['states'],
         goal=goal,
         subgoal=subgoal,
     )
Exemplo n.º 4
0
    def log_outputs(outputs, batch, summarizer, global_step, prefix):
        preimage = masked_symbolic_state_index(batch['preimage'],
                                               batch['preimage_mask'])
        preimage_preds = outputs['preimage_preds'].argmax(-1)
        preimage_preds.masked_fill_(batch['preimage_loss_mask'] == 0, 2)
        preimage.masked_fill_(batch['preimage_loss_mask'] == 0, 2)

        preimage_acc, preimage_mask_acc = masked_binary_accuracy(
            tu.to_onehot(preimage_preds, 3), preimage)
        reachable_acc = classification_accuracy(outputs['reachable_preds'],
                                                batch['reachable'])
        satisfied_acc = classification_accuracy(
            outputs['satisfied_preds'], batch['satisfied'][:, -1].long())
        dependency_acc = classification_accuracy(
            outputs['dependency_preds'], batch['dependency'][:, -1].long())

        summarizer.add_scalar(prefix + 'acc/preimage',
                              preimage_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/preimage_mask',
                              preimage_mask_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/reachable',
                              reachable_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/satisfied',
                              satisfied_acc,
                              global_step=global_step)
        summarizer.add_scalar(prefix + 'acc/dependency',
                              dependency_acc,
                              global_step=global_step)
Exemplo n.º 5
0
    def compute_losses(outputs, batch):
        preimage = masked_symbolic_state_index(batch['preimage'],
                                               batch['preimage_mask'])
        preimage_loss = nn.CrossEntropyLoss()(
            outputs['preimage_preds'].reshape(-1, 3) *
            batch['preimage_loss_mask'].reshape(-1).unsqueeze(-1),
            preimage.reshape(-1) *
            batch['preimage_loss_mask'].reshape(-1).long())
        reachable_loss = nn.CrossEntropyLoss()(outputs['reachable_preds'],
                                               batch['reachable'].long())
        satisfied_loss = nn.CrossEntropyLoss()(outputs['satisfied_preds'],
                                               batch['satisfied'][:,
                                                                  -1].long())
        dependency_loss = nn.CrossEntropyLoss()(outputs['dependency_preds'],
                                                batch['dependency'][:,
                                                                    -1].long())

        # action_loss = nn.CrossEntropyLoss()(outputs['action_preds'], batch['action_labels'])
        return {
            'preimage': preimage_loss,
            'reachable': reachable_loss,
            'satisfied': satisfied_loss,
            'dependency': dependency_loss,
            # 'action': action_loss
        }
Exemplo n.º 6
0
 def compute_losses(outputs, batch):
     subgoal = masked_symbolic_state_index(batch['subgoal'],
                                           batch['subgoal_mask'])
     subgoal_loss = nn.CrossEntropyLoss()(outputs['subgoal_preds'].reshape(
         -1, 3), subgoal.reshape(-1))
     return {
         'subgoal': subgoal_loss,
     }
Exemplo n.º 7
0
 def compute_losses(outputs, batch):
     preimage = masked_symbolic_state_index(batch['preimage'],
                                            batch['preimage_mask'])
     preimage_loss = nn.CrossEntropyLoss()(
         outputs['preimage_preds'].reshape(-1, 3) *
         batch['preimage_loss_mask'].reshape(-1).unsqueeze(-1),
         preimage.reshape(-1) *
         batch['preimage_loss_mask'].reshape(-1).long())
     focus = masked_symbolic_state_index(batch['goal'], batch['focus_mask'])
     focus_loss = nn.CrossEntropyLoss()(outputs['focus_preds'].reshape(
         -1, 3), focus.reshape(-1))
     reachable_loss = nn.CrossEntropyLoss()(outputs['reachable_preds'],
                                            batch['reachable'].long())
     return {
         'subgoal': focus_loss,
         'preimage': preimage_loss,
         'reachable': reachable_loss
     }
Exemplo n.º 8
0
 def log_outputs(outputs, batch, summarizer, global_step, prefix):
     subgoal = masked_symbolic_state_index(batch['subgoal'],
                                           batch['subgoal_mask'])
     subgoal_acc, subgoal_mask_acc = masked_binary_accuracy(
         outputs['subgoal_preds'], subgoal)
     summarizer.add_scalar(prefix + 'acc/subgoal',
                           subgoal_acc,
                           global_step=global_step)
     summarizer.add_scalar(prefix + 'acc/subgoal_mask',
                           subgoal_mask_acc,
                           global_step=global_step)
Exemplo n.º 9
0
 def debug(outputs, batch, env=None):
     subgoal = tu.to_numpy(
         masked_symbolic_state_index(batch['subgoal'],
                                     batch['subgoal_mask']))
     subgoal_preds = tu.to_numpy(outputs['subgoal_preds'].argmax(-1))
     focus_goal = tu.to_numpy(
         masked_symbolic_state_index(batch['goal'], batch['focus_mask']))
     print('subgoal')
     for pi, pip, fg in zip(subgoal, subgoal_preds, focus_goal):
         if np.all(pi == pip):
             continue
         print(
             '- preds: ',
             env.masked_symbolic_state(env.deserialize_symbolic_state(pip)))
         print(
             'label: ',
             env.masked_symbolic_state(env.deserialize_symbolic_state(pi)))
         print(
             'focused: ',
             env.masked_symbolic_state(env.deserialize_symbolic_state(fg)))
Exemplo n.º 10
0
    def debug(outputs, batch, env=None):
        preimage = tu.to_numpy(
            masked_symbolic_state_index(batch['preimage'],
                                        batch['preimage_mask']))
        preimage_preds = tu.to_numpy(outputs['preimage_preds'].argmax(-1))
        print('preimage')
        for pi, pip, pm in zip(preimage, preimage_preds,
                               tu.to_numpy(batch['preimage_loss_mask'])):
            if np.all(pi == pip) or np.all(pm == 0):
                continue
            print(
                'preds: ',
                env.masked_symbolic_state(env.deserialize_symbolic_state(pip)))
            print(
                'label: ',
                env.masked_symbolic_state(env.deserialize_symbolic_state(pi)))

        focus_goal = tu.to_numpy(
            masked_symbolic_state_index(batch['goal'], batch['focus_mask']))

        print('reachable')
        reachable_preds = tu.to_numpy(outputs['reachable_preds'].argmax(-1))
        reachable_label = tu.to_numpy(batch['reachable'])
        for i, (rp, rl) in enumerate(zip(reachable_preds, reachable_label)):
            if int(rp) == int(rl):
                continue
            msg = 'fp' if int(rl) == 0 else 'fn'
            print(
                msg,
                env.masked_symbolic_state(
                    env.deserialize_symbolic_state(focus_goal[i])))

        print('dependency')
        dep_preds = tu.to_numpy(outputs['dependency_preds'].argmax(-1))
        dep_label = tu.to_numpy(batch['dependency'])
        for i, (dp, dl) in enumerate((zip(dep_preds, dep_label))):
            if int(dp) == int(dl[-1]):
                continue
            msg = 'fp' if int(dl[-1]) == 0 else 'fn'
            print(msg, env.deserialize_dependency_entry(dl))
Exemplo n.º 11
0
 def compute_losses(outputs, batch):
     subgoal = masked_symbolic_state_index(batch['subgoal'],
                                           batch['subgoal_mask'])
     subgoal_loss = nn.CrossEntropyLoss()(outputs['subgoal_preds'].reshape(
         -1, 3), subgoal.reshape(-1))
     satisfied_loss = nn.CrossEntropyLoss()(outputs['satisfied_preds'],
                                            batch['satisfied'][:,
                                                               -1].long())
     dependency_loss = nn.CrossEntropyLoss()(outputs['dependency_preds'],
                                             batch['dependency'][:,
                                                                 -1].long())
     return {
         'subgoal': subgoal_loss,
         'satisfied': satisfied_loss,
         'dependency': dependency_loss,
     }
Exemplo n.º 12
0
 def log_outputs(outputs, batch, summarizer, global_step, prefix):
     subgoal = masked_symbolic_state_index(batch['subgoal'],
                                           batch['subgoal_mask'])
     subgoal_preds = outputs['subgoal_preds'].argmax(-1)
     subgoal_acc, subgoal_mask_acc = masked_binary_accuracy(
         tu.to_onehot(subgoal_preds, 3), subgoal)
     satisfied_acc = classification_accuracy(
         outputs['satisfied_preds'], batch['satisfied'][:, -1].long())
     dependency_acc = classification_accuracy(
         outputs['dependency_preds'], batch['dependency'][:, -1].long())
     summarizer.add_scalar(prefix + 'acc/subgoal',
                           subgoal_acc,
                           global_step=global_step)
     summarizer.add_scalar(prefix + 'acc/subgoal_mask',
                           subgoal_mask_acc,
                           global_step=global_step)
     summarizer.add_scalar(prefix + 'acc/satisfied',
                           satisfied_acc,
                           global_step=global_step)
     summarizer.add_scalar(prefix + 'acc/dependency',
                           dependency_acc,
                           global_step=global_step)