コード例 #1
0
    def evaluate_actions(
        self,
        inputs,
        z_eps,
        rnn_hxs,
        masks,
        action,
        get_entropy=True,
    ):
        # value, actor_features, rnn_hxs = self.base(inputs, rnn_hxs, masks)
        state_features, rnn_hxs = self.base(inputs, rnn_hxs, masks)

        enc_input = torch.cat([state_features, inputs['goal_vector']], 1)
        concat_params = self.z_enc_net(enc_input)
        mu, std = utils.get_mean_std(concat_params)
        std = torch.clamp(std, max=self.z_std_clip_max)
        z_gauss_dist = ds.normal.Normal(loc=mu, scale=std)
        z_sample = mu + (z_eps * std)

        actor_inp = torch.cat([state_features, z_sample], 1)

        actor_features = self.actor_net(actor_inp)
        value = self.critic_net(actor_inp)

        dist = self.dist(actor_features)

        action_log_probs = dist.log_probs(action)
        if get_entropy:
            dist_entropy = dist.entropy().mean()
        else:
            dist_entropy = None

        return z_sample, z_gauss_dist, value, action_log_probs, \
            dist_entropy, rnn_hxs, dist
コード例 #2
0
 def _z_encode(self, state_features, goal_vector, do_z_sampling):
     enc_input = torch.cat([state_features, goal_vector], 1)
     concat_params = self.z_enc_net(enc_input)
     mu, std = utils.get_mean_std(concat_params)
     std = torch.clamp(std, max=self.z_std_clip_max)
     z_gauss_dist = ds.normal.Normal(loc=mu, scale=std)
     if do_z_sampling:
         z_sample = z_gauss_dist.rsample()
     else:
         z_sample = z_gauss_dist.mean
     return z_sample, z_gauss_dist
コード例 #3
0
    def _encode(self, obs, rnn_hxs, masks):
        # if self.use_state_encoder:
        obs_feats, rnn_hxs = self.base(obs,
                                       rnn_hxs=rnn_hxs,
                                       masks=masks.clone())

        # if self.encoder_type == 'single':
        # full_feats = torch.cat([emb, obs_feats], 1)
        # hid = self.fc(full_feats)
        # hid = self.fc(obs_feats)
        hid = obs_feats

        goal_vector = obs['goal_vector']
        hid_cat = torch.cat([hid, goal_vector], 1)
        concat_params = self.fc12(hid_cat)
        mu, std = utils.get_mean_std(concat_params)
        std = torch.clamp(std, max=self.z_std_clip_max)
        gauss_dist = ds.normal.Normal(loc=mu, scale=std)
        return gauss_dist, hid, rnn_hxs
コード例 #4
0
    def forward(self, trajectory, resizing_shape=None, masks=None):
        if self.ic_mode == 'valor':
            obs_feats = self.encode_state_sequence(trajectory=trajectory,
                                                   masks=masks)

        elif self.input_type == 'final_state':
            final_state = trajectory
            obs_feats, _ = self.base(final_state)

        elif self.input_type == 'final_and_initial_state':
            i_state, f_state = trajectory
            i_feats, _ = self.base(i_state)
            f_feats, _ = self.base(f_state)
            obs_feats = torch.cat([i_feats, f_feats], 1)

        else:
            raise ValueError

        feats = self.fc(obs_feats)

        if self.option_space == 'discrete':
            opt_features = self.fc_logits(feats)
            if resizing_shape is not None:
                opt_features = opt_features.view(*resizing_shape,
                                                 *opt_features.shape[1:])
            dist = self.dist(opt_features)
            return dist

        else:
            concat_params = self.fc12(feats)

            mu, std = utils.get_mean_std(concat_params)
            if resizing_shape is not None:
                mu = mu.view(*resizing_shape, *mu.shape[1:])
                std = std.view(*resizing_shape, *std.shape[1:])

            gauss_dist = ds.normal.Normal(loc=mu, scale=std)

            return gauss_dist
コード例 #5
0
    def _encode(self, obs, rnn_hxs, masks):
        # if self.use_state_encoder:
        obs_feats, rnn_hxs = self.base(obs,
                                       rnn_hxs=rnn_hxs,
                                       masks=masks.clone())

        # if self.encoder_type == 'single':
        # full_feats = torch.cat([emb, obs_feats], 1)
        # hid = self.fc(full_feats)
        # hid = self.fc(obs_feats)
        hid = obs_feats

        if self.latent_space == 'gaussian':
            concat_params = self.fc12(hid)
            mu, std = utils.get_mean_std(concat_params)
            std = torch.clamp(std, max=self.z_std_clip_max)
            gauss_dist = ds.normal.Normal(loc=mu, scale=std)
            return gauss_dist, hid, rnn_hxs

        else:
            raise NotImplementedError
            opt_features = self.fc_logits(hid)
            return opt_features, rnn_hxs
コード例 #6
0
    def _encode(self, obs, specifications=None):
        if self.use_state_encoder:
            obs_feats, _ = self.base(obs)

        if self.encoder_type == 'single':
            # [NOTE] : We can evaluate recall for this case as well.
            # Options are:
            # 1. Predict a random value for missing attribute
            # 2. Do something else :P

            emb = self.main_embed(obs['mission'])

            if self.use_state_encoder:
                full_feats = torch.cat([emb, obs_feats], 1)
                hid = self.fc(full_feats)
            else:
                hid = emb

            if self.option_space == 'continuous':
                concat_params = self.fc12(hid)
                mu, std = utils.get_mean_std(concat_params)
                gauss_dist = ds.normal.Normal(loc=mu, scale=std)
            else:
                opt_features = self.fc_logits(hid)

        elif self.encoder_type == 'poe':
            '''
            Product of experts for composing gaussians of all
            specified attributes along with the prior.

            'specifications': mask tensor with ones for
                specified attributes and zeros otherwise
            '''

            # attr_indices = [0, 1, 2, 3]
            attr_indices = np.arange(len(self.input_attr_dims))
            if specifications is None:
                specifications = obs['mission'].new_ones(
                    (obs['mission'].shape[0], len(self.input_attr_dims)))
            # else:
            #     attr_indices = specifications
            #     assert len(attr_indices) >= 0

            # if target[:, attr_indices].min().item() < 0:
            #     import pdb; pdb.set_trace()
            #     pass

            mission = obs['mission'] * (torch.eq(specifications, 0).long())
            # obs['mission'].masked_fill_(torch.eq(specifications, 0), 0)

            # assert target[:, attr_indices].min().item() >= 0, \
            assert mission.min().item() >= 0, \
                "Negative index given as input to nn.embedding table"

            # Embed goals for specified attributes
            goal_embeds = [
                self.poe_embed[idx](mission[:, idx]) \
                    for idx in attr_indices]

            if self.use_state_encoder:
                # Forward pass goal embed and state observation
                cats = [self.fc_poe[attr_indices[idx]](
                    torch.cat([emb, obs_feats], 1)) \
                        for idx, emb in enumerate(goal_embeds)]
            else:
                cats = goal_embeds

            concat_params = [self.fc12[attr_indices[idx]](cat)\
                for idx, cat in enumerate(cats)]

            # Get mean and standard deviation
            concat_params = [utils.get_mean_std(par) \
                for par in concat_params]

            mus, stds = zip(*concat_params)

            # Initialize mu, std of prior before multiplying experts
            prior_mu = obs['agent_pos'].new_zeros(
                (obs['agent_pos'].shape[0], self.omega_option_dims))
            prior_std = obs['agent_pos'].new_ones(
                (obs['agent_pos'].shape[0], self.omega_option_dims))

            # Multiply gaussians to prior one by one in a for loop
            sum_sig = 1.0 / (prior_std**2)
            sum_mu = prior_mu / (prior_std**2)
            for idx, (mu, std) in enumerate(zip(mus, stds)):
                _mask = specifications[:, idx:idx + 1].float()

                sum_sig += (_mask * 1.0) / (std**2)
                sum_mu += (_mask * mu) / (std**2)

            std_poe_sq = 1.0 / sum_sig
            std_poe = torch.sqrt(std_poe_sq)
            mu_poe = std_poe_sq * sum_mu

            # Get distributions object
            gauss_dist = ds.normal.Normal(loc=mu_poe, scale=std_poe)
            # return gauss_dist
            hid = None

        else:
            raise ValueError("Only 'single' and 'poe' supported")

        if self.option_space == 'continuous':
            return gauss_dist, hid
        else:
            return opt_features