Exemplo n.º 1
0
 def _forward(self, state):
     c = util.zeros(self._t, 2)
     c[state.t % 2] = 1
     if state.flipped is None:
         return torch.cat([c, util.zeros(self._t, 6)]).view(1, 1, -1)
     else:
         return torch.cat([c, state.faces[state.card[state.flipped]]
                           ]).view(1, 1, -1)
Exemplo n.º 2
0
 def _forward(self, state):
     view = util.zeros(self._t.weight, 1, 1, 4)
     view[0, 0, 0] = 1.
     view[0, 0, 1] = np.cos(state.th)
     view[0, 0, 2] = np.sin(state.th)
     view[0, 0, 3] = state.th_dot
     return Varng(view)
Exemplo n.º 3
0
 def _forward(self, state):
     view = util.zeros(self._t.weight, 1, 1, 4)
     view[0, 0, 0] = 1.
     view[0, 0, 1] = float(sum_hand(state.player))
     view[0, 0, 2] = float(state.dealer[0])
     view[0, 0, 3] = float(usable_ace(state.player))
     return Varng(view)
Exemplo n.º 4
0
 def _forward(self, state):
     view = util.zeros(self._t, 1, 1, 10 * self.history_length)
     for h in range(self.history_length):
         obs = state.obs[max(0, len(state.obs) - h - 1)]
         for i in range(10):
             if (obs & i) > 0:
                 view[0, 0, h * 10 + i] = 1.
     return Varng(view)
Exemplo n.º 5
0
 def _forward_batch(self, envs):
     batch_size = len(envs)
     txts = [util.getattr_deep(env, self.input_field) for env in envs]
     txt_len = list(map(len, txts))
     max_len = max(txt_len)
     bow = util.zeros(self, batch_size, max_len, self.dim)
     for j, txt in enumerate(txts):
         self.set_bow(bow, j, txt)
     return Varng(bow), txt_len
Exemplo n.º 6
0
 def _forward(self, state):
     view = util.zeros(self._t.weight, 1, 1, self.dim)
     if not state.is_legal((state.loc[0] - 1, state.loc[1])):
         view[0, 0, 0] = 1.
     if not state.is_legal((state.loc[0] + 1, state.loc[1])):
         view[0, 0, 1] = 1.
     if not state.is_legal((state.loc[0], state.loc[1] - 1)):
         view[0, 0, 2] = 1.
     if not state.is_legal((state.loc[0], state.loc[1] + 1)):
         view[0, 0, 3] = 1.
     return Varng(view)
Exemplo n.º 7
0
    def _forward(self, state):
        f = []
        for n in range(2 * self.n_card_types):
            c = util.zeros(self._t, self.n_card_types)
            if n in state.seen:
                c[state.card[n]] = 1
            f.append(c)

        c = util.zeros(self._t, self.n_card_types)
        if state.flipped is not None:
            c[state.card[state.flipped]] = 1
        f.append(c)

        c = util.zeros(self._t, 2)
        c[state.t % 2] = 1
        f.append(c)

        if self.cheat:
            c = util.zeros(self._t, 2 * self.n_card_types)
            c[ConcentrationReference()(state)] = 1
            f.append(c)

        return torch.cat(f).view(1, 1, -1)
Exemplo n.º 8
0
Arquivo: rnn.py Projeto: yyht/macarico
    def _forward(self, state, x):
        w = self.rnn.weight_ih
        # embed the previous action (if it exists)
        last_a = self.n_actions if len(
            state._trajectory) == 0 else state._trajectory[-1]
        if self.d_actemb is None:
            prev_a = util.zeros(w, 1, 1 + self.n_actions)
            prev_a[0, last_a] = 1
            prev_a = Varng(prev_a)
        else:
            prev_a = self.embed_a(util.onehot(w, last_a))

        # combine prev hidden state, prev action embedding, and input x
        inputs = torch.cat([prev_a] + x, 1)
        self.h = self.rnn(inputs, self.h)
        return self.hidden()
Exemplo n.º 9
0
 def _forward(self, state):
     ghost_positions = set(state.ghost_pos)
     view = util.zeros(self._t, 1, 1, 10 * self.history_length)
     for y in range(self.height):
         for x in range(self.width):
             idx = (x * self.height + y) * 8
             c = 0
             pos = (x, y)
             if not state.passable(pos): c = 1
             if pos in state.food:
                 c = 3 if state.check(pos, POWER) else 2
             if pos in ghost_positions:
                 c = 5 if state.power_steps == 0 else 7
             elif pos == state.pocman:
                 c = 4 if state.power_steps == 0 else 6
             view[0, 0, idx + c] = 1
     return Varng(view)
Exemplo n.º 10
0
 def _forward(self, state, x):
     feats = x[:]
     if self.act_history_length > 0:
         f = util.zeros(self, 1, self.act_history_length * self.n_actions)
         for i in range(min(self.act_history_length,
                            len(state._trajectory))):
             a = state._trajectory[-i]
             f[0, i * self.n_actions + a] = 1
         feats.append(Varng(f))
     if self.obs_history_length > 0:
         for i in range(self.obs_history_length):
             feats.append(
                 Varng(self.obs_history[(self.obs_history_pos + i) %
                                        self.obs_history_length]))
         # update history
         self.obs_history[self.obs_history_pos] = torch.cat(x, dim=1).data
         self.obs_history_pos = (self.obs_history_pos +
                                 1) % self.obs_history_length
     return torch.cat(feats, dim=1)
Exemplo n.º 11
0
 def _reset(self):
     self.obs_history = []
     for _ in range(self.obs_history_length):
         self.obs_history.append(util.zeros(self, 1, self.att_dim))
     self.obs_history_pos = 0
Exemplo n.º 12
0
 def _forward(self, state):
     view = util.zeros(self._t.weight, 1, 1, self.dim)
     view[0, 0, state.loc[0] * state.example.height + state.loc[1]] = 1
     return Varng(view)
Exemplo n.º 13
0
Arquivo: rnn.py Projeto: yyht/macarico
 def _reset(self):
     self.h = Varng(util.zeros(self.rnn.weight_ih, 1, self.d_hid))
     if self.cell_type == 'LSTM':
         self.h = self.h, Varng(
             util.zeros(self.rnn.weight_ih, 1, self.d_hid))
Exemplo n.º 14
0
Arquivo: mdp.py Projeto: yyht/macarico
 def _forward(self, state):
     f = util.zeros(self._t.weight, 1, 1, self.n_states)
     if np.random.random() > self.noise_rate:
         f[0, 0, state.s] = 1
     return Varng(f)
Exemplo n.º 15
0
 def _forward(self, env):
     txt = util.getattr_deep(env, self.input_field)
     bow = util.zeros(self, 1, len(txt), self.dim)
     self.set_bow(bow, 0, txt)
     return Varng(bow)