Example #1
0
    def _select_action(self,
                       logit,
                       ended,
                       is_prob=False,
                       fix_action_ended=True):
        logit_cpu = logit.clone().cpu()
        if is_prob:
            probs = logit_cpu
        else:
            probs = F.softmax(logit_cpu, 1)

        if self.feedback == 'argmax':
            _, action = probs.max(1)  # student forcing - argmax
            action = action.detach()
        elif self.feedback == 'sample':
            # sampling an action from model
            m = D.Categorical(probs)
            action = m.sample()
        else:
            raise ValueError('Invalid feedback option: {}'.format(
                self.feedback))

        # set action to 0 if already ended
        if fix_action_ended:
            for i, _ended in enumerate(ended):
                if _ended:
                    action[i] = 0

        return action
Example #2
0
def sample(lnprobs, temperature=1.0):
    if temperature == 0.0:
        return lnprobs.argmax()
    prob = F.softmax(lnprobs / temperature, dim=0)
    cdf = dist.Categorical(prob)

    return cdf.sample()
    def _sample_posterior(self, x, num_samples, context=None):
        log_weights = torch.log(self.module.soft_max(self.module.soft_weights))
        T = self.module.covars[None, :, :, :] + x[1][:, None, :, :]

        p_weights = log_weights + dist.MultivariateNormal(
            loc=self.module.means, covariance_matrix=T
        ).log_prob(x[0][:, None, :])
        p_weights -= torch.logsumexp(p_weights, axis=1)[:, None]

        L_t = torch.cholesky(T)
        T_inv = torch.cholesky_solve(
            torch.eye(self.d, device=self.device), L_t)

        diff = x[0][:, None, :] - self.module.means
        T_prod = torch.matmul(T_inv, diff[:, :, :, None])
        p_means = self.module.means + torch.matmul(
            self.module.covars,
            T_prod
        ).squeeze()

        p_covars = self.module.covars - torch.matmul(
            self.module.covars,
            torch.matmul(T_inv, self.module.covars)
        )

        idx = dist.Categorical(logits=p_weights).sample([num_samples])
        samples = dist.MultivariateNormal(
            loc=p_means, covariance_matrix=p_covars).sample([num_samples])

        return samples.transpose(0, 1)[
            torch.arange(len(x), device=self.device)[:, None, None, None],
            torch.arange(num_samples, device=self.device)[None, :, None, None],
            idx.T[:, :, None, None],
            torch.arange(self.d, device=self.device)[None, None, None, :]
        ].squeeze()
Example #4
0
def compose_losses(outputs, log_selected_policies, total_advantages, targets,
                   batch, args):
    """Caluculate loss value

    Returns:
        tuple: losses and statistic values and the number of training data
    """

    tmasks = batch['turn_mask']
    omasks = batch['observation_mask']

    losses = {}
    dcnt = tmasks.sum().item()
    turn_advantages = total_advantages.mul(tmasks).sum(2, keepdim=True)

    losses['p'] = (-log_selected_policies * turn_advantages).sum()
    if 'value' in outputs:
        losses['v'] = (
            (outputs['value'] - targets['value'])**2).mul(omasks).sum() / 2
    if 'return' in outputs:
        losses['r'] = F.smooth_l1_loss(outputs['return'],
                                       targets['return'],
                                       reduction='none').mul(omasks).sum()

    entropy = dist.Categorical(logits=outputs['policy']).entropy().mul(
        tmasks.sum(-1))
    losses['ent'] = entropy.sum()

    base_loss = losses['p'] + losses.get('v', 0) + losses.get('r', 0)
    entropy_loss = entropy.mul(1 - batch['progress'] *
                               (1 - args['entropy_regularization_decay'])).sum(
                               ) * -args['entropy_regularization']
    losses['total'] = base_loss + entropy_loss

    return losses, dcnt
Example #5
0
def test_words(net, chars, setence_len=50, iscuda=False):
    ''' Given a network, valid characters in trained book, let the network generate a sentence.
        This is used for training rnn model.
    '''
    # create hidden state
    ho = net.init_hidden()
    # create random word index
    x_in = torch.LongTensor([random.randint(0, len(chars) - 1)])
    # create output index
    output = [int(x_in)]

    if iscuda:
        ho = ho.cuda()
        x_in = x_in.cuda()

    # now we iterate through our setence, pasing x_in to get y_out, and setting y_out as x_in for the next time step
    for i in range(setence_len):
        y_out, ho = net.forward(x_in, ho)
        dist = distributions.Categorical(probs=y_out.exp())
        # get max val and index
        sample = dist.sample()
        output.append(int(sample))
        x_in = sample

    # now we print our words
    words = ''
    for item in output:
        words += chars[item]
    return words
Example #6
0
    def __init__(self,
                 latent_dim=2,
                 num_classes=10,
                 distribution=None,
                 categorical=None):
        """
        Initializes a new dataset where noise and a label is sampled from the given distribution.
        If no distribution is given, noise is sampled from a multivariate normal distribution
        with a certain latent dimension and the label is sampled from a categorical distribution.

        Parameters
        ----------
        latent_dim: int
            The latent dimension for the Normal Distribution the noise is sampled from.
        num_classes: int
            Number of classes for the Categorical Distribution the label is sampled from.
        distribution: torch.distributions.Distribution
            The noise type to use. Overrides setting of latent_dim if specified.
        categorical: torch.distributions.Distribution
            The distribution to sample labels from. Overrides setting of num_classes if specified.
        """
        super().__init__(latent_dim=latent_dim, distribution=distribution)

        if categorical is None:
            self.categorical = D.Categorical(
                torch.Tensor([1.0 / num_classes] * num_classes))
        else:
            self.categorical = categorical
Example #7
0
    def _goal_likelihood(self, y: torch.Tensor, goal: torch.Tensor,
                         **hyperparams) -> torch.Tensor:
        """Returns the goal-likelihood of a plan `y`, given `goal`.

    Args:
      y: A plan under evaluation, with shape `[B, T, 2]`.
      goal: The goal locations, with shape `[B, K, 2]`.
      hyperparams: (keyword arguments) The goal-likelihood hyperparameters.

    Returns:
      The log-likelihodd of the plan `y` under the `goal` distribution.
    """
        # Parses tensor dimensions.
        B, K, _ = goal.shape

        # Fetches goal-likelihood hyperparameters.
        epsilon = hyperparams.get("epsilon", 1.0)

        # TODO(filangel): implement other goal likelihoods from the DIM paper
        # Initializes the goal distribution.
        goal_distribution = D.MixtureSameFamily(
            mixture_distribution=D.Categorical(
                probs=torch.ones((B, K)).to(goal.device)),  # pylint: disable=no-member
            component_distribution=D.Independent(
                D.Normal(loc=goal, scale=torch.ones_like(goal) * epsilon),  # pylint: disable=no-member
                reinterpreted_batch_ndims=1,
            ))

        return torch.mean(goal_distribution.log_prob(y[:, -1, :]), dim=0)  # pylint: disable=no-member
    def get_dist(self):
        n = len(self.mean)
        mix = D.Categorical(torch.ones(n, ))
        comp = D.Independent(D.Normal(self.mean, self.var * torch.ones(n, 2)),
                             1)

        return D.MixtureSameFamily(mix, comp)
 def predict(self, x, deterministic=True):
     out = self.actor(x)
     if deterministic:
         out = torch.max(out, dim=1)[1]
     else:
         out = distributions.Categorical(probs=out).sample()
     return out.cpu().numpy()
    def act(self, s, epsilon):
        '''epsilon greedy action selection
        
        Arguments:
            s {np array} -- state selection
            epsilon {float} -- epsilon value
        
        Returns:
            action -- action index
        '''
        # get action logits
        action_logits = self.brain(s)

        # create a categorical distribution from logits
        categorical_distribution = distributions.Categorical(
            logits=action_logits)

        # sample actions according to the distribution
        actions = categorical_distribution.sample()
        # print(actions.shape)

        # collect relevant log probabilities
        relevant_log_probs = categorical_distribution.log_prob(actions)
        # print(relevant_log_probs.shape)

        return actions[0].item(), relevant_log_probs
Example #11
0
    def act_intrinsic(self, obs):
        assert self.intrinsic  # Only usable with random network distillation

        obs = torch.FloatTensor(obs)

        if self.action_type == "Discrete":
            logits, state_values, int_state_values = self.net(obs)
            state_values = state_values.squeeze()
            int_state_values = int_state_values.squeeze()

            dist = distributions.Categorical(F.softmax(logits, dim=-1))
            actions = dist.sample().squeeze()
            action_log_probs = dist.log_prob(actions).squeeze()

        elif self.action_type == "Box":
            logits, sd, state_values, int_state_values = self.net.forward_continuous(
                obs)
            state_values = state_values.squeeze()
            int_state_values = int_state_values.squeeze()

            dist = distributions.Normal(logits, torch.exp(sd))
            actions = dist.sample()
            action_log_probs = dist.log_prob(actions)
            dist_entropy = dist.entropy()

        return actions, state_values, int_state_values, action_log_probs
Example #12
0
    def test_mutual_info_penalty(self):
        real_loss_mean = 2.600133
        real_loss_sum = 5.200266
        real_losses = [0.7086121, 4.491654]
        mean = torch.Tensor([[1.3, 4.6, 7.1], [0.2, 11.4, 1.0]])
        std = torch.Tensor([[1.0, 0.5, 3.1], [0.2, 3.5, 4.9]])
        logits = torch.Tensor([[0.5, 0.5], [0.75, 0.25]])

        c_dis = torch.Tensor([[0, 1], [1, 0]])
        c_cont = torch.Tensor([[1.4, 4.0, 5.0], [-1.0, 7.0, 2.0]])

        q_cont = ds.Normal(loc=mean, scale=std)
        q_cat = ds.Categorical(logits=logits)

        mutualinfo = MutualInformationPenalty()
        loss_mean = mutualinfo(c_dis, c_cont, q_cat, q_cont)
        self.assertAlmostEqual(loss_mean.item(), real_loss_mean, 5)

        mutualinfo.reduction = "sum"
        loss_sum = mutualinfo(c_dis, c_cont, q_cat, q_cont)
        self.assertAlmostEqual(loss_sum.item(), real_loss_sum, 5)

        mutualinfo.reduction = "none"
        loss = mutualinfo(c_dis, c_cont, q_cat, q_cont)
        for i in range(2):
            self.assertAlmostEqual(loss[i].item(), real_losses[i], 5)
Example #13
0
 def select_action(self, obs):
     output = self.actor(obs, rnncs=self.rnncs)  # [B, A]
     self.rnncs_ = self.actor.get_rnncs()
     value = self.critic(obs, rnncs=self.rnncs)  # [B, 1]
     if self.is_continuous:
         mu, log_std = output  # [B, A]
         dist = td.Independent(td.Normal(mu, log_std.exp()), 1)
         action = dist.sample().clamp(-1, 1)  # [B, A]
         log_prob = dist.log_prob(action).unsqueeze(-1)  # [B, 1]
     else:
         logits = output  # [B, A]
         logp_all = logits.log_softmax(-1)  # [B, A]
         norm_dist = td.Categorical(logits=logp_all)
         action = norm_dist.sample()  # [B,]
         log_prob = norm_dist.log_prob(action).unsqueeze(-1)  # [B, 1]
     acts_info = Data(action=action,
                      value=value,
                      log_prob=log_prob + th.finfo().eps)
     if self.use_rnn:
         acts_info.update(rnncs=self.rnncs)
     if self.is_continuous:
         acts_info.update(mu=mu, log_std=log_std)
     else:
         acts_info.update(logp_all=logp_all)
     return action, acts_info
Example #14
0
    def decoder(self,
                z,
                encoded_history,
                current_state,
                y_e=None,
                train=False):
        pass

        bs = encoded_history.shape[0]
        a_0 = F.dropout(self.action(current_state.reshape(bs, -1)),
                        self.dropout_p)
        state = F.dropout(self.state(encoded_history.reshape(bs, -1)),
                          self.dropout_p)

        current_state = current_state.unsqueeze(1)
        gauses = []
        inp = F.dropout(
            torch.cat((encoded_history.reshape(bs, -1), a_0), dim=-1),
            self.dropout_p)
        for i in range(12):
            h_state = self.gru(inp.reshape(bs, -1), state)

            _, deltas, log_sigmas, corrs = self.project_to_GMM_params(h_state)
            deltas = torch.clamp(deltas, max=1.5, min=-1.5)
            deltas = deltas.reshape(bs, -1, 2)
            log_sigmas = log_sigmas.reshape(bs, -1, 2)
            corrs = corrs.reshape(bs, -1, 1)

            mus = deltas + current_state
            current_state = mus
            variance = torch.clamp(torch.exp(log_sigmas).unsqueeze(2)**2,
                                   max=1e3)

            m_diag = variance * torch.eye(2).to(variance.device)
            sigma_xy = torch.clamp(torch.prod(torch.exp(log_sigmas), dim=-1),
                                   min=1e-8,
                                   max=1e3)

            if train:
                # log_pis = z.reshape(bs, 1) * torch.ones(bs, self.num_modes).cuda()
                log_pis = to_one_hot(z, n_dims=self.num_modes).cuda()

            else:
                log_pis = to_one_hot(z, n_dims=self.num_modes).cuda()
            log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True)
            mix = D.Categorical(logits=log_pis)
            comp = D.MultivariateNormal(mus, m_diag)
            gmm = D.MixtureSameFamily(mix, comp)
            t = (sigma_xy * corrs.squeeze()).reshape(-1, 1, 1)
            cov_matrix = m_diag  # + anti_diag
            gauses.append(gmm)
            a_t = gmm.sample()  # possible grad problems?
            a_tt = F.dropout(self.action(a_t.reshape(bs, -1)), self.dropout_p)
            state = h_state
            inp = F.dropout(
                torch.cat((encoded_history.reshape(bs, -1), a_tt), dim=-1),
                self.dropout_p)

        return gauses
Example #15
0
    def forward(self, output_sizes, hold_seed=None, hold_initial_set=False):
        """
        Sample from prior
        :param output_sizes: Tensor([B,])
        :param hold_seed
        :param hold_initial_set
        :return: Tensor([B, N, D])
        """
        bsize = output_sizes.shape[0]
        if hold_initial_set:  # [B, N]
            x_mask = get_mask(output_sizes, self.max_outputs)
        else:
            x_mask = sample_mask(output_sizes, self.max_outputs)

        if hold_seed is not None:  # [B, N, Ds]
            torch.random.manual_seed(hold_seed)
            eps = torch.randn([1, self.max_outputs, self.dim_seed
                               ]).to(x_mask.device).repeat(bsize, 1, 1)
        else:
            eps = torch.randn([bsize, self.max_outputs,
                               self.dim_seed]).to(x_mask.device)

        if self.n_mixtures == 1:
            x = self.mu + torch.exp(self.logvar / 2.) * eps
        else:
            if self.train_gmm:
                if hold_seed is not None:
                    torch.random.manual_seed(hold_seed)
                    logits = self.logits.reshape([1, 1,
                                                  self.n_mixtures]).repeat(
                                                      1, self.max_outputs,
                                                      1)  # [1, N, M]
                    onehot = F.gumbel_softmax(
                        logits, tau=self.tau,
                        hard=True).repeat(bsize, 1,
                                          1).unsqueeze(-1)  # [B, N, M, 1]
                else:
                    logits = self.logits.reshape([1, 1,
                                                  self.n_mixtures]).repeat(
                                                      bsize, self.max_outputs,
                                                      1)  # [B, N, M]
                    onehot = F.gumbel_softmax(logits, tau=self.tau,
                                              hard=True).unsqueeze(
                                                  -1)  # [B, N, M, 1]
                mu = self.mu.reshape([1, 1, self.n_mixtures,
                                      self.dim_seed])  # [1, 1, M, D]
                sig = self.sig.reshape([1, 1, self.n_mixtures,
                                        self.dim_seed])  # [1, 1, M, D]
                mu = (mu * onehot).sum(2)  # [B, N, D]
                sig = (sig * onehot).sum(2)  # [B, N, D]
                x = mu + sig * eps
            else:
                mix = D.Categorical(self.logits)
                comp = D.Independent(D.Normal(self.mu, self.sig.abs()), 1)
                mixture = D.MixtureSameFamily(mix, comp)
                x = mixture.sample((output_sizes.size(0), self.max_outputs))

        x = self.output(x)  # [B, N, D]
        return x, x_mask
 def generate(self, bs):
     a = torch.zeros(
         (bs, self.Number_qubits)).type(torch.LongTensor).to(args.device)
     hidden = self.init_hidden.repeat(1, bs, 1)
     # BOS input
     beginning = self.BOS.view(1, 1, -1)
     beginning = beginning.repeat(1, bs, 1)
     output, hidden = self.gru(beginning, hidden)
     output = self.logsoftmax(self.out(output[0]))
     sampled_op = dist.Categorical(output.squeeze(0).exp()).sample()
     a[:, 0] = sampled_op
     for i in range(0, self.Number_qubits - 1):
         output, hidden = self.forward(
             a[:, i], hidden)  #output: [1,bs,charset_length]
         sampled_op = dist.Categorical(output.squeeze(0).exp()).sample()
         a[:, i + 1] = sampled_op
     return a
Example #17
0
 def label(self, i: int, j: int) -> dist.Distribution:
     """
     Observed label distribution for each item (i) and label (j).
     """
     labeler = self.labelers[i, j].item()
     return dist.Categorical(
         self.confusion_matrix(labeler,
                               self.true_label(i).item()))
Example #18
0
 def select_action(self, obs):
     q_values = self.q_net(obs, rnncs=self.rnncs)  # [B, A]
     self.rnncs_ = self.q_net.get_rnncs()
     logits = ((q_values - self._get_v(q_values)) / self.alpha).exp()  # > 0   # [B, A]
     logits /= logits.sum(-1, keepdim=True)  # [B, A]
     cate_dist = td.Categorical(logits=logits)
     actions = cate_dist.sample()  # [B,]
     return actions, Data(action=actions)
Example #19
0
 def forward(self, x, z_prev):
     """
     x: shape=(BS,N)
     z_prev: shape=(BS,N)
     """
     logits = self.transform_x(x_n) + self.P[z_prev]  #shape=(BS,N,B)
     dist_z = dist.Categorical(logits=logits)
     return dist_z.sample()
Example #20
0
    def select_action(self,
                      state,
                      rand_flag=False,
                      eps_flag=False,
                      eps_value=1.0,
                      train_flag=False):
        def get_reverse_prob(probs):
            # assume probs.size() is size([1, action_size])
            rev_idxs = torch.arange(probs.size(-1) - 1,
                                    -1,
                                    -1,
                                    device=probs.device).long()
            with torch.no_grad():
                rev_probs = torch.index_select(probs, -1, rev_idxs)
            return rev_probs

        if train_flag:
            self.policy_net.train()
            action_probs = self.policy_net(state)  # size([1, action_size])
        else:
            self.policy_net.eval()
            with torch.no_grad():
                action_probs = self.policy_net(state)  # size([1, action_size])
        action_logps = probs_to_logits(action_probs)  # size([1, action_size])

        if rand_flag:
            if eps_flag and random.random() < eps_value:
                # print('use epsilon random policy')
                action_rev_probs = get_reverse_prob(action_probs)
                m = dist.Categorical(probs=action_rev_probs)
            else:
                m = dist.Categorical(probs=action_probs)
            action = m.sample()  # size([1])
            # action_logp = m.log_prob(action)  # size([1])
        else:
            action = torch.argmax(action_probs, dim=-1)  # size([1])

        assert action.requires_grad is False
        action_logp = action_logps.gather(-1, action.unsqueeze(0)).squeeze(
            -1)  # size([1])
        action = action.item()
        self.episode_actions.append(action)
        self.episode_action_logps.append(action_logp)
        self.episode_action_probs.append(action_probs)

        return action
Example #21
0
def select_action(env, model, side, hidden, config: Config):
    x = get_model_input(env, side).to(config.device).float()
    output, value, hidden = model.train().to(config.device)(x, hidden)
    distribution = dist.Categorical(F.softmax(output, dim=-1))
    action = distribution.sample()
    log_prob = distribution.log_prob(action)
    entropy = -(log_prob * output).sum(-1)
    return log_prob, action.item() + 1, value, hidden, entropy
Example #22
0
 def forward(self, x, **kwargs):
     p = self.p.expand(x.shape[0], self.p.shape[-1])
     if isinstance(self.action_space, spaces.Discrete):
         dist = distributions.Categorical(probs=F.softmax(p, dim=1))
     elif isinstance(self.action_space, spaces.Box):
         p = torch.chunk(p, 2, dim=1)
         dist = distributions.Normal(loc=p[0], scale=p[1])
     return dist, torch.ones_like(x)[:, :1]
Example #23
0
 def encode(self, x):
     feats = self.pointnet(x)
     log_prob_y = self.cat_encoder(feats)
     prob_y = torch.exp(log_prob_y)
     y_dis = distrib.Categorical(probs=prob_y).sample()
     y = one_hot(y_dis, self.clusters).to(x.device)
     z, _, _ = self.encode_z(y, feats)
     return y, z
Example #24
0
 def forward(self, observation: torch.FloatTensor) -> Any:
     if self.discrete:
         action_probs = self.logits_na(observation)
         return distributions.Categorical(action_probs)
     else:
         mean = self.mean_net(observation)
         dist = distributions.Normal(loc=mean, scale=torch.exp(self.logstd))
         return dist
Example #25
0
def _sample(dist, sampling_mode='greedy'):
    if sampling_mode == 'greedy':
        _, sample = torch.topk(dist, 1, dim=-1)
    elif sampling_mode == 'random':
        p = F.softmax(dist, dim=-1)
        sample = dis.Categorical(p).sample()
    sample = sample.squeeze()
    return sample
Example #26
0
 def __init__(self, lib):
     assert isinstance(lib, Library)
     super(CharacterTypeDist, self).__init__(lib)
     # override part type dist
     self.pdist = StrokeTypeDist(lib)
     # distribution of 'k' (number of strokes)
     assert len(lib.pkappa.shape) == 1
     self.kappa = dist.Categorical(probs=lib.pkappa)
Example #27
0
 def select_action(self, state):
     action_prob, value_pred = self.policy(state)
     dist = distributions.Categorical(action_prob)
     action = dist.sample()
     log_prob_action = dist.log_prob(action)
     self.log_prob_actions.append(log_prob_action)
     self.values.append(value_pred)
     return action
Example #28
0
def ppo_update(config, f_actor, diff_actor_opt, critic, critic_opt, memory_cache, update_type='meta'):
    # Actor is functional in meta, and normal in rl.
    summed_policy_loss = torch.zeros(1)
    summed_value_loss = torch.zeros(1)

    states, next_states, actions_init, rewards, dones, log_prob_actions_init = get_shaped_memory_sample(config, memory_cache)
    # Using critic to predict last reward. Just as a placeholder in case the trajectory is incomplete in the batch-mode.
    final_predicted_reward = 0.
    if dones[-1] == 0.:  # Then last step is not done. Last value has to be predicted.
        final_state = next_states[-1]
        with torch.no_grad():
            final_predicted_reward = critic(final_state).detach().item()
    returns = calculate_returns(config, rewards, dones, predicted_end_reward=final_predicted_reward) #Returns(samples,1)
    # At this point, they should always be tensors and output a tensor based solution.
    values_init = critic(states)
    advantages = returns - values_init
    if config.normalize_rewards_and_advantages:
        advantages = (advantages - advantages.mean()) / advantages.std()
    advantages = advantages.detach()  # Necessary to keep the advantages from have a connection to the value model.
    # Now the actor makes steps and recalculates actions and log_probs based on the current values for k epochs.

    for ppo_step in range(config.num_ppo_steps):
        action_prob = f_actor(states)
        # print('action_prob', type(action_prob), action_prob.shape, action_prob)
        values_pred = critic(states)
        if config.env_config.action_space_type == 'discrete':
            dist = distributions.Categorical(action_prob) ## Stupido
            actions_init = actions_init.squeeze(-1)
            new_log_prob_actions = dist.log_prob(actions_init)
            new_log_prob_actions = new_log_prob_actions.view(-1, 1)
        elif config.env_config.action_space_type == 'continuous':
            action_mean_vector = action_prob * f_actor.action_upper_limit  # Direct code from actor get_action, refer there
            dist = distributions.MultivariateNormal(action_mean_vector, f_actor.covariance_matrix)
            actions_init = actions_init.view(-1, config.action_dim)
            new_log_prob_actions = dist.log_prob(actions_init)
            new_log_prob_actions = new_log_prob_actions.view(-1, 1)

        policy_ratio = (new_log_prob_actions - log_prob_actions_init).exp()
        policy_loss_1 = policy_ratio * advantages
        policy_loss_2 = torch.clamp(policy_ratio, min=1.0 - config.ppo_clip, max=1.0 + config.ppo_clip) * advantages
        if config.include_entropy_in_ppo:
            inner_policy_loss = (
                        -torch.min(policy_loss_1, policy_loss_2) - config.entropy_coefficient * dist.entropy()).sum()
        else:
            inner_policy_loss = -torch.min(policy_loss_1, policy_loss_2).sum()
        if update_type == 'meta':
            diff_actor_opt.step(inner_policy_loss)
        else:
            # In this case, it's normal RL, and so there is no updating that happens outside in the main function.
            diff_actor_opt.zero_grad()
            inner_policy_loss.backward()
            diff_actor_opt.step()
        inner_value_loss = F.smooth_l1_loss(values_pred, returns).sum()
        inner_value_loss.backward()
        critic_opt.step()
        summed_policy_loss += inner_policy_loss
        summed_value_loss += inner_value_loss
    return summed_policy_loss, summed_value_loss.item()
Example #29
0
def generate_rollout(
    world: GridWorld,
    agent: nn.Module,
    grammar_goal: str,
    critic: nn.Module,
    task_idx: int,
    deterministic: bool = False,
) -> Trajectory:
    samples: Sequence[Sample] = []

    # Perform typical RL loop.
    agent.reset(grammar_goal, device=args.device)
    obs_raw: Observation = world.reset()
    while True:
        obs = encode_observation(obs_raw).to(args.device)  # size[D]

        primitive_idx = None
        if args.agent_type == "ppg":
            agent_state, action_probs = agent(obs.unsqueeze(0))  # size[1, *]
        elif args.agent_type == "sketch":
            agent_state, action_probs, primitive_idx = agent(obs.unsqueeze(0))

        state_value = critic(agent_state)[:, task_idx]

        action_probs = action_probs.squeeze(0)
        state_value = state_value.squeeze(0)

        action_dist = dist.Categorical(action_probs)
        action = action_dist.sample(
        ) if not deterministic else action_probs.argmax()
        log_prob = action_dist.log_prob(action)

        action_raw: Action = Action(action.item())
        obs_raw, reward_raw, done, info = world.step(action_raw)
        reward = torch.tensor(float(reward_raw)).to(args.device)

        # Must detach results computed by neural networks from computational graph as these values
        # are just used to compute gradients for the model. We don't actually want the gradients to
        # be propagating through them.
        samples.append(
            Sample(
                obs=obs,
                action=action,
                reward=reward,
                log_prob=log_prob.detach(),
                state_value=state_value.detach(),
                ret=None,
                advantage=None,
                primitive_idx=primitive_idx,
            ))
        if done: break

    samples = compute_returns(samples,
                              discount=args.discount_factor,
                              device=args.device)
    samples = compute_advantages(samples)

    return Trajectory(samples)
    def step(self, t, state, prev_output, detections, seq, *args, mode='teacher_forcing'):
        assert (mode in ['teacher_forcing', 'feedback'])
        device = detections.device
        b_s = detections.size(0)
        bos_idx = self.bos_idx
        state_1, state_2 = state[:2], state[2:]
        detections_mask = (torch.sum(detections, -1, keepdim=True) != 0).float()
        detections_mean = torch.sum(detections, 1) / torch.sum(detections_mask, 1)

        if mode == 'teacher_forcing':
            if self.training and t > 0 and self.ss_prob > .0:
                # Scheduled sampling
                coin = detections.data.new(b_s).uniform_(0, 1)
                coin = (coin < self.ss_prob).long()
                distr = distributions.Categorical(logits=prev_output)
                action = distr.sample()
                it = coin * action.data + (1 - coin) * seq[:, t - 1].data
                it = it.to(device)
            else:
                it = seq[:, t]
        elif mode == 'feedback': # test
            if t == 0:
                it = detections.data.new_full((b_s,), bos_idx).long()
            else:
                it = prev_output

        xt = self.embed(it)
        if self.with_relu:
            xt = F.relu(xt)
        input_1 = torch.cat([state_2[0], detections_mean, xt], 1)

        if self.with_visual_sentinel:
            g_t = torch.sigmoid(self.W_sx(input_1) + self.W_sh(state_1[0]))
        state_1 = self.lstm_cell_1(input_1, state_1)

        att_weights = torch.tanh(self.att_va(detections) + self.att_ha(state_1[0]).unsqueeze(1))
        att_weights = self.att_a(att_weights)

        if self.with_visual_sentinel:
            s_t = g_t * torch.tanh(state_1[1])
            fc_sentinel = self.fc_sentinel(s_t).unsqueeze(1)
            if self.with_relu:
                fc_sentinel = F.relu(fc_sentinel)
            detections = torch.cat([fc_sentinel, detections], 1)
            detections_mask = (torch.sum(detections, -1, keepdim=True) != 0).float()
            sent_att_weights = torch.tanh(self.W_sas(s_t) + self.att_ha(state_1[0])).unsqueeze(1)
            sent_att_weights = self.W_sa(sent_att_weights)
            att_weights = torch.cat([sent_att_weights, att_weights], 1)

        att_weights = F.softmax(att_weights, 1)
        att_weights = detections_mask * att_weights
        att_weights = att_weights / torch.sum(att_weights, 1, keepdim=True)
        att_detections = torch.sum(detections * att_weights, 1)
        input_2 = torch.cat([state_1[0], att_detections], 1)

        state_2 = self.lstm_cell_2(input_2, state_2)
        out = F.log_softmax(self.out_fc(state_2[0]), dim=-1)
        return out, (state_1[0], state_1[1], state_2[0], state_2[1])