Exemplo n.º 1
0
    def get_reg(self, **kwargs):
        """
        Get weights regularization (KL(q(w)||p(w)) approximation)
        """

        log_alp = self.clip_alp(self.log_alp)
        element_wise_kl = .5 * log_alp \
                   + 1.16145124 * torch.exp(log_alp) \
                   - 1.50204118 * torch.exp(log_alp) ** 2 \
                   + 0.58629921 * torch.exp(log_alp) ** 3

        sum_kl = element_wise_kl.sum(dim=(1, 2, 3))
        beta = F.sigmoid(self.clip_beta(self.beta_r))

        qz = torch.cat([self.ONE, torch.cumprod(beta, dim=0)]) * torch.cat(
            [1 - beta, self.ONE])
        coef0 = torch.cumsum(qz, dim=0)[:-1]
        coef1 = torch.sum(qz) - coef0
        coef1 = torch.cat([self.ONE.repeat(self.FREEZE_PART),
                           coef1]).repeat_interleave(self.EVERY)

        kl_w = coef1.dot(sum_kl)

        qz = torch.cat([self.ONE, torch.cumprod(beta, dim=0)]) * torch.cat(
            [1 - beta, self.ONE])
        log_frac_qz_pz = torch.log(qz / self.pz)
        kl_z = qz.dot(log_frac_qz_pz)

        kl = -(kl_w - kl_z)
        return kl
Exemplo n.º 2
0
def cumprod(x, axis: int = 0, exclusive: bool = False):
    if exclusive:
        x = torch.transpose(x, axis, -1)
        x = torch.cat((torch.ones_like(x[..., -1:]), x[..., :-1]), -1)
        res = torch.cumprod(x, -1)
        return torch.transpose(res, axis, -1)
    return torch.cumprod(x, axis)
    def gen_sampler(self):
        # 初始化脊椎+椎间盘标注数据list
        v_annos, d_annos = [], []
        # 遍历所有标注数据中的key和对应的脊椎+椎间盘标注数据
        for key, (v_anno, d_anno) in self.annotations:
            # 将对应的数据添加到相应的list中
            v_annos.append(v_anno[:, -1])
            d_annos.append(d_anno[:, -1])
        # 将所有脊椎+椎间盘标注数据进行堆叠
        v_annos = torch.stack(v_annos, dim=0)
        d_annos = torch.stack(d_annos, dim=0)

        # 统计标注数据中每个独立元素的个数
        v_count = torch.unique(v_annos, return_counts=True)[1]
        d_count = torch.unique(d_annos, return_counts=True)[1]

        # 首先对独立元素个数的最后一维进行cumprod操作,获取到结果的最后一列数据后,将其除1,返回除法的浮点数结果而不作整数处理
        v_weights = torch.true_divide(
            1,
            torch.cumprod(v_count[v_annos], dim=-1)[:, -1])
        d_weights = torch.true_divide(
            1,
            torch.cumprod(d_count[d_annos], dim=-1)[:, -1])

        # 根据脊椎权重和椎间盘权重计算得到总权重
        weights = v_weights * d_weights

        # 使用计算好的权重进行加权随机采样,返回采样得到的数据
        return WeightedRandomSampler(weights=weights,
                                     num_samples=len(self),
                                     replacement=True)
Exemplo n.º 4
0
    def fb(self, x, mode='batch'):
        if mode == 'batch':
            # TODO(j-pong): multi-target energy
            x = 1.0 - self.minimaxn(x[:, 0])
            bsz, tsz = x.size()
            xs = x.unsqueeze(-1).repeat(1, 1, tsz)
            ret = []
            for x in xs:
                m_f = torch.triu(torch.ones(tsz, tsz)).to(x.device)
                x_f = torch.cumprod(torch.tril(x) + m_f, dim=-2) - m_f
                m_b = torch.tril(torch.ones(tsz, tsz)).to(x.device)
                x_b = torch.cumprod((torch.triu(x) + m_b).flip(dims=[-2]),
                                    dim=-2).flip(dims=[-2]) - m_b
                ret.append(x_b + x_f)
            xs = torch.stack(ret, dim=0).unsqueeze(1) / 2  # B, 1, T, T
        else:
            x = 1.0 - self.minimaxn(x[:, 0:1])
            xs = []
            for i in range(x.size(-1)):
                if i != 0:
                    x_f = torch.cumprod(x[:, :, i:], dim=-1)
                    x_b = torch.cumprod(x[:, :, :i].flip(dims=[-1]),
                                        dim=-1).flip(dims=[-1])
                    xs.append(torch.cat([x_b, x_f], dim=-1))
                else:
                    xs.append(torch.cumprod(x[:, :, :], dim=-1))
            xs = torch.stack(xs, dim=-1)  # B, tnum, T, T

        return xs
Exemplo n.º 5
0
 def cpt_gate(self, semantic_score: T) -> Tuple[T, T]:
     assert semantic_score.size()[0] > 4
     score = semantic_score[1:-1]  # (num_score - 2)
     fwd_score = torch.cat([torch.zeros(score.size(0)), score], dim=0)
     bwd_score = torch.cat([score, torch.zeros(score.size(0))], dim=0)
     fwd_score_hat = torch.stack([
         fwd_score[i:i + score.size()[0]]
         for i in range(score.size()[0] - 1, 0, -1)
     ],
                                 dim=0)
     bwd_score_hat = torch.stack([
         bwd_score[i:i + score.size(0)] for i in range(1,
                                                       score.size()[0])
     ],
                                 dim=0)
     if self.hard:
         fwd_gate = (F.hardtanh(
             (fwd_score_hat - score[None, :]) / self.resolution * 2 + 1) +
                     1) / 2
         bwd_gate = (F.hardtanh(
             (bwd_score_hat - score[None, :]) / self.resolution * 2 + 1) +
                     1) / 2
     else:
         fwd_gate = F.sigmoid((fwd_score_hat - score[None, :]) /
                              self.resolution * 10 + 5)
         bwd_gate = F.sigmoid((bwd_score_hat - score[None, :]) /
                              self.resolution * 10 + 5)
     fwd_gate = torch.cumprod(fwd_gate, dim=0)  # seq x seq - 1
     bwd_gate = torch.cumprod(bwd_gate, dim=0)  # seq x seq - 1
     return (fwd_gate, bwd_gate)
Exemplo n.º 6
0
    def fb(x, mode='batch'):
        # TODO(j-pong): multi-target energy
        with torch.no_grad():
            if mode == 'batch':
                x = x[:, 0]
                bsz, tsz = x.size()
                xs = x.unsqueeze(-1).repeat(1, 1, tsz)
                ret = []
                for x in xs:
                    ones = torch.ones(tsz, tsz).to(x.device)

                    m_f = torch.triu(ones, diagonal=1)
                    x_f = torch.cumprod(torch.tril(x) + m_f, dim=-2)
                    x_f = torch.cat([ones[0:1], x_f[:-1]], dim=0) - m_f

                    ret.append(x_f)
                xs = torch.stack(ret, dim=0).unsqueeze(1)  # B, 1, T, T
            else:
                xs = []
                for i in range(x.size(-1)):
                    if i != 0:
                        x_f = torch.cumprod(x[:, :, i:], dim=-1)
                        x_b = torch.cumprod(x[:, :, :i].flip(dims=[-1]), dim=-1).flip(dims=[-1])
                        xs.append(torch.cat([x_b, x_f], dim=-1))
                    else:
                        xs.append(torch.cumprod(x[:, :, :], dim=-1))
                xs = torch.stack(xs, dim=-1)  # B, tnum, T, T

        if torch.isnan(xs.sum()):
            raise ValueError("kernel target value has nan")

        return xs
Exemplo n.º 7
0
 def compute_gate(self, score: T) -> Tuple[T, T]:
     #assert score.size()[0] > 4
     score = score[1:-1]
     fwd_gate = self.compute_prob(self.pad_score(score), score)
     bwd_gate = self.compute_prob(
         self.pad_score(torch.flip(score, dims=(0, ))),
         torch.flip(score, dims=(0, )))
     fwd_gate = torch.cumprod(fwd_gate, dim=0)  # seq x seq - 1
     bwd_gate = torch.cumprod(bwd_gate, dim=0)  # seq x seq - 1
     return (fwd_gate, bwd_gate)
Exemplo n.º 8
0
def raw2outputs_blending(raw_dy, raw_rigid, raw_blend_w, z_vals, rays_d,
                         raw_noise_std):
    act_fn = F.relu

    dists = z_vals[..., 1:] - z_vals[..., :-1]
    dists = torch.cat(
        [dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)],
        -1)  # [N_rays, N_samples]
    dists = dists * torch.norm(rays_d[..., None, :], dim=-1)

    rgb_dy = torch.sigmoid(raw_dy[..., :3])  # [N_rays, N_samples, 3]
    rgb_rigid = torch.sigmoid(raw_rigid[..., :3])  # [N_rays, N_samples, 3]

    noise = 0.
    if raw_noise_std > 0.:
        noise = torch.randn(raw_dy[..., 3].shape) * raw_noise_std

    opacity_dy = act_fn(raw_dy[..., 3] + noise)  #.detach() #* raw_blend_w
    opacity_rigid = act_fn(raw_rigid[..., 3] +
                           noise)  #.detach() #* (1. - raw_blend_w)

    # alpha with blending weights
    alpha_dy = (1. - torch.exp(-opacity_dy * dists)) * raw_blend_w
    alpha_rig = (1. - torch.exp(-opacity_rigid * dists)) * (1. - raw_blend_w)

    Ts = torch.cumprod(
        torch.cat([
            torch.ones((alpha_dy.shape[0], 1)),
            (1. - alpha_dy) * (1. - alpha_rig) + 1e-10
        ], -1), -1)[:, :-1]

    weights_dy = Ts * alpha_dy
    weights_rig = Ts * alpha_rig

    # union map
    rgb_map = torch.sum(weights_dy[..., None] * rgb_dy + \
                        weights_rig[..., None] * rgb_rigid, -2)

    weights_mix = weights_dy + weights_rig
    depth_map = torch.sum(weights_mix * z_vals, -1)

    # compute dynamic depth only
    alpha_dynamic = 1. - torch.exp(-opacity_dy * dists)
    weights_dynamic = alpha_dynamic * torch.cumprod(
        torch.cat([
            torch.ones((alpha_dynamic.shape[0], 1)), 1. - alpha_dynamic + 1e-10
        ], -1), -1)[:, :-1]
    depth_map_dynamic = torch.sum(weights_dynamic * z_vals, -1)
    rgb_map_dy = torch.sum(
        weights_dynamic[..., None] * torch.sigmoid(raw_dy[..., :3]), -2)

    return rgb_map, depth_map, \
           rgb_map_dy, depth_map_dynamic, weights_dynamic
Exemplo n.º 9
0
def torch_ideal_err(sorted_labels, k=10, point=True, gpu=False):
    assert sorted_labels.size(0) >= k

    max_label = torch.max(sorted_labels)

    labels = sorted_labels[0:k]
    satis_pros = (torch.pow(2.0, labels) - 1.0) / torch.pow(2.0, max_label)

    unsatis_pros = torch.ones_like(labels) - satis_pros
    cum_unsatis_pros = torch.cumprod(unsatis_pros, dim=0)

    if gpu:
        ranks = torch.arange(k).type(tensor) + 1.0
        expt_ranks = 1.0 / ranks
    else:
        ranks = torch.arange(k) + 1.0
        expt_ranks = 1.0 / ranks

    cascad_unsatis_pros = ranks
    cascad_unsatis_pros[1:k] = cum_unsatis_pros[0:k - 1]

    expt_satis_ranks = expt_ranks * satis_pros * cascad_unsatis_pros  # w.r.t. all rank positions

    if point:  # a specific position
        ideal_err = torch.sum(expt_satis_ranks, dim=0)
        return ideal_err

    else:
        ideal_err_at_ks = torch.cumsum(expt_satis_ranks, dim=0)
        return ideal_err_at_ks
Exemplo n.º 10
0
    def update_allocation_weights(self, sorted_indices, ordered_usage):
        '''
        Update allocation weights (B, N)

        :param (B, R) free gates

        '''

        ones = torch.ones(self.batch_size, device=self.device).view(-1, 1)
        usage_ones = torch.cat((ones, ordered_usage[:, :-1]), dim=1)

        prod = torch.cumprod(usage_ones, dim=1)
        #prod = torch.cat((ones, prod), dim=1)

        allocation_weights_ordered = (1 - ordered_usage) * prod

        # reorder allocation weights - inefficient but simpler version
        #for b in self.batch_size:
        #    self.allocation_weights[b][sorted_indices[b]] = allocation_weights_ordered[b]

        # more efficient version than above - avoid a loop over B dimension
        # N.B. this part is crucial for the convergence of the Copy task
        allocation_weights = self.unorder_tensor(allocation_weights_ordered,
                                                 sorted_indices)

        return allocation_weights
Exemplo n.º 11
0
    def _compute_loss_actor(self,
                            imag_beliefs,
                            imag_states,
                            imag_ac_logps=None):
        # reward and value prediction of imagined trajectories
        imag_rewards = bottle(self.reward_model, (imag_beliefs, imag_states))
        imag_values = bottle(self.value_model, (imag_beliefs, imag_states))

        with torch.no_grad():
            if self.args.pcont:
                pcont = bottle(self.pcont_model, (imag_beliefs, imag_states))
            else:
                pcont = self.args.discount * torch.ones_like(imag_rewards)
        pcont = pcont.detach()

        if imag_ac_logps is not None:
            imag_values[
                1:] -= self.args.temp * imag_ac_logps  # add entropy here

        returns = cal_returns(imag_rewards[:-1],
                              imag_values[:-1],
                              imag_values[-1],
                              pcont[:-1],
                              lambda_=self.args.disclam)

        discount = torch.cumprod(
            torch.cat([torch.ones_like(pcont[:1]), pcont[:-2]], 0),
            0).detach()

        actor_loss = -torch.mean(discount * returns)
        return actor_loss
Exemplo n.º 12
0
def cumprod_exclusive(tensor: torch.Tensor) -> torch.Tensor:
    r"""Mimick functionality of tf.math.cumprod(..., exclusive=True), as it isn't available in PyTorch.

    Args:
    tensor (torch.Tensor): Tensor whose cumprod (cumulative product, see `torch.cumprod`) along dim=-1
      is to be computed.

    Returns:
    cumprod (torch.Tensor): cumprod of Tensor along dim=-1, mimiciking the functionality of
      tf.math.cumprod(..., exclusive=True) (see `tf.math.cumprod` for details).
    """
    # TESTED
    # Only works for the last dimension (dim=-1)
    dim = -1

    # Compute regular cumprod first (this is equivalent to `tf.math.cumprod(..., exclusive=False)`).
    cumprod = torch.cumprod(tensor, dim)

    # "Roll" the elements along dimension 'dim' by 1 element.
    cumprod = torch.roll(cumprod, 1, dim)

    # Replace the first element by "1" as this is what tf.cumprod(..., exclusive=True) does.
    cumprod[..., 0] = 1.0

    return cumprod
Exemplo n.º 13
0
def torch_batch_ideal_err(batch_sorted_labels, k=10, gpu=False, point=True):
    assert batch_sorted_labels.size(1) > k

    batch_max = torch.max(batch_sorted_labels, dim=1)

    batch_labels = batch_sorted_labels[:, 0:k]
    batch_satis_pros = (torch.pow(2.0, batch_labels) - 1.0) / torch.pow(
        2.0, batch_max)

    batch_unsatis_pros = torch.ones(batch_labels) - batch_satis_pros
    batch_cum_unsatis_pros = torch.cumprod(batch_unsatis_pros, dim=1)

    positions = torch.arange(k) + 1.0
    positions = positions.view(1, -1)
    positions = torch.repeat_interleave(positions,
                                        batch_sorted_labels.size(0),
                                        dim=0)

    batch_expt_ranks = 1.0 / positions

    cascad_unsatis_pros = positions
    cascad_unsatis_pros[:, 1:k] = batch_cum_unsatis_pros[:, 0:k - 1]

    expt_satis_ranks = batch_expt_ranks * batch_satis_pros * cascad_unsatis_pros  # w.r.t. all rank positions

    if point:
        batch_errs = torch.sum(expt_satis_ranks, dim=1)
        return batch_errs
    else:
        batch_err_at_ks = torch.cumsum(expt_satis_ranks, dim=1)
        return batch_err_at_ks
Exemplo n.º 14
0
    def __init__(self, betas, model_mean_type, model_var_type, loss_type):
        super().__init__()

        betas = betas.type(torch.float64)
        timesteps = betas.shape[0]
        self.num_timesteps = int(timesteps)

        self.model_mean_type = model_mean_type  # xprev, xstart, eps
        self.model_var_type = model_var_type  # learned, fixedsmall, fixedlarge
        self.loss_type = loss_type  # kl, mse

        alphas = 1 - betas
        alphas_cumprod = torch.cumprod(alphas, 0)
        alphas_cumprod_prev = torch.cat(
            (torch.tensor([1], dtype=torch.float64), alphas_cumprod[:-1]), 0
        )
        posterior_variance = betas * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)

        self.register("betas", betas)
        self.register("alphas_cumprod", alphas_cumprod)
        self.register("alphas_cumprod_prev", alphas_cumprod_prev)

        self.register("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
        self.register("sqrt_one_minus_alphas_cumprod", torch.sqrt(1 - alphas_cumprod))
        self.register("log_one_minus_alphas_cumprod", torch.log(1 - alphas_cumprod))
        self.register("sqrt_recip_alphas_cumprod", torch.rsqrt(alphas_cumprod))
        self.register("sqrt_recipm1_alphas_cumprod", torch.sqrt(1 / alphas_cumprod - 1))
        self.register("posterior_variance", posterior_variance)
        self.register("posterior_log_variance_clipped",
                      torch.log(torch.cat((posterior_variance[1].view(1, 1),
                                           posterior_variance[1:].view(-1, 1)), 0)).view(-1))
        self.register("posterior_mean_coef1", (betas * torch.sqrt(alphas_cumprod_prev) / (1 - alphas_cumprod)))
        self.register("posterior_mean_coef2", ((1 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1 - alphas_cumprod)))
Exemplo n.º 15
0
def off_policy_weight(eval_log_p, behavior_log_p, full_trajectory=False, clamp_max=5.0):
    """Compute off-policy weight.

    Parameters
    ----------
    eval_log_p: torch.tensor.
        Evaluation log probabilities.
    behavior_log_p: torch.tensor.
        Behavior log probabilities.
    full_trajectory: bool, optional (default=False).
        Flag that indicates whether the off-policy weight is for a single step or for
        the full trajectory.
    clamp_max: float.
        Value to clamp max.

    Returns
    -------
    weight: torch.Tensor.
        Importance sample weights of the trajectory.
    """
    weight = torch.exp(eval_log_p - behavior_log_p)
    if full_trajectory:
        weight = torch.cumprod(weight, dim=-1)

    return weight.clamp_max(clamp_max)
Exemplo n.º 16
0
    def _allocation(self, usage):

        usage = _EPSILON + (1 - _EPSILON) * usage

        sorted_usage, indices = torch.topk(usage,
                                           memory_size,
                                           largest=False,
                                           sorted=True)

        #this is not quite right, check

        prod_sorted_usage = torch.cumprod(sorted_usage, axis=1)
        exclusive_prod_sorted_usage = torch.ones(1, memory_size)
        exclusive_prod_sorted_usage[0][1:] = prod_sorted_usage[0][:-1]

        ##end

        sorted_allocation = (1 - sorted_usage) * exclusive_prod_sorted_usage

        final_allocation = torch.zeros((1, memory_size))
        #undo the initial permutation
        for i in range(memory_size):
            a_index = indices[0][i]
            final_allocation[0][a_index] = sorted_allocation[0][i]

        return final_allocation
Exemplo n.º 17
0
    def _apply_discount(self, rewards):
        cum_discount = torch.cumprod(
            self.hp.gamma * torch.ones(*rewards[0].size()),
            dim=1) / self.hp.gamma
        discounted_rewards = rewards * cum_discount

        return discounted_rewards
Exemplo n.º 18
0
def flat_to_net(flat_net):
    test_net = SimpleMNISTCNN()
    index = 0
    for p in test_net.parameters():
        end_index = torch.cumprod(torch.tensor(p.shape), 0)[-1] + index
        p = flat_net[index:end_index].reshape(p.shape)
        index = end_index
Exemplo n.º 19
0
    def forward(self, output, dropoutl):
        # output: [seq_len x batch_size x nhidlast]
        latent = self.latent(output)  # h
        latent = self.lockdrop(latent, dropoutl)  # h after variational dropout [seq_len x batch_size x n_experts * ninp]
        logit = self.decoder(latent.view(-1, self.ninp))  # HW [seq_len * batch_size * n_experts x voc_size]

        a = self.reduce(output.view(-1, self.nhidlast)) + 1e-8 # [seq_len * batch_size x n_experts]
        b = torch.sum(a, 1)
        c = torch.cumsum(a, 1)
        d = torch.abs(b[:, None] - c)
        a = a.to('cpu')
        d = d.to('cpu')
        beta = torch.distributions.Beta(a, d)
        sample = beta.rsample()  # [seq_len * batch_size x n_experts]
        sample = sample.to('cuda')
        rem = 1 - sample
        D = torch.diag(torch.ones(self.n_experts - 1, device=device), 1)
        rem = rem @ D
        rem[:, 0] = 1
        remprod = torch.cumprod(rem, 1)
        pis = remprod * sample
        prob = nn.functional.softmax(logit.view(-1, self.ntoken), dim=1).view(-1, self.n_experts, self.ntoken)  # exp(hw) / sum(exp(hw))
        prob = (prob * pis.unsqueeze(2).expand_as(prob)).sum(1)  # weighted sum
        # TODO maybe we can do this with logsoftmax
        return prob
Exemplo n.º 20
0
    def _composite(self, optic_d, s_samples, rays_d):
        # distances between each samples
        dists = s_samples[..., 1:] - s_samples[
            ..., :-1]  # (chunk_size, N_samples - 1)
        dists = torch.cat(
            [dists, torch.tensor([1e10]).expand(dists[..., :1].shape)],
            -1)  # (chunk_size, N_samples)

        dists = dists * torch.norm(rays_d[..., None, :], dim=-1)

        # retrieve display colors and alphas for each samples by a transfer function
        rgbs, alphas = self._transfer(optic_d, dists)

        weights = alphas * torch.cumprod(torch.cat(
            [torch.ones(
                (alphas.shape[0], 1)), 1.0 - alphas + 1e-10], dim=-1)[:, :-1],
                                         dim=-1)  # (chunk_size, N_samples)
        rgb_map = torch.sum(weights[..., None] * rgbs,
                            dim=-2)  # (chunk_size, 3)
        acc_map = torch.sum(weights, -1)  # (chunk_size)

        if self.config.white_bkgd:
            rgb_map = rgb_map + (1.0 - acc_map[..., None])

        return rgb_map, weights
Exemplo n.º 21
0
def raw2outputs_warp(raw_p, z_vals, rays_d, raw_noise_std=0):

    dists = z_vals[..., 1:] - z_vals[..., :-1]
    dists = torch.cat(
        [dists, torch.Tensor([1e10]).expand(dists[..., :1].shape)],
        -1)  # [N_rays, N_samples]

    dists = dists * torch.norm(rays_d[..., None, :], dim=-1)

    rgb = torch.sigmoid(raw_p[..., :3])  # [N_rays, N_samples, 3]
    noise = 0.

    if raw_noise_std > 0.:
        noise = torch.randn(raw_p[..., 3].shape) * raw_noise_std

    act_fn = F.relu
    opacity = act_fn(raw_p[..., 3] + noise)

    alpha = 1. - torch.exp(-opacity * dists)

    weights = alpha * torch.cumprod(
        torch.cat([torch.ones(
            (alpha.shape[0], 1)), 1. - alpha + 1e-10], -1), -1)[:, :-1]
    rgb_map = torch.sum(weights[..., None] * rgb, -2)  # [N_rays, 3]

    depth_map = torch.sum(weights * z_vals, -1)

    return rgb_map, depth_map, weights  #, alpha #alpha#, 1. - probs
Exemplo n.º 22
0
 def _discount_rewards(self, rewards):
     gammas = torch.ones_like(rewards)
     if len(rewards) > 1:
         gammas[1:] = torch.cumprod(torch.tensor(
             self.config['discount_rate']).repeat(len(rewards) - 1),
                                    dim=0)
     return gammas * rewards
Exemplo n.º 23
0
    def _build_mask_and_subsample(
        self, not_dones, valids,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int]:
        t = not_dones.size(1)

        valid_mask = (not_dones > 0) & valids
        valid_mask = valid_mask.float()

        valid_mask_unfolded = self._build_unfolded(
            valid_mask[:, :-1].to(dtype=torch.bool), self.k
        )

        time_subsample = torch.randperm(
            t - 1, device=valid_mask.device, dtype=torch.long
        )[0:self.time_subsample]

        forward_mask = (
            torch.cumprod(valid_mask_unfolded.index_select(2, time_subsample), dim=0)
            .to(dtype=torch.bool)
            .flatten(1, 2)
        )

        max_k = forward_mask.flatten(1).any(-1).nonzero().max().item() + 1

        unroll_subsample = torch.randperm(max_k, dtype=torch.long)[0:self.forward_subsample]

        max_k = unroll_subsample.max().item() + 1

        unroll_subsample = unroll_subsample.to(device=valid_mask.device)
        forward_mask = forward_mask.index_select(0, unroll_subsample)

        return forward_mask, unroll_subsample, time_subsample, max_k
Exemplo n.º 24
0
def make_objective(batch_data, v_fn, tau=1.0, dice_lambda=1.0, use_dice=False, gamma_weighted=True):
    """Make objective for DiCE, Loaded DiCE, or classic surrogate loss"""
    (batch_states, batch_pi_taken, batch_a_taken,
        batch_r, batch_dones, batch_mask, ep_returns) = batch_data
    empty_mask = (1 - batch_mask[:,:-1]).type(torch.ByteTensor)
    batch_pi_taken[empty_mask] = 1.0
    log_pi = torch.log(batch_pi_taken)
    batch_values = v_fn(batch_states).detach()
    advantages = gae(batch_r, batch_values, batch_mask, tau=tau, gamma_weighted=gamma_weighted)
    if use_dice == "old":
        log_pi_cumsum = torch.cumsum(log_pi, 1)
        deps = magic_box(log_pi_cumsum)
        batch_r[:,-1] = batch_r[:,-1] + batch_values[:,-1]*gamma
        if gamma_weighted:
            gamma_weights = torch.cumprod(torch.ones_like(advantages) * gamma, 1) / gamma
            batch_r = batch_r * gamma_weights
        obj = (batch_r * deps).sum(1).mean()
    elif use_dice == "loaded":
        weighted_cumsum = torch.zeros_like(log_pi)
        weighted_cumsum[:,0] = log_pi[:,0]
        for t in range(1, log_pi.size(1)):
            weighted_cumsum[:,t] = dice_lambda * weighted_cumsum[:,t-1] + log_pi[:,t]
        deps_exclusive = weighted_cumsum - log_pi
        full_deps = magic_box(weighted_cumsum) - magic_box(deps_exclusive)
        obj = (advantages * full_deps).sum(1).mean()
    else:
        obj = (advantages * log_pi).sum(1).mean()
    return obj
Exemplo n.º 25
0
    def allocation_weighting(self):
        '''
        Sorts the memory by usages first.
        Then perform calculation depending on the sort order.

        The alloation_weighting of the third least used memory is calculated as follows:
        Find the least used and second least used. Multiply their usages.
        Multiply the product with (1-usage of the third least), return.

        Do not confuse the sort order and the memory's natural location.
        Verify backprop.

        :param usage_vector: u_t, (N), [0,1]
        :return: allocation_wighting: a_t, (N), simplex bound
        '''

        # not the last usage, since we will update usage before this
        sorted, indices = self.last_usage_vector.sort(dim=1)
        cum_prod = torch.cumprod(sorted, 1)
        # notice the index on the product
        cum_prod = torch.cat(
            [Variable(torch.ones(self.bs, 1).cuda()), cum_prod], 1)[:, :-1]
        sorted_inv = 1 - sorted
        allocation_weighting = sorted_inv * cum_prod
        # to shuffle back in place
        ret = torch.gather(allocation_weighting, 1, indices)
        if debug:
            if (ret != ret).any():
                raise ValueError("NA found in allocation weighting")
        return ret
Exemplo n.º 26
0
    def dice_objective(self):
        self_logprobs = torch.stack(self.self_logprobs, dim=1)
        other_logprobs = torch.stack(self.other_logprobs, dim=1)
        values = torch.stack(self.values, dim=1)
        rewards = torch.stack(self.rewards, dim=1)

        # apply discount:
        cum_discount = torch.cumprod(hp.gamma * torch.ones(*rewards.size()),
                                     dim=1) / hp.gamma
        discounted_rewards = rewards * cum_discount
        discounted_values = values * cum_discount

        # stochastics nodes involved in rewards dependencies:
        dependencies = torch.cumsum(self_logprobs + other_logprobs, dim=1)

        # logprob of each stochastic nodes:
        stochastic_nodes = self_logprobs + other_logprobs

        # dice objective:
        dice_objective = torch.mean(
            torch.sum(magic_box(dependencies) * discounted_rewards, dim=1))

        if hp.use_baseline:
            # variance_reduction:
            baseline_term = torch.mean(
                torch.sum(
                    (1 - magic_box(stochastic_nodes)) * discounted_values,
                    dim=1))
            dice_objective = dice_objective + baseline_term

        return -dice_objective  # want to minimize -objective
Exemplo n.º 27
0
    def forward(self, **kwargs):
        h = kwargs['x_path']

        A, h = self.attention_net(h)
        A = torch.transpose(A, 1, 0)

        if 'attention_only' in kwargs.keys():
            if kwargs['attention_only']:
                return A

        A_raw = A
        A = F.softmax(A, dim=1)
        M = torch.mm(A, h)
        logits = self.classifier(M)
        Y_hat = torch.topk(logits, 1, dim=1)[1]
        hazards = torch.sigmoid(logits)
        # hazards = F.softmax(logits, dim=1)
        S = torch.cumprod(1 - hazards, dim=1)
        # S = 1 - torch.cumsum(hazards, dim=1)
        results_dict = {}

        if 'return_features' in kwargs.keys():
            if kwargs['return_features']:
                results_dict.update({'features': M})

        return hazards, S, Y_hat, A_raw, results_dict
Exemplo n.º 28
0
def jet_cross_entropy_loss(prediction: Tensor, target_data: Tensor,
                           target_mask: Tensor, gamma: float) -> Tensor:
    batch_size = prediction.shape[0]
    prediction_shape = prediction.shape[1:]

    # Remove missing jets
    target_data = target_data.clamp(0, None)

    # Find the unravelling shape required to flatten the target indices
    ravel_sizes = torch.tensor(prediction_shape).flip(0)
    ravel_sizes = torch.cumprod(ravel_sizes, 0)
    ravel_sizes = ravel_sizes // ravel_sizes[0]
    ravel_sizes = ravel_sizes.flip(0).unsqueeze(0)
    ravel_sizes = ravel_sizes.to(target_data.device)

    # Flatten the target and predicted data to be one dimensional
    ravel_target = (target_data * ravel_sizes).sum(1)
    ravel_prediction = prediction.reshape(batch_size, -1).contiguous()

    log_probability = ravel_prediction.gather(-1,
                                              ravel_target.view(-1,
                                                                1)).squeeze()
    focal_scale = (1 - torch.exp(log_probability))**gamma

    return -log_probability * focal_scale * target_mask
Exemplo n.º 29
0
    def render(self, volume, axis=2):
        padding = 1
        volume = torch.nn.functional.pad(
            volume, (padding, padding, padding, padding, padding, padding))

        density = volume[:, [3]]
        signal = volume[:, :3]

        bs = density.shape[0]
        sample_coords = self.sample_coords.expand(bs, self.sample_size_z,
                                                  self.param.data.cube_len,
                                                  self.param.data.cube_len, 3)

        density = density * self.density_factor
        density = density / self.sample_size_z
        density = torch.nn.functional.grid_sample(density, sample_coords)
        transmission = torch.cumprod(1.0 - density, dim=axis)

        weight = density * transmission
        weight_sum = torch.sum(weight, dim=axis)

        signal = torch.nn.functional.grid_sample(signal, sample_coords)

        rendering = torch.sum(weight * signal, dim=axis)
        rendering = rendering / (weight_sum + 1e-8)

        alpha = 1.0 - torch.prod(1 - density, dim=axis)

        rendering = rendering * alpha

        rendering = torch.cat([rendering, alpha], dim=1)

        return rendering
Exemplo n.º 30
0
def retrace(q_values, rewards, importance_weights=None, discount=0.99, l=0.75):
    """
    Retrace estimate (Munos et al., 2016).

    Args:
        q_values (torch.Tensor): the Q-value estimates at each time step [time_steps+1, batch_size, 1]
        rewards (torch.Tensor): the rewards at each time step [time_steps, batch_size, 1]
        importance_weights (torch.Tensor): the importance weights at each time step [time_steps, batch_size, 1]
        discount (float): the temporal discount factor
        l (float): the lambda weighting factor
    """
    if q_values.shape[0] == 1 or l == -1:
        # degenerate case
        return q_values

    if importance_weights is None:
        # On-policy
        importance_weights = torch.ones_like(q_values)

    deltas = rewards + discount * q_values[1:] - q_values[:-1]
    importance_weights = l * torch.clamp(importance_weights, 0, 1)[:-2]
    importance_weights = torch.cat([torch.ones_like(q_values[:1]), importance_weights], 0)
    discounts = torch.cat([(discount*torch.ones_like(q_values[:1]))**i for i in range(q_values.shape[0])], 0)
    q_estimates = q_values[:1] + torch.sum(discounts[:-1] * torch.cumprod(importance_weights, 0) * deltas, 0, keepdim=True)

    return q_estimates
Exemplo n.º 31
0
 def forward(self, x):
     x0 = self.conv.forward(x.float())
     x = self.pool_mil(x0)
     x = x.squeeze(2).squeeze(2)
     x1 = torch.add(torch.mul(x0.view(x.size(0), 1000, -1), -1), 1)
     cumprod = torch.cumprod(x1, 2)
     out = torch.max(x, torch.add(torch.mul(cumprod[:, :, -1], -1), 1))
     #out = F.softmax(out)
     return out
Exemplo n.º 32
0
 def forward(self, img, att_size=14):
     x0 = self.conv(img)
     x = self.pool_mil(x0)
     x = x.squeeze(2).squeeze(2)
     x = self.l1(x)
     x1 = torch.add(torch.mul(x.view(x.size(0), 1000, -1), -1), 1)
     cumprod = torch.cumprod(x1, 2)
     out = torch.max(x, torch.add(torch.mul(cumprod[:, :, -1], -1), 1))
     return out
Exemplo n.º 33
0
 def forward(ctx, input, dim):
     ctx.dim = dim
     ctx.save_for_backward(input)
     return torch.cumprod(input, dim=ctx.dim)
Exemplo n.º 34
0
    def backward(ctx, grad_output):
        '''
        There are two algorithms to do this. The first one
        is very efficient, but works only when there are no
        nonzero elements in the input.

        The second one is much more complex, but it doesn't
        assume anything on the input. The main downside is
        that it takes time O(n^2), where n = input.size(self.dim)
        (i.e. the length of the cumulative product). This is in
        contrast to the forward pass and the efficient algorithm,
        which are both O(n).

        The second algorithm is a simple application of the chain
        rule. If x is an n-dimensional vector, and y = cumprod(x),
        and F is the final cost, then

        dF / dx_k = sum_j (dF / dy_j) * (dy_j / dx_k)   (1)

        The term dF / dy_j is just grad_output[j] (assuming again
        everything is one-dimensional).

        The term (dy_j / dx_k) is easilly seen to be

        if j >= k
            dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i
        else:
            dy_j / dx_k = 0

        Note that the indicator (j>=k) can be taken out
        by replacing the sum in (1) with a sum from
        j = k to n.

        Thus,
        df / dx_k = sum_{k <= j <= n} grad_output[j] * (dy_j / dx_k)

        with
        dy_j / dx_k = prod_{1 <= i <= j, i != k} x_i     (2)

        Note that this last term is just the cumulative product
        with k omitted. Thus, if x_k (the input) is nonzero, we can
        just express this as

        dy_j / dx_k = (prod_{1 <= i <= j} x_i) / x_k
                    = y_j / x_k

        So therefore,

        df / dx_k = sum_{k <= j <= n} grad_output[j] * y_j / x_k

        so

        grad_output = sum_scan_exclusiv(grad_output * output) / input

        If the input is nonzero, we need to calculate the dy_j / dx_k
        by using the formula (2), called in the code omitted_products.

        The way the code calculates it is simply by noting that

        prod_{1 <= i <= j, i != k} x_i
            = (prod_{1 <= i <= k} x_i) * (prod_{k + 1 <= i <= j} x_i)

        the first term is calculated as prods_until_k, which since
        doesn't depend in j is easy to vectorize.

        The second term (indexed by j) is the cumulative product of
        x_{k+1}, x_{k+2}, ..., x_n, and it's named in the code
        prods_from_k_pkus_1, and it's calculated as a cumprod.

        In order to vectorize this properly, we need to add to
        omitted_products the dimensions where k > j, and therefore
        dy_j / dx_k = 0, which is done right after the assert.
        '''

        input, = ctx.saved_variables
        dim_size = input.size(ctx.dim)
        if dim_size == 1:
            return grad_output, None

        #  Simple case with nonzero elements in the input
        if (input != 0).data.all():
            output = torch.cumprod(input, dim=ctx.dim)
            return sum_scan_exclusive(output * grad_output, dim=ctx.dim) / input, None

        positive_dim = ctx.dim if ctx.dim >= 0 else input.dim() + ctx.dim
        dim_padding = (slice(None, None),) * (positive_dim)

        ones_size = list(input.size())
        ones_size[ctx.dim] = 1
        ones = Variable(input.data.new([1]).expand(ones_size))
        grad_input = Variable(grad_output.data.new(input.size()).zero_())
        for k in range(dim_size):
            if k == 0:
                prods_from_k_plus_1 = torch.cumprod(
                    input[dim_padding + (slice(k + 1, None),)],
                    dim=ctx.dim
                )

                omitted_products = torch.cat(
                    (ones, prods_from_k_plus_1),
                    dim=ctx.dim
                )

            elif k == dim_size - 1:
                prods_until_k = torch.prod(
                    input[dim_padding + (slice(None, k),)],
                    dim=ctx.dim,
                    keepdim=True
                )

                omitted_products = prods_until_k

            else:
                prods_until_k = torch.prod(
                    input[dim_padding + (slice(None, k),)],
                    dim=ctx.dim,
                    keepdim=True
                )

                prods_from_k_plus_1 = torch.cumprod(
                    input[dim_padding + (slice(k + 1, None),)],
                    dim=ctx.dim
                )

                omitted_products = prods_until_k.expand_as(
                    prods_from_k_plus_1) * prods_from_k_plus_1

                omitted_products = torch.cat(
                    (prods_until_k, omitted_products), ctx.dim)

            # At this point omitted_products is the same size
            # as input, except on the dimension dim where it's
            # dim_size - k
            assert omitted_products.size(ctx.dim) == dim_size - k

            # should we implement copy_ or _set_item in variable?
            index = tuple(slice(None, None) for _ in range(positive_dim)) + (k,)
            grad_input[index] = torch.sum(
                grad_output[dim_padding + (slice(k, None),)] * omitted_products,
                dim=ctx.dim)

        return grad_input, None