def forward(self, imgs, extract_sym_features: bool = False): lead_dim, T, B, img_shape = infer_leading_dims(imgs, 3) imgs = imgs.view(T * B, *img_shape) # NHWC -> NCHW imgs = imgs.transpose(3, 2).transpose(2, 1) # TODO: don't do uint8 -> float32 conversion twice (in detector and here) if self.sym_extractor is not None and extract_sym_features: sym_feats = self.sym_extractor(imgs[:, -1:]) else: sym_feats = None # convert from [0, 255] uint8 to [0, 1] float32 imgs = imgs.to(torch.float32) imgs.mul_(1.0 / 255) m = self.normalize_mean.to(imgs.device) s = self.normalize_std.to(imgs.device) imgs.sub_(m).div_(s) feats = self.conv1(imgs) feats = self.bn1(feats) feats = self.relu(feats) feats = self.maxpool(feats) feats = self.layer1(feats) feats = self.relu(feats) feats = self.layer2(feats) feats = self.relu(feats) feats = self.layer3(feats) feats = self.relu(feats) feats = self.flatten(feats) feats = self.linear(feats) feats = self.relu(feats) value = self.value_predictor(feats).squeeze( -1) # squeezing seems expected? if self.categorical: action_dist = self.action_predictor(feats) action_dist, value = restore_leading_dims((action_dist, value), lead_dim, T, B) if self.sym_extractor is not None and extract_sym_features: # have to "restore dims" for sym_feats manually... sym_feats = sym_feats.reshape((*action_dist.shape[:-1], -1)) # sym_feats = torch.rand_like(action_dist) return action_dist, value, sym_feats else: mu = self.mu_predictor(feats) log_std = self.log_std_predictor(feats) mu, log_std, value = restore_leading_dims((mu, log_std, value), lead_dim, T, B) if self.sym_extractor is not None and extract_sym_features: # have to "restore dims" for sym_feats manually... sym_feats = sym_feats.reshape((*mu.shape[:-1], -1)) # sym_feats = torch.rand_like(action_dist) return mu, log_std, value, sym_feats
def forward(self, observation): lead_dim, T, B, img_shape = infer_leading_dims(observation, 3) if observation.dtype == torch.uint8: img = observation.type(torch.float) img = img.mul_(1.0 / 255) else: img = observation conv, conv_layers = self.conv(img.view(T * B, *img_shape)) # lists all layers c = self.head(conv_layers[-1].view(T * B, -1)) c, conv = restore_leading_dims((c, conv), lead_dim, T, B) conv_layers = restore_leading_dims(conv_layers, lead_dim, T, B) return c, conv, conv_layers # include conv_outs for local-stdim losses
def forward(self, observation, prev_action, prev_reward, action, detach_encoder=False): # dummy lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) q_input = torch.cat( [observation.view(T * B, -1), action.view(T * B, -1)], dim=1) q1 = self.q1_mlp(q_input).squeeze(-1) q1 = restore_leading_dims(q1, lead_dim, T, B) q2 = self.q2_mlp(q_input).squeeze(-1) q2 = restore_leading_dims(q2, lead_dim, T, B) return q1, q2
def forward(self, observation, prev_action, prev_reward, init_rnn_state): img = observation.type(torch.float) # Expect torch.uint8 inputs # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, img_shape = infer_leading_dims(img, 3) conv_out = self.conv(img.reshape(T * B, *img_shape)) # Fold if T dimension. features = conv_out.reshape(T * B, -1) init_rnn_state = None if init_rnn_state is None else tuple( init_rnn_state) lstm_out, (hidden_state, cell_state) = self.lstm(features.reshape(T, B, -1), init_rnn_state) head_input = torch.cat((features, lstm_out.reshape(T * B, -1)), dim=-1) pi = self.pi_head(head_input) pi = torch.softmax(pi, dim=-1) # pi = torch.sigmoid(pi - 2) value = self.value_head(head_input).squeeze(-1) # Restore leading dimensions: [T,B], [B], or [], as input. pi, value = restore_leading_dims((pi, value), lead_dim, T, B) state = RnnState(h=hidden_state, c=cell_state) return pi, value, state
def forward(self, observation, prev_action, prev_reward): lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) output = self.mlp(observation.view(T * B, -1)) alpha, beta = output[:, :self._action_size], output[:, self._action_size:] alpha, beta = restore_leading_dims((alpha, beta), lead_dim, T, B) return alpha, beta
def forward(self, observation, prev_action, prev_reward, init_rnn_state=None): """ Compute mean, log_std, and value estimate from input state. Infers leading dimensions of input: can be [T,B], [B], or []; provides returns with same leading dims. Intermediate feedforward layers process as [T*B,H], with T=1,B=1 when not given. Used both in sampler and in algorithm (both via the agent). """ # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, _ = infer_leading_dims(observation.state, self._obs_ndim) assert not torch.any(torch.isnan(observation.state)), 'obs elem is nan' obs_flat = observation.state.reshape(T * B, -1) # obs_flat = self.layer_norm(obs_flat) # features = self.shared_mlp(obs_flat) action = self.mu_head(obs_flat) v = self.v_head(obs_flat).squeeze(-1) # mu, std = (action[:, :self.action_size], action[:, self.action_size:]) # std = self.softplus(std) pi_output = self.softplus(action * 8) + 1 mu, std = pi_output[:, :self.action_size], pi_output[:, self.action_size:] # Restore leading dimensions: [T,B], [B], or [], as input. mu, std, v = restore_leading_dims((mu, std, v), lead_dim, T, B) return mu, std, v, torch.zeros(1, B, 1), None # return fake rnn state
def forward(self, image, prev_action, prev_reward, init_rnn_state): """Feedforward layers process as [T*B,H], recurrent ones as [T,B,H]. Return same leading dims as input, can be [T,B], [B], or []. (Same forward used for sampling and training.)""" img = image.type(torch.float) # Expect torch.uint8 inputs img = img.mul_(1. / 255) # From [0-255] to [0-1], in place. # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, img_shape = infer_leading_dims(img, 3) fc_out = self.conv(img.view(T * B, *img_shape)) lstm_input = torch.cat([ fc_out.view(T, B, -1), prev_action.view(T, B, -1), # Assumed onehot. prev_reward.view(T, B, 1), ], dim=2) init_rnn_state = None if init_rnn_state is None else tuple(init_rnn_state) lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state) pi = F.softmax(self.pi(lstm_out.view(T * B, -1)), dim=-1) v = self.value(lstm_out.view(T * B, -1)).squeeze(-1) # Restore leading dimensions: [T,B], [B], or [], as input. pi, v = restore_leading_dims((pi, v), lead_dim, T, B) # Model should always leave B-dimension in rnn state: [N,B,H]. next_rnn_state = RnnState(h=hn, c=cn) return pi, v, next_rnn_state
def forward(self, image, prev_action=None, prev_reward=None, init_rnn_state=None): img = image.type(torch.float) img = img.mul_(1. / 255) lead_dim, T, B, img_shape = infer_leading_dims(img, 3) img = img.view(T * B, *img_shape) fc_out = self.conv(img.view(T * B, *img_shape)) lstm_input = torch.cat([ fc_out.view(T, B, -1), prev_action.view(T, B, -1), # Assumed onehot. prev_reward.view(T, B, 1), ], dim=2) init_rnn_state = None if init_rnn_state is None else tuple(init_rnn_state) lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state) pi = F.softmax(self.pi(lstm_out.view(T * B, -1)), dim=-1) v = self.value(lstm_out.view(T * B, -1)).squeeze(-1) # Restore leading dimensions: [T,B], [B], or [], as input. pi, v = restore_leading_dims((pi, v), lead_dim, T, B) # Model should always leave B-dimension in rnn state: [N,B,H]. next_rnn_state = RnnState(h=hn, c=cn) return pi, v, next_rnn_state
def forward(self, observation, prev_action, prev_reward, init_rnn_state): # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_n_dim) if self.normalize_observation: obs_var = self.obs_rms.var if self.norm_obs_var_clip is not None: obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip) observation = torch.clamp( (observation - self.obs_rms.mean) / obs_var.sqrt(), -self.norm_obs_clip, self.norm_obs_clip) mlp_out = self.mlp(observation.view(T * B, -1)) lstm_input = torch.cat( [ mlp_out.view(T, B, -1), prev_action.view(T, B, -1), # Assumed onehot. prev_reward.view(T, B, 1), ], dim=2) init_rnn_state = None if init_rnn_state is None else tuple( init_rnn_state) lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state) pi = F.softmax(self.pi(lstm_out.view(T * B, -1)), dim=-1) v = self.value(lstm_out.view(T * B, -1)).squeeze(-1) # Restore leading dimensions: [T,B], [B], or [], as input. pi, v = restore_leading_dims((pi, v), lead_dim, T, B) # Model should always leave B-dimension in rnn state: [N,B,H]. next_rnn_state = RnnState(h=hn, c=cn) return pi, v, next_rnn_state
def forward(self, observation, prev_action, prev_reward, init_rnn_state): lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_dim) if init_rnn_state is not None: if self.rnn_is_lstm: init_rnn_pi, init_rnn_v = tuple(init_rnn_state) # DualRnnState -> RnnState_pi, RnnState_v init_rnn_pi, init_rnn_v = tuple(init_rnn_pi), tuple(init_rnn_v) else: init_rnn_pi, init_rnn_v = tuple(init_rnn_state) # DualRnnState -> h, h else: init_rnn_pi, init_rnn_v = None, None o_flat = self.preprocessor(observation.view(T*B)) b_pi, b_v = self.body_pi(o_flat), self.body_v(o_flat) p_a, p_r = prev_action.view(T, B, -1), prev_reward.view(T, B, 1) pi_inp_list = [b_pi.view(T,B,-1)] + ([p_a] if self.p_a in [1,3] else []) + ([p_r] if self.p_r in [1,3] else []) v_inp_list = [b_pi.view(T, B, -1)] + ([p_a] if self.p_a in [2,3] else []) + ([p_r] if self.p_r in [2, 3] else []) rnn_input_pi = torch.cat(pi_inp_list, dim=2) rnn_input_v = torch.cat(v_inp_list, dim=2) rnn_pi, next_rnn_state_pi = self.rnn_pi(rnn_input_pi, init_rnn_pi) rnn_v, next_rnn_state_v = self.rnn_v(rnn_input_v, init_rnn_v) rnn_pi = rnn_pi.view(T*B, -1); rnn_v = rnn_v.view(T*B, -1) pi, v = self.pi(rnn_pi), self.v(rnn_v).squeeze(-1) pi, v = restore_leading_dims((pi, v), lead_dim, T, B) if self.rnn_is_lstm: next_rnn_state = DualRnnState(RnnState(*next_rnn_state_pi), RnnState(*next_rnn_state_v)) else: next_rnn_state = DualRnnState(next_rnn_state_pi, next_rnn_state_v) return pi, v, next_rnn_state
def forward(self, image, prev_action, prev_reward): """ Compute action probabilities and value estimate from input state. Infers leading dimensions of input: can be [T,B], [B], or []; provides returns with same leading dims. Convolution layers process as [T*B, *image_shape], with T=1,B=1 when not given. Expects uint8 images in [0,255] and converts them to float32 in [0,1] (to minimize image data storage and transfer). Used in both sampler and in algorithm (both via the agent). """ # img = image.type(torch.float) # Expect torch.uint8 inputs # img = img.mul_(1. / 255) # From [0-255] to [0-1], in place. # Expects [0-1] float # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, img_shape = infer_leading_dims(image, 3) fc_out = self.conv(image.view(T * B, *img_shape)) mu = self.mu(fc_out) v = self.v(fc_out).squeeze(-1) log_std = self.log_std.repeat(T * B, 1) # Restore leading dimensions: [T,B], [B], or [], as input. mu, log_std, v = restore_leading_dims((mu, log_std, v), lead_dim, T, B) return mu, log_std, v
def forward(self, observation): lead_dim, T, B, img_shape = infer_leading_dims(observation, 3) obs = observation.view(T * B, *img_shape) obs = (obs.permute(0, 3, 1, 2).float() / 255.) * 2 - 1 z_fake = torch.normal(torch.zeros(obs.shape[0], self.zdim), torch.ones(obs.shape[0], self.zdim)).to(device) z_real = self.e(obs).reshape(obs.shape[0], self.zdim) x_fake = self.g(z_fake).reshape(obs.shape[0], -1) x_real = obs.view(obs.shape[0], -1) # extractor_in=z_fake # if self.detach_encoder: # extractor_in = extractor_in.detach() # extractor_out = self.shared_extractor(extractor_in) # if self.detach_policy: # policy_in = extractor_out.detach() # else: # policy_in = extractor_out # if self.detach_value: # value_in = extractor_out.detach() # else: # value_in = extractor_out # act_dist = self.policy(policy_in) # value = self.value(value_in).squeeze(-1) act_dist, value = 0, 0 label_real = self.d(z_real, x_real) label_fake = self.d(z_fake, x_fake) latent, reconstruction = restore_leading_dims((z_fake, x_fake), lead_dim, T, B) return act_dist, value, latent, reconstruction, label_real, label_fake
def forward(self, observation, prev_action, prev_reward): lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) obs_flat = observation.view(T * B, -1) pi = F.softmax(self.pi(obs_flat), dim=-1) v = self.v(obs_flat).squeeze(-1) pi, v = restore_leading_dims((pi, v), lead_dim, T, B) return pi, v
def forward(self, observation, prev_action, prev_reward): """ Compute mean, log_std, and value estimates from input state. Includes value estimates of both distinct intrinsic and extrinsic returns. See rlpyt MujocoFfModel for more information on this function. """ # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) if self.normalize_observation: obs_var = self.obs_rms.var if self.norm_obs_var_clip is not None: obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip) observation = torch.clamp((observation - self.obs_rms.mean) / obs_var.sqrt(), -self.norm_obs_clip, self.norm_obs_clip) obs_flat = observation.view(T * B, -1) mu = self.mu(obs_flat) ev = self.v(obs_flat).squeeze(-1) iv = self.iv(obs_flat).squeeze(-1) # Added intrinsic value MLP forward pass log_std = self.log_std.repeat(T * B, 1) # Restore leading dimensions: [T,B], [B], or [], as input. mu, log_std, ev, iv = restore_leading_dims((mu, log_std, ev, iv), lead_dim, T, B) return mu, log_std, ev, iv
def forward(self, image, prev_action=None, prev_reward=None): img = image.type(torch.float) img = img.mul_(1. / 255) lead_dim, T, B, img_shape = infer_leading_dims(img, 3) img = img.view(T * B, *img_shape) relu = torch.nn.functional.relu y = relu(self.conv1(img)) y = self.attention(y) y = relu(self.conv2(y)) y = relu(self.conv3(y)) fc_out = self.fc(y.view(T * B, -1)) pi = torch.nn.functional.softmax(self.pi(fc_out), dim=-1) v = self.value(fc_out).squeeze(-1) # Restore leading dimensions: [T,B], [B], or [], as input. #T -> transition #B -> batch_size? pi, v = restore_leading_dims((pi, v), lead_dim, T, B) return pi, v
def forward(self, image, prev_action=None, prev_reward=None): img = image.type(torch.float) img = img.mul_(1. / 255) lead_dim, T, B, img_shape = infer_leading_dims(img, 3) img = img.view(T * B, *img_shape) relu = torch.nn.functional.relu l1 = relu(self.conv1(img)) l2 = relu(self.conv2(l1)) y = relu(self.conv3(l2)) #fc_out = self.fc(y.view(T * B, -1)) fc_out = self.fc(y) l1 = self.projector(l1) l2 = self.projector2(l2) c1, g1 = self.attn1(l1, fc_out) c2, g2 = self.attn2(l2, fc_out) g = torch.cat((g1, g2), dim=1) # batch_sizexC # classification layer pi = torch.nn.functional.softmax(self.pi(g), dim=-1) v = self.value(g).squeeze(-1) # Restore leading dimensions: [T,B], [B], or [], as input. #T -> transition #B -> batch_size? pi, v = restore_leading_dims((pi, v), lead_dim, T, B) return pi, v
def forward(self, observation, prev_action, prev_reward): """ Compute action Q-value estimates from input state. Infers leading dimensions of input: can be [T,B], [B], or []; provides returns with same leading dims. Convolution layers process as [T*B, image_shape[0], image_shape[1],...,image_shape[-1]], with T=1,B=1 when not given. Expects uint8 images in [0,255] and converts them to float32 in [0,1] (to minimize image data storage and transfer). Used in both sampler and in algorithm (both via the agent). """ img = observation.type(torch.float) # Expect torch.uint8 inputs # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, img_shape = infer_leading_dims(img, 3) # features = self.mlp(img.reshape(T * B, -1)) conv_out = self.conv(img.view(T * B, *img_shape)) # Fold if T dimension. # q = self.head(self.dropout(conv_out.view(T * B, -1))) q = self.head(conv_out.view(T * B, -1)) # q = self.head(features) # Restore leading dimensions: [T,B], [B], or [], as input. q = restore_leading_dims(q, lead_dim, T, B) return q
def forward(self, observation, prev_action, prev_reward): """ Compute mean, log_std, q-value, and termination estimates from input state. Infers leading dimensions of input: can be [T,B], [B], or []; provides returns with same leading dims. Intermediate feedforward layers process as [T*B,H], with T=1,B=1 when not given. Used both in sampler and in algorithm (both via the agent). """ # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) if self.normalize_observation: obs_var = self.obs_rms.var if self.norm_obs_var_clip is not None: obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip) observation = torch.clamp((observation - self.obs_rms.mean) / obs_var.sqrt(), -self.norm_obs_clip, self.norm_obs_clip) obs_flat = observation.view(T * B, -1) mu, logstd = self.mu(obs_flat) q = self.q(obs_flat) log_std = logstd.repeat(T * B, 1, 1) beta = self.beta(obs_flat) pi = self.pi_omega(obs_flat) I = self.pi_omega_I(obs_flat) # Restore leading dimensions: [T,B], [B], or [], as input. mu, log_std, q, beta, pi, I = restore_leading_dims((mu, log_std, q, beta, pi, I), lead_dim, T, B) pi = pi * I # Torch multinomial will normalize return mu, log_std, beta, q, pi
def forward(self, observation, prev_action, prev_reward): """Feedforward layers process as [T*B,H]. Return same leading dims as input, can be [T,B], [B], or [].""" observation = self.encoder(observation) if self.pooling is not None: if self.pooling == 'average': pooled = observation.mean(-2) elif self.pooling == 'max': pooled = observation.max(-2) pooled = pooled.unsqueeze(-2).expand_as(observation) observation = torch.cat([observation, pooled], dim=-1) # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) obs_flat = observation.view(T * B * self._n_pop, -1) mu = self.mu(obs_flat) v = self.v(obs_flat).squeeze(-1) log_std = self.log_std.repeat(T * B, 1) mu = mu.view(T * B, self._n_pop, -1) v = v.view(T * B, self._n_pop) log_std = log_std.view(T * B, self._n_pop, -1) # Restore leading dimensions: [T,B], [B], or [], as input. mu, log_std, v = restore_leading_dims((mu, log_std, v), lead_dim, T, B) return mu, log_std, v
def forward(self, observation, prev_action, prev_reward, state): lead_dim, T, B, _ = infer_leading_dims(observation.state, 1) # print(f'step in episode {observation.step_in_episode}') aux_loss = None if T == 1: transformer_output, state = self.sample_forward( observation.state, state) value = torch.zeros(B) elif T == self.sequence_length: transformer_output, aux_loss = self.optim_forward( observation.state, state) value = self.value_head(transformer_output).reshape(T * B, -1) else: raise NotImplementedError pi_output = self.pi_head(transformer_output).view(T * B, -1) mu, std = pi_output[:, :self.action_size], pi_output[:, self.action_size:] std = self.softplus(std - 1) mu = torch.tanh(mu) # pi_output = self.softplus(pi_output * 1) + 1 # mu, std = pi_output[:, :self.action_size], pi_output[:, self.action_size:] mu, std, value = restore_leading_dims((mu, std, value), lead_dim, T, B) return mu, std, value, state, aux_loss
def forward(self, observation, prev_action, prev_reward, init_rnn_state): if observation.dtype == torch.uint8: img = observation.type(torch.float) img = img.mul_(1.0 / 255) else: img = observation lead_dim, T, B, img_shape = infer_leading_dims(img, 3) conv = self.conv(img.view(T * B, *img_shape)) if self.stop_conv_grad: conv = conv.detach() fc1 = F.relu(self.fc1(conv.view(T * B, -1))) lstm_input = torch.cat( [ fc1.view(T, B, -1), prev_action.view(T, B, -1), # Assumed onehot prev_reward.view(T, B, 1), ], dim=2, ) init_rnn_state = None if init_rnn_state is None else tuple( init_rnn_state) lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state) if self._skip_lstm: lstm_out = lstm_out.view(T * B, -1) + fc1 pi_v = self.pi_v_head(lstm_out.view(T * B, -1)) pi = F.softmax(pi_v[:, :-1], dim=-1) v = pi_v[:, -1] pi, v, conv = restore_leading_dims((pi, v, conv), lead_dim, T, B) next_rnn_state = RnnState(h=hn, c=cn) return pi, v, next_rnn_state, conv
def forward(self, observation, prev_action=None, prev_reward=None): lead_dim, T, B, img_shape = infer_leading_dims(observation, 3) obs = observation.view(T * B, *img_shape) z, mu, logsd = self.encoder(obs) reconstruction = self.decoder(z) extractor_in = mu if self.deterministic else z if self.detach_vae: extractor_in = extractor_in.detach() extractor_out = self.shared_extractor(extractor_in) if self.detach_policy: policy_in = extractor_out.detach() else: policy_in = extractor_out if self.detach_value: value_in = extractor_out.detach() else: value_in = extractor_out act_dist = self.policy(policy_in) value = self.value(value_in).squeeze(-1) if self.rae: latent = mu else: latent = torch.cat((mu, logsd), dim=1) act_dist, value, latent, reconstruction = restore_leading_dims( (act_dist, value, latent, reconstruction), lead_dim, T, B) return act_dist, value, latent, reconstruction
def forward(self, image, prev_action, prev_reward): """ Compute action probabilities and value estimate from input state. Infers leading dimensions of input: can be [T,B], [B], or []; provides returns with same leading dims. Convolution layers process as [T*B, *image_shape], with T=1,B=1 when not given. Expects uint8 images in [0,255] and converts them to float32 in [0,1] (to minimize image data storage and transfer). Used in both sampler and in algorithm (both via the agent). """ img = image.type(torch.float) # Expect torch.uint8 inputs img = img.mul_(1. / 255) # From [0-255] to [0-1], in place. # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, img_shape = infer_leading_dims(img, 3) fc_out = self.conv(img.view(T * B, *img_shape)) pi = self.pi(fc_out) q = self.q(fc_out) beta = self.beta(fc_out) pi_omega_I = self.pi_omega(fc_out) I = self.pi_omega_I(fc_out) pi_omega_I = pi_omega_I * I # Restore leading dimensions: [T,B], [B], or [], as input. pi, q, beta, pi_omega_I = restore_leading_dims( (pi, q, beta, pi_omega_I), lead_dim, T, B) return pi, q, beta, pi_omega_I
def forward(self, observation, prev_action, prev_reward, init_rnn_state): """ Compute action probabilities and value estimate NOTE: Rnn concatenates previous action and reward to input """ # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) # Convert init_rnn_state appropriately if init_rnn_state is not None: if self.rnn_type == 'gru': init_rnn_state = init_rnn_state.h # namedarraytuple -> h else: init_rnn_state = tuple(init_rnn_state) # namedarraytuple -> tuple (h, c) oh = self.preprocessor(observation) # Leave in TxB format for lstm rnn_input = torch.cat([ oh.view(T,B,-1), prev_action.view(T, B, -1), # Assumed onehot. prev_reward.view(T, B, 1), ], dim=2) rnn_out, h = self.rnn(rnn_input, init_rnn_state) rnn_out = rnn_out.view(T*B, -1) pi, beta, q, pi_omega, q_ent = self.model(rnn_out) # Restore leading dimensions: [T,B], [B], or [], as input. pi, beta, q, pi_omega, q_ent = restore_leading_dims((pi, beta, q, pi_omega, q_ent), lead_dim, T, B) # Model should always leave B-dimension in rnn state: [N,B,H]. if self.rnn_type == 'gru': next_rnn_state = GruState(h=h) else: next_rnn_state = RnnState(*h) return pi, beta, q, pi_omega, next_rnn_state
def forward(self, observation, prev_action, prev_reward, init_rnn_state): """Feedforward layers process as [T*B,H]. Return same leading dims as input, can be [T,B], [B], or [].""" img = observation.type(torch.float) # Expect torch.uint8 inputs img = img.mul_(1. / 255) # From [0-255] to [0-1], in place. # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, img_shape = infer_leading_dims(img, 3) conv_out = self.conv(img.view(T * B, *img_shape)) # Fold if T dimension. lstm_input = torch.cat( [ conv_out.view(T, B, -1), prev_action.view(T, B, -1), # Assumed onehot. prev_reward.view(T, B, 1), ], dim=2) init_rnn_state = None if init_rnn_state is None else tuple( init_rnn_state) lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state) q = self.head(lstm_out.view(T * B, -1)) # Restore leading dimensions: [T,B], [B], or [], as input. q = restore_leading_dims(q, lead_dim, T, B) # Model should always leave B-dimension in rnn state: [N,B,H]. next_rnn_state = RnnState(h=hn, c=cn) return q, next_rnn_state
def forward(self, observation, prev_action, prev_reward): """ Compute mean, log_std, and value estimate from input state. Infers leading dimensions of input: can be [T,B], [B], or []; provides returns with same leading dims. Intermediate feedforward layers process as [T*B,H], with T=1,B=1 when not given. Used both in sampler and in algorithm (both via the agent). """ # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) if self.normalize_observation: obs_var = self.obs_rms.var if self.norm_obs_var_clip is not None: obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip) observation = torch.clamp( (observation - self.obs_rms.mean) / obs_var.sqrt(), -self.norm_obs_clip, self.norm_obs_clip) obs_flat = observation.view(T * B, -1) mu = self.mu(obs_flat) v = self.v(obs_flat).squeeze(-1) log_std = self.log_std.repeat(T * B, 1) # Restore leading dimensions: [T,B], [B], or [], as input. mu, log_std, v = restore_leading_dims((mu, log_std, v), lead_dim, T, B) return mu, log_std, v
def forward(self, observation, prev_action, prev_reward, init_rnn_state): """Feedforward layers process as [T*B,H]. Return same leading dims as input, can be [T,B], [B], or [].""" # Infer (presence of) leading dimensions: [T,B], [B], or []. obs_shape, T, B, has_T, has_B = infer_leading_dims( observation, self._obs_n_dim) mlp_out = self.mlp(observation.view(T * B, -1)) lstm_input = torch.cat([ mlp_out.view(T, B, -1), prev_action.view(T, B, -1), prev_reward.view(T, B, 1), ], dim=2) init_rnn_state = None if init_rnn_state is None else tuple( init_rnn_state) lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state) outputs = self.head(lstm_out) mu = outputs[:, :self._action_size] log_std = outputs[:, self._action_size:-1] v = outputs[:, -1].squeeze(-1) # Restore leading dimensions: [T,B], [B], or [], as input. mu, log_std, v = restore_leading_dims((mu, log_std, v), T, B, has_T, has_B) # Model should always leave B-dimension in rnn state: [N,B,H] next_rnn_state = RnnState(h=hn, c=cn) return mu, log_std, v, next_rnn_state
def forward(self, observation, prev_action, prev_reward): if observation.dtype == torch.uint8: img = observation.type(torch.float) img = img.mul_(1.0 / 255) else: img = observation lead_dim, T, B, img_shape = infer_leading_dims(img, 3) conv = self.conv(img.view(T * B, *img_shape)) if self.stop_conv_grad: conv = conv.detach() if self.normalize_conv_out: conv_var = self.conv_rms.var conv_var = torch.clamp(conv_var, min=self.var_clip) # stddev of uniform [a,b] = (b-a)/sqrt(12), 1/sqrt(12)~0.29 # then allow [0, 10]? conv = torch.clamp(0.29 * conv / conv_var.sqrt(), 0, 10) pi_v = self.pi_v_mlp(conv.view(T * B, -1)) pi = F.softmax(pi_v[:, :-1], dim=-1) v = pi_v[:, -1] pi, v, conv = restore_leading_dims((pi, v, conv), lead_dim, T, B) return pi, v, conv
def forward(self, observation, prev_action, prev_reward): obs_shape, T, B, has_T, has_B = infer_leading_dims( observation, self._obs_ndim) mu = self._output_max * torch.tanh( self.mlp(observation.view(T * B, -1))) mu = restore_leading_dims(mu, T, B, has_T, has_B) return mu
def forward(self, observation, prev_action, prev_reward): lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) output = self.mlp(observation.view(T * B, -1)) mu, log_std = output[:, :self._action_size], output[:, self._action_size:] mu, log_std = restore_leading_dims((mu, log_std), lead_dim, T, B) return mu, log_std