def evaluate_actions( self, inputs, z_eps, rnn_hxs, masks, action, get_entropy=True, ): # value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks) state_features, rnn_hxs = self.base(inputs, rnn_hxs, masks) enc_input = torch.cat([state_features, inputs['goal_vector']], 1) concat_params = self.z_enc_net(enc_input) mu, std = utils.get_mean_std(concat_params) std = torch.clamp(std, max=self.z_std_clip_max) z_gauss_dist = ds.normal.Normal(loc=mu, scale=std) z_sample = mu + (z_eps * std) actor_inp = torch.cat([state_features, z_sample], 1) actor_features = self.actor_net(actor_inp) value = self.critic_net(actor_inp) dist = self.dist(actor_features) action_log_probs = dist.log_probs(action) if get_entropy: dist_entropy = dist.entropy().mean() else: dist_entropy = None return z_sample, z_gauss_dist, value, action_log_probs, \ dist_entropy, rnn_hxs, dist
def _z_encode(self, state_features, goal_vector, do_z_sampling): enc_input = torch.cat([state_features, goal_vector], 1) concat_params = self.z_enc_net(enc_input) mu, std = utils.get_mean_std(concat_params) std = torch.clamp(std, max=self.z_std_clip_max) z_gauss_dist = ds.normal.Normal(loc=mu, scale=std) if do_z_sampling: z_sample = z_gauss_dist.rsample() else: z_sample = z_gauss_dist.mean return z_sample, z_gauss_dist
def _encode(self, obs, rnn_hxs, masks): # if self.use_state_encoder: obs_feats, rnn_hxs = self.base(obs, rnn_hxs=rnn_hxs, masks=masks.clone()) # if self.encoder_type == 'single': # full_feats = torch.cat([emb, obs_feats], 1) # hid = self.fc(full_feats) # hid = self.fc(obs_feats) hid = obs_feats goal_vector = obs['goal_vector'] hid_cat = torch.cat([hid, goal_vector], 1) concat_params = self.fc12(hid_cat) mu, std = utils.get_mean_std(concat_params) std = torch.clamp(std, max=self.z_std_clip_max) gauss_dist = ds.normal.Normal(loc=mu, scale=std) return gauss_dist, hid, rnn_hxs
def forward(self, trajectory, resizing_shape=None, masks=None): if self.ic_mode == 'valor': obs_feats = self.encode_state_sequence(trajectory=trajectory, masks=masks) elif self.input_type == 'final_state': final_state = trajectory obs_feats, _ = self.base(final_state) elif self.input_type == 'final_and_initial_state': i_state, f_state = trajectory i_feats, _ = self.base(i_state) f_feats, _ = self.base(f_state) obs_feats = torch.cat([i_feats, f_feats], 1) else: raise ValueError feats = self.fc(obs_feats) if self.option_space == 'discrete': opt_features = self.fc_logits(feats) if resizing_shape is not None: opt_features = opt_features.view(*resizing_shape, *opt_features.shape[1:]) dist = self.dist(opt_features) return dist else: concat_params = self.fc12(feats) mu, std = utils.get_mean_std(concat_params) if resizing_shape is not None: mu = mu.view(*resizing_shape, *mu.shape[1:]) std = std.view(*resizing_shape, *std.shape[1:]) gauss_dist = ds.normal.Normal(loc=mu, scale=std) return gauss_dist
def _encode(self, obs, rnn_hxs, masks): # if self.use_state_encoder: obs_feats, rnn_hxs = self.base(obs, rnn_hxs=rnn_hxs, masks=masks.clone()) # if self.encoder_type == 'single': # full_feats = torch.cat([emb, obs_feats], 1) # hid = self.fc(full_feats) # hid = self.fc(obs_feats) hid = obs_feats if self.latent_space == 'gaussian': concat_params = self.fc12(hid) mu, std = utils.get_mean_std(concat_params) std = torch.clamp(std, max=self.z_std_clip_max) gauss_dist = ds.normal.Normal(loc=mu, scale=std) return gauss_dist, hid, rnn_hxs else: raise NotImplementedError opt_features = self.fc_logits(hid) return opt_features, rnn_hxs
def _encode(self, obs, specifications=None): if self.use_state_encoder: obs_feats, _ = self.base(obs) if self.encoder_type == 'single': # [NOTE] : We can evaluate recall for this case as well. # Options are: # 1. Predict a random value for missing attribute # 2. Do something else :P emb = self.main_embed(obs['mission']) if self.use_state_encoder: full_feats = torch.cat([emb, obs_feats], 1) hid = self.fc(full_feats) else: hid = emb if self.option_space == 'continuous': concat_params = self.fc12(hid) mu, std = utils.get_mean_std(concat_params) gauss_dist = ds.normal.Normal(loc=mu, scale=std) else: opt_features = self.fc_logits(hid) elif self.encoder_type == 'poe': ''' Product of experts for composing gaussians of all specified attributes along with the prior. 'specifications': mask tensor with ones for specified attributes and zeros otherwise ''' # attr_indices = [0, 1, 2, 3] attr_indices = np.arange(len(self.input_attr_dims)) if specifications is None: specifications = obs['mission'].new_ones( (obs['mission'].shape[0], len(self.input_attr_dims))) # else: # attr_indices = specifications # assert len(attr_indices) >= 0 # if target[:, attr_indices].min().item() < 0: # import pdb; pdb.set_trace() # pass mission = obs['mission'] * (torch.eq(specifications, 0).long()) # obs['mission'].masked_fill_(torch.eq(specifications, 0), 0) # assert target[:, attr_indices].min().item() >= 0, \ assert mission.min().item() >= 0, \ "Negative index given as input to nn.embedding table" # Embed goals for specified attributes goal_embeds = [ self.poe_embed[idx](mission[:, idx]) \ for idx in attr_indices] if self.use_state_encoder: # Forward pass goal embed and state observation cats = [self.fc_poe[attr_indices[idx]]( torch.cat([emb, obs_feats], 1)) \ for idx, emb in enumerate(goal_embeds)] else: cats = goal_embeds concat_params = [self.fc12[attr_indices[idx]](cat)\ for idx, cat in enumerate(cats)] # Get mean and standard deviation concat_params = [utils.get_mean_std(par) \ for par in concat_params] mus, stds = zip(*concat_params) # Initialize mu, std of prior before multiplying experts prior_mu = obs['agent_pos'].new_zeros( (obs['agent_pos'].shape[0], self.omega_option_dims)) prior_std = obs['agent_pos'].new_ones( (obs['agent_pos'].shape[0], self.omega_option_dims)) # Multiply gaussians to prior one by one in a for loop sum_sig = 1.0 / (prior_std**2) sum_mu = prior_mu / (prior_std**2) for idx, (mu, std) in enumerate(zip(mus, stds)): _mask = specifications[:, idx:idx + 1].float() sum_sig += (_mask * 1.0) / (std**2) sum_mu += (_mask * mu) / (std**2) std_poe_sq = 1.0 / sum_sig std_poe = torch.sqrt(std_poe_sq) mu_poe = std_poe_sq * sum_mu # Get distributions object gauss_dist = ds.normal.Normal(loc=mu_poe, scale=std_poe) # return gauss_dist hid = None else: raise ValueError("Only 'single' and 'poe' supported") if self.option_space == 'continuous': return gauss_dist, hid else: return opt_features