def forward(self, obs, extract_sym_features: bool = False): imgs = obs.img lead_dim, T, B, img_shape = infer_leading_dims(imgs, 3) imgs = imgs.view(T * B, *img_shape) # NHWC -> NCHW imgs = imgs.transpose(3, 2).transpose(2, 1) # TODO: don't do uint8 -> float32 conversion twice (in detector and here) if self.sym_extractor is not None and extract_sym_features: sym_feats = self.sym_extractor(imgs[:, -1:]) else: sym_feats = None # convert from [0, 255] uint8 to [0, 1] float32 imgs = imgs.to(torch.float32) imgs.mul_(1.0 / 255) feats = self.feature_extractor(imgs) vector_obs = obs.vector.view(-1, obs.vector.shape[-1]) feats = torch.cat((feats, vector_obs), -1) feats = self.relu(self.linear(feats)) value = self.value_predictor(feats).squeeze( -1) # squeezing seems expected? if self.categorical: action_dist = self.action_predictor(feats) action_dist, value = restore_leading_dims((action_dist, value), lead_dim, T, B) if self.sym_extractor is not None and extract_sym_features: # have to "restore dims" for sym_feats manually... sym_feats = sym_feats.reshape((*action_dist.shape[:-1], -1)) # sym_feats = torch.rand_like(action_dist) return action_dist, value, sym_feats else: mu = self.mu_predictor(feats) log_std = self.log_std_predictor(feats) mu, log_std, value = restore_leading_dims((mu, log_std, value), lead_dim, T, B) if self.sym_extractor is not None and extract_sym_features: # have to "restore dims" for sym_feats manually... sym_feats = sym_feats.reshape((*mu.shape[:-1], -1)) # sym_feats = torch.rand_like(action_dist) return mu, log_std, value, sym_feats
def forward(self, observation, prev_action, prev_reward): """ Feedforward layers process as [T*B,H]. Return same leading dims as input, can be [T,B], [B], or []. forward是在agent类 DqnAgent 里面的 step() 函数里隐式调用的,由于 torch.nn.Module 类定义了 __call__(),因此 DqnAgent.step() 里面通过 self.model(*model_inputs) 这种方式就相当于调用了forward()。 """ img = observation.type(torch.float) # Expect torch.uint8 inputs img = img.mul_(1. / 255) # From [0-255] to [0-1], in place. # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, img_shape = infer_leading_dims(img, 3) conv_out = self.conv(img.view(T * B, *img_shape)) # Fold if T dimension. q = self.head(conv_out.view(T * B, -1)) # Restore leading dimensions: [T,B], [B], or [], as input. q = restore_leading_dims(q, lead_dim, T, B) return q
def forward(self, image, prev_action=None, prev_reward=None): #input normalization, cast to float then grayscale it img = image.type(torch.float) img = img.mul_(1. / 255) lead_dim, T, B, img_shape = infer_leading_dims(img, 3) img = img.view(T * B, *img_shape) if self.augment_obs != None: b, c, h, w = img.shape mask_vbox = torch.zeros(size=img.shape, dtype=torch.bool, device=img.device) mh = math.ceil(h * 0.20) #2 squares side by side mw = mh * 2 ##create velocity mask -> False where velocity box is, True rest of the screen vmask = torch.ones((b, c, mh, mw), dtype=torch.bool, device=img.device) mask_vbox[:,:,:mh,:mw] = vmask obs_without_vbox = torch.where(mask_vbox, torch.zeros_like(img), img) if self.augment_obs == 'cutout': augmented = random_cutout_color(obs_without_vbox) elif self.augment_obs == 'jitter': if self.transform is None: self.transform = ColorJitterLayer(b) augmented = self.transform(obs_without_vbox) elif self.augment_obs == 'rand_conv': augmented = random_convolution(obs_without_vbox) fixed = torch.where(mask_vbox, img, augmented) img = fixed fc_out = self.conv(img) pi = F.softmax(self.pi(fc_out), dim=-1) v = self.value(fc_out).squeeze(-1) # Restore leading dimensions: [T,B], [B], or [], as input. #T -> transition #B -> batch_size? pi, v = restore_leading_dims((pi, v), lead_dim, T, B) return pi, v
def forward(self, observation, prev_action, prev_reward, init_rnn_state): """Feedforward layers process as [T*B,H]. Return same leading dims as input, can be [T,B], [B], or [].""" img = observation.type(torch.float) # Expect torch.uint8 inputs img = img.mul_(1. / 255) # From [0-255] to [0-1], in place. # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, img_shape = infer_leading_dims(img, 3) conv_out = self.conv(img.view(T * B, *img_shape)) # Fold if T dimension. if self.use_recurrence: rnn_input = torch.cat( [ conv_out.view(T, B, -1), prev_action.view(T, B, -1), # Assumed onehot. prev_reward.view(T, B, 1), ], dim=2) init_rnn_state = None if init_rnn_state is None else tuple( init_rnn_state) else: rnn_input = torch.cat( [ conv_out.view(T, B, -1), prev_action.view(T, B, -1), # Assumed onehot. prev_reward.view(T, B, 1), ], dim=2) init_rnn_state = None rnn_out, (hn, ) = self.run_rnn(rnn_input, init_rnn_state) q = self.head(rnn_out.view(T * B, -1)) # Restore leading dimensions: [T,B], [B], or [], as input. q = restore_leading_dims(q, lead_dim, T, B) # Model should always leave B-dimension in rnn state: [N,B,H]. next_rnn_state = RnnState(h=hn) return q, next_rnn_state
def forward(self, observation, prev_action, prev_reward): """ Compute mean, log_std, and value estimate from input state. Infers leading dimensions of input: can be [T,B], [B], or []; provides returns with same leading dims. Intermediate feedforward layers process as [T*B,H], with T=1,B=1 when not given. Used both in sampler and in algorithm (both via the agent). """ # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) obs_flat = observation.view(T * B, -1) mu = self.mu(obs_flat) v = self.v(obs_flat).squeeze(-1) log_std = self.log_std.repeat(T * B, 1) # Restore leading dimensions: [T,B], [B], or [], as input. mu, log_std, v = restore_leading_dims((mu, log_std, v), lead_dim, T, B) return mu, log_std, v
def forward(self, obs, extract_sym_features: bool = False): lead_dim, T, B, _ = infer_leading_dims(obs, 1) feats = self.feature_extractor(obs.view(T * B, -1)) value = self.value_predictor(feats).squeeze( -1) # squeezing seems expected? sym_feats = obs if extract_sym_features else None if self.categorical: action_dist = self.action_predictor(feats) action_dist, value = restore_leading_dims((action_dist, value), lead_dim, T, B) return action_dist, value, sym_feats else: mu = self.mu_predictor(feats) log_std = self.log_std_predictor(feats) mu, log_std, value = restore_leading_dims((mu, log_std, value), lead_dim, T, B) return mu, log_std, value, sym_feats
def forward(self, image, prev_action, prev_reward): """ Compute action probabilities and value estimate from input state. Infers leading dimensions of input: can be [T,B], [B], or []; provides returns with same leading dims. Convolution layers process as [T*B, *image_shape], with T=1,B=1 when not given. Expects uint8 images in [0,255] and converts them to float32 in [0,1] (to minimize image data storage and transfer). Used in both sampler and in algorithm (both via the agent). """ img = image.type(torch.float) # Expect torch.uint8 inputs img = img.mul_(1. / 255) # From [0-255] to [0-1], in place. # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, img_shape = infer_leading_dims(img, 3) fc_out = self.conv(img.view(T * B, *img_shape)) pi = F.softmax(self.pi(fc_out), dim=-1) v = self.value(fc_out).squeeze(-1) # Restore leading dimensions: [T,B], [B], or [], as input. pi, v = restore_leading_dims((pi, v), lead_dim, T, B) return pi, v
def forward(self, obs, q_or_pi: str, extract_sym_features: bool = False): """ :returns: (q1, q2, sym_features) or (mu, log_std, sym_features) """ if q_or_pi == "q": obs, actions = obs imgs = obs.img lead_dim, T, B, img_shape = infer_leading_dims(imgs, 3) imgs = imgs.view(T * B, *img_shape) # NHWC -> NCHW imgs = imgs.transpose(3, 2).transpose(2, 1) # TODO: don't do uint8 -> float32 conversion twice (in detector and here) if self.sym_extractor is not None and extract_sym_features: sym_feats = self.sym_extractor(imgs[:, -1:]) else: sym_feats = None # convert from [0, 255] uint8 to [0, 1] float32 imgs = imgs.to(torch.float32) imgs = imgs.mul(1.0 / 255) vector_obs = obs.vector.view(-1, obs.vector.shape[-1]) feats = self.feature_extractor(imgs, vector_obs) if q_or_pi == "q": feats = torch.cat((feats, actions.view(-1, actions.shape[-1])), -1) r1 = self.q1(feats) r2 = self.q2(feats) elif q_or_pi == "pi": r = self.pi(feats) r1 = r[:, :self.action_dim] # mu r2 = r[:, self.action_dim:] # log_std else: raise ValueError("q_or_pi must be 'q' or 'pi'.") r1, r2 = restore_leading_dims((r1, r2), lead_dim, T, B) if self.sym_extractor is not None and extract_sym_features: # have to "restore dims" for sym_feats manually... sym_feats = sym_feats.reshape((*r1.shape[:-1], -1)) return r1, r2, sym_feats
def forward(self, observation, prev_action, prev_reward, init_rnn_state): """ Compute mean, log_std, and value estimate from input state. Infer leading dimensions of input: can be [T,B], [B], or []; provides returns with same leading dims. Intermediate feedforward layers process as [T*B,H], and recurrent layers as [T,B,H], with T=1,B=1 when not given. Used both in sampler and in algorithm (both via the agent). Also returns the next RNN state. """ # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_n_dim) if self.normalize_observation: obs_var = self.obs_rms.var if self.norm_obs_var_clip is not None: obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip) observation = torch.clamp( (observation - self.obs_rms.mean) / obs_var.sqrt(), -self.norm_obs_clip, self.norm_obs_clip) mlp_out = self.mlp(observation.view(T * B, -1)) lstm_input = torch.cat([ mlp_out.view(T, B, -1), prev_action.view(T, B, -1), prev_reward.view(T, B, 1), ], dim=2) init_rnn_state = None if init_rnn_state is None else tuple( init_rnn_state) lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state) outputs = self.head(lstm_out.view(T * B, -1)) mu = outputs[:, :self._action_size] log_std = outputs[:, self._action_size:-1] v = outputs[:, -1] # Restore leading dimensions: [T,B], [B], or [], as input. mu, log_std, v = restore_leading_dims((mu, log_std, v), lead_dim, T, B) # Model should always leave B-dimension in rnn state: [N,B,H] next_rnn_state = RnnState(h=hn, c=cn) return mu, log_std, v, next_rnn_state
def next(self, actions, observation, prev_action, prev_reward): if isinstance(observation, tuple): observation = torch.cat(observation, dim=-1) lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) input_obs = observation.view(T * B, -1) if self._counter == 0: output = self.mlp_loc(input_obs) mu, log_std = output.chunk(2, dim=-1) elif self._counter == 1: assert len(actions) == 1 action_loc = actions[0].view(T * B, -1) model_input = torch.cat((input_obs, action_loc.repeat((1, self._n_tile))), dim=-1) output = self.mlp_delta(model_input) mu, log_std = output.chunk(2, dim=-1) else: raise Exception('Invalid self._counter', self._counter) mu, log_std = restore_leading_dims((mu, log_std), lead_dim, T, B) self._counter += 1 return mu, log_std
def forward(self, image, prev_action, prev_reward): """ Overrides AtariFfModel forward to also run separate intrinsic value head. """ img = image.type(torch.float) # Expect torch.uint8 inputs img = img.mul_(1. / 255) # From [0-255] to [0-1], in place. # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, img_shape = infer_leading_dims(img, 3) fc_out = self.conv(img.view(T * B, *img_shape)) pi = F.softmax(self.pi(fc_out), dim=-1) ext_val = self.value(fc_out).squeeze(-1) int_val = self.int_value(fc_out).squeeze(-1) # Restore leading dimensions: [T,B], [B], or [], as input. pi, ext_val, int_val = restore_leading_dims((pi, ext_val, int_val), lead_dim, T, B) return pi, ext_val, int_val
def forward(self, observation: torch.Tensor, prev_action: torch.Tensor = None, prev_state: RSSMState = None): lead_dim, T, B, img_shape = infer_leading_dims(observation, 3) observation = observation.reshape(T * B, *img_shape).type( self.dtype) / 255.0 - 0.5 prev_action = prev_action.reshape(T * B, -1).to(self.dtype) if prev_state is None: prev_state = self.representation.initial_state( prev_action.size(0), device=prev_action.device, dtype=self.dtype) state = self.get_state_representation(observation, prev_action, prev_state) action, action_dist = self.policy(state) return_spec = ModelReturnSpec(action, state) return_spec = buffer_func(return_spec, restore_leading_dims, lead_dim, T, B) return return_spec
def write_videos(self, observation, action, image_pred, post, step=None, n=4, t=25): """ observation shape T,N,C,H,W generates n rollouts with the model. For t time steps, observations are used to generate state representations. Then for time steps t+1:T, uses the state transition model. Outputs 3 different frames to video: ground truth, reconstruction, error """ lead_dim, batch_t, batch_b, img_shape = infer_leading_dims(observation, 3) model = self.agent.model ground_truth = observation[:, :n] + 0.5 reconstruction = image_pred.mean[:t, :n] prev_state = post[t - 1, :n] prior = model.rollout.rollout_transition(batch_t - t, action[t:, :n], prev_state) imagined = model.observation_decoder(get_feat(prior)).mean model = torch.cat((reconstruction, imagined), dim=0) + 0.5 error = (model - ground_truth + 1) / 2 # concatenate vertically on height dimension openl = torch.cat((ground_truth, model, error), dim=3) openl = openl.transpose(1, 0) # N,T,C,H,W video_summary('videos/model_error', torch.clamp(openl, 0., 1.), step)
def forward(self, observation, prev_action, prev_reward): """ Compute action Q-value estimates from input state. Infers leading dimensions of input: can be [T,B], [B], or []; provides returns with same leading dims. Convolution layers process as [T*B, image_shape[0], image_shape[1],...,image_shape[-1]], with T=1,B=1 when not given. Expects uint8 images in [0,255] and converts them to float32 in [0,1] (to minimize image data storage and transfer). Used in both sampler and in algorithm (both via the agent). """ img = observation.type(torch.float) # Expect torch.uint8 inputs img = img.mul_(1. / 255) # From [0-255] to [0-1], in place. # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, img_shape = infer_leading_dims(img, 3) conv_out = self.conv(img.view(T * B, *img_shape)) # Fold if T dimension. q = self.head(conv_out.view(T * B, -1)) # Restore leading dimensions: [T,B], [B], or [], as input. q = restore_leading_dims(q, lead_dim, T, B) return q
def forward(self, image, prev_action, prev_reward, init_rnn_state): """ Compute action probabilities and value estimate from input state. Infers leading dimensions of input: can be [T,B], [B], or []; provides returns with same leading dims. Convolution layers process as [T*B, *image_shape], with T=1,B=1 when not given. Expects uint8 images in [0,255] and converts them to float32 in [0,1] (to minimize image data storage and transfer). Recurrent layers processed as [T,B,H]. Used in both sampler and in algorithm (both via the agent). Also returns the next RNN state. """ if self.obs_stats is not None: # don't normalize observation image = (image - self.obs_mean) / self.obs_std img = image.type(torch.float) # Expect torch.uint8 inputs # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, img_shape = infer_leading_dims(img, 3) fc_out = self.conv(img.view(T * B, *img_shape)) lstm_input = torch.cat( [ fc_out.view(T, B, -1), prev_action.view(T, B, -1), # Assumed onehot. ], dim=2) init_rnn_state = None if init_rnn_state is None else tuple( init_rnn_state) lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state) pi = F.softmax(self.pi(lstm_out.view(T * B, -1)), dim=-1) v = self.value(lstm_out.view(T * B, -1)).squeeze(-1) # Restore leading dimensions: [T,B], [B], or [], as input. pi, v = restore_leading_dims((pi, v), lead_dim, T, B) # Model should always leave B-dimension in rnn state: [N,B,H]. next_rnn_state = RnnState(h=hn, c=cn) return pi, v, next_rnn_state
def forward(self, observation, prev_action=None, prev_reward=None): self.device = observation.device lead_dim, T, B, img_shape = infer_leading_dims(observation, 3) obs = observation.view(T*B, *img_shape) obs = obs.permute(0, 3, 1, 2).float() / 255. if self.greyscale: obs = torch.mean(obs,dim=1, keepdims=True) noise_idx = None if self.noise_prob: obs, noise_idx = salt_and_pepper(obs,self.noise_prob) z, mu, logsd = self.encoder(obs) reconstruction = self.decoder(z) extractor_in = mu if self.deterministic else z if self.detach_vae: extractor_in = extractor_in.detach() extractor_out = self.shared_extractor(extractor_in) if self.detach_policy: policy_in = extractor_out.detach() else: policy_in = extractor_out if self.detach_value: value_in = extractor_out.detach() else: value_in = extractor_out act_dist = self.policy(policy_in) value = self.value(value_in).squeeze(-1) if self.rae: latent = mu else: latent = torch.cat((mu, logsd), dim=1) act_dist, value, latent, reconstruction = restore_leading_dims((act_dist, value, latent, reconstruction), lead_dim, T, B) return act_dist, value, latent, reconstruction, noise_idx
def forward(self, observation, prev_action, prev_reward): """ Compute mean, log_std, q-value, and termination estimates from input state. Infers leading dimensions of input: can be [T,B], [B], or []; provides returns with same leading dims. Intermediate feedforward layers process as [T*B,H], with T=1,B=1 when not given. Used both in sampler and in algorithm (both via the agent). """ # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) if self.normalize_observation: obs_var = self.obs_rms.var if self.norm_obs_var_clip is not None: obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip) observation = torch.clamp((observation - self.obs_rms.mean) / obs_var.sqrt(), -self.norm_obs_clip, self.norm_obs_clip) obs_flat = observation.view(T * B, -1) (mu, logstd), beta, q, pi_I, q_ent = self.model(obs_flat) log_std = logstd.repeat(T * B, 1, 1) # Restore leading dimensions: [T,B], [B], or [], as input. mu, log_std, q, beta, pi, q_ent = restore_leading_dims((mu, log_std, q, beta, pi_I, q_ent), lead_dim, T, B) return mu, log_std, beta, q, pi
def head_forward(self, conv_out, prev_action, prev_reward, logits=False): if len(conv_out.shape) > 2: lead_dim, T, B, img_shape = infer_leading_dims(conv_out, 3) # 1, 1, 32, [64, 7, 7] p = self.head(conv_out) # [32, 4, 51] if self.distributional: if logits: p = F.log_softmax(p, dim=-1) else: p = F.softmax(p, dim=-1) else: p = p.squeeze(-1) # Restore leading dimensions: [T,B], [B], or [], as input. if len(conv_out.shape) > 2: p = restore_leading_dims(p, lead_dim, T, B) # [32, 4, 51] # [32, 4, 51] return p
def forward(self, observation, prev_action, prev_reward): if observation.dtype == torch.uint8: img = observation.type(torch.float) img = img.mul_(1.0 / 255) else: img = observation lead_dim, T, B, img_shape = infer_leading_dims(img, 3) conv = self.conv(img.view(T * B, *img_shape)) if self.stop_conv_grad: conv = conv.detach() if self.normalize_conv_out: conv_var = self.conv_rms.var conv_var = torch.clamp(conv_var, min=self.var_clip) # stddev of uniform [a,b] = (b-a)/sqrt(12), 1/sqrt(12)~0.29 # then allow [0, 10]? conv = torch.clamp(0.29 * conv / conv_var.sqrt(), 0, 10) q = self.q_mlp(conv.view(T * B, -1)) q, conv = restore_leading_dims((q, conv), lead_dim, T, B) return q, conv
def forward(self, observation, prev_action=None, prev_reward=None): """ Compute action Q-value estimates from input state. Infers leading dimensions of input: can be [T,B], [B], or []; provides returns with same leading dims. Convolution layers process as [T*B, image_shape[0], image_shape[1],...,image_shape[-1]], with T=1,B=1 when not given. Expects uint8 images in [0,255] and converts them to float32 Used in both sampler and in algorithm (both via the agent). """ obs = observation.type(torch.float) # Expect torch.uint8 inputs # Infer (presence of) leading dimensions: [T,B], [B], or []. if len(observation.shape) < 3: reshape = (observation.shape[0], -1) if len(observation.shape) == 2 else (-1,) view = obs.view(*reshape) return self.head(view) lead_dim, T, B, _ = infer_leading_dims(obs, 3) q = self.head(obs.view(T * B, -1)) # Restore leading dimensions: [T,B], [B], or [], as input. q = restore_leading_dims(q, lead_dim, T, B) return q
def forward(self, observation, prev_action, prev_reward): """Args: x: tensor shape [batch_size, input_size] train: a boolean scalar. loss_coef: a scalar - multiplier on load-balancing losses Returns: y: a tensor with shape [batch_size, output_size]. extra_training_loss: a scalar. This should be added into the overall training loss of the model. The backpropagation of this loss encourages all experts to be approximately equally used across a batch. """ train = self.training observation = observation.float() self.device = observation.device # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, obs_shape = infer_leading_dims(observation, 1) observation = observation.view(T * B, *obs_shape) action_mask = observation[:, -19:].type(torch.bool) observation = observation[:, :-19] z = self.encoder(observation) gates, load = self.noisy_top_k_gating(z, train) dispatcher = SparseDispatcher(self.num_experts, gates) expert_inputs = dispatcher.dispatch(z) gates = dispatcher.expert_to_gates() expert_outputs = [ self.experts[i](expert_inputs[i]) for i in range(self.num_experts) ] y = dispatcher.combine(expert_outputs, device=self.device) value = self.value(observation).squeeze(-1) y[~action_mask] = -1e24 y = nn.functional.softmax(y, dim=-1) y, value = restore_leading_dims((y, value), lead_dim, T, B) return y, value
def forward(self, observation, prev_action, prev_reward, init_rnn_state): """Feedforward layers process as [T*B,H]. Return same leading dims as input, can be [T,B], [B], or [].""" obz = observation.type(torch.float) # Expect torch.uint8 inputs # Infer (presence of) leading dimensions: [T,B], [B], or []. lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_n_dim) if self.normalize_observation: obs_var = self.obs_rms.var if self.norm_obs_var_clip is not None: obs_var = torch.clamp(obs_var, min=self.norm_obs_var_clip) observation = torch.clamp( (observation - self.obs_rms.mean) / obs_var.sqrt(), -self.norm_obs_clip, self.norm_obs_clip) mlp_out = self.mlp(observation.view(T * B, -1)) lstm_input = torch.cat([ mlp_out.view(T, B, -1), prev_action.view(T, B, -1), prev_reward.view(T, B, 1), ], dim=2) init_rnn_state = None if init_rnn_state is None else tuple( init_rnn_state) lstm_out, (hn, cn) = self.lstm(lstm_input, init_rnn_state) q = self.head(lstm_out.view(T * B, -1)) # Restore leading dimensions: [T,B], [B], or [], as input. q = restore_leading_dims(q, lead_dim, T, B) # Model should always leave B-dimension in rnn state: [N,B,H]. next_rnn_state = RnnState(h=hn, c=cn) return q, next_rnn_state
def forward(self, obs, prev_action, prev_reward): """Feedforward layers process as [T*B,H]. Return same leading dims as input, can be [T,B], [B], or [].""" # print(obs.shape) # print(obs.target_im.shape) # print(obs.cur_im.shape) # x1 = self.conv(obs.target_im) # x2 = self.conv(obs.cur_im) # obs.cur_coord lead_dim, T, B, img_shape = infer_leading_dims(obs.target_im, 3) x1 = self.conv(obs.target_im.view(T * B, *img_shape)) x2 = self.conv(obs.cur_im.view(T * B, *img_shape)) x = torch.cat((x1, x2), dim=1) # x = self.conv(obs.view(T * B, *img_shape)) x = self.fc(x.view(T * B, -1)) pi = F.softmax(self.pi(x), dim=-1) v = self.value(x).squeeze(-1) pi, v = restore_leading_dims((pi, v), lead_dim, T, B) return pi, v
def forward(self, observation, prev_action, prev_reward, init_rnn_state): lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_dim) if init_rnn_state is not None: if self.rnn_is_lstm: init_rnn_pi, init_rnn_v = tuple(init_rnn_state) # DualRnnState -> RnnState_pi, RnnState_v init_rnn_pi, init_rnn_v = tuple(init_rnn_pi), tuple(init_rnn_v) else: init_rnn_pi, init_rnn_v = tuple(init_rnn_state) # DualRnnState -> h, h else: init_rnn_pi, init_rnn_v = None, None o_flat = observation.view(T*B, -1) b_pi, b_v = self.body_pi(o_flat), self.body_v(o_flat) rnn_input_pi = torch.cat([b_pi.view(T,B,-1),prev_action.view(T, B, -1),prev_reward.view(T, B, 1),], dim=2) rnn_input_v = torch.cat([b_v.view(T, B, -1), prev_action.view(T, B, -1), prev_reward.view(T, B, 1), ], dim=2) rnn_pi, next_rnn_state_pi = self.rnn_pi(rnn_input_pi, init_rnn_pi) rnn_v, next_rnn_state_v = self.rnn_pi(rnn_input_v, init_rnn_v) rnn_pi = rnn_pi.view(T*B, -1); rnn_v = rnn_v.view(T*B, -1) pi, v = self.pi(rnn_pi), self.v(rnn_v).squeeze(-1) pi, v = restore_leading_dims((pi, v), lead_dim, T, B) if self.rnn_is_lstm: next_rnn_state = DualRnnState(RnnState(*next_rnn_state_pi), RnnState(*next_rnn_state_v)) else: next_rnn_state = DualRnnState(next_rnn_state_pi, next_rnn_state_v) return pi, v, next_rnn_state
def forward(self, observation, prev_action, prev_reward): obs_shape, T, B, has_T, has_B = infer_leading_dims( observation, self._obs_ndim) v = self.mlp(observation.view(T * B, -1)).squeeze(-1) v = restore_leading_dims(v, T, B, has_T, has_B) return v
def forward(self, observation, prev_action, prev_reward): lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) mu = self._output_max * torch.tanh( self.mlp(observation.view(T * B, -1))) mu = restore_leading_dims(mu, lead_dim, T, B) return mu
def evaluate(self, obs_tuple, act_tensor, update_stats=True): # put model into eval mode if necessary old_training = self.reward_model.training if old_training: self.reward_model.eval() with torch.no_grad(): # flatten observations & actions obs_image = obs_tuple.observation old_dev = obs_image.device lead_dim, T, B, _ = infer_leading_dims(obs_image, self.obs_dims) # use tree_map so we are able to handle the namedtuple directly obs_flat = tree_map( lambda t: t.view((T * B, ) + t.shape[lead_dim:]), obs_tuple) act_flat = act_tensor.view((T * B, ) + act_tensor.shape[lead_dim:]) # now evaluate one batch at a time reward_tensors = [] for b_start in range(0, T * B, self.batch_size): obs_batch = obs_flat[b_start:b_start + self.batch_size] act_batch = act_flat[b_start:b_start + self.batch_size] dev_obs = tree_map(lambda t: t.to(self.dev), obs_batch) dev_acts = act_batch.to(self.dev) dev_reward = self.reward_model(dev_obs, dev_acts) reward_tensors.append(dev_reward.to(old_dev)) # join together the batch results new_reward_flat = torch.cat(reward_tensors, 0) new_reward = restore_leading_dims(new_reward_flat, lead_dim, T, B) task_ids = obs_tuple.task_id assert new_reward.shape == task_ids.shape # put back into training mode if necessary if old_training: self.reward_model.train(old_training) # normalise if necessary if self.normalise: mus = [] stds = [] for task_id, averager in enumerate(self.rew_running_averages): if update_stats: id_sub = task_ids.view((-1, )) == task_id if not torch.any(id_sub): continue rew_sub = new_reward.view((-1, ))[id_sub] averager.update(rew_sub) mus.append(averager.mean.item()) stds.append(averager.std.item()) mus = new_reward.new_tensor(mus) stds = new_reward.new_tensor(stds) denom = torch.max(stds.new_tensor(1e-3), stds / self.target_std) denom_sub = denom[task_ids] mu_sub = mus[task_ids] # only bother applying result if we've actually seen an update # before (otherwise reward will be insane) new_reward = (new_reward - mu_sub) / denom_sub return new_reward
def loss(self, samples: SamplesFromReplay, sample_itr: int, opt_itr: int): """ Compute the loss for a batch of data. This includes computing the model and reward losses on the given data, as well as using the dynamics model to generate additional rollouts, which are used for the actor and value components of the loss. :param samples: samples from replay :param sample_itr: sample iteration :param opt_itr: optimization iteration :return: FloatTensor containing the loss """ model = self.agent.model observation = samples.all_observation[: -1] # [t, t+batch_length+1] -> [t, t+batch_length] action = samples.all_action[ 1:] # [t-1, t+batch_length] -> [t, t+batch_length] reward = samples.all_reward[ 1:] # [t-1, t+batch_length] -> [t, t+batch_length] reward = reward.unsqueeze(2) done = samples.done done = done.unsqueeze(2) # Extract tensors from the Samples object # They all have the batch_t dimension first, but we'll put the batch_b dimension first. # Also, we convert all tensors to floats so they can be fed into our models. lead_dim, batch_t, batch_b, img_shape = infer_leading_dims( observation, 3) # squeeze batch sizes to single batch dimension for imagination roll-out batch_size = batch_t * batch_b # normalize image observation = observation.type(self.type) / 255.0 - 0.5 # embed the image embed = model.observation_encoder(observation) prev_state = model.representation.initial_state(batch_b, device=action.device, dtype=action.dtype) # Rollout model by taking the same series of actions as the real model prior, post = model.rollout.rollout_representation( batch_t, embed, action, prev_state) # Flatten our data (so first dimension is batch_t * batch_b = batch_size) # since we're going to do a new rollout starting from each state visited in each batch. # Compute losses for each component of the model # Model Loss feat = get_feat(post) image_pred = model.observation_decoder(feat) reward_pred = model.reward_model(feat) reward_loss = -torch.mean(reward_pred.log_prob(reward)) image_loss = -torch.mean(image_pred.log_prob(observation)) pcont_loss = torch.tensor(0.) # placeholder if use_pcont = False if self.use_pcont: pcont_pred = model.pcont(feat) pcont_target = self.discount * (1 - done.float()) pcont_loss = -torch.mean(pcont_pred.log_prob(pcont_target)) prior_dist = get_dist(prior) post_dist = get_dist(post) div = torch.mean( torch.distributions.kl.kl_divergence(post_dist, prior_dist)) div = torch.max(div, div.new_full(div.size(), self.free_nats)) model_loss = self.kl_scale * div + reward_loss + image_loss if self.use_pcont: model_loss += self.pcont_scale * pcont_loss # ------------------------------------------ Gradient Barrier ------------------------------------------------ # Don't let gradients pass through to prevent overwriting gradients. # Actor Loss # remove gradients from previously calculated tensors with torch.no_grad(): if self.use_pcont: # "Last step could be terminal." Done in TF2 code, but unclear why flat_post = buffer_method(post[:-1, :], 'reshape', (batch_t - 1) * (batch_b), -1) else: flat_post = buffer_method(post, 'reshape', batch_size, -1) # Rollout the policy for self.horizon steps. Variable names with imag_ indicate this data is imagined not real. # imag_feat shape is [horizon, batch_t * batch_b, feature_size] with FreezeParameters(self.model_modules): imag_dist, _ = model.rollout.rollout_policy( self.horizon, model.policy, flat_post) # Use state features (deterministic and stochastic) to predict the image and reward imag_feat = get_feat( imag_dist) # [horizon, batch_t * batch_b, feature_size] # Assumes these are normal distributions. In the TF code it's be mode, but for a normal distribution mean = mode # If we want to use other distributions we'll have to fix this. # We calculate the target here so no grad necessary # freeze model parameters as only action model gradients needed with FreezeParameters(self.model_modules + self.value_modules): imag_reward = model.reward_model(imag_feat).mean value = model.value_model(imag_feat).mean # Compute the exponential discounted sum of rewards if self.use_pcont: with FreezeParameters([model.pcont]): discount_arr = model.pcont(imag_feat).mean else: discount_arr = self.discount * torch.ones_like(imag_reward) returns = self.compute_return(imag_reward[:-1], value[:-1], discount_arr[:-1], bootstrap=value[-1], lambda_=self.discount_lambda) # Make the top row 1 so the cumulative product starts with discount^0 discount_arr = torch.cat( [torch.ones_like(discount_arr[:1]), discount_arr[1:]]) discount = torch.cumprod(discount_arr[:-1], 0) actor_loss = -torch.mean(discount * returns) # ------------------------------------------ Gradient Barrier ------------------------------------------------ # Don't let gradients pass through to prevent overwriting gradients. # Value Loss # remove gradients from previously calculated tensors with torch.no_grad(): value_feat = imag_feat[:-1].detach() value_discount = discount.detach() value_target = returns.detach() value_pred = model.value_model(value_feat) log_prob = value_pred.log_prob(value_target) value_loss = -torch.mean(value_discount * log_prob.unsqueeze(2)) # ------------------------------------------ Gradient Barrier ------------------------------------------------ # loss info with torch.no_grad(): prior_ent = torch.mean(prior_dist.entropy()) post_ent = torch.mean(post_dist.entropy()) loss_info = LossInfo(model_loss, actor_loss, value_loss, prior_ent, post_ent, div, reward_loss, image_loss, pcont_loss) if self.log_video: if opt_itr == self.train_steps - 1 and sample_itr % self.video_every == 0: self.write_videos(observation, action, image_pred, post, step=sample_itr, n=self.video_summary_b, t=self.video_summary_t) return model_loss, actor_loss, value_loss, loss_info
def forward(self, observation, prev_action, prev_reward): lead_dim, T, B, img_shape = infer_leading_dims(observation, 3) action = torch.randint(low=0, high=self.num_actions, size=(T * B, )) action = restore_leading_dims((action), lead_dim, T, B) return action
def compute_bonus(self, observations, prev_actions, actions): #------------------------------------------------------------# lead_dim, T, B, img_shape = infer_leading_dims(observations, 3) # hacky dimension add for when you have only one environment if prev_actions.dim() == 1: prev_actions = prev_actions.view(1, 1, -1) if actions.dim() == 1: actions = actions.view(1, 1, -1) #------------------------------------------------------------# # generate belief states belief_states, gru_output_states = self.forward(observations, prev_actions) self.gru_states = None # only bc we're processing exactly 1 episode per batch # slice beliefs and actions belief_states_t = belief_states.clone()[:-self.horizon] # slice off last timesteps belief_states_tm1 = belief_states.clone()[:-self.horizon-1] action_seqs_t = torch.zeros((T-self.horizon, B, self.horizon*self.action_size), device=self.device) # placeholder action_seqs_tm1 = torch.zeros((T-self.horizon-1, B, (self.horizon+1)*self.action_size), device=self.device) # placeholder for i in range(len(actions)-self.horizon): if i != len(actions)-self.horizon-1: action_seq_tm1 = actions.clone()[i:i+self.horizon+1] action_seq_tm1 = torch.transpose(action_seq_tm1, 0, 1) action_seq_tm1 = torch.reshape(action_seq_tm1, (action_seq_tm1.shape[0], -1)) action_seqs_tm1[i] = action_seq_tm1 action_seq_t = actions.clone()[i:i+self.horizon] action_seq_t = torch.transpose(action_seq_t, 0, 1) action_seq_t = torch.reshape(action_seq_t, (action_seq_t.shape[0], -1)) action_seqs_t[i] = action_seq_t # make forward model predictions if self.horizon == 1: predicted_states_tm1 = self.forward_model_2(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-1, B, 75) predicted_states_t = self.forward_model_1(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-1, B, 75) elif self.horizon == 2: predicted_states_tm1 = self.forward_model_3(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-2, B, 75) predicted_states_t = self.forward_model_2(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-2, B, 75) elif self.horizon == 3: predicted_states_tm1 = self.forward_model_4(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-3, B, 75) predicted_states_t = self.forward_model_3(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-3, B, 75) elif self.horizon == 4: predicted_states_tm1 = self.forward_model_5(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-4, B, 75) predicted_states_t = self.forward_model_4(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-4, B, 75) elif self.horizon == 5: predicted_states_tm1 = self.forward_model_6(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-5, B, 75) predicted_states_t = self.forward_model_5(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-5, B, 75) elif self.horizon == 6: predicted_states_tm1 = self.forward_model_7(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-6, B, 75) predicted_states_t = self.forward_model_6(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-6, B, 75) elif self.horizon == 7: predicted_states_tm1 = self.forward_model_8(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-7, B, 75) predicted_states_t = self.forward_model_7(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-7, B, 75) elif self.horizon == 8: predicted_states_tm1 = self.forward_model_9(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-8, B, 75) predicted_states_t = self.forward_model_8(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-8, B, 75) elif self.horizon == 9: predicted_states_tm1 = self.forward_model_10(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-9, B, 75) predicted_states_t = self.forward_model_9(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-9, B, 75) elif self.horizon == 10: predicted_states_tm1 = self.forward_model_11(belief_states_tm1, action_seqs_tm1.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-10, B, 75) predicted_states_t = self.forward_model_10(belief_states_t, action_seqs_t.detach()).view(-1, B, img_shape[0]*img_shape[1]*img_shape[2]) # (T-10, B, 75) predicted_states_tm1 = nn.functional.sigmoid(predicted_states_tm1) true_obs_tm1 = observations.clone()[self.horizon:-1].view(-1, *predicted_states_tm1.shape[1:]).type(torch.float) predicted_states_t = nn.functional.sigmoid(predicted_states_t) true_obs_t = observations.clone()[self.horizon:].view(-1, *predicted_states_t.shape[1:]).type(torch.float) # generate losses losses_tm1 = nn.functional.binary_cross_entropy(predicted_states_tm1, true_obs_tm1, reduction='none') losses_tm1 = torch.sum(losses_tm1, dim=-1)/losses_tm1.shape[-1] # average of each feature for each environment at each timestep (T, B, ave_loss_over_feature) losses_t = nn.functional.binary_cross_entropy(predicted_states_t, true_obs_t, reduction='none') losses_t = torch.sum(losses_t, dim=-1)/losses_t.shape[-1] # subtract losses to get rewards (r[t+H-1] = losses[t-1] - losses[t]) r_int = torch.zeros((T, B), device=self.device) r_int[self.horizon:len(losses_t)+self.horizon-1] = losses_tm1 - losses_t[1:] # time zero reward is set to 0 (L[-1] doesn't exist) # r_int[self.horizon:len(losses_t)+self.horizon-1] = losses_t[1:] - losses_tm1 # r_int[1:len(losses_t)] = losses_tm1 - losses_t[1:] # r_int[1:len(losses_t)] = losses_t[1:] - losses_tm1 # r_int = nn.functional.relu(r_int) return r_int*self.beta