Ejemplo n.º 1
0
    def reset(self, expert=False):
        """
        Reset environment.

        :return: root argument, one-hot target floor
        """
        self.length = np.random.choice(self.max_length - 1) + 2
        self.list_to_sort = list(np.random.randint(0, 10, self.length))
        self.pointer1 = 1
        self.pointer2 = 1
        root_arg = torch.Tensor([self.length])
        num1 = torch_utils.one_hot(self.list_to_sort[self.pointer1 - 1], 10)
        num2 = torch_utils.one_hot(self.list_to_sort[self.pointer2 - 1], 10)
        self.last_obs = [num1, torch.tensor([1.]), num2, torch.tensor([1.])]
        return DictTree(value=root_arg, expert_value=root_arg)
Ejemplo n.º 2
0
 async def __call__(self, iput):
     env = iput.env
     env.pointer2 = max(env.pointer2 - 1, 1)
     num1 = torch_utils.one_hot(env.list_to_sort[env.pointer1 - 1], 10)
     num2 = torch_utils.one_hot(env.list_to_sort[env.pointer2 - 1], 10)
     env.last_obs = [
         num1,
         torch.tensor([
             float((env.pointer1 == 1) or (env.pointer1 == env.length))
         ]), num2,
         torch.tensor([
             float((env.pointer2 == 1) or (env.pointer2 == env.length))
         ])
     ]
     return torch.empty(0)
Ejemplo n.º 3
0
 def _get_loss(self, batch, evaluate=False):
     # TODO: optionally use any available annotations
     packed_obs = rnn_utils.pack_sequence([
         torch.stack([step.obs for step in trace.data.steps])
         for trace in batch
     ])
     # TODO: handle act_arg
     packed_act_logprob, _ = self.act_logits(packed_obs.to(self.device))
     if evaluate:
         packed_true_act_idx = rnn_utils.pack_sequence([
             torch.tensor([
                 self.act_names.index(step.act_name)
                 for step in trace.data.steps
             ],
                          device=self.device) for trace in batch
         ])
         packed_error = torch_utils.apply2packed(
             lambda x: x.argmax(1) != packed_true_act_idx.data,
             packed_act_logprob)
         padded_error, _ = rnn_utils.pad_packed_sequence(packed_error,
                                                         batch_first=True)
         return padded_error.max(1)[0].long().sum().item()
     else:
         packed_true_act1h = rnn_utils.pack_sequence([
             torch.stack([
                 torch_utils.one_hot(step.act_name,
                                     self.act_names,
                                     device=self.device)
                 for step in trace.data.steps
             ]) for trace in batch
         ])
         return -(packed_true_act1h.data *
                  packed_act_logprob.data).sum().cpu()