def test_one_hot_categorical_shape(self):
     dist = OneHotCategorical(torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
     self.assertEqual(dist._batch_shape, torch.Size((3,)))
     self.assertEqual(dist._event_shape, torch.Size((2,)))
     self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
     self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
     self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
     self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
     self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((2, 3)))
示例#2
0
class TestMultiOneHotCategorical(unittest.TestCase):

    def setUp(self) -> None:
        self.test_probs = torch.tensor([[0.3, 0.2, 0.4, 0.1, 0.25, 0.5, 0.25, 0.3, 0.4, 0.1, 0.1, 0.1],
                                        [0.2, 0.3, 0.1, 0.4, 0.5, 0.3, 0.2, 0.2, 0.3, 0.2, 0.2, 0.1]])
        self.test_sections = (4, 3, 5)

        self.test_actions = torch.tensor([[0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0.],
                                          [0., 0., 0., 1., 1., 0., 0., 0., 0., 1., 0., 0.]]).long()

        self.test_sected_actions = torch.split(self.test_actions, self.test_sections, dim=-1)

        self.test_multi_onehot_categorical = MultiOneHotCategorical(self.test_probs, self.test_sections)

        self.test_onehot_categorical1 = OneHotCategorical(self.test_probs[:, :4])
        self.test_onehot_categorical2 = OneHotCategorical(self.test_probs[:, 4:7])
        self.test_onehot_categorical3 = OneHotCategorical(self.test_probs[:, 7:])

    def test_log_prob(self):
        test_cat1_log_prob = self.test_onehot_categorical1.log_prob(self.test_sected_actions[0])
        test_cat2_log_prob = self.test_onehot_categorical2.log_prob(self.test_sected_actions[1])
        test_cat3_log_prob = self.test_onehot_categorical3.log_prob(self.test_sected_actions[2])

        test_multi_cat_log_prob = self.test_multi_onehot_categorical.log_prob(self.test_actions)
        print(test_multi_cat_log_prob)
        print(test_cat1_log_prob)
        self.assertEqual(test_cat1_log_prob.shape, test_multi_cat_log_prob.shape)

        self.assertTrue(
            torch.equal(test_cat1_log_prob + test_cat2_log_prob + test_cat3_log_prob, test_multi_cat_log_prob))

    def test_sample(self):
        test_cat1_sample = self.test_onehot_categorical1.sample()
        test_cat2_sample = self.test_onehot_categorical2.sample()
        test_cat3_sample = self.test_onehot_categorical3.sample()

        test_cat_sample = torch.cat([test_cat1_sample, test_cat2_sample, test_cat3_sample], dim=-1)
        test_multi_cat_sample = self.test_multi_onehot_categorical.sample()

        self.assertEqual(test_cat_sample.shape, test_multi_cat_sample.shape)
        self.assertTrue(torch.equal(test_cat_sample.sum(dim=-1),
                                    test_multi_cat_sample.sum(dim=-1)))

    def test_entropy(self):
        test_cat1_entropy = self.test_onehot_categorical1.entropy()
        test_cat2_entropy = self.test_onehot_categorical2.entropy()
        test_cat3_entropy = self.test_onehot_categorical3.entropy()

        test_multi_cat_entropy = self.test_multi_onehot_categorical.entropy()

        self.assertTrue(torch.equal(test_cat1_entropy + test_cat2_entropy + test_cat3_entropy, test_multi_cat_entropy),
                        "Expected same entropy!!!")
示例#3
0
    def forward(self,
                inp: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, None]:
        """
        Returns a sample of the policy on the input with the mean and log
        probability of the sample.

        Args:
            inp: The input tensor to put through the network.categorie

        Returns:
            A multi-categorical distribution of the network.
        """
        linear = self.linear(inp)

        value = self.value(linear)
        value = value.view(*value.shape[:-1], -1, self.num_classes)

        probs = self.probs(value)
        dist = OneHotCategorical(probs)

        sample = dist.sample()
        log_prob = dist.log_prob(sample)

        # Straight through gradient trick
        sample = sample + probs - probs.detach()

        mean = torch.argmax(probs, dim=-1).values

        return sample, log_prob, None
示例#4
0
 def forward(self, x):
     # For convenience we use torch.distributions to sample and compute the values of interest for the distribution see (https://pytorch.org/docs/stable/distributions.html) for more details.
     probs = self.encode(x.view(-1, 784))
     m = OneHotCategorical(probs)
     action = m.sample()
     log_prob = m.log_prob(action)
     entropy = m.entropy()
     return self.decode(action), log_prob, entropy
    def test_one_hot_categorical_2d(self):
        probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
        probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
        p = Variable(torch.Tensor(probabilities), requires_grad=True)
        s = Variable(torch.Tensor(probabilities_1), requires_grad=True)
        self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3))
        self.assertEqual(OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
        self.assertEqual(OneHotCategorical(p).sample_n(6).size(), (6, 2, 3))
        self._gradcheck_log_prob(OneHotCategorical, (p,))

        dist = OneHotCategorical(p)
        x = dist.sample()
        self.assertEqual(dist.log_prob(x), Categorical(p).log_prob(x.max(-1)[1]))
示例#6
0
    def forward(self, input, targets, args, n_particles, criterion, test=False):
        """
        This version takes the inputs, and does not expose the logits, but instead
        computes the losses directly
        """

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (h, c) = self.encoder(emb, hidden)

        # teacher-forcing
        out_emb = self.dropout(self.dec_embedding(targets))

        # now, we'll replicate it for each particle - it's currently [seq_len x batch_sz x nhid]
        hidden_states = hidden_states.repeat(1, n_particles, 1)
        out_emb = out_emb.repeat(1, n_particles, 1)
        # now [seq_len x (n_particles x batch_sz) x nhid]
        # out_emb, hidden_states should be viewed as (n_particles x batch_sz) - this means that's true for h as well

        # run the z-decoder at this point, evaluating the NLL at each step
        p_h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)  # initially zero
        h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)
        d_h = self.init_hidden(batch_sz * n_particles, self.nhid, squeeze=True)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        accumulated_weights = -math.log(n_particles)  # will contain log w_{t - 1}
        resamples = 0

        for i in range(seq_len):
            h = self.z_decoder(hidden_states[i], h)
            logits = self.logits(h)

            # build the next z sample
            if test:
                q = OneHotCategorical(logits=logits)
                z = q.sample()
            else:
                q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits)
                z = q.rsample()
            h = z

            # prior
            if test:
                p = OneHotCategorical(logits=p_h)
            else:
                p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=p_h)

            # now, compute the log-likelihood of the data given this mean, and the input out_emb
            d_h = self.decoder(torch.cat([z, out_emb[i]], 1), d_h)
            decoder_logits = self.out_embedding(d_h)
            NLL = criterion(decoder_logits, input[i].repeat(n_particles))
            nlls[i] = NLL.data

            # compute the weight using `reweight` on page (4)
            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + args.anneal * (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            # sample ancestors, and reindex everything
            Z = log_sum_exp(wa, dim=0)  # line 7
            if (Z.data > 0.1).any():
                pdb.set_trace()

            loss += Z  # line 8
            accumulated_weights = wa - Z  # line 9
            probs = accumulated_weights.data.exp()
            probs += 0.01
            probs = probs / probs.sum(0, keepdim=True)
            effective_sample_size = 1./probs.pow(2).sum(0)

            # resample / RSAMP if 3 batch elements need resampling
            if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                resamples += 1
                ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True)

                # now, reindex, which is the most important thing
                offsets = n_particles * torch.arange(batch_sz).unsqueeze(1).repeat(1, n_particles).long()
                if ancestors.is_cuda:
                    offsets = offsets.cuda()
                unrolled_idx = Variable(ancestors.t().contiguous()+offsets).view(-1)
                h = torch.index_select(h, 0, unrolled_idx)
                p_h = torch.index_select(p_h, 0, unrolled_idx)
                d_h = torch.index_select(d_h, 0, unrolled_idx)

                # reset accumulated_weights
                accumulated_weights = -math.log(n_particles)  # will contain log w_{t - 1}

            if i != seq_len - 1:
                # build the next mean prediction, feeding in the correct ancestor
                p_h = self.ar_prior_logits(torch.cat([h, out_emb[i]], 1), p_h)

        # now, we calculate the final log-marginal estimator
        nll = nlls.view(seq_len, n_particles, batch_sz).mean(1).sum()
        return -loss.sum(), nll, (seq_len * batch_sz), resamples
示例#7
0
    def sampled_filter(self, input, args, n_particles, emb, hidden_states):
        seq_len, batch_sz = input.size()
        T = F.log_softmax(self.T, 0)  # NOTE: in log-space
        pi = F.log_softmax(self.pi, 0)  # NOTE: in log-space
        emit = self.calc_emit()

        hidden_states = hidden_states.repeat(1, n_particles, 1)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        accumulated_weights = -math.log(
            n_particles)  # will contain log w_{t - 1}
        resamples = 0

        # in log probability space
        prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                              self.z_dim)

        for i in range(seq_len):
            # the approximate posterior comes from the same thing as before
            logits = self.logits(hidden_states[i])

            if not self.training:
                # this is crucial!!
                p = OneHotCategorical(logits=prior_logits)
                q = OneHotCategorical(logits=logits)
                z = q.sample()
            else:
                p = RelaxedOneHotCategorical(temperature=self.temp_prior,
                                             logits=prior_logits)
                q = RelaxedOneHotCategorical(temperature=self.temp,
                                             logits=logits)
                z = q.rsample()

            # now, compute the log-likelihood of the data given this z-sample
            emission = F.embedding(input[i].repeat(n_particles), emit)
            NLL = -(emission * z).sum(1)
            # NLL = -self.decode(z, input[i].repeat(n_particles), (emit,))  # diff. w.r.t. z
            nlls[i] = NLL.data

            # compute the weight using `reweight` on page (4)
            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            Z = log_sum_exp(wa, dim=0)  # line 7

            loss += Z  # line 8
            accumulated_weights = wa - Z  # F.log_softmax(wa, dim=0)  # line 9

            # sample ancestors, and reindex everything
            if args.filter:
                probs = accumulated_weights.data.exp()
                probs += 0.01
                probs = probs / probs.sum(0, keepdim=True)
                effective_sample_size = 1. / probs.pow(2).sum(0)

                # probs is [n_particles, batch_sz]
                # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]]
                # offsets   [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]]

                # resample / RSAMP
                if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                    resamples += 1
                    ancestors = torch.multinomial(probs.transpose(0, 1),
                                                  n_particles, True)

                    # now, reindex, which is the most important thing
                    offsets = n_particles * torch.arange(batch_sz).unsqueeze(
                        1).repeat(1, n_particles).long()
                    if ancestors.is_cuda:
                        offsets = offsets.cuda()
                    unrolled_idx = Variable(ancestors + offsets).view(-1)
                    z = torch.index_select(z, 0, unrolled_idx)

                    # reset accumulated_weights
                    accumulated_weights = -math.log(
                        n_particles)  # will contain log w_{t - 1}

            if i != seq_len - 1:
                # now in log-probability space
                prior_logits = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2)

        if self.training:
            (-loss.sum() /
             (seq_len * batch_sz * n_particles)).backward(retain_graph=True)
        return -loss.sum(), nlls.sum(), seq_len * batch_sz * n_particles, 0
示例#8
0
    def forward(self, input, args, n_particles, test=False):
        """
        evaluation is the IWAE-10 bound
        """
        pi = F.log_softmax(self.pi, 0)

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = (Variable(
            hidden_states.data.new(batch_sz * n_particles,
                                   self.hidden_size).zero_()),
             Variable(
                 hidden_states.data.new(batch_sz * n_particles,
                                        self.hidden_size).zero_()))

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                              self.z_dim)
        prior_h = (Variable(torch.zeros(batch_sz * n_particles, 50).cuda()),
                   Variable(torch.zeros(batch_sz * n_particles, 50).cuda()))

        accumulated_weights = -math.log(
            n_particles)  # will contain log w_{t - 1}
        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)
        feed = None

        x_emb = self.lockdrop(emb, self.dropout_x)

        if test:
            pdb.set_trace()

        for i in range(seq_len):
            # build the next z sample - not differentiable! we don't train the inference network
            logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach()

            # if test:
            q = OneHotCategorical(logits=logits)
            p = OneHotCategorical(logits=prior_logits)
            a = q.sample()
            # else:
            #     q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits)
            #     p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits)
            #     a = q.rsample()

            # to guard against being too crazy
            b = a + 1e-16
            z = b / b.sum(1, keepdim=True)

            # this should be batch_sz x x_dim
            scores = torch.mm(self.project(torch.cat([h[0], z], 1)),
                              self.emit.t())

            NLL = nn.CrossEntropyLoss(reduce=False)(
                scores, input[i].repeat(n_particles))
            nlls[i] = NLL.data

            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            Z = log_sum_exp(wa, dim=0)  # line 7

            loss += Z  # line 8
            accumulated_weights = wa - Z  # F.log_softmax(wa, dim=0)  # line 9

            probs = accumulated_weights.data.exp()
            probs += 0.01
            probs = probs / probs.sum(0, keepdim=True)
            effective_sample_size = 1. / probs.pow(2).sum(0)

            if any_nans(probs):
                pdb.set_trace()

            # probs is [n_particles, batch_sz]
            # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]]
            # offsets   [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]]

            # resample / RSAMP
            if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                ancestors = torch.multinomial(probs.transpose(0, 1),
                                              n_particles, True)

                # now, reindex, which is the most important thing
                offsets = n_particles * torch.arange(batch_sz).unsqueeze(
                    1).repeat(1, n_particles).long()
                if ancestors.is_cuda:
                    offsets = offsets.cuda()
                unrolled_idx = Variable(ancestors + offsets).view(-1)

                # shuffle!
                z = torch.index_select(z, 0, unrolled_idx)
                a, b = h
                h = torch.index_select(a, 0, unrolled_idx), torch.index_select(
                    b, 0, unrolled_idx)
                a, b = prior_h
                prior_h = torch.index_select(a, 0,
                                             unrolled_idx), torch.index_select(
                                                 b, 0, unrolled_idx)

                # reset accumulated_weights
                accumulated_weights = -math.log(
                    n_particles)  # will contain log w_{t - 1}

            # set things up for next time
            if i != seq_len - 1:
                feed = torch.cat(
                    [emb[i].repeat(n_particles, 1),
                     self.z_emb(z)], 1)
                prior_h = self.z_decoder(feed, prior_h)
                prior_logits = F.log_softmax(self.project_z(prior_h[0]), 1)
                h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1),
                                    h)  # feed the next word into the RNN

        NLL = nlls

        # now, we calculate the final log-marginal estimator
        return loss.sum(), NLL.sum(), (seq_len * batch_sz * n_particles), 0
示例#9
0
class PPOTorchPolicy(TorchPolicy):
    def __init__(self, observation_space, action_space, config):
        super().__init__(observation_space, action_space, config)
        self.device = torch.device('cpu')

        # Get hyperparameters
        self.alpha = config['alpha']
        self.clip_ratio = config['clip_ratio']
        self.gamma = config['gamma']
        self.lam = config['lambda']
        self.lr_pi = config['lr_pi']
        self.lr_vf = config['lr_vf']
        self.model_hidden_sizes = config['model_hidden_sizes']
        self.num_skills = config['num_skills']
        self.skill_input = config['skill_input']
        self.target_kl = config['target_kl']
        self.use_diayn = config['use_diayn']
        self.use_env_rewards = config['use_env_rewards']
        self.use_gae = config['use_gae']

        # Initialize actor-critic model
        self.skills = OneHotCategorical(torch.ones((1, self.num_skills)))
        if self.skill_input is not None:
            skill_vec = [0.] * (self.num_skills - 1)
            skill_vec.insert(self.skill_input, 1.)
            self.z = torch.as_tensor([skill_vec], dtype=torch.float32)
        else:
            self.z = None
        self.model = SkilledA2C(observation_space,
                                action_space,
                                hidden_sizes=self.model_hidden_sizes,
                                skills=self.skills).to(self.device)

        # Set up optimizers for policy and value function
        self.pi_optimizer = Adam(self.model.pi.parameters(), self.lr_pi)
        self.vf_optimizer = Adam(self.model.vf.parameters(), self.lr_vf)
        self.disc_optimizer = Adam(self.model.discriminator.parameters(),
                                   self.lr_vf)

    def compute_loss_d(self, batch):
        obs, z = batch[SampleBatch.CUR_OBS], batch[SKILLS]
        logq_z = self.model.discriminator(obs)
        return nn.functional.nll_loss(logq_z, z.argmax(dim=-1))

    def compute_loss_pi(self, batch):
        obs, act, z = batch[
            SampleBatch.CUR_OBS], batch[ACTIVATIONS], batch[SKILLS]
        adv, logp_old = batch[Postprocessing.ADVANTAGES], batch[
            SampleBatch.ACTION_LOGP]
        clip_ratio = self.clip_ratio

        # Policy loss
        oz = torch.cat([obs, z], dim=-1)
        pi, logp = self.model.pi(oz, act)
        ratio = torch.exp(logp - logp_old)
        clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv
        loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()

        # Useful extra info
        approx_kl = (logp_old - logp).mean().item()
        ent = pi.entropy().mean().item()
        clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
        clip_frac = torch.as_tensor(clipped, dtype=torch.float32).mean().item()
        pi_info = dict(kl=approx_kl, ent=ent, cf=clip_frac)

        return loss_pi, pi_info

    def compute_loss_v(self, batch):
        obs, z = batch[SampleBatch.NEXT_OBS], batch[SKILLS]
        v_pred_old, v_targ = batch[SampleBatch.VF_PREDS], batch[
            Postprocessing.VALUE_TARGETS]

        oz = torch.cat([obs, z], dim=-1)
        v_pred = self.model.vf(oz)
        v_pred_clipped = v_pred_old + torch.clamp(
            v_pred - v_pred_old, -self.clip_ratio, self.clip_ratio)

        loss_clipped = (v_pred_clipped - v_targ).pow(2)
        loss_unclipped = (v_pred - v_targ).pow(2)

        return 0.5 * torch.max(loss_unclipped, loss_clipped).mean()

    def _convert_activation_to_action(self, activation):
        min_ = self.action_space.low
        max_ = self.action_space.high
        return tanh_to_action(activation, min_, max_)

    def _normalize_obs(self, obs):
        min_ = self.observation_space.low
        max_ = self.observation_space.high
        return normalize_obs(obs, min_, max_)

    @override(Policy)
    def compute_actions(self, obs, **kwargs):
        # Sample a skill at the start of each episode
        if self.z is None:
            self.z = self.skills.sample()

        o = self._normalize_obs(obs)
        a, v, logp_a, logq_z = self.model.step(
            torch.as_tensor(o, dtype=torch.float32), self.z)

        actions = self._convert_activation_to_action(a)
        extras = {
            ACTIVATIONS: a,
            SampleBatch.VF_PREDS: v,
            SampleBatch.ACTION_LOGP: logp_a,
            SKILLS: self.z.numpy(),
            SKILL_LOGQ: logq_z
        }
        return actions, [], extras

    @override(Policy)
    def postprocess_trajectory(self,
                               batch,
                               other_agent_batches=None,
                               episode=None):
        """Adds the policy logits, VF preds, and advantages to the trajectory."""

        completed = batch["dones"][-1]
        if completed:
            # Force end of episode reward
            last_r = 0.0

            # Reset skill at the end of each episode
            self.z = None
        else:
            next_state = []
            for i in range(self.num_state_tensors()):
                next_state.append([batch["state_out_{}".format(i)][-1]])
            obs = [batch[SampleBatch.NEXT_OBS][-1]]
            o = self._normalize_obs(obs)
            _, last_r, _, _ = self.model.step(
                torch.as_tensor(o, dtype=torch.float32), self.z)
            last_r = last_r.item()

        # Compute DIAYN rewards
        if self.use_diayn:
            z = torch.as_tensor(batch[SKILLS], dtype=torch.float32)
            logp_z = self.skills.log_prob(z).numpy()
            logq_z = batch[SKILL_LOGQ][:, z.argmax(dim=-1)[0].item()]
            entropy_reg = self.alpha * batch[SampleBatch.ACTION_LOGP]
            diayn_rewards = logq_z - logp_z - entropy_reg

            if self.use_env_rewards:
                batch[SampleBatch.REWARDS] += diayn_rewards
            else:
                batch[SampleBatch.REWARDS] = diayn_rewards

        batch = compute_advantages(batch,
                                   last_r,
                                   gamma=self.gamma,
                                   lambda_=self.lam,
                                   use_gae=self.use_gae)
        return batch

    @override(Policy)
    def learn_on_batch(self, postprocessed_batch):
        postprocessed_batch[SampleBatch.CUR_OBS] = self._normalize_obs(
            postprocessed_batch[SampleBatch.CUR_OBS])
        train_batch = self._lazy_tensor_dict(postprocessed_batch)

        # Train policy with multiple steps of gradient descent
        self.pi_optimizer.zero_grad()
        loss_pi, pi_info = self.compute_loss_pi(train_batch)
        # if pi_info['kl'] > 1.5 * self.target_kl:
        #     logger.info('Early stopping at step %d due to reaching max kl.' % i)
        #     return
        loss_pi.backward()
        self.pi_optimizer.step()

        # Value function learning
        self.vf_optimizer.zero_grad()
        loss_v = self.compute_loss_v(train_batch)
        loss_v.backward()
        self.vf_optimizer.step()

        # Discriminator learning
        self.disc_optimizer.zero_grad()
        loss_d = self.compute_loss_d(train_batch)
        loss_d.backward()
        self.disc_optimizer.step()

        grad_info = dict(pi_loss=loss_pi.item(),
                         vf_loss=loss_v.item(),
                         d_loss=loss_d.item(),
                         **pi_info)
        return {LEARNER_STATS_KEY: grad_info}
示例#10
0
    def forward(self, input, args, n_particles, test=False):
        T = F.log_softmax(self.T, 0)  # NOTE: in log-space
        pi = F.log_softmax(self.pi, 0)  # NOTE: in log-space
        emit = self.calc_emit()

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        accumulated_weights = -math.log(
            n_particles)  # will contain log w_{t - 1}
        resamples = 0

        # in log probability space
        prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                             self.z_dim)

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)

        for i in range(seq_len):
            # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits)))
            logits = self.logits(
                nn.functional.relu(
                    self.z_decoder(torch.cat([hidden_states[i], h], 1),
                                   logits)))

            # build the next z sample
            if any_nans(logits):
                pdb.set_trace()
            if test:
                q = OneHotCategorical(logits=logits)
                z = q.sample()
            else:
                q = RelaxedOneHotCategorical(temperature=self.temp,
                                             logits=logits)
                z = q.rsample()
            h = z

            # prior
            if any_nans(prior_probs):
                pdb.set_trace()
            if test:
                p = OneHotCategorical(logits=prior_probs)
            else:
                p = RelaxedOneHotCategorical(temperature=self.temp_prior,
                                             logits=prior_probs)

            if any_nans(prior_probs):
                pdb.set_trace()
            if any_nans(logits):
                pdb.set_trace()

            # now, compute the log-likelihood of the data given this z-sample
            NLL = -self.decode(z, input[i].repeat(n_particles),
                               (emit, ))  # diff. w.r.t. z
            nlls[i] = NLL.data

            # compute the weight using `reweight` on page (4)
            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            # sample ancestors, and reindex everything
            Z = log_sum_exp(wa, dim=0)  # line 7

            loss += Z  # line 8
            accumulated_weights = wa - Z  # line 9

            if args.filter:
                probs = accumulated_weights.data.exp()
                probs += 0.01
                probs = probs / probs.sum(0, keepdim=True)
                effective_sample_size = 1. / probs.pow(2).sum(0)

                # probs is [n_particles, batch_sz]
                # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]]
                # offsets   [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]]

                # resample / RSAMP
                if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                    resamples += 1
                    ancestors = torch.multinomial(probs.transpose(0, 1),
                                                  n_particles, True)

                    # now, reindex, which is the most important thing
                    offsets = n_particles * torch.arange(batch_sz).unsqueeze(
                        1).repeat(1, n_particles).long()
                    if ancestors.is_cuda:
                        offsets = offsets.cuda()
                    unrolled_idx = Variable(ancestors + offsets).view(-1)
                    h = torch.index_select(h, 0, unrolled_idx)

                    # reset accumulated_weights
                    accumulated_weights = -math.log(
                        n_particles)  # will contain log w_{t - 1}

            if i != seq_len - 1:
                # now in probability space
                prior_probs = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2)

                # let's normalize things - slower, but safer
                # prior_probs += 0.01
                # prior_probs = prior_probs / prior_probs.sum(1, keepdim=True)

            # # if ((prior_probs.sum(1) - 1) > 1e-3).any()[0]:
            #     pdb.set_trace()

        if any_nans(loss):
            pdb.set_trace()

        # now, we calculate the final log-marginal estimator
        return -loss.sum(), nlls.sum(), (seq_len * batch_sz *
                                         n_particles), resamples