def _forward(self, state): c = util.zeros(self._t, 2) c[state.t % 2] = 1 if state.flipped is None: return torch.cat([c, util.zeros(self._t, 6)]).view(1, 1, -1) else: return torch.cat([c, state.faces[state.card[state.flipped]] ]).view(1, 1, -1)
def _forward(self, state): view = util.zeros(self._t.weight, 1, 1, 4) view[0, 0, 0] = 1. view[0, 0, 1] = np.cos(state.th) view[0, 0, 2] = np.sin(state.th) view[0, 0, 3] = state.th_dot return Varng(view)
def _forward(self, state): view = util.zeros(self._t.weight, 1, 1, 4) view[0, 0, 0] = 1. view[0, 0, 1] = float(sum_hand(state.player)) view[0, 0, 2] = float(state.dealer[0]) view[0, 0, 3] = float(usable_ace(state.player)) return Varng(view)
def _forward(self, state): view = util.zeros(self._t, 1, 1, 10 * self.history_length) for h in range(self.history_length): obs = state.obs[max(0, len(state.obs) - h - 1)] for i in range(10): if (obs & i) > 0: view[0, 0, h * 10 + i] = 1. return Varng(view)
def _forward_batch(self, envs): batch_size = len(envs) txts = [util.getattr_deep(env, self.input_field) for env in envs] txt_len = list(map(len, txts)) max_len = max(txt_len) bow = util.zeros(self, batch_size, max_len, self.dim) for j, txt in enumerate(txts): self.set_bow(bow, j, txt) return Varng(bow), txt_len
def _forward(self, state): view = util.zeros(self._t.weight, 1, 1, self.dim) if not state.is_legal((state.loc[0] - 1, state.loc[1])): view[0, 0, 0] = 1. if not state.is_legal((state.loc[0] + 1, state.loc[1])): view[0, 0, 1] = 1. if not state.is_legal((state.loc[0], state.loc[1] - 1)): view[0, 0, 2] = 1. if not state.is_legal((state.loc[0], state.loc[1] + 1)): view[0, 0, 3] = 1. return Varng(view)
def _forward(self, state): f = [] for n in range(2 * self.n_card_types): c = util.zeros(self._t, self.n_card_types) if n in state.seen: c[state.card[n]] = 1 f.append(c) c = util.zeros(self._t, self.n_card_types) if state.flipped is not None: c[state.card[state.flipped]] = 1 f.append(c) c = util.zeros(self._t, 2) c[state.t % 2] = 1 f.append(c) if self.cheat: c = util.zeros(self._t, 2 * self.n_card_types) c[ConcentrationReference()(state)] = 1 f.append(c) return torch.cat(f).view(1, 1, -1)
def _forward(self, state, x): w = self.rnn.weight_ih # embed the previous action (if it exists) last_a = self.n_actions if len( state._trajectory) == 0 else state._trajectory[-1] if self.d_actemb is None: prev_a = util.zeros(w, 1, 1 + self.n_actions) prev_a[0, last_a] = 1 prev_a = Varng(prev_a) else: prev_a = self.embed_a(util.onehot(w, last_a)) # combine prev hidden state, prev action embedding, and input x inputs = torch.cat([prev_a] + x, 1) self.h = self.rnn(inputs, self.h) return self.hidden()
def _forward(self, state): ghost_positions = set(state.ghost_pos) view = util.zeros(self._t, 1, 1, 10 * self.history_length) for y in range(self.height): for x in range(self.width): idx = (x * self.height + y) * 8 c = 0 pos = (x, y) if not state.passable(pos): c = 1 if pos in state.food: c = 3 if state.check(pos, POWER) else 2 if pos in ghost_positions: c = 5 if state.power_steps == 0 else 7 elif pos == state.pocman: c = 4 if state.power_steps == 0 else 6 view[0, 0, idx + c] = 1 return Varng(view)
def _forward(self, state, x): feats = x[:] if self.act_history_length > 0: f = util.zeros(self, 1, self.act_history_length * self.n_actions) for i in range(min(self.act_history_length, len(state._trajectory))): a = state._trajectory[-i] f[0, i * self.n_actions + a] = 1 feats.append(Varng(f)) if self.obs_history_length > 0: for i in range(self.obs_history_length): feats.append( Varng(self.obs_history[(self.obs_history_pos + i) % self.obs_history_length])) # update history self.obs_history[self.obs_history_pos] = torch.cat(x, dim=1).data self.obs_history_pos = (self.obs_history_pos + 1) % self.obs_history_length return torch.cat(feats, dim=1)
def _reset(self): self.obs_history = [] for _ in range(self.obs_history_length): self.obs_history.append(util.zeros(self, 1, self.att_dim)) self.obs_history_pos = 0
def _forward(self, state): view = util.zeros(self._t.weight, 1, 1, self.dim) view[0, 0, state.loc[0] * state.example.height + state.loc[1]] = 1 return Varng(view)
def _reset(self): self.h = Varng(util.zeros(self.rnn.weight_ih, 1, self.d_hid)) if self.cell_type == 'LSTM': self.h = self.h, Varng( util.zeros(self.rnn.weight_ih, 1, self.d_hid))
def _forward(self, state): f = util.zeros(self._t.weight, 1, 1, self.n_states) if np.random.random() > self.noise_rate: f[0, 0, state.s] = 1 return Varng(f)
def _forward(self, env): txt = util.getattr_deep(env, self.input_field) bow = util.zeros(self, 1, len(txt), self.dim) self.set_bow(bow, 0, txt) return Varng(bow)