Ejemplo n.º 1
0
Archivo: crf.py Proyecto: zysite/post
    def get_logZ(self, emit, mask):
        T, B, N = emit.shape

        alpha = self.strans + emit[0]  # [B, N]

        for i in range(1, T):
            trans_i = self.trans.unsqueeze(0)  # [1, N, N]
            emit_i = emit[i].unsqueeze(1)  # [B, 1, N]
            mask_i = mask[i].unsqueeze(1).expand_as(alpha)  # [B, N]
            scores = trans_i + emit_i + alpha.unsqueeze(2)  # [B, N, N]
            scores = torch.logsumexp(scores, dim=1)  # [B, N]
            alpha[mask_i] = scores[mask_i]

        return torch.logsumexp(alpha + self.etrans, dim=1).sum()
Ejemplo n.º 2
0
 def loss(self, predict, target):
     assert isinstance(predict, HMMState)
     seq_len = target.shape[0]
     hmm = HMM(predict.init.shape[-1], self.trg_vocab_size, predict.init,
               predict.trans, predict.emiss)
     loss = hmm.p_x(target, ignore_index=PAD_IDX)
     return -torch.logsumexp(loss, dim=-1).mean() / seq_len
Ejemplo n.º 3
0
def get_thermo_loss_different_samples(generative_model,
                                      inference_network,
                                      obs,
                                      partition=None,
                                      num_particles=1,
                                      integration='left'):
    """Thermo loss gradient estimator computed using two set of importance
        samples.

    Args:
        generative_model: models.GenerativeModel object
        inference_network: models.InferenceNetwork object
        obs: tensor of shape [batch_size]
        partition: partition of [0, 1];
            tensor of shape [num_partitions + 1] where partition[0] is zero and
            partition[-1] is one;
            see https://en.wikipedia.org/wiki/Partition_of_an_interval
        num_particles: int
        integration: left, right or trapz

    Returns:
        loss: scalar that we call .backward() on and step the optimizer.
        elbo: average elbo over data
    """

    log_weights, log_ps, log_qs, heated_normalized_weights = [], [], [], []
    for _ in range(2):
        log_weight, log_p, log_q = get_log_weight_log_p_log_q(
            generative_model, inference_network, obs, num_particles)
        log_weights.append(log_weight)
        log_ps.append(log_p)
        log_qs.append(log_q)

        heated_log_weight = log_weight.unsqueeze(-1) * partition
        heated_normalized_weights.append(
            util.exponentiate_and_normalize(heated_log_weight, dim=1))

    w_detached = heated_normalized_weights[0].detach()
    thermo_logp = partition * log_ps[0].unsqueeze(-1) + \
        (1 - partition) * log_qs[0].unsqueeze(-1)
    wf = heated_normalized_weights[1] * log_weights[1].unsqueeze(-1)

    if num_particles == 1:
        correction = 1
    else:
        correction = num_particles / (num_particles - 1)

    thing_to_add = correction * torch.sum(
        w_detached *
        (log_weight.unsqueeze(-1) -
         torch.sum(wf, dim=1, keepdim=True)).detach() * thermo_logp,
        dim=1)

    multiplier = torch.zeros_like(partition)
    if integration == 'trapz':
        multiplier[0] = 0.5 * (partition[1] - partition[0])
        multiplier[1:-1] = 0.5 * (partition[2:] - partition[0:-2])
        multiplier[-1] = 0.5 * (partition[-1] - partition[-2])
    elif integration == 'left':
        multiplier[:-1] = partition[1:] - partition[:-1]
    elif integration == 'right':
        multiplier[1:] = partition[1:] - partition[:-1]

    loss = -torch.mean(
        torch.sum(multiplier *
                  (thing_to_add +
                   torch.sum(w_detached * log_weight.unsqueeze(-1), dim=1)),
                  dim=1))

    log_evidence = torch.logsumexp(log_weight, dim=1) - np.log(num_particles)
    elbo = torch.mean(log_evidence)

    return loss, elbo
Ejemplo n.º 4
0
 def mixing_log_prob(self):
     return self._mixing_logits - torch.logsumexp(
         self._mixing_logits, 1, keepdim=True)
Ejemplo n.º 5
0
 def E(self, z):
     a = -0.5 * ((torch.norm(z, dim=1) - 2) / 0.4)**2  # check dim
     vec = torch.stack(
         (-0.5 * ((z[:, 0] - 2) / 0.6)**2, -0.5 * ((z[:, 0] + 2) / 0.6)**2))
     return -torch.logsumexp(vec, dim=0) - a
Ejemplo n.º 6
0
    def forward(self, qk, v):
        batch_size, seqlen, _ = qk.shape
        device = qk.device

        n_buckets = seqlen // self.bucket_size
        n_bins = n_buckets

        buckets = self.hash_vectors(n_buckets, qk)
        # We use the same vector as both a query and a key.
        assert int(buckets.shape[1]) == self.n_hashes * seqlen

        ticker = torch.arange(self.n_hashes * seqlen,
                              device=device).unsqueeze(0)
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = buckets_and_t.detach()

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = sort_key_val(buckets_and_t, ticker, dim=-1)
        _, undo_sort = sort_key_val(sticker, ticker, dim=-1)

        sbuckets_and_t = sbuckets_and_t.detach()
        sticker = sticker.detach()
        undo_sort = undo_sort.detach()

        st = (sticker % seqlen)
        sqk = batched_index_select(qk, st)
        sv = batched_index_select(v, st)

        # Split off a "bin" axis so that attention only occurs within chunks.
        bq_t = bkv_t = torch.reshape(st,
                                     (batch_size, self.n_hashes * n_bins, -1))
        bqk = torch.reshape(
            sqk, (batch_size, self.n_hashes * n_bins, -1, sqk.shape[-1]))
        bv = torch.reshape(
            sv, (batch_size, self.n_hashes * n_bins, -1, sv.shape[-1]))
        bq_buckets = bkv_buckets = torch.reshape(
            sbuckets_and_t // seqlen, (batch_size, self.n_hashes * n_bins, -1))

        # Hashing operates on unit-length vectors. Unnormalized query vectors are
        # fine because they effectively provide a learnable temperature for the
        # attention softmax, but normalizing keys is needed so that similarity for
        # the purposes of attention correctly corresponds to hash locality.
        bq = bqk
        bk = make_unit_length(bqk)

        # Allow each chunk to attend within itself, and also one chunk back. Chunk
        # boundaries might occur in the middle of a sequence of items from the
        # same bucket, so this increases the chances of attending to relevant items.
        def look_one_back(x):
            x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1)
            return torch.cat([x, x_extra], dim=2)

        bk = look_one_back(bk)
        bv = look_one_back(bv)
        bkv_t = look_one_back(bkv_t)
        bkv_buckets = look_one_back(bkv_buckets)

        # Dot-product attention.
        dots = torch.einsum('bhie,bhje->bhij', bq, bk) * (bq.shape[-1]**-0.5)

        # Causal masking
        if self.causal:
            mask = bq_t[:, :, :, None] < bkv_t[:, :, None, :]
            dots.masked_fill_(mask, float('-inf'))

        # Mask out attention to self except when no other targets are available.
        self_mask = bq_t[:, :, :, None] == bkv_t[:, :, None, :]
        dots.masked_fill_(self_mask, -1e5)

        # Mask out attention to other hash buckets.
        if not self._attend_across_buckets:
            bucket_mask = bq_buckets[:, :, :, None] != bkv_buckets[:, :,
                                                                   None, :]
            dots.masked_fill_(bucket_mask, float('-inf'))

        # Don't double-count query-key pairs across multiple rounds of hashing.
        # There are two possible strategies here. (1) The default is to count how
        # many times a query-key pair is repeated, and to lower its log-prob
        # correspondingly at each repetition. (2) When hard_k is set, the code
        # instead masks all but the first occurence of each query-key pair.
        if not self._allow_duplicate_attention:
            locs1 = undo_sort // bq_t.shape[-1]
            locs2 = (locs1 + 1) % (self.n_hashes * n_bins)
            if not self._attend_across_buckets:
                locs1 = buckets * (self.n_hashes * n_bins) + locs1
                locs2 = buckets * (self.n_hashes * n_bins) + locs2
            locs = torch.cat([
                torch.reshape(locs1, (batch_size, self.n_hashes, seqlen)),
                torch.reshape(locs2, (batch_size, self.n_hashes, seqlen)),
            ], 1).permute((0, 2, 1))

            slocs = batched_index_select(locs, st)
            b_locs = torch.reshape(
                slocs,
                (batch_size, self.n_hashes * n_bins, -1, 2 * self.n_hashes))

            b_locs1 = b_locs[:, :, :, None, :self.n_hashes]

            bq_locs = b_locs1.expand(b_locs.shape[:3] + (2, self.n_hashes))
            bq_locs = torch.reshape(bq_locs, b_locs.shape)
            bkv_locs = look_one_back(b_locs)

            dup_counts = (bq_locs[:, :, :, None, :] == bkv_locs[:, :,
                                                                None, :, :])
            # for memory considerations, chunk summation of last dimension for counting duplicates
            dup_counts = chunked_sum(dup_counts,
                                     chunks=(self.n_hashes * batch_size))
            dup_counts = dup_counts.detach()
            assert dup_counts.shape == dots.shape
            dots = dots - torch.log(dup_counts + 1e-9)

        # Softmax.
        dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True)
        dots = torch.exp(dots - dots_logsumexp)
        dots = self.dropout(dots)

        bo = torch.einsum('buij,buje->buie', dots, bv)
        so = torch.reshape(bo, (batch_size, -1, bo.shape[-1]))
        slogits = torch.reshape(dots_logsumexp, (
            batch_size,
            -1,
        ))

        class UnsortLogits(Function):
            @staticmethod
            def forward(ctx, so, slogits):
                so = so.detach()
                slogits = slogits.detach()
                o = batched_index_select(so, undo_sort)
                _, logits = sort_key_val(sticker, slogits, dim=-1)
                return o, logits

            @staticmethod
            def backward(ctx, grad_x, grad_y):
                so_grad = batched_index_select(grad_x, sticker)
                _, slogits_grad = sort_key_val(buckets_and_t, grad_y, dim=-1)
                return so_grad, slogits_grad

        o, logits = UnsortLogits.apply(so, slogits)

        if self.n_hashes == 1:
            out = o
        else:
            o = torch.reshape(o,
                              (batch_size, self.n_hashes, seqlen, o.shape[-1]))
            logits = torch.reshape(logits,
                                   (batch_size, self.n_hashes, seqlen, 1))
            probs = torch.exp(logits -
                              torch.logsumexp(logits, dim=1, keepdims=True))
            out = torch.sum(o * probs, dim=1)

        assert out.shape == v.shape
        return out, buckets
def logsumexp_tensors(tensors: List[torch.Tensor]) -> torch.Tensor:
    result = torch.stack(tensors, dim=0)
    return torch.logsumexp(result, dim=0)
Ejemplo n.º 8
0
    def __init__(self,
                 data=None,
                 xlim=None,
                 ylim=None,
                 nbin_per_side=None,
                 prior_count=None):
        """
        Build a 2D histogram model of the data

        :param data: [(n,2) tensor] data to model
        :param xlim: [list of 2 ints] (xmin,xmax); range of x-dimension
        :param ylim: [list of 2 ints] (ymin,ymax); range of y-dimension
        :param nbin_per_side: [int] number of bins per dimension
        :param prior_count: [float] prior counts in each cell (not added to
                            edge cells)
        """
        # if params are empty, return; model properties will be set later
        # using 'set_properties' method
        params = [data, xlim, ylim, nbin_per_side]
        if all([item is None for item in params]):
            return

        # set default value of 0 for prior_count
        if prior_count is None:
            prior_count = 0.

        ndata, dim = data.shape
        assert len(xlim) == 2
        assert len(ylim) == 2
        assert dim == 2

        # compute the "edges" of the histogram
        xtick = torch.linspace(xlim[0], xlim[1], nbin_per_side + 1)
        ytick = torch.linspace(ylim[0], ylim[1], nbin_per_side + 1)
        assert len(xtick) - 1 == nbin_per_side
        assert len(ytick) - 1 == nbin_per_side
        edges = [xtick, ytick]

        # length, in pixels, of a side of a bin
        rg_bin = torch.tensor([(xlim[1] - xlim[0]), (ylim[1] - ylim[0])],
                              dtype=torch.float32)
        rg_bin = rg_bin / nbin_per_side

        # Compute the histogram
        N = myhist3(data, edges)
        diff = ndata - torch.sum(N)
        if diff > 0:
            warnings.warn('%i position points are out of bounds' % diff)

        # Add in the prior counts
        N = torch.transpose(N, 0, 1)
        N = N + prior_count
        logN = torch.log(N)

        # Convert to probability distribution
        logpN = logN - torch.logsumexp(logN.view(-1), 0)
        assert aeq(torch.sum(torch.exp(logpN)), torch.tensor(1.))

        self.logpYX = logpN
        self.xlab = xtick
        self.xlim = xtick[[0, -1]]
        self.ylab = ytick
        self.ylim = ytick[[0, -1]]
        self.rg_bin = rg_bin
        self.prior_count = prior_count
Ejemplo n.º 9
0
    def train_from_torch(self, batch):
        self._current_epoch += 1
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']
        """
        Policy and Alpha Loss
        """
        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs,
            reparameterize=True,
            return_log_prob=True,
        )

        if self.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha *
                           (log_pi + self.target_entropy).detach()).mean()
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        if self.num_qs == 1:
            q_new_actions = self.qf1(obs, new_obs_actions)
        else:
            q_new_actions = torch.min(
                self.qf1(obs, new_obs_actions),
                self.qf2(obs, new_obs_actions),
            )

        policy_loss = (alpha * log_pi - q_new_actions).mean()

        if self._current_epoch < self.policy_eval_start:
            """
            For the initial few epochs, try doing behaivoral cloning, if needed
            conventionally, there's not much difference in performance with having 20k 
            gradient steps here, or not having it
            """
            policy_log_prob = self.policy.log_prob(obs, actions)
            policy_loss = (alpha * log_pi - policy_log_prob).mean()
        """
        QF Loss
        """
        q1_pred = self.qf1(obs, actions)
        if self.num_qs > 1:
            q2_pred = self.qf2(obs, actions)

        new_next_actions, _, _, new_log_pi, *_ = self.policy(
            next_obs,
            reparameterize=True,
            return_log_prob=True,
        )
        new_curr_actions, _, _, new_curr_log_pi, *_ = self.policy(
            obs,
            reparameterize=True,
            return_log_prob=True,
        )

        if not self.max_q_backup:
            if self.num_qs == 1:
                target_q_values = self.target_qf1(next_obs, new_next_actions)
            else:
                target_q_values = torch.min(
                    self.target_qf1(next_obs, new_next_actions),
                    self.target_qf2(next_obs, new_next_actions),
                )

            if not self.deterministic_backup:
                target_q_values = target_q_values - alpha * new_log_pi

        if self.max_q_backup:
            """when using max q backup"""
            next_actions_temp, _ = self._get_policy_actions(
                next_obs, num_actions=10, network=self.policy)
            target_qf1_values = self._get_tensor_values(
                next_obs, next_actions_temp,
                network=self.target_qf1).max(1)[0].view(-1, 1)
            target_qf2_values = self._get_tensor_values(
                next_obs, next_actions_temp,
                network=self.target_qf2).max(1)[0].view(-1, 1)
            target_q_values = torch.min(target_qf1_values, target_qf2_values)

        q_target = self.reward_scale * rewards + (
            1. - terminals) * self.discount * target_q_values
        q_target = q_target.detach()

        qf1_loss = self.qf_criterion(q1_pred, q_target)
        if self.num_qs > 1:
            qf2_loss = self.qf_criterion(q2_pred, q_target)

        ## add CQL
        random_actions_tensor = ptu.uniform(
            (q2_pred.shape[0] * self.num_random, actions.shape[-1]))
        curr_actions_tensor, curr_log_pis = self._get_policy_actions(
            obs, num_actions=self.num_random, network=self.policy)
        new_curr_actions_tensor, new_log_pis = self._get_policy_actions(
            next_obs, num_actions=self.num_random, network=self.policy)
        q1_rand = self._get_tensor_values(obs,
                                          random_actions_tensor,
                                          network=self.qf1)
        q2_rand = self._get_tensor_values(obs,
                                          random_actions_tensor,
                                          network=self.qf2)
        q1_curr_actions = self._get_tensor_values(obs,
                                                  curr_actions_tensor,
                                                  network=self.qf1)
        q2_curr_actions = self._get_tensor_values(obs,
                                                  curr_actions_tensor,
                                                  network=self.qf2)
        q1_next_actions = self._get_tensor_values(obs,
                                                  new_curr_actions_tensor,
                                                  network=self.qf1)
        q2_next_actions = self._get_tensor_values(obs,
                                                  new_curr_actions_tensor,
                                                  network=self.qf2)

        cat_q1 = torch.cat(
            [q1_rand,
             q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1)
        cat_q2 = torch.cat(
            [q2_rand,
             q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1)
        std_q1 = torch.std(cat_q1, dim=1)
        std_q2 = torch.std(cat_q2, dim=1)

        if self.min_q_version == 3:
            # importance sammpled version
            random_density = np.log(0.5**curr_actions_tensor.shape[-1])
            cat_q1 = torch.cat([
                q1_rand - random_density, q1_next_actions -
                new_log_pis.detach(), q1_curr_actions - curr_log_pis.detach()
            ], 1)
            cat_q2 = torch.cat([
                q2_rand - random_density, q2_next_actions -
                new_log_pis.detach(), q2_curr_actions - curr_log_pis.detach()
            ], 1)

        min_qf1_loss = torch.logsumexp(
            cat_q1 / self.temp,
            dim=1,
        ).mean() * self.min_q_weight * self.temp
        min_qf2_loss = torch.logsumexp(
            cat_q2 / self.temp,
            dim=1,
        ).mean() * self.min_q_weight * self.temp
        """Subtract the log likelihood of data"""
        min_qf1_loss = min_qf1_loss - q1_pred.mean() * self.min_q_weight
        min_qf2_loss = min_qf2_loss - q2_pred.mean() * self.min_q_weight

        if self.with_lagrange:
            alpha_prime = torch.clamp(self.log_alpha_prime.exp(),
                                      min=0.0,
                                      max=1000000.0)
            min_qf1_loss = alpha_prime * (min_qf1_loss -
                                          self.target_action_gap)
            min_qf2_loss = alpha_prime * (min_qf2_loss -
                                          self.target_action_gap)

            self.alpha_prime_optimizer.zero_grad()
            alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5
            alpha_prime_loss.backward(retain_graph=True)
            self.alpha_prime_optimizer.step()

        qf1_loss = qf1_loss + min_qf1_loss
        qf2_loss = qf2_loss + min_qf2_loss
        """
        Update networks
        """
        # Update the Q-functions iff
        self._num_q_update_steps += 1
        self.qf1_optimizer.zero_grad()
        qf1_loss.backward(retain_graph=True)
        self.qf1_optimizer.step()

        if self.num_qs > 1:
            self.qf2_optimizer.zero_grad()
            qf2_loss.backward(retain_graph=True)
            self.qf2_optimizer.step()

        self._num_policy_update_steps += 1
        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=False)
        self.policy_optimizer.step()
        """
        Soft Updates
        """
        ptu.soft_update_from_to(self.qf1, self.target_qf1,
                                self.soft_target_tau)
        if self.num_qs > 1:
            ptu.soft_update_from_to(self.qf2, self.target_qf2,
                                    self.soft_target_tau)
        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """
            policy_loss = (log_pi - q_new_actions).mean()

            self.eval_statistics['QF1 Loss'] = np.mean(ptu.get_numpy(qf1_loss))
            self.eval_statistics['min QF1 Loss'] = np.mean(
                ptu.get_numpy(min_qf1_loss))
            if self.num_qs > 1:
                self.eval_statistics['QF2 Loss'] = np.mean(
                    ptu.get_numpy(qf2_loss))
                self.eval_statistics['min QF2 Loss'] = np.mean(
                    ptu.get_numpy(min_qf2_loss))

            if not self.discrete:
                self.eval_statistics['Std QF1 values'] = np.mean(
                    ptu.get_numpy(std_q1))
                self.eval_statistics['Std QF2 values'] = np.mean(
                    ptu.get_numpy(std_q2))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF1 in-distribution values',
                        ptu.get_numpy(q1_curr_actions),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF2 in-distribution values',
                        ptu.get_numpy(q2_curr_actions),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF1 random values',
                        ptu.get_numpy(q1_rand),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF2 random values',
                        ptu.get_numpy(q2_rand),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF1 next_actions values',
                        ptu.get_numpy(q1_next_actions),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'QF2 next_actions values',
                        ptu.get_numpy(q2_next_actions),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict('actions',
                                              ptu.get_numpy(actions)))
                self.eval_statistics.update(
                    create_stats_ordered_dict('rewards',
                                              ptu.get_numpy(rewards)))

            self.eval_statistics['Num Q Updates'] = self._num_q_update_steps
            self.eval_statistics[
                'Num Policy Updates'] = self._num_policy_update_steps
            self.eval_statistics['Policy Loss'] = np.mean(
                ptu.get_numpy(policy_loss))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q1 Predictions',
                    ptu.get_numpy(q1_pred),
                ))
            if self.num_qs > 1:
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Q2 Predictions',
                        ptu.get_numpy(q2_pred),
                    ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Q Targets',
                    ptu.get_numpy(q_target),
                ))
            self.eval_statistics.update(
                create_stats_ordered_dict(
                    'Log Pis',
                    ptu.get_numpy(log_pi),
                ))
            if not self.discrete:
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Policy mu',
                        ptu.get_numpy(policy_mean),
                    ))
                self.eval_statistics.update(
                    create_stats_ordered_dict(
                        'Policy log std',
                        ptu.get_numpy(policy_log_std),
                    ))

            if self.use_automatic_entropy_tuning:
                self.eval_statistics['Alpha'] = alpha.item()
                self.eval_statistics['Alpha Loss'] = alpha_loss.item()

            if self.with_lagrange:
                self.eval_statistics['Alpha_prime'] = alpha_prime.item()
                self.eval_statistics['min_q1_loss'] = ptu.get_numpy(
                    min_qf1_loss).mean()
                self.eval_statistics['min_q2_loss'] = ptu.get_numpy(
                    min_qf2_loss).mean()
                self.eval_statistics[
                    'threshold action gap'] = self.target_action_gap
                self.eval_statistics[
                    'alpha prime loss'] = alpha_prime_loss.item()

        self._n_train_steps_total += 1
Ejemplo n.º 10
0
    def cpc_loss(self, gru_input_feats, gru_output_feats, feats_len):
        zt_feats = gru_input_feats.contiguous().view(gru_input_feats.size(0),
                                                     gru_input_feats.size(1),
                                                     -1)
        ct = gru_output_feats.contiguous().view(gru_output_feats.size(0),
                                                gru_output_feats.size(1), -1)
        zt_length = feats_len

        tot_loss = 0
        nb_examples = 0
        lossK = {}  # key=k, value=(tot_loss_k, nb_example_k)
        nbErrK = {}  # key=k, value=(tot_err_k, nb_example_k)

        # change from BxTx(FxC) to TxBxF
        zt_feats = zt_feats.permute(1, 0, 2)
        ct = ct.permute(1, 0, 2)

        for b in range(zt_length.size(0)):
            seq_len = zt_length[b].item()

            # compute indices
            matK = np.arange(self.k + 1)[:, np.newaxis] + np.arange(0, seq_len)
            # example:
            # ct_i    (0 1 2 3 4 5 6)
            # zt_i k0 (1 2 3 4 5 6 7)
            # zt_i k1 (2 3 4 5 6 7 8)
            # ...

            noise = min(self.N, seq_len - 1)
            noiseC_ind = np.arange(seq_len * noise) // noise
            # if noise = 3, produce (0 0 0 1 1 1 2 2 2 ...)

            for k in range(self.k_step, self.k, self.k_step):
                z_ind = matK[k][matK[k] < seq_len]
                # for example if k=1, z_ind = (2 3 4 5 ... seq_len-1)
                in_feats_z = zt_feats[z_ind, b]
                # then select the zt_feats corresponding to thoses indices
                in_feats_c = ct[matK[0][:in_feats_z.size(0)], b]
                # and the ct_feats correesponding to the first line of matK
                # limited to the number of values in z_feats
                # then we wants to learn W as f(x_1, c_1) = exp(z_1.T W_1 c_1)

                noiseInd = np.zeros((in_feats_z.size(0), noise))

                for i, z in enumerate(z_ind):
                    rand = np.random.permutation(seq_len)
                    orig = rand[rand != z]
                    rand = orig[orig < (z + self.n_around)]
                    rand = rand[rand > (z - self.n_around)]
                    if (rand.shape[0] >= noise):
                        n_indices = rand[:noise]
                    else:
                        n_around = noise + 1
                        rand = orig[orig < (z + n_around)]
                        rand = rand[rand > (z - n_around)]
                        n_indices = rand[:noise]
                    # taking random indices (different to the one of the posit.
                    # z_feat, limited to the number of noise that we want
                    noiseInd[i] = n_indices
                    # for each value of z we have noise random indices
                    # in noiseInd matrix

                # noiseC_ind contains  (0 0 0 ...)
                # noise_Ind  contains  (rand(seq_len)!=z rand(seq_len)!=z ...)
                noise_feats_z = zt_feats[
                    noiseInd.reshape(in_feats_z.size(0) * noise), b]
                noise_feats_c = ct[noiseC_ind[:in_feats_z.size(0) * noise], b]

                fxt_all = self.W[k](
                    torch.cat((in_feats_z, noise_feats_z), 0),
                    torch.cat((in_feats_c, noise_feats_c), 0),
                )

                f_x_t_k = fxt_all[:in_feats_z.size(0)]

                # loss = -(torch.log(f_x_t_k.exp() / fxt_all.exp().sum()).sum())
                loss = -(f_x_t_k - torch.logsumexp(fxt_all, dim=0)).sum()

                slen = in_feats_z.size(0)

                if self.cpc_compute_kcer:
                    # classify each elem of the sequence to compute the cer
                    nbErr = 0
                    for pred in range(slen):
                        in_feats = (in_feats_c[pred], in_feats_z[pred])
                        offset = pred * noise
                        noise_f = (
                            noise_feats_c[offset:offset + noise],
                            noise_feats_z[offset:offset + noise],
                        )
                        f_x_t_n = F.bilinear(
                            torch.cat((in_feats[1].unsqueeze(0), noise_f[1]),
                                      0),
                            torch.cat((in_feats[0].unsqueeze(0), noise_f[0]),
                                      0),
                            self.W[k].weight,
                            self.W[k].bias,
                        )
                        probs = f_x_t_n.exp() / f_x_t_n.exp().sum()
                        nbErr += probs.argmax() != 0

                tot_loss += loss
                nb_examples += f_x_t_k.size(0)
                if k not in lossK:
                    lossK[k] = (loss, f_x_t_k.size(0))
                    if self.cpc_compute_kcer:
                        nbErrK[k] = (nbErr, f_x_t_k.size(0))
                else:
                    (tot_loss_k, nb_examples_k) = lossK[k]
                    nb_examples_k += f_x_t_k.size(0)
                    if self.cpc_compute_kcer:
                        (tot_errors, nbEx) = nbErrK[k]
                        tot_errors += nbErr
                        nbErrK[k] = (tot_errors, nb_examples_k)
                    tot_loss_k += loss
                    lossK[k] = (tot_loss_k, nb_examples_k)

        if self.reduction == "sum":
            tot_loss = 0
            for k in lossK.keys():
                (tot_loss_k, nb_examples_k) = lossK[k]
                tot_loss += tot_loss_k / nb_examples_k
        else:
            tot_loss = 0
            nb_examples = 0
            for k in lossK.keys():
                (tot_loss_k, nb_examples_k) = lossK[k]
                tot_loss += tot_loss_k
                nb_examples += nb_examples_k
            tot_loss /= nb_examples

        details = {}
        details["loss"] = tot_loss
        for k in lossK.keys():
            (tot_loss_k, nb_examples_k) = lossK[k]
            if self.cpc_compute_kcer:
                (nbErr, nbEx) = nbErrK[k]
                details["cer_k" + str(k + 1)] = (torch.tensor(nbErr / nbEx) *
                                                 100)
            if self.loss_details:
                details["loss_k" + str(k + 1)] = tot_loss_k / nb_examples_k

        if tot_loss.item() == float("inf") or tot_loss.item() == float("-inf"):
            print("Inf loss !!")

        return tot_loss, details
Ejemplo n.º 11
0
    def forward(self, scene: torch.Tensor):
        """
        :param scene: tensor of shape num_peds, history_size, data_dim
        :return: predicted poses distributions for each agent at next 12 timesteps
        """
        bs = scene.shape[0]
        poses = scene[:, :, :2]
        pv = scene[:, :, 2:6]
        vel = scene[:, :, 2:4]
        acc = scene[:, :, 4:6]
        pav = scene[:, :, :6]

        # lstm_out, hid = self.node_hist_encoder(pav)  # lstm_out shape num_peds, timestamps ,  2*hidden_dim
        lstm_out_acc, hid = self.node_hist_encoder_acc(
            acc)  # lstm_out shape num_peds, timestamps ,  2*hidden_dim
        lstm_out_vell, hid = self.node_hist_encoder_vel(
            vel)  # lstm_out shape num_peds, timestamps ,  2*hidden_dim
        lstm_out_poses, hid = self.node_hist_encoder_poses(poses)
        lstm_out = lstm_out_vell + lstm_out_poses + lstm_out_acc
        # lstm_out = lstm_out_poses  # + lstm_out_poses

        current_pose = scene[:, -1, :2]  # num_people, data_dim
        current_state = poses[:, -1, :]
        np, data_dim = current_pose.shape
        stacked = current_pose.flatten().repeat(np).reshape(np, np * data_dim)
        deltas = (stacked - current_pose.repeat(1, np)).reshape(
            np, np, data_dim)  # np, np, data_dim

        distruction, _ = self.edge_encoder(deltas)
        catted = torch.cat((lstm_out[:, -1:, :], distruction[:, -1:, :]),
                           dim=1)
        a_0 = F.dropout(self.action(current_state.reshape(bs, -1)),
                        self.dropout_p)
        state = F.dropout(self.state(catted.reshape(bs, -1)), self.dropout_p)

        current_state = current_state.unsqueeze(1)
        gauses = []
        inp = F.dropout(torch.cat((catted.reshape(bs, -1), a_0), dim=-1),
                        self.dropout_p)
        for i in range(12):
            h_state = self.gru(inp.reshape(bs, -1), state)

            log_pis, deltas, log_sigmas, corrs = self.project_to_GMM_params(
                h_state)
            deltas = torch.clamp(deltas, max=1.5, min=-1.5)

            log_pis = log_pis.reshape(bs, -1)
            log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True)
            deltas = deltas.reshape(bs, -1, 2)
            log_sigmas = log_sigmas.reshape(bs, -1, 2)
            corrs = corrs.reshape(bs, -1, 1)

            mus = deltas + current_state
            current_state = mus
            variance = torch.clamp(torch.exp(log_sigmas).unsqueeze(2)**2,
                                   max=1e3)

            m_diag = variance * torch.eye(2).to(variance.device)
            sigma_xy = torch.clamp(torch.prod(torch.exp(log_sigmas), dim=-1),
                                   min=1e-8,
                                   max=1e3)

            mix = D.Categorical(log_pis)
            comp = D.MultivariateNormal(mus, m_diag)
            gmm = D.MixtureSameFamily(mix, comp)
            t = (sigma_xy * corrs.squeeze()).reshape(-1, 1, 1)
            cov_matrix = m_diag  # + anti_diag
            gauses.append(gmm)
            a_t = gmm.sample()  # possible grad problems?
            a_tt = F.dropout(self.action(a_t.reshape(bs, -1)), self.dropout_p)
            state = h_state
            inp = F.dropout(torch.cat((catted.reshape(bs, -1), a_tt), dim=-1),
                            self.dropout_p)

        return gauses
Ejemplo n.º 12
0
    def forward(self, scene: torch.Tensor):
        """
        :param scene: tensor of shape num_peds, history_size, data_dim
        :return: predicted poses distributions for each agent at next 12 timesteps
        """
        bs = scene.shape[0]
        poses = scene[:, :, :2]
        pv = scene[:, :, 2:6]
        vel = scene[:, :, 2:4]
        acc = scene[:, :, 4:6]
        pav = scene[:, :, :6]

        lstm__poses_out, _ = self.node_hist_encoder_poses(
            poses)  # lstm_out shape num_peds, timestamps ,  2*hidden_dim
        lstm_out_acc, hid = self.node_hist_encoder_acc(
            acc)  # lstm_out shape num_peds, timestamps ,  2*hidden_dim
        lstm_out_vell, hid = self.node_hist_encoder_vel(
            vel)  # lstm_out shape num_peds, timestamps ,  2*hidden_dim
        # lstm_out_poses, hid = self.node_hist_encoder_poses(poses)
        lstm_out = lstm_out_vell + lstm_out_acc  # + lstm_out_poses

        current_state = poses[:, -1, :]
        # np, data_dim = current_pose.shape
        bs, seq, data_dim = poses.shape
        stacked = poses.permute(1, 0, 2).reshape(seq,
                                                 -1).repeat(1, bs).reshape(
                                                     seq, bs, bs * data_dim)
        deltas = (stacked - poses.permute(1, 0, 2).repeat(1, 1, bs))
        deltas = deltas.permute(1, 0, 2).reshape(bs, seq, bs, data_dim)
        deltas_flat = deltas.reshape(deltas.shape[0], deltas.shape[1],
                                     -1).cuda()
        max_size = 50  # TODO: fix
        prep_for_deltas = torch.zeros(bs, seq, 50).cuda()
        if deltas_flat.shape[2] >= max_size:
            prep_for_deltas = deltas_flat[:, :, :max_size]
        else:
            prep_for_deltas[:, :, :deltas_flat.shape[2]] = deltas_flat
        at_hidden = self.att.init_hidden(bs=bs)
        for i in range(8):
            at_output, at_hidden, at_normalized_weights = self.att(
                at_hidden, lstm__poses_out[:, i:i + 1, :],
                prep_for_deltas[:, i:i + 1, :])
        # current_pose = scene[:, -1, :2]  # num_people, data_dim
        # stacked = current_pose.flatten().repeat(np).reshape(np, np * data_dim)
        # deltas = (stacked - current_pose.repeat(1, np)).reshape(np, np, data_dim)  # np, np, data_dim

        # distruction, _ = self.edge_encoder(deltas, poses, poses)
        catted = torch.cat((lstm_out[:, -1:, :], at_output[:, -1:, :]), dim=2)
        a_0 = F.dropout(self.action(current_state.reshape(bs, -1)),
                        self.dropout_p)
        state = F.dropout(self.state(catted.reshape(bs, -1)), self.dropout_p)

        current_state = current_state.unsqueeze(1)
        gauses = []
        inp = F.dropout(torch.cat((catted.reshape(bs, -1), a_0), dim=-1),
                        self.dropout_p)
        for i in range(12):
            h_state = self.gru(inp.reshape(bs, -1), state)

            log_pis, deltas, log_sigmas, corrs = self.project_to_GMM_params(
                h_state)
            deltas = torch.clamp(deltas, max=1.5, min=-1.5)

            log_pis = log_pis.reshape(bs, -1)
            log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True)
            deltas = deltas.reshape(bs, -1, 2)
            log_sigmas = log_sigmas.reshape(bs, -1, 2)
            corrs = corrs.reshape(bs, -1, 1)

            mus = deltas + current_state
            current_state = mus
            variance = torch.clamp(torch.exp(log_sigmas).unsqueeze(2)**2,
                                   max=1e3)

            m_diag = variance * torch.eye(2).to(variance.device)
            sigma_xy = torch.clamp(torch.prod(torch.exp(log_sigmas), dim=-1),
                                   min=1e-8,
                                   max=1e3)

            mix = D.Categorical(log_pis)
            comp = D.MultivariateNormal(mus, m_diag)
            gmm = D.MixtureSameFamily(mix, comp)
            t = (sigma_xy * corrs.squeeze()).reshape(-1, 1, 1)
            cov_matrix = m_diag  # + anti_diag
            gauses.append(gmm)
            a_t = gmm.sample()  # possible grad problems?
            state = h_state
            inp = F.dropout(torch.cat((catted.reshape(bs, -1), a_t), dim=-1),
                            self.dropout_p)
        return gauses
Ejemplo n.º 13
0
def log_matrix_product(A, B):
    """
    Computes the matrix products of two matrices in log-space, returning the result in log-space.
    This is useful for calculating the vector chain rule for Jacobian terms.
    """
    return torch.logsumexp(A.unsqueeze(-1) + B.unsqueeze(-3), dim=-2)
Ejemplo n.º 14
0
def ensemble_test_stats(model,
                        savefile_list,
                        dset,
                        data_dir,
                        corruption=None,
                        rotation=None,
                        batch_size=256,
                        cuda=True,
                        gpu=None,
                        MC_samples=0,
                        workers=4,
                        iterate=False):
    assert not (corruption is not None and rotation is not None)
    if corruption is None and rotation is None:
        _, _, val_loader, _, _, _ = \
            get_image_loader(dset, batch_size, cuda=cuda, workers=workers, distributed=False, data_dir=data_dir)
    elif corruption is not None:
        val_loader = load_corrupted_dataset(dset,
                                            severity=corruption,
                                            data_dir=data_dir,
                                            batch_size=batch_size,
                                            cuda=cuda,
                                            workers=workers)
    elif rotation is not None:
        val_loader = rotate_load_dataset(dset,
                                         rotation,
                                         data_dir=data_dir,
                                         batch_size=batch_size,
                                         cuda=cuda,
                                         workers=workers)

    logprob_vec, target_vec = ensemble_get_preds_targets(model,
                                                         savefile_list,
                                                         val_loader,
                                                         cuda=cuda,
                                                         gpu=gpu,
                                                         return_vector=iterate)
    if iterate:
        brier_vec = []
        err_vec = []
        ll_vec = []
        ece_vec = []

        for n_samples in range(1, logprob_vec.shape[1] + 1):
            comb_logprobs = torch.logsumexp(logprob_vec[:, :n_samples, :],
                                            dim=1,
                                            keepdim=False) - np.log(n_samples)

            brier_vec.append(
                class_brier(y=target_vec, log_probs=comb_logprobs, probs=None))
            err_vec.append(class_err(y=target_vec, model_out=comb_logprobs))
            ll_vec.append(
                class_ll(y=target_vec,
                         log_probs=comb_logprobs,
                         probs=None,
                         eps=1e-40))
            ece_vec.append(
                class_ECE(y=target_vec,
                          log_probs=comb_logprobs,
                          probs=None,
                          nbins=10))
        return err_vec, ll_vec, brier_vec, ece_vec

    brier = class_brier(y=target_vec, log_probs=logprob_vec, probs=None)
    err = class_err(y=target_vec, model_out=logprob_vec)
    ll = class_ll(y=target_vec, log_probs=logprob_vec, probs=None, eps=1e-40)
    ece = class_ECE(y=target_vec, log_probs=logprob_vec, probs=None, nbins=10)
    return err, ll, brier, ece
Ejemplo n.º 15
0
def cal_lstm_crf_loss(crf_scores, targets, tag2id):
    """计算双向LSTM-CRF模型的损失
    该损失函数的计算可以参考:https://arxiv.org/pdf/1603.01360.pdf
    """
    pad_id = tag2id.get('<pad>')
    start_id = tag2id.get('<start>')
    end_id = tag2id.get('<end>')

    device = crf_scores.device

    # targets:[B, L] crf_scores:[B, L, T, T]
    batch_size, max_len = targets.size()
    target_size = len(tag2id)

    # mask = 1 - ((targets == pad_id) + (targets == end_id))  # [B, L]
    mask = (targets != pad_id)
    lengths = mask.sum(dim=1)
    targets = indexed(targets, target_size, start_id)

    # # 计算Golden scores方法1
    # import pdb
    # pdb.set_trace()
    targets = targets.masked_select(mask)  # [real_L]

    flatten_scores = crf_scores.masked_select(
        mask.view(batch_size, max_len, 1, 1).expand_as(crf_scores)).view(
            -1, target_size * target_size).contiguous()

    golden_scores = flatten_scores.gather(dim=1,
                                          index=targets.unsqueeze(1)).sum()

    # 计算golden_scores方法2:利用pack_padded_sequence函数
    # targets[targets == end_id] = pad_id
    # scores_at_targets = torch.gather(
    #     crf_scores.view(batch_size, max_len, -1), 2, targets.unsqueeze(2)).squeeze(2)
    # scores_at_targets, _ = pack_padded_sequence(
    #     scores_at_targets, lengths-1, batch_first=True
    # )
    # golden_scores = scores_at_targets.sum()

    # 计算all path scores
    # scores_upto_t[i, j]表示第i个句子的第t个词被标注为j标记的所有t时刻事前的所有子路径的分数之和
    scores_upto_t = torch.zeros(batch_size, target_size).to(device)
    for t in range(max_len):
        # 当前时刻 有效的batch_size(因为有些序列比较短)
        batch_size_t = (lengths > t).sum().item()
        if t == 0:
            scores_upto_t[:batch_size_t] = crf_scores[:batch_size_t, t,
                                                      start_id, :]
        else:
            # We add scores at current timestep to scores accumulated up to previous
            # timestep, and log-sum-exp Remember, the cur_tag of the previous
            # timestep is the prev_tag of this timestep
            # So, broadcast prev. timestep's cur_tag scores
            # along cur. timestep's cur_tag dimension
            scores_upto_t[:batch_size_t] = torch.logsumexp(
                crf_scores[:batch_size_t, t, :, :] +
                scores_upto_t[:batch_size_t].unsqueeze(2),
                dim=1)
    all_path_scores = scores_upto_t[:, end_id].sum()

    # 训练大约两个epoch loss变成负数,从数学的角度上来说,loss = -logP
    loss = (all_path_scores - golden_scores) / batch_size
    return loss
def visualize_inference(model,
                        inputs,
                        results,
                        savedir,
                        name,
                        cfg,
                        energy_threshold=None):
    """
    A function used to visualize final network predictions.
    It shows the original image and up to 20
    predicted object bounding boxes on the original image.

    Valuable for debugging inference methods.

    Args:
        inputs (list): a list that contains input to the model.
        results (List[Instances]): a list of #images elements.
    """
    import cv2
    from detectron2.utils.visualizer import ColorMode, _SMALL_OBJECT_AREA_THRESH
    from detectron2.data import MetadataCatalog
    from src.engine.myvisualizer import MyVisualizer
    max_boxes = 20

    # required_width = inputs[0]['width']
    # required_height = inputs[0]['height']

    # img = inputs[0]["image"].cpu().numpy()
    # assert img.shape[0] == 3, "Images should have 3 channels."
    # if model.input_format == "RGB":
    #     img = img[::-1, :, :]
    # img = img.transpose(1, 2, 0)
    # img = cv2.resize(img, (required_width, required_height))
    # breakpoint()
    results = results['instances']
    predicted_boxes = results.pred_boxes.tensor.cpu().numpy()

    v_pred = MyVisualizer(inputs, MetadataCatalog.get(cfg.DATASETS.TRAIN[0]))
    # print(len(predicted_boxes))
    # breakpoint()
    labels = results.det_labels[0:max_boxes]
    scores = results.scores[0:max_boxes]
    print(labels)
    print(scores)
    # breakpoint()

    inter_feat = results.inter_feat[0:max_boxes]
    print(inter_feat)
    print(torch.logsumexp(inter_feat[:, :-1], dim=1).cpu().data.numpy())
    print((np.argwhere(
        torch.logsumexp(inter_feat[:, :-1], dim=1).cpu().data.numpy() <
        energy_threshold)).reshape(-1))
    if energy_threshold:
        labels[(np.argwhere(
            torch.logsumexp(inter_feat[:, :-1], dim=1).cpu().data.numpy() <
            energy_threshold)).reshape(-1)] = 8
    print(labels)
    # # if name == '133631':
    #     # breakpoint()
    # # breakpoint()
    if len(scores) == 0 or max(scores) <= 0.0:
        return

    v_pred = v_pred.overlay_covariance_instances(
        labels=labels,
        scores=scores,
        boxes=predicted_boxes[0:max_boxes],
        covariance_matrices=None,
        score_threshold=0.0)
    # covariance_matrices=predicted_covar_mats[0:max_boxes])

    prop_img = v_pred.get_image()
    vis_name = f"{max_boxes} Highest Scoring Results"
    # cv2.imshow(vis_name, prop_img)
    # cv2.savefig
    cv2.imwrite(savedir + '/' + name + '.jpg', prop_img)
    cv2.waitKey()
Ejemplo n.º 17
0
 def log_prob(self, value):
     log_probs = torch.stack([sub_model.log_prob(value)
                              for sub_model in self.models])
     cat_log_probs = self.categorical.probs.view(-1, 1).log()
     return torch.logsumexp(log_probs + cat_log_probs, dim=0)
Ejemplo n.º 18
0
    def train(self):
        # copy and paste
        if len(self.replay_buffer) < self.batch_size:
            return

        self.train_cnt += 1
        self.total_steps = self.train_cnt

        transitions = self.replay_buffer.sample(self.batch_size)
        map_func = lambda x: x[0]
        batch = Transition(*zip(*map(map_func, transitions)))

        state_batch = torch.tensor(
            np.array(batch.state, dtype=np.float32), device=self.device)
        action_batch = torch.tensor(
            np.array(batch.action, dtype=np.float32), device=self.device)
        next_state_batch = torch.tensor(
            np.array(batch.next_state, dtype=np.float32), device=self.device)
        reward_batch = torch.tensor(
            np.array(batch.reward, dtype=np.float32), device=self.device).unsqueeze(1)
        valid_batch = torch.tensor(
            np.array(batch.valid, dtype=np.float32), device=self.device).unsqueeze(1)

        target_q = self.calc_target_q(
            state_batch, action_batch, reward_batch, next_state_batch, valid_batch)
        q1_data = self.critic1(state_batch, action_batch)
        q2_data = self.critic2(state_batch, action_batch)
        q1_mse_loss = F.mse_loss(q1_data, target_q)
        q2_mse_loss = F.mse_loss(q2_data, target_q)

        num_random = 10
        random_actions = torch.FloatTensor(q2_data.shape[0] * num_random, action_batch.shape[-1]).uniform_(-1, 1).to(self.device)
        current_states = state_batch.unsqueeze(1).repeat(1, num_random, 1).view(-1, state_batch.shape[-1])
        current_actions, current_logpis, _ = self.try_act(current_states)
        current_actions, current_logpis = current_actions.detach(), current_logpis.detach()
        next_states = next_state_batch.unsqueeze(1).repeat(1, num_random, 1).view(-1, state_batch.shape[-1])
        next_actions, next_logpis, _ = self.try_act(next_states)
        next_actions, next_logpis = next_actions.detach(), next_logpis.detach()

        q1_rand = self.critic1(current_states, random_actions).view(-1, num_random)
        q2_rand = self.critic2(current_states, random_actions).view(-1, num_random)
        q1_curr = self.critic1(current_states, current_actions).view(-1, num_random)
        q2_curr = self.critic2(current_states, current_actions).view(-1, num_random)
        q1_next = self.critic1(current_states, next_actions).view(-1, num_random)
        q2_next = self.critic2(current_states, next_actions).view(-1, num_random)
        current_logpis = current_logpis.view(-1, num_random)
        next_logpis = next_logpis.view(-1, num_random)

        random_density = np.log(0.5 ** current_actions.shape[-1])
        cat_q1 = torch.cat([q1_rand - random_density,
                            q1_curr - current_logpis.detach(),
                            q1_next - next_logpis.detach()], 1)
        cat_q2 = torch.cat([q2_rand - random_density,
                            q2_curr - current_logpis.detach(),
                            q2_next - next_logpis.detach()], 1)

        cql_loss1 = (torch.logsumexp(cat_q1, dim=1, keepdim=True) - q1_data).mean()
        cql_loss2 = (torch.logsumexp(cat_q2, dim=1, keepdim=True) - q2_data).mean()

        q1_loss = q1_mse_loss + self.cql_weight * cql_loss1
        q2_loss = q2_mse_loss + self.cql_weight * cql_loss2
        q_loss = q1_loss + q2_loss

        self.q1_optim.zero_grad()
        self.q2_optim.zero_grad()
        q_loss.backward()
        self.q1_optim.step()
        self.q2_optim.step()

        pi, log_pi, _ = self.try_act(state_batch)

        qf1_pi = self.critic1(state_batch, pi)
        qf2_pi = self.critic2(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

        if self.total_steps < self.policy_eval_start:
            raw_action = atanh(action_batch)
            mean, log_std = self.actor(state_batch)
            normal = Normal(mean, log_std.exp())
            eps = 1e-6
            logprob = (normal.log_prob(raw_action)
                      - torch.log(1 - action_batch.pow(2) + eps))
            logprob = logprob.sum(1, keepdims=True)

            policy_loss = ((self.alpha * log_pi) - logprob).mean()

        self.actor_optim.zero_grad()
        policy_loss.backward()
        self.actor_optim.step()

        # adjust alpha
        alpha_loss = -(self.log_alpha *
                       (self.target_entropy + log_pi).detach()).mean()

        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        self.alpha_optim.step()
        self.alpha = self.log_alpha.exp()

        if self.train_cnt % self.target_update_interval == 0:
            self.soft_update(self.target_critic1, self.critic1)
            self.soft_update(self.target_critic2, self.critic2)
            self.prev_target_update_time = self.total_steps
Ejemplo n.º 19
0
    def forward_decoder(
        self,
        tokens,
        encoder_outs: List[EncoderOut],
        incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]],
        temperature: float = 1.0,
    ):
        log_probs = []
        avg_attn: Optional[Tensor] = None
        encoder_out: Optional[EncoderOut] = None
        for i, model in enumerate(self.models):
            if self.has_encoder():
                encoder_out = encoder_outs[i]
            # decode each model
            if self.has_incremental_states():
                decoder_out = model.decoder.forward(
                    tokens,
                    encoder_out=encoder_out,
                    incremental_state=incremental_states[i],
                )
            else:
                decoder_out = model.decoder.forward(tokens,
                                                    encoder_out=encoder_out)

            attn: Optional[Tensor] = None
            decoder_len = len(decoder_out)
            if decoder_len > 1 and decoder_out[1] is not None:
                if isinstance(decoder_out[1], Tensor):
                    attn = decoder_out[1]
                else:
                    attn_holder = decoder_out[1]["attn"]
                    if isinstance(attn_holder, Tensor):
                        attn = attn_holder
                    elif attn_holder is not None:
                        attn = attn_holder[0]
                if attn is not None:
                    attn = attn[:, -1, :]

            decoder_out_tuple = (
                decoder_out[0][:, -1:, :].div_(temperature),
                None if decoder_len <= 1 else decoder_out[1],
            )

            probs = model.get_normalized_probs(decoder_out_tuple,
                                               log_probs=True,
                                               sample=None)
            probs = probs[:, -1, :]
            if self.models_size == 1:
                return probs, attn

            log_probs.append(probs)
            if attn is not None:
                if avg_attn is None:
                    avg_attn = attn
                else:
                    avg_attn.add_(attn)
        avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0),
                                    dim=0) - math.log(self.models_size)
        if avg_attn is not None:
            avg_attn.div_(self.models_size)
        return avg_probs, avg_attn
Ejemplo n.º 20
0
    def _estimate_latent_entropies(self,
                                   samples_zCx,
                                   params_zCX,
                                   n_samples=10000):
        r"""Estimate :math:`H(z_j) = E_{q(z_j)} [-log q(z_j)] = E_{p(x)} E_{q(z_j|x)} [-log q(z_j)]`
        using the emperical distribution of :math:`p(x)`.

        Note
        ----
        - the expectation over the emperical distributio is: :math:`q(z) = 1/N sum_{n=1}^N q(z|x_n)`.
        - we assume that q(z|x) is factorial i.e. :math:`q(z|x) = \prod_j q(z_j|x)`.
        - computes numerically stable NLL: :math:`- log q(z) = log N - logsumexp_n=1^N log q(z|x_n)`.

        Parameters
        ----------
        samples_zCx: torch.tensor
            Tensor of shape (len_dataset, latent_dim) containing a sample of
            q(z|x) for every x in the dataset.

        params_zCX: tuple of torch.Tensor
            Sufficient statistics q(z|x) for each training example. E.g. for
            gaussian (mean, log_var) each of shape : (len_dataset, latent_dim).

        n_samples: int, optional
            Number of samples to use to estimate the entropies.

        Return
        ------
        H_z: torch.Tensor
            Tensor of shape (latent_dim) containing the marginal entropies H(z_j)
        """
        len_dataset, latent_dim = samples_zCx.shape
        device = samples_zCx.device
        H_z = torch.zeros(latent_dim, device=device)

        # sample from p(x)
        samples_x = torch.randperm(len_dataset, device=device)[:n_samples]
        # sample from p(z|x)
        samples_zCx = samples_zCx.index_select(0, samples_x).view(
            latent_dim, n_samples)

        mini_batch_size = 10
        samples_zCx = samples_zCx.expand(len_dataset, latent_dim, n_samples)
        mean = params_zCX[0].unsqueeze(-1).expand(len_dataset, latent_dim,
                                                  n_samples)
        log_var = params_zCX[1].unsqueeze(-1).expand(len_dataset, latent_dim,
                                                     n_samples)
        log_N = math.log(len_dataset)
        with trange(n_samples, leave=False, disable=self.is_progress_bar) as t:
            for k in range(0, n_samples, mini_batch_size):
                # log q(z_j|x) for n_samples
                idcs = slice(k, k + mini_batch_size)
                log_q_zCx = log_density_gaussian(samples_zCx[..., idcs],
                                                 mean[..., idcs],
                                                 log_var[..., idcs])
                # numerically stable log q(z_j) for n_samples:
                # log q(z_j) = -log N + logsumexp_{n=1}^N log q(z_j|x_n)
                # As we don't know q(z) we appoximate it with the monte carlo
                # expectation of q(z_j|x_n) over x. => fix a single z and look at
                # proba for every x to generate it. n_samples is not used here !
                log_q_z = -log_N + torch.logsumexp(
                    log_q_zCx, dim=0, keepdim=False)
                # H(z_j) = E_{z_j}[- log q(z_j)]
                # mean over n_samples (i.e. dimesnion 1 because already summed over 0).
                H_z += (-log_q_z).sum(1)

                t.update(mini_batch_size)

        H_z /= n_samples

        return H_z
Ejemplo n.º 21
0
 def _iwbo(self, x, k):
     x_stack = torch.cat([x for _ in range(k)], dim=0)
     ll_stack = self.log_prob(x_stack)
     ll = torch.stack(torch.chunk(ll_stack, k, dim=0))
     iwae = torch.logsumexp(ll, dim=0) - math.log(k)
     return -iwae.sum() / (x.numel() * math.log(2))
Ejemplo n.º 22
0
    def forward(self, xs: torch.Tensor, **kwargs):

        # Get the batch size
        batch_size = xs.size(0)

        # Keep a dict to assign attributes to nodes. Create one if not already existent
        node_attr = kwargs.setdefault('attr', dict())
        # In this dict, store the probability of arriving at this node.
        # It is assumed that when a parent node calls forward on this node it passes its node_attr object with the call
        # and that it sets the path probability of arriving at its child
        # Therefore, if this attribute is not present this node is assumed to not have a parent.
        # The probability of arriving at this node should thus be set to 1 (as this would be the root in this case)
        # The path probability is tracked for all x in the batch
        if not self._log_probabilities:
            pa = node_attr.setdefault((self, 'pa'),
                                      torch.ones(batch_size, device=xs.device))
        else:
            pa = node_attr.setdefault((self, 'pa'),
                                      torch.ones(batch_size, device=xs.device))

        # Obtain the probabilities of taking the right subtree
        ps = self.g(xs, **kwargs)  # shape: (bs,)

        if not self._log_probabilities:
            # Store decision node probabilities as node attribute
            node_attr[self, 'ps'] = ps
            # Store path probabilities of arriving at child nodes as node attributes
            node_attr[self.l, 'pa'] = (1 - ps) * pa
            node_attr[self.r, 'pa'] = ps * pa
            # # Store alpha value for this batch for this decision node
            # node_attr[self, 'alpha'] = torch.sum(pa * ps) / torch.sum(pa)

            # Obtain the unweighted probability distributions from the child nodes
            l_dists, _ = self.l.forward(xs, **kwargs)  # shape: (bs, k)
            r_dists, _ = self.r.forward(xs, **kwargs)  # shape: (bs, k)
            # Weight the probability distributions by the decision node's output
            ps = ps.view(batch_size, 1)
            return (1 -
                    ps) * l_dists + ps * r_dists, node_attr  # shape: (bs, k)
        else:
            # Store decision node probabilities as node attribute
            node_attr[self, 'ps'] = ps

            # Store path probabilities of arriving at child nodes as node attributes
            # source: rewritten to pytorch from
            # https://github.com/tensorflow/probability/blob/v0.9.0/tensorflow_probability/python/math/generic.py#L447-L471
            x = torch.abs(
                ps) + 1e-7  # add small epsilon for numerical stability
            oneminusp = torch.where(x < np.log(2), torch.log(-torch.expm1(-x)),
                                    torch.log1p(-torch.exp(-x)))

            node_attr[self.l, 'pa'] = oneminusp + pa
            node_attr[self.r, 'pa'] = ps + pa

            # Obtain the unweighted probability distributions from the child nodes
            l_dists, _ = self.l.forward(xs, **kwargs)  # shape: (bs, k)
            r_dists, _ = self.r.forward(xs, **kwargs)  # shape: (bs, k)

            # Weight the probability distributions by the decision node's output
            ps = ps.view(batch_size, 1)
            oneminusp = oneminusp.view(batch_size, 1)
            logs_stacked = torch.stack((oneminusp + l_dists, ps + r_dists))
            return torch.logsumexp(logs_stacked,
                                   dim=0), node_attr  # shape: (bs,)
Ejemplo n.º 23
0
 def _energy(self, x):
     energies = torch.stack([c.energy(x) for c in self._components], dim=-1)
     return -torch.logsumexp(-energies + self._log_weights.view(1, 1, -1),
                             dim=-1)
Ejemplo n.º 24
0
 def E(self, x):
     vec = torch.zeros((x.size(0), self.K), device=self.L.device)
     for k in range(self.K):
         vec[:, k] = -batch_mvn.E(x, self.mu[k].unsqueeze(0), self.inv_L[k])
     return -torch.logsumexp(vec, dim=1)
Ejemplo n.º 25
0
    def forward(self, *inputs):

        horizontal_unary = inputs[0]
        vertical_unary = inputs[1]
        horizontal_pairwise = inputs[2]
        vertical_pairwise = inputs[3]

        width = horizontal_unary.size(-1)
        height = horizontal_unary.size(-2)

        # horizontal_left_cache = torch.empty(horizontal_unary.shape, device=horizontal_unary.device)
        # horizontal_right_cache = torch.empty(horizontal_unary.shape, device=horizontal_unary.device)
        # vertical_top_cache = torch.empty(vertical_unary.shape, device=vertical_unary.device)
        # vertical_bottom_cache = torch.empty(vertical_unary.shape, device=vertical_unary.device)

        # # Initialize max-marginals for head and tail positions
        # horizontal_left_cache[:, :, :, 0] = horizontal_unary[:, :, :, 0]
        # horizontal_right_cache[:, :, :, -1] = horizontal_unary[:, :, :, -1]

        # vertical_top_cache[:, :, 0, :] = vertical_unary[:, :, 0, :]
        # vertical_bottom_cache[:, :, -1, :] = vertical_unary[:, :, -1, :]

        horizontal_left_cache = [horizontal_unary[:, :, :, 0]]
        horizontal_right_cache = [horizontal_unary[:, :, :, -1]]

        vertical_top_cache = [vertical_unary[:, :, 0, :]]
        vertical_bottom_cache = [vertical_unary[:, :, -1, :]]

        # argmax_h_left = []
        # argmax_h_right = []

        # Compute max-marginals along horizontal and vertical chains
        for i in range(width - 1):
            # max_margins, argmaxes = torch.max(
            #     horizontal_pairwise[:, :, :, :, i] + horizontal_left_cache[-1].unsqueeze(2),
            #     dim=1,
            #     keepdim=False
            # )
            # argmax_h_left.append(argmaxes)
            # horizontal_left_cache.append(horizontal_unary[:, :, :, i+1] + max_margins)
            # max_margins, argmaxes = torch.max(
            #     horizontal_pairwise[:, :, :, :, width-i-2] + horizontal_right_cache[-1].unsqueeze(1),
            #     dim=2,
            #     keepdim=False
            # )
            # argmax_h_right.append(argmaxes)
            # horizontal_right_cache.append(horizontal_unary[:, :, :, width-i-2] + max_margins)
            max_margins = self.gamma * torch.logsumexp(
                (horizontal_pairwise[:, :, :, :, i] +
                 horizontal_left_cache[-1].unsqueeze(2)) / self.gamma,
                dim=1,
                keepdim=False)
            horizontal_left_cache.append(horizontal_unary[:, :, :, i + 1] +
                                         max_margins)
            max_margins = self.gamma * torch.logsumexp(
                (horizontal_pairwise[:, :, :, :, width - i - 2] +
                 horizontal_right_cache[-1].unsqueeze(1)) / self.gamma,
                dim=2,
                keepdim=False)
            horizontal_right_cache.append(horizontal_unary[:, :, :, width - i -
                                                           2] + max_margins)

        # argmax_h_left = torch.stack(argmax_h_left, dim=-1)
        # argmax_h_right = torch.stack(argmax_h_right[::-1], dim=-1)

        horizontal_left_cache = torch.stack(horizontal_left_cache, dim=-1)
        horizontal_right_cache = torch.stack(horizontal_right_cache[::-1],
                                             dim=-1)

        horizontal_marginals = horizontal_left_cache + horizontal_right_cache - horizontal_unary

        # argmax_v_top = []
        # argmax_v_bottom = []

        for i in range(height - 1):
            # max_margins, argmaxes = torch.max(
            #     vertical_pairwise[:, :, :, i, :] + vertical_top_cache[-1].unsqueeze(2),
            #     dim=1,
            #     keepdim=False
            # )
            # argmax_v_top.append(argmaxes)
            # vertical_top_cache.append(vertical_unary[:, :, i+1, :] + max_margins)

            # max_margins, argmaxes = torch.max(
            #     vertical_pairwise[:, :, :, height-i-2, :] + vertical_bottom_cache[-1].unsqueeze(1),
            #     dim=2,
            #     keepdim=False
            # )
            # argmax_v_bottom.append(argmaxes)
            # vertical_bottom_cache.append(vertical_unary[:, :, height-i-2, :] + max_margins)
            max_margins = self.gamma * torch.logsumexp(
                (vertical_pairwise[:, :, :, i, :] +
                 vertical_top_cache[-1].unsqueeze(2)) / self.gamma,
                dim=1,
                keepdim=False)
            vertical_top_cache.append(vertical_unary[:, :, i + 1, :] +
                                      max_margins)

            max_margins = self.gamma * torch.logsumexp(
                (vertical_pairwise[:, :, :, height - i - 2, :] +
                 vertical_bottom_cache[-1].unsqueeze(1)) / self.gamma,
                dim=2,
                keepdim=False)
            vertical_bottom_cache.append(vertical_unary[:, :, height - i -
                                                        2, :] + max_margins)

        # argmax_v_top = torch.stack(argmax_v_top, dim=-2)
        # argmax_v_bottom = torch.stack(argmax_v_bottom[::-1], dim=-2)

        vertical_top_cache = torch.stack(vertical_top_cache, dim=-2)
        vertical_bottom_cache = torch.stack(vertical_bottom_cache[::-1],
                                            dim=-2)

        vertical_marginals = vertical_top_cache + vertical_bottom_cache - vertical_unary

        # Update and return new unary terms
        average_marginals = (horizontal_marginals + vertical_marginals) / 2.0

        horizontal_unary -= 1.0 / width * (horizontal_marginals -
                                           average_marginals)
        vertical_unary -= 1.0 / height * (vertical_marginals -
                                          average_marginals)

        return horizontal_unary, vertical_unary, horizontal_marginals, vertical_marginals
Ejemplo n.º 26
0
    def forward_log(self, s, nrows=None, ncols=None, dummy_row=False):
        """Compute sinkhorn with row/column normalization in the log space."""
        if len(s.shape) == 2:
            s = s.unsqueeze(0)
            matrix_input = True
        elif len(s.shape) == 3:
            matrix_input = False
        else:
            raise ValueError('input data shape not understood.')

        batch_size = s.shape[0]

        if s.shape[2] >= s.shape[1]:
            transposed = False
        else:
            s = s.transpose(1, 2)
            transposed = True

        if nrows is None:
            nrows = [s.shape[1] for _ in range(batch_size)]
        if ncols is None:
            ncols = [s.shape[2] for _ in range(batch_size)]

        # operations are performed on log_s
        # s = s / self.tau
        s = torch.log(s)

        if dummy_row:
            assert s.shape[2] >= s.shape[1]
            dummy_shape = list(s.shape)
            dummy_shape[1] = s.shape[2] - s.shape[1]
            ori_nrows = nrows
            nrows = ncols
            s = torch.cat(
                (s, torch.full(dummy_shape, -float('inf')).to(s.device)),
                dim=1)
            for b in range(batch_size):
                s[b, ori_nrows[b]:nrows[b], :ncols[b]] = -100
                s[b, nrows[b]:, :] = -float('inf')
                s[b, :, ncols[b]:] = -float('inf')

        if self.batched_operation:
            log_s = s

            for i in range(self.max_iter):
                if i % 2 == 0:
                    log_sum = torch.logsumexp(log_s, 2, keepdim=True)
                    log_s = log_s - log_sum
                    log_s[torch.isnan(log_s)] = -float('inf')
                else:
                    log_sum = torch.logsumexp(log_s, 1, keepdim=True)
                    log_s = log_s - log_sum
                    log_s[torch.isnan(log_s)] = -float('inf')

                # ret_log_s[b, row_slice, col_slice] = log_s

            if dummy_row and dummy_shape[1] > 0:
                log_s = log_s[:, :-dummy_shape[1]]
                for b in range(batch_size):
                    log_s[b, ori_nrows[b]:nrows[b], :ncols[b]] = -float('inf')

            if matrix_input:
                log_s.squeeze_(0)

            return torch.exp(log_s)
        else:
            ret_log_s = torch.full((batch_size, s.shape[1], s.shape[2]),
                                   -float('inf'),
                                   device=s.device,
                                   dtype=s.dtype)

            for b in range(batch_size):
                row_slice = slice(0, nrows[b])
                col_slice = slice(0, ncols[b])
                log_s = s[b, row_slice, col_slice]

                for i in range(self.max_iter):
                    if i % 2 == 0:
                        log_sum = torch.logsumexp(log_s, 1, keepdim=True)
                        log_s = log_s - log_sum
                    else:
                        log_sum = torch.logsumexp(log_s, 0, keepdim=True)
                        log_s = log_s - log_sum

                ret_log_s[b, row_slice, col_slice] = log_s

            if dummy_row:
                if dummy_shape[1] > 0:
                    ret_log_s = ret_log_s[:, :-dummy_shape[1]]
                for b in range(batch_size):
                    ret_log_s[b,
                              ori_nrows[b]:nrows[b], :ncols[b]] = -float('inf')

            if transposed:
                ret_log_s = ret_log_s.transpose(1, 2)
            if matrix_input:
                ret_log_s.squeeze_(0)

            return torch.exp(ret_log_s)
Ejemplo n.º 27
0
 def agg_softmax(x, gamma=3):
     res = 1.0 / gamma * \
         torch.logsumexp(gamma * x, dim=3, keepdim=True)
     return res
Ejemplo n.º 28
0
def get_thermo_loss_from_log_weight_log_p_log_q(log_weight,
                                                log_p,
                                                log_q,
                                                partition,
                                                num_particles=1,
                                                integration='left',
                                                mode='covariance'):
    """Args:
        log_weight: tensor of shape [batch_size, num_particles]
        log_p: tensor of shape [batch_size, num_particles]
        log_q: tensor of shape [batch_size, num_particles]
        partition: partition of [0, 1];
            tensor of shape [num_partitions + 1] where partition[0] is zero and
            partition[-1] is one;
            see https://en.wikipedia.org/wiki/Partition_of_an_interval
        num_particles: int
        integration: left, right or trapz
        mode: covariance or baselined_reinforce

    Returns:
        loss: scalar that we call .backward() on and step the optimizer.
        elbo: average elbo over data
    """
    heated_log_weight = log_weight.unsqueeze(-1) * partition
    heated_normalized_weight = util.exponentiate_and_normalize(
        heated_log_weight, dim=1)
    thermo_logp = partition * log_p.unsqueeze(-1) + \
        (1 - partition) * log_q.unsqueeze(-1)

    wf = heated_normalized_weight * log_weight.unsqueeze(-1)
    w_detached = heated_normalized_weight.detach()
    # wf_detached = wf.detach()
    if num_particles == 1:
        correction = 1
    else:
        correction = num_particles / (num_particles - 1)

    if mode == 'covariance':
        thing_to_add = correction * torch.sum(
            w_detached * (log_weight.unsqueeze(-1) -
                          torch.sum(wf, dim=1, keepdim=True)).detach() *
            (thermo_logp -
             torch.sum(thermo_logp * w_detached, dim=1, keepdim=True)),
            dim=1)
    elif mode == 'baselined_reinforce':
        thing_to_add = correction * torch.sum(
            w_detached *
            (log_weight.unsqueeze(-1) -
             torch.sum(wf, dim=1, keepdim=True)).detach() * thermo_logp,
            dim=1)

    multiplier = torch.zeros_like(partition)
    if integration == 'trapz':
        multiplier[0] = 0.5 * (partition[1] - partition[0])
        multiplier[1:-1] = 0.5 * (partition[2:] - partition[0:-2])
        multiplier[-1] = 0.5 * (partition[-1] - partition[-2])
    elif integration == 'left':
        multiplier[:-1] = partition[1:] - partition[:-1]
    elif integration == 'right':
        multiplier[1:] = partition[1:] - partition[:-1]

    loss = -torch.mean(
        torch.sum(multiplier *
                  (thing_to_add +
                   torch.sum(w_detached * log_weight.unsqueeze(-1), dim=1)),
                  dim=1))

    log_evidence = torch.logsumexp(log_weight, dim=1) - np.log(num_particles)
    elbo = torch.mean(log_evidence)

    return loss, elbo
Ejemplo n.º 29
0
 def log_prob(self, x):
     """Log-probability of sample ``x``."""
     x = x.unsqueeze(-1 - self._event_ndims)
     log_prob_x = self._component_distribution.log_prob(x)  # [S, B, k]
     log_mix_prob = self._mixture_distribution.logits  # [B, k]
     return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1)  # [S, B]
    opt = optim.Adam(model.parameters(), lr, weight_decay=args.wd)
    pbar = tqdm(range(args.n_iter))

    for i in pbar:
        X_mb, t_mb = mnist.train.next_batch(m)
        X_mb, t_mb = torch.from_numpy(X_mb).cuda(), torch.from_numpy(t_mb).long().cuda()

        if not args.use_dropout:
            log_p_y = []

            for _ in range(S):
                y_s, D_KL = model.forward(X_mb)
                log_p_y_s = dists.Categorical(logits=y_s).log_prob(t_mb)
                log_p_y.append(log_p_y_s)

            loss = -torch.mean(torch.logsumexp(torch.stack(log_p_y), 0) - math.log(S))
            loss += args.lam*D_KL
        else:
            y = model.forward(X_mb)
            loss = F.cross_entropy(y, t_mb)

        loss.backward()
        nn.utils.clip_grad_value_(model.parameters(), 5)
        opt.step()
        opt.zero_grad()

        if i % args.info_interval == 0:
            val_acc = validate(m)
            pbar.set_description(f'[Loss: {loss.data.item():.3f}; val acc: {val_acc:.3f}]')

# Save model
Ejemplo n.º 31
0
    def forward(self, qk, v, query_len = None, input_mask = None, input_attn_mask = None, **kwargs):
        batch_size, seqlen, dim, device = *qk.shape, qk.device

        query_len = default(query_len, seqlen)
        is_reverse = kwargs.pop('_reverse', False)
        depth = kwargs.pop('_depth', None)

        assert seqlen % (self.bucket_size * 2) == 0, f'Sequence length ({seqlen}) needs to be divisible by target bucket size  x 2 - {self.bucket_size * 2}'

        n_buckets = seqlen // self.bucket_size
        buckets = self.hash_vectors(n_buckets, qk, key_namespace=depth, fetch=is_reverse, set_cache=self.training)

        # We use the same vector as both a query and a key.
        assert int(buckets.shape[1]) == self.n_hashes * seqlen

        total_hashes = self.n_hashes

        ticker = torch.arange(total_hashes * seqlen, device=device).unsqueeze(0).expand_as(buckets)
        buckets_and_t = seqlen * buckets + (ticker % seqlen)
        buckets_and_t = buckets_and_t.detach()

        # Hash-based sort ("s" at the start of variable names means "sorted")
        sbuckets_and_t, sticker = sort_key_val(buckets_and_t, ticker, dim=-1)
        _, undo_sort = sticker.sort(dim=-1)
        del ticker

        sbuckets_and_t = sbuckets_and_t.detach()
        sticker = sticker.detach()
        undo_sort = undo_sort.detach()

        st = (sticker % seqlen)
        sqk = batched_index_select(qk, st)
        sv = batched_index_select(v, st)

        # Split off a "bin" axis so that attention only occurs within chunks.
        chunk_size = total_hashes * n_buckets
        bq_t = bkv_t = torch.reshape(st, (batch_size, chunk_size, -1))
        bqk = torch.reshape(sqk, (batch_size, chunk_size, -1, dim))
        bv = torch.reshape(sv, (batch_size, chunk_size, -1, dim))

        # Hashing operates on unit-length vectors. Unnormalized query vectors are
        # fine because they effectively provide a learnable temperature for the
        # attention softmax, but normalizing keys is needed so that similarity for
        # the purposes of attention correctly corresponds to hash locality.
        bq = bqk
        bk = F.normalize(bqk, p=2, dim=-1).type_as(bq)

        # Allow each chunk to attend within itself, and also one chunk back. Chunk
        # boundaries might occur in the middle of a sequence of items from the
        # same bucket, so this increases the chances of attending to relevant items.
        def look_one_back(x):
            x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1)
            return torch.cat([x, x_extra], dim=2)

        bk = look_one_back(bk)
        bv = look_one_back(bv)
        bkv_t = look_one_back(bkv_t)

        # Dot-product attention.
        dots = torch.einsum('bhie,bhje->bhij', bq, bk) * (dim ** -0.5)
        masked_value = max_neg_value(dots)

        # Mask for post qk attention logits of the input sequence
        if input_attn_mask is not None:
            input_attn_mask = F.pad(input_attn_mask, (0, seqlen - input_attn_mask.shape[-1], 0, seqlen - input_attn_mask.shape[-2]), value=True)
            dot_attn_indices = ((bq_t * seqlen)[:, :, :, None] + bkv_t[:, :, None, :])
            input_attn_mask = input_attn_mask.reshape(batch_size, -1)
            dot_attn_indices = dot_attn_indices.reshape(batch_size, -1)
            mask = input_attn_mask.gather(1, dot_attn_indices).reshape_as(dots)
            dots.masked_fill_(~mask, masked_value)
            del mask

        # Input mask for padding in variable lengthed sequences
        if input_mask is not None:
            input_mask = F.pad(input_mask, (0, seqlen - input_mask.shape[1]), value=True)
            mq = input_mask.gather(1, st).reshape((batch_size, chunk_size, -1))
            mkv = look_one_back(mq)
            mask = mq[:, :, :, None] * mkv[:, :, None, :]
            dots.masked_fill_(~mask, masked_value)
            del mask

        # Causal masking
        if self.causal:
            mask = bq_t[:, :, :, None] < bkv_t[:, :, None, :]
            if seqlen > query_len:
                mask = mask & (bkv_t[:, :, None, :] < query_len)
            dots.masked_fill_(mask, masked_value)
            del mask

        # Mask out attention to self except when no other targets are available.
        self_mask = bq_t[:, :, :, None] == bkv_t[:, :, None, :]
        dots.masked_fill_(self_mask, TOKEN_SELF_ATTN_VALUE)
        del self_mask

        # Mask out attention to other hash buckets.
        if not self._attend_across_buckets:
            bq_buckets = bkv_buckets = torch.reshape(sbuckets_and_t // seqlen, (batch_size, chunk_size, -1))
            bkv_buckets = look_one_back(bkv_buckets)
            bucket_mask = bq_buckets[:, :, :, None] != bkv_buckets[:, :, None, :]
            dots.masked_fill_(bucket_mask, masked_value)
            del bucket_mask

        # Don't double-count query-key pairs across multiple rounds of hashing.
        # There are two possible strategies here. (1) The default is to count how
        # many times a query-key pair is repeated, and to lower its log-prob
        # correspondingly at each repetition. (2) When hard_k is set, the code
        # instead masks all but the first occurence of each query-key pair.
        if not self._allow_duplicate_attention:
            locs1 = undo_sort // bq_t.shape[-1]
            locs2 = (locs1 + 1) % chunk_size
            if not self._attend_across_buckets:
                locs1 = buckets * chunk_size + locs1
                locs2 = buckets * chunk_size + locs2
            locs = torch.cat([
                torch.reshape(locs1, (batch_size, total_hashes, seqlen)),
                torch.reshape(locs2, (batch_size, total_hashes, seqlen)),
            ], 1).permute((0, 2, 1))

            slocs = batched_index_select(locs, st)
            b_locs = torch.reshape(slocs, (batch_size, chunk_size, -1, 2 * total_hashes))

            b_locs1 = b_locs[:, :, :, None, :total_hashes]

            bq_locs = b_locs1.expand(b_locs.shape[:3] + (2, total_hashes))
            bq_locs = torch.reshape(bq_locs, b_locs.shape)
            bkv_locs = look_one_back(b_locs)

            dup_counts = (bq_locs[:, :, :, None, :] == bkv_locs[:, :, None, :, :])
            # for memory considerations, chunk summation of last dimension for counting duplicates
            dup_counts = chunked_sum(dup_counts, chunks=(total_hashes * batch_size))
            dup_counts = dup_counts.detach()
            assert dup_counts.shape == dots.shape
            dots = dots - torch.log(dup_counts + 1e-9)
            del dup_counts

        # Softmax.
        dots_logsumexp = torch.logsumexp(dots, dim=-1, keepdim=True)
        dots = torch.exp(dots - dots_logsumexp).type_as(dots)
        dropped_dots = self.dropout(dots)

        bo = torch.einsum('buij,buje->buie', dropped_dots, bv)
        so = torch.reshape(bo, (batch_size, -1, dim))
        slogits = torch.reshape(dots_logsumexp, (batch_size, -1,))

        # unsort logits
        o = batched_index_select(so, undo_sort)
        logits = slogits.gather(1, undo_sort)

        o = torch.reshape(o, (batch_size, total_hashes, seqlen, dim))
        logits = torch.reshape(logits, (batch_size, total_hashes, seqlen, 1))

        if query_len != seqlen:
            query_slice = (slice(None), slice(None), slice(0, query_len))
            o, logits = o[query_slice], logits[query_slice]

        probs = torch.exp(logits - torch.logsumexp(logits, dim=1, keepdim=True))
        out = torch.sum(o * probs, dim=1)

        attn = torch.empty(0, device=device)

        # return unsorted attention weights
        if self._return_attn:
            attn_unsort = ((bq_t * seqlen)[:, :, :, None] + bkv_t[:, :, None, :])
            attn_unsort = attn_unsort.view(batch_size * total_hashes, -1).long()
            unsorted_dots = torch.zeros(batch_size * total_hashes, seqlen * seqlen, device=device)
            unsorted_dots.scatter_add_(1, attn_unsort, dots.view_as(attn_unsort))
            del attn_unsort
            unsorted_dots = unsorted_dots.reshape(batch_size, total_hashes, seqlen, seqlen)
            attn = torch.sum(unsorted_dots[:, :, 0:query_len, :] * probs, dim=1)

        # return output, attention matrix, and bucket distribution
        return out, attn, buckets
Ejemplo n.º 32
0
def logsumexp(x):
    return torch.logsumexp(x, 0)
Ejemplo n.º 33
0
 def forward(self, input):
     non_bg_x = torch.logsumexp(input, dim=self.dim, keepdim=True)
     x = torch.cat([-non_bg_x, input], self.dim)
     return x