Esempio n. 1
0
 def _process_step(step):
     return DictTree(
         ret1h=torch_utils.one_hot(
             step.ret_name, self.act_names, dtype=step.ret_val.dtype, device=step.ret_val.device),
         ret_val=torch_utils.pad(step.ret_val, self.max_act_ret_size, 0),
         obs=step.obs,
         act1h=torch_utils.one_hot(
             step.act_name, self.act_names, dtype=step.act_arg.dtype, device=step.act_arg.device),
         act_arg=torch_utils.pad(step.act_arg, self.max_act_arg_size, 0),
     )
Esempio n. 2
0
 async def __call__(self, iput):
     env = iput.env
     env.pointer2 = max(env.pointer2 - 1, 1)
     num1 = torch_utils.one_hot(env.list_to_sort[env.pointer1 - 1], 10)
     num2 = torch_utils.one_hot(env.list_to_sort[env.pointer2 - 1], 10)
     env.last_obs = [
         num1,
         torch.tensor([
             float((env.pointer1 == 1) or (env.pointer1 == env.length))
         ]), num2,
         torch.tensor([
             float((env.pointer2 == 1) or (env.pointer2 == env.length))
         ])
     ]
     return torch.empty(0)
Esempio n. 3
0
    def reset(self, expert=False):
        """
        Reset environment.

        :return: root argument, one-hot target floor
        """
        self.length = np.random.choice(self.max_length - 1) + 2
        self.list_to_sort = list(np.random.randint(0, 10, self.length))
        self.pointer1 = 1
        self.pointer2 = 1
        root_arg = torch.Tensor([self.length])
        num1 = torch_utils.one_hot(self.list_to_sort[self.pointer1 - 1], 10)
        num2 = torch_utils.one_hot(self.list_to_sort[self.pointer2 - 1], 10)
        self.last_obs = [num1, torch.tensor([1.]), num2, torch.tensor([1.])]
        return DictTree(value=root_arg, expert_value=root_arg)
Esempio n. 4
0
 def _get_loss(self, batch, evaluate=False):
     # TODO: optionally use any available annotations
     packed_obs = rnn_utils.pack_sequence([
         torch.stack([step.obs for step in trace.data.steps])
         for trace in batch
     ])
     # TODO: handle act_arg
     packed_act_logprob, _ = self.act_logits(packed_obs.to(self.device))
     if evaluate:
         packed_true_act_idx = rnn_utils.pack_sequence([
             torch.tensor([
                 self.act_names.index(step.act_name)
                 for step in trace.data.steps
             ],
                          device=self.device) for trace in batch
         ])
         packed_error = torch_utils.apply2packed(
             lambda x: x.argmax(1) != packed_true_act_idx.data,
             packed_act_logprob)
         padded_error, _ = rnn_utils.pad_packed_sequence(packed_error,
                                                         batch_first=True)
         return padded_error.max(1)[0].long().sum().item()
     else:
         packed_true_act1h = rnn_utils.pack_sequence([
             torch.stack([
                 torch_utils.one_hot(step.act_name,
                                     self.act_names,
                                     device=self.device)
                 for step in trace.data.steps
             ]) for trace in batch
         ])
         return -(packed_true_act1h.data *
                  packed_act_logprob.data).sum().cpu()
Esempio n. 5
0
def mrcnn_cls_loss(rois_labels, predict_logits):
    """
    mrcnn 分类损失
    :param rois_labels: torch tensor [rois_num]
    :param predict_logits: torch tensor[rois_num,num_classes]
    :return: loss 标量
    """
    # 转one hot编码
    num_classes = predict_logits.size(1)
    labels = torch_utils.one_hot(rois_labels, num_classes)
    loss = F.cross_entropy(predict_logits, labels)  # 标量
    return loss
Esempio n. 6
0
File: acausal.py Progetto: royf/hvil
 def observe(self):
     if self.digit is None:
         self.digit = self.mnist[self.training][np.random.choice(
             len(self.mnist[
                 self.training]))]  # type: Tuple[torch.Tensor, int]
         return DictTree(
             value=torch.cat([
                 self.digit[0].view(self.img_size),
                 torch.zeros(self.num_cats)
             ]),
             teacher_value=torch_utils.one_hot(self.digit[1],
                                               2 * self.num_cats),
         )
     else:
         return DictTree(
             value=torch.cat([
                 torch.zeros(self.img_size),
                 torch_utils.one_hot(self.digit[1], self.num_cats)
             ]),
             teacher_value=torch_utils.one_hot(
                 self.num_cats + self.digit[1], 2 * self.num_cats),
         )
Esempio n. 7
0
File: php.py Progetto: royf/hvil
 async def forward(self, iput):
     """
     iput:
         is_root
         arg
         cnt
         ret = (ret_name, ret_val)
         obs
         ctx                       [optional]
         sub = (sub_name, sub_arg) [optional]
         act = (act_name, act_arg) [optional]
     modes:
         no ctx + no sub + no act = rollout:                sample p(sub | arg, cnt, ret, obs)
         no ctx + no sub +    act = evaluate:               sample p(sub | arg, cnt, ret, obs),
                                                                and compute error(sub, iput.act)
            ctx +    sub +    act = get_loss (annotated):   compute -log p(sub | arg, cnt, ret, obs)
                                                                   - log q(sub | arg, cnt, ret, ctx, act)
            ctx + no sub +    act = get_loss (unannotated): sample q(sub | arg, cnt, ret, ctx, act),
                                                                and compute D[q(. | arg, cnt, ret, ctx, act)
                                                                           || p(. | arg, cnt, ret, obs)]
                                                                    and log q(sub | arg, cnt, ret, ctx, act)
     oput:
         sub = (sub_name, sub_arg)
         error                     [in evaluate]
         loss                      [in get_loss]
         log_q                     [in get_loss]
     """
     assert not iput.arg.requires_grad
     assert not iput.cnt.requires_grad
     assert not iput.ret_val.requires_grad
     assert not iput.obs.requires_grad
     has_sub = 'sub_name' in iput
     if has_sub:
         assert not iput.sub_arg.requires_grad
     has_act = 'act_name' in iput
     if has_act:
         assert not iput.act_arg.requires_grad
     iput.ret_val = torch_utils.pad(iput.ret_val, self.ret_in_size, 0)
     iput.ret1h = torch_utils.one_hot(iput.ret_name,
                                      self.sub_names,
                                      dtype=iput.arg.dtype,
                                      device=iput.arg.device)
     if 'ctx' in iput:
         sub_iput = DictTree(
             p_iput=self._get_iput(iput, is_posterior=False,
                                   is_arg=False)[None],
             p_mask=self._get_mask(iput, is_posterior=False)[None],
             q_iput=self._get_iput(iput, is_posterior=True,
                                   is_arg=False)[None],
             q_mask=self._get_mask(iput, is_posterior=True)[None],
         )
         if has_sub:
             sub1h = torch_utils.one_hot(iput.sub_name,
                                         self.sub_names,
                                         dtype=iput.arg.dtype,
                                         device=iput.arg.device)
             sub_iput.true_oput = sub1h[None]
             sub_oput = await self.sub_module(sub_iput)
             sub_name = iput.sub_name
         else:
             sub_oput = await self.sub_module(sub_iput)
             [sub1h] = sub_oput.oput
             sub_name = torch_utils.lookup(sub1h, self.sub_names)
         [sub_loss] = sub_oput.loss
         [sub_log_p] = sub_oput.log_p
         [sub_log_q] = sub_oput.log_q
         if self.arg_module is None:
             assert self.arg_out_size == 0
             if has_sub:
                 sub_arg = iput.sub_arg
             else:
                 sub_arg = iput.arg.new_empty(0)
             arg_loss = iput.arg.new_zeros(())
             arg_log_p = iput.arg.new_zeros(())
             arg_log_q = iput.arg.new_zeros(())
         else:
             sub_arg_size = self.ret_out_size if sub_name is None else self.subs[
                 sub_name].arg_in_size
             sub_is_arg = self._sub_is_arg(iput
                                           | DictTree(sub_name=sub_name))
             arg_iput = DictTree(
                 p_iput=self._get_iput(iput | DictTree(sub1h=sub1h),
                                       is_posterior=False,
                                       is_arg=True)[None],
                 oput_size=iput.arg.new_full((),
                                             sub_arg_size,
                                             dtype=torch.int64)[None],
             )
             if sub_is_arg:
                 if has_sub:
                     # noinspection PyUnresolvedReferences
                     assert (iput.sub_arg == iput.act_arg).all()
                 # TODO: use iput.act_arg as auxiliary task
                 arg_iput.true_oput = torch_utils.pad(
                     iput.act_arg, self.arg_out_size, 0)[None]
             else:
                 arg_iput.q_iput = self._get_iput(iput
                                                  | DictTree(sub1h=sub1h),
                                                  is_posterior=True,
                                                  is_arg=True)[None]
                 if has_sub:
                     arg_iput.true_oput = torch_utils.pad(
                         iput.sub_arg, self.arg_out_size, 0)[None]
             arg_oput = await self.arg_module(arg_iput)
             if has_sub:
                 sub_arg = iput.sub_arg
             else:
                 [sub_arg] = arg_oput.oput.detach()[:, :sub_arg_size]
             [arg_loss] = arg_oput.loss
             [arg_log_p] = arg_oput.log_p
             [arg_log_q] = arg_oput.log_q
         oput = DictTree(
             sub_name=sub_name,
             sub_arg=sub_arg,
             loss=[sub_loss, arg_loss],
             log_p=[sub_log_p, arg_log_p],
             log_q=[sub_log_q, arg_log_q],
         )
     else:
         sub_iput = DictTree(
             p_iput=self._get_iput(iput, is_posterior=False,
                                   is_arg=False)[None],
             p_mask=self._get_mask(iput, is_posterior=False)[None],
         )
         if has_act and iput.act_name in self.sub_names:
             sub_iput.eval_oput = torch_utils.one_hot(
                 iput.act_name,
                 self.sub_names,
                 dtype=iput.arg.dtype,
                 device=iput.arg.device)[None]
         sub_oput = await self.sub_module(sub_iput)
         [sub1h] = sub_oput.oput
         sub_name = torch_utils.lookup(sub1h, self.sub_names)
         if has_act and iput.act_name in self.sub_names:
             [sub_error] = sub_oput.error.detach()
         else:
             sub_error = True
         if self.arg_module is None:
             assert self.arg_out_size == 0
             sub_arg = iput.arg.new_empty(0)
             arg_error = False
         else:
             sub_arg_size = self.ret_out_size if sub_name is None else self.subs[
                 sub_name].arg_in_size
             arg_iput = DictTree(
                 p_iput=self._get_iput(iput | DictTree(sub1h=sub1h),
                                       is_posterior=False,
                                       is_arg=True)[None],
                 oput_size=iput.arg.new_full((),
                                             sub_arg_size,
                                             dtype=torch.int64)[None],
             )
             if has_act and not sub_error:
                 arg_iput.eval_oput = torch_utils.pad(
                     iput.act_arg, self.arg_out_size, 0)[None]
             arg_oput = await self.arg_module(arg_iput)
             [sub_arg] = arg_oput.oput.detach()[:, :sub_arg_size]
             if has_act and not sub_error:
                 [arg_error] = arg_oput.error.detach()
             else:
                 arg_error = True
         oput = DictTree(
             sub_name=sub_name,
             sub_arg=sub_arg,
         )
         if has_act:
             oput.error = bool(sub_error) or bool(arg_error)
     return oput