示例#1
0
 def sampleiter(self, bs=1):
     """
     Ancestral sampling with MLP.
     
     1 sample is a tensor (1, M, N).
     A minibatch of samples is a tensor (bs, M, N).
     1 variable is a tensor (bs, 1, N)
     """
     while True:
         with torch.no_grad():
             h = []  # Hard (onehot) samples  (bs,1,N)
             for i in range(self.M):
                 O = torch.zeros(bs, self.M - i, self.N)  # (bs,M-i,N)
                 v = torch.cat(h + [O], dim=1)  # (bs,M-i,N) + (bs,1,N)*i
                 v = torch.einsum("hik,i,bik->bh", self.W0gt[i],
                                  self.gammagt[i], v)
                 v = v + self.B0gt[i].unsqueeze(0)
                 v = v.relu()
                 v = torch.einsum("oh,bh->bo", self.W1gt[i], v)
                 v = v + self.B1gt[i].unsqueeze(0)
                 v = v.softmax(dim=1).unsqueeze(1)
                 h.append(OneHotCategorical(v).sample())
             s = torch.cat(h, dim=1)
         yield s
    def latent_prior_sample(self, y, n_batch, n_samples):
        n_cat = self.n_labels
        n_latent = self.n_latent

        u = Normal(
            torch.zeros(n_latent),
            torch.ones(n_latent),
        ).sample((n_samples, n_batch))

        if y is None:
            ys = OneHotCategorical(probs=(1.0 / n_cat) *
                                   torch.ones(n_cat)).sample(
                                       (n_samples, n_batch))
        else:
            ys = torch.FloatTensor(n_batch, n_cat)
            ys.zero_()
            ys.scatter_(1, y.view(-1, 1), 1)
            ys = ys.view(1, n_batch, n_cat).expand(n_samples, n_batch, n_cat)

        z2_y = torch.cat([u, ys], dim=-1)
        pz1_z2m, pz1_z2_v = self.decoder_z1_z2(z2_y)
        z = Normal(pz1_z2m, pz1_z2_v).sample()
        return dict(z1=z, z2=u, ys=ys)
示例#3
0
        self.net = nn.Sequential(nn.Linear(s_dim, hidden), nn.ReLU(),
                                 nn.Linear(hidden, hidden), nn.ReLU(),
                                 nn.Linear(hidden, z_num))

    def forward(self, s, log=False):
        feature = self.net(s)
        if log:
            return F.log_softmax(feature, dim=-1)
        else:
            return F.softmax(feature, dim=-1)


if __name__ == "__main__":
    from torch.distributions import Categorical
    from torch.distributions import OneHotCategorical
    onehot = OneHotCategorical(torch.ones(4))
    s = torch.FloatTensor([1, 2])
    z = onehot.sample()  #torch.LongTensor([0,0,1,0])
    print(s, z)
    policy = Policy(s_dim=2, z_num=4, hidden=32, a_num=4)
    vnet = VNet(s_dim=2, z_num=4, hidden=32)
    qnet = QNet(s_dim=2, z_num=4, hidden=32, a_num=4)
    dis = Discriminator(s_dim=2, z_num=4, hidden=32)

    prob = policy(s, z)
    print(prob)
    dist = Categorical(prob)
    a = dist.sample()
    print(a)
    index = torch.LongTensor(range(1))
    v = vnet(s, z)
示例#4
0
 def variational_posterior(self, logits: torch.Tensor):
     return OneHotCategorical(probs=logits.softmax(dim=-1))
示例#5
0
    def forward(self, input, targets, args, n_particles, criterion, test=False):
        """
        This version takes the inputs, and does not expose the logits, but instead
        computes the losses directly
        """

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (h, c) = self.encoder(emb, hidden)

        # teacher-forcing
        out_emb = self.dropout(self.dec_embedding(targets))

        # now, we'll replicate it for each particle - it's currently [seq_len x batch_sz x nhid]
        hidden_states = hidden_states.repeat(1, n_particles, 1)
        out_emb = out_emb.repeat(1, n_particles, 1)
        # now [seq_len x (n_particles x batch_sz) x nhid]
        # out_emb, hidden_states should be viewed as (n_particles x batch_sz) - this means that's true for h as well

        # run the z-decoder at this point, evaluating the NLL at each step
        p_h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)  # initially zero
        h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)
        d_h = self.init_hidden(batch_sz * n_particles, self.nhid, squeeze=True)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        accumulated_weights = -math.log(n_particles)  # will contain log w_{t - 1}
        resamples = 0

        for i in range(seq_len):
            h = self.z_decoder(hidden_states[i], h)
            logits = self.logits(h)

            # build the next z sample
            if test:
                q = OneHotCategorical(logits=logits)
                z = q.sample()
            else:
                q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits)
                z = q.rsample()
            h = z

            # prior
            if test:
                p = OneHotCategorical(logits=p_h)
            else:
                p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=p_h)

            # now, compute the log-likelihood of the data given this mean, and the input out_emb
            d_h = self.decoder(torch.cat([z, out_emb[i]], 1), d_h)
            decoder_logits = self.out_embedding(d_h)
            NLL = criterion(decoder_logits, input[i].repeat(n_particles))
            nlls[i] = NLL.data

            # compute the weight using `reweight` on page (4)
            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + args.anneal * (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            # sample ancestors, and reindex everything
            Z = log_sum_exp(wa, dim=0)  # line 7
            if (Z.data > 0.1).any():
                pdb.set_trace()

            loss += Z  # line 8
            accumulated_weights = wa - Z  # line 9
            probs = accumulated_weights.data.exp()
            probs += 0.01
            probs = probs / probs.sum(0, keepdim=True)
            effective_sample_size = 1./probs.pow(2).sum(0)

            # resample / RSAMP if 3 batch elements need resampling
            if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                resamples += 1
                ancestors = torch.multinomial(probs.transpose(0, 1), n_particles, True)

                # now, reindex, which is the most important thing
                offsets = n_particles * torch.arange(batch_sz).unsqueeze(1).repeat(1, n_particles).long()
                if ancestors.is_cuda:
                    offsets = offsets.cuda()
                unrolled_idx = Variable(ancestors.t().contiguous()+offsets).view(-1)
                h = torch.index_select(h, 0, unrolled_idx)
                p_h = torch.index_select(p_h, 0, unrolled_idx)
                d_h = torch.index_select(d_h, 0, unrolled_idx)

                # reset accumulated_weights
                accumulated_weights = -math.log(n_particles)  # will contain log w_{t - 1}

            if i != seq_len - 1:
                # build the next mean prediction, feeding in the correct ancestor
                p_h = self.ar_prior_logits(torch.cat([h, out_emb[i]], 1), p_h)

        # now, we calculate the final log-marginal estimator
        nll = nlls.view(seq_len, n_particles, batch_sz).mean(1).sum()
        return -loss.sum(), nll, (seq_len * batch_sz), resamples
示例#6
0
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, log_prob, entropy = model(data)
            test_loss += loss_function(recon_batch, data, log_prob, entropy).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(args.batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))


for epoch in range(1, args.epochs + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        m = OneHotCategorical(torch.ones(256)/256.)
        sample = m.sample((64, 20))
        sample = sample.to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(64, 1, 28, 28),
                   'results/sample_' + str(epoch) + '.png')
示例#7
0
 def sample(self, params):
     pi, mean, log_std = params['pi'], params['mean'], params['log_std']
     pi_onehot = OneHotCategorical(pi).sample()
     ac = torch.sum((mean + torch.randn_like(mean) * torch.exp(log_std)) *
                    pi_onehot.unsqueeze(-1), 1)
     return ac
示例#8
0
import storch
import torch
from torch.distributions import Bernoulli, OneHotCategorical
from storch.method import RELAX, REBAR, ARM

torch.manual_seed(0)

p = torch.tensor(0.5, requires_grad=True)
d = Bernoulli(p)
sample = RELAX("sample", in_dim=1)(d)
# sample = ARM('sample', n_samples=10)(d)
storch.add_cost(sample, "cost")
storch.backward()

method = REBAR("test", n_samples=1)
x = torch.Tensor([[0.2, 0.4, 0.4], [0.5, 0.1, 0.4], [0.2, 0.2, 0.6],
                  [0.15, 0.15, 0.7]])
qx = OneHotCategorical(x)
print(method(qx))
示例#9
0
class GaussianMixture(Distribution):
    def __init__(self, normal_means, normal_stds, weights):
        self.num_gaussians = weights.shape[1]
        self.normal_means = normal_means
        self.normal_stds = normal_stds
        self.normal = MultivariateDiagonalNormal(normal_means, normal_stds)
        self.normals = [
            MultivariateDiagonalNormal(normal_means[:, :, i], normal_stds[:, :,
                                                                          i])
            for i in range(self.num_gaussians)
        ]
        self.weights = weights
        self.categorical = OneHotCategorical(self.weights[:, :, 0])

    def log_prob(
        self,
        value,
    ):
        # log_p = [self.normals[i].log_prob(value) for i in range(self.num_gaussians)]
        # log_p = torch.stack(log_p, -1)
        # # log_p = log_p.sum(dim=1)
        # log_weights = torch.log(self.weights[:, :, 0])
        # lp = log_weights + log_p
        # m = lp.max(dim=1)[0]  # log-sum-exp numerical stability trick
        # log_p_mixture = m + torch.log(torch.exp(lp.sum(dim=1) - m))

        log_p = [
            self.normals[i].log_prob(value) for i in range(self.num_gaussians)
        ]
        log_p = torch.stack(log_p, -1)
        p = torch.exp(log_p)
        weights = self.weights[:, :, 0]

        p = p * weights
        p = p.sum(dim=1)

        log_p = torch.log(p)

        return log_p

    def sample(self):
        z = self.normal.sample().detach()
        c = self.categorical.sample()[:, :, None]
        s = torch.matmul(z, c)
        return torch.squeeze(s, 2)

    def rsample(self):
        z = (self.normal_means + self.normal_stds * MultivariateDiagonalNormal(
            ptu.zeros(self.normal_means.size()),
            ptu.ones(self.normal_stds.size())).sample())
        z.requires_grad_()
        c = self.categorical.sample()[:, :, None]
        s = torch.matmul(z, c)
        return torch.squeeze(s, 2)

    def mle_estimate(self):
        """Return the mean of the most likely component.

        This often computes the mode of the distribution, but not always.
        """
        c = ptu.zeros(self.weights.shape[:2])
        ind = torch.argmax(self.weights, dim=1)  # [:, 0]
        c.scatter_(1, ind, 1)
        s = torch.matmul(self.normal_means, c[:, :, None])
        return torch.squeeze(s, 2)

    def __repr__(self):
        s = "GaussianMixture(normal_means=%s, normal_stds=%s, weights=%s)"
        return s % (self.normal_means, self.normal_stds, self.weights)
示例#10
0
class PPOTorchPolicy(TorchPolicy):
    def __init__(self, observation_space, action_space, config):
        super().__init__(observation_space, action_space, config)
        self.device = torch.device('cpu')

        # Get hyperparameters
        self.alpha = config['alpha']
        self.clip_ratio = config['clip_ratio']
        self.gamma = config['gamma']
        self.lam = config['lambda']
        self.lr_pi = config['lr_pi']
        self.lr_vf = config['lr_vf']
        self.model_hidden_sizes = config['model_hidden_sizes']
        self.num_skills = config['num_skills']
        self.skill_input = config['skill_input']
        self.target_kl = config['target_kl']
        self.use_diayn = config['use_diayn']
        self.use_env_rewards = config['use_env_rewards']
        self.use_gae = config['use_gae']

        # Initialize actor-critic model
        self.skills = OneHotCategorical(torch.ones((1, self.num_skills)))
        if self.skill_input is not None:
            skill_vec = [0.] * (self.num_skills - 1)
            skill_vec.insert(self.skill_input, 1.)
            self.z = torch.as_tensor([skill_vec], dtype=torch.float32)
        else:
            self.z = None
        self.model = SkilledA2C(observation_space,
                                action_space,
                                hidden_sizes=self.model_hidden_sizes,
                                skills=self.skills).to(self.device)

        # Set up optimizers for policy and value function
        self.pi_optimizer = Adam(self.model.pi.parameters(), self.lr_pi)
        self.vf_optimizer = Adam(self.model.vf.parameters(), self.lr_vf)
        self.disc_optimizer = Adam(self.model.discriminator.parameters(),
                                   self.lr_vf)

    def compute_loss_d(self, batch):
        obs, z = batch[SampleBatch.CUR_OBS], batch[SKILLS]
        logq_z = self.model.discriminator(obs)
        return nn.functional.nll_loss(logq_z, z.argmax(dim=-1))

    def compute_loss_pi(self, batch):
        obs, act, z = batch[
            SampleBatch.CUR_OBS], batch[ACTIVATIONS], batch[SKILLS]
        adv, logp_old = batch[Postprocessing.ADVANTAGES], batch[
            SampleBatch.ACTION_LOGP]
        clip_ratio = self.clip_ratio

        # Policy loss
        oz = torch.cat([obs, z], dim=-1)
        pi, logp = self.model.pi(oz, act)
        ratio = torch.exp(logp - logp_old)
        clip_adv = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio) * adv
        loss_pi = -(torch.min(ratio * adv, clip_adv)).mean()

        # Useful extra info
        approx_kl = (logp_old - logp).mean().item()
        ent = pi.entropy().mean().item()
        clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
        clip_frac = torch.as_tensor(clipped, dtype=torch.float32).mean().item()
        pi_info = dict(kl=approx_kl, ent=ent, cf=clip_frac)

        return loss_pi, pi_info

    def compute_loss_v(self, batch):
        obs, z = batch[SampleBatch.NEXT_OBS], batch[SKILLS]
        v_pred_old, v_targ = batch[SampleBatch.VF_PREDS], batch[
            Postprocessing.VALUE_TARGETS]

        oz = torch.cat([obs, z], dim=-1)
        v_pred = self.model.vf(oz)
        v_pred_clipped = v_pred_old + torch.clamp(
            v_pred - v_pred_old, -self.clip_ratio, self.clip_ratio)

        loss_clipped = (v_pred_clipped - v_targ).pow(2)
        loss_unclipped = (v_pred - v_targ).pow(2)

        return 0.5 * torch.max(loss_unclipped, loss_clipped).mean()

    def _convert_activation_to_action(self, activation):
        min_ = self.action_space.low
        max_ = self.action_space.high
        return tanh_to_action(activation, min_, max_)

    def _normalize_obs(self, obs):
        min_ = self.observation_space.low
        max_ = self.observation_space.high
        return normalize_obs(obs, min_, max_)

    @override(Policy)
    def compute_actions(self, obs, **kwargs):
        # Sample a skill at the start of each episode
        if self.z is None:
            self.z = self.skills.sample()

        o = self._normalize_obs(obs)
        a, v, logp_a, logq_z = self.model.step(
            torch.as_tensor(o, dtype=torch.float32), self.z)

        actions = self._convert_activation_to_action(a)
        extras = {
            ACTIVATIONS: a,
            SampleBatch.VF_PREDS: v,
            SampleBatch.ACTION_LOGP: logp_a,
            SKILLS: self.z.numpy(),
            SKILL_LOGQ: logq_z
        }
        return actions, [], extras

    @override(Policy)
    def postprocess_trajectory(self,
                               batch,
                               other_agent_batches=None,
                               episode=None):
        """Adds the policy logits, VF preds, and advantages to the trajectory."""

        completed = batch["dones"][-1]
        if completed:
            # Force end of episode reward
            last_r = 0.0

            # Reset skill at the end of each episode
            self.z = None
        else:
            next_state = []
            for i in range(self.num_state_tensors()):
                next_state.append([batch["state_out_{}".format(i)][-1]])
            obs = [batch[SampleBatch.NEXT_OBS][-1]]
            o = self._normalize_obs(obs)
            _, last_r, _, _ = self.model.step(
                torch.as_tensor(o, dtype=torch.float32), self.z)
            last_r = last_r.item()

        # Compute DIAYN rewards
        if self.use_diayn:
            z = torch.as_tensor(batch[SKILLS], dtype=torch.float32)
            logp_z = self.skills.log_prob(z).numpy()
            logq_z = batch[SKILL_LOGQ][:, z.argmax(dim=-1)[0].item()]
            entropy_reg = self.alpha * batch[SampleBatch.ACTION_LOGP]
            diayn_rewards = logq_z - logp_z - entropy_reg

            if self.use_env_rewards:
                batch[SampleBatch.REWARDS] += diayn_rewards
            else:
                batch[SampleBatch.REWARDS] = diayn_rewards

        batch = compute_advantages(batch,
                                   last_r,
                                   gamma=self.gamma,
                                   lambda_=self.lam,
                                   use_gae=self.use_gae)
        return batch

    @override(Policy)
    def learn_on_batch(self, postprocessed_batch):
        postprocessed_batch[SampleBatch.CUR_OBS] = self._normalize_obs(
            postprocessed_batch[SampleBatch.CUR_OBS])
        train_batch = self._lazy_tensor_dict(postprocessed_batch)

        # Train policy with multiple steps of gradient descent
        self.pi_optimizer.zero_grad()
        loss_pi, pi_info = self.compute_loss_pi(train_batch)
        # if pi_info['kl'] > 1.5 * self.target_kl:
        #     logger.info('Early stopping at step %d due to reaching max kl.' % i)
        #     return
        loss_pi.backward()
        self.pi_optimizer.step()

        # Value function learning
        self.vf_optimizer.zero_grad()
        loss_v = self.compute_loss_v(train_batch)
        loss_v.backward()
        self.vf_optimizer.step()

        # Discriminator learning
        self.disc_optimizer.zero_grad()
        loss_d = self.compute_loss_d(train_batch)
        loss_d.backward()
        self.disc_optimizer.step()

        grad_info = dict(pi_loss=loss_pi.item(),
                         vf_loss=loss_v.item(),
                         d_loss=loss_d.item(),
                         **pi_info)
        return {LEARNER_STATS_KEY: grad_info}
示例#11
0
 def forward(self, context=None):
     if context is not None:
         self.logits = self.network(context)
     return OneHotCategorical(logits=self.logits)
示例#12
0
 def forward(self, x):
     logits = self.categories_net(x)
     return OneHotCategorical(logits=logits)
示例#13
0
    def forward(self, input, args, n_particles, test=False):
        T = nn.Softmax(dim=0)(self.T)  # NOTE: not in log-space
        pi = nn.Softmax(dim=0)(self.pi)
        emit = self.calc_emit()

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        z = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                             self.z_dim)

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)

        for i in range(seq_len):
            # logits = self.logits(torch.cat([hidden_states[i], h], 1))
            # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits)))
            logits = self.logits(
                nn.functional.relu(
                    self.z_decoder(torch.cat([hidden_states[i], z], 1),
                                   logits)))

            # build the next z sample
            if test:
                q = OneHotCategorical(logits=logits)
                z = q.sample()
            else:
                q = RelaxedOneHotCategorical(temperature=self.temp,
                                             logits=logits)
                z = q.rsample()

            lse = log_sum_exp(logits, dim=1).view(-1, 1)
            log_probs = logits - lse

            # now, compute the log-likelihood of the data given this z-sample
            # so emission is [batch_sz x z_dim], i.e. emission[i, j] is the log-probability of getting this
            # data for element i given choice z
            emission = F.embedding(input[i].repeat(n_particles), emit)

            NLL = -log_sum_exp(emission + log_probs, 1)
            nlls[i] = NLL.data
            KL = (log_probs.exp() * (log_probs -
                                     (prior_probs + 1e-16).log())).sum(1)
            loss += (NLL + KL)

            if i != seq_len - 1:
                prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2)

        # now, we calculate the final log-marginal estimator
        return loss.sum(), nlls.sum(), (seq_len * batch_sz * n_particles), 0
示例#14
0
    def forward(self, input, args, n_particles, test=False):
        T = F.log_softmax(self.T, 0)  # NOTE: in log-space
        pi = F.log_softmax(self.pi, 0)  # NOTE: in log-space
        emit = self.calc_emit()

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = self.init_hidden(batch_sz * n_particles, self.z_dim, squeeze=True)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        accumulated_weights = -math.log(
            n_particles)  # will contain log w_{t - 1}
        resamples = 0

        # in log probability space
        prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                             self.z_dim)

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)

        for i in range(seq_len):
            # logits = self.logits(nn.functional.relu(self.z_decoder(hidden_states[i], logits)))
            logits = self.logits(
                nn.functional.relu(
                    self.z_decoder(torch.cat([hidden_states[i], h], 1),
                                   logits)))

            # build the next z sample
            if any_nans(logits):
                pdb.set_trace()
            if test:
                q = OneHotCategorical(logits=logits)
                z = q.sample()
            else:
                q = RelaxedOneHotCategorical(temperature=self.temp,
                                             logits=logits)
                z = q.rsample()
            h = z

            # prior
            if any_nans(prior_probs):
                pdb.set_trace()
            if test:
                p = OneHotCategorical(logits=prior_probs)
            else:
                p = RelaxedOneHotCategorical(temperature=self.temp_prior,
                                             logits=prior_probs)

            if any_nans(prior_probs):
                pdb.set_trace()
            if any_nans(logits):
                pdb.set_trace()

            # now, compute the log-likelihood of the data given this z-sample
            NLL = -self.decode(z, input[i].repeat(n_particles),
                               (emit, ))  # diff. w.r.t. z
            nlls[i] = NLL.data

            # compute the weight using `reweight` on page (4)
            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            # sample ancestors, and reindex everything
            Z = log_sum_exp(wa, dim=0)  # line 7

            loss += Z  # line 8
            accumulated_weights = wa - Z  # line 9

            if args.filter:
                probs = accumulated_weights.data.exp()
                probs += 0.01
                probs = probs / probs.sum(0, keepdim=True)
                effective_sample_size = 1. / probs.pow(2).sum(0)

                # probs is [n_particles, batch_sz]
                # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]]
                # offsets   [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]]

                # resample / RSAMP
                if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                    resamples += 1
                    ancestors = torch.multinomial(probs.transpose(0, 1),
                                                  n_particles, True)

                    # now, reindex, which is the most important thing
                    offsets = n_particles * torch.arange(batch_sz).unsqueeze(
                        1).repeat(1, n_particles).long()
                    if ancestors.is_cuda:
                        offsets = offsets.cuda()
                    unrolled_idx = Variable(ancestors + offsets).view(-1)
                    h = torch.index_select(h, 0, unrolled_idx)

                    # reset accumulated_weights
                    accumulated_weights = -math.log(
                        n_particles)  # will contain log w_{t - 1}

            if i != seq_len - 1:
                # now in probability space
                prior_probs = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2)

                # let's normalize things - slower, but safer
                # prior_probs += 0.01
                # prior_probs = prior_probs / prior_probs.sum(1, keepdim=True)

            # # if ((prior_probs.sum(1) - 1) > 1e-3).any()[0]:
            #     pdb.set_trace()

        if any_nans(loss):
            pdb.set_trace()

        # now, we calculate the final log-marginal estimator
        return -loss.sum(), nlls.sum(), (seq_len * batch_sz *
                                         n_particles), resamples
示例#15
0
    def get_actions(self,
                    obs,
                    prev_actions,
                    actor_rnn_states,
                    available_actions=None,
                    use_target=False,
                    t_env=None,
                    use_gumbel=False,
                    explore=False):

        assert prev_actions is None or len(obs.shape) == len(
            prev_actions.shape)
        # obs is either an array of shape (batch_size, obs_dim) or (seq_len, batch_size, obs_dim)
        if len(obs.shape) == 2:
            batch_size = obs.shape[0]
            no_sequence = True
        else:
            batch_size = obs.shape[1]
            no_sequence = False

        eps = None
        if use_target:
            actor_out, new_rnn_states = self.target_actor(
                obs, prev_actions, actor_rnn_states)
        else:
            actor_out, new_rnn_states = self.actor(obs, prev_actions,
                                                   actor_rnn_states)

        if self.discrete_action:
            if self.multidiscrete:
                if use_gumbel or explore or use_target:
                    onehot_actions = list(
                        map(lambda a: gumbel_softmax(a, hard=True), actor_out))
                else:
                    onehot_actions = list(map(onehot_from_logits, actor_out))

                onehot_actions = torch.cat(onehot_actions, dim=-1)
                if explore:
                    # eps greedy exploration
                    batch_size = obs.shape[0]
                    eps = self.exploration.eval(t_env)
                    rand_numbers = torch.rand((batch_size, 1))
                    take_random = (rand_numbers < eps).int().view(-1, 1)

                    # random actions sample uniformly from action space
                    random_actions = [
                        OneHotCategorical(logits=torch.ones(
                            batch_size, self.act_dim[i])).sample()
                        for i in range(len(self.act_dim))
                    ]
                    random_actions = torch.cat(random_actions, dim=1)
                    actions = (
                        1 - take_random
                    ) * onehot_actions + take_random * random_actions
                else:
                    actions = onehot_actions
            else:
                if use_gumbel or explore or use_target:
                    onehot_actions = gumbel_softmax(
                        actor_out, available_actions,
                        hard=True)  # gumbel has a gradient
                else:
                    onehot_actions = onehot_from_logits(
                        actor_out, available_actions)  # no gradient

                if explore:
                    assert no_sequence, "Doesn't make sense to do exploration on a sequence!"
                    # eps greedy exploration
                    eps = self.exploration.eval(t_env)
                    rand_numbers = np.random.rand(batch_size, 1)
                    # random actions sample uniformly from action space
                    logits = torch.ones(batch_size, self.act_dim)
                    random_actions = avail_choose(logits,
                                                  available_actions).sample()
                    random_actions = make_onehot(random_actions, batch_size,
                                                 self.act_dim)
                    take_random = (rand_numbers < eps).astype(float)
                    actions = (
                        1.0 - take_random) * onehot_actions.detach().cpu(
                        ).numpy() + take_random * random_actions.cpu().numpy()
                else:
                    actions = onehot_actions
        else:
            if explore:
                assert no_sequence, "Cannot do exploration on a sequence!"
                actions = gaussian_noise(actor_out.shape,
                                         self.args.act_noise_std) + actor_out
            elif use_target:
                target_noise = gaussian_noise(
                    actor_out.shape, self.args.target_noise_std).clamp(
                        -self.args.target_noise_clip,
                        self.args.target_noise_clip)
                actions = actor_out + target_noise
            else:
                actions = actor_out
            # # clip the actions at the bounds of the action space
            # actions = torch.max(torch.min(actions, torch.from_numpy(self.act_space.high)), torch.from_numpy(self.act_space.low))

        return actions, new_rnn_states, eps
示例#16
0
    def sampled_filter(self, input, args, n_particles, emb, hidden_states):
        seq_len, batch_sz = input.size()
        T = F.log_softmax(self.T, 0)  # NOTE: in log-space
        pi = F.log_softmax(self.pi, 0)  # NOTE: in log-space
        emit = self.calc_emit()

        hidden_states = hidden_states.repeat(1, n_particles, 1)

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        accumulated_weights = -math.log(
            n_particles)  # will contain log w_{t - 1}
        resamples = 0

        # in log probability space
        prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                              self.z_dim)

        for i in range(seq_len):
            # the approximate posterior comes from the same thing as before
            logits = self.logits(hidden_states[i])

            if not self.training:
                # this is crucial!!
                p = OneHotCategorical(logits=prior_logits)
                q = OneHotCategorical(logits=logits)
                z = q.sample()
            else:
                p = RelaxedOneHotCategorical(temperature=self.temp_prior,
                                             logits=prior_logits)
                q = RelaxedOneHotCategorical(temperature=self.temp,
                                             logits=logits)
                z = q.rsample()

            # now, compute the log-likelihood of the data given this z-sample
            emission = F.embedding(input[i].repeat(n_particles), emit)
            NLL = -(emission * z).sum(1)
            # NLL = -self.decode(z, input[i].repeat(n_particles), (emit,))  # diff. w.r.t. z
            nlls[i] = NLL.data

            # compute the weight using `reweight` on page (4)
            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            Z = log_sum_exp(wa, dim=0)  # line 7

            loss += Z  # line 8
            accumulated_weights = wa - Z  # F.log_softmax(wa, dim=0)  # line 9

            # sample ancestors, and reindex everything
            if args.filter:
                probs = accumulated_weights.data.exp()
                probs += 0.01
                probs = probs / probs.sum(0, keepdim=True)
                effective_sample_size = 1. / probs.pow(2).sum(0)

                # probs is [n_particles, batch_sz]
                # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]]
                # offsets   [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]]

                # resample / RSAMP
                if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                    resamples += 1
                    ancestors = torch.multinomial(probs.transpose(0, 1),
                                                  n_particles, True)

                    # now, reindex, which is the most important thing
                    offsets = n_particles * torch.arange(batch_sz).unsqueeze(
                        1).repeat(1, n_particles).long()
                    if ancestors.is_cuda:
                        offsets = offsets.cuda()
                    unrolled_idx = Variable(ancestors + offsets).view(-1)
                    z = torch.index_select(z, 0, unrolled_idx)

                    # reset accumulated_weights
                    accumulated_weights = -math.log(
                        n_particles)  # will contain log w_{t - 1}

            if i != seq_len - 1:
                # now in log-probability space
                prior_logits = log_sum_exp(T.unsqueeze(0) + z.unsqueeze(1), 2)

        if self.training:
            (-loss.sum() /
             (seq_len * batch_sz * n_particles)).backward(retain_graph=True)
        return -loss.sum(), nlls.sum(), seq_len * batch_sz * n_particles, 0
    def sample(self,
               obs,
               prev_acts,
               rnn_hidden_states,
               available_actions=None,
               sample_gumbel=False):
        # TODO: review this method
        act_logits, h_outs = self.forward(obs, prev_acts, rnn_hidden_states)

        if self.multidiscrete:
            sampled_actions = []
            mean_action_logprobs = []
            max_prob_actions = []
            for act_logit in act_logits:
                categorical = OneHotCategorical(logits=act_logit)

                all_action_prob = categorical.probs
                eps = (all_action_prob == 0.0) * 1e-6
                all_action_logprob = torch.log(all_action_prob +
                                               eps.float().detach())
                mean_action_logprob = (all_action_logprob *
                                       all_action_prob).sum(
                                           dim=-1).unsqueeze(-1)

                if sample_gumbel:
                    # get a differentiable sample of the action
                    sampled_action = gumbel_softmax(act_logit, hard=True)
                else:
                    sampled_action = categorical.sample()

                max_prob_action = onehot_from_logits(act_logit)

                sampled_actions.append(sampled_action)
                mean_action_logprobs.append(mean_action_logprob)
                max_prob_actions.append(max_prob_action)

            sampled_actions = torch.cat(sampled_actions, dim=-1)
            mean_action_logprobs = torch.cat(mean_action_logprobs, dim=-1)
            max_prob_actions = torch.cat(max_prob_actions, dim=-1)

            return sampled_actions, mean_action_logprobs, max_prob_actions, h_outs
        else:
            categorical = OneHotCategorical(logits=act_logits)

            all_action_probs = categorical.probs
            eps = (all_action_probs == 0.0) * 1e-6
            all_action_logprobs = torch.log(all_action_probs +
                                            eps.float().detach())
            mean_action_logprobs = (all_action_logprobs *
                                    all_action_probs).sum(dim=-1).unsqueeze(-1)

            if sample_gumbel:
                # get a differentiable sample of the action
                sampled_actions = gumbel_softmax(act_logits,
                                                 available_actions,
                                                 hard=True)
            else:
                if available_actions is not None:
                    if type(available_actions) == np.ndarray:
                        available_actions = torch.from_numpy(available_actions)
                    act_logits[available_actions == 0] = -1e10
                    sampled_actions = OneHotCategorical(
                        logits=act_logits).sample()
                else:
                    sampled_actions = categorical.sample()

            max_prob_actions = onehot_from_logits(act_logits,
                                                  available_actions)
            return sampled_actions, mean_action_logprobs, max_prob_actions, h_outs
示例#18
0
文件: networks.py 项目: ag8/mrl
 def forward(self, x):
     params = self.fc(self.body(x))
     return OneHotCategorical(logits=params)
    def inference(
        self,
        x,
        y=None,
        temperature=None,
        n_samples=1,
        reparam=True,
        encoder_key="default",
        counts=None,
    ):
        """
        Dimension choice
            (n_categories, n_is, n_batch, n_latent)

            log_q
            (n_categories, n_is, n_batch)
        """
        if temperature is None:
            raise ValueError(
                "Please provide a temperature for the relaxed OneHot distribution"
            )

        if counts is not None:
            return self.inference_defensive_sampling(
                x=x, y=y, temperature=temperature, counts=counts
            )
        n_cat = self.n_labels
        n_batch = len(x)
        # Z | X
        inp = x
        q_z1 = self.encoder_z1[encoder_key](
            inp, n_samples=n_samples, reparam=reparam, squeeze=False
        )
        # if not self.do_iaf:
        qz1_m = q_z1["q_m"]
        qz1_v = q_z1["q_v"]
        z1 = q_z1["latent"]
        assert z1.dim() == 3
        # log_qz1_x = Normal(qz1_m, qz1_v.sqrt()).log_prob(z1).sum(-1)
        log_qz1_x = q_z1["dist"].log_prob(z1)
        dfs = q_z1.get("df", None)
        if q_z1["sum_last"]:
            log_qz1_x = log_qz1_x.sum(-1)
        z1s = z1
        # torch.cuda.synchronize()

        #  C | Z
        # Broadcast labels if necessary
        qc_z1 = self.classifier[encoder_key](z1)
        log_qc_z1 = qc_z1.log()
        qc_z1_all_probas = qc_z1
        # C
        if y is None:
            if reparam:
                cat_dist = RelaxedOneHotCategorical(
                    temperature=temperature, probs=qc_z1
                )
                ys_probs = cat_dist.rsample()
            else:
                cat_dist = OneHotCategorical(probs=qc_z1)
                ys_probs = cat_dist.sample()
            ys = (ys_probs == ys_probs.max(-1, keepdim=True).values).float()
            y_int = ys.argmax(-1)
        else:
            ys = torch.cuda.FloatTensor(n_batch, n_cat)
            ys.zero_()
            ys.scatter_(1, y.view(-1, 1), 1)
            ys = ys.view(1, n_batch, n_cat).expand(n_samples, n_batch, n_cat)
            y_int = y.view(1, -1).expand(n_samples, n_batch)
        log_pc = self.y_prior.log_prob(y_int)
        assert y_int.unsqueeze(-1).shape == (n_samples, n_batch, 1), y_int.shape
        log_qc_z1 = torch.gather(log_qc_z1, dim=-1, index=y_int.unsqueeze(-1)).squeeze(
            -1
        )
        qc_z1 = torch.gather(qc_z1, dim=-1, index=y_int.unsqueeze(-1)).squeeze(-1)
        assert qc_z1.shape == (n_samples, n_batch)
        pc = log_pc.exp()

        # U | Z1, C
        z1_y = torch.cat([z1s, ys], dim=-1)
        q_z2_z1 = self.encoder_z2_z1[encoder_key](z1_y, n_samples=1, reparam=reparam)
        z2 = q_z2_z1["latent"]
        qz2_z1_m = q_z2_z1["q_m"]
        qz2_z1_v = q_z2_z1["q_v"]
        # log_qz2_z1 = Normal(q_z2_z1["q_m"], q_z2_z1["q_v"].sqrt()).log_prob(z2).sum(-1)
        log_qz2_z1 = q_z2_z1["dist"].log_prob(z2)
        if q_z2_z1["sum_last"]:
            log_qz2_z1 = log_qz2_z1.sum(-1)
        z2_y = torch.cat([z2, ys], dim=-1)
        pz1_z2m, pz1_z2_v = self.decoder_z1_z2(z2_y)
        log_pz1_z2 = Normal(pz1_z2m, pz1_z2_v.sqrt()).log_prob(z1).sum(-1)

        log_pz2 = Normal(torch.zeros_like(z2), torch.ones_like(z2)).log_prob(z2).sum(-1)

        px_z_loc = self.x_decoder(z1)
        log_px_z = Bernoulli(px_z_loc).log_prob(x).sum(-1)
        generative_density = log_pz2 + log_pc + log_pz1_z2 + log_px_z
        variational_density = log_qz1_x + log_qz2_z1
        log_ratio = generative_density - variational_density

        variables = dict(
            z1=z1,
            ys=ys,
            z2=z2,
            qz1_m=qz1_m,
            qz1_v=qz1_v,
            qz2_z1_m=qz2_z1_m,
            qz2_z1_v=qz2_z1_v,
            pz1_z2m=pz1_z2m,
            pz1_z2_v=pz1_z2_v,
            px_z_m=px_z_loc,
            log_qz1_x=log_qz1_x,
            qc_z1=qc_z1,
            log_qc_z1=log_qc_z1,
            log_qz2_z1=log_qz2_z1,
            log_pz2=log_pz2,
            log_pc=log_pc,
            pc=pc,
            log_pz1_z2=log_pz1_z2,
            log_px_z=log_px_z,
            generative_density=generative_density,
            variational_density=variational_density,
            log_ratio=log_ratio,
            qc_z1_all_probas=qc_z1_all_probas,
            df=dfs,
        )
        # torch.cuda.synchronize()
        return variables
示例#20
0
import storch
import torch
from torch.distributions import Bernoulli, OneHotCategorical

expect = storch.method.Expect("x")
probs = torch.tensor([0.95, 0.01, 0.01, 0.01, 0.01, 0.01], requires_grad=True)
indices = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
b = OneHotCategorical(probs=probs)
z = expect.sample(b)
c = (2.4 * z * indices).sum(-1)
storch.add_cost(c, "no_baseline_cost")

storch.backward()

expect_grad = z.grad["probs"].clone()


def eval(grads):
    print("----------------------------------")
    grad_samples = storch.gather_samples(grads, "variance")
    mean = storch.reduce_plates(grad_samples, plates=["variance"])
    print("mean grad", mean)
    print("expected grad", expect_grad)
    print("specific_diffs", (mean - expect_grad)**2)
    mse = storch.reduce_plates((grad_samples - expect_grad)**2).sum()
    print("MSE", mse)
    bias = (storch.reduce_plates((mean - expect_grad)**2)).sum()
    print("bias", bias)
    return bias

 def distribution(self, output_net):
     return OneHotCategorical(logits=output_net)
示例#22
0
    def forward(self, input, args, n_particles, test=False):
        """
        The major difference is that now we use a GRU to predict the prior z logits, instead of using a linear map
        T. I think trying to fit this GRU is really hard, I'm kind of concerned
        """
        if test:
            n_particles = 10
        else:
            n_particles = 1
        pi = F.log_softmax(self.pi, 0)

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = Variable(
            hidden_states.data.new(batch_sz * n_particles,
                                   self.hidden_size).zero_())

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                              self.z_dim)
        prior_h = Variable(torch.zeros(batch_sz * n_particles, 50).cuda())

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)
        feed = None

        # use dropout on the teacher-forcing
        x_emb = self.lockdrop(emb, self.dropout_x)

        for i in range(seq_len):
            # build the next z sample - not differentiable! we don't train the inference network
            logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach()
            z = OneHotCategorical(logits=logits).sample()

            # this should be batch_sz x x_dim
            scores = torch.mm(self.project(torch.cat([h, z], 1)),
                              self.emit.t())

            NLL = nn.CrossEntropyLoss(reduce=False)(
                scores, input[i].repeat(n_particles))
            KL = (logits.exp() * (logits - prior_logits)).sum(1)
            loss += (NLL + KL)

            nlls[i] = NLL.data

            # set things up for next time
            if i != seq_len - 1:
                feed = torch.cat(
                    [emb[i].repeat(n_particles, 1),
                     self.z_emb(z)], 1)
                prior_h = self.z_decoder(feed, prior_h)
                prior_logits = F.log_softmax(self.project_z(prior_h), 1)
                h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1),
                                    h)  # feed the next word into the RNN

        if n_particles != 1:
            loss = -log_sum_exp(-loss.view(n_particles, batch_sz),
                                0) + math.log(n_particles)
            NLL = -log_sum_exp(
                -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log(
                    n_particles)  # not quite accurate, but what can you do
        else:
            NLL = nlls

        # now, we calculate the final log-marginal estimator
        return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
示例#23
0
event = 2
plt_n1 = 3
plt_n2 = 2

# Define swr method
swr_method = storch.method.ScoreFunctionWOR("z", k, biased=True, use_baseline=False)
normal_method1 = storch.method.ScoreFunction("n1", n_samples=plt_n1)


l_entropy = torch.tensor([-3.0, -3.0, 2, -2.0], requires_grad=True)
h_entropy = torch.tensor([-0.1, 0.1, 0.05, -0.05], requires_grad=True)

n_params = torch.tensor(0.0, requires_grad=True)


d1 = OneHotCategorical(logits=l_entropy.repeat((event, 1)))
d2 = OneHotCategorical(logits=h_entropy)

dn1 = Normal(n_params, 1.0)

# k x event x |D_yv|
z_1 = swr_method.sample(d1)
# k x |D_yv|
z_2 = swr_method.sample(d2)

print("z1", z_1)
print("z2", z_2)

assert z_1.shape == (min(k, d_yv ** event), event, d_yv)
assert z_2.shape == (min(k, d_yv ** (event + 1)), d_yv)
示例#24
0
    def forward(self, input, args, n_particles, test=False):
        """
        evaluation is the IWAE-10 bound
        """
        if test:
            n_particles = 10
        else:
            n_particles = 1
        pi = F.log_softmax(self.pi, 0)

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = (Variable(
            hidden_states.data.new(batch_sz * n_particles,
                                   self.hidden_size).zero_()),
             Variable(
                 hidden_states.data.new(batch_sz * n_particles,
                                        self.hidden_size).zero_()))

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                              self.z_dim)
        prior_h = (Variable(torch.zeros(batch_sz * n_particles, 50).cuda()),
                   Variable(torch.zeros(batch_sz * n_particles, 50).cuda()))

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)
        feed = None

        x_emb = self.lockdrop(emb, self.dropout_x)

        for i in range(seq_len):
            # build the next z sample - not differentiable! we don't train the inference network
            logits = F.log_softmax(self.logits(hidden_states[i]), 1)

            if test:
                q = OneHotCategorical(logits=logits)
                # p = OneHotCategorical(logits=prior_logits)
                z = q.sample()
            else:
                q = RelaxedOneHotCategorical(temperature=self.temp,
                                             logits=logits)
                # p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits)
                z = q.rsample()

            # this should be batch_sz x x_dim
            scores = torch.mm(self.project(torch.cat([h[0], z], 1)),
                              self.emit.t())

            NLL = nn.CrossEntropyLoss(reduce=False)(
                scores, input[i].repeat(n_particles))
            # KL = q.log_prob(z) - p.log_prob(z)
            KL = (logits.exp() * (logits - prior_logits)).sum(1)
            loss += (NLL + KL)
            # else:
            #     loss += (NLL + args.anneal * KL)

            nlls[i] = NLL.data

            # set things up for next time
            if i != seq_len - 1:
                feed = torch.cat(
                    [emb[i].repeat(n_particles, 1),
                     self.z_emb(z)], 1)
                prior_h = self.z_decoder(feed, prior_h)
                prior_logits = F.log_softmax(self.project_z(prior_h[0]), 1)
                h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1),
                                    h)  # feed the next word into the RNN

        if n_particles != 1:
            loss = -log_sum_exp(-loss.view(n_particles, batch_sz),
                                0) + math.log(n_particles)
            NLL = -log_sum_exp(
                -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log(
                    n_particles)  # not quite accurate, but what can you do
        else:
            NLL = nlls

        # now, we calculate the final log-marginal estimator
        return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
示例#25
0
 def __init__(self, probs: torch.Tensor, sections: Tuple):
     self._sections = sections
     self._dists = [
         OneHotCategorical(x) for x in torch.split(probs, sections, dim=-1)
     ]
示例#26
0
    def forward(self, input, args, n_particles, test=False):
        """
        evaluation is the IWAE-10 bound
        """
        pi = F.log_softmax(self.pi, 0)

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = (Variable(
            hidden_states.data.new(batch_sz * n_particles,
                                   self.hidden_size).zero_()),
             Variable(
                 hidden_states.data.new(batch_sz * n_particles,
                                        self.hidden_size).zero_()))

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_logits = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                              self.z_dim)
        prior_h = (Variable(torch.zeros(batch_sz * n_particles, 50).cuda()),
                   Variable(torch.zeros(batch_sz * n_particles, 50).cuda()))

        accumulated_weights = -math.log(
            n_particles)  # will contain log w_{t - 1}
        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)
        feed = None

        x_emb = self.lockdrop(emb, self.dropout_x)

        if test:
            pdb.set_trace()

        for i in range(seq_len):
            # build the next z sample - not differentiable! we don't train the inference network
            logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach()

            # if test:
            q = OneHotCategorical(logits=logits)
            p = OneHotCategorical(logits=prior_logits)
            a = q.sample()
            # else:
            #     q = RelaxedOneHotCategorical(temperature=self.temp, logits=logits)
            #     p = RelaxedOneHotCategorical(temperature=self.temp_prior, logits=prior_logits)
            #     a = q.rsample()

            # to guard against being too crazy
            b = a + 1e-16
            z = b / b.sum(1, keepdim=True)

            # this should be batch_sz x x_dim
            scores = torch.mm(self.project(torch.cat([h[0], z], 1)),
                              self.emit.t())

            NLL = nn.CrossEntropyLoss(reduce=False)(
                scores, input[i].repeat(n_particles))
            nlls[i] = NLL.data

            f_term = p.log_prob(z)  # prior
            r_term = q.log_prob(z)  # proposal
            alpha = -NLL + (f_term - r_term)

            wa = accumulated_weights + alpha.view(n_particles, batch_sz)

            Z = log_sum_exp(wa, dim=0)  # line 7

            loss += Z  # line 8
            accumulated_weights = wa - Z  # F.log_softmax(wa, dim=0)  # line 9

            probs = accumulated_weights.data.exp()
            probs += 0.01
            probs = probs / probs.sum(0, keepdim=True)
            effective_sample_size = 1. / probs.pow(2).sum(0)

            if any_nans(probs):
                pdb.set_trace()

            # probs is [n_particles, batch_sz]
            # ancestors [2 x 15] = [[0, 0, 0, ..., 0], [0, 1, 2, 3, ...]]
            # offsets   [2 x 15] = [[0, 0, 0, ..., 0], [1, 1, 1, 1, ...]]

            # resample / RSAMP
            if ((effective_sample_size / n_particles) < 0.3).sum() > 0:
                ancestors = torch.multinomial(probs.transpose(0, 1),
                                              n_particles, True)

                # now, reindex, which is the most important thing
                offsets = n_particles * torch.arange(batch_sz).unsqueeze(
                    1).repeat(1, n_particles).long()
                if ancestors.is_cuda:
                    offsets = offsets.cuda()
                unrolled_idx = Variable(ancestors + offsets).view(-1)

                # shuffle!
                z = torch.index_select(z, 0, unrolled_idx)
                a, b = h
                h = torch.index_select(a, 0, unrolled_idx), torch.index_select(
                    b, 0, unrolled_idx)
                a, b = prior_h
                prior_h = torch.index_select(a, 0,
                                             unrolled_idx), torch.index_select(
                                                 b, 0, unrolled_idx)

                # reset accumulated_weights
                accumulated_weights = -math.log(
                    n_particles)  # will contain log w_{t - 1}

            # set things up for next time
            if i != seq_len - 1:
                feed = torch.cat(
                    [emb[i].repeat(n_particles, 1),
                     self.z_emb(z)], 1)
                prior_h = self.z_decoder(feed, prior_h)
                prior_logits = F.log_softmax(self.project_z(prior_h[0]), 1)
                h = self.hidden_rnn(x_emb[i].repeat(n_particles, 1),
                                    h)  # feed the next word into the RNN

        NLL = nlls

        # now, we calculate the final log-marginal estimator
        return loss.sum(), NLL.sum(), (seq_len * batch_sz * n_particles), 0
示例#27
0
 def prior(self, posterior: Distribution):
     return OneHotCategorical(probs=torch.ones_like(posterior.probs) / 10.0)
示例#28
0
    def forward(self, input, args, n_particles, test=False):
        """
        n_particles is interpreted as 1 for now to not screw anything up
        """
        if test:
            n_particles = 10
        else:
            n_particles = 1
        T = nn.Softmax(dim=0)(self.T)  # NOTE: not in log-space
        pi = nn.Softmax(dim=0)(self.pi)

        # run the input and teacher-forcing inputs through the embedding layers here
        seq_len, batch_sz = input.size()
        emb = self.inp_embedding(input)
        hidden = self.init_hidden(batch_sz, self.nhid, 2)  # bidirectional
        hidden_states, (_, _) = self.encoder(emb, hidden)
        hidden_states = hidden_states.repeat(1, n_particles, 1)

        # run the z-decoder at this point, evaluating the NLL at each step
        h = Variable(
            hidden_states.data.new(batch_sz * n_particles,
                                   self.hidden_size).zero_())

        nlls = hidden_states.data.new(seq_len, batch_sz * n_particles)
        loss = 0

        # now a log-prob
        prior_probs = pi.unsqueeze(0).expand(batch_sz * n_particles,
                                             self.z_dim)

        logits = self.init_hidden(batch_sz * n_particles,
                                  self.z_dim,
                                  squeeze=True)
        feed = None

        for i in range(seq_len):
            # build the next z sample - not differentiable! we don't train the inference network
            logits = F.log_softmax(self.logits(hidden_states[i]), 1).detach()
            z = OneHotCategorical(logits=logits).sample()

            # this should be batch_sz x x_dim
            feed = self.project(torch.cat([h, z], 1))  # batch_sz x hidden_dim
            scores = torch.mm(feed, self.emit.t())  # batch_sz x x_dim

            NLL = nn.CrossEntropyLoss(reduce=False)(
                scores, input[i].repeat(n_particles))
            KL = (logits.exp() * (logits - (prior_probs + 1e-16).log())).sum(1)
            loss += (NLL + KL)

            nlls[i] = NLL.data

            # set things up for next time
            if i != seq_len - 1:
                prior_probs = (T.unsqueeze(0) * z.unsqueeze(1)).sum(2)
                h = self.hidden_rnn(emb[i].repeat(n_particles, 1),
                                    h)  # feed the next word into the RNN

        if n_particles != 1:
            loss = -log_sum_exp(-loss.view(n_particles, batch_sz),
                                0) + math.log(n_particles)
            NLL = -log_sum_exp(
                -nlls.view(seq_len, n_particles, batch_sz), 1) + math.log(
                    n_particles)  # not quite accurate, but what can you do
        else:
            NLL = nlls

        # now, we calculate the final log-marginal estimator
        return loss.sum(), NLL.sum(), (seq_len * batch_sz), 0
示例#29
0
 def forward(self, x):
     params = self.network(x)
     return OneHotCategorical(logits=params)
示例#30
0
 def kl_categorical(self, logits_q):
     # Analytical KL with categorical prior
     p_cat = OneHotCategorical(logits=self.logits_p.expand_as(logits_q))
     q_cat = OneHotCategorical(logits=logits_q)
     KL_qp = kl_divergence(q_cat, p_cat)
     return KL_qp