def step(self, observation, prev_action, prev_reward, device="cpu"): """ Compute policy's option and action distributions from inputs. Calls model to get mean, std for all pi_w, q, beta for all options, pi over options Moves inputs to device and returns outputs back to CPU, for the sampler. (no grad) """ model_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mu, log_std, beta, q, pi = self.model(*model_inputs) dist_info_omega = DistInfo(prob=pi) new_o, terminations = self.sample_option( beta, dist_info_omega) # Sample terminations and options dist_info = DistInfoStd(mean=mu, log_std=log_std) dist_info_o = DistInfoStd(mean=select_at_indexes(new_o, mu), log_std=select_at_indexes(new_o, log_std)) action = self.distribution.sample(dist_info_o) agent_info = AgentInfoOC(dist_info=dist_info, dist_info_o=dist_info_o, q=q, value=(pi * q).sum(-1), termination=terminations, dist_info_omega=dist_info_omega, prev_o=self._prev_option, o=new_o) action, agent_info = buffer_to((action, agent_info), device=device) self.advance_oc_state(new_o) return AgentStep(action=action, agent_info=agent_info)
def beta_dist_infos(self, observation, prev_action, prev_reward, init_rnn_state): model_inputs = buffer_to( (observation, prev_action, prev_reward, init_rnn_state), device=self.device) r_mu, r_log_std, _, _ = self.beta_r_model(*model_inputs) c_mu, c_log_std, _, _ = self.beta_c_model(*model_inputs) return buffer_to((DistInfoStd(mean=r_mu, log_std=r_log_std), DistInfoStd(mean=c_mu, log_std=c_log_std)), device="cpu")
def sample(self, dist_info): logits, delta_dist_info = dist_info.cat_dist, dist_info.delta_dist u = torch.rand_like(logits) u = torch.clamp(u, 1e-5, 1 - 1e-5) gumbel = -torch.log(-torch.log(u)) prob = F.softmax((logits + gumbel) / 10, dim=-1) cat_sample = torch.argmax(prob, dim=-1) one_hot = to_onehot(cat_sample, 4, dtype=torch.float32) if len(prob.shape) == 1: # Edge case for when it gets buffer shapes cat_sample = cat_sample.unsqueeze(0) if self._all_corners: mu, log_std = delta_dist_info.mean, delta_dist_info.log_std mu, log_std = mu.view(-1, 4, 3), log_std.view(-1, 4, 3) mu = select_at_indexes(cat_sample, mu) log_std = select_at_indexes(cat_sample, log_std) if len(prob.shape) == 1: # Edge case for when it gets buffer shapes mu, log_std = mu.squeeze(0), log_std.squeeze(0) new_dist_info = DistInfoStd(mean=mu, log_std=log_std) else: new_dist_info = delta_dist_info if self.training: self.delta_distribution.set_std(None) else: self.delta_distribution.set_std(0) delta_sample = self.delta_distribution.sample(new_dist_info) return torch.cat((one_hot, delta_sample), dim=-1)
def step(self, observation, prev_action, prev_reward, device="cpu"): """ Compute policy's action distribution from inputs, and sample an action. Calls the model to produce mean, log_std, value estimate, and next recurrent state. Moves inputs to device and returns outputs back to CPU, for the sampler. Advances the recurrent state of the agent. (no grad) """ agent_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mu, log_std, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state) dist_info = DistInfoStd(mean=mu, log_std=log_std) action = self.distribution.sample(dist_info) # Model handles None, but Buffer does not, make zeros if needed: prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func( rnn_state, torch.zeros_like) # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. # (Special case: model should always leave B dimension in.) prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) agent_info = AgentInfoRnn(dist_info=dist_info, value=value, prev_rnn_state=prev_rnn_state) action, agent_info = buffer_to((action, agent_info), device=device) self.advance_rnn_state(rnn_state) # Keep on device. return AgentStep(action=action, agent_info=agent_info)
def __call__(self, observation, prev_action, prev_reward, device='cpu'): """Performs forward pass on training data, for algorithm.""" model_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mu, log_std, value = self.model(*model_inputs) return buffer_to((DistInfoStd(mean=mu, log_std=log_std), value), device=device)
def __call__(self, observation, prev_action, prev_reward, init_rnn_state): # Assume init_rnn_state already shaped: [N,B,H] model_inputs = buffer_to((observation, prev_action, prev_reward, init_rnn_state), device=self.device) mu, log_std, value, next_rnn_state = self.model(*model_inputs) dist_info, value = buffer_to((DistInfoStd(mean=mu, log_std=log_std), value), device="cpu") return dist_info, value, next_rnn_state # Leave rnn_state on device.
def __call__(self, observation, prev_action, prev_reward): model_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mu, log_std, value = self.model(*model_inputs) samples = (DistInfoStd(mean=mu, log_std=log_std), value) return buffer_to(samples, device="cpu")
def sample_loglikelihood(self, dist_info): logits, delta_dist_info = dist_info.cat_dist, dist_info.delta_dist u = torch.rand_like(logits) u = torch.clamp(u, 1e-5, 1 - 1e-5) gumbel = -torch.log(-torch.log(u)) prob = F.softmax((logits + gumbel) / 10, dim=-1) cat_sample = torch.argmax(prob, dim=-1) cat_loglikelihood = select_at_indexes(cat_sample, prob) one_hot = to_onehot(cat_sample, 4, dtype=torch.float32) one_hot = (one_hot - prob).detach() + prob # Make action differentiable through prob if self._all_corners: mu, log_std = delta_dist_info.mean, delta_dist_info.log_std mu, log_std = mu.view(-1, 4, 3), log_std.view(-1, 4, 3) mu = mu[torch.arange(len(cat_sample)), cat_sample.squeeze(-1)] log_std = log_std[torch.arange(len(cat_sample)), cat_sample.squeeze(-1)] new_dist_info = DistInfoStd(mean=mu, log_std=log_std) else: new_dist_info = delta_dist_info delta_sample, delta_loglikelihood = self.delta_distribution.sample_loglikelihood(new_dist_info) action = torch.cat((one_hot, delta_sample), dim=-1) log_likelihood = cat_loglikelihood + delta_loglikelihood return action, log_likelihood
def pi(self, observation, prev_action, prev_reward): model_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mean, log_std = self.model(*model_inputs) dist_info = DistInfoStd(mean=mean, log_std=log_std) action, log_pi = self.distribution.sample_loglikelihood(dist_info) log_pi, dist_info = buffer_to((log_pi, dist_info), device="cpu") return action, log_pi, dist_info # Action stays on device for q models.
def step(self, observation, prev_action, prev_reward): model_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mu, log_std, value = self.model(*model_inputs) dist_info = DistInfoStd(mean=mu, log_std=log_std) action = self.distribution.sample(dist_info) agent_info = AgentInfo(dist_info=dist_info, value=value) action, agent_info = buffer_to((action, agent_info), device="cpu") return AgentStep(action=action, agent_info=agent_info)
def __call__(self, observation, prev_action, prev_reward, init_rnn_state): """Performs forward pass on training data, for algorithm (requires recurrent state input).""" # Assume init_rnn_state already shaped: [N,B,H] model_inputs = buffer_to( (observation, prev_action, prev_reward, init_rnn_state), device=self.device) mu, log_std, value, next_rnn_state = self.model(*model_inputs) dist_info, value = buffer_to( (DistInfoStd(mean=mu, log_std=log_std), value), device="cpu") return dist_info, value, next_rnn_state # Leave rnn_state on device.
def pi(self, observation, prev_action, prev_reward): """Compute action log-probabilities for state/observation, and sample new action (with grad). Uses special ``sample_loglikelihood()`` method of Gaussian distribution, which handles action squashing through this process.""" model_inputs = buffer_to(observation, device=self.device) mean, log_std, _ = self.model(model_inputs, "pi") dist_info = DistInfoStd(mean=mean, log_std=log_std) action, log_pi = self.distribution.sample_loglikelihood(dist_info) log_pi, dist_info = buffer_to((log_pi, dist_info), device="cpu") return action, log_pi, dist_info # Action stays on device for q models.
def step(self, observation, prev_action, prev_reward): model_inputs = buffer_to(observation, device=self.device) mean, log_std, sym_features = self.model(model_inputs, "pi", extract_sym_features=True) dist_info = DistInfoStd(mean=mean, log_std=log_std) action = self.distribution.sample(dist_info) agent_info = SafeSacAgentInfo(dist_info=dist_info, sym_features=sym_features) action, agent_info = buffer_to((action, agent_info), device="cpu") return AgentStep(action=action, agent_info=agent_info)
def forward(self, observation, prev_action, prev_reward): if isinstance(observation, tuple): observation = torch.cat(observation, dim=-1) lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) output = self.mlp(observation.view(T * B, -1)) logits = output[:, :4] mu, log_std = output[:, 4:4 + self._delta_dim], output[:, 4 + self._delta_dim:] logits, mu, log_std = restore_leading_dims((logits, mu, log_std), lead_dim, T, B) return GumbelDistInfo(cat_dist=logits, delta_dist=DistInfoStd(mean=mu, log_std=log_std))
def step(self, observation, prev_action, prev_reward): model_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mean, log_std = self.pi_model(*model_inputs) dist_info = DistInfoStd(mean=mean, log_std=log_std) action = self.distribution.sample(dist_info) agent_info = AgentInfo(dist_info=dist_info) action, agent_info = buffer_to((action, agent_info), device="cpu") if np.any(np.isnan(action.numpy())): breakpoint() return AgentStep(action=action, agent_info=agent_info)
def step(self, observation, prev_action, prev_reward): observation, prev_action, prev_reward = buffer_to( (observation, prev_action, prev_reward), device=self.device) # self.model includes encoder + actor MLP. mean, log_std, latent, conv = self.model(observation, prev_action, prev_reward) dist_info = DistInfoStd(mean=mean, log_std=log_std) action = self.distribution.sample(dist_info) agent_info = AgentInfo(dist_info=dist_info, conv=conv if self.store_latent else None) action, agent_info = buffer_to((action, agent_info), device="cpu") return AgentStep(action=action, agent_info=agent_info)
def step(self, observation, prev_action, prev_reward): model_inputs = buffer_to((observation, ), device=self.device)[0] mu, log_std, value, sym_features = self.model( model_inputs, extract_sym_features=True) dist_info = DistInfoStd(mean=mu, log_std=log_std) action = self.distribution.sample(dist_info) action = action.clamp(-1, 1) agent_info = SafeAgentInfo(dist_info=dist_info, value=value, sym_features=sym_features) action, agent_info = buffer_to((action, agent_info), device="cpu") return AgentStep(action=action, agent_info=agent_info)
def step(self, observation, prev_action, prev_reward): model_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) actions, means, log_stds = [], [], [] self.model.start() while self.model.has_next(): mean, log_std = self.model.next(actions, *model_inputs) dist_info = DistInfoStd(mean=mean, log_std=log_std) action = self.distribution.sample(dist_info) actions.append(action) means.append(mean) log_stds.append(log_std) mean, log_std = torch.cat(means, dim=-1), torch.cat(log_stds, dim=-1) dist_info = DistInfoStd(mean=mean, log_std=log_std) agent_info = AgentInfo(dist_info=dist_info) action = torch.cat(actions, dim=-1) action, agent_info = buffer_to((action, agent_info), device="cpu") return AgentStep(action=action, agent_info=agent_info)
def step(self, observation, prev_action, prev_reward, device="cpu"): """ Compute policy's action distribution from inputs, and sample an action. Calls the model to produce mean, log_std, value estimate, and next recurrent state. Moves inputs to device and returns outputs back to CPU, for the sampler. Advances the recurrent state of the agent. (no grad) """ agent_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mu, log_std, beta, q, pi, rnn_state = self.model( *agent_inputs, self.prev_rnn_state) terminations = torch.bernoulli(beta).bool() # Sample terminations dist_info_omega = DistInfo(prob=pi) new_o = self.sample_option(terminations, dist_info_omega) dist_info = DistInfoStd(mean=mu, log_std=log_std) dist_info_o = DistInfoStd(mean=select_at_indexes(new_o, mu), log_std=select_at_indexes(new_o, log_std)) action = self.distribution.sample(dist_info_o) # Model handles None, but Buffer does not, make zeros if needed: prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func( rnn_state, torch.zeros_like) # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. # (Special case: model should always leave B dimension in.) prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) agent_info = AgentInfoOCRnn(dist_info=dist_info, dist_info_o=dist_info_o, q=q, value=(pi * q).sum(-1), termination=terminations, inter_option_dist_info=dist_info_omega, prev_o=self._prev_option, o=new_o, prev_rnn_state=prev_rnn_state) action, agent_info = buffer_to((action, agent_info), device=device) self.advance_rnn_state(rnn_state) # Keep on device. self.advance_oc_state(new_o) return AgentStep(action=action, agent_info=agent_info)
def pi(self, observation, prev_action, prev_reward): model_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) actions, means, log_stds = [], [], [] log_pi_total = 0 self.model.start() while self.model.has_next(): mean, log_std = self.model.next(actions, *model_inputs) dist_info = DistInfoStd(mean=mean, log_std=log_std) action, log_pi = self.distribution.sample_loglikelihood(dist_info) log_pi_total += log_pi actions.append(action) means.append(mean) log_stds.append(log_std) mean, log_std = torch.cat(means, dim=-1), torch.cat(log_stds, dim=-1) dist_info = DistInfoStd(mean=mean, log_std=log_std) log_pi_total, dist_info = buffer_to((log_pi_total, dist_info), device="cpu") action = torch.cat(actions, dim=-1) return action, log_pi_total, dist_info # Action stays on device for q models.
def pi(self, conv_out, prev_action, prev_reward): """Compute action log-probabilities for state/observation, and sample new action (with grad). Uses special ``sample_loglikelihood()`` method of Gaussian distriution, which handles action squashing through this process. Assume variables already on device.""" # Call just the actor mlp, not the encoder. latent = self.pi_fc1(conv_out) mean, log_std = self.pi_mlp(latent, prev_action, prev_reward) dist_info = DistInfoStd(mean=mean, log_std=log_std) action, log_pi = self.distribution.sample_loglikelihood(dist_info) # action = self.distribution.sample(dist_info) # log_pi = self.distribution.log_likelihood(action, dist_info) log_pi, dist_info = buffer_to((log_pi, dist_info), device="cpu") return action, log_pi, dist_info # Action stays on device for q models.
def step(self, observation, prev_action, prev_reward, device="cpu"): """ Compute policy's action distribution from inputs, and sample an action. Calls the model to produce mean, log_std, and value estimate. Moves inputs to device and returns outputs back to CPU, for the sampler. (no grad) """ model_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mu, log_std, value = self.model(*model_inputs) dist_info = DistInfoStd(mean=mu, log_std=log_std) action = self.distribution.sample(dist_info) agent_info = AgentInfo(dist_info=dist_info, value=value) action, agent_info = buffer_to((action, agent_info), device=device) return AgentStep(action=action, agent_info=agent_info)
def step(self, observation, prev_action, prev_reward): agent_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mu, log_std, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state) dist_info = DistInfoStd(mean=mu, log_std=log_std) action = self.distribution.sample(dist_info) # Model handles None, but Buffer does not, make zeros if needed: prev_rnn_state = self.prev_rnn_state or buffer_func(rnn_state, torch.zeros_like) # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. # (Special case: model should always leave B dimension in.) prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) agent_info = AgentInfoRnn(dist_info=dist_info, value=value, prev_rnn_state=prev_rnn_state) action, agent_info = buffer_to((action, agent_info), device="cpu") self.advance_rnn_state(rnn_state) # Keep on device. return AgentStep(action=action, agent_info=agent_info)
def __call__(self, observation, prev_action, prev_reward, sampled_option, device="cpu"): """Performs forward pass on training data, for algorithm. Returns sampled distinfo, q, beta, and piomega distinfo""" model_inputs = buffer_to( (observation, prev_action, prev_reward, sampled_option), device=self.device) mu, log_std, beta, q, pi = self.model(*model_inputs[:-1]) # Need gradients from intra-option (DistInfoStd), q_o (q), termination (beta), and pi_omega (DistInfo) return buffer_to( (DistInfoStd(mean=select_at_indexes(sampled_option, mu), log_std=select_at_indexes(sampled_option, log_std)), q, beta, DistInfo(prob=pi)), device=device)
def __call__(self, observation, prev_action, prev_reward, sampled_option, init_rnn_state, device="cpu"): """Performs forward pass on training data, for algorithm (requires recurrent state input). Returnssampled distinfo, q, beta, and piomega distinfo""" # Assume init_rnn_state already shaped: [N,B,H] model_inputs = buffer_to((observation, prev_action, prev_reward, init_rnn_state, sampled_option), device=self.device) mu, log_std, beta, q, pi, next_rnn_state = self.model( *model_inputs[:-1]) # Need gradients from intra-option (DistInfoStd), q_o (q), termination (beta), and pi_omega (DistInfo) dist_info, q, beta, dist_info_omega = buffer_to( (DistInfoStd(mean=select_at_indexes(sampled_option, mu), log_std=select_at_indexes(sampled_option, log_std)), q, beta, DistInfo(prob=pi)), device=device) return dist_info, q, beta, dist_info_omega, next_rnn_state # Leave rnn_state on device.
def next(self, actions, observation, prev_action, prev_reward): if isinstance(observation, tuple): observation = torch.cat(observation, dim=-1) lead_dim, T, B, _ = infer_leading_dims(observation, self._obs_ndim) input_obs = observation.view(T * B, -1) if self._counter == 0: logits = self.mlp_loc(input_obs) logits = restore_leading_dims(logits, lead_dim, T, B) self._counter += 1 return logits elif self._counter == 1: assert len(actions) == 1 action_loc = actions[0].view(T * B, -1) model_input = torch.cat((input_obs, action_loc.repeat((1, self._n_tile))), dim=-1) output = self.mlp_delta(model_input) mu, log_std = output.chunk(2, dim=-1) mu, log_std = restore_leading_dims((mu, log_std), lead_dim, T, B) self._counter += 1 return DistInfoStd(mean=mu, log_std=log_std) else: raise Exception('Invalid self._counter', self._counter)
def step(self, observation, prev_action, prev_reward): threshold = 0.2 model_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) if self._max_q_eval_mode == 'none': mean, log_std = self.model(*model_inputs) dist_info = DistInfoStd(mean=mean, log_std=log_std) action = self.distribution.sample(dist_info) agent_info = AgentInfo(dist_info=dist_info) action, agent_info = buffer_to((action, agent_info), device="cpu") return AgentStep(action=action, agent_info=agent_info) else: global MaxQInput observation, prev_action, prev_reward = model_inputs fields = observation._fields if 'position' in fields: no_batch = len(observation.position.shape) == 1 else: no_batch = len(observation.pixels.shape) == 3 if no_batch: if 'state' in self._max_q_eval_mode: observation = [observation.position.unsqueeze(0)] else: observation = [observation.pixels.unsqueeze(0)] else: if 'state' in self._max_q_eval_mode: observation = [observation.position] else: observation = [observation.pixels] if self._max_q_eval_mode == 'state_rope': locations = np.arange(25).astype('float32') locations = locations[:, None] locations = np.tile(locations, (1, 50)) / 24 elif self._max_q_eval_mode == 'state_cloth_corner': locations = np.array( [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype='float32') locations = np.tile(locations, (1, 50)) elif self._max_q_eval_mode == 'state_cloth_point': locations = np.mgrid[0:9, 0:9].reshape(2, 81).T.astype('float32') locations = np.tile(locations, (1, 50)) / 8 elif self._max_q_eval_mode == 'pixel_rope': image = observation[0].squeeze(0).cpu().numpy() locations = np.transpose(np.where(np.all( image > 150, axis=2))).astype('float32') if locations.shape[0] == 0: locations = np.array([[-1, -1]], dtype='float32') locations = np.tile(locations, (1, 50)) / 63 elif self._max_q_eval_mode == 'pixel_cloth': image = observation[0].squeeze(0).cpu().numpy() locations = np.transpose(np.where(np.any( image < 100, axis=-1))).astype('float32') locations = np.tile(locations, (1, 50)) / 63 else: raise Exception() observation_pi = self.model.forward_embedding(observation) observation_qs = [ q.forward_embedding(observation) for q in self.q_models ] n_locations = len(locations) observation_pi_i = [ repeat(o[[i]], [n_locations] + [1] * len(o.shape[1:])) for o in observation_pi ] observation_qs_i = [[ repeat(o, [n_locations] + [1] * len(o.shape[1:])) for o in observation_q ] for observation_q in observation_qs] locations = torch.from_numpy(locations).to(self.device) if MaxQInput is None: MaxQInput = namedtuple('MaxQPolicyInput', fields) aug_observation_pi = [locations] + list(observation_pi_i) aug_observation_pi = MaxQInput(*aug_observation_pi) aug_observation_qs = [[locations] + list(observation_q_i) for observation_q_i in observation_qs_i] aug_observation_qs = [ MaxQInput(*aug_observation_q) for aug_observation_q in aug_observation_qs ] mean, log_std = self.model.forward_output( aug_observation_pi) #, prev_action, prev_reward) qs = [ q.forward_output(aug_obs, mean) for q, aug_obs in zip(self.q_models, aug_observation_qs) ] q = torch.min(torch.stack(qs, dim=0), dim=0)[0] #q = q.view(batch_size, n_locations) values, indices = torch.topk(q, math.ceil(threshold * n_locations), dim=-1) # vmin, vmax = values.min(dim=-1, keepdim=True)[0], values.max(dim=-1, keepdim=True)[0] # values = (values - vmin) / (vmax - vmin) # values = F.log_softmax(values, -1) # # uniform = torch.rand_like(values) # uniform = torch.clamp(uniform, 1e-5, 1 - 1e-5) # gumbel = -torch.log(-torch.log(uniform)) #sampled_idx = torch.argmax(values + gumbel, dim=-1) sampled_idx = torch.randint(high=math.ceil(threshold * n_locations), size=(1, )).to(self.device) actual_idxs = indices[sampled_idx] #actual_idxs += (torch.arange(batch_size) * n_locations).to(self.device) location = locations[actual_idxs][:, :1] location = (location - 0.5) / 0.5 delta = torch.tanh(mean[actual_idxs]) action = torch.cat((location, delta), dim=-1) mean, log_std = mean[actual_idxs], log_std[actual_idxs] if no_batch: action = action.squeeze(0) mean = mean.squeeze(0) log_std = log_std.squeeze(0) dist_info = DistInfoStd(mean=mean, log_std=log_std) agent_info = AgentInfo(dist_info=dist_info) action, agent_info = buffer_to((action, agent_info), device="cpu") return AgentStep(action=action, agent_info=agent_info)
def __call__(self, observation, prev_action, prev_reward): model_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mu, log_std, ev, iv = self.model(*model_inputs) return buffer_to((DistInfoStd(mean=mu, log_std=log_std), ev, iv), device="cpu")