Exemple #1
0
    def unpack_ppo_batch(self, batch):
        """
                        batch.state: tuple of num_episodes of FloatTensor (1, t, state-dim), where t is variable
            batch.action: tuple of num_episodes of tuples of (LongTensor (1), LongTensor (1)) for (action, index)
            batch.reward: tuple of num_episodes scalars in {0,1}
            batch.masks: tuple of num_episodes scalars in {0,1}
            batch.value: tuple of num_episodes Variable FloatTensor of (1,1)
            batch.logprob: tuple of num_episodes of tuples of (FloatTensor (1), FloatTensor (1)) for (action_logprob, index_logprob)

            states is not a variable
            actions is not a variable: tuple of length (B) of LongTensor (1)
            secondary_actions is not a variable: tuple of length (B) of LongTensor (1)
            action_logprobs is not a variable (B)
            secondary_log_probs is not a variable (B)
            values is not a variable: (B, 1)
            rewards is not a variable: FloatTensor (B)
            masks is not a variable: FloatTensor (B)
            perm_idx is a tuple
            group_idx is an array 
        """
        lengths = [e.size(1) for e in batch.state]
        perm_idx, sorted_lengths = u.sort_decr(lengths)
        group_idx, group_lengths = u.group_by_element(sorted_lengths)

        states = batch.state  # tuple of num_episodes of FloatTensor (1, t, state-dim), where t is variable
        actions, secondary_actions = zip(*batch.action)
        action_logprobs, secondary_log_probs = zip(*batch.logprob)
        action_logprobs = torch.cat(action_logprobs).data  # FloatTensor (B)
        secondary_log_probs = torch.cat(secondary_log_probs).data  # FloatTensor (B)
        values = torch.cat(batch.value).data  # FloatTensor (B, 1)
        rewards = u.cuda_if_needed(torch.from_numpy(np.stack(batch.reward)).float(), self.args)  # FloatTensor (b)
        masks = u.cuda_if_needed(torch.from_numpy(np.stack(batch.mask)).float(), self.args)  # FloatTensor (b)
        return states, actions, secondary_actions, action_logprobs, secondary_log_probs, values, rewards, masks, perm_idx, group_idx
Exemple #2
0
 def get_log_prob(self, state, action, secondary_action):
     b, t, d = state.size()
     action_dist, rnn_out, summarized_rnn_out = self.forward(state)
     action_log_prob = logprob_categorical_dist(action_dist, action)
     if action.data[0] == 2:  # STOP
         stop_dist = cuda_if_needed(Variable(torch.ones(1)), self.args)
         secondary_log_prob = logprob_categorical_dist(
             stop_dist, secondary_action)
     elif action.data[0] == 1:  # REDUCTION
         indices = self.get_indices(state.data)
         if indices.sum() > 0:
             reduction_scores = self.get_reducer_dist(rnn_out)
             reduction_dist = F.softmax(reduction_scores, dim=-1)
         else:
             reduction_dist = cuda_if_needed(
                 Variable(torch.ones(b, 1), volatile=action.volatile),
                 self.args)
         secondary_log_prob = logprob_categorical_dist(
             reduction_dist, secondary_action)
     elif action.data[0] == 0:  # TRANSLATE
         translator_scores = self.get_translator_dist(summarized_rnn_out)
         translator_dist = F.softmax(translator_scores, dim=-1)
         secondary_log_prob = logprob_categorical_dist(
             translator_dist, secondary_action)
     else:
         assert False
     return action_log_prob, secondary_log_prob
Exemple #3
0
    def ppo_step(self, num_value_iters, states, actions, indices, returns, advantages, fixed_action_logprobs, fixed_index_logprobs, lr_mult, lr, clip_epsilon, l2_reg):
        clip_epsilon = clip_epsilon * lr_mult

        """update critic"""
        values_target = Variable(u.cuda_if_needed(returns, self.args))  # (mb, 1)
        for k in range(num_value_iters):
            values_pred = self.valuefn(Variable(states))  # (mb, 1)
            value_loss = (values_pred - values_target).pow(2).mean()
            # weight decay
            for param in self.valuefn.parameters():
                value_loss += param.pow(2).sum() * l2_reg
            self.value_optimizer.zero_grad()
            value_loss.backward()
            self.value_optimizer.step()

        """update policy"""
        advantages_var = Variable(u.cuda_if_needed(advantages, self.args)).view(-1)  # (mb)

        ########################################
        perm_idx, sorted_actions = u.sort_decr(actions)
        inverse_perm_idx = u.invert_permutation(perm_idx)
        group_idx, group_actions = u.group_by_element(sorted_actions)

        # permute everything by action type
        states_ap, actions_ap, indices_ap = map(lambda x: u.permute(x, perm_idx), [states, actions, indices])

        # group everything by action type
        states_ag, actions_ag, indices_ag = map(lambda x: u.group_by_indices(x, group_idx), [states_ap, actions_ap, indices_ap])

        action_logprobs, index_logprobs = [], []
        for grp in xrange(len(group_idx)):
            states_grp = torch.stack(states_ag[grp])  # (g, grp_length, indim)
            actions_grp = torch.LongTensor(np.stack(actions_ag[grp]))  # (g)
            indices_grp = torch.LongTensor(np.stack(indices_ag[grp]))  # (g)

            actions_grp = u.cuda_if_needed(actions_grp, self.args)
            indices_grp = u.cuda_if_needed(indices_grp, self.args)

            alp, ilp = self.policy.get_log_prob(Variable(states_grp), Variable(actions_grp), Variable(indices_grp))

            action_logprobs.append(alp)
            index_logprobs.append(ilp)

        action_logprobs = torch.cat(action_logprobs)
        index_logprobs = torch.cat(index_logprobs)

        # unpermute
        inverse_perm_idx = u.cuda_if_needed(torch.LongTensor(inverse_perm_idx), self.args)
        action_logprobs = action_logprobs[inverse_perm_idx]
        index_logprobs = index_logprobs[inverse_perm_idx]
        ########################################
        ratio = torch.exp(action_logprobs + index_logprobs - Variable(fixed_action_logprobs) - Variable(fixed_index_logprobs))
        surr1 = ratio * advantages_var  # (mb)
        surr2 = torch.clamp(ratio, 1.0 - clip_epsilon, 1.0 + clip_epsilon) * advantages_var  # (mb)
        policy_surr = -torch.min(surr1, surr2).mean()
        self.policy_optimizer.zero_grad()
        policy_surr.backward()
        torch.nn.utils.clip_grad_norm(self.policy.parameters(), 40)
        self.policy_optimizer.step()
Exemple #4
0
    def improve_policy_ppo(self):
        optim_epochs = self.args.ppo_optim_epochs  # can anneal this
        minibatch_size = self.args.ppo_minibatch_size
        num_value_iters = self.args.ppo_value_iters
        clip_epsilon = self.args.ppo_clip
        gamma = self.args.gamma
        tau = 0.95
        l2_reg = 1e-3

        batch = self.replay_buffer.sample()

        all_states, all_actions, all_indices, all_fixed_action_logprobs, all_fixed_index_logprobs, all_values, all_rewards, all_masks, perm_idx, group_idx = self.unpack_ppo_batch(batch)
        all_advantages, all_returns = self.estimate_advantages(all_rewards, all_masks, all_values, gamma, tau) # (b, 1) (b, 1)

        # permute everything by length
        states_p, actions_p, indices_p, returns_p, advantages_p, fixed_action_logprobs_p, fixed_index_logprobs_p = map(
            lambda x: u.permute(x, perm_idx), [all_states, all_actions, all_indices, all_returns, all_advantages, all_fixed_action_logprobs, all_fixed_index_logprobs])

        # group everything by length
        states_g, actions_g, indices_g, returns_g, advantages_g, fixed_action_logprobs_g, fixed_index_logprobs_g = map(
            lambda x: u.group_by_indices(x, group_idx), [states_p, actions_p, indices_p, returns_p, advantages_p, fixed_action_logprobs_p, fixed_index_logprobs_p])
        
        for j in range(optim_epochs):

            for grp in range(len(group_idx)):
                states = torch.cat(states_g[grp], dim=0)  # FloatTensor (g, grp_length, indim)
                actions = torch.cat(actions_g[grp])  # LongTensor (g)
                indices = torch.cat(indices_g[grp])  # LongTensor (g)
                returns = torch.cat(returns_g[grp])  # FloatTensor (g)
                advantages = torch.cat(advantages_g[grp])  # FloatTensor (g)
                fixed_action_logprobs = u.cuda_if_needed(torch.FloatTensor(fixed_action_logprobs_g[grp]), self.args)  # FloatTensor (g)
                fixed_index_logprobs = u.cuda_if_needed(torch.FloatTensor(fixed_index_logprobs_g[grp]), self.args)  # FloatTensor (g)

                for x in [states, actions, indices, returns, advantages, fixed_action_logprobs, fixed_index_logprobs]:
                    assert not isinstance(x, torch.autograd.variable.Variable)

                perm = np.random.permutation(range(states.shape[0]))
                perm = u.cuda_if_needed(torch.LongTensor(perm), self.args)

                states, actions, indices, returns, advantages, fixed_action_logprobs, fixed_index_logprobs = \
                    states[perm], actions[perm], indices[perm], returns[perm], advantages[perm], fixed_action_logprobs[perm], fixed_index_logprobs[perm]

                optim_iter_num = int(np.ceil(states.shape[0] / float(minibatch_size)))
                for i in range(optim_iter_num):
                    ind = slice(i * minibatch_size, min((i + 1) * minibatch_size, states.shape[0]))
                    
                    states_b, actions_b, indices_b, advantages_b, returns_b, fixed_action_logprobs_b, fixed_index_logprobs_b = \
                        states[ind], actions[ind], indices[ind], advantages[ind], returns[ind], fixed_action_logprobs[ind], fixed_index_logprobs[ind]

                    self.ppo_step(num_value_iters, states_b, actions_b, indices_b, returns_b, advantages_b, fixed_action_logprobs_b, fixed_index_logprobs_b,
                        1, self.args.plr, clip_epsilon, l2_reg)
Exemple #5
0
 def pad(self, encoder_out):
     b, d = encoder_out.size()
     encoder_out = encoder_out.unsqueeze(1)  # unsqueeze the time dimension
     padding = cuda_if_needed(Variable(torch.zeros(b, self.outlength-1, d)), self.args)
     padded_encoder_out = torch.cat((encoder_out, padding), dim=1)
     padded_encoder_out = padded_encoder_out.contiguous()
     return padded_encoder_out
Exemple #6
0
def create_lang_batch(env, bsize, mode, args):
    volatile = mode != 'train'
    z = 1
    whole_expr = np.random.binomial(n=1, p=0.5)
    enc_inps = []
    target_tokens = []
    zs = []
    targets = []

    for j in range(bsize):
        initial, target = env.reset(mode, z)
        enc_inps.append(
            np.stack([du.num2onehot(x, env.vocabsize) for x in initial[0]]))

        target_tokens.append(du.num2onehot(initial[1], env.langsize))
        zs.append(du.num2onehot(initial[2], env.zsize))
        targets.append(target)

    env.change_mt()

    enc_inps = torch.FloatTensor(
        np.array(enc_inps))  # (b, inp_seq_length, vocabsize)

    target_tokens = torch.FloatTensor(target_tokens)  # (b, langsize)
    zs = torch.FloatTensor(zs)  # (b, zsize)

    targets = torch.LongTensor(targets)  # (b, 1)

    enc_inps, target_tokens, zs, targets = map(
        lambda x: cuda_if_needed(x, args),
        (enc_inps, target_tokens, zs, targets))

    targets = Variable(targets, volatile=volatile)

    return (enc_inps, target_tokens, zs), targets
Exemple #7
0
 def init_hidden(self, bsize):
     (num_directions,
      h_hdim) = (2,
                 self.hdim // 2) if self.args.bidirectional else (1,
                                                                  self.hdim)
     return cuda_if_needed(
         Variable(torch.zeros(self.nlayers * num_directions, bsize,
                              h_hdim)), self.args)
Exemple #8
0
    def improve(self, args):

        ###########################################################################

        batch = self.replay_buffer.sample()

        states, actions, rewards, masks = self.unpack_ppo_batch(
            batch)  # none of these are Variables, so we are good
        states, actions, rewards, masks = map(
            lambda x: u.cuda_if_needed(x, args),
            (states, actions, rewards, masks))

        values = self.valuefn(Variable(states, volatile=True)).data  # (b, 1)
        fixed_log_probs = self.policy.get_log_prob(
            Variable(states, volatile=True), Variable(actions)).data  # (b)

        advantages, returns = self.estimate_advantages(rewards, masks,
                                                       values)  # (b, 1) (b, 1)

        optim_iter_num = int(
            np.ceil(states.shape[0] / float(self.minibatch_size)))
        for j in range(self.optim_epochs):
            perm = np.random.permutation(range(states.shape[0]))
            perm = u.cuda_if_needed(torch.LongTensor(perm), args)

            states, actions, returns, advantages, fixed_log_probs = \
                states[perm], actions[perm], returns[perm], advantages[perm], fixed_log_probs[perm]

            for i in range(optim_iter_num):
                ind = slice(
                    i * self.minibatch_size,
                    min((i + 1) * self.minibatch_size, states.shape[0]))

                states_b, actions_b, advantages_b, returns_b, fixed_log_probs_b = \
                    states[ind], actions[ind], advantages[ind], returns[ind], fixed_log_probs[ind]

                minibatch = {
                    'states': states_b,
                    'actions': actions_b,
                    'returns': returns_b,
                    'advantages': advantages_b,
                    'fixed_log_probs': fixed_log_probs_b
                }

                self.ppo_step(minibatch=minibatch, args=args)
Exemple #9
0
    def select_action(self, state):
        b, t, d = state.size()

        action_dist, rnn_out, summarized_rnn_out = self.forward(state)
        """
        action_dist: (b, 3)
        rnn_out: (b, t, hdim)
        summarized_rnn_out: (b, hdim)
        """
        action = sample_from_categorical_dist(action_dist)  # Variable (b)

        if action.data[0] == 2:  # STOP
            stop_dist = cuda_if_needed(Variable(torch.ones(1)),
                                       self.args)  # dummy
            secondary_action = sample_from_categorical_dist(stop_dist)
        elif action.data[0] == 1:  # REDUCE
            indices = self.get_indices(state.data)
            if indices.sum() > 0:
                reduction_scores = self.get_reducer_dist(rnn_out)
                reduction_dist = F.softmax(reduction_scores, dim=-1)
            else:
                reduction_dist = cuda_if_needed(
                    Variable(torch.ones(b, 1), volatile=action.volatile),
                    self.args)
            secondary_action = sample_from_categorical_dist(reduction_dist)
        elif action.data[0] == 0:  # TRANSLATE
            translator_scores = self.get_translator_dist(summarized_rnn_out)
            translator_dist = F.softmax(translator_scores, dim=-1)
            secondary_action = sample_from_categorical_dist(translator_dist)
        else:
            assert False
        dist_type = action.data[0]
        if action.data[0] == 2:
            choice_dist = stop_dist.data.cpu().squeeze().numpy()
        elif action.data[0] == 1:
            choice_dist = reduction_dist.data.cpu().squeeze().numpy()
        elif action.data[0] == 0:
            choice_dist = translator_dist.data.cpu().squeeze().numpy()
        else:
            assert False
        meta_dist = action_dist.data.cpu().squeeze().numpy()

        return action.data, secondary_action.data, (dist_type, choice_dist,
                                                    meta_dist)
Exemple #10
0
    def ppo_step(self, minibatch, args):

        states = minibatch['states']
        actions = minibatch['actions']
        returns = minibatch['returns']
        advantages = minibatch['advantages']
        fixed_log_probs = minibatch['fixed_log_probs']

        ###########################################################################

        self.clip_epsilon = self.clip_epsilon * self.lr_mult  # NOTE: this is deprecated. Set self.lr_mult=1. We can anneal based on pytorch's scheduler.
        """update critic"""
        values_target = Variable(u.cuda_if_needed(returns, args))  # (mb, 1)
        for k in range(self.value_iters):
            values_pred = self.valuefn(Variable(states))  # (mb, 1)
            value_loss = (values_pred - values_target).pow(2).mean()
            # weight decay
            for param in self.valuefn.parameters():
                value_loss += param.pow(2).sum() * self.l2_reg
            self.value_optimizer.zero_grad()
            value_loss.backward()
            self.value_optimizer.step()
        """update policy"""
        advantages_var = Variable(u.cuda_if_needed(advantages,
                                                   args)).view(-1)  # (mb)
        log_probs = self.policy.get_log_prob(Variable(states),
                                             Variable(actions))  # (mb)
        probs = torch.exp(log_probs)  # (mb)
        entropy = torch.sum(-(log_probs * probs))  # (1)
        ratio = torch.exp(log_probs - Variable(fixed_log_probs))  # (mb)
        surr1 = ratio * advantages_var  # (mb)
        surr2 = torch.clamp(ratio, 1.0 - self.clip_epsilon,
                            1.0 + self.clip_epsilon) * advantages_var  # (mb)
        policy_surr = -torch.min(surr1, surr2).mean(
        ) - self.entropy_coeff * entropy  # (1)  subtract entropy!
        self.policy_optimizer.zero_grad()
        policy_surr.backward()
        torch.nn.utils.clip_grad_norm(self.policy.parameters(), 40)
        self.policy_optimizer.step()